From f3c8a7bed787678ccaa8fe71c1b1cb6a46c394a8 Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Sun, 9 Jun 2024 01:11:52 +0100 Subject: [PATCH] Better document entire API --- src/server.py | 46 +++++++++++++++++++++++++++++++++++----------- 1 file changed, 35 insertions(+), 11 deletions(-) diff --git a/src/server.py b/src/server.py index 568bef2..da0f36e 100644 --- a/src/server.py +++ b/src/server.py @@ -19,11 +19,34 @@ JSON: typing.Union[ class TruthPayload(BaseModel): + """Represents a truth. This can be used to both create and get truths.""" id: str content: str author: typing.Literal["trump", "tate"] = Field(pattern=r"^(trump|tate)$") timestamp: float = Field(default_factory=time.time, ge=0) extra: typing.Optional[JSON] = None + """Any extra information, JSON compliant, can be entered here.""" + + +class OllamaThread(BaseModel): + """Represents an Ollama thread.""" + + class ThreadMessage(BaseModel): + """Represents a message in an Ollama thread.""" + role: typing.Literal["assistant", "system", "user"] + content: str + images: typing.Optional[list[str]] = [] + """An array of base64 images""" + + member: int + """The author's discord user ID""" + seed: int + """The seed used to generate the thread""" + messages: list[ThreadMessage] + + +class HealthResponse(BaseModel): + status: typing.Literal["ok"] security = HTTPBasic(realm="Jimmy") @@ -62,7 +85,7 @@ def get_db_factory(n: int = 11) -> typing.Callable[[], typing.Generator[redis.Re app = FastAPI( title="Jimmy v3 API", - version="3.0.0", + version="3.1.0", root_path=os.getenv("WEB_ROOT_PATH", "") + "/api" ) truth_router = APIRouter( @@ -72,7 +95,7 @@ truth_router = APIRouter( ) -@truth_router.get("") +@truth_router.get("", response_model=list[TruthPayload]) def get_all_truths(rich: bool = True, db: redis.Redis = Depends(get_db_factory())): """Retrieves all stored truths""" keys = db.keys() @@ -82,17 +105,18 @@ def get_all_truths(rich: bool = True, db: redis.Redis = Depends(get_db_factory() return truths -@truth_router.get("/all", deprecated=True) -def get_all_truths_deprecated(rich: bool = True, db: redis.Redis = Depends(get_db_factory())): - """Retrieves all stored truths""" +@truth_router.get("/all", deprecated=True, response_model=list[TruthPayload]) +def get_all_truths_deprecated(response: JSONResponse, rich: bool = True, db: redis.Redis = Depends(get_db_factory())): + """DEPRECATED - USE get_all_truths INSTEAD""" keys = db.keys() if rich is False: return keys truths = [json.loads(db.get(key)) for key in keys] + response.headers["X-Deprecated"] = "true" return truths -@truth_router.get("/{truth_id}") +@truth_router.get("/{truth_id}", response_model=TruthPayload) def get_truth(truth_id: str, db: redis.Redis = Depends(get_db_factory())): """Retrieves a stored truth""" data: str = db.get(truth_id) @@ -110,7 +134,7 @@ def head_truth(truth_id: str, db: redis.Redis = Depends(get_db_factory())): return Response() -@truth_router.post("", status_code=201) +@truth_router.post("", status_code=201, response_model=TruthPayload) def new_truth(payload: TruthPayload, response: JSONResponse, db: redis.Redis = Depends(get_db_factory())): """Creates a new truth""" data = payload.model_dump() @@ -125,7 +149,7 @@ def new_truth(payload: TruthPayload, response: JSONResponse, db: redis.Redis = D return data -@truth_router.put("/{truth_id}") +@truth_router.put("/{truth_id}", response_model=TruthPayload) def put_truth(truth_id: str, payload: TruthPayload, db: redis.Redis = Depends(get_db_factory())): """Replaces a stored truth""" data = payload.model_dump() @@ -152,7 +176,7 @@ ollama_router = APIRouter( ) -@ollama_router.get("/threads") +@ollama_router.get("/threads", response_model=list[str]) def get_ollama_threads(db: redis.Redis = Depends(get_db_factory(0))): """ Retrieves all stored threads @@ -163,7 +187,7 @@ def get_ollama_threads(db: redis.Redis = Depends(get_db_factory(0))): return keys -@ollama_router.get("/thread/{thread_id}") +@ollama_router.get("/thread/{thread_id}", response_model=OllamaThread) def get_ollama_thread(thread_id: str, db: redis.Redis = Depends(get_db_factory(0))): """Retrieves a stored thread""" data: str = db.get(thread_id) @@ -183,7 +207,7 @@ def delete_ollama_thread(thread_id: str, db: redis.Redis = Depends(get_db_factor app.include_router(ollama_router) -@app.get("/health") +@app.get("/health", response_model=HealthResponse) def health(db: redis.Redis = Depends(get_db_factory())): try: db.ping()