diff --git a/langchain4j-open-ai-spring-boot-starter/src/main/java/dev/langchain4j/openai/spring/AutoConfig.java b/langchain4j-open-ai-spring-boot-starter/src/main/java/dev/langchain4j/openai/spring/AutoConfig.java index 292f17ec..9ace5bbe 100644 --- a/langchain4j-open-ai-spring-boot-starter/src/main/java/dev/langchain4j/openai/spring/AutoConfig.java +++ b/langchain4j-open-ai-spring-boot-starter/src/main/java/dev/langchain4j/openai/spring/AutoConfig.java @@ -6,6 +6,10 @@ import dev.langchain4j.model.openai.*; import org.springframework.beans.factory.ObjectProvider; import org.springframework.beans.factory.annotation.Qualifier; +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.beans.factory.support.BeanDefinitionRegistry; +import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor; +import org.springframework.beans.factory.support.GenericBeanDefinition; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; @@ -13,12 +17,26 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.boot.context.properties.bind.Binder; import org.springframework.context.annotation.Bean; +import org.springframework.core.env.Environment; import org.springframework.core.task.AsyncTaskExecutor; import org.springframework.core.task.support.ContextPropagatingTaskDecorator; import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; import org.springframework.web.client.RestClient; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; + +import org.springframework.context.ApplicationContext; +import org.springframework.context.ApplicationContextAware; +import org.springframework.core.env.ConfigurableEnvironment; +import org.springframework.core.env.EnumerablePropertySource; +import org.springframework.core.env.PropertySource; + import static dev.langchain4j.openai.spring.Properties.PREFIX; @AutoConfiguration(after = RestClientAutoConfiguration.class) @@ -378,4 +396,281 @@ HttpClientBuilder openAiImageModelHttpClientBuilder(ObjectProvider KNOWN_PROPERTIES = Set.of( + "base-url", "api-key", "organization-id", "project-id", "model-name", + "temperature", "top-p", "stop", "max-tokens", "max-completion-tokens", + "presence-penalty", "frequency-penalty", "logit-bias", "response-format", + "supported-capabilities", "strict-json-schema", "seed", "user", + "strict-tools", "parallel-tool-calls", "store", "metadata", "service-tier", + "reasoning-effort", "return-thinking", "timeout", "max-retries", + "log-requests", "log-responses", "custom-headers", "custom-query-params", + "custom-parameters" + ); + + private final Environment environment; + private ApplicationContext applicationContext; + + NamedModelBeanRegistrar(Environment environment) { + this.environment = environment; + } + + @Override + public void setApplicationContext(ApplicationContext applicationContext) { + this.applicationContext = applicationContext; + } + + @Override + public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) { + registerNamedChatModels(registry, CHAT_MODEL_PREFIX, "openAiChatModel", false); + registerNamedChatModels(registry, STREAMING_CHAT_MODEL_PREFIX, "openAiStreamingChatModel", true); + } + + private void registerNamedChatModels(BeanDefinitionRegistry registry, String prefix, + String beanNamePrefix, boolean streaming) { + ChatModelProperties globalProps = Binder.get(environment) + .bind(prefix, ChatModelProperties.class) + .orElse(null); + + Set namedModelKeys = findNamedModelKeys(prefix); + + for (String modelName : namedModelKeys) { + String namedPrefix = prefix + "." + modelName; + ChatModelProperties namedProps = Binder.get(environment) + .bind(namedPrefix, ChatModelProperties.class) + .orElse(null); + + if (namedProps == null) { + continue; + } + + ChatModelProperties mergedProps = mergeWithGlobal(globalProps, namedProps); + + if (mergedProps.apiKey() == null) { + continue; + } + + String beanName = beanNamePrefix + toPascalCase(modelName); + + GenericBeanDefinition beanDefinition = new GenericBeanDefinition(); + beanDefinition.setScope(BeanDefinition.SCOPE_SINGLETON); + + if (streaming) { + beanDefinition.setBeanClass(OpenAiStreamingChatModel.class); + beanDefinition.setInstanceSupplier(() -> { + HttpClientBuilder httpClientBuilder = createHttpClientBuilder(true); + List listeners = getListeners(); + return createStreamingChatModel(mergedProps, httpClientBuilder, listeners); + }); + } else { + beanDefinition.setBeanClass(OpenAiChatModel.class); + beanDefinition.setInstanceSupplier(() -> { + HttpClientBuilder httpClientBuilder = createHttpClientBuilder(false); + List listeners = getListeners(); + return createChatModel(mergedProps, httpClientBuilder, listeners); + }); + } + + registry.registerBeanDefinition(beanName, beanDefinition); + } + } + + private Set findNamedModelKeys(String prefix) { + Set namedKeys = new HashSet<>(); + String searchPrefix = prefix + "."; + + if (environment instanceof ConfigurableEnvironment configurableEnv) { + for (PropertySource propertySource : configurableEnv.getPropertySources()) { + if (propertySource instanceof EnumerablePropertySource enumerable) { + for (String propertyName : enumerable.getPropertyNames()) { + if (propertyName.startsWith(searchPrefix)) { + String remainder = propertyName.substring(searchPrefix.length()); + int dotIndex = remainder.indexOf('.'); + String firstSegment = dotIndex > 0 ? remainder.substring(0, dotIndex) : remainder; + + if (!KNOWN_PROPERTIES.contains(firstSegment)) { + namedKeys.add(firstSegment); + } + } + } + } + } + } + + return namedKeys; + } + + private HttpClientBuilder createHttpClientBuilder(boolean streaming) { + ObjectProvider restClientBuilderProvider = + applicationContext.getBeanProvider(RestClient.Builder.class); + RestClient.Builder restClientBuilder = restClientBuilderProvider.getIfAvailable(RestClient::builder); + + if (streaming) { + ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor(); + executor.setQueueCapacity(0); + executor.setThreadNamePrefix(TASK_EXECUTOR_THREAD_NAME_PREFIX); + executor.initialize(); + return SpringRestClient.builder() + .restClientBuilder(restClientBuilder) + .streamingRequestExecutor(executor); + } else { + return SpringRestClient.builder() + .restClientBuilder(restClientBuilder) + .createDefaultStreamingRequestExecutor(false); + } + } + + private List getListeners() { + return applicationContext.getBeanProvider(ChatModelListener.class) + .orderedStream() + .toList(); + } + + private ChatModelProperties mergeWithGlobal(ChatModelProperties global, ChatModelProperties named) { + if (global == null) { + return named; + } + return new ChatModelProperties( + named.baseUrl() != null ? named.baseUrl() : global.baseUrl(), + named.apiKey() != null ? named.apiKey() : global.apiKey(), + named.organizationId() != null ? named.organizationId() : global.organizationId(), + named.projectId() != null ? named.projectId() : global.projectId(), + named.modelName() != null ? named.modelName() : global.modelName(), + named.temperature() != null ? named.temperature() : global.temperature(), + named.topP() != null ? named.topP() : global.topP(), + named.stop() != null ? named.stop() : global.stop(), + named.maxTokens() != null ? named.maxTokens() : global.maxTokens(), + named.maxCompletionTokens() != null ? named.maxCompletionTokens() : global.maxCompletionTokens(), + named.presencePenalty() != null ? named.presencePenalty() : global.presencePenalty(), + named.frequencyPenalty() != null ? named.frequencyPenalty() : global.frequencyPenalty(), + named.logitBias() != null ? named.logitBias() : global.logitBias(), + named.responseFormat() != null ? named.responseFormat() : global.responseFormat(), + named.supportedCapabilities() != null ? named.supportedCapabilities() : global.supportedCapabilities(), + named.strictJsonSchema() != null ? named.strictJsonSchema() : global.strictJsonSchema(), + named.seed() != null ? named.seed() : global.seed(), + named.user() != null ? named.user() : global.user(), + named.strictTools() != null ? named.strictTools() : global.strictTools(), + named.parallelToolCalls() != null ? named.parallelToolCalls() : global.parallelToolCalls(), + named.store() != null ? named.store() : global.store(), + named.metadata() != null ? named.metadata() : global.metadata(), + named.serviceTier() != null ? named.serviceTier() : global.serviceTier(), + named.reasoningEffort() != null ? named.reasoningEffort() : global.reasoningEffort(), + named.returnThinking() != null ? named.returnThinking() : global.returnThinking(), + named.timeout() != null ? named.timeout() : global.timeout(), + named.maxRetries() != null ? named.maxRetries() : global.maxRetries(), + named.logRequests() != null ? named.logRequests() : global.logRequests(), + named.logResponses() != null ? named.logResponses() : global.logResponses(), + named.customHeaders() != null ? named.customHeaders() : global.customHeaders(), + named.customQueryParams() != null ? named.customQueryParams() : global.customQueryParams(), + named.customParameters() != null ? named.customParameters() : global.customParameters() + ); + } + + private OpenAiChatModel createChatModel(ChatModelProperties props, + HttpClientBuilder httpClientBuilder, + List listeners) { + return OpenAiChatModel.builder() + .httpClientBuilder(httpClientBuilder) + .baseUrl(props.baseUrl()) + .apiKey(props.apiKey()) + .organizationId(props.organizationId()) + .projectId(props.projectId()) + .modelName(props.modelName()) + .temperature(props.temperature()) + .topP(props.topP()) + .stop(props.stop()) + .maxTokens(props.maxTokens()) + .maxCompletionTokens(props.maxCompletionTokens()) + .presencePenalty(props.presencePenalty()) + .frequencyPenalty(props.frequencyPenalty()) + .logitBias(props.logitBias()) + .responseFormat(props.responseFormat()) + .supportedCapabilities(props.supportedCapabilities()) + .strictJsonSchema(props.strictJsonSchema()) + .seed(props.seed()) + .user(props.user()) + .strictTools(props.strictTools()) + .parallelToolCalls(props.parallelToolCalls()) + .store(props.store()) + .metadata(props.metadata()) + .serviceTier(props.serviceTier()) + .defaultRequestParameters(OpenAiChatRequestParameters.builder() + .reasoningEffort(props.reasoningEffort()) + .customParameters(props.customParameters()) + .build()) + .returnThinking(props.returnThinking()) + .timeout(props.timeout()) + .maxRetries(props.maxRetries()) + .logRequests(props.logRequests()) + .logResponses(props.logResponses()) + .customHeaders(props.customHeaders()) + .customQueryParams(props.customQueryParams()) + .listeners(listeners) + .build(); + } + + private OpenAiStreamingChatModel createStreamingChatModel(ChatModelProperties props, + HttpClientBuilder httpClientBuilder, + List listeners) { + return OpenAiStreamingChatModel.builder() + .httpClientBuilder(httpClientBuilder) + .baseUrl(props.baseUrl()) + .apiKey(props.apiKey()) + .organizationId(props.organizationId()) + .projectId(props.projectId()) + .modelName(props.modelName()) + .temperature(props.temperature()) + .topP(props.topP()) + .stop(props.stop()) + .maxTokens(props.maxTokens()) + .maxCompletionTokens(props.maxCompletionTokens()) + .presencePenalty(props.presencePenalty()) + .frequencyPenalty(props.frequencyPenalty()) + .logitBias(props.logitBias()) + .responseFormat(props.responseFormat()) + .seed(props.seed()) + .user(props.user()) + .strictTools(props.strictTools()) + .parallelToolCalls(props.parallelToolCalls()) + .store(props.store()) + .metadata(props.metadata()) + .serviceTier(props.serviceTier()) + .defaultRequestParameters(OpenAiChatRequestParameters.builder() + .reasoningEffort(props.reasoningEffort()) + .customParameters(props.customParameters()) + .build()) + .returnThinking(props.returnThinking()) + .timeout(props.timeout()) + .logRequests(props.logRequests()) + .logResponses(props.logResponses()) + .customHeaders(props.customHeaders()) + .customQueryParams(props.customQueryParams()) + .listeners(listeners) + .build(); + } + + private String toPascalCase(String input) { + if (input == null || input.isEmpty()) { + return input; + } + return Arrays.stream(input.split("[-_]")) + .map(segment -> { + if (segment.isEmpty()) { + return ""; + } + return Character.toUpperCase(segment.charAt(0)) + segment.substring(1).toLowerCase(); + }) + .collect(Collectors.joining()); + } + } } \ No newline at end of file diff --git a/langchain4j-open-ai-spring-boot-starter/src/test/java/dev/langchain4j/openai/spring/AutoConfigIT.java b/langchain4j-open-ai-spring-boot-starter/src/test/java/dev/langchain4j/openai/spring/AutoConfigIT.java index ded2abe9..00032019 100644 --- a/langchain4j-open-ai-spring-boot-starter/src/test/java/dev/langchain4j/openai/spring/AutoConfigIT.java +++ b/langchain4j-open-ai-spring-boot-starter/src/test/java/dev/langchain4j/openai/spring/AutoConfigIT.java @@ -589,6 +589,140 @@ void should_bind_custom_parameters_from_properties() { }); } + @Test + void should_provide_named_chat_models() { + contextRunner + .withPropertyValues( + "langchain4j.open-ai.chat-model.base-url=" + BASE_URL, + "langchain4j.open-ai.chat-model.api-key=" + API_KEY, + "langchain4j.open-ai.chat-model.fast.model-name=gpt-4o-mini", + "langchain4j.open-ai.chat-model.fast.max-tokens=20", + "langchain4j.open-ai.chat-model.smart.model-name=gpt-4o-mini", + "langchain4j.open-ai.chat-model.smart.max-tokens=20" + ) + .run(context -> { + OpenAiChatModel fastModel = context.getBean("openAiChatModelFast", OpenAiChatModel.class); + OpenAiChatModel smartModel = context.getBean("openAiChatModelSmart", OpenAiChatModel.class); + + assertThat(fastModel).isNotNull(); + assertThat(smartModel).isNotNull(); + assertThat(fastModel).isNotSameAs(smartModel); + + assertThat(fastModel.chat("What is 2+2?")).contains("4"); + assertThat(smartModel.chat("What is 3+3?")).contains("6"); + }); + } + + @Test + void should_provide_default_and_named_chat_models_together() { + contextRunner + .withPropertyValues( + "langchain4j.open-ai.chat-model.base-url=" + BASE_URL, + "langchain4j.open-ai.chat-model.api-key=" + API_KEY, + "langchain4j.open-ai.chat-model.model-name=gpt-4o-mini", + "langchain4j.open-ai.chat-model.max-tokens=20", + "langchain4j.open-ai.chat-model.mini.model-name=gpt-4o-mini", + "langchain4j.open-ai.chat-model.mini.max-tokens=20" + ) + .run(context -> { + OpenAiChatModel defaultModel = context.getBean("openAiChatModel", OpenAiChatModel.class); + assertThat(defaultModel).isNotNull(); + + OpenAiChatModel miniModel = context.getBean("openAiChatModelMini", OpenAiChatModel.class); + assertThat(miniModel).isNotNull(); + + assertThat(defaultModel).isNotSameAs(miniModel); + + assertThat(defaultModel.chat("What is 1+1?")).contains("2"); + assertThat(miniModel.chat("What is 2+2?")).contains("4"); + }); + } + + @Test + void should_provide_named_streaming_chat_models() { + contextRunner + .withPropertyValues( + "langchain4j.open-ai.streaming-chat-model.api-key=" + API_KEY, + "langchain4j.open-ai.streaming-chat-model.fast.model-name=gpt-4o-mini", + "langchain4j.open-ai.streaming-chat-model.fast.max-tokens=20" + ) + .run(context -> { + OpenAiStreamingChatModel fastModel = context.getBean("openAiStreamingChatModelFast", OpenAiStreamingChatModel.class); + assertThat(fastModel).isNotNull(); + + CompletableFuture future = new CompletableFuture<>(); + fastModel.chat("What is 2+2?", new StreamingChatResponseHandler() { + @Override + public void onPartialResponse(String partialResponse) { + } + + @Override + public void onCompleteResponse(ChatResponse completeResponse) { + future.complete(completeResponse); + } + + @Override + public void onError(Throwable error) { + future.completeExceptionally(error); + } + }); + + ChatResponse response = future.get(30, SECONDS); + assertThat(response.aiMessage().text()).contains("4"); + }); + } + + @Test + void should_not_create_named_model_without_api_key() { + contextRunner + .withPropertyValues( + "langchain4j.open-ai.chat-model.nokey.model-name=gpt-4o-mini" + ) + .run(context -> { + assertThat(context.containsBean("openAiChatModelNokey")).isFalse(); + }); + } + + @Test + void should_inherit_global_api_key() { + contextRunner + .withPropertyValues( + "langchain4j.open-ai.chat-model.base-url=" + BASE_URL, + "langchain4j.open-ai.chat-model.api-key=" + API_KEY, + "langchain4j.open-ai.chat-model.inherited.model-name=gpt-4o-mini", + "langchain4j.open-ai.chat-model.inherited.max-tokens=20" + ) + .run(context -> { + OpenAiChatModel inheritedModel = context.getBean("openAiChatModelInherited", OpenAiChatModel.class); + assertThat(inheritedModel).isNotNull(); + assertThat(inheritedModel.chat("What is 5+5?")).contains("10"); + }); + } + + @Test + void should_provide_named_chat_model_with_listeners() { + contextRunner + .withPropertyValues( + "langchain4j.open-ai.chat-model.base-url=" + BASE_URL, + "langchain4j.open-ai.chat-model.api-key=" + API_KEY, + "langchain4j.open-ai.chat-model.fast.model-name=gpt-4o-mini", + "langchain4j.open-ai.chat-model.fast.max-tokens=20" + ) + .withUserConfiguration(ListenerConfig.class) + .run(context -> { + OpenAiChatModel fastModel = context.getBean("openAiChatModelFast", OpenAiChatModel.class); + + fastModel.chat("What is 2+2?"); + + ChatModelListener listener1 = context.getBean("listener1", ChatModelListener.class); + ChatModelListener listener2 = context.getBean("listener2", ChatModelListener.class); + verify(listener1).onRequest(any()); + verify(listener1).onResponse(any()); + verify(listener2).onRequest(any()); + verify(listener2).onResponse(any()); + }); + } + @Configuration static class ListenerConfig {