Include ollama in API
This commit is contained in:
parent
6fb9bb839b
commit
2625657699
1 changed files with 64 additions and 25 deletions
|
@ -6,7 +6,7 @@ import os
|
||||||
import secrets
|
import secrets
|
||||||
import typing
|
import typing
|
||||||
import time
|
import time
|
||||||
from fastapi import FastAPI, Depends, HTTPException
|
from fastapi import FastAPI, Depends, HTTPException, APIRouter
|
||||||
from fastapi.responses import JSONResponse, Response
|
from fastapi.responses import JSONResponse, Response
|
||||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
@ -40,25 +40,29 @@ def check_credentials(credentials: HTTPBasicCredentials = Depends(security)):
|
||||||
return credentials
|
return credentials
|
||||||
|
|
||||||
|
|
||||||
def get_db() -> redis.Redis:
|
def get_db_factory(n: int = 11) -> typing.Callable[[], typing.Generator[redis.Redis, None, None]]:
|
||||||
|
def inner():
|
||||||
uri = os.getenv("REDIS_URL", "redis://redis")
|
uri = os.getenv("REDIS_URL", "redis://redis")
|
||||||
conn = redis.Redis.from_url(uri)
|
conn = redis.Redis.from_url(uri)
|
||||||
|
conn.select(n)
|
||||||
try:
|
try:
|
||||||
yield conn
|
yield conn
|
||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
return inner
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="Jimmy v3 API",
|
title="Jimmy v3 API",
|
||||||
version="3.0.0",
|
version="3.0.0",
|
||||||
dependencies=[Depends(check_credentials)],
|
dependencies=[Depends(check_credentials)],
|
||||||
root_path=os.getenv("WEB_ROOT_PATH")
|
root_path=os.getenv("WEB_ROOT_PATH", "") + "/api"
|
||||||
)
|
)
|
||||||
|
truth_router = APIRouter(prefix="/truths")
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/truths/all")
|
@truth_router.get("/")
|
||||||
def get_all_truths(rich: bool = True, db: redis.Redis = Depends(get_db)):
|
def get_all_truths(rich: bool = True, db: redis.Redis = Depends(get_db_factory())):
|
||||||
"""Retrieves all stored truths"""
|
"""Retrieves all stored truths"""
|
||||||
keys = db.keys()
|
keys = db.keys()
|
||||||
if rich is False:
|
if rich is False:
|
||||||
|
@ -67,17 +71,17 @@ def get_all_truths(rich: bool = True, db: redis.Redis = Depends(get_db)):
|
||||||
return truths
|
return truths
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/truths/{truth_id}")
|
@truth_router.get("/{truth_id}")
|
||||||
def get_truth(truth_id: str, db: redis.Redis = Depends(get_db)):
|
def get_truth(truth_id: str, db: redis.Redis = Depends(get_db_factory())):
|
||||||
"""Retrieves a stored truth"""
|
"""Retrieves a stored truth"""
|
||||||
data = db.get(truth_id)
|
data: str = db.get(truth_id)
|
||||||
if not data:
|
if not data:
|
||||||
raise HTTPException(404, detail="%r not found." % id)
|
raise HTTPException(404, detail="%r not found." % id)
|
||||||
return json.loads(data)
|
return json.loads(data)
|
||||||
|
|
||||||
|
|
||||||
@app.head("/api/truths/{truth_id}")
|
@truth_router.head("/{truth_id}")
|
||||||
def head_truth(truth_id: str, db: redis.Redis = Depends(get_db)):
|
def head_truth(truth_id: str, db: redis.Redis = Depends(get_db_factory())):
|
||||||
"""Checks that a truth exists"""
|
"""Checks that a truth exists"""
|
||||||
data = db.get(truth_id)
|
data = db.get(truth_id)
|
||||||
if not data:
|
if not data:
|
||||||
|
@ -85,11 +89,11 @@ def head_truth(truth_id: str, db: redis.Redis = Depends(get_db)):
|
||||||
return Response()
|
return Response()
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/truths", status_code=201)
|
@truth_router.post("/", status_code=201)
|
||||||
def post_truth(payload: TruthPayload, response: JSONResponse, db: redis.Redis = Depends(get_db)):
|
def post_truth(payload: TruthPayload, response: JSONResponse, db: redis.Redis = Depends(get_db_factory())):
|
||||||
"""Stores a new truth"""
|
"""Stores a new truth"""
|
||||||
data = payload.model_dump()
|
data = payload.model_dump()
|
||||||
existing = db.get(data["id"])
|
existing: str = db.get(data["id"])
|
||||||
if existing:
|
if existing:
|
||||||
parsed = json.loads(existing)
|
parsed = json.loads(existing)
|
||||||
if parsed == existing:
|
if parsed == existing:
|
||||||
|
@ -100,8 +104,8 @@ def post_truth(payload: TruthPayload, response: JSONResponse, db: redis.Redis =
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
@app.put("/api/truths/{truth_id}")
|
@truth_router.put("/{truth_id}")
|
||||||
def put_truth(truth_id: str, payload: TruthPayload, db: redis.Redis = Depends(get_db)):
|
def put_truth(truth_id: str, payload: TruthPayload, db: redis.Redis = Depends(get_db_factory())):
|
||||||
"""Replaces a stored truth"""
|
"""Replaces a stored truth"""
|
||||||
data = payload.model_dump()
|
data = payload.model_dump()
|
||||||
existing = db.get(truth_id)
|
existing = db.get(truth_id)
|
||||||
|
@ -111,16 +115,51 @@ def put_truth(truth_id: str, payload: TruthPayload, db: redis.Redis = Depends(ge
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
@app.delete("/api/truths/{truth_id}", status_code=204)
|
@truth_router.delete("/{truth_id}", status_code=204)
|
||||||
def delete_truth(truth_id: str, db: redis.Redis = Depends(get_db)):
|
def delete_truth(truth_id: str, db: redis.Redis = Depends(get_db_factory())):
|
||||||
"""Deletes a stored truth"""
|
"""Deletes a stored truth"""
|
||||||
if not db.delete(truth_id):
|
if not db.delete(truth_id):
|
||||||
raise HTTPException(404, detail="%r not found." % truth_id)
|
raise HTTPException(404, detail="%r not found." % truth_id)
|
||||||
return Response(status_code=204)
|
return Response(status_code=204)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/health")
|
app.include_router(truth_router)
|
||||||
def health(db: redis.Redis = Depends(get_db)):
|
ollama_router = APIRouter(prefix="/ollama")
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/threads")
|
||||||
|
def get_ollama_threads(db: redis.Redis = Depends(get_db_factory(0))):
|
||||||
|
"""
|
||||||
|
Retrieves all stored threads
|
||||||
|
|
||||||
|
This only returns thread keys as returning entire threads would be too much data.
|
||||||
|
"""
|
||||||
|
keys = db.keys()
|
||||||
|
return keys
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/thread/{thread_id}")
|
||||||
|
def get_ollama_thread(thread_id: str, db: redis.Redis = Depends(get_db_factory(0))):
|
||||||
|
"""Retrieves a stored thread"""
|
||||||
|
data: str = db.get(thread_id)
|
||||||
|
if not data:
|
||||||
|
raise HTTPException(404, detail="%r not found." % thread_id)
|
||||||
|
return json.loads(data)
|
||||||
|
|
||||||
|
|
||||||
|
@app.delete("/thread/{thread_id}", status_code=204)
|
||||||
|
def delete_ollama_thread(thread_id: str, db: redis.Redis = Depends(get_db_factory(0))):
|
||||||
|
"""Deletes a stored thread"""
|
||||||
|
if not db.delete(thread_id):
|
||||||
|
raise HTTPException(404, detail="%r not found." % thread_id)
|
||||||
|
return Response(status_code=204)
|
||||||
|
|
||||||
|
|
||||||
|
app.include_router(ollama_router)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
def health(db: redis.Redis = Depends(get_db_factory())):
|
||||||
try:
|
try:
|
||||||
db.ping()
|
db.ping()
|
||||||
except ConnectionError:
|
except ConnectionError:
|
||||||
|
|
Loading…
Reference in a new issue