diff --git a/common/constant/key.go b/common/constant/key.go index b6d32a88c3..e85c204771 100644 --- a/common/constant/key.go +++ b/common/constant/key.go @@ -160,6 +160,8 @@ const ( CallHTTP2 = "http2" CallHTTP3 = "http3" CallHTTP2AndHTTP3 = "http2-and-http3" + AltSvcProtocolH2 = "h2" + AltSvcProtocolH3 = "h3" ServiceInfoKey = "service-info" RpcServiceKey = "rpc-service" ClientInfoKey = "client-info" diff --git a/protocol/triple/client.go b/protocol/triple/client.go index b235ee4c00..061c711033 100644 --- a/protocol/triple/client.go +++ b/protocol/triple/client.go @@ -314,72 +314,3 @@ func genKeepAliveOptions(url *common.URL, tripleConf *global.TripleConfig) ([]tr return cliKeepAliveOpts, keepAliveInterval, keepAliveTimeout, nil } - -// dualTransport is a transport that can handle both HTTP/2 and HTTP/3 -// It uses HTTP Alternative Services (Alt-Svc) for protocol negotiation -type dualTransport struct { - http2Transport *http2.Transport - http3Transport *http3.Transport - // Cache for alternative services to avoid repeated lookups - altSvcCache *tri.AltSvcCache -} - -// newDualTransport creates a new dual transport that supports both HTTP/2 and HTTP/3 -func newDualTransport(tlsConfig *tls.Config, keepAliveInterval, keepAliveTimeout time.Duration) http.RoundTripper { - http2Transport := &http2.Transport{ - TLSClientConfig: tlsConfig, - ReadIdleTimeout: keepAliveInterval, - PingTimeout: keepAliveTimeout, - } - - http3Transport := &http3.Transport{ - TLSClientConfig: tlsConfig, - QUICConfig: &quic.Config{ - KeepAlivePeriod: keepAliveInterval, - MaxIdleTimeout: keepAliveTimeout, - }, - } - - return &dualTransport{ - http2Transport: http2Transport, - http3Transport: http3Transport, - altSvcCache: tri.NewAltSvcCache(), - } -} - -// RoundTrip implements http.RoundTripper interface with HTTP Alternative Services support -func (dt *dualTransport) RoundTrip(req *http.Request) (*http.Response, error) { - // Check if we have cached alternative service information - cachedAltSvc := dt.altSvcCache.Get(req.URL.Host) - - // If we have valid cached alt-svc info and it's for HTTP/3, try HTTP/3 first - // Check if the cached information is still valid (not expired) - if cachedAltSvc != nil && cachedAltSvc.Protocol == "h3" { - logger.Debugf("Using cached HTTP/3 alternative service for %s", req.URL.String()) - resp, err := dt.http3Transport.RoundTrip(req) - if err == nil { - // Update alt-svc cache from response headers - dt.altSvcCache.UpdateFromHeaders(req.URL.Host, resp.Header) - return resp, nil - } - logger.Debugf("Cached HTTP/3 request failed to %s, falling back to HTTP/2: %v", req.URL.String(), err) - } - - // Start with HTTP/2 to get alternative service information - logger.Debugf("Making initial HTTP/2 request to %s to discover alternative services", req.URL.String()) - resp, err := dt.http2Transport.RoundTrip(req) - if err != nil { - logger.Errorf("HTTP/2 request failed to %s: %v", req.URL.String(), err) - return nil, err - } - - // Check for alternative services in the response - dt.altSvcCache.UpdateFromHeaders(req.URL.Host, resp.Header) - - // If the response indicates HTTP/3 is available, try HTTP/3 for future requests - if altSvc := dt.altSvcCache.Get(req.URL.Host); altSvc != nil && altSvc.Protocol == "h3" { - logger.Debugf("Server %s supports HTTP/3, will use HTTP/3 for future requests", req.URL.Host) - } - - return resp, nil -} diff --git a/protocol/triple/dual_transport.go b/protocol/triple/dual_transport.go new file mode 100644 index 0000000000..ca84e62c17 --- /dev/null +++ b/protocol/triple/dual_transport.go @@ -0,0 +1,342 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package triple + +import ( + "context" + "crypto/tls" + "errors" + "io" + "net/http" + "net/url" + "sync" + "time" +) + +import ( + "github.com/dubbogo/gost/log/logger" + + "github.com/quic-go/quic-go" + "github.com/quic-go/quic-go/http3" + + "golang.org/x/net/http2" +) + +import ( + "dubbo.apache.org/dubbo-go/v3/common/constant" + tri "dubbo.apache.org/dubbo-go/v3/protocol/triple/triple_protocol" +) + +type originMode int + +const ( + // originUnknown means the origin has not advertised HTTP/3 yet. + originUnknown originMode = iota + // originCandidate means the origin advertised HTTP/3 and is waiting for validation. + originCandidate + // originProbing means an out-of-band HTTP/3 probe is in flight. + originProbing + // originH3Healthy means later requests may be sent over HTTP/3. + originH3Healthy + // originCooldown means HTTP/3 recently failed and should be avoided for a while. + originCooldown +) + +const ( + defaultH3ProbeTimeout = 3 * time.Second + defaultH3BaseCooldown = 4 * time.Second + defaultH3MaxCooldown = 1 * time.Minute +) + +// originState tracks whether the current upstream origin is ready for HTTP/3. +type originState struct { + mode originMode + failures int + cooldownUntil time.Time +} + +// dualTransport keeps HTTP/2 as the stable path and only uses HTTP/3 after discovery and probe. +type dualTransport struct { + http2Transport http.RoundTripper + http3Transport http.RoundTripper + // Cache for alternative services to avoid repeated lookups + altSvcCache *tri.AltSvcCache + + // state tracks HTTP/3 readiness for the bound upstream origin. + state originState + + // mu protects state and serializes transitions between unknown/probing/healthy/cooldown. + mu sync.Mutex + + probeTimeout time.Duration + baseCooldown time.Duration + maxCooldown time.Duration +} + +// newDualTransport creates a new dual transport that supports both HTTP/2 and HTTP/3 +func newDualTransport(tlsConfig *tls.Config, keepAliveInterval, keepAliveTimeout time.Duration) http.RoundTripper { + http2Transport := &http2.Transport{ + TLSClientConfig: tlsConfig, + ReadIdleTimeout: keepAliveInterval, + PingTimeout: keepAliveTimeout, + } + + http3Transport := &http3.Transport{ + TLSClientConfig: tlsConfig, + QUICConfig: &quic.Config{ + KeepAlivePeriod: keepAliveInterval, + MaxIdleTimeout: keepAliveTimeout, + }, + } + + return &dualTransport{ + http2Transport: http2Transport, + http3Transport: http3Transport, + altSvcCache: tri.NewAltSvcCache(), + probeTimeout: defaultH3ProbeTimeout, + baseCooldown: defaultH3BaseCooldown, + maxCooldown: defaultH3MaxCooldown, + } +} + +// RoundTrip implements http.RoundTripper interface with HTTP Alternative Services support +func (dt *dualTransport) RoundTrip(req *http.Request) (*http.Response, error) { + host := req.URL.Host + + if dt.shouldUseH3(host) { + // Only use HTTP/3 after a separate probe marks the origin healthy. + // If the HTTP/3 request fails, return the error directly instead of + // replaying the same request over HTTP/2 with a partially consumed body. + resp, err := dt.http3Transport.RoundTrip(req) + if err == nil { + dt.markH3Success(host) + return resp, nil + } + if dt.shouldMarkH3Failure(req, err) { + logger.Warnf("HTTP/3 request failed for %s: %v", req.URL.String(), err) + dt.markH3Failure(host) + } + return nil, err + } + + resp, err := dt.http2Transport.RoundTrip(req) + if err != nil { + return nil, err + } + + // Learn HTTP/3 availability from HTTP/2 responses and validate it with + // an independent probe before routing later requests over HTTP/3. + dt.observeH2Response(req.URL, resp.Header) + return resp, nil +} + +// shouldUseH3 only decides whether the current request may use HTTP/3. +func (dt *dualTransport) shouldUseH3(host string) bool { + if host == "" { + return false + } + + altSvc := dt.altSvcCache.Get(host) + if altSvc == nil || altSvc.Protocol != constant.AltSvcProtocolH3 { + return false + } + + now := time.Now() + + dt.mu.Lock() + defer dt.mu.Unlock() + + switch dt.state.mode { + case originH3Healthy: + return true + + case originCooldown: + // Wait for the cooldown window to expire before probing HTTP/3 again. + if now.Before(dt.state.cooldownUntil) { + return false + } + + logger.Debugf("HTTP/3 cooldown expired for %s", host) + dt.state.mode = originCandidate + dt.state.cooldownUntil = time.Time{} + return false + + case originUnknown, originCandidate, originProbing: + return false + } + return false +} + +func (dt *dualTransport) shouldMarkH3Failure(req *http.Request, err error) bool { + if req.Context().Err() != nil { + return false + } + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return false + } + return true +} + +func (dt *dualTransport) markH3Success(host string) { + dt.mu.Lock() + defer dt.mu.Unlock() + + dt.state.mode = originH3Healthy + dt.state.failures = 0 + dt.state.cooldownUntil = time.Time{} + logger.Debugf("HTTP/3 ready for %s", host) +} + +func (dt *dualTransport) markH3Failure(host string) { + dt.mu.Lock() + defer dt.mu.Unlock() + + dt.state.failures++ + dt.state.mode = originCooldown + dt.state.cooldownUntil = time.Now().Add(dt.nextCooldown(dt.state.failures)) + logger.Debugf( + "HTTP/3 cooldown for %s until %s after %d failure(s)", + host, + dt.state.cooldownUntil.Format(time.RFC3339), + dt.state.failures, + ) +} + +// observeH2Response learns Alt-Svc from the current HTTP/2 response for later requests. +func (dt *dualTransport) observeH2Response(u *url.URL, headers http.Header) { + if u == nil || u.Host == "" { + return + } + + dt.altSvcCache.UpdateFromHeaders(u.Host, headers) + + altSvc := dt.altSvcCache.Get(u.Host) + if altSvc == nil || altSvc.Protocol != constant.AltSvcProtocolH3 { + return + } + + now := time.Now() + + dt.mu.Lock() + switch dt.state.mode { + case originUnknown: + dt.state.mode = originCandidate + logger.Debugf("HTTP/3 advertised by %s", u.Host) + + case originCooldown: + if !now.Before(dt.state.cooldownUntil) { + dt.state.mode = originCandidate + dt.state.cooldownUntil = time.Time{} + logger.Debugf("HTTP/3 reprobe enabled for %s", u.Host) + } + + case originCandidate, originProbing, originH3Healthy: + // Already aware of HTTP/3, nothing to do + } + dt.mu.Unlock() + + dt.maybeStartProbe(u) +} + +func (dt *dualTransport) maybeStartProbe(u *url.URL) { + if u == nil || u.Host == "" { + return + } + host := u.Host + + altSvc := dt.altSvcCache.Get(host) + if altSvc == nil || altSvc.Protocol != constant.AltSvcProtocolH3 { + return + } + + now := time.Now() + + dt.mu.Lock() + switch dt.state.mode { + case originH3Healthy, originProbing: + dt.mu.Unlock() + return + case originCooldown: + if now.Before(dt.state.cooldownUntil) { + dt.mu.Unlock() + return + } + dt.state.mode = originCandidate + dt.state.cooldownUntil = time.Time{} + case originUnknown, originCandidate: + dt.state.mode = originCandidate + } + + probeURL := &url.URL{ + Scheme: u.Scheme, + Host: host, + Path: "/", + } + // Validate HTTP/3 readiness out of band so the current business request + // can stay on HTTP/2. + dt.state.mode = originProbing + dt.mu.Unlock() + + logger.Debugf("Start HTTP/3 probe for %s via %s", host, probeURL.String()) + go dt.runProbe(probeURL) +} + +func (dt *dualTransport) runProbe(probeURL *url.URL) { + if probeURL == nil || probeURL.Host == "" { + return + } + host := probeURL.Host + + ctx, cancel := context.WithTimeout(context.Background(), dt.probeTimeout) + defer cancel() + + // Probe with an independent request so business request bodies never need + // to be replayed across transports. + req, err := http.NewRequestWithContext(ctx, http.MethodOptions, probeURL.String(), nil) + if err != nil { + logger.Debugf("Create HTTP/3 probe request failed for %s: %v", host, err) + dt.markH3Failure(host) + return + } + resp, err := dt.http3Transport.RoundTrip(req) + if err != nil { + logger.Debugf("HTTP/3 probe failed for %s: %v", host, err) + dt.markH3Failure(host) + return + } + defer resp.Body.Close() + + _, _ = io.Copy(io.Discard, resp.Body) + + dt.markH3Success(host) +} + +func (dt *dualTransport) nextCooldown(failures int) time.Duration { + if failures <= 1 { + return dt.baseCooldown + } + // Increase the cooldown window after repeated HTTP/3 failures, capped by maxCooldown. + d := dt.baseCooldown + for i := 1; i < failures; i++ { + d *= 2 + if d >= dt.maxCooldown { + return dt.maxCooldown + } + } + return d +} diff --git a/protocol/triple/dual_transport_test.go b/protocol/triple/dual_transport_test.go new file mode 100644 index 0000000000..7fbc93eec5 --- /dev/null +++ b/protocol/triple/dual_transport_test.go @@ -0,0 +1,393 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package triple + +import ( + "context" + "errors" + "io" + "net/http" + "strings" + "sync" + "sync/atomic" + "testing" + "time" +) + +import ( + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +import ( + tri "dubbo.apache.org/dubbo-go/v3/protocol/triple/triple_protocol" +) + +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +func newTestResponse(status int, headers http.Header) *http.Response { + if headers == nil { + headers = make(http.Header) + } + return &http.Response{ + StatusCode: status, + Header: headers, + Body: io.NopCloser(strings.NewReader("")), + } +} + +func TestDualTransport_H2AltSvcStartsProbeAndPromotesH3Healthy(t *testing.T) { + t.Parallel() + + dt := &dualTransport{ + http2Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + headers := make(http.Header) + headers.Set("Alt-Svc", `h3=":443"; ma=86400`) + return newTestResponse(http.StatusOK, headers), nil + }), + altSvcCache: tri.NewAltSvcCache(), + probeTimeout: 100 * time.Millisecond, + baseCooldown: 4 * time.Second, + maxCooldown: 1 * time.Minute, + } + + probeCalled := make(chan struct{}, 1) + dt.http3Transport = roundTripperFunc(func(req *http.Request) (*http.Response, error) { + assert.Equal(t, http.MethodOptions, req.Method) + assert.Equal(t, "/", req.URL.Path) + probeCalled <- struct{}{} + return newTestResponse(http.StatusMethodNotAllowed, nil), nil + }) + + req, err := http.NewRequest(http.MethodPost, "https://example.com/service", nil) + require.NoError(t, err) + + resp, err := dt.RoundTrip(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + + select { + case <-probeCalled: + case <-time.After(time.Second): + t.Fatal("probe was not triggered") + } + + require.Eventually(t, func() bool { + dt.mu.Lock() + defer dt.mu.Unlock() + return dt.state.mode == originH3Healthy + }, time.Second, 10*time.Millisecond) +} + +func TestDualTransport_ConcurrentH2DiscoveryStartsSingleProbe(t *testing.T) { + t.Parallel() + + const numRequests = 8 + + var h2Calls atomic.Int32 + var h3Calls atomic.Int32 + + dt := &dualTransport{ + http2Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + h2Calls.Add(1) + headers := make(http.Header) + headers.Set("Alt-Svc", `h3=":443"; ma=86400`) + return newTestResponse(http.StatusOK, headers), nil + }), + altSvcCache: tri.NewAltSvcCache(), + probeTimeout: 100 * time.Millisecond, + baseCooldown: 4 * time.Second, + maxCooldown: 1 * time.Minute, + } + + probeStarted := make(chan struct{}, 1) + releaseProbe := make(chan struct{}) + dt.http3Transport = roundTripperFunc(func(req *http.Request) (*http.Response, error) { + h3Calls.Add(1) + assert.Equal(t, http.MethodOptions, req.Method) + select { + case probeStarted <- struct{}{}: + default: + } + <-releaseProbe + return newTestResponse(http.StatusMethodNotAllowed, nil), nil + }) + + var wg sync.WaitGroup + wg.Add(numRequests) + for i := 0; i < numRequests; i++ { + req, err := http.NewRequest(http.MethodPost, "https://example.com/service", nil) + require.NoError(t, err) + + go func(req *http.Request) { + defer wg.Done() + + resp, err := dt.RoundTrip(req) + if !assert.NoError(t, err) { + return + } + if assert.NotNil(t, resp) { + assert.Equal(t, http.StatusOK, resp.StatusCode) + } + }(req) + } + + select { + case <-probeStarted: + case <-time.After(time.Second): + t.Fatal("probe was not triggered") + } + + wg.Wait() + assert.Equal(t, int32(numRequests), h2Calls.Load()) + assert.Equal(t, int32(1), h3Calls.Load()) + + close(releaseProbe) + + require.Eventually(t, func() bool { + dt.mu.Lock() + defer dt.mu.Unlock() + return dt.state.mode == originH3Healthy + }, time.Second, 10*time.Millisecond) +} + +func TestDualTransport_ProbeFailureEntersCooldownAndKeepsHTTP2(t *testing.T) { + t.Parallel() + + var h2Calls atomic.Int32 + var h3Calls atomic.Int32 + + dt := &dualTransport{ + http2Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + h2Calls.Add(1) + headers := make(http.Header) + headers.Set("Alt-Svc", `h3=":443"; ma=86400`) + return newTestResponse(http.StatusOK, headers), nil + }), + http3Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + h3Calls.Add(1) + assert.Equal(t, http.MethodOptions, req.Method) + return nil, errors.New("probe failed") + }), + altSvcCache: tri.NewAltSvcCache(), + probeTimeout: 100 * time.Millisecond, + baseCooldown: 4 * time.Second, + maxCooldown: 1 * time.Minute, + } + + req, err := http.NewRequest(http.MethodPost, "https://example.com/service", nil) + require.NoError(t, err) + + resp, err := dt.RoundTrip(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + + require.Eventually(t, func() bool { + dt.mu.Lock() + defer dt.mu.Unlock() + return dt.state.mode == originCooldown && dt.state.failures == 1 && !dt.state.cooldownUntil.IsZero() + }, time.Second, 10*time.Millisecond) + + resp, err = dt.RoundTrip(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, int32(2), h2Calls.Load()) + assert.Equal(t, int32(1), h3Calls.Load()) +} + +func TestDualTransport_H3HealthyFailureDoesNotFallbackToHTTP2(t *testing.T) { + t.Parallel() + + var h2Calls atomic.Int32 + var h3Calls atomic.Int32 + + dt := &dualTransport{ + http2Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + h2Calls.Add(1) + return newTestResponse(http.StatusOK, nil), nil + }), + http3Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + h3Calls.Add(1) + return nil, errors.New("h3 failed") + }), + altSvcCache: tri.NewAltSvcCache(), + baseCooldown: 4 * time.Second, + maxCooldown: 1 * time.Minute, + } + dt.altSvcCache.Set("example.com", &tri.AltSvcInfo{ + Protocol: "h3", + Expires: time.Now().Add(time.Hour), + }) + dt.state.mode = originH3Healthy + + req, err := http.NewRequest(http.MethodPost, "https://example.com/service", nil) + require.NoError(t, err) + + resp, err := dt.RoundTrip(req) + require.Error(t, err) + assert.Nil(t, resp) + assert.Equal(t, int32(1), h3Calls.Load()) + assert.Equal(t, int32(0), h2Calls.Load()) + + dt.mu.Lock() + defer dt.mu.Unlock() + assert.Equal(t, originCooldown, dt.state.mode) + assert.False(t, dt.state.cooldownUntil.IsZero()) +} + +func TestDualTransport_H3ContextErrorDoesNotEnterCooldown(t *testing.T) { + t.Parallel() + + var h2Calls atomic.Int32 + var h3Calls atomic.Int32 + + dt := &dualTransport{ + http2Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + h2Calls.Add(1) + return newTestResponse(http.StatusOK, nil), nil + }), + http3Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + h3Calls.Add(1) + return nil, req.Context().Err() + }), + altSvcCache: tri.NewAltSvcCache(), + baseCooldown: 4 * time.Second, + maxCooldown: 1 * time.Minute, + } + dt.altSvcCache.Set("example.com", &tri.AltSvcInfo{ + Protocol: "h3", + Expires: time.Now().Add(time.Hour), + }) + dt.state.mode = originH3Healthy + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://example.com/service", nil) + require.NoError(t, err) + + resp, err := dt.RoundTrip(req) + require.ErrorIs(t, err, context.Canceled) + assert.Nil(t, resp) + assert.Equal(t, int32(1), h3Calls.Load()) + assert.Equal(t, int32(0), h2Calls.Load()) + + dt.mu.Lock() + defer dt.mu.Unlock() + assert.Equal(t, originH3Healthy, dt.state.mode) + assert.Equal(t, 0, dt.state.failures) + assert.True(t, dt.state.cooldownUntil.IsZero()) +} + +func TestDualTransport_CooldownUsesHTTP2(t *testing.T) { + t.Parallel() + + var h2Calls atomic.Int32 + var h3Calls atomic.Int32 + + dt := &dualTransport{ + http2Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + h2Calls.Add(1) + return newTestResponse(http.StatusOK, nil), nil + }), + http3Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + h3Calls.Add(1) + return newTestResponse(http.StatusOK, nil), nil + }), + altSvcCache: tri.NewAltSvcCache(), + baseCooldown: 4 * time.Second, + maxCooldown: 1 * time.Minute, + } + dt.altSvcCache.Set("example.com", &tri.AltSvcInfo{ + Protocol: "h3", + Expires: time.Now().Add(time.Hour), + }) + dt.state.mode = originCooldown + dt.state.cooldownUntil = time.Now().Add(time.Minute) + + req, err := http.NewRequest(http.MethodPost, "https://example.com/service", nil) + require.NoError(t, err) + + resp, err := dt.RoundTrip(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, int32(1), h2Calls.Load()) + assert.Equal(t, int32(0), h3Calls.Load()) +} + +func TestDualTransport_CooldownExpiryStartsProbeAndPromotesH3Healthy(t *testing.T) { + t.Parallel() + + var h2Calls atomic.Int32 + var h3Calls atomic.Int32 + + dt := &dualTransport{ + http2Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + h2Calls.Add(1) + headers := make(http.Header) + headers.Set("Alt-Svc", `h3=":443"; ma=86400`) + return newTestResponse(http.StatusOK, headers), nil + }), + altSvcCache: tri.NewAltSvcCache(), + probeTimeout: 100 * time.Millisecond, + baseCooldown: 4 * time.Second, + maxCooldown: 1 * time.Minute, + } + + probeCalled := make(chan struct{}, 1) + dt.http3Transport = roundTripperFunc(func(req *http.Request) (*http.Response, error) { + h3Calls.Add(1) + assert.Equal(t, http.MethodOptions, req.Method) + assert.Equal(t, "/", req.URL.Path) + probeCalled <- struct{}{} + return newTestResponse(http.StatusMethodNotAllowed, nil), nil + }) + + dt.altSvcCache.Set("example.com", &tri.AltSvcInfo{ + Protocol: "h3", + Expires: time.Now().Add(time.Hour), + }) + dt.state.mode = originCooldown + dt.state.failures = 1 + dt.state.cooldownUntil = time.Now().Add(-time.Second) + + req, err := http.NewRequest(http.MethodPost, "https://example.com/service", nil) + require.NoError(t, err) + + resp, err := dt.RoundTrip(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, int32(1), h2Calls.Load()) + + select { + case <-probeCalled: + case <-time.After(time.Second): + t.Fatal("probe was not triggered after cooldown expiry") + } + + require.Eventually(t, func() bool { + dt.mu.Lock() + defer dt.mu.Unlock() + return dt.state.mode == originH3Healthy && dt.state.failures == 0 && dt.state.cooldownUntil.IsZero() + }, time.Second, 10*time.Millisecond) + assert.Equal(t, int32(1), h3Calls.Load()) +} diff --git a/protocol/triple/triple_protocol/negotiation.go b/protocol/triple/triple_protocol/negotiation.go index 342cd5ff08..2899bff163 100644 --- a/protocol/triple/triple_protocol/negotiation.go +++ b/protocol/triple/triple_protocol/negotiation.go @@ -31,6 +31,10 @@ import ( "github.com/quic-go/quic-go/http3" ) +import ( + "dubbo.apache.org/dubbo-go/v3/common/constant" +) + // AltSvcInfo represents cached alternative service information // This struct stores the parsed information from Alt-Svc HTTP headers // according to RFC 7838 (HTTP Alternative Services) @@ -113,10 +117,10 @@ func (c *AltSvcCache) UpdateFromHeaders(host string, headers http.Header) { // Prefer HTTP/3 over HTTP/2 var preferredAltSvc *AltSvcInfo for _, altSvc := range altSvcs { - if altSvc.Protocol == "h3" { + if altSvc.Protocol == constant.AltSvcProtocolH3 { preferredAltSvc = altSvc break - } else if altSvc.Protocol == "h2" && preferredAltSvc == nil { + } else if altSvc.Protocol == constant.AltSvcProtocolH2 && preferredAltSvc == nil { preferredAltSvc = altSvc } } @@ -160,7 +164,7 @@ func parseAltSvcPart(part string) *AltSvcInfo { } protocol := strings.TrimSpace(part[:eqIndex]) - if protocol != "h3" && protocol != "h2" { + if protocol != constant.AltSvcProtocolH3 && protocol != constant.AltSvcProtocolH2 { return nil }