-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcot_memory.py
More file actions
144 lines (129 loc) · 4.7 KB
/
cot_memory.py
File metadata and controls
144 lines (129 loc) · 4.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""
Chain-of-Thought Memory (Phase 16)
====================================
Stores the agent's reasoning steps (not just answers) in vector memory.
Allows the agent to recall WHY it made past decisions.
"""
import json
import re
import asyncio
from datetime import datetime
from typing import Any, Dict, List, Optional
COT_EXTRACTOR_SYSTEM = (
"You are a reasoning extractor. Given a question and its answer, produce a structured "
"JSON object capturing the chain of thought. Use this exact format:\n"
"{\n"
' "problem": "brief restatement of the problem",\n'
' "steps": ["step 1", "step 2", ...],\n'
' "key_assumptions": ["assumption 1", ...],\n'
' "alternatives_considered": ["alt 1", ...],\n'
' "conclusion": "one-sentence summary of what was decided"\n'
"}\n"
"Output ONLY valid JSON, no markdown."
)
class CoTMemory:
"""
Chain-of-Thought Memory — captures and stores structured reasoning steps
in the agent's vector memory for future recall.
"""
def __init__(self, vector_memory: Any, llm_provider: Any, database: Any = None):
self.vmem = vector_memory
self.llm = llm_provider
self.db = database
self._local_cache: List[Dict] = [] # in-memory fallback
def _extract_json(self, text: str) -> Optional[Dict]:
"""Extract first valid JSON object from text."""
try:
# Try direct parse
return json.loads(text.strip())
except json.JSONDecodeError:
pass
# Try regex extraction
match = re.search(r'\{.*\}', text, re.DOTALL)
if match:
try:
return json.loads(match.group())
except json.JSONDecodeError:
pass
return None
async def record(
self,
tenant_id: int,
question: str,
answer: str,
session_id: str = "default",
) -> Optional[Dict]:
"""
Asynchronously extract and store chain-of-thought for a Q/A pair.
Returns the extracted CoT dict, or None if extraction failed.
"""
prompt = (
f"QUESTION: {question}\n\n"
f"ANSWER: {answer}\n\n"
"Extract the chain of thought as JSON."
)
raw = await asyncio.to_thread(
self.llm.call, prompt, system=COT_EXTRACTOR_SYSTEM, history=[]
)
cot = self._extract_json(raw) if raw else None
if not cot:
# Minimal fallback
cot = {
"problem": question[:200],
"steps": [answer[:200]],
"key_assumptions": [],
"alternatives_considered": [],
"conclusion": answer[:100],
}
cot["timestamp"] = datetime.now().isoformat()
cot["session_id"] = session_id
# Ensure required keys exist (guard against partial JSON from LLM)
cot.setdefault("problem", question[:200])
cot.setdefault("steps", [answer[:200]])
cot.setdefault("key_assumptions", [])
cot.setdefault("alternatives_considered", [])
cot.setdefault("conclusion", answer[:100])
# Store in vector memory
cot_text = (
f"[CoT] Problem: {cot['problem']} | "
f"Steps: {' → '.join(cot.get('steps', [])[:3])} | "
f"Conclusion: {cot.get('conclusion', '')}"
)
try:
self.vmem.add(
tenant_id=tenant_id,
text=cot_text,
metadata={
"category": "cot_reasoning",
"session": session_id,
"problem": cot["problem"][:100],
},
)
except Exception:
pass
self._local_cache.append(cot)
return cot
def recall_why(self, tenant_id: int, query: str, n: int = 5) -> List[Dict]:
"""Search stored reasoning chains relevant to a query."""
try:
results = self.vmem.search(
tenant_id=tenant_id,
query=query,
n_results=n,
category="cot_reasoning",
)
return results
except Exception:
# Fallback: keyword search in cache
q_lower = query.lower()
matches = [
c for c in self._local_cache
if q_lower in c.get("problem", "").lower()
or q_lower in " ".join(c.get("steps", [])).lower()
]
return matches[:n]
def get_recent(self, n: int = 5) -> List[Dict]:
"""Return the N most recently stored chains."""
return self._local_cache[-n:]
def describe(self) -> str:
return f"CoTMemory — {len(self._local_cache)} reasoning chains stored in-session."