Include ollama in API
This commit is contained in:
parent
6fb9bb839b
commit
2625657699
1 changed files with 64 additions and 25 deletions
|
@ -6,7 +6,7 @@ import os
|
|||
import secrets
|
||||
import typing
|
||||
import time
|
||||
from fastapi import FastAPI, Depends, HTTPException
|
||||
from fastapi import FastAPI, Depends, HTTPException, APIRouter
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||
from pydantic import BaseModel, Field
|
||||
|
@ -40,25 +40,29 @@ def check_credentials(credentials: HTTPBasicCredentials = Depends(security)):
|
|||
return credentials
|
||||
|
||||
|
||||
def get_db() -> redis.Redis:
|
||||
def get_db_factory(n: int = 11) -> typing.Callable[[], typing.Generator[redis.Redis, None, None]]:
|
||||
def inner():
|
||||
uri = os.getenv("REDIS_URL", "redis://redis")
|
||||
conn = redis.Redis.from_url(uri)
|
||||
conn.select(n)
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
conn.close()
|
||||
return inner
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="Jimmy v3 API",
|
||||
version="3.0.0",
|
||||
dependencies=[Depends(check_credentials)],
|
||||
root_path=os.getenv("WEB_ROOT_PATH")
|
||||
root_path=os.getenv("WEB_ROOT_PATH", "") + "/api"
|
||||
)
|
||||
truth_router = APIRouter(prefix="/truths")
|
||||
|
||||
|
||||
@app.get("/api/truths/all")
|
||||
def get_all_truths(rich: bool = True, db: redis.Redis = Depends(get_db)):
|
||||
@truth_router.get("/")
|
||||
def get_all_truths(rich: bool = True, db: redis.Redis = Depends(get_db_factory())):
|
||||
"""Retrieves all stored truths"""
|
||||
keys = db.keys()
|
||||
if rich is False:
|
||||
|
@ -67,17 +71,17 @@ def get_all_truths(rich: bool = True, db: redis.Redis = Depends(get_db)):
|
|||
return truths
|
||||
|
||||
|
||||
@app.get("/api/truths/{truth_id}")
|
||||
def get_truth(truth_id: str, db: redis.Redis = Depends(get_db)):
|
||||
@truth_router.get("/{truth_id}")
|
||||
def get_truth(truth_id: str, db: redis.Redis = Depends(get_db_factory())):
|
||||
"""Retrieves a stored truth"""
|
||||
data = db.get(truth_id)
|
||||
data: str = db.get(truth_id)
|
||||
if not data:
|
||||
raise HTTPException(404, detail="%r not found." % id)
|
||||
return json.loads(data)
|
||||
|
||||
|
||||
@app.head("/api/truths/{truth_id}")
|
||||
def head_truth(truth_id: str, db: redis.Redis = Depends(get_db)):
|
||||
@truth_router.head("/{truth_id}")
|
||||
def head_truth(truth_id: str, db: redis.Redis = Depends(get_db_factory())):
|
||||
"""Checks that a truth exists"""
|
||||
data = db.get(truth_id)
|
||||
if not data:
|
||||
|
@ -85,11 +89,11 @@ def head_truth(truth_id: str, db: redis.Redis = Depends(get_db)):
|
|||
return Response()
|
||||
|
||||
|
||||
@app.post("/api/truths", status_code=201)
|
||||
def post_truth(payload: TruthPayload, response: JSONResponse, db: redis.Redis = Depends(get_db)):
|
||||
@truth_router.post("/", status_code=201)
|
||||
def post_truth(payload: TruthPayload, response: JSONResponse, db: redis.Redis = Depends(get_db_factory())):
|
||||
"""Stores a new truth"""
|
||||
data = payload.model_dump()
|
||||
existing = db.get(data["id"])
|
||||
existing: str = db.get(data["id"])
|
||||
if existing:
|
||||
parsed = json.loads(existing)
|
||||
if parsed == existing:
|
||||
|
@ -100,8 +104,8 @@ def post_truth(payload: TruthPayload, response: JSONResponse, db: redis.Redis =
|
|||
return data
|
||||
|
||||
|
||||
@app.put("/api/truths/{truth_id}")
|
||||
def put_truth(truth_id: str, payload: TruthPayload, db: redis.Redis = Depends(get_db)):
|
||||
@truth_router.put("/{truth_id}")
|
||||
def put_truth(truth_id: str, payload: TruthPayload, db: redis.Redis = Depends(get_db_factory())):
|
||||
"""Replaces a stored truth"""
|
||||
data = payload.model_dump()
|
||||
existing = db.get(truth_id)
|
||||
|
@ -111,16 +115,51 @@ def put_truth(truth_id: str, payload: TruthPayload, db: redis.Redis = Depends(ge
|
|||
return data
|
||||
|
||||
|
||||
@app.delete("/api/truths/{truth_id}", status_code=204)
|
||||
def delete_truth(truth_id: str, db: redis.Redis = Depends(get_db)):
|
||||
@truth_router.delete("/{truth_id}", status_code=204)
|
||||
def delete_truth(truth_id: str, db: redis.Redis = Depends(get_db_factory())):
|
||||
"""Deletes a stored truth"""
|
||||
if not db.delete(truth_id):
|
||||
raise HTTPException(404, detail="%r not found." % truth_id)
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@app.get("/api/health")
|
||||
def health(db: redis.Redis = Depends(get_db)):
|
||||
app.include_router(truth_router)
|
||||
ollama_router = APIRouter(prefix="/ollama")
|
||||
|
||||
|
||||
@app.get("/threads")
|
||||
def get_ollama_threads(db: redis.Redis = Depends(get_db_factory(0))):
|
||||
"""
|
||||
Retrieves all stored threads
|
||||
|
||||
This only returns thread keys as returning entire threads would be too much data.
|
||||
"""
|
||||
keys = db.keys()
|
||||
return keys
|
||||
|
||||
|
||||
@app.get("/thread/{thread_id}")
|
||||
def get_ollama_thread(thread_id: str, db: redis.Redis = Depends(get_db_factory(0))):
|
||||
"""Retrieves a stored thread"""
|
||||
data: str = db.get(thread_id)
|
||||
if not data:
|
||||
raise HTTPException(404, detail="%r not found." % thread_id)
|
||||
return json.loads(data)
|
||||
|
||||
|
||||
@app.delete("/thread/{thread_id}", status_code=204)
|
||||
def delete_ollama_thread(thread_id: str, db: redis.Redis = Depends(get_db_factory(0))):
|
||||
"""Deletes a stored thread"""
|
||||
if not db.delete(thread_id):
|
||||
raise HTTPException(404, detail="%r not found." % thread_id)
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
app.include_router(ollama_router)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health(db: redis.Redis = Depends(get_db_factory())):
|
||||
try:
|
||||
db.ping()
|
||||
except ConnectionError:
|
||||
|
|
Loading…
Reference in a new issue