Skip to content

Commit ca19f6f

Browse files
M-ElsaeedMohammed Ehab
andauthored
Allow ClientContext.Custom unmarshaling for non-string (JSON) values (#620)
* Allow ClientContext.Custom unmarshaling for non-string (JSON) values * Lint * Adding Tests * More tests * empty commit to retrigger flakey tests --------- Co-authored-by: Mohammed Ehab <[email protected]>
1 parent 9c32960 commit ca19f6f

File tree

3 files changed

+105
-0
lines changed

3 files changed

+105
-0
lines changed

lambda/invoke_loop_test.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,46 @@ func TestContextDeserializationErrors(t *testing.T) {
389389
}`, string(record.responses[2]))
390390
}
391391

392+
func TestClientContextWithNestedCustomValues(t *testing.T) {
393+
metadata := defaultInvokeMetadata()
394+
metadata.clientContext = `{
395+
"Client": {
396+
"app_title": "test",
397+
"installation_id": "install1",
398+
"app_version_code": "1.0",
399+
"app_package_name": "com.test"
400+
},
401+
"custom": {
402+
"bedrockAgentCoreTargetId": "target-123",
403+
"bedrockAgentCorePropagatedHeaders": {"x-id": "my-custom-id"}
404+
}
405+
}`
406+
407+
ts, record := runtimeAPIServer(`{}`, 1, metadata)
408+
defer ts.Close()
409+
handler := NewHandler(func(ctx context.Context) (interface{}, error) {
410+
lc, _ := lambdacontext.FromContext(ctx)
411+
return lc.ClientContext, nil
412+
})
413+
endpoint := strings.Split(ts.URL, "://")[1]
414+
_ = startRuntimeAPILoop(endpoint, handler)
415+
416+
expected := `{
417+
"Client": {
418+
"installation_id": "install1",
419+
"app_title": "test",
420+
"app_version_code": "1.0",
421+
"app_package_name": "com.test"
422+
},
423+
"env": null,
424+
"custom": {
425+
"bedrockAgentCoreTargetId": "target-123",
426+
"bedrockAgentCorePropagatedHeaders": "{\"x-id\": \"my-custom-id\"}"
427+
}
428+
}`
429+
assert.JSONEq(t, expected, string(record.responses[0]))
430+
}
431+
392432
type invalidPayload struct{}
393433

394434
func (invalidPayload) MarshalJSON() ([]byte, error) {

lambdacontext/context.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ package lambdacontext
1111

1212
import (
1313
"context"
14+
"encoding/json"
1415
"os"
1516
"strconv"
1617
)
@@ -68,6 +69,35 @@ type ClientContext struct {
6869
Custom map[string]string `json:"custom"`
6970
}
7071

72+
// UnmarshalJSON implements custom JSON unmarshaling for ClientContext.
73+
// This handles the case where values in the "custom" map are not strings
74+
// (e.g. nested JSON objects), by serializing non-string values back to
75+
// their JSON string representation.
76+
func (cc *ClientContext) UnmarshalJSON(data []byte) error {
77+
var raw struct {
78+
Client ClientApplication `json:"Client"`
79+
Env map[string]string `json:"env"`
80+
Custom map[string]json.RawMessage `json:"custom"`
81+
}
82+
if err := json.Unmarshal(data, &raw); err != nil {
83+
return err
84+
}
85+
cc.Client = raw.Client
86+
cc.Env = raw.Env
87+
if raw.Custom != nil {
88+
cc.Custom = make(map[string]string, len(raw.Custom))
89+
for k, v := range raw.Custom {
90+
var s string
91+
if err := json.Unmarshal(v, &s); err == nil {
92+
cc.Custom[k] = s
93+
} else {
94+
cc.Custom[k] = string(v)
95+
}
96+
}
97+
}
98+
return nil
99+
}
100+
71101
// CognitoIdentity is the cognito identity used by the calling application.
72102
type CognitoIdentity struct {
73103
CognitoIdentityID string

lambdacontext/context_test.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package lambdacontext
2+
3+
import (
4+
"encoding/json"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
"github.com/stretchr/testify/require"
9+
)
10+
11+
func TestClientContextUnmarshalJSON(t *testing.T) {
12+
t.Run("non-string custom values are serialized to string", func(t *testing.T) {
13+
input := `{
14+
"Client": {"installation_id": "install1"},
15+
"custom": {
16+
"key1": "stringval",
17+
"key2": {"nested": "object"},
18+
"key3": 42
19+
}
20+
}`
21+
var cc ClientContext
22+
err := json.Unmarshal([]byte(input), &cc)
23+
require.NoError(t, err)
24+
assert.Equal(t, "install1", cc.Client.InstallationID)
25+
assert.Equal(t, "stringval", cc.Custom["key1"])
26+
assert.JSONEq(t, `{"nested":"object"}`, cc.Custom["key2"])
27+
assert.Equal(t, "42", cc.Custom["key3"])
28+
})
29+
30+
t.Run("invalid JSON returns error", func(t *testing.T) {
31+
var cc ClientContext
32+
err := json.Unmarshal([]byte(`not valid json`), &cc)
33+
assert.Error(t, err)
34+
})
35+
}

0 commit comments

Comments
 (0)