Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions cmd/aaop/aaop.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ var (
ips = flag.String("image-pull-secret", "", "the imagePullSecret to use for private registries")
port = flag.String("port", "8080", "port to listen to")
metricsPort = flag.String("metrics-port", "9090", "port to listen to for metrics")
maxBundle = flag.Int("max-bundle", 0, "maximum bundle size in bytes to download")
updateCABundle = flag.Bool("update-ca-bundle", false, "regularly update the Provider's caBundle field")
)

Expand Down Expand Up @@ -81,6 +82,12 @@ func main() {
}
}

if *maxBundle > 0 {
slog.Info("setting maximum bundle size",
"bytes", *maxBundle)
fetcher.MaxBundleSize = int64(*maxBundle)
}

// Start the metrics server
go func() {
var mm = http.NewServeMux()
Expand All @@ -95,7 +102,8 @@ func main() {
}
slog.Info("starting Prometheus metrics server",
"url", promSrv.Addr)
if err := promSrv.ListenAndServe(); err != nil {
if err := promSrv.ListenAndServe(); err != nil &&
!errors.Is(err, http.ErrServerClosed) {
log.Fatalf("failed to start metrics server: %v", err)
}
}()
Expand Down Expand Up @@ -255,18 +263,25 @@ func (t *transport) validate(w http.ResponseWriter, r *http.Request) {
return
}

// Limit request body size to prevent DoS attacks (1 MB limit)
const maxRequestSize = 1 << 20 // 1 MB
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
defer r.Body.Close()

// read request body
requestBody, err := io.ReadAll(r.Body)
if err != nil {
sendResponse(w, provider.ErrorResponse(fmt.Sprintf("unable to read request body: %v", err)))
slog.Error("unable to read request body", "error", err)
sendResponse(w, provider.ErrorResponse("unable to read request body"))
return
}

// parse request body
var providerRequest externaldata.ProviderRequest
err = json.Unmarshal(requestBody, &providerRequest)
if err != nil {
sendResponse(w, provider.ErrorResponse(fmt.Sprintf("unable to unmarshal request body: %v", err)))
slog.Error("unable to unmarshal request body", "error", err)
sendResponse(w, provider.ErrorResponse("unable to parse request body"))
return
}

Expand Down
31 changes: 22 additions & 9 deletions pkg/cainjector/injector.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,24 @@ import (
"k8s.io/client-go/dynamic"
)

var propagationDelay = 10 * time.Second

// UpdateCABundle ensures that the `caBundle` field in the Provider object contains the CA certificates in $certsDir/ca.crt.
// If the field is already up to date, no changes are made.
// If an update is made, it sleeps for 10 seconds to allow Gatekeeper to pick up the changes.
// UpdateCABundle removes expired certificates to prevent the bundle from growing indefinitely.
// DefaultPropagationDelay is the default time to wait for Gatekeeper
// to pick up CA bundle changes.
const DefaultPropagationDelay = 10 * time.Second

// UpdateCABundle ensures that the `caBundle` field in the Provider
// object contains the CA certificates in $certsDir/ca.crt. If the
// field is already up to date, no changes are made. If an update is
// made, it sleeps for the specified propagation delay to allow
// Gatekeeper to pick up the changes. UpdateCABundle removes expired
// certificates to prevent the bundle from growing indefinitely.
func UpdateCABundle(ctx context.Context, k8sClient dynamic.Interface, bundlePath string) error {
return UpdateCABundleWithDelay(ctx, k8sClient, bundlePath, DefaultPropagationDelay)
}

// UpdateCABundleWithDelay is like UpdateCABundle but allows
// specifying a custom propagation delay. This is useful for testing
// where a shorter or zero delay is desired.
func UpdateCABundleWithDelay(ctx context.Context, k8sClient dynamic.Interface, bundlePath string, propagationDelay time.Duration) error {
provider, err := getProvider(ctx, k8sClient)
if err != nil {
return fmt.Errorf("failed to get Provider object: %w", err)
Expand All @@ -52,9 +63,11 @@ func UpdateCABundle(ctx context.Context, k8sClient dynamic.Interface, bundlePath
}

slog.Info("Successfully updated CA bundle in Provider object.")
slog.Info("Sleeping to allow Gatekeeper to pick up the changes",
"sleep_time", propagationDelay)
time.Sleep(propagationDelay)
if propagationDelay > 0 {
slog.Info("Sleeping to allow Gatekeeper to pick up the changes",
"sleep_time", propagationDelay)
time.Sleep(propagationDelay)
}
slog.Info("Update CA bundle done")

return nil
Expand Down
3 changes: 1 addition & 2 deletions pkg/cainjector/injector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ func TestUpdateCABundle(t *testing.T) {
}
err := v1beta1.AddToScheme(scheme.Scheme)
require.NoError(t, err)
propagationDelay = 0 // speed up tests

for _, test := range cases {
t.Run(test.Name, func(t *testing.T) {
Expand All @@ -105,7 +104,7 @@ func TestUpdateCABundle(t *testing.T) {
caPath := path.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(caPath, test.additionalBundle, 0600))

err = UpdateCABundle(t.Context(), client, caPath)
err = UpdateCABundleWithDelay(t.Context(), client, caPath, 0) // Use zero delay for tests
if test.ErrorMsg != "" {
require.ErrorContains(t, err, test.ErrorMsg)
} else {
Expand Down
37 changes: 29 additions & 8 deletions pkg/fetcher/bundle.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package fetcher

import (
"context"
"errors"
"fmt"
"io"
"runtime"
Expand All @@ -22,6 +23,10 @@ var (
runtime.GOARCH)
)

// MaxBundleSize is the max number of bytes read from a remote OCI registry
// when fetching a bundle. The Default value is 10MB.
var MaxBundleSize int64 = 10 << 20

// DefaultBundleFetcher is the default implementation of the BundleFetcher.
type DefaultBundleFetcher struct{}

Expand All @@ -38,7 +43,7 @@ func (*DefaultBundleFetcher) GetRemoteOptions(ctx context.Context, kc authn.Keyc

// BundleFromName fetches a sigstore bundle for a container from
// a registry.
// This is copied from
// This is based on
// https://github.com/github/policy-controller/blob/09dab43394666d59c15ded66aee622097af58b77/pkg/webhook/bundle.go#L125
func BundleFromName(ref name.Reference, remoteOpts []remote.Option) ([]*bundle.Bundle, *v1.Hash, error) {
desc, err := remote.Get(ref, remoteOpts...)
Expand All @@ -59,26 +64,42 @@ func BundleFromName(ref name.Reference, remoteOpts []remote.Option) ([]*bundle.B
bundles := make([]*bundle.Bundle, 0)

for _, refDesc := range refManifest.Manifests {
var refImg v1.Image
var layers []v1.Layer
var layer0 io.ReadCloser
var err error

if !strings.HasPrefix(refDesc.ArtifactType, "application/vnd.dev.sigstore.bundle") {
continue
}

refImg, err := remote.Image(ref.Context().Digest(refDesc.Digest.String()), remoteOpts...)
refImg, err = remote.Image(ref.Context().Digest(refDesc.Digest.String()), remoteOpts...)
if err != nil {
return nil, nil, fmt.Errorf("error getting referrer image: %w", err)
}
layers, err := refImg.Layers()
layers, err = refImg.Layers()
if err != nil {
return nil, nil, fmt.Errorf("error getting referrer image: %w", err)
return nil, nil, fmt.Errorf("error getting referrer image layers: %w", err)
}
layer0, err := layers[0].Uncompressed()
if len(layers) == 0 {
return nil, nil, errors.New("error getting referrer image: no layers found")
}
layer0, err = layers[0].Uncompressed()
if err != nil {
return nil, nil, fmt.Errorf("error getting referrer image: %w", err)
return nil, nil, fmt.Errorf("error decompressing layer: %w", err)
}
bundleBytes, err := io.ReadAll(layer0)
bundleBytes, err := io.ReadAll(io.LimitReader(layer0,
MaxBundleSize+1))
layer0.Close()
if err != nil {
return nil, nil, fmt.Errorf("error getting referrer image: %w", err)
return nil, nil, fmt.Errorf("error reading bundle layer: %w", err)
}

// check if we didn't read all the data
if int64(len(bundleBytes)) > MaxBundleSize {
return nil, nil, fmt.Errorf("bundle size exceeds maximum allowed size of %d bytes", MaxBundleSize)
}

b := &bundle.Bundle{}
err = b.UnmarshalJSON(bundleBytes)
if err != nil {
Expand Down
5 changes: 3 additions & 2 deletions pkg/verifier/multi.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package verifier

import (
"crypto/x509"
"errors"
"fmt"
"log/slog"

Expand Down Expand Up @@ -80,10 +81,10 @@ func getIssuer(b *bundle.Bundle) (string, error) {
var err error

if vc, err = b.VerificationContent(); err != nil {
return "", err
return "", fmt.Errorf("no verification content in bundle: %w", err)
}
if c = vc.Certificate(); c == nil {
return "", err
return "", errors.New("no certificate found in bundle")
}

if len(c.Issuer.Organization) != 1 {
Expand Down
29 changes: 29 additions & 0 deletions pkg/verifier/verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ func GHVerifier(td string) (*Verifier, error) {
if td == "" || td == "dotcom" {
target = defaultTR
} else {
if !validTrustDomain(td) {
return nil, fmt.Errorf("invalid trust domain: %q", td)
}
target = fmt.Sprintf("%s.%s", td, defaultTR)
}

Expand All @@ -91,6 +94,32 @@ func GHVerifier(td string) (*Verifier, error) {
)
}

// validTrustDomain validates that a trust domain contains only safe
// characters. Trust domains should be alphanumeric with hyphens, similar
// to DNS labels.
func validTrustDomain(td string) bool {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use some regex matching here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's fair. Maybe easier to work with.

if len(td) == 0 || len(td) > 63 {
return false
}
for i, c := range td {
if c >= 'a' && c <= 'z' {
continue
}
if c >= 'A' && c <= 'Z' {
continue
}
if c >= '0' && c <= '9' {
continue
}
// Hyphen allowed but not at start or end
if c == '-' && i > 0 && i < len(td)-1 {
continue
}
return false
}
return true
}

// Verify iterates of the provided bundles and returns a set of verification
// results using VerifyOne.
func (v *Verifier) Verify(bundles []*bundle.Bundle, h *v1.Hash) ([]*verify.VerificationResult, error) {
Expand Down
Loading