Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 63 additions & 1 deletion internal/transport/client_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
package transport

import (
"strconv"
"sync"
"sync/atomic"
"time"

"golang.org/x/net/http2"
"google.golang.org/grpc/mem"
Expand All @@ -28,6 +31,12 @@ import (
"google.golang.org/grpc/status"
)

// NonGRPCDataMaxLen is the maximum length of nonGRPCDataBuf.
const NonGRPCDataMaxLen = 1024

// nonGRPCDataCollectionTimeout is the timeout for collecting non-gRPC data.
const nonGRPCDataCollectionTimeout = 3 * time.Second

// ClientStream implements streaming functionality for a gRPC client.
type ClientStream struct {
Stream // Embed for common stream functionality.
Expand All @@ -46,14 +55,65 @@ type ClientStream struct {
// headerValid indicates whether a valid header was received. Only
// meaningful after headerChan is closed (always call waitOnHeader() before
// reading its value).
headerValid bool
headerValid bool

collectionMu sync.Mutex
collecting bool // indicates if stream entered the stage of non-gRPC data collection.
collectionTimer *time.Timer // used to limit the time spent on collecting non-gRPC error details.
nonGRPCDataBuf []byte // stores the data of a non-gRPC response.

noHeaders bool // set if the client never received headers (set only after the stream is done).
headerChanClosed uint32 // set when headerChan is closed. Used to avoid closing headerChan multiple times.
bytesReceived atomic.Bool // indicates whether any bytes have been received on this stream
unprocessed atomic.Bool // set if the server sends a refused stream or GOAWAY including this stream
statsHandler stats.Handler // nil for internal streams (e.g., health check, ORCA) where telemetry is not supported.
}

func (s *ClientStream) startNonGRPCDataCollection(st *status.Status, onTimeout func()) {
s.collectionMu.Lock()
defer s.collectionMu.Unlock()
if s.collecting {
return
}
s.status = st
s.collecting = true
s.nonGRPCDataBuf = make([]byte, 0, NonGRPCDataMaxLen)
s.collectionTimer = time.AfterFunc(nonGRPCDataCollectionTimeout, onTimeout)
}

// tryHandleNonGRPCData tries to collect non-gRPC body from the given data frame.
// It returns two booleans:
// handle indicates whether the frame should be handled as a non-gRPC response body,
// end indicates whether the stream should be closed after handling this frame.
func (s *ClientStream) tryHandleNonGRPCData(f *parsedDataFrame) (handle bool, end bool) {
s.collectionMu.Lock()
defer s.collectionMu.Unlock()
if !s.collecting {
return false, false
}

n := min(f.data.Len(), NonGRPCDataMaxLen-len(s.nonGRPCDataBuf))
s.nonGRPCDataBuf = append(s.nonGRPCDataBuf, f.data.ReadOnlyData()[0:n]...)
if len(s.nonGRPCDataBuf) >= NonGRPCDataMaxLen || f.StreamEnded() {
return true, true
}
return true, false
}

// stopNonGRPCBodyCollection stops collecting non-gRPC body and appends the collected.
// Should only be called in closeStream.
func (s *ClientStream) stopNonGRPCDataCollectionLocked() {
if !s.collecting {
return
}
if s.collectionTimer != nil {
s.collectionTimer.Stop()
s.collectionTimer = nil
}
data := "\ndata: " + strconv.Quote(string(s.nonGRPCDataBuf))
s.status = status.New(s.status.Code(), s.status.Message()+data)
}

// Read reads an n byte message from the input stream.
func (s *ClientStream) Read(n int) (mem.BufferSlice, error) {
b, err := s.Stream.read(n)
Expand Down Expand Up @@ -126,6 +186,8 @@ func (s *ClientStream) Header() (metadata.MD, error) {
s.waitOnHeader()

if !s.headerValid || s.noHeaders {
s.collectionMu.Lock()
defer s.collectionMu.Unlock()
return nil, s.status.Err()
}

Expand Down
47 changes: 42 additions & 5 deletions internal/transport/http2_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -934,6 +934,10 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr, handler s
return s, nil
}

func (t *http2Client) closeStreamWithNonGRPCStatus(s *ClientStream) {
t.closeStream(s, nil, true, http2.ErrCodeProtocol, nil, nil, true)
}

func (t *http2Client) closeStream(s *ClientStream, err error, rst bool, rstCode http2.ErrCode, st *status.Status, mdata map[string][]string, eosReceived bool) {
// Set stream status to done.
if s.swapState(streamDone) == streamDone {
Expand All @@ -942,10 +946,17 @@ func (t *http2Client) closeStream(s *ClientStream, err error, rst bool, rstCode
<-s.done
return
}
// status and trailers can be updated here without any synchronization because the stream goroutine will
// only read it after it sees an io.EOF error from read or write and we'll write those errors
// only after updating this.
s.collectionMu.Lock()
if s.collecting {
// If the stream is collecting data for non-gRPC, stop collection to finalize status
s.stopNonGRPCDataCollectionLocked()
}
if s.status != nil {
st = s.status
err = st.Err()
}
s.status = st
s.collectionMu.Unlock()
if len(mdata) > 0 {
s.trailer = mdata
}
Expand Down Expand Up @@ -1222,6 +1233,21 @@ func (t *http2Client) handleData(f *parsedDataFrame) {
t.closeStream(s, io.EOF, true, http2.ErrCodeFlowControl, status.New(codes.Internal, err.Error()), nil, false)
return
}

handle, end := s.tryHandleNonGRPCData(f)
if handle {
if w := s.fc.onRead(size); w > 0 {
t.controlBuf.put(&outgoingWindowUpdate{
streamID: s.id,
increment: w,
})
}
if end {
t.closeStreamWithNonGRPCStatus(s)
}
return
}

dataLen := f.data.Len()
if f.Header().Flags.Has(http2.FlagDataPadded) {
if w := s.fc.onRead(size - uint32(dataLen)); w > 0 {
Expand Down Expand Up @@ -1561,8 +1587,19 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
errs = append(errs, contentTypeErr)
}

se := status.New(grpcErrorCode, strings.Join(errs, "; "))
t.closeStream(s, se.Err(), true, http2.ErrCodeProtocol, se, nil, endStream)
errMsg := strings.Join(errs, "; ")
se := status.New(grpcErrorCode, errMsg)
if endStream {
t.closeStream(s, se.Err(), true, http2.ErrCodeProtocol, se, nil, true)
return
}

s.startNonGRPCDataCollection(se, func() {
t.closeStreamWithNonGRPCStatus(s)
})
if atomic.CompareAndSwapUint32(&s.headerChanClosed, 0, 1) {
close(s.headerChan)
}
return
}

Expand Down
113 changes: 113 additions & 0 deletions test/end2end_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"os"
"reflect"
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -6787,6 +6788,118 @@ func (s) TestAuthorityHeader(t *testing.T) {
}
}

func (s) TestHTTPServerSendsNonGRPCHeaderSurfaceFurtherData(t *testing.T) {
tests := []struct {
name string
responses []httpServerResponse
wantCode codes.Code
wantErr string
}{
{
name: "non-gRPC content-type without payload",
responses: []httpServerResponse{
{
headers: [][]string{
{
":status", "200",
"content-type", "text/html",
},
},
// payload: nil
},
},
wantCode: codes.Unknown,
wantErr: `rpc error: code = Unknown desc = unexpected HTTP status code received from server: 200 (OK); transport: received unexpected content-type "text/html"
data: ""`,
},
{
name: "non-gRPC content-type with payload",
responses: []httpServerResponse{
{
headers: [][]string{
{
":status", "200",
"content-type", "text/html",
},
},
payload: []byte(`<html><body>Hello World</body></html>`),
},
},
wantCode: codes.Unknown,
wantErr: `rpc error: code = Unknown desc = unexpected HTTP status code received from server: 200 (OK); transport: received unexpected content-type "text/html"
data: "<html><body>Hello World</body></html>"`,
},
{
name: "non-gRPC content-type with bytes payload length more than transport.NonGRPCDataMaxLen",
responses: []httpServerResponse{
{
headers: [][]string{
{
":status", "200",
"content-type", "text/html",
},
},
payload: bytes.Repeat([]byte("a"), transport.NonGRPCDataMaxLen+1),
},
},
wantCode: codes.Unknown,
wantErr: `rpc error: code = Unknown desc = unexpected HTTP status code received from server: 200 (OK); transport: received unexpected content-type "text/html"
data: ` + strconv.Quote(strings.Repeat("a", transport.NonGRPCDataMaxLen)),
},
{
name: "content-type not provided",
responses: []httpServerResponse{
{
headers: [][]string{{
":status", "502",
}},
payload: []byte("hello"),
},
},
wantCode: codes.Unavailable,
wantErr: `rpc error: code = Unavailable desc = unexpected HTTP status code received from server: 502 (Bad Gateway); malformed header: missing HTTP content-type
data: "hello"`,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("net.Listen() failed: %v", err)
}
defer lis.Close()

hs := &httpServer{responses: test.responses}
hs.start(t, lis)

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

cc, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("grpc.NewClient() failed: %v", err)
}
defer cc.Close()

client := testgrpc.NewTestServiceClient(cc)
_, err = client.EmptyCall(ctx, &testpb.Empty{})

if err == nil {
t.Fatalf("EmptyCall() = nil; want non-nil error due to non-gRPC response")
}

if got, want := status.Code(err), test.wantCode; got != want {
t.Fatalf("Unexpected error code: got %v, want %v\nfull error:\n%v", got, want, err)
}

if err.Error() != test.wantErr {
t.Errorf("Unexpected error message: got\n %v, want\n %v", err.Error(), test.wantErr)
}
})
}
}

// wrapCloseListener tracks Accepts/Closes and maintains a counter of the
// number of open connections.
type wrapCloseListener struct {
Expand Down
Loading