Skip to content

Commit 4baa5a7

Browse files
authored
Merge pull request #107 from redis/feat/litellm-llmclient-abstraction
Feat/litellm llmclient abstraction Replaces the custom LLM wrapper classes in llms.py with a unified LLMClient abstraction layer backed by LiteLLM. All LLM operations—chat completions, embeddings, and LangChain Embeddings instances—now go through a single entry point, reducing maintenance burden and enabling support for multiple providers without code changes. Issue: #105
2 parents 6141eed + 37f4dea commit 4baa5a7

32 files changed

+2185
-1898
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,3 +237,4 @@ TASK_MEMORY.md
237237
*.code-workspace
238238
/agent-memory-client/agent-memory-client-java/.gradle/
239239
augment*.md
240+
dev_docs/

CLAUDE.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,9 @@ query = VectorQuery(vector=embedding, vector_field_name="vector", return_fields=
110110

111111
## Critical Rules
112112

113+
### Import Placement
114+
Place all imports at the top of modules, not inside functions. Inline imports should only be used when strictly necessary (e.g., avoiding circular dependencies, optional dependencies, or significant startup performance concerns).
115+
113116
### Authentication
114117
- **PRODUCTION**: Never set `DISABLE_AUTH=true` in production
115118
- **DEVELOPMENT**: Use `DISABLE_AUTH=true` for local testing only
@@ -149,7 +152,11 @@ agent_memory_server/
149152
├── summarization.py # Conversation summarization
150153
├── extraction.py # Topic and entity extraction
151154
├── filters.py # Search filtering logic
152-
├── llms.py # LLM provider integrations
155+
├── llm/ # LLM client package (LiteLLM-based)
156+
│ ├── __init__.py # Re-exports for clean imports
157+
│ ├── client.py # LLMClient class with chat/embedding methods
158+
│ ├── types.py # ChatCompletionResponse, EmbeddingResponse, LLMBackend
159+
│ └── exceptions.py # LLMClientError, ModelValidationError, APIKeyMissingError
153160
├── migrations.py # Database schema migrations
154161
├── docket_tasks.py # Background task definitions
155162
├── cli.py # Command-line interface

agent_memory_server/api.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from agent_memory_server.config import settings
1212
from agent_memory_server.dependencies import HybridBackgroundTasks
1313
from agent_memory_server.filters import SessionId, UserId
14-
from agent_memory_server.llms import get_model_client, get_model_config
14+
from agent_memory_server.llm import LLMClient
1515
from agent_memory_server.logging import get_logger
1616
from agent_memory_server.models import (
1717
AckResponse,
@@ -101,7 +101,7 @@ def _get_effective_token_limit(
101101
return context_window_max
102102
# If model_name is provided, get its max_tokens from our config
103103
if model_name is not None:
104-
model_config = get_model_config(model_name)
104+
model_config = LLMClient.get_model_config(model_name)
105105
return model_config.max_tokens
106106
# Otherwise use a conservative default (GPT-3.5 context window)
107107
return 16000 # Conservative default
@@ -238,9 +238,8 @@ async def _summarize_working_memory(
238238
if current_tokens <= token_threshold:
239239
return memory
240240

241-
# Get model client for summarization
242-
client = await get_model_client(model)
243-
model_config = get_model_config(model)
241+
# Get model config for summarization
242+
model_config = LLMClient.get_model_config(model)
244243
summarization_max_tokens = model_config.max_tokens
245244

246245
# Token allocation for summarization (same logic as original summarize_session)
@@ -305,7 +304,6 @@ async def _summarize_working_memory(
305304
# Generate summary
306305
summary, summary_tokens_used = await _incremental_summary(
307306
model,
308-
client,
309307
memory.context, # Use existing context as base
310308
messages_to_summarize,
311309
)

agent_memory_server/extraction.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,7 @@
99
# Lazy-import transformers in get_ner_model to avoid heavy deps at startup
1010
from agent_memory_server.config import settings
1111
from agent_memory_server.filters import DiscreteMemoryExtracted, MemoryType
12-
from agent_memory_server.llms import (
13-
AnthropicClientWrapper,
14-
BedrockClientWrapper,
15-
OpenAIClientWrapper,
16-
get_model_client,
17-
)
12+
from agent_memory_server.llm import LLMClient
1813
from agent_memory_server.logging import get_logger
1914
from agent_memory_server.models import MemoryRecord
2015

@@ -128,15 +123,10 @@ def extract_entities(text: str) -> list[str]:
128123
async def extract_topics_llm(
129124
text: str,
130125
num_topics: int | None = None,
131-
client: OpenAIClientWrapper
132-
| AnthropicClientWrapper
133-
| BedrockClientWrapper
134-
| None = None,
135126
) -> list[str]:
136127
"""
137128
Extract topics from text using the LLM model.
138129
"""
139-
_client = client or await get_model_client(settings.topic_model)
140130
_num_topics = num_topics if num_topics is not None else settings.top_k_topics
141131

142132
prompt = f"""
@@ -152,17 +142,15 @@ async def extract_topics_llm(
152142

153143
async for attempt in AsyncRetrying(stop=stop_after_attempt(3)):
154144
with attempt:
155-
response = await _client.create_chat_completion(
145+
response = await LLMClient.create_chat_completion(
156146
model=settings.generation_model,
157-
prompt=prompt,
147+
messages=[{"role": "user", "content": prompt}],
158148
response_format={"type": "json_object"},
159149
)
160150
try:
161-
topics = json.loads(response.choices[0].message.content)["topics"]
151+
topics = json.loads(response.content)["topics"]
162152
except (json.JSONDecodeError, KeyError):
163-
logger.error(
164-
f"Error decoding JSON: {response.choices[0].message.content}"
165-
)
153+
logger.error(f"Error decoding JSON: {response.content}")
166154
topics = []
167155
if topics:
168156
topics = topics[:_num_topics]
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""
2+
LLM client package for unified LLM operations.
3+
4+
This package provides a single entry point for all LLM interactions,
5+
abstracting away the underlying provider (OpenAI, Anthropic, Bedrock, etc.).
6+
7+
Usage:
8+
from agent_memory_server.llm import LLMClient, ChatCompletionResponse
9+
10+
response = await LLMClient.create_chat_completion(
11+
model="gpt-4o",
12+
messages=[{"role": "user", "content": "Hello"}],
13+
)
14+
"""
15+
16+
from agent_memory_server.llm.client import (
17+
LLMClient,
18+
get_model_config,
19+
optimize_query_for_vector_search,
20+
)
21+
from agent_memory_server.llm.exceptions import (
22+
APIKeyMissingError,
23+
LLMClientError,
24+
ModelValidationError,
25+
)
26+
from agent_memory_server.llm.types import (
27+
ChatCompletionResponse,
28+
EmbeddingResponse,
29+
)
30+
31+
32+
__all__ = [
33+
# Client
34+
"LLMClient",
35+
# Convenience functions
36+
"get_model_config",
37+
"optimize_query_for_vector_search",
38+
# Exceptions
39+
"LLMClientError",
40+
"ModelValidationError",
41+
"APIKeyMissingError",
42+
# Types
43+
"ChatCompletionResponse",
44+
"EmbeddingResponse",
45+
]

0 commit comments

Comments
 (0)