-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathapi_run.py
More file actions
100 lines (90 loc) · 3.44 KB
/
api_run.py
File metadata and controls
100 lines (90 loc) · 3.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# main.py
from fastapi import FastAPI, HTTPException
from langchain_core.messages import HumanMessage
from graph import app, checkpointer
import pickle
api = FastAPI()
@api.post("/chat")
async def chat(msg: str, order_id: str = "", amount: float = 0, thread_id: str = ""):
state_input = {
"messages": [HumanMessage(content=msg)],
"order_id": order_id,
"amount": amount
}
result = app.invoke(
state_input,
config={
"configurable": {
"thread_id": thread_id # pass unique ID per user/session
}
}
)
return result
# Human approval endpoint
@api.post("/approve_refund")
async def approve(thread_id: str):
result = app.invoke(
None,
config={
"configurable": {
"thread_id": thread_id # same ID used when paused
},
"resume": True,
"value": True # human approved
}
)
return result
@api.get("/checkpoint/{thread_id}")
async def get_checkpoint(thread_id: str):
"""Retrieve checkpoint information for a given thread_id from the database."""
try:
with checkpointer._cursor() as cur:
cur.execute(
"SELECT checkpoint_id, parent_checkpoint_id, checkpoint, metadata "
"FROM checkpoints WHERE thread_id=%s "
"ORDER BY checkpoint_id DESC",
(thread_id,)
)
rows = cur.fetchall()
if not rows:
raise HTTPException(status_code=404, detail=f"No checkpoints found for thread_id: {thread_id}")
# Deserialize all checkpoints
checkpoints = []
for row in rows:
# DictCursor returns dictionaries, not tuples
checkpoint_id = row.get('checkpoint_id')
parent_id = row.get('parent_checkpoint_id')
checkpoint_data = row.get('checkpoint')
metadata_data = row.get('metadata')
try:
# Ensure checkpoint_data and metadata_data are bytes
if isinstance(checkpoint_data, str):
checkpoint_data = checkpoint_data.encode()
if isinstance(metadata_data, str):
metadata_data = metadata_data.encode()
checkpoint = pickle.loads(checkpoint_data)
metadata = pickle.loads(metadata_data)
checkpoints.append({
"checkpoint_id": checkpoint_id,
"parent_checkpoint_id": parent_id,
"checkpoint": checkpoint,
"metadata": metadata
})
except Exception as e:
checkpoints.append({
"checkpoint_id": checkpoint_id,
"parent_checkpoint_id": parent_id,
"error": f"Failed to deserialize: {str(e)}"
})
return {
"thread_id": thread_id,
"count": len(checkpoints),
"checkpoints": checkpoints
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Database error: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(api, host="0.0.0.0", port=8203, reload=False)