Skip to content

Commit 3b8b13b

Browse files
committed
feat: Add new fields to conversation api
Signed-off-by: Javier Aliaga <javier@diagrid.io>
1 parent f6c5400 commit 3b8b13b

File tree

7 files changed

+146
-10
lines changed

7 files changed

+146
-10
lines changed

sdk/src/main/java/io/dapr/client/DaprClientImpl.java

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import io.dapr.client.domain.ConversationResponseAlpha2;
4343
import io.dapr.client.domain.ConversationResultAlpha2;
4444
import io.dapr.client.domain.ConversationResultChoices;
45+
import io.dapr.client.domain.ConversationResultCompletionUsage;
4546
import io.dapr.client.domain.ConversationResultMessage;
4647
import io.dapr.client.domain.ConversationToolCalls;
4748
import io.dapr.client.domain.ConversationToolCallsOfFunction;
@@ -1793,6 +1794,7 @@ public Mono<ConversationResponseAlpha2> converseAlpha2(ConversationRequestAlpha2
17931794
DaprAiProtos.ConversationResponseAlpha2 conversationResponse = conversationResponseMono.block();
17941795

17951796
assert conversationResponse != null;
1797+
17961798
List<ConversationResultAlpha2> results = buildConversationResults(conversationResponse.getOutputsList());
17971799
return Mono.just(new ConversationResponseAlpha2(conversationResponse.getContextId(), results));
17981800
} catch (Exception ex) {
@@ -1857,6 +1859,33 @@ private DaprAiProtos.ConversationRequestAlpha2 buildConversationRequestProto(Con
18571859

18581860
builder.addInputs(inputBuilder.build());
18591861
}
1862+
1863+
if (request.getResponseFormat() != null) {
1864+
Map<String, Value> responseParams = request.getResponseFormat()
1865+
.entrySet().stream()
1866+
.collect(Collectors.toMap(
1867+
Map.Entry::getKey,
1868+
e -> {
1869+
try {
1870+
return ProtobufValueHelper.toProtobufValue(e.getValue());
1871+
} catch (IOException ex) {
1872+
throw new RuntimeException(ex);
1873+
}
1874+
}
1875+
));
1876+
1877+
builder.setResponseFormat(Struct.newBuilder().putAllFields(responseParams).build());
1878+
}
1879+
1880+
if (request.getPromptCacheRetention() != null) {
1881+
Duration javaDuration = request.getPromptCacheRetention();
1882+
builder.setPromptCacheRetention(
1883+
com.google.protobuf.Duration.newBuilder()
1884+
.setSeconds(javaDuration.getSeconds())
1885+
.setNanos(javaDuration.getNano())
1886+
.build()
1887+
);
1888+
}
18601889

18611890
return builder.build();
18621891
}
@@ -1974,9 +2003,16 @@ private List<ConversationResultAlpha2> buildConversationResults(
19742003
for (DaprAiProtos.ConversationResultChoices protoChoice : protoResult.getChoicesList()) {
19752004
ConversationResultMessage message = buildConversationResultMessage(protoChoice);
19762005
choices.add(new ConversationResultChoices(protoChoice.getFinishReason(), protoChoice.getIndex(), message));
1977-
}
2006+
}
19782007

1979-
results.add(new ConversationResultAlpha2(choices));
2008+
results.add(new ConversationResultAlpha2(
2009+
choices,
2010+
protoResult.getModel(),
2011+
new ConversationResultCompletionUsage(
2012+
protoResult.getUsage().getCompletionTokens(),
2013+
protoResult.getUsage().getPromptTokens(),
2014+
protoResult.getUsage().getTotalTokens()))
2015+
);
19802016
}
19812017

19822018
return results;

sdk/src/main/java/io/dapr/client/DaprPreviewClient.java

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,15 @@
1717
import io.dapr.client.domain.BulkPublishRequest;
1818
import io.dapr.client.domain.BulkPublishResponse;
1919
import io.dapr.client.domain.BulkPublishResponseFailedEntry;
20-
import io.dapr.client.domain.CloudEvent;
2120
import io.dapr.client.domain.ConversationRequest;
2221
import io.dapr.client.domain.ConversationRequestAlpha2;
2322
import io.dapr.client.domain.ConversationResponse;
2423
import io.dapr.client.domain.ConversationResponseAlpha2;
2524
import io.dapr.client.domain.DecryptRequestAlpha1;
26-
import io.dapr.client.domain.DeleteJobRequest;
2725
import io.dapr.client.domain.EncryptRequestAlpha1;
28-
import io.dapr.client.domain.GetJobRequest;
29-
import io.dapr.client.domain.GetJobResponse;
3026
import io.dapr.client.domain.LockRequest;
3127
import io.dapr.client.domain.QueryStateRequest;
3228
import io.dapr.client.domain.QueryStateResponse;
33-
import io.dapr.client.domain.ScheduleJobRequest;
3429
import io.dapr.client.domain.UnlockRequest;
3530
import io.dapr.client.domain.UnlockResponseStatus;
3631
import io.dapr.client.domain.query.Query;

sdk/src/main/java/io/dapr/client/domain/ConversationOutput.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
package io.dapr.client.domain;
1515

16-
import java.util.Collections;
1716
import java.util.Map;
1817

1918
/**

sdk/src/main/java/io/dapr/client/domain/ConversationRequestAlpha2.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
package io.dapr.client.domain;
1515

16+
import java.time.Duration;
1617
import java.util.List;
1718
import java.util.Map;
1819

@@ -31,6 +32,8 @@ public class ConversationRequestAlpha2 {
3132
private String toolChoice;
3233
private Map<String, Object> parameters;
3334
private Map<String, String> metadata;
35+
private Map<String, Object> responseFormat;
36+
private Duration promptCacheRetention;
3437

3538
/**
3639
* Constructs a ConversationRequestAlpha2 with a component name and conversation inputs.
@@ -206,4 +209,22 @@ public ConversationRequestAlpha2 setMetadata(Map<String, String> metadata) {
206209
this.metadata = metadata;
207210
return this;
208211
}
212+
213+
public Map<String, Object> getResponseFormat() {
214+
return responseFormat;
215+
}
216+
217+
public ConversationRequestAlpha2 setResponseFormat(Map<String, Object> responseFormat) {
218+
this.responseFormat = responseFormat;
219+
return this;
220+
}
221+
222+
public Duration getPromptCacheRetention() {
223+
return promptCacheRetention;
224+
}
225+
226+
public ConversationRequestAlpha2 setPromptCacheRetention(Duration promptCacheRetention) {
227+
this.promptCacheRetention = promptCacheRetention;
228+
return this;
229+
}
209230
}

sdk/src/main/java/io/dapr/client/domain/ConversationResponse.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
package io.dapr.client.domain;
1515

16-
import java.util.Collections;
1716
import java.util.List;
1817

1918
/**

sdk/src/main/java/io/dapr/client/domain/ConversationResultAlpha2.java

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,22 @@
2121
public class ConversationResultAlpha2 {
2222

2323
private final List<ConversationResultChoices> choices;
24+
private final String model;
25+
private final ConversationResultCompletionUsage usage;
2426

2527
/**
2628
* Constructor.
2729
*
2830
* @param choices the list of conversation result choices.
31+
* @param model the model used for the conversation.
32+
* @param usage the usage of the model.
2933
*/
30-
public ConversationResultAlpha2(List<ConversationResultChoices> choices) {
34+
public ConversationResultAlpha2(List<ConversationResultChoices> choices,
35+
String model,
36+
ConversationResultCompletionUsage usage) {
3137
this.choices = List.copyOf(choices);
38+
this.model = model;
39+
this.usage = usage;
3240
}
3341

3442
/**
@@ -39,4 +47,22 @@ public ConversationResultAlpha2(List<ConversationResultChoices> choices) {
3947
public List<ConversationResultChoices> getChoices() {
4048
return choices;
4149
}
50+
51+
/**
52+
* Gets the model used for the conversation.
53+
*
54+
* @return the model used for the conversation.
55+
*/
56+
public String getModel() {
57+
return model;
58+
}
59+
60+
/**
61+
* Gets the usage of the model.
62+
*
63+
* @return the usage of the model.
64+
*/
65+
public ConversationResultCompletionUsage getUsage() {
66+
return usage;
67+
}
4268
}
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Copyright 2026 The Dapr Authors
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
* Unless required by applicable law or agreed to in writing, software
8+
* distributed under the License is distributed on an "AS IS" BASIS,
9+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
* See the License for the specific language governing permissions and
11+
limitations under the License.
12+
*/
13+
14+
package io.dapr.client.domain;
15+
16+
public class ConversationResultCompletionUsage {
17+
private final long completionTokens;
18+
private final long promptTokens;
19+
private final long totalTokens;
20+
21+
/**
22+
* Constructor.
23+
*
24+
* @param completionTokens completion tokens used.
25+
* @param promptTokens prompt tokens used.
26+
* @param totalTokens total tokens used.
27+
*/
28+
public ConversationResultCompletionUsage(long completionTokens, long promptTokens, long totalTokens) {
29+
this.completionTokens = completionTokens;
30+
this.promptTokens = promptTokens;
31+
this.totalTokens = totalTokens;
32+
}
33+
34+
/**
35+
* Completion tokens used.
36+
*
37+
* @return completion tokens used.
38+
*/
39+
public long getCompletionTokens() {
40+
return completionTokens;
41+
}
42+
43+
/**
44+
* Prompt tokens used.
45+
*
46+
* @return prompt tokens used.
47+
*/
48+
public long getPromptTokens() {
49+
return promptTokens;
50+
}
51+
52+
/**
53+
* Total tokens used.
54+
*
55+
* @return total tokens used.
56+
*/
57+
public long getTotalTokens() {
58+
return totalTokens;
59+
}
60+
}

0 commit comments

Comments
 (0)