Skip to content

Commit 75c110a

Browse files
committed
Add comprehensive registry validation to prevent invalid servers
- Add validation for supported registry types (npm, pypi, oci, nuget, mcpb) - Add validation for registry base URLs matching expected URLs per type - Reject empty registry types and base URLs when package identifier provided - Enforce strict allowlist with no localhost or development URL exceptions - Add comprehensive test coverage with clear pass/fail comparisons 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]> :house: Remote-Dev: homespace
1 parent 9910ca2 commit 75c110a

File tree

3 files changed

+255
-0
lines changed

3 files changed

+255
-0
lines changed

internal/validators/constants.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@ var (
1313
// Remote validation errors
1414
ErrInvalidRemoteURL = errors.New("invalid remote URL")
1515

16+
// Registry validation errors
17+
ErrUnsupportedRegistryType = errors.New("unsupported registry type")
18+
ErrUnsupportedRegistryBaseURL = errors.New("unsupported registry base URL")
19+
ErrMismatchedRegistryTypeAndURL = errors.New("registry type and base URL do not match")
20+
1621
// Argument validation errors
1722
ErrNamedArgumentNameRequired = errors.New("named argument name is required")
1823
ErrInvalidNamedArgumentName = errors.New("invalid named argument name format")

internal/validators/validators.go

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,81 @@ func validateMCPBPackage(host string) error {
164164
return nil
165165
}
166166

167+
// validateRegistryType checks if the registry type is supported
168+
func validateRegistryType(registryType string) error {
169+
// Registry type is required
170+
if registryType == "" {
171+
return fmt.Errorf("%w: registry type is required", ErrUnsupportedRegistryType)
172+
}
173+
174+
supportedTypes := []string{
175+
model.RegistryTypeNPM,
176+
model.RegistryTypePyPI,
177+
model.RegistryTypeOCI,
178+
model.RegistryTypeNuGet,
179+
model.RegistryTypeMCPB,
180+
}
181+
182+
for _, supported := range supportedTypes {
183+
if registryType == supported {
184+
return nil
185+
}
186+
}
187+
188+
return fmt.Errorf("%w: '%s'. Supported types: %v", ErrUnsupportedRegistryType, registryType, supportedTypes)
189+
}
190+
191+
// validateRegistryBaseURL checks if the registry base URL is valid for the given registry type
192+
func validateRegistryBaseURL(registryType, baseURL string) error {
193+
// Base URL is required for all registry types except MCPB (which uses direct URLs)
194+
if baseURL == "" {
195+
if registryType == model.RegistryTypeMCPB {
196+
return nil // MCPB packages use direct URLs in the identifier
197+
}
198+
return fmt.Errorf("%w: registry base URL is required for registry type '%s'", ErrUnsupportedRegistryBaseURL, registryType)
199+
}
200+
201+
// Define expected base URLs for each registry type
202+
expectedURLs := map[string][]string{
203+
model.RegistryTypeNPM: {model.RegistryURLNPM},
204+
model.RegistryTypePyPI: {model.RegistryURLPyPI},
205+
model.RegistryTypeOCI: {model.RegistryURLDocker},
206+
model.RegistryTypeNuGet: {model.RegistryURLNuGet},
207+
model.RegistryTypeMCPB: {model.RegistryURLGitHub, model.RegistryURLGitLab},
208+
}
209+
210+
// Check if the base URL is valid for the registry type
211+
if expectedURLsForType, exists := expectedURLs[registryType]; exists {
212+
for _, expected := range expectedURLsForType {
213+
if baseURL == expected {
214+
return nil
215+
}
216+
}
217+
return fmt.Errorf("%w: '%s' is not valid for registry type '%s'. Expected: %v",
218+
ErrMismatchedRegistryTypeAndURL, baseURL, registryType, expectedURLsForType)
219+
}
220+
221+
// If registry type is not in our expected URLs map but base URL is provided,
222+
// it's likely an unsupported base URL
223+
return fmt.Errorf("%w: '%s'", ErrUnsupportedRegistryBaseURL, baseURL)
224+
}
225+
167226
func validatePackage(pkg *model.Package) error {
168227
registryType := strings.ToLower(pkg.RegistryType)
169228

229+
// Only validate if package has an identifier (i.e., it's a real package reference)
230+
if pkg.Identifier != "" {
231+
// Validate registry type is supported
232+
if err := validateRegistryType(registryType); err != nil {
233+
return err
234+
}
235+
236+
// Validate registry base URL matches the registry type
237+
if err := validateRegistryBaseURL(registryType, pkg.RegistryBaseURL); err != nil {
238+
return err
239+
}
240+
}
241+
170242
// For direct download packages (mcpb or direct URLs)
171243
if registryType == model.RegistryTypeMCPB ||
172244
strings.HasPrefix(pkg.Identifier, "http://") || strings.HasPrefix(pkg.Identifier, "https://") {

internal/validators/validators_test.go

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package validators_test
22

33
import (
44
"fmt"
5+
"strings"
56
"testing"
67

78
"github.com/modelcontextprotocol/registry/internal/validators"
@@ -648,6 +649,183 @@ func TestValidateArgument_ValidValueFields(t *testing.T) {
648649
}
649650

650651
// Helper function to create a valid server with a specific argument for testing
652+
func TestValidate_RegistryTypes(t *testing.T) {
653+
testCases := []struct {
654+
name string
655+
registryType string
656+
baseURL string
657+
identifier string
658+
expectError bool
659+
}{
660+
// Valid registry types (should pass)
661+
{"valid_npm", model.RegistryTypeNPM, model.RegistryURLNPM, "test-package", false},
662+
{"valid_pypi", model.RegistryTypePyPI, model.RegistryURLPyPI, "test-package", false},
663+
{"valid_oci", model.RegistryTypeOCI, model.RegistryURLDocker, "test-package", false},
664+
{"valid_nuget", model.RegistryTypeNuGet, model.RegistryURLNuGet, "test-package", false},
665+
{"valid_mcpb_github", model.RegistryTypeMCPB, model.RegistryURLGitHub, "https://github.com/owner/repo", false},
666+
{"valid_mcpb_gitlab", model.RegistryTypeMCPB, model.RegistryURLGitLab, "https://gitlab.com/owner/repo", false},
667+
668+
// Invalid registry types (should fail)
669+
{"invalid_maven", "maven", "https://example.com/registry", "test-package", true},
670+
{"invalid_cargo", "cargo", "https://example.com/registry", "test-package", true},
671+
{"invalid_gem", "gem", "https://example.com/registry", "test-package", true},
672+
{"invalid_invalid", "invalid", "https://example.com/registry", "test-package", true},
673+
{"invalid_unknown", "UNKNOWN", "https://example.com/registry", "test-package", true},
674+
{"invalid_custom", "custom-registry", "https://example.com/registry", "test-package", true},
675+
{"invalid_github", "github", "https://example.com/registry", "test-package", true}, // This is a source, not a registry type
676+
{"invalid_docker", "docker", "https://example.com/registry", "test-package", true}, // Should be "oci"
677+
{"invalid_empty", "", "https://example.com/registry", "test-package", true}, // Empty registry type
678+
}
679+
680+
for _, tc := range testCases {
681+
t.Run(tc.name, func(t *testing.T) {
682+
serverDetail := apiv0.ServerJSON{
683+
Name: "com.example/test-server",
684+
Description: "A test server",
685+
Repository: model.Repository{
686+
URL: "https://github.com/owner/repo",
687+
Source: "github",
688+
ID: "owner/repo",
689+
},
690+
VersionDetail: model.VersionDetail{
691+
Version: "1.0.0",
692+
},
693+
Packages: []model.Package{
694+
{
695+
Identifier: tc.identifier,
696+
RegistryType: tc.registryType,
697+
RegistryBaseURL: tc.baseURL,
698+
},
699+
},
700+
Remotes: []model.Remote{
701+
{
702+
URL: "https://example.com/remote",
703+
},
704+
},
705+
}
706+
707+
err := validators.ValidateServerJSON(&serverDetail)
708+
if tc.expectError {
709+
assert.Error(t, err)
710+
assert.Contains(t, err.Error(), validators.ErrUnsupportedRegistryType.Error())
711+
} else {
712+
assert.NoError(t, err)
713+
}
714+
})
715+
}
716+
}
717+
718+
func TestValidate_RegistryBaseURLs(t *testing.T) {
719+
testCases := []struct {
720+
name string
721+
registryType string
722+
baseURL string
723+
identifier string
724+
expectError bool
725+
}{
726+
// Invalid base URLs for specific registry types
727+
{"npm_wrong_url", model.RegistryTypeNPM, "https://pypi.org", "test-package", true},
728+
{"pypi_wrong_url", model.RegistryTypePyPI, "https://registry.npmjs.org", "test-package", true},
729+
{"oci_wrong_url", model.RegistryTypeOCI, "https://registry.npmjs.org", "test-package", true},
730+
{"nuget_wrong_url", model.RegistryTypeNuGet, "https://docker.io", "test-package", true},
731+
{"mcpb_wrong_url", model.RegistryTypeMCPB, "https://evil.com", "https://github.com/owner/repo", true},
732+
{"empty_base_url", model.RegistryTypeNPM, "", "test-package", true},
733+
{"empty_base_url", model.RegistryTypeNPM, model.RegistryURLDocker, "test-package", true},
734+
{"empty_base_url", model.RegistryTypeOCI, model.RegistryTypeNuGet, "test-package", true},
735+
736+
// Localhost URLs should be rejected - no development exceptions
737+
{"localhost_npm", model.RegistryTypeNPM, "http://localhost:3000", "test-package", true},
738+
{"localhost_ip", model.RegistryTypePyPI, "http://127.0.0.1:8080", "test-package", true},
739+
740+
// Valid combinations (should pass)
741+
{"valid_npm", model.RegistryTypeNPM, model.RegistryURLNPM, "test-package", false},
742+
{"valid_pypi", model.RegistryTypePyPI, model.RegistryURLPyPI, "test-package", false},
743+
{"valid_oci", model.RegistryTypeOCI, model.RegistryURLDocker, "test-package", false},
744+
{"valid_nuget", model.RegistryTypeNuGet, model.RegistryURLNuGet, "test-package", false},
745+
{"valid_mcpb_github", model.RegistryTypeMCPB, model.RegistryURLGitHub, "https://github.com/owner/repo", false},
746+
{"valid_mcpb_gitlab", model.RegistryTypeMCPB, model.RegistryURLGitLab, "https://gitlab.com/owner/repo", false},
747+
748+
// Trailing slash URLs should be rejected - strict exact match only
749+
{"npm_trailing_slash", model.RegistryTypeNPM, "https://registry.npmjs.org/", "test-package", true},
750+
{"pypi_trailing_slash", model.RegistryTypePyPI, "https://pypi.org/", "test-package", true},
751+
}
752+
753+
for _, tc := range testCases {
754+
t.Run(tc.name, func(t *testing.T) {
755+
serverDetail := apiv0.ServerJSON{
756+
Name: "com.example/test-server",
757+
Description: "A test server",
758+
Repository: model.Repository{
759+
URL: "https://github.com/owner/repo",
760+
Source: "github",
761+
ID: "owner/repo",
762+
},
763+
VersionDetail: model.VersionDetail{
764+
Version: "1.0.0",
765+
},
766+
Packages: []model.Package{
767+
{
768+
Identifier: tc.identifier,
769+
RegistryType: tc.registryType,
770+
RegistryBaseURL: tc.baseURL,
771+
},
772+
},
773+
Remotes: []model.Remote{
774+
{
775+
URL: "https://example.com/remote",
776+
},
777+
},
778+
}
779+
780+
err := validators.ValidateServerJSON(&serverDetail)
781+
if tc.expectError {
782+
assert.Error(t, err)
783+
// Check that the error is related to registry validation
784+
errStr := err.Error()
785+
assert.True(t,
786+
strings.Contains(errStr, validators.ErrUnsupportedRegistryBaseURL.Error()) ||
787+
strings.Contains(errStr, validators.ErrMismatchedRegistryTypeAndURL.Error()),
788+
"Expected registry validation error, got: %s", errStr)
789+
} else {
790+
assert.NoError(t, err)
791+
}
792+
})
793+
}
794+
}
795+
796+
func TestValidate_EmptyRegistryType(t *testing.T) {
797+
// Test that empty registry type is rejected
798+
serverDetail := apiv0.ServerJSON{
799+
Name: "com.example/test-server",
800+
Description: "A test server",
801+
Repository: model.Repository{
802+
URL: "https://github.com/owner/repo",
803+
Source: "github",
804+
ID: "owner/repo",
805+
},
806+
VersionDetail: model.VersionDetail{
807+
Version: "1.0.0",
808+
},
809+
Packages: []model.Package{
810+
{
811+
Identifier: "test-package",
812+
RegistryType: "", // Empty registry type
813+
RegistryBaseURL: "",
814+
},
815+
},
816+
Remotes: []model.Remote{
817+
{
818+
URL: "https://example.com/remote",
819+
},
820+
},
821+
}
822+
823+
err := validators.ValidateServerJSON(&serverDetail)
824+
assert.Error(t, err)
825+
assert.Contains(t, err.Error(), validators.ErrUnsupportedRegistryType.Error())
826+
assert.Contains(t, err.Error(), "registry type is required")
827+
}
828+
651829
func createValidServerWithArgument(arg model.Argument) apiv0.ServerJSON {
652830
return apiv0.ServerJSON{
653831
Name: "com.example/test-server",

0 commit comments

Comments
 (0)