diff --git a/protocol/triple/triple_invoker.go b/protocol/triple/triple_invoker.go index 53faaf0674..64ea5d0b32 100644 --- a/protocol/triple/triple_invoker.go +++ b/protocol/triple/triple_invoker.go @@ -21,6 +21,8 @@ import ( "context" "errors" "fmt" + "net/http" + "strings" "sync" ) @@ -161,20 +163,28 @@ func mergeAttachmentToOutgoing(ctx context.Context, inv base.Invocation) (contex if timeout, ok := inv.GetAttachment(constant.TimeoutKey); ok { ctx = context.WithValue(ctx, tri.TimeoutKey{}, timeout) } + header := cloneOutgoingHeader(tri.ExtractFromOutgoingContext(ctx)) for key, valRaw := range inv.Attachments() { + lowerKey := strings.ToLower(key) if str, ok := valRaw.(string); ok { - ctx = tri.AppendToOutgoingContext(ctx, key, str) + header[lowerKey] = []string{str} continue } if strs, ok := valRaw.([]string); ok { - for _, str := range strs { - ctx = tri.AppendToOutgoingContext(ctx, key, str) - } + header[lowerKey] = append([]string(nil), strs...) continue } return ctx, fmt.Errorf("triple attachments value with key = %s is invalid, which should be string or []string", key) } - return ctx, nil + return tri.NewOutgoingContext(ctx, header), nil +} + +func cloneOutgoingHeader(header http.Header) http.Header { + cloned := make(http.Header, len(header)) + for key, vals := range header { + cloned[strings.ToLower(key)] = append([]string(nil), vals...) + } + return cloned } // parseInvocation retrieves information from invocation. @@ -202,18 +212,8 @@ func parseInvocation(ctx context.Context, url *common.URL, invocation base.Invoc return callType, inRaw, method, nil } -// parseAttachments retrieves attachments from users passed-in and URL, then injects them into ctx +// parseAttachments injects pre-defined URL attachments into invocation. func parseAttachments(ctx context.Context, url *common.URL, invocation base.Invocation) { - // retrieve users passed-in attachment - attaRaw := ctx.Value(constant.AttachmentKey) - if attaRaw != nil { - if userAtta, ok := attaRaw.(map[string]any); ok { - for key, val := range userAtta { - invocation.SetAttachment(key, val) - } - } - } - // set pre-defined attachments for _, key := range triAttachmentKeys { if val := url.GetParam(key, ""); len(val) > 0 { invocation.SetAttachment(key, val) diff --git a/protocol/triple/triple_invoker_test.go b/protocol/triple/triple_invoker_test.go index 7d7a2e30e9..abcb7147ac 100644 --- a/protocol/triple/triple_invoker_test.go +++ b/protocol/triple/triple_invoker_test.go @@ -19,7 +19,9 @@ package triple import ( "context" + "fmt" "net/http" + "strings" "sync" "testing" ) @@ -27,13 +29,22 @@ import ( import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/propagation" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/sdk/trace/tracetest" + oteltrace "go.opentelemetry.io/otel/trace" ) import ( "dubbo.apache.org/dubbo-go/v3/common" "dubbo.apache.org/dubbo-go/v3/common/constant" + "dubbo.apache.org/dubbo-go/v3/common/extension" + _ "dubbo.apache.org/dubbo-go/v3/filter/otel/trace" "dubbo.apache.org/dubbo-go/v3/protocol/base" "dubbo.apache.org/dubbo-go/v3/protocol/invocation" + "dubbo.apache.org/dubbo-go/v3/protocol/result" tri "dubbo.apache.org/dubbo-go/v3/protocol/triple/triple_protocol" ) @@ -107,7 +118,7 @@ func Test_parseAttachments(t *testing.T) { ctx func() context.Context url *common.URL invo func() base.Invocation - expect func(t *testing.T, ctx context.Context, err error) + expect func(t *testing.T, inv base.Invocation) }{ { desc: "url has pre-defined keys in triAttachmentKeys", @@ -121,48 +132,28 @@ func Test_parseAttachments(t *testing.T) { invo: func() base.Invocation { return invocation.NewRPCInvocationWithOptions() }, - expect: func(t *testing.T, ctx context.Context, err error) { - require.NoError(t, err) - header := http.Header(tri.ExtractFromOutgoingContext(ctx)) - assert.NotNil(t, header) - assert.Equal(t, "interface", header.Get(constant.InterfaceKey)) - assert.Equal(t, "token", header.Get(constant.TokenKey)) + expect: func(t *testing.T, inv base.Invocation) { + assert.Equal(t, "interface", inv.GetAttachmentInterface(constant.InterfaceKey)) + assert.Equal(t, "token", inv.GetAttachmentInterface(constant.TokenKey)) }, }, { - desc: "user passed-in legal attachments", + desc: "ctx attachments are ignored", ctx: func() context.Context { userDefined := make(map[string]any) userDefined["key1"] = "val1" - userDefined["key2"] = []string{"key2_1", "key2_2"} - return context.WithValue(context.Background(), constant.AttachmentKey, userDefined) - }, - url: common.NewURLWithOptions(), - invo: func() base.Invocation { - return invocation.NewRPCInvocationWithOptions() - }, - expect: func(t *testing.T, ctx context.Context, err error) { - require.NoError(t, err) - header := http.Header(tri.ExtractFromOutgoingContext(ctx)) - assert.NotNil(t, header) - assert.Equal(t, "val1", header.Get("key1")) - assert.Equal(t, []string{"key2_1", "key2_2"}, header.Values("key2")) - }, - }, - { - desc: "user passed-in illegal attachments", - ctx: func() context.Context { - userDefined := make(map[string]any) - userDefined["key1"] = 1 + userDefined["traceparent"] = "old-trace" return context.WithValue(context.Background(), constant.AttachmentKey, userDefined) }, url: common.NewURLWithOptions(), invo: func() base.Invocation { return invocation.NewRPCInvocationWithOptions() }, - expect: func(t *testing.T, ctx context.Context, err error) { - require.Error(t, err) - assert.Contains(t, err.Error(), "invalid") + expect: func(t *testing.T, inv base.Invocation) { + _, ok := inv.GetAttachment("key1") + assert.False(t, ok) + _, ok = inv.GetAttachment("traceparent") + assert.False(t, ok) }, }, } @@ -172,8 +163,7 @@ func Test_parseAttachments(t *testing.T) { ctx := test.ctx() inv := test.invo() parseAttachments(ctx, test.url, inv) - ctx, err := mergeAttachmentToOutgoing(ctx, inv) - test.expect(t, ctx, err) + test.expect(t, inv) }) } } @@ -451,7 +441,7 @@ func Test_mergeAttachmentToOutgoing(t *testing.T) { expect: func(t *testing.T, ctx context.Context, err error) { require.NoError(t, err) header := http.Header(tri.ExtractFromOutgoingContext(ctx)) - assert.Equal(t, "custom-value", header.Get("custom-key")) + assert.Equal(t, []string{"custom-value"}, header["custom-key"]) }, }, { @@ -466,7 +456,40 @@ func Test_mergeAttachmentToOutgoing(t *testing.T) { expect: func(t *testing.T, ctx context.Context, err error) { require.NoError(t, err) header := http.Header(tri.ExtractFromOutgoingContext(ctx)) - assert.Equal(t, []string{"val1", "val2"}, header.Values("multi-key")) + assert.Equal(t, []string{"val1", "val2"}, header["multi-key"]) + }, + }, + { + desc: "preserves unrelated existing outgoing header", + ctx: tri.NewOutgoingContext(context.Background(), http.Header{ + "existing-header": []string{"existing-value"}, + }), + invo: func() base.Invocation { + return invocation.NewRPCInvocationWithOptions( + invocation.WithAttachment("custom-key", "custom-value"), + ) + }, + expect: func(t *testing.T, ctx context.Context, err error) { + require.NoError(t, err) + header := http.Header(tri.ExtractFromOutgoingContext(ctx)) + assert.Equal(t, []string{"existing-value"}, header["existing-header"]) + assert.Equal(t, []string{"custom-value"}, header["custom-key"]) + }, + }, + { + desc: "overwrites existing traceparent instead of appending", + ctx: tri.NewOutgoingContext(context.Background(), http.Header{ + "traceparent": []string{"old-traceparent"}, + }), + invo: func() base.Invocation { + return invocation.NewRPCInvocationWithOptions( + invocation.WithAttachment("traceparent", "new-traceparent"), + ) + }, + expect: func(t *testing.T, ctx context.Context, err error) { + require.NoError(t, err) + header := http.Header(tri.ExtractFromOutgoingContext(ctx)) + assert.Equal(t, []string{"new-traceparent"}, header["traceparent"]) }, }, { @@ -503,3 +526,217 @@ func Test_mergeAttachmentToOutgoing(t *testing.T) { }) } } + +func Test_mergeAttachmentToOutgoing_DoesNotMutatePreviousContext(t *testing.T) { + inv1 := invocation.NewRPCInvocationWithOptions( + invocation.WithAttachment("traceparent", "first-traceparent"), + ) + ctx1, err := mergeAttachmentToOutgoing(context.Background(), inv1) + require.NoError(t, err) + + inv2 := invocation.NewRPCInvocationWithOptions( + invocation.WithAttachment("traceparent", "second-traceparent"), + ) + ctx2, err := mergeAttachmentToOutgoing(ctx1, inv2) + require.NoError(t, err) + + header1 := http.Header(tri.ExtractFromOutgoingContext(ctx1)) + header2 := http.Header(tri.ExtractFromOutgoingContext(ctx2)) + + assert.Equal(t, []string{"first-traceparent"}, header1["traceparent"]) + assert.Equal(t, []string{"second-traceparent"}, header2["traceparent"]) +} + +type capturedTripleCall struct { + method string + activeSpanContext oteltrace.SpanContext + outgoingTraceparent string +} + +type traceCaptureInvoker struct { + base.BaseInvoker + calls []capturedTripleCall +} + +func newTraceCaptureInvoker(url *common.URL) *traceCaptureInvoker { + return &traceCaptureInvoker{ + BaseInvoker: *base.NewBaseInvoker(url), + } +} + +func (i *traceCaptureInvoker) Invoke(ctx context.Context, inv base.Invocation) result.Result { + _, _, _, err := parseInvocation(ctx, i.GetURL(), inv) + if err != nil { + return &result.RPCResult{Err: err} + } + + ctx, err = mergeAttachmentToOutgoing(ctx, inv) + if err != nil { + return &result.RPCResult{Err: err} + } + + header := cloneOutgoingHeader(tri.ExtractFromOutgoingContext(ctx)) + i.calls = append(i.calls, capturedTripleCall{ + method: inv.MethodName(), + activeSpanContext: oteltrace.SpanContextFromContext(ctx), + outgoingTraceparent: firstHeaderValue(header, "traceparent"), + }) + return &result.RPCResult{} +} + +func TestTripleClientOTELTraceparentIsolation(t *testing.T) { + spanRecorder := tracetest.NewSpanRecorder() + tracerProvider := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(spanRecorder)) + defer func() { + _ = tracerProvider.Shutdown(context.Background()) + }() + + oldTracerProvider := otel.GetTracerProvider() + oldPropagator := otel.GetTextMapPropagator() + otel.SetTracerProvider(tracerProvider) + otel.SetTextMapPropagator(propagation.TraceContext{}) + defer func() { + otel.SetTracerProvider(oldTracerProvider) + otel.SetTextMapPropagator(oldPropagator) + }() + + clientFilter, ok := extension.GetFilter(constant.OTELClientTraceKey) + require.True(t, ok) + + incomingTraceparent := "00-4bf92f3577b34da6a3ce929d0e0e4736-1111111111111111-01" + upstreamCtx := propagation.TraceContext{}.Extract( + context.Background(), + propagation.MapCarrier{"traceparent": incomingTraceparent}, + ) + + serverTracer := tracerProvider.Tracer("triple-otel-repro") + serverCtx, serverSpan := serverTracer.Start(upstreamCtx, "service-a-handler", oteltrace.WithSpanKind(oteltrace.SpanKindServer)) + serverCtx = context.WithValue(serverCtx, constant.AttachmentKey, map[string]any{ + "traceparent": incomingTraceparent, + }) + serverCtx = tri.NewOutgoingContext(serverCtx, http.Header{ + "x-seed": []string{"seed"}, + }) + + invoker := newTraceCaptureInvoker(common.NewURLWithOptions( + common.WithProtocol("tri"), + common.WithInterface("org.apache.dubbo.test.DownstreamService"), + )) + + callB := newTraceInvocation("CallServiceB") + callC := newTraceInvocation("CallServiceC") + + resB := clientFilter.Invoke(serverCtx, invoker, callB) + require.NoError(t, resB.Error()) + afterBHeader := cloneOutgoingHeader(tri.ExtractFromOutgoingContext(serverCtx)) + + resC := clientFilter.Invoke(serverCtx, invoker, callC) + require.NoError(t, resC.Error()) + afterCHeader := cloneOutgoingHeader(tri.ExtractFromOutgoingContext(serverCtx)) + + serverSpan.End() + + require.Len(t, invoker.calls, 2) + endedSpans := spanRecorder.Ended() + serverReadOnly := findEndedSpanByName(t, endedSpans, "service-a-handler") + clientBReadOnly := findEndedSpanByName(t, endedSpans, "CallServiceB") + clientCReadOnly := findEndedSpanByName(t, endedSpans, "CallServiceC") + + clientBOutgoing := parseTraceparent(t, invoker.calls[0].outgoingTraceparent) + clientCOutgoing := parseTraceparent(t, invoker.calls[1].outgoingTraceparent) + + assert.Equal(t, incomingTraceparent, contextAttachmentTraceparent(serverCtx)) + assert.Equal(t, serverReadOnly.SpanContext().SpanID(), clientBReadOnly.Parent().SpanID()) + assert.Equal(t, serverReadOnly.SpanContext().SpanID(), clientCReadOnly.Parent().SpanID()) + assert.Equal(t, clientBReadOnly.SpanContext().SpanID().String(), clientBOutgoing.spanID) + assert.Equal(t, clientCReadOnly.SpanContext().SpanID().String(), clientCOutgoing.spanID) + assert.NotEqual(t, incomingTraceparent, invoker.calls[0].outgoingTraceparent) + assert.NotEqual(t, invoker.calls[0].outgoingTraceparent, invoker.calls[1].outgoingTraceparent) + assert.Equal(t, "seed", firstHeaderValue(afterBHeader, "x-seed")) + assert.Equal(t, "seed", firstHeaderValue(afterCHeader, "x-seed")) + assert.Empty(t, afterBHeader["traceparent"]) + assert.Empty(t, afterCHeader["traceparent"]) + + t.Logf("incoming traceparent = %s", incomingTraceparent) + t.Logf( + "service A server span: trace_id=%s span_id=%s parent_span_id=%s", + serverReadOnly.SpanContext().TraceID(), + serverReadOnly.SpanContext().SpanID(), + serverReadOnly.Parent().SpanID(), + ) + t.Logf( + "call B client span: trace_id=%s span_id=%s parent_span_id=%s outgoing_traceparent=%s", + clientBReadOnly.SpanContext().TraceID(), + clientBReadOnly.SpanContext().SpanID(), + clientBReadOnly.Parent().SpanID(), + invoker.calls[0].outgoingTraceparent, + ) + t.Logf( + "call C client span: trace_id=%s span_id=%s parent_span_id=%s outgoing_traceparent=%s", + clientCReadOnly.SpanContext().TraceID(), + clientCReadOnly.SpanContext().SpanID(), + clientCReadOnly.Parent().SpanID(), + invoker.calls[1].outgoingTraceparent, + ) + t.Logf("base ctx outgoing header after call B = %v", afterBHeader) + t.Logf("base ctx outgoing header after call C = %v", afterCHeader) +} + +func newTraceInvocation(method string) *invocation.RPCInvocation { + inv := invocation.NewRPCInvocationWithOptions( + invocation.WithMethodName(method), + ) + inv.SetAttribute(constant.CallTypeKey, constant.CallUnary) + return inv +} + +func findEndedSpanByName(t *testing.T, spans []sdktrace.ReadOnlySpan, name string) sdktrace.ReadOnlySpan { + t.Helper() + for _, span := range spans { + if span.Name() == name { + return span + } + } + t.Fatalf("span %q not found", name) + return nil +} + +type parsedTraceparent struct { + traceID string + spanID string + flags string +} + +func parseTraceparent(t *testing.T, traceparent string) parsedTraceparent { + t.Helper() + parts := strings.Split(traceparent, "-") + require.Len(t, parts, 4, "invalid traceparent: %s", traceparent) + return parsedTraceparent{ + traceID: parts[1], + spanID: parts[2], + flags: parts[3], + } +} + +func firstHeaderValue(header http.Header, key string) string { + values := header[strings.ToLower(key)] + if len(values) == 0 { + return "" + } + return values[0] +} + +func contextAttachmentTraceparent(ctx context.Context) string { + raw := ctx.Value(constant.AttachmentKey) + if raw == nil { + return "" + } + attachments, ok := raw.(map[string]any) + if !ok { + return fmt.Sprintf("%v", raw) + } + if value, ok := attachments["traceparent"].(string); ok { + return value + } + return "" +} diff --git a/protocol/triple/triple_protocol/header.go b/protocol/triple/triple_protocol/header.go index e46b0a93b0..b23a5c279d 100644 --- a/protocol/triple/triple_protocol/header.go +++ b/protocol/triple/triple_protocol/header.go @@ -119,17 +119,31 @@ func NewOutgoingContext(ctx context.Context, data http.Header) context.Context { var header = http.Header{} for key, vals := range data { - header[strings.ToLower(key)] = vals + header[strings.ToLower(key)] = append([]string(nil), vals...) } extraData, ok := ctx.Value(extraDataKey{}).(map[string]http.Header) if !ok { extraData = map[string]http.Header{} + } else { + extraData = cloneExtraData(extraData) } extraData[headerOutgoingKey] = header return context.WithValue(ctx, extraDataKey{}, extraData) } +func cloneExtraData(data map[string]http.Header) map[string]http.Header { + cloned := make(map[string]http.Header, len(data)) + for extraKey, header := range data { + headerClone := make(http.Header, len(header)) + for headerKey, vals := range header { + headerClone[headerKey] = append([]string(nil), vals...) + } + cloned[extraKey] = headerClone + } + return cloned +} + // AppendToOutgoingContext merges kv pairs from user and existing headers. // It is used for passing headers to server-side. // It is like grpc.AppendToOutgoingContext.