Skip to content

Commit 5166c0e

Browse files
authored
🔧 Allow backend agnostic configuration (#23)
1 parent 1d7bf65 commit 5166c0e

File tree

5 files changed

+31
-29
lines changed

5 files changed

+31
-29
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ rand = "0.9"
2727
humantime = "2.1"
2828
log = "0.4"
2929

30-
llm = { version = "1.3", features = ["openai", "rustls-tls"] }
30+
llm = { version = "1.3", features = ["rustls-tls"] }
3131
rmcp = { version = "0.6.0", features = [
3232
"client",
3333
"transport-sse-client-reqwest",

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,16 @@ The bot is in a very early stage of development but already usable.
2222

2323
The project requires the following environment variables:
2424

25-
- `OPENAI_TOKEN`: Your OpenAI API token.
25+
- `API_KEY`: Your LLM provider API key.
26+
- `LLM_PROVIDER`: The LLM provider to use. Supported values: `openai`, `anthropic`, `ollama`, `deepseek`, `xai`, `phind`, `google`, `groq`, `azureopenai`, `elevenlabs`, `cohere`, `mistral`, `openrouter`.
2627
- `MODEL`: The model to use.
2728
- `DISCORD_TOKEN`: Your Discord bot token.
2829
- `TEMPLATE_DIR`: The directory where your Tera templates are located. Defaults to `templates`.
2930
- `RATE_LIMIT_CONFIG`: The path to your rate limit configuration file. Defaults to `rate_limits.toml`.
3031
- `DATABASE_URL`: The URL to your database. For example `mysql://user:password@localhost/database`.
3132
- `WHITELIST`: A comma separated list of Discord snowflakes for channels, categories, or guilds in which the bot should respond. If empty, the bot will respond in all channels. Defaults to an empty string.
3233
- `OPT_OUT_LOCKOUT`: The time in seconds a user is locked out from the bot after opting out. Defaults to `30d`. Can use any time format supported by the `humantime` crate.
34+
- `COMPLETION_TIMEOUT`: The timeout for LLM completion requests. Defaults to `60s`. Can use any time format supported by the `humantime` crate.
3335

3436
## License
3537

src/handler/completion.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -150,10 +150,7 @@ pub async fn handle_completion(
150150

151151
let typing_notification = typing_indicator(ctx, new_message.channel_id);
152152

153-
let completion_request = tokio::time::timeout(
154-
std::time::Duration::from_secs(60),
155-
generate_llm_response(ctx, app, new_message),
156-
);
153+
let completion_request = tokio::time::timeout(app.completion_timeout, generate_llm_response(ctx, app, new_message));
157154

158155
// assuming typing notifications don't fail, we can just wait for the fork to finish and will keep sending typing
159156
// notifications in the meantime
@@ -366,10 +363,10 @@ async fn generate_llm_response<'a>(
366363
}
367364

368365
// add assistant's tool call to conversation
369-
conversation.push(ChatMessage::assistant().tool_use(tool_calls.clone()).content("").build());
366+
conversation.push(ChatMessage::assistant().tool_use(tool_calls.clone()).build());
370367

371368
// add tool results to conversation
372-
conversation.push(ChatMessage::assistant().tool_result(tool_results.clone()).content("").build());
369+
conversation.push(ChatMessage::user().tool_result(tool_results.clone()).build());
373370
} else {
374371
// No tool calls - we have our final response
375372
let content = response.text().ok_or(miette!("LLM response has no content"))?;

src/main.rs

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,11 @@ lazy_static! {
103103

104104
#[derive(Envconfig)]
105105
struct EnvConfig {
106-
#[envconfig(from = "OPENAI_TOKEN")]
107-
openai_token: String,
106+
#[envconfig(from = "API_KEY")]
107+
api_key: String,
108+
109+
#[envconfig(from = "LLM_PROVIDER")]
110+
llm_provider: String,
108111

109112
#[envconfig(from = "MODEL")]
110113
model: String,
@@ -126,6 +129,19 @@ struct EnvConfig {
126129

127130
#[envconfig(from = "WHITELIST", default = "")]
128131
whitelist: Whitelist,
132+
133+
#[envconfig(from = "COMPLETION_TIMEOUT", default = "60s")]
134+
completion_timeout: ParsedDuration,
135+
}
136+
137+
impl EnvConfig {
138+
/// Converts the provider string to LLMBackend enum
139+
fn get_llm_backend(&self) -> Result<LLMBackend> {
140+
use std::str::FromStr;
141+
LLMBackend::from_str(&self.llm_provider)
142+
.into_diagnostic()
143+
.wrap_err_with(|| format!("unsupported LLM provider: {}", self.llm_provider))
144+
}
129145
}
130146

131147
struct ParsedDuration(Duration);
@@ -224,6 +240,7 @@ struct AppState {
224240
context_settings: InvocationContextSettings,
225241
whitelist: Whitelist,
226242
opt_out_lockout: Duration,
243+
completion_timeout: Duration,
227244
}
228245

229246
type Context<'a> = poise::Context<'a, AppState, Report>;
@@ -248,9 +265,11 @@ async fn main() -> Result<()> {
248265
};
249266

250267
let llm_client = {
268+
let backend = env_config.get_llm_backend()?;
269+
251270
let mut builder = LLMBuilder::new()
252-
.backend(LLMBackend::OpenAI)
253-
.api_key(&env_config.openai_token)
271+
.backend(backend)
272+
.api_key(&env_config.api_key)
254273
.model(&env_config.model)
255274
.max_tokens(2000);
256275

@@ -355,6 +374,7 @@ async fn main() -> Result<()> {
355374
},
356375
whitelist: env_config.whitelist,
357376
opt_out_lockout: env_config.opt_out_lockout.0,
377+
completion_timeout: env_config.completion_timeout.0,
358378
})
359379
})
360380
})

src/mcp_config.rs

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,6 @@ mod tests {
112112
},
113113
_ => panic!("Expected HTTP server config"),
114114
}
115-
116-
assert!(server.is_http_based());
117-
assert_eq!(
118-
server.get_connection_url(),
119-
Some("http://192.168.200.10:8096/servers/web-search/sse")
120-
);
121115
}
122116

123117
/// Test parsing an SSE server configuration
@@ -153,8 +147,6 @@ mod tests {
153147
},
154148
_ => panic!("Expected SSE server config"),
155149
}
156-
157-
assert!(server.is_http_based());
158150
}
159151

160152
/// Test parsing a stdio server configuration
@@ -194,9 +186,6 @@ mod tests {
194186
},
195187
_ => panic!("Expected Stdio server config"),
196188
}
197-
198-
assert!(!server.is_http_based());
199-
assert_eq!(server.get_connection_url(), None);
200189
}
201190

202191
/// Test parsing multiple servers with mixed types
@@ -228,12 +217,6 @@ mod tests {
228217
assert!(config.servers.contains_key("web-search"));
229218
assert!(config.servers.contains_key("web-fetch"));
230219
assert!(config.servers.contains_key("local-tool"));
231-
232-
let http_count = config.servers.values().filter(|s| s.is_http_based()).count();
233-
let stdio_count = config.servers.values().filter(|s| !s.is_http_based()).count();
234-
235-
assert_eq!(http_count, 2);
236-
assert_eq!(stdio_count, 1);
237220
}
238221

239222
/// Test parsing config similar to the provided mcp.json

0 commit comments

Comments
 (0)