diff --git a/mcp/shared.go b/mcp/shared.go index bda631fe..393caf42 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -14,6 +14,7 @@ import ( "encoding/json" "fmt" "log" + "net/http" "reflect" "slices" "strings" @@ -396,6 +397,7 @@ type ServerRequest[P Params] struct { // the transport layer. type RequestExtra struct { TokenInfo *auth.TokenInfo // bearer token info (e.g. from OAuth) if any + Header http.Header // header from HTTP request, if any } func (*ClientRequest[P]) isRequest() {} diff --git a/mcp/streamable.go b/mcp/streamable.go index c51f3cc4..d50eadf8 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -580,10 +580,6 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques // This also requires access to the negotiated version, which would either be // set by the MCP-Protocol-Version header, or would require peeking into the // session. - if err != nil { - http.Error(w, fmt.Sprintf("malformed payload: %v", err), http.StatusBadRequest) - return - } incoming, _, err := readBatch(body) if err != nil { http.Error(w, fmt.Sprintf("malformed payload: %v", err), http.StatusBadRequest) @@ -592,17 +588,20 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques requests := make(map[jsonrpc.ID]struct{}) tokenInfo := auth.TokenInfoFromContext(req.Context()) for _, msg := range incoming { - if req, ok := msg.(*jsonrpc.Request); ok { + if jreq, ok := msg.(*jsonrpc.Request); ok { // Preemptively check that this is a valid request, so that we can fail // the HTTP request. If we didn't do this, a request with a bad method or // missing ID could be silently swallowed. - if _, err := checkRequest(req, serverMethodInfos); err != nil { + if _, err := checkRequest(jreq, serverMethodInfos); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } - req.Extra = &RequestExtra{TokenInfo: tokenInfo} - if req.IsCall() { - requests[req.ID] = struct{}{} + jreq.Extra = &RequestExtra{ + TokenInfo: tokenInfo, + Header: req.Header, + } + if jreq.IsCall() { + requests[jreq.ID] = struct{}{} } } }