Include ollama in API

This commit is contained in:
Nexus 2024-06-06 00:15:01 +01:00
parent 6fb9bb839b
commit 2625657699
Signed by: nex
GPG key ID: 0FA334385D0B689F

View file

@ -6,7 +6,7 @@ import os
import secrets import secrets
import typing import typing
import time import time
from fastapi import FastAPI, Depends, HTTPException from fastapi import FastAPI, Depends, HTTPException, APIRouter
from fastapi.responses import JSONResponse, Response from fastapi.responses import JSONResponse, Response
from fastapi.security import HTTPBasic, HTTPBasicCredentials from fastapi.security import HTTPBasic, HTTPBasicCredentials
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -40,25 +40,29 @@ def check_credentials(credentials: HTTPBasicCredentials = Depends(security)):
return credentials return credentials
def get_db() -> redis.Redis: def get_db_factory(n: int = 11) -> typing.Callable[[], typing.Generator[redis.Redis, None, None]]:
def inner():
uri = os.getenv("REDIS_URL", "redis://redis") uri = os.getenv("REDIS_URL", "redis://redis")
conn = redis.Redis.from_url(uri) conn = redis.Redis.from_url(uri)
conn.select(n)
try: try:
yield conn yield conn
finally: finally:
conn.close() conn.close()
return inner
app = FastAPI( app = FastAPI(
title="Jimmy v3 API", title="Jimmy v3 API",
version="3.0.0", version="3.0.0",
dependencies=[Depends(check_credentials)], dependencies=[Depends(check_credentials)],
root_path=os.getenv("WEB_ROOT_PATH") root_path=os.getenv("WEB_ROOT_PATH", "") + "/api"
) )
truth_router = APIRouter(prefix="/truths")
@app.get("/api/truths/all") @truth_router.get("/")
def get_all_truths(rich: bool = True, db: redis.Redis = Depends(get_db)): def get_all_truths(rich: bool = True, db: redis.Redis = Depends(get_db_factory())):
"""Retrieves all stored truths""" """Retrieves all stored truths"""
keys = db.keys() keys = db.keys()
if rich is False: if rich is False:
@ -67,17 +71,17 @@ def get_all_truths(rich: bool = True, db: redis.Redis = Depends(get_db)):
return truths return truths
@app.get("/api/truths/{truth_id}") @truth_router.get("/{truth_id}")
def get_truth(truth_id: str, db: redis.Redis = Depends(get_db)): def get_truth(truth_id: str, db: redis.Redis = Depends(get_db_factory())):
"""Retrieves a stored truth""" """Retrieves a stored truth"""
data = db.get(truth_id) data: str = db.get(truth_id)
if not data: if not data:
raise HTTPException(404, detail="%r not found." % id) raise HTTPException(404, detail="%r not found." % id)
return json.loads(data) return json.loads(data)
@app.head("/api/truths/{truth_id}") @truth_router.head("/{truth_id}")
def head_truth(truth_id: str, db: redis.Redis = Depends(get_db)): def head_truth(truth_id: str, db: redis.Redis = Depends(get_db_factory())):
"""Checks that a truth exists""" """Checks that a truth exists"""
data = db.get(truth_id) data = db.get(truth_id)
if not data: if not data:
@ -85,11 +89,11 @@ def head_truth(truth_id: str, db: redis.Redis = Depends(get_db)):
return Response() return Response()
@app.post("/api/truths", status_code=201) @truth_router.post("/", status_code=201)
def post_truth(payload: TruthPayload, response: JSONResponse, db: redis.Redis = Depends(get_db)): def post_truth(payload: TruthPayload, response: JSONResponse, db: redis.Redis = Depends(get_db_factory())):
"""Stores a new truth""" """Stores a new truth"""
data = payload.model_dump() data = payload.model_dump()
existing = db.get(data["id"]) existing: str = db.get(data["id"])
if existing: if existing:
parsed = json.loads(existing) parsed = json.loads(existing)
if parsed == existing: if parsed == existing:
@ -100,8 +104,8 @@ def post_truth(payload: TruthPayload, response: JSONResponse, db: redis.Redis =
return data return data
@app.put("/api/truths/{truth_id}") @truth_router.put("/{truth_id}")
def put_truth(truth_id: str, payload: TruthPayload, db: redis.Redis = Depends(get_db)): def put_truth(truth_id: str, payload: TruthPayload, db: redis.Redis = Depends(get_db_factory())):
"""Replaces a stored truth""" """Replaces a stored truth"""
data = payload.model_dump() data = payload.model_dump()
existing = db.get(truth_id) existing = db.get(truth_id)
@ -111,16 +115,51 @@ def put_truth(truth_id: str, payload: TruthPayload, db: redis.Redis = Depends(ge
return data return data
@app.delete("/api/truths/{truth_id}", status_code=204) @truth_router.delete("/{truth_id}", status_code=204)
def delete_truth(truth_id: str, db: redis.Redis = Depends(get_db)): def delete_truth(truth_id: str, db: redis.Redis = Depends(get_db_factory())):
"""Deletes a stored truth""" """Deletes a stored truth"""
if not db.delete(truth_id): if not db.delete(truth_id):
raise HTTPException(404, detail="%r not found." % truth_id) raise HTTPException(404, detail="%r not found." % truth_id)
return Response(status_code=204) return Response(status_code=204)
@app.get("/api/health") app.include_router(truth_router)
def health(db: redis.Redis = Depends(get_db)): ollama_router = APIRouter(prefix="/ollama")
@app.get("/threads")
def get_ollama_threads(db: redis.Redis = Depends(get_db_factory(0))):
"""
Retrieves all stored threads
This only returns thread keys as returning entire threads would be too much data.
"""
keys = db.keys()
return keys
@app.get("/thread/{thread_id}")
def get_ollama_thread(thread_id: str, db: redis.Redis = Depends(get_db_factory(0))):
"""Retrieves a stored thread"""
data: str = db.get(thread_id)
if not data:
raise HTTPException(404, detail="%r not found." % thread_id)
return json.loads(data)
@app.delete("/thread/{thread_id}", status_code=204)
def delete_ollama_thread(thread_id: str, db: redis.Redis = Depends(get_db_factory(0))):
"""Deletes a stored thread"""
if not db.delete(thread_id):
raise HTTPException(404, detail="%r not found." % thread_id)
return Response(status_code=204)
app.include_router(ollama_router)
@app.get("/health")
def health(db: redis.Redis = Depends(get_db_factory())):
try: try:
db.ping() db.ping()
except ConnectionError: except ConnectionError: