Skip to content

Commit a35a5e9

Browse files
authored
fix: Ensure agent wakes up on A2A task completion (#384)
Enabled A2A monitoring in config and fixed agent event channel integration. Agent now properly receives wake-up events when A2A tasks complete. Fixed task state handling with case-insensitive comparisons and prevented premature agent exit. ## Changes Made: 1. **Enabled A2A monitoring** in `.infer/config.yaml` 2. **Fixed import ordering** in multiple files for consistency 3. **Added agent event channel integration** to wake up the agent when A2A tasks complete 4. **Fixed A2A task state handling** with case-insensitive state comparisons 5. **Prevented agent from exiting** while A2A tasks are pending in idle state ## Key Benefits: - A2A monitoring is now enabled by default - Agent stays alive while A2A tasks are pending - Proper event notification when tasks complete - Robust state handling for various task status formats - Improved reliability of A2A task execution
1 parent f7d2257 commit a35a5e9

File tree

7 files changed

+71
-42
lines changed

7 files changed

+71
-42
lines changed

.infer/config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,7 @@ chat:
651651
cost: true
652652
git_branch: true
653653
a2a:
654-
enabled: false
654+
enabled: true
655655
cache:
656656
enabled: true
657657
ttl: 300

cmd/agent_test.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@ import (
55
"testing"
66
"time"
77

8+
sdk "github.com/inference-gateway/sdk"
9+
10+
domainmocks "github.com/inference-gateway/cli/tests/mocks/domain"
11+
812
config "github.com/inference-gateway/cli/config"
913
domain "github.com/inference-gateway/cli/internal/domain"
10-
domainmocks "github.com/inference-gateway/cli/tests/mocks/domain"
11-
sdk "github.com/inference-gateway/sdk"
1214
)
1315

1416
func TestIsModelAvailable(t *testing.T) {

internal/agent/agent.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@ import (
77
"sync"
88
"time"
99

10+
sdk "github.com/inference-gateway/sdk"
11+
1012
constants "github.com/inference-gateway/cli/internal/constants"
1113
domain "github.com/inference-gateway/cli/internal/domain"
1214
logger "github.com/inference-gateway/cli/internal/logger"
1315
services "github.com/inference-gateway/cli/internal/services"
14-
sdk "github.com/inference-gateway/sdk"
1516
)
1617

1718
// AgentServiceImpl implements the AgentService interface with direct chat functionality
@@ -457,6 +458,10 @@ func (s *AgentServiceImpl) RunWithStream(ctx context.Context, req *domain.AgentR
457458
taskTracker,
458459
)
459460

461+
if monitor != nil {
462+
monitor.SetAgentEventChannel(agent.GetEventChannel())
463+
}
464+
460465
agent.Start()
461466
agent.Wait()
462467
}()

internal/agent/agent_event_driven.go

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -175,16 +175,11 @@ func (a *EventDrivenAgent) registerHandler(handler domain.StateHandler) {
175175

176176
// Start begins the event-driven agent execution
177177
func (a *EventDrivenAgent) Start() {
178-
logger.Debug("starting event-driven agent",
179-
"request_id", a.req.RequestID,
180-
"max_turns", a.agentCtx.MaxTurns)
181-
182178
_ = a.stateMachine.Transition(a.agentCtx, domain.StateIdle)
183179

184180
a.wg.Add(1)
185181
go a.processEvents()
186182

187-
logger.Debug("triggering initial message received event")
188183
a.events <- domain.MessageReceivedEvent{}
189184
}
190185

@@ -194,42 +189,48 @@ func (a *EventDrivenAgent) Wait() {
194189
close(a.events)
195190
}
196191

192+
// GetEventChannel returns the agent's internal event channel for external components
193+
// to send wake-up events (e.g., when A2A tasks complete)
194+
func (a *EventDrivenAgent) GetEventChannel() chan<- domain.AgentEvent {
195+
return a.events
196+
}
197+
197198
// processEvents is the main event processing loop
198199
func (a *EventDrivenAgent) processEvents() {
199200
defer a.wg.Done()
200-
defer logger.Debug("agent event processing stopped", "request_id", a.req.RequestID)
201201

202202
for {
203203
select {
204204
case <-a.cancelChan:
205-
logger.Debug("agent cancelled", "request_id", a.req.RequestID)
206205
_ = a.stateMachine.Transition(a.agentCtx, domain.StateCancelled)
207206
a.eventPublisher.publishChatComplete("", []sdk.ChatCompletionMessageToolCall{}, a.service.GetMetrics(a.req.RequestID))
208207
return
209208

210209
case event, ok := <-a.events:
211210
if !ok {
212-
logger.Debug("event channel closed")
213211
return
214212
}
215213

216-
logger.Debug("received event",
217-
"event_type", event.EventType(),
218-
"current_state", a.stateMachine.GetCurrentState(),
219-
"turn", a.agentCtx.Turns)
220-
221214
a.handleEvent(event)
222215

223216
currentState := a.stateMachine.GetCurrentState()
224-
if currentState == domain.StateIdle ||
225-
currentState == domain.StateStopped ||
217+
if currentState == domain.StateStopped ||
226218
currentState == domain.StateCancelled ||
227219
currentState == domain.StateError {
228-
logger.Debug("agent reached terminal state",
229-
"state", currentState,
230-
"total_turns", a.agentCtx.Turns)
231220
return
232221
}
222+
223+
if currentState == domain.StateIdle {
224+
hasPendingTasks := a.taskTracker != nil && len(a.taskTracker.GetAllPollingTasks()) > 0
225+
if hasPendingTasks {
226+
logger.Debug("agent in Idle state but has pending A2A tasks, staying alive",
227+
"pending_tasks", len(a.taskTracker.GetAllPollingTasks()))
228+
} else {
229+
logger.Debug("agent reached Idle state with no pending tasks",
230+
"total_turns", a.agentCtx.Turns)
231+
return
232+
}
233+
}
233234
}
234235
}
235236
}
@@ -240,11 +241,6 @@ func (a *EventDrivenAgent) handleEvent(event domain.AgentEvent) {
240241
defer a.mu.Unlock()
241242

242243
currentState := a.stateMachine.GetCurrentState()
243-
logger.Debug("handling event in state",
244-
"event", event.EventType(),
245-
"state", currentState,
246-
"turn", a.agentCtx.Turns,
247-
"queue_empty", a.agentCtx.MessageQueue.IsEmpty())
248244

249245
handler, exists := a.stateHandlers[currentState]
250246
if !exists {
@@ -257,8 +253,4 @@ func (a *EventDrivenAgent) handleEvent(event domain.AgentEvent) {
257253
"state", currentState.String(),
258254
"error", err)
259255
}
260-
261-
logger.Debug("event handled",
262-
"event", event.EventType(),
263-
"new_state", a.stateMachine.GetCurrentState())
264256
}

internal/agent/tools/a2a_task.go

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@ import (
77
"strings"
88
"time"
99

10-
client "github.com/inference-gateway/adk/client"
1110
adk "github.com/inference-gateway/adk/types"
11+
sdk "github.com/inference-gateway/sdk"
12+
13+
client "github.com/inference-gateway/adk/client"
1214
config "github.com/inference-gateway/cli/config"
1315
domain "github.com/inference-gateway/cli/internal/domain"
14-
sdk "github.com/inference-gateway/sdk"
1516
)
1617

1718
// A2ASubmitTaskTool handles A2A task submission and management
@@ -286,7 +287,6 @@ func (t *A2ASubmitTaskTool) pollTaskInBackground(
286287
adkClient := t.getOrCreateClient(agentURL)
287288

288289
strategy := t.config.A2A.Task.PollingStrategy
289-
290290
currentInterval := t.initializePollingStrategy(agentURL, taskID, strategy)
291291
state.CurrentInterval = currentInterval
292292
state.NextPollTime = time.Now().Add(currentInterval)
@@ -427,8 +427,10 @@ func (t *A2ASubmitTaskTool) publishStatusUpdate(state *domain.TaskPollingState,
427427
}
428428

429429
func (t *A2ASubmitTaskTool) handleTaskState(agentURL, _ /* taskID */ string, _ /* pollAttempt */ int, state *domain.TaskPollingState, currentTask adk.Task, _ /* pollingDetails */ string) (bool, *domain.ToolExecutionResult) {
430-
switch currentTask.Status.State {
431-
case adk.TaskStateCompleted:
430+
normalizedState := strings.ToLower(string(currentTask.Status.State))
431+
432+
switch {
433+
case normalizedState == strings.ToLower(string(adk.TaskStateCompleted)) || normalizedState == "completed":
432434
finalResult := ""
433435
if currentTask.Status.Message != nil {
434436
finalResult = t.extractTextFromParts(currentTask.Status.Message.Parts)
@@ -451,7 +453,7 @@ func (t *A2ASubmitTaskTool) handleTaskState(agentURL, _ /* taskID */ string, _ /
451453
}
452454
return true, result
453455

454-
case adk.TaskStateFailed:
456+
case normalizedState == strings.ToLower(string(adk.TaskStateFailed)) || normalizedState == "failed":
455457
finalResult := ""
456458
if currentTask.Status.Message != nil {
457459
finalResult = t.extractTextFromParts(currentTask.Status.Message.Parts)
@@ -473,7 +475,7 @@ func (t *A2ASubmitTaskTool) handleTaskState(agentURL, _ /* taskID */ string, _ /
473475
}
474476
return true, result
475477

476-
case adk.TaskStateInputRequired:
478+
case normalizedState == strings.ToLower(string(adk.TaskStateInputRequired)) || normalizedState == "input-required" || normalizedState == "input_required":
477479
inputMessage := ""
478480
if currentTask.Status.Message != nil {
479481
inputMessage = t.extractTextFromParts(currentTask.Status.Message.Parts)
@@ -495,7 +497,7 @@ func (t *A2ASubmitTaskTool) handleTaskState(agentURL, _ /* taskID */ string, _ /
495497
}
496498
return true, result
497499

498-
case adk.TaskStateCancelled:
500+
case normalizedState == strings.ToLower(string(adk.TaskStateCancelled)) || normalizedState == "cancelled" || normalizedState == "canceled":
499501
cancelMessage := ""
500502
if currentTask.Status.Message != nil {
501503
cancelMessage = t.extractTextFromParts(currentTask.Status.Message.Parts)

internal/handlers/chat_handler.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,16 @@ import (
1313
"time"
1414

1515
tea "github.com/charmbracelet/bubbletea"
16+
17+
sdk "github.com/inference-gateway/sdk"
18+
1619
config "github.com/inference-gateway/cli/config"
1720
tools "github.com/inference-gateway/cli/internal/agent/tools"
1821
domain "github.com/inference-gateway/cli/internal/domain"
1922
logger "github.com/inference-gateway/cli/internal/logger"
2023
services "github.com/inference-gateway/cli/internal/services"
2124
shortcuts "github.com/inference-gateway/cli/internal/shortcuts"
2225
utils "github.com/inference-gateway/cli/internal/utils"
23-
sdk "github.com/inference-gateway/sdk"
2426
)
2527

2628
type ChatHandler struct {
@@ -1583,6 +1585,8 @@ func (h *ChatHandler) handleA2ATaskStatusUpdate(
15831585
func (h *ChatHandler) handleMessageQueued(
15841586
_ domain.MessageQueuedEvent,
15851587
) (tea.Model, tea.Cmd) {
1588+
chatSession := h.stateManager.GetChatSession()
1589+
15861590
var cmds []tea.Cmd
15871591

15881592
cmds = append(cmds, func() tea.Msg {
@@ -1600,7 +1604,6 @@ func (h *ChatHandler) handleMessageQueued(
16001604
}
16011605
})
16021606

1603-
chatSession := h.stateManager.GetChatSession()
16041607
if chatSession != nil && chatSession.EventChannel != nil {
16051608
cmds = append(cmds, h.ListenForChatEvents(chatSession.EventChannel))
16061609
}

internal/services/a2a_polling_monitor.go

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@ import (
77
"time"
88

99
adk "github.com/inference-gateway/adk/types"
10+
sdk "github.com/inference-gateway/sdk"
11+
1012
domain "github.com/inference-gateway/cli/internal/domain"
1113
logger "github.com/inference-gateway/cli/internal/logger"
12-
sdk "github.com/inference-gateway/sdk"
1314
)
1415

1516
type A2APollingMonitor struct {
@@ -22,6 +23,7 @@ type A2APollingMonitor struct {
2223
activeMonitors map[string]context.CancelFunc
2324
stopChan chan struct{}
2425
stopped bool
26+
agentEventChan chan<- domain.AgentEvent
2527
}
2628

2729
func NewA2APollingMonitor(
@@ -43,6 +45,14 @@ func NewA2APollingMonitor(
4345
}
4446
}
4547

48+
// SetAgentEventChannel sets the agent's internal event channel for waking up the agent
49+
// when an A2A task completes. This should be called after the agent is created.
50+
func (m *A2APollingMonitor) SetAgentEventChannel(eventChan chan<- domain.AgentEvent) {
51+
m.mu.Lock()
52+
defer m.mu.Unlock()
53+
m.agentEventChan = eventChan
54+
}
55+
4656
func (m *A2APollingMonitor) Start(ctx context.Context) {
4757
ticker := time.NewTicker(1 * time.Second)
4858
defer ticker.Stop()
@@ -305,6 +315,21 @@ func (m *A2APollingMonitor) addResultToMessageQueue(taskID string, result *domai
305315

306316
m.messageQueue.Enqueue(message, m.requestID)
307317

318+
m.mu.RLock()
319+
agentChan := m.agentEventChan
320+
m.mu.RUnlock()
321+
322+
if agentChan != nil {
323+
select {
324+
case agentChan <- domain.MessageReceivedEvent{}:
325+
logger.Debug("Sent wake-up event to agent for A2A task completion",
326+
"task_id", taskID)
327+
default:
328+
logger.Debug("Failed to send wake-up event to agent - channel full",
329+
"task_id", taskID)
330+
}
331+
}
332+
308333
if m.eventChan != nil {
309334
event := domain.MessageQueuedEvent{
310335
RequestID: m.requestID,

0 commit comments

Comments
 (0)