Include ollama in API

This commit is contained in:
Nexus 2024-06-06 00:15:01 +01:00
parent 6fb9bb839b
commit 2625657699
Signed by: nex
GPG key ID: 0FA334385D0B689F

View file

@ -6,7 +6,7 @@ import os
import secrets
import typing
import time
from fastapi import FastAPI, Depends, HTTPException
from fastapi import FastAPI, Depends, HTTPException, APIRouter
from fastapi.responses import JSONResponse, Response
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from pydantic import BaseModel, Field
@ -40,25 +40,29 @@ def check_credentials(credentials: HTTPBasicCredentials = Depends(security)):
return credentials
def get_db() -> redis.Redis:
uri = os.getenv("REDIS_URL", "redis://redis")
conn = redis.Redis.from_url(uri)
try:
yield conn
finally:
conn.close()
def get_db_factory(n: int = 11) -> typing.Callable[[], typing.Generator[redis.Redis, None, None]]:
def inner():
uri = os.getenv("REDIS_URL", "redis://redis")
conn = redis.Redis.from_url(uri)
conn.select(n)
try:
yield conn
finally:
conn.close()
return inner
app = FastAPI(
title="Jimmy v3 API",
version="3.0.0",
dependencies=[Depends(check_credentials)],
root_path=os.getenv("WEB_ROOT_PATH")
root_path=os.getenv("WEB_ROOT_PATH", "") + "/api"
)
truth_router = APIRouter(prefix="/truths")
@app.get("/api/truths/all")
def get_all_truths(rich: bool = True, db: redis.Redis = Depends(get_db)):
@truth_router.get("/")
def get_all_truths(rich: bool = True, db: redis.Redis = Depends(get_db_factory())):
"""Retrieves all stored truths"""
keys = db.keys()
if rich is False:
@ -67,17 +71,17 @@ def get_all_truths(rich: bool = True, db: redis.Redis = Depends(get_db)):
return truths
@app.get("/api/truths/{truth_id}")
def get_truth(truth_id: str, db: redis.Redis = Depends(get_db)):
@truth_router.get("/{truth_id}")
def get_truth(truth_id: str, db: redis.Redis = Depends(get_db_factory())):
"""Retrieves a stored truth"""
data = db.get(truth_id)
data: str = db.get(truth_id)
if not data:
raise HTTPException(404, detail="%r not found." % id)
return json.loads(data)
@app.head("/api/truths/{truth_id}")
def head_truth(truth_id: str, db: redis.Redis = Depends(get_db)):
@truth_router.head("/{truth_id}")
def head_truth(truth_id: str, db: redis.Redis = Depends(get_db_factory())):
"""Checks that a truth exists"""
data = db.get(truth_id)
if not data:
@ -85,11 +89,11 @@ def head_truth(truth_id: str, db: redis.Redis = Depends(get_db)):
return Response()
@app.post("/api/truths", status_code=201)
def post_truth(payload: TruthPayload, response: JSONResponse, db: redis.Redis = Depends(get_db)):
@truth_router.post("/", status_code=201)
def post_truth(payload: TruthPayload, response: JSONResponse, db: redis.Redis = Depends(get_db_factory())):
"""Stores a new truth"""
data = payload.model_dump()
existing = db.get(data["id"])
existing: str = db.get(data["id"])
if existing:
parsed = json.loads(existing)
if parsed == existing:
@ -100,8 +104,8 @@ def post_truth(payload: TruthPayload, response: JSONResponse, db: redis.Redis =
return data
@app.put("/api/truths/{truth_id}")
def put_truth(truth_id: str, payload: TruthPayload, db: redis.Redis = Depends(get_db)):
@truth_router.put("/{truth_id}")
def put_truth(truth_id: str, payload: TruthPayload, db: redis.Redis = Depends(get_db_factory())):
"""Replaces a stored truth"""
data = payload.model_dump()
existing = db.get(truth_id)
@ -111,16 +115,51 @@ def put_truth(truth_id: str, payload: TruthPayload, db: redis.Redis = Depends(ge
return data
@app.delete("/api/truths/{truth_id}", status_code=204)
def delete_truth(truth_id: str, db: redis.Redis = Depends(get_db)):
@truth_router.delete("/{truth_id}", status_code=204)
def delete_truth(truth_id: str, db: redis.Redis = Depends(get_db_factory())):
"""Deletes a stored truth"""
if not db.delete(truth_id):
raise HTTPException(404, detail="%r not found." % truth_id)
return Response(status_code=204)
@app.get("/api/health")
def health(db: redis.Redis = Depends(get_db)):
app.include_router(truth_router)
ollama_router = APIRouter(prefix="/ollama")
@app.get("/threads")
def get_ollama_threads(db: redis.Redis = Depends(get_db_factory(0))):
"""
Retrieves all stored threads
This only returns thread keys as returning entire threads would be too much data.
"""
keys = db.keys()
return keys
@app.get("/thread/{thread_id}")
def get_ollama_thread(thread_id: str, db: redis.Redis = Depends(get_db_factory(0))):
"""Retrieves a stored thread"""
data: str = db.get(thread_id)
if not data:
raise HTTPException(404, detail="%r not found." % thread_id)
return json.loads(data)
@app.delete("/thread/{thread_id}", status_code=204)
def delete_ollama_thread(thread_id: str, db: redis.Redis = Depends(get_db_factory(0))):
"""Deletes a stored thread"""
if not db.delete(thread_id):
raise HTTPException(404, detail="%r not found." % thread_id)
return Response(status_code=204)
app.include_router(ollama_router)
@app.get("/health")
def health(db: redis.Redis = Depends(get_db_factory())):
try:
db.ping()
except ConnectionError: