Skip to content

Commit cc22ce8

Browse files
committed
feat: Add new fields to conversation api
Signed-off-by: Javier Aliaga <javier@diagrid.io>
1 parent f6c5400 commit cc22ce8

File tree

9 files changed

+200
-10
lines changed

9 files changed

+200
-10
lines changed

sdk/src/main/java/io/dapr/client/DaprClientImpl.java

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import io.dapr.client.domain.ConversationResponseAlpha2;
4343
import io.dapr.client.domain.ConversationResultAlpha2;
4444
import io.dapr.client.domain.ConversationResultChoices;
45+
import io.dapr.client.domain.ConversationResultCompletionUsage;
4546
import io.dapr.client.domain.ConversationResultMessage;
4647
import io.dapr.client.domain.ConversationToolCalls;
4748
import io.dapr.client.domain.ConversationToolCallsOfFunction;
@@ -1793,6 +1794,7 @@ public Mono<ConversationResponseAlpha2> converseAlpha2(ConversationRequestAlpha2
17931794
DaprAiProtos.ConversationResponseAlpha2 conversationResponse = conversationResponseMono.block();
17941795

17951796
assert conversationResponse != null;
1797+
17961798
List<ConversationResultAlpha2> results = buildConversationResults(conversationResponse.getOutputsList());
17971799
return Mono.just(new ConversationResponseAlpha2(conversationResponse.getContextId(), results));
17981800
} catch (Exception ex) {
@@ -1857,6 +1859,33 @@ private DaprAiProtos.ConversationRequestAlpha2 buildConversationRequestProto(Con
18571859

18581860
builder.addInputs(inputBuilder.build());
18591861
}
1862+
1863+
if (request.getResponseFormat() != null) {
1864+
Map<String, Value> responseParams = request.getResponseFormat()
1865+
.entrySet().stream()
1866+
.collect(Collectors.toMap(
1867+
Map.Entry::getKey,
1868+
e -> {
1869+
try {
1870+
return ProtobufValueHelper.toProtobufValue(e.getValue());
1871+
} catch (IOException ex) {
1872+
throw new RuntimeException(ex);
1873+
}
1874+
}
1875+
));
1876+
1877+
builder.setResponseFormat(Struct.newBuilder().putAllFields(responseParams).build());
1878+
}
1879+
1880+
if (request.getPromptCacheRetention() != null) {
1881+
Duration javaDuration = request.getPromptCacheRetention();
1882+
builder.setPromptCacheRetention(
1883+
com.google.protobuf.Duration.newBuilder()
1884+
.setSeconds(javaDuration.getSeconds())
1885+
.setNanos(javaDuration.getNano())
1886+
.build()
1887+
);
1888+
}
18601889

18611890
return builder.build();
18621891
}
@@ -1974,9 +2003,16 @@ private List<ConversationResultAlpha2> buildConversationResults(
19742003
for (DaprAiProtos.ConversationResultChoices protoChoice : protoResult.getChoicesList()) {
19752004
ConversationResultMessage message = buildConversationResultMessage(protoChoice);
19762005
choices.add(new ConversationResultChoices(protoChoice.getFinishReason(), protoChoice.getIndex(), message));
1977-
}
2006+
}
19782007

1979-
results.add(new ConversationResultAlpha2(choices));
2008+
results.add(new ConversationResultAlpha2(
2009+
choices,
2010+
protoResult.getModel(),
2011+
new ConversationResultCompletionUsage(
2012+
protoResult.getUsage().getCompletionTokens(),
2013+
protoResult.getUsage().getPromptTokens(),
2014+
protoResult.getUsage().getTotalTokens()))
2015+
);
19802016
}
19812017

19822018
return results;

sdk/src/main/java/io/dapr/client/DaprPreviewClient.java

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,15 @@
1717
import io.dapr.client.domain.BulkPublishRequest;
1818
import io.dapr.client.domain.BulkPublishResponse;
1919
import io.dapr.client.domain.BulkPublishResponseFailedEntry;
20-
import io.dapr.client.domain.CloudEvent;
2120
import io.dapr.client.domain.ConversationRequest;
2221
import io.dapr.client.domain.ConversationRequestAlpha2;
2322
import io.dapr.client.domain.ConversationResponse;
2423
import io.dapr.client.domain.ConversationResponseAlpha2;
2524
import io.dapr.client.domain.DecryptRequestAlpha1;
26-
import io.dapr.client.domain.DeleteJobRequest;
2725
import io.dapr.client.domain.EncryptRequestAlpha1;
28-
import io.dapr.client.domain.GetJobRequest;
29-
import io.dapr.client.domain.GetJobResponse;
3026
import io.dapr.client.domain.LockRequest;
3127
import io.dapr.client.domain.QueryStateRequest;
3228
import io.dapr.client.domain.QueryStateResponse;
33-
import io.dapr.client.domain.ScheduleJobRequest;
3429
import io.dapr.client.domain.UnlockRequest;
3530
import io.dapr.client.domain.UnlockResponseStatus;
3631
import io.dapr.client.domain.query.Query;

sdk/src/main/java/io/dapr/client/domain/ConversationOutput.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
package io.dapr.client.domain;
1515

16-
import java.util.Collections;
1716
import java.util.Map;
1817

1918
/**

sdk/src/main/java/io/dapr/client/domain/ConversationRequestAlpha2.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
package io.dapr.client.domain;
1515

16+
import java.time.Duration;
1617
import java.util.List;
1718
import java.util.Map;
1819

@@ -31,6 +32,8 @@ public class ConversationRequestAlpha2 {
3132
private String toolChoice;
3233
private Map<String, Object> parameters;
3334
private Map<String, String> metadata;
35+
private Map<String, Object> responseFormat;
36+
private Duration promptCacheRetention;
3437

3538
/**
3639
* Constructs a ConversationRequestAlpha2 with a component name and conversation inputs.
@@ -206,4 +209,22 @@ public ConversationRequestAlpha2 setMetadata(Map<String, String> metadata) {
206209
this.metadata = metadata;
207210
return this;
208211
}
212+
213+
public Map<String, Object> getResponseFormat() {
214+
return responseFormat;
215+
}
216+
217+
public ConversationRequestAlpha2 setResponseFormat(Map<String, Object> responseFormat) {
218+
this.responseFormat = responseFormat;
219+
return this;
220+
}
221+
222+
public Duration getPromptCacheRetention() {
223+
return promptCacheRetention;
224+
}
225+
226+
public ConversationRequestAlpha2 setPromptCacheRetention(Duration promptCacheRetention) {
227+
this.promptCacheRetention = promptCacheRetention;
228+
return this;
229+
}
209230
}

sdk/src/main/java/io/dapr/client/domain/ConversationResponse.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
package io.dapr.client.domain;
1515

16-
import java.util.Collections;
1716
import java.util.List;
1817

1918
/**

sdk/src/main/java/io/dapr/client/domain/ConversationResultAlpha2.java

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,22 @@
2121
public class ConversationResultAlpha2 {
2222

2323
private final List<ConversationResultChoices> choices;
24+
private final String model;
25+
private final ConversationResultCompletionUsage usage;
2426

2527
/**
2628
* Constructor.
2729
*
2830
* @param choices the list of conversation result choices.
31+
* @param model the model used for the conversation.
32+
* @param usage the usage of the model.
2933
*/
30-
public ConversationResultAlpha2(List<ConversationResultChoices> choices) {
34+
public ConversationResultAlpha2(List<ConversationResultChoices> choices,
35+
String model,
36+
ConversationResultCompletionUsage usage) {
3137
this.choices = List.copyOf(choices);
38+
this.model = model;
39+
this.usage = usage;
3240
}
3341

3442
/**
@@ -39,4 +47,22 @@ public ConversationResultAlpha2(List<ConversationResultChoices> choices) {
3947
public List<ConversationResultChoices> getChoices() {
4048
return choices;
4149
}
50+
51+
/**
52+
* Gets the model used for the conversation.
53+
*
54+
* @return the model used for the conversation.
55+
*/
56+
public String getModel() {
57+
return model;
58+
}
59+
60+
/**
61+
* Gets the usage of the model.
62+
*
63+
* @return the usage of the model.
64+
*/
65+
public ConversationResultCompletionUsage getUsage() {
66+
return usage;
67+
}
4268
}
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Copyright 2026 The Dapr Authors
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
* Unless required by applicable law or agreed to in writing, software
8+
* distributed under the License is distributed on an "AS IS" BASIS,
9+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
* See the License for the specific language governing permissions and
11+
limitations under the License.
12+
*/
13+
14+
package io.dapr.client.domain;
15+
16+
public class ConversationResultCompletionUsage {
17+
private final long completionTokens;
18+
private final long promptTokens;
19+
private final long totalTokens;
20+
21+
/**
22+
* Constructor.
23+
*
24+
* @param completionTokens completion tokens used.
25+
* @param promptTokens prompt tokens used.
26+
* @param totalTokens total tokens used.
27+
*/
28+
public ConversationResultCompletionUsage(long completionTokens, long promptTokens, long totalTokens) {
29+
this.completionTokens = completionTokens;
30+
this.promptTokens = promptTokens;
31+
this.totalTokens = totalTokens;
32+
}
33+
34+
/**
35+
* Completion tokens used.
36+
*
37+
* @return completion tokens used.
38+
*/
39+
public long getCompletionTokens() {
40+
return completionTokens;
41+
}
42+
43+
/**
44+
* Prompt tokens used.
45+
*
46+
* @return prompt tokens used.
47+
*/
48+
public long getPromptTokens() {
49+
return promptTokens;
50+
}
51+
52+
/**
53+
* Total tokens used.
54+
*
55+
* @return total tokens used.
56+
*/
57+
public long getTotalTokens() {
58+
return totalTokens;
59+
}
60+
}

sdk/src/test/java/io/dapr/client/DaprPreviewClientGrpcTest.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import com.fasterxml.jackson.core.JsonProcessingException;
1818
import com.fasterxml.jackson.databind.ObjectMapper;
1919
import com.google.protobuf.ByteString;
20+
import com.google.protobuf.Struct;
21+
import com.google.protobuf.Value;
2022
import io.dapr.client.domain.AssistantMessage;
2123
import io.dapr.client.domain.BulkPublishEntry;
2224
import io.dapr.client.domain.BulkPublishRequest;
@@ -43,6 +45,7 @@
4345
import io.dapr.client.domain.QueryStateRequest;
4446
import io.dapr.client.domain.QueryStateResponse;
4547
import io.dapr.client.domain.SystemMessage;
48+
import io.dapr.client.domain.TestData;
4649
import io.dapr.client.domain.ToolMessage;
4750
import io.dapr.client.domain.UnlockResponseStatus;
4851
import io.dapr.client.domain.UserMessage;
@@ -75,6 +78,7 @@
7578

7679
import java.io.IOException;
7780
import java.nio.charset.StandardCharsets;
81+
import java.time.Duration;
7882
import java.util.ArrayList;
7983
import java.util.Collections;
8084
import java.util.HashMap;
@@ -1061,6 +1065,9 @@ public void converseAlpha2ComplexRequestTest() {
10611065
Map<String, Object> parameters = new HashMap<>();
10621066
parameters.put("max_tokens", "1000");
10631067

1068+
var responseFormat = new HashMap<String, Object>();
1069+
responseFormat.put("temperature", 0.7);
1070+
responseFormat.put("data", new TestData("Peter", 40));
10641071
ConversationRequestAlpha2 request = new ConversationRequestAlpha2("openai", List.of(input));
10651072
request.setContextId("test-context");
10661073
request.setTemperature(0.7);
@@ -1069,11 +1076,19 @@ public void converseAlpha2ComplexRequestTest() {
10691076
request.setToolChoice("auto");
10701077
request.setMetadata(metadata);
10711078
request.setParameters(parameters);
1079+
request.setPromptCacheRetention(Duration.ofDays(1));
1080+
request.setResponseFormat(responseFormat);
10721081

10731082
// Mock response with tool calls
10741083
DaprAiProtos.ConversationResponseAlpha2 grpcResponse = DaprAiProtos.ConversationResponseAlpha2.newBuilder()
10751084
.setContextId("test-context")
10761085
.addOutputs(DaprAiProtos.ConversationResultAlpha2.newBuilder()
1086+
.setModel("gpt-3.5-turbo")
1087+
.setUsage(DaprAiProtos.ConversationResultAlpha2CompletionUsage.newBuilder()
1088+
.setPromptTokens(100)
1089+
.setCompletionTokens(100)
1090+
.setTotalTokens(200)
1091+
.build())
10771092
.addChoices(DaprAiProtos.ConversationResultChoices.newBuilder()
10781093
.setFinishReason("tool_calls")
10791094
.setIndex(0)
@@ -1108,6 +1123,11 @@ public void converseAlpha2ComplexRequestTest() {
11081123
assertEquals("tool_calls", choice.getFinishReason());
11091124
assertEquals("I'll help you get the weather information.", choice.getMessage().getContent());
11101125
assertEquals(1, choice.getMessage().getToolCalls().size());
1126+
assertEquals("gpt-3.5-turbo", response.getOutputs().get(0).getModel());
1127+
assertEquals(100, response.getOutputs().get(0).getUsage().getCompletionTokens());
1128+
assertEquals(100, response.getOutputs().get(0).getUsage().getPromptTokens());
1129+
assertEquals(200, response.getOutputs().get(0).getUsage().getTotalTokens());
1130+
11111131

11121132
ConversationToolCalls toolCall = choice.getMessage().getToolCalls().get(0);
11131133
assertEquals("call_123", toolCall.getId());
@@ -1128,6 +1148,13 @@ public void converseAlpha2ComplexRequestTest() {
11281148
assertEquals("value1", capturedRequest.getMetadataMap().get("key1"));
11291149
assertEquals(1, capturedRequest.getToolsCount());
11301150
assertEquals("get_weather", capturedRequest.getTools(0).getFunction().getName());
1151+
assertEquals(Struct.newBuilder()
1152+
.putFields("temperature", Value.newBuilder().setNumberValue(0.7).build())
1153+
.putFields("data", Value.newBuilder().setStringValue("TestData{name='Peter', age=40}").build())
1154+
.build(),
1155+
capturedRequest.getResponseFormat());
1156+
assertEquals(Duration.ofDays(1).getSeconds(), capturedRequest.getPromptCacheRetention().getSeconds());
1157+
assertEquals(0, capturedRequest.getPromptCacheRetention().getNanos());
11311158
}
11321159

11331160
@Test
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
package io.dapr.client.domain;
2+
3+
public class TestData {
4+
private final String name;
5+
private final int age;
6+
7+
public TestData(String name, int age) {
8+
this.name = name;
9+
this.age = age;
10+
}
11+
12+
public int getAge() {
13+
return age;
14+
}
15+
16+
public String getName() {
17+
return name;
18+
}
19+
20+
@Override
21+
public String toString() {
22+
return "TestData{" +
23+
"name='" + name + "'" +
24+
", age=" + age +
25+
"}";
26+
}
27+
}

0 commit comments

Comments
 (0)