Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 154 additions & 5 deletions agentendpoint/agentendpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ type agentEndpointServiceTestServer struct {
streamClose chan struct{}
streamSend chan struct{}
permissionError chan struct{}
resourceExhaustedError chan struct{}
taskStart bool
execTaskProgress bool
patchTaskProgress bool
Expand All @@ -142,25 +143,42 @@ type agentEndpointServiceTestServer struct {
patchTaskComplete bool
applyConfigTaskComplete bool
runTaskIDs []string
registerAgentReq *agentendpointpb.RegisterAgentRequest
}

func newAgentEndpointServiceTestServer() *agentEndpointServiceTestServer {
return &agentEndpointServiceTestServer{
streamClose: make(chan struct{}, 1),
streamSend: make(chan struct{}, 1),
permissionError: make(chan struct{}, 1),
streamClose: make(chan struct{}, 1),
streamSend: make(chan struct{}, 1),
permissionError: make(chan struct{}, 1),
resourceExhaustedError: make(chan struct{}, 1),
}
}

// causePermissionError triggers a PermissionDenied error in ReceiveTaskNotification.
func (s *agentEndpointServiceTestServer) causePermissionError() {
s.permissionError <- struct{}{}
}

// causeResourceExhaustedError triggers a ResourceExhausted error in ReceiveTaskNotification.
func (s *agentEndpointServiceTestServer) causeResourceExhaustedError() {
s.resourceExhaustedError <- struct{}{}
}

// ReceiveTaskNotification sends task notifications or errors based on the state of the test server's channels.
func (s *agentEndpointServiceTestServer) ReceiveTaskNotification(req *agentendpointpb.ReceiveTaskNotificationRequest, srv agentendpointpb.AgentEndpointService_ReceiveTaskNotificationServer) error {
for {
select {
case <-s.streamClose:
return nil
case <-srv.Context().Done():
return srv.Context().Err()
case <-s.streamSend:
srv.Send(&agentendpointpb.ReceiveTaskNotificationResponse{})
case <-s.permissionError:
return status.Errorf(codes.PermissionDenied, "")
case <-s.resourceExhaustedError:
return status.Errorf(codes.ResourceExhausted, "")
}
}
}
Expand Down Expand Up @@ -217,8 +235,10 @@ func (s *agentEndpointServiceTestServer) ReportTaskComplete(ctx context.Context,
return &agentendpointpb.ReportTaskCompleteResponse{}, nil
}

func (*agentEndpointServiceTestServer) RegisterAgent(ctx context.Context, req *agentendpointpb.RegisterAgentRequest) (*agentendpointpb.RegisterAgentResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method RegisterAgent not implemented")
// RegisterAgent is a mock implementation of the RegisterAgent RPC.
func (s *agentEndpointServiceTestServer) RegisterAgent(ctx context.Context, req *agentendpointpb.RegisterAgentRequest) (*agentendpointpb.RegisterAgentResponse, error) {
s.registerAgentReq = req
return &agentendpointpb.RegisterAgentResponse{}, nil
}

func (*agentEndpointServiceTestServer) ReportInventory(ctx context.Context, req *agentendpointpb.ReportInventoryRequest) (*agentendpointpb.ReportInventoryResponse, error) {
Expand Down Expand Up @@ -285,6 +305,13 @@ func TestWaitForTaskErrors(t *testing.T) {
if err := tc.client.waitForTask(ctx); err != nil {
t.Errorf("did not expect error from a closed stream: %v", err)
}

// Error from receiveTaskNotification
ctx, cancel := context.WithCancel(ctx)
cancel()
if err := tc.client.waitForTask(ctx); !errors.Is(err, context.Canceled) {
t.Errorf("did not get expected context.Canceled, got: %v", err)
}
}

func TestLoadPatchTaskFromState(t *testing.T) {
Expand Down Expand Up @@ -353,3 +380,125 @@ func TestLoadPatchTaskFromState(t *testing.T) {
t.Errorf("first entry in runTaskIDs does not match taskID, %q, %q", srv.runTaskIDs, taskID)
}
}

// TestClose verifies that the Client can be closed multiple times without panicking.
func TestClose(t *testing.T) {
ctx := context.Background()
tc, err := newTestClient(ctx, newAgentEndpointServiceTestServer())
if err != nil {
t.Fatalf("newTestClient error: %v", err)
}
defer tc.s.Stop()

if tc.client.Closed() {
t.Errorf("Closed() = true, want false")
}
tc.client.Close()
}

// TestWaitForTaskNotification verifies that WaitForTaskNotification correctly handles various scenarios.
// It tests successful notification, service disablement, multiple calls, and context cancellation.
func TestWaitForTaskNotification(t *testing.T) {
ctx := context.Background()

waitUntil := func(t *testing.T, condition func() bool, timeout time.Duration) bool {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
for {
if condition() {
return true
}
select {
case <-ctx.Done():
return false
case <-time.After(10 * time.Millisecond):
}
}
}

tests := []struct {
name string
setup func(srv *agentEndpointServiceTestServer, cancel context.CancelFunc)
check func(t *testing.T, tc *testClient)
}{
{
name: "service disabled error handling",
setup: func(srv *agentEndpointServiceTestServer, cancel context.CancelFunc) {
srv.causePermissionError()
},
check: func(t *testing.T, tc *testClient) {
if !waitUntil(t, tc.client.Closed, 400*time.Millisecond) {
t.Error("Expected client to be closed after service disabled error")
}
},
},
{
name: "multiple notification calls",
setup: func(srv *agentEndpointServiceTestServer, cancel context.CancelFunc) {
},
check: func(t *testing.T, tc *testClient) {
tc.client.WaitForTaskNotification(context.Background())
if waitUntil(t, tc.client.Closed, 50*time.Millisecond) {
t.Error("Expected client to remain open")
}
},
},
{
name: "context cancellation handling",
setup: func(srv *agentEndpointServiceTestServer, cancel context.CancelFunc) {
cancel()
},
check: func(t *testing.T, tc *testClient) {
if waitUntil(t, tc.client.Closed, 50*time.Millisecond) {
t.Error("Expected client to remain open after context cancellation")
}
},
},
{
name: "resource exhausted error handling",
setup: func(srv *agentEndpointServiceTestServer, cancel context.CancelFunc) {
srv.causeResourceExhaustedError()
},
check: func(t *testing.T, tc *testClient) {
if waitUntil(t, tc.client.Closed, 100*time.Millisecond) {
t.Error("Expected client to remain open after resource exhausted error")
}
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tCtx, tCancel := context.WithCancel(ctx)
defer tCancel()
srv := newAgentEndpointServiceTestServer()
tc, err := newTestClient(tCtx, srv)
if err != nil {
t.Fatalf("%s: newTestClient error: %v", tt.name, err)
}
defer tc.s.Stop()

tt.setup(srv, tCancel)
tc.client.WaitForTaskNotification(tCtx)

if tt.check != nil {
tt.check(t, tc)
}
})
}
}

// TestRegisterAgent verifies that RegisterAgent correctly sends a registration request to the server.
func TestRegisterAgent(t *testing.T) {
ctx := context.Background()
tc, err := newTestClient(ctx, newAgentEndpointServiceTestServer())
if err != nil {
t.Fatalf("newTestClient error: %v", err)
}
defer tc.s.Stop()

if err := tc.client.RegisterAgent(ctx); err != nil {
t.Errorf("RegisterAgent() error: %v", err)
}
}
Loading