From 26256576990388f5398a1258476cceca1b0f0adc Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Thu, 6 Jun 2024 00:15:01 +0100 Subject: [PATCH] Include ollama in API --- src/server.py | 89 ++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 64 insertions(+), 25 deletions(-) diff --git a/src/server.py b/src/server.py index c46499e..fd89cde 100644 --- a/src/server.py +++ b/src/server.py @@ -6,7 +6,7 @@ import os import secrets import typing import time -from fastapi import FastAPI, Depends, HTTPException +from fastapi import FastAPI, Depends, HTTPException, APIRouter from fastapi.responses import JSONResponse, Response from fastapi.security import HTTPBasic, HTTPBasicCredentials from pydantic import BaseModel, Field @@ -40,25 +40,29 @@ def check_credentials(credentials: HTTPBasicCredentials = Depends(security)): return credentials -def get_db() -> redis.Redis: - uri = os.getenv("REDIS_URL", "redis://redis") - conn = redis.Redis.from_url(uri) - try: - yield conn - finally: - conn.close() +def get_db_factory(n: int = 11) -> typing.Callable[[], typing.Generator[redis.Redis, None, None]]: + def inner(): + uri = os.getenv("REDIS_URL", "redis://redis") + conn = redis.Redis.from_url(uri) + conn.select(n) + try: + yield conn + finally: + conn.close() + return inner app = FastAPI( title="Jimmy v3 API", version="3.0.0", dependencies=[Depends(check_credentials)], - root_path=os.getenv("WEB_ROOT_PATH") + root_path=os.getenv("WEB_ROOT_PATH", "") + "/api" ) +truth_router = APIRouter(prefix="/truths") -@app.get("/api/truths/all") -def get_all_truths(rich: bool = True, db: redis.Redis = Depends(get_db)): +@truth_router.get("/") +def get_all_truths(rich: bool = True, db: redis.Redis = Depends(get_db_factory())): """Retrieves all stored truths""" keys = db.keys() if rich is False: @@ -67,17 +71,17 @@ def get_all_truths(rich: bool = True, db: redis.Redis = Depends(get_db)): return truths -@app.get("/api/truths/{truth_id}") -def get_truth(truth_id: str, db: redis.Redis = Depends(get_db)): +@truth_router.get("/{truth_id}") +def get_truth(truth_id: str, db: redis.Redis = Depends(get_db_factory())): """Retrieves a stored truth""" - data = db.get(truth_id) + data: str = db.get(truth_id) if not data: raise HTTPException(404, detail="%r not found." % id) return json.loads(data) -@app.head("/api/truths/{truth_id}") -def head_truth(truth_id: str, db: redis.Redis = Depends(get_db)): +@truth_router.head("/{truth_id}") +def head_truth(truth_id: str, db: redis.Redis = Depends(get_db_factory())): """Checks that a truth exists""" data = db.get(truth_id) if not data: @@ -85,11 +89,11 @@ def head_truth(truth_id: str, db: redis.Redis = Depends(get_db)): return Response() -@app.post("/api/truths", status_code=201) -def post_truth(payload: TruthPayload, response: JSONResponse, db: redis.Redis = Depends(get_db)): +@truth_router.post("/", status_code=201) +def post_truth(payload: TruthPayload, response: JSONResponse, db: redis.Redis = Depends(get_db_factory())): """Stores a new truth""" data = payload.model_dump() - existing = db.get(data["id"]) + existing: str = db.get(data["id"]) if existing: parsed = json.loads(existing) if parsed == existing: @@ -100,8 +104,8 @@ def post_truth(payload: TruthPayload, response: JSONResponse, db: redis.Redis = return data -@app.put("/api/truths/{truth_id}") -def put_truth(truth_id: str, payload: TruthPayload, db: redis.Redis = Depends(get_db)): +@truth_router.put("/{truth_id}") +def put_truth(truth_id: str, payload: TruthPayload, db: redis.Redis = Depends(get_db_factory())): """Replaces a stored truth""" data = payload.model_dump() existing = db.get(truth_id) @@ -111,16 +115,51 @@ def put_truth(truth_id: str, payload: TruthPayload, db: redis.Redis = Depends(ge return data -@app.delete("/api/truths/{truth_id}", status_code=204) -def delete_truth(truth_id: str, db: redis.Redis = Depends(get_db)): +@truth_router.delete("/{truth_id}", status_code=204) +def delete_truth(truth_id: str, db: redis.Redis = Depends(get_db_factory())): """Deletes a stored truth""" if not db.delete(truth_id): raise HTTPException(404, detail="%r not found." % truth_id) return Response(status_code=204) -@app.get("/api/health") -def health(db: redis.Redis = Depends(get_db)): +app.include_router(truth_router) +ollama_router = APIRouter(prefix="/ollama") + + +@app.get("/threads") +def get_ollama_threads(db: redis.Redis = Depends(get_db_factory(0))): + """ + Retrieves all stored threads + + This only returns thread keys as returning entire threads would be too much data. + """ + keys = db.keys() + return keys + + +@app.get("/thread/{thread_id}") +def get_ollama_thread(thread_id: str, db: redis.Redis = Depends(get_db_factory(0))): + """Retrieves a stored thread""" + data: str = db.get(thread_id) + if not data: + raise HTTPException(404, detail="%r not found." % thread_id) + return json.loads(data) + + +@app.delete("/thread/{thread_id}", status_code=204) +def delete_ollama_thread(thread_id: str, db: redis.Redis = Depends(get_db_factory(0))): + """Deletes a stored thread""" + if not db.delete(thread_id): + raise HTTPException(404, detail="%r not found." % thread_id) + return Response(status_code=204) + + +app.include_router(ollama_router) + + +@app.get("/health") +def health(db: redis.Redis = Depends(get_db_factory())): try: db.ping() except ConnectionError: