Skip to content

Commit ecffdbb

Browse files
google-genai-botcopybara-github
authored andcommitted
refactor: refactors parts of the ADK codebase to improve null safety and consistency
This CL refactors parts of the ADK codebase to improve null safety and consistency. The main changes include: 1. **`BaseAgent`**: * `beforeAgentCallback` and `afterAgentCallback` fields and their accessors now use `ImmutableList` (defaulting to empty) instead of `Optional<List>`. * `findAgent` and `findSubAgent` now return `Optional<BaseAgent>`, with `findSubAgent` being reimplemented using Java Streams. 2. **`BaseAgentConfig`**: Getters for `subAgents`, `beforeAgentCallbacks`, and `afterAgentCallbacks` now return an empty list if the underlying field is null. 3. **`CallbackUtil`**: `getBeforeAgentCallbacks` and `getAfterAgentCallbacks` return `ImmutableList.of()` instead of `null` for null inputs. 4. **`LlmAgent`**: The `codeExecutor()` method now returns `Optional<BaseCodeExecutor>`. These changes necessitate updates in `BaseLlmFlow`, `CodeExecution`, and `Runner` to handle the new `Optional` return types. PiperOrigin-RevId: 863294916
1 parent efe58d6 commit ecffdbb

File tree

9 files changed

+143
-131
lines changed

9 files changed

+143
-131
lines changed

core/src/main/java/com/google/adk/agents/BaseAgent.java

Lines changed: 33 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@ public abstract class BaseAgent {
5757
*/
5858
private BaseAgent parentAgent;
5959

60-
private final List<? extends BaseAgent> subAgents;
60+
private final ImmutableList<? extends BaseAgent> subAgents;
6161

62-
private final Optional<List<? extends BeforeAgentCallback>> beforeAgentCallback;
63-
private final Optional<List<? extends AfterAgentCallback>> afterAgentCallback;
62+
private final ImmutableList<? extends BeforeAgentCallback> beforeAgentCallback;
63+
private final ImmutableList<? extends AfterAgentCallback> afterAgentCallback;
6464

6565
/**
6666
* Creates a new BaseAgent.
@@ -82,9 +82,13 @@ public BaseAgent(
8282
this.name = name;
8383
this.description = description;
8484
this.parentAgent = null;
85-
this.subAgents = subAgents != null ? subAgents : ImmutableList.of();
86-
this.beforeAgentCallback = Optional.ofNullable(beforeAgentCallback);
87-
this.afterAgentCallback = Optional.ofNullable(afterAgentCallback);
85+
this.subAgents = subAgents == null ? ImmutableList.of() : ImmutableList.copyOf(subAgents);
86+
this.beforeAgentCallback =
87+
beforeAgentCallback == null
88+
? ImmutableList.of()
89+
: ImmutableList.copyOf(beforeAgentCallback);
90+
this.afterAgentCallback =
91+
afterAgentCallback == null ? ImmutableList.of() : ImmutableList.copyOf(afterAgentCallback);
8892

8993
// Establish parent relationships for all sub-agents if needed.
9094
for (BaseAgent subAgent : this.subAgents) {
@@ -144,38 +148,38 @@ public BaseAgent rootAgent() {
144148
/**
145149
* Finds an agent (this or descendant) by name.
146150
*
147-
* @return the agent or descendant with the given name, or {@code null} if not found.
151+
* @return an {@link Optional} containing the agent or descendant with the given name, or {@link
152+
* Optional#empty()} if not found.
148153
*/
149-
public BaseAgent findAgent(String name) {
154+
public Optional<BaseAgent> findAgent(String name) {
150155
if (this.name().equals(name)) {
151-
return this;
156+
return Optional.of(this);
152157
}
153158
return findSubAgent(name);
154159
}
155160

156-
/** Recursively search sub agent by name. */
157-
public @Nullable BaseAgent findSubAgent(String name) {
158-
for (BaseAgent subAgent : subAgents) {
159-
if (subAgent.name().equals(name)) {
160-
return subAgent;
161-
}
162-
BaseAgent result = subAgent.findSubAgent(name);
163-
if (result != null) {
164-
return result;
165-
}
166-
}
167-
return null;
161+
/**
162+
* Recursively search sub agent by name.
163+
*
164+
* @return an {@link Optional} containing the sub agent with the given name, or {@link
165+
* Optional#empty()} if not found.
166+
*/
167+
public Optional<BaseAgent> findSubAgent(String name) {
168+
return subAgents.stream()
169+
.map(subAgent -> subAgent.findAgent(name))
170+
.flatMap(Optional::stream)
171+
.findFirst();
168172
}
169173

170174
public List<? extends BaseAgent> subAgents() {
171175
return subAgents;
172176
}
173177

174-
public Optional<List<? extends BeforeAgentCallback>> beforeAgentCallback() {
178+
public ImmutableList<? extends BeforeAgentCallback> beforeAgentCallback() {
175179
return beforeAgentCallback;
176180
}
177181

178-
public Optional<List<? extends AfterAgentCallback>> afterAgentCallback() {
182+
public ImmutableList<? extends AfterAgentCallback> afterAgentCallback() {
179183
return afterAgentCallback;
180184
}
181185

@@ -184,17 +188,17 @@ public Optional<List<? extends AfterAgentCallback>> afterAgentCallback() {
184188
*
185189
* <p>This method is only for use by Agent Development Kit.
186190
*/
187-
public List<? extends BeforeAgentCallback> canonicalBeforeAgentCallbacks() {
188-
return beforeAgentCallback.orElse(ImmutableList.of());
191+
public ImmutableList<? extends BeforeAgentCallback> canonicalBeforeAgentCallbacks() {
192+
return beforeAgentCallback;
189193
}
190194

191195
/**
192196
* The resolved afterAgentCallback field as a list.
193197
*
194198
* <p>This method is only for use by Agent Development Kit.
195199
*/
196-
public List<? extends AfterAgentCallback> canonicalAfterAgentCallbacks() {
197-
return afterAgentCallback.orElse(ImmutableList.of());
200+
public ImmutableList<? extends AfterAgentCallback> canonicalAfterAgentCallbacks() {
201+
return afterAgentCallback;
198202
}
199203

200204
/**
@@ -239,8 +243,7 @@ public Flowable<Event> runAsync(InvocationContext parentContext) {
239243
() ->
240244
callCallback(
241245
beforeCallbacksToFunctions(
242-
invocationContext.pluginManager(),
243-
beforeAgentCallback.orElse(ImmutableList.of())),
246+
invocationContext.pluginManager(), beforeAgentCallback),
244247
invocationContext)
245248
.flatMapPublisher(
246249
beforeEventOpt -> {
@@ -257,7 +260,7 @@ public Flowable<Event> runAsync(InvocationContext parentContext) {
257260
callCallback(
258261
afterCallbacksToFunctions(
259262
invocationContext.pluginManager(),
260-
afterAgentCallback.orElse(ImmutableList.of())),
263+
afterAgentCallback),
261264
invocationContext)
262265
.flatMapPublisher(Flowable::fromOptional));
263266

core/src/main/java/com/google/adk/agents/BaseAgentConfig.java

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package com.google.adk.agents;
1818

19+
import com.google.common.collect.ImmutableList;
1920
import java.util.List;
2021

2122
/**
@@ -27,11 +28,11 @@ public class BaseAgentConfig {
2728
private String name;
2829
private String description = "";
2930
private String agentClass;
30-
private List<AgentRefConfig> subAgents;
31+
private ImmutableList<AgentRefConfig> subAgents = ImmutableList.of();
3132

3233
// Callback configuration (names resolved via ComponentRegistry)
33-
private List<CallbackRef> beforeAgentCallbacks;
34-
private List<CallbackRef> afterAgentCallbacks;
34+
private ImmutableList<CallbackRef> beforeAgentCallbacks = ImmutableList.of();
35+
private ImmutableList<CallbackRef> afterAgentCallbacks = ImmutableList.of();
3536

3637
/** Reference to a callback stored in the ComponentRegistry. */
3738
public static class CallbackRef {
@@ -131,27 +132,33 @@ public String agentClass() {
131132
return agentClass;
132133
}
133134

134-
public List<AgentRefConfig> subAgents() {
135+
public ImmutableList<AgentRefConfig> subAgents() {
135136
return subAgents;
136137
}
137138

138139
public void setSubAgents(List<AgentRefConfig> subAgents) {
139-
this.subAgents = subAgents;
140+
this.subAgents = subAgents == null ? ImmutableList.of() : ImmutableList.copyOf(subAgents);
140141
}
141142

142-
public List<CallbackRef> beforeAgentCallbacks() {
143+
public ImmutableList<CallbackRef> beforeAgentCallbacks() {
143144
return beforeAgentCallbacks;
144145
}
145146

146147
public void setBeforeAgentCallbacks(List<CallbackRef> beforeAgentCallbacks) {
147-
this.beforeAgentCallbacks = beforeAgentCallbacks;
148+
this.beforeAgentCallbacks =
149+
beforeAgentCallbacks == null
150+
? ImmutableList.of()
151+
: ImmutableList.copyOf(beforeAgentCallbacks);
148152
}
149153

150-
public List<CallbackRef> afterAgentCallbacks() {
154+
public ImmutableList<CallbackRef> afterAgentCallbacks() {
151155
return afterAgentCallbacks;
152156
}
153157

154158
public void setAfterAgentCallbacks(List<CallbackRef> afterAgentCallbacks) {
155-
this.afterAgentCallbacks = afterAgentCallbacks;
159+
this.afterAgentCallbacks =
160+
afterAgentCallbacks == null
161+
? ImmutableList.of()
162+
: ImmutableList.copyOf(afterAgentCallbacks);
156163
}
157164
}

core/src/main/java/com/google/adk/agents/CallbackUtil.java

Lines changed: 45 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
import com.google.errorprone.annotations.CanIgnoreReturnValue;
2727
import io.reactivex.rxjava3.core.Maybe;
2828
import java.util.List;
29-
import org.jspecify.annotations.Nullable;
29+
import java.util.function.Function;
30+
import java.util.stream.Stream;
3031
import org.slf4j.Logger;
3132
import org.slf4j.LoggerFactory;
3233

@@ -37,65 +38,62 @@ public final class CallbackUtil {
3738
/**
3839
* Normalizes before-agent callbacks.
3940
*
40-
* @param beforeAgentCallback Callback list (sync or async).
41-
* @return normalized async callbacks, or null if input is null.
41+
* @param beforeAgentCallbacks Callback list (sync or async).
42+
* @return normalized async callbacks, or empty list if input is null.
4243
*/
4344
@CanIgnoreReturnValue
44-
public static @Nullable ImmutableList<BeforeAgentCallback> getBeforeAgentCallbacks(
45-
List<BeforeAgentCallbackBase> beforeAgentCallback) {
46-
if (beforeAgentCallback == null) {
47-
return null;
48-
} else if (beforeAgentCallback.isEmpty()) {
49-
return ImmutableList.of();
50-
} else {
51-
ImmutableList.Builder<BeforeAgentCallback> builder = ImmutableList.builder();
52-
for (BeforeAgentCallbackBase callback : beforeAgentCallback) {
53-
if (callback instanceof BeforeAgentCallback beforeAgentCallbackInstance) {
54-
builder.add(beforeAgentCallbackInstance);
55-
} else if (callback instanceof BeforeAgentCallbackSync beforeAgentCallbackSyncInstance) {
56-
builder.add(
57-
(callbackContext) ->
58-
Maybe.fromOptional(beforeAgentCallbackSyncInstance.call(callbackContext)));
59-
} else {
60-
logger.warn(
61-
"Invalid beforeAgentCallback callback type: {}. Ignoring this callback.",
62-
callback.getClass().getName());
63-
}
64-
}
65-
return builder.build();
66-
}
45+
public static ImmutableList<BeforeAgentCallback> getBeforeAgentCallbacks(
46+
List<BeforeAgentCallbackBase> beforeAgentCallbacks) {
47+
return getCallbacks(
48+
beforeAgentCallbacks,
49+
BeforeAgentCallback.class,
50+
BeforeAgentCallbackSync.class,
51+
sync -> (callbackContext -> Maybe.fromOptional(sync.call(callbackContext))),
52+
"beforeAgentCallbacks");
6753
}
6854

6955
/**
7056
* Normalizes after-agent callbacks.
7157
*
7258
* @param afterAgentCallback Callback list (sync or async).
73-
* @return normalized async callbacks, or null if input is null.
59+
* @return normalized async callbacks, or empty list if input is null.
7460
*/
7561
@CanIgnoreReturnValue
76-
public static @Nullable ImmutableList<AfterAgentCallback> getAfterAgentCallbacks(
62+
public static ImmutableList<AfterAgentCallback> getAfterAgentCallbacks(
7763
List<AfterAgentCallbackBase> afterAgentCallback) {
78-
if (afterAgentCallback == null) {
79-
return null;
80-
} else if (afterAgentCallback.isEmpty()) {
64+
return getCallbacks(
65+
afterAgentCallback,
66+
AfterAgentCallback.class,
67+
AfterAgentCallbackSync.class,
68+
sync -> (callbackContext -> Maybe.fromOptional(sync.call(callbackContext))),
69+
"afterAgentCallback");
70+
}
71+
72+
private static <B, A extends B, S extends B> ImmutableList<A> getCallbacks(
73+
List<B> callbacks,
74+
Class<A> asyncClass,
75+
Class<S> syncClass,
76+
Function<S, A> converter,
77+
String callbackTypeForLogging) {
78+
if (callbacks == null) {
8179
return ImmutableList.of();
82-
} else {
83-
ImmutableList.Builder<AfterAgentCallback> builder = ImmutableList.builder();
84-
for (AfterAgentCallbackBase callback : afterAgentCallback) {
85-
if (callback instanceof AfterAgentCallback afterAgentCallbackInstance) {
86-
builder.add(afterAgentCallbackInstance);
87-
} else if (callback instanceof AfterAgentCallbackSync afterAgentCallbackSyncInstance) {
88-
builder.add(
89-
(callbackContext) ->
90-
Maybe.fromOptional(afterAgentCallbackSyncInstance.call(callbackContext)));
91-
} else {
92-
logger.warn(
93-
"Invalid afterAgentCallback callback type: {}. Ignoring this callback.",
94-
callback.getClass().getName());
95-
}
96-
}
97-
return builder.build();
9880
}
81+
return callbacks.stream()
82+
.flatMap(
83+
callback -> {
84+
if (asyncClass.isInstance(callback)) {
85+
return Stream.of(asyncClass.cast(callback));
86+
} else if (syncClass.isInstance(callback)) {
87+
return Stream.of(converter.apply(syncClass.cast(callback)));
88+
} else {
89+
logger.warn(
90+
"Invalid {} callback type: {}. Ignoring this callback.",
91+
callbackTypeForLogging,
92+
callback.getClass().getName());
93+
return Stream.empty();
94+
}
95+
})
96+
.collect(ImmutableList.toImmutableList());
9997
}
10098

10199
private CallbackUtil() {}

core/src/main/java/com/google/adk/agents/LlmAgent.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -935,9 +935,8 @@ public Optional<String> outputKey() {
935935
return outputKey;
936936
}
937937

938-
@Nullable
939-
public BaseCodeExecutor codeExecutor() {
940-
return codeExecutor.orElse(null);
938+
public Optional<BaseCodeExecutor> codeExecutor() {
939+
return codeExecutor;
941940
}
942941

943942
public Model resolvedModel() {

core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -388,15 +388,15 @@ private Flowable<Event> runOneStep(InvocationContext context) {
388388
String agentToTransfer = event.actions().transferToAgent().get();
389389
logger.debug("Transferring to agent: {}", agentToTransfer);
390390
BaseAgent rootAgent = context.agent().rootAgent();
391-
BaseAgent nextAgent = rootAgent.findAgent(agentToTransfer);
392-
if (nextAgent == null) {
391+
Optional<BaseAgent> nextAgent = rootAgent.findAgent(agentToTransfer);
392+
if (nextAgent.isEmpty()) {
393393
String errorMsg = "Agent not found for transfer: " + agentToTransfer;
394394
logger.error(errorMsg);
395395
return postProcessedEvents.concatWith(
396396
Flowable.error(new IllegalStateException(errorMsg)));
397397
}
398398
return postProcessedEvents.concatWith(
399-
Flowable.defer(() -> nextAgent.runAsync(context)));
399+
Flowable.defer(() -> nextAgent.get().runAsync(context)));
400400
}
401401
return postProcessedEvents;
402402
});
@@ -574,14 +574,14 @@ public void onError(Throwable e) {
574574
Flowable<Event> events = Flowable.just(event);
575575
if (event.actions().transferToAgent().isPresent()) {
576576
BaseAgent rootAgent = invocationContext.agent().rootAgent();
577-
BaseAgent nextAgent =
577+
Optional<BaseAgent> nextAgent =
578578
rootAgent.findAgent(event.actions().transferToAgent().get());
579-
if (nextAgent == null) {
579+
if (nextAgent.isEmpty()) {
580580
throw new IllegalStateException(
581581
"Agent not found: " + event.actions().transferToAgent().get());
582582
}
583583
Flowable<Event> nextAgentEvents =
584-
nextAgent.runLive(invocationContext);
584+
nextAgent.get().runLive(invocationContext);
585585
events = Flowable.concat(events, nextAgentEvents);
586586
}
587587
return events;

0 commit comments

Comments
 (0)