Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 36 additions & 26 deletions mcp/event.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,7 @@ func writeEvent(w io.Writer, evt Event) (int, error) {
// TODO(rfindley): consider a different API here that makes failure modes more
// apparent.
func scanEvents(r io.Reader) iter.Seq2[Event, error] {
scanner := bufio.NewScanner(r)
const maxTokenSize = 1 * 1024 * 1024 // 1 MiB max line size
scanner.Buffer(nil, maxTokenSize)
reader := bufio.NewReader(r)

// TODO: investigate proper behavior when events are out of order, or have
// non-standard names.
Expand All @@ -94,31 +92,45 @@ func scanEvents(r io.Reader) iter.Seq2[Event, error] {
evt Event
dataBuf *bytes.Buffer // if non-nil, preceding field was also data
)
flushData := func() {
var previousLineType []byte
yieldEvent := func() bool {
if dataBuf != nil {
evt.Data = dataBuf.Bytes()
dataBuf = nil
previousLineType = nil
}
if evt.Empty() {
return true
}
if !yield(evt, nil) {
return false
}
evt = Event{}
return true
}
for scanner.Scan() {
line := scanner.Bytes()
for {
line, err := reader.ReadBytes('\n')
if err != nil && !errors.Is(err, io.EOF) {
yield(Event{}, fmt.Errorf("error reading event: %v", err))
return
}
line = bytes.TrimRight(line, "\r\n")
isEOF := errors.Is(err, io.EOF)

if len(line) == 0 {
flushData()
// \n\n is the record delimiter
if !evt.Empty() && !yield(evt, nil) {
if !yieldEvent() {
return
}
if isEOF {
return
}
evt = Event{}
continue
}
before, after, found := bytes.Cut(line, []byte{':'})
if !found {
yield(Event{}, fmt.Errorf("malformed line in SSE stream: %q", string(line)))
return
}
if !bytes.Equal(before, dataKey) {
flushData()
}
switch {
case bytes.Equal(before, eventKey):
evt.Name = strings.TrimSpace(string(after))
Expand All @@ -128,27 +140,25 @@ func scanEvents(r io.Reader) iter.Seq2[Event, error] {
evt.Retry = strings.TrimSpace(string(after))
case bytes.Equal(before, dataKey):
data := bytes.TrimSpace(after)
if dataBuf != nil {
dataBuf.WriteByte('\n')
previousLineEmptyOrData := previousLineType == nil || bytes.Equal(previousLineType, dataKey)
if dataBuf == nil {
dataBuf = new(bytes.Buffer)
dataBuf.Write(data)
} else if !previousLineEmptyOrData {
yield(Event{}, fmt.Errorf("non-continuous data items in the event"))
return
} else {
dataBuf = new(bytes.Buffer)
dataBuf.WriteByte('\n')
dataBuf.Write(data)
}
}
}
if err := scanner.Err(); err != nil {
if errors.Is(err, bufio.ErrTooLong) {
err = fmt.Errorf("event exceeded max line length of %d", maxTokenSize)
}
if !yield(Event{}, err) {
previousLineType = before

if isEOF {
yieldEvent()
return
}
}
flushData()
if !evt.Empty() {
yield(evt, nil)
}
}
}

Expand Down
48 changes: 48 additions & 0 deletions mcp/event_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,54 @@ func TestScanEvents(t *testing.T) {
input: "invalid line\n\n",
wantErr: "malformed line",
},
{
name: "message with 2 data lines and another event",
input: "event: message\ndata: hello\ndata: hello\ndata: hello\n\nevent:keepalive",
want: []Event{
{Name: "message", Data: []byte("hello\nhello\nhello")},
{Name: "keepalive"},
},
},
{
name: "event with multiple lines",
input: "event: message\ndata: hello\ndata: hello\ndata: hello\nid:1",
want: []Event{
{Name: "message", ID: "1", Data: []byte("hello\nhello\nhello")},
},
},
{
name: "multiple events, out of order keys",
input: strings.Join([]string{
"event:message",
"data: hello0",
"\n",
"data: hello1",
"data: hello1",
"id:1",
"event:message",
"\n",
"event:message",
"data: hello3",
"data: hello3",
"id:3",
"\n",
"data: hello4",
"data: hello4",
"id:4",
"event:message",
}, "\n"),
want: []Event{
{Name: "message", Data: []byte("hello0")},
{Name: "message", ID: "1", Data: []byte("hello1\nhello1")},
{Name: "message", ID: "3", Data: []byte("hello3\nhello3")},
{Name: "message", ID: "4", Data: []byte("hello4\nhello4")},
},
},
{
name: "non-continuous data items in the event",
input: "event: foo\ndata: 123\nretry: 5\ndata: 456",
wantErr: "non-continuous data items in the event",
},
}

for _, tt := range tests {
Expand Down
28 changes: 21 additions & 7 deletions mcp/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) {
s := &sseClientConn{
client: httpClient,
msgEndpoint: msgEndpoint,
incoming: make(chan []byte, 100),
incoming: make(chan sseMessage, 100),
body: resp.Body,
done: make(chan struct{}),
}
Expand All @@ -392,10 +392,14 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) {

for evt, err := range scanEvents(resp.Body) {
if err != nil {
select {
case s.incoming <- sseMessage{err: err}:
case <-s.done:
}
return
}
select {
case s.incoming <- evt.Data:
case s.incoming <- sseMessage{data: evt.Data}:
case <-s.done:
return
}
Expand All @@ -405,15 +409,21 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) {
return s, nil
}

// sseMessage represents a message or error from the SSE stream.
type sseMessage struct {
data []byte
err error
}

// An sseClientConn is a logical jsonrpc2 connection that implements the client
// half of the SSE protocol:
// - Writes are POSTS to the session endpoint.
// - Reads are SSE 'message' events, and pushes them onto a buffered channel.
// - Close terminates the GET request.
type sseClientConn struct {
client *http.Client // HTTP client to use for requests
msgEndpoint *url.URL // session endpoint for POSTs
incoming chan []byte // queue of incoming messages
client *http.Client // HTTP client to use for requests
msgEndpoint *url.URL // session endpoint for POSTs
incoming chan sseMessage // queue of incoming messages or errors

mu sync.Mutex
body io.ReadCloser // body of the hanging GET
Expand All @@ -438,12 +448,16 @@ func (c *sseClientConn) Read(ctx context.Context) (jsonrpc.Message, error) {
case <-c.done:
return nil, io.EOF

case data := <-c.incoming:
case m := <-c.incoming:
if m.err != nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Returning an error from Read breaks the connection. I don't think that's actually what we want to happen here, is it?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that this method return either an error or a jsonrpc.Message, I think this is an acceptable behaviour.
Considering previous behaviour when there was an error during scan events, parsing the body would stop as the loop over scanEvents would break in case there was an error, but without bubbling up the error.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I don't think I sufficiently conveyed the consequences.
The JSON-RPC library expects to operate on a logical stream. Any failure to read from that stream indicates that the logical connection is broken.

So as a consequence of this change, a context cancellation calling a tool will break the entire MCP session.

I think I know the desired behavior: you want to get an error from the CallTool request, but this change doesn't achieve that result.

The jsonrpc2 library doesn't really support this: we'd need to add something like jsonrpc2.Connection.Fail(jsonrpc2.ID, error) to allow this type of layering traversal.

If you'd like to land this PR, I suggest not returning an error here, and I can make the change to achieve what you want.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this error different from the error returned a few lines bellow by jsonrpc2.DecodeMessage?

Your assumption is correct, I want to bubble up any errors from lower levels, so that for the caller the problem would be clear enough, but this implementation doesn't seem to be using a jsonrpc2 connection, just the encode/decode part. (this is sse implementation, not jsonrpc).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The jsonrpc2.Connection calls Read, which must have stream semantics: an error from Read breaks the stream.

I think there's an argument to be made that a malformed payload should break the stream: if the server sends something that we can't even parse as a jsonrpc2.Message, then the stream is corrupt.

But if the error is due to a network error or client cancellation, we don't want to terminate the session.

See also #683.

Let's focus on fixing the size limit. We can leave the bubbling up of errors, but should not return an error here. I'll make the change to surface the error to the application layer:it's a subtle change to the jsonrpc2 connection.

// TODO: bubble up this error
return nil, nil
}
// TODO(rfindley): do we really need to check this? We receive from c.done above.
if c.isDone() {
return nil, io.EOF
}
msg, err := jsonrpc2.DecodeMessage(data)
msg, err := jsonrpc2.DecodeMessage(m.data)
if err != nil {
return nil, err
}
Expand Down
6 changes: 5 additions & 1 deletion mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -1859,7 +1859,11 @@ func (c *streamableClientConn) processStream(ctx context.Context, requestSummary
if ctx.Err() != nil {
return "", 0, true // don't reconnect: client cancelled
}
break

// Network errors during reading should trigger reconnection, not permanent failure.
// Return from processStream so handleSSE can attempt to reconnect.
c.logger.Debug(fmt.Sprintf("%s: stream read error (will attempt reconnect): %v", requestSummary, err))
return lastEventID, reconnectDelay, false
}

if evt.ID != "" {
Expand Down