Skip to content

Commit 2b28957

Browse files
tilgalascopybara-github
authored andcommitted
fix: emit multiple LlmResponses in GeminiLlmConnection
A single LiveServerMessage is now converted to a series of LlmResponse messages each corresponding to a different part of the LiveServerMessage, notably the UsageMetadata field is now converted to a GenerateResponseUsageMetadata and emitted downstream. PiperOrigin-RevId: 866010045
1 parent 5607f64 commit 2b28957

File tree

3 files changed

+242
-55
lines changed

3 files changed

+242
-55
lines changed

core/src/main/java/com/google/adk/models/GeminiLlmConnection.java

Lines changed: 92 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import com.google.genai.types.Part;
3737
import io.reactivex.rxjava3.core.Completable;
3838
import io.reactivex.rxjava3.core.Flowable;
39+
import io.reactivex.rxjava3.core.Observable;
3940
import io.reactivex.rxjava3.processors.PublishProcessor;
4041
import java.net.SocketException;
4142
import java.util.List;
@@ -120,53 +121,103 @@ private void handleServerMessage(LiveServerMessage message) {
120121

121122
logger.debug("Received server message: {}", message.toJson());
122123

123-
Optional<LlmResponse> llmResponse = convertToServerResponse(message);
124-
llmResponse.ifPresent(responseProcessor::onNext);
124+
Observable<LlmResponse> llmResponse = convertToServerResponse(message);
125+
llmResponse.subscribe(responseProcessor::onNext, responseProcessor::onError);
125126
}
126127

127128
/** Converts a server message into the standardized LlmResponse format. */
128-
static Optional<LlmResponse> convertToServerResponse(LiveServerMessage message) {
129+
static Observable<LlmResponse> convertToServerResponse(LiveServerMessage message) {
130+
return Observable.create(
131+
emitter -> {
132+
// AtomicBoolean is used to modify state from within lambdas, which
133+
// require captured variables to be effectively final.
134+
final AtomicBoolean handled = new AtomicBoolean(false);
135+
message
136+
.serverContent()
137+
.ifPresent(
138+
serverContent -> {
139+
emitter.onNext(createServerContentResponse(serverContent));
140+
handled.set(true);
141+
});
142+
message
143+
.toolCall()
144+
.ifPresent(
145+
toolCall -> {
146+
emitter.onNext(createToolCallResponse(toolCall));
147+
handled.set(true);
148+
});
149+
message
150+
.usageMetadata()
151+
.ifPresent(
152+
usageMetadata -> {
153+
logger.debug("Received usage metadata: {}", usageMetadata);
154+
emitter.onNext(createUsageMetadataResponse(usageMetadata));
155+
handled.set(true);
156+
});
157+
message
158+
.toolCallCancellation()
159+
.ifPresent(
160+
toolCallCancellation -> {
161+
logger.debug("Received tool call cancellation: {}", toolCallCancellation);
162+
// TODO: implement proper CFC and thus tool call cancellation handling.
163+
handled.set(true);
164+
});
165+
message
166+
.setupComplete()
167+
.ifPresent(
168+
setupComplete -> {
169+
logger.debug("Received setup complete.");
170+
handled.set(true);
171+
});
172+
173+
if (!handled.get()) {
174+
logger.warn("Received unknown or empty server message: {}", message.toJson());
175+
emitter.onNext(createUnknownMessageResponse());
176+
}
177+
emitter.onComplete();
178+
});
179+
}
180+
181+
private static LlmResponse createServerContentResponse(LiveServerContent serverContent) {
129182
LlmResponse.Builder builder = LlmResponse.builder();
183+
serverContent.modelTurn().ifPresent(builder::content);
184+
builder
185+
.partial(serverContent.turnComplete().map(completed -> !completed).orElse(false))
186+
.turnComplete(serverContent.turnComplete().orElse(false))
187+
.interrupted(serverContent.interrupted());
188+
return builder.build();
189+
}
130190

131-
if (message.serverContent().isPresent()) {
132-
LiveServerContent serverContent = message.serverContent().get();
133-
serverContent.modelTurn().ifPresent(builder::content);
134-
builder
135-
.partial(serverContent.turnComplete().map(completed -> !completed).orElse(false))
136-
.turnComplete(serverContent.turnComplete().orElse(false))
137-
.interrupted(serverContent.interrupted());
138-
} else if (message.toolCall().isPresent()) {
139-
LiveServerToolCall toolCall = message.toolCall().get();
140-
toolCall
141-
.functionCalls()
142-
.ifPresent(
143-
calls -> {
144-
for (FunctionCall call : calls) {
145-
builder.content(
146-
Content.builder()
147-
.parts(ImmutableList.of(Part.builder().functionCall(call).build()))
148-
.build());
149-
}
150-
});
151-
builder.partial(false).turnComplete(false);
152-
} else if (message.usageMetadata().isPresent()) {
153-
logger.debug("Received usage metadata: {}", message.usageMetadata().get());
154-
return Optional.empty();
155-
} else if (message.toolCallCancellation().isPresent()) {
156-
logger.debug("Received tool call cancellation: {}", message.toolCallCancellation().get());
157-
// TODO: implement proper CFC and thus tool call cancellation handling.
158-
return Optional.empty();
159-
} else if (message.setupComplete().isPresent()) {
160-
logger.debug("Received setup complete.");
161-
return Optional.empty();
162-
} else {
163-
logger.warn("Received unknown or empty server message: {}", message.toJson());
164-
builder
165-
.errorCode(new FinishReason("Unknown server message."))
166-
.errorMessage("Received unknown server message.");
167-
}
191+
private static LlmResponse createToolCallResponse(LiveServerToolCall toolCall) {
192+
LlmResponse.Builder builder = LlmResponse.builder();
193+
toolCall
194+
.functionCalls()
195+
.ifPresent(
196+
calls -> {
197+
for (FunctionCall call : calls) {
198+
builder.content(
199+
Content.builder()
200+
.parts(ImmutableList.of(Part.builder().functionCall(call).build()))
201+
.build());
202+
}
203+
});
204+
builder.partial(false).turnComplete(false);
205+
return builder.build();
206+
}
207+
208+
private static LlmResponse createUsageMetadataResponse(
209+
com.google.genai.types.UsageMetadata usageMetadata) {
210+
LlmResponse.Builder builder = LlmResponse.builder();
211+
builder.usageMetadata(GeminiUtil.toGenerateContentResponseUsageMetadata(usageMetadata));
212+
return builder.build();
213+
}
168214

169-
return Optional.of(builder.build());
215+
private static LlmResponse createUnknownMessageResponse() {
216+
LlmResponse.Builder builder = LlmResponse.builder();
217+
builder
218+
.errorCode(new FinishReason("Unknown server message."))
219+
.errorMessage("Received unknown server message.");
220+
return builder.build();
170221
}
171222

172223
/** Handles errors that occur *during* the initial connection attempt. */

core/src/main/java/com/google/adk/models/GeminiUtil.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
import com.google.genai.types.Blob;
2525
import com.google.genai.types.Content;
2626
import com.google.genai.types.FileData;
27+
import com.google.genai.types.GenerateContentResponseUsageMetadata;
2728
import com.google.genai.types.Part;
29+
import com.google.genai.types.UsageMetadata;
2830
import java.util.List;
2931
import java.util.Optional;
3032
import java.util.stream.Stream;
@@ -224,4 +226,22 @@ public static List<Content> stripThoughts(List<Content> originalContents) {
224226
})
225227
.collect(toImmutableList());
226228
}
229+
230+
public static GenerateContentResponseUsageMetadata toGenerateContentResponseUsageMetadata(
231+
UsageMetadata usageMetadata) {
232+
GenerateContentResponseUsageMetadata.Builder builder =
233+
GenerateContentResponseUsageMetadata.builder();
234+
usageMetadata.promptTokenCount().ifPresent(builder::promptTokenCount);
235+
usageMetadata.cachedContentTokenCount().ifPresent(builder::cachedContentTokenCount);
236+
usageMetadata.responseTokenCount().ifPresent(builder::candidatesTokenCount);
237+
usageMetadata.toolUsePromptTokenCount().ifPresent(builder::toolUsePromptTokenCount);
238+
usageMetadata.thoughtsTokenCount().ifPresent(builder::thoughtsTokenCount);
239+
usageMetadata.totalTokenCount().ifPresent(builder::totalTokenCount);
240+
usageMetadata.promptTokensDetails().ifPresent(builder::promptTokensDetails);
241+
usageMetadata.cacheTokensDetails().ifPresent(builder::cacheTokensDetails);
242+
usageMetadata.responseTokensDetails().ifPresent(builder::candidatesTokensDetails);
243+
usageMetadata.toolUsePromptTokensDetails().ifPresent(builder::toolUsePromptTokensDetails);
244+
usageMetadata.trafficType().ifPresent(builder::trafficType);
245+
return builder.build();
246+
}
227247
}

0 commit comments

Comments
 (0)