diff --git a/.github/workflows/go.yaml b/.github/workflows/go.yaml index a33512790..633c63b46 100644 --- a/.github/workflows/go.yaml +++ b/.github/workflows/go.yaml @@ -17,24 +17,30 @@ jobs: uses: actions/setup-go@v5 with: go-version: "1.23.2" + cache: true + cache-dependency-path: go.sum - - name: golangci-lint - uses: golangci/golangci-lint-action@v6 - with: - version: v1.61 - args: --timeout=5m - - name: ./gomod.sh run: ./gomod.sh - - name: exhaustive github.com/nishanths/exhaustive@v0.12.0 + - name: Run `golangci-lint@v2.1.2` + run: | + go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@v2.1.2 + golangci-lint run + + - name: Run `shadow@v0.31.0` + run: | + go install golang.org/x/tools/go/analysis/passes/shadow/cmd/shadow@v0.31.0 + shadow ./... + + - name: Run `exhaustive@v0.12.0` run: | go install github.com/nishanths/exhaustive/cmd/exhaustive@v0.12.0 exhaustive -default-signifies-exhaustive ./... - - - name: deadcode golang.org/x/tools/cmd/deadcode@v0.26.0 + + - name: Run `deadcode@v0.31.0` run: | - go install golang.org/x/tools/cmd/deadcode@v0.26.0 + go install golang.org/x/tools/cmd/deadcode@v0.31.0 output=$(deadcode -test ./...) if [[ -n "$output" ]]; then echo "🚨 Deadcode found:" @@ -44,6 +50,20 @@ jobs: echo "βœ… No deadcode found" fi + - name: Run `goimports@v0.31.0` + run: | + go install golang.org/x/tools/cmd/goimports@v0.31.0 + # Find all .go files excluding paths containing 'mock' and run goimports + non_compliant_files=$(find . -type f -name "*.go" ! -path "*mock*" | xargs goimports -local "github.com/stellar/stellar-disbursement-platform-backend" -l) + + if [ -n "$non_compliant_files" ]; then + echo "🚨 The following files are not compliant with goimports:" + echo "$non_compliant_files" + exit 1 + else + echo "βœ… All files are compliant with goimports." + fi + build: runs-on: ubuntu-latest steps: @@ -54,6 +74,8 @@ jobs: uses: actions/setup-go@v5 with: go-version: "1.23.2" + cache: true + cache-dependency-path: go.sum - name: Build Project run: go build ./... @@ -82,7 +104,7 @@ jobs: PGPASSWORD: postgres PGDATABASE: postgres DATABASE_URL: postgres://postgres:postgres@localhost:5432/postgres?sslmode=disable - + steps: - name: Checkout uses: actions/checkout@v4 @@ -91,6 +113,30 @@ jobs: uses: actions/setup-go@v5 with: go-version: "1.23.2" + cache: true + cache-dependency-path: go.sum + + - name: Install gotestsum@v1.11.0 + run: go install gotest.tools/gotestsum@v1.11.0 + + - name: Run tests + run: gotestsum --format-hide-empty-pkg --format pkgname-and-test-fails -- -coverprofile=c.out ./... -timeout 3m -coverpkg ./... - - name: Run Tests - run: go test -v -race -cover ./... + - name: Validate Test Coverage Threshold + env: + TESTCOVERAGE_THRESHOLD: 65 # percentage + run: | + echo "Quality Gate: Checking if test coverage is above threshold..." + echo "Threshold: $TESTCOVERAGE_THRESHOLD%" + totalCoverage=`./scripts/exclude_from_coverage.sh && go tool cover -func=c.out | grep total: | grep -Eo '[0-9]+\.[0-9]+'` + echo "Test Coverage: $totalCoverage%" + echo "-------------------------" + if (( $(echo "$totalCoverage $TESTCOVERAGE_THRESHOLD" | awk '{print ($1 >= $2)}') )); then + echo " $totalCoverage% > $TESTCOVERAGE_THRESHOLD%" + echo "Current test coverage is above threshold πŸŽ‰πŸŽ‰πŸŽ‰! Please keep up the good work!" + else + echo " $totalCoverage% < $TESTCOVERAGE_THRESHOLD%" + echo "🚨 Current test coverage is below threshold 😱! Please add more unit tests or adjust threshold to a lower value." + echo "Failed 😭" + exit 1 + fi diff --git a/.gitignore b/.gitignore index c79af127c..9011b8e1d 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ captive-core*/ env.sh .idea __debug_bin* +c.out diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 000000000..2d9941d39 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,58 @@ +version: "2" + +formatters: + enable: + - gofmt + - gofumpt + + settings: + gofmt: + simplify: false + +linters: + enable: + - govet + - rowserrcheck + - nilerr + - nilnesserr + - errname + - errorlint + - errcheck + - wrapcheck + + settings: + staticcheck: + checks: + - all + - -ST1000 + - -QF1008 + + errcheck: + # report about assignment of errors to blank identifier: `num, _ := strconv.Atoi(numStr)`. + check-blank: true + + exclusions: + paths: + # mocks are ignored + - ".*.mock.*\\.go$" + + rules: + - path: _test\.go + linters: [wrapcheck] + + - path: (_test\.go|dbtest\.go)$ + linters: [errcheck] + text: '\.Close' + + # Exclude some `staticcheck` messages. + - linters: [staticcheck] + text: '(QF1008|SqlDB should be SQLDB)' + + # Mode of the generated files analysis. + # + # - `strict`: Only fules containing `^// Code generated .* DO NOT EDIT\.$` are excluded. + # - `lax`: Sources containing lines like `autogenerated file`, `code generated`, `do not edit`, etc. are excluded. + # - `disable`: disable the generated files exclusion. + # + # Default: strict + generated: lax diff --git a/.mockery.yml b/.mockery.yml index d048e14b8..c03e1383d 100644 --- a/.mockery.yml +++ b/.mockery.yml @@ -5,7 +5,7 @@ formatter: goimports log-level: info structname: "{{.Mock}}{{.InterfaceName}}" pkgname: "{{.SrcPackageName}}" -filename: "mock_{{.InterfaceName | snakecase}}.go" +filename: "mocks.go" recursive: false require-template-schema-exists: true template: testify @@ -14,5 +14,3 @@ packages: github.com/stellar/wallet-backend/cmd: interfaces: ChAccCmdServiceInterface: - config: - filename: "channel_account_mocks.go" diff --git a/cmd/channel_account.go b/cmd/channel_account.go index fd721fc61..e4e33fced 100644 --- a/cmd/channel_account.go +++ b/cmd/channel_account.go @@ -18,6 +18,7 @@ import ( "github.com/stellar/wallet-backend/internal/services" "github.com/stellar/wallet-backend/internal/signing/store" signingutils "github.com/stellar/wallet-backend/internal/signing/utils" + internalUtils "github.com/stellar/wallet-backend/internal/utils" ) // ChAccCmdServiceInterface is the interface for the channel account command service. It is used to allow mocking the @@ -32,6 +33,7 @@ type ChAccCmdService struct{} var _ ChAccCmdServiceInterface = (*ChAccCmdService)(nil) +//nolint:wrapcheck // Skipping wrapcheck because this is just a proxy to the service. func (s *ChAccCmdService) EnsureChannelAccounts(ctx context.Context, chAccService services.ChannelAccountService, amount int64) error { return chAccService.EnsureChannelAccounts(ctx, amount) } @@ -126,7 +128,7 @@ func (c *channelAccountCmd) Command(cmdService ChAccCmdServiceInterface) *cobra. return nil }, RunE: func(cmd *cobra.Command, args []string) error { - defer distAccSigClientOpts.DBConnectionPool.Close() + defer internalUtils.DeferredClose(cmd.Context(), distAccSigClientOpts.DBConnectionPool, "closing distAccSigClient's db connection pool") amount, err := strconv.Atoi(args[0]) if err != nil { return fmt.Errorf("invalid [amount] argument=%s", args[0]) diff --git a/cmd/channel_accounts_test.go b/cmd/channel_accounts_test.go index 5202ceb14..65ddf1330 100644 --- a/cmd/channel_accounts_test.go +++ b/cmd/channel_accounts_test.go @@ -32,8 +32,8 @@ func Test_ChannelAccountsCommand_EnsureCommand(t *testing.T) { }) t.Run("🟒executes_successfully", func(t *testing.T) { - mChAccService. - On("EnsureChannelAccounts", mock.AnythingOfType("context.backgroundCtx"), mock.AnythingOfType("*services.channelAccountService"), int64(2)). + mChAccService.EXPECT(). + EnsureChannelAccounts(mock.AnythingOfType("context.backgroundCtx"), mock.AnythingOfType("*services.channelAccountService"), int64(2)). Return(nil). Once() err := rootCmd.Execute() @@ -41,8 +41,8 @@ func Test_ChannelAccountsCommand_EnsureCommand(t *testing.T) { }) t.Run("πŸ”΄fails_if_ChannelAccountsService_fails", func(t *testing.T) { - mChAccService. - On("EnsureChannelAccounts", mock.AnythingOfType("context.backgroundCtx"), mock.AnythingOfType("*services.channelAccountService"), int64(2)). + mChAccService.EXPECT(). + EnsureChannelAccounts(mock.AnythingOfType("context.backgroundCtx"), mock.AnythingOfType("*services.channelAccountService"), int64(2)). Return(errors.New("foo bar baz")). Once() err := rootCmd.Execute() @@ -50,6 +50,4 @@ func Test_ChannelAccountsCommand_EnsureCommand(t *testing.T) { assert.ErrorContains(t, err, "ensuring the number of channel accounts is created") assert.ErrorContains(t, err, "foo bar baz") }) - - mChAccService.AssertExpectations(t) } diff --git a/cmd/distribution_account.go b/cmd/distribution_account.go index 073d90d7e..57d44fa40 100644 --- a/cmd/distribution_account.go +++ b/cmd/distribution_account.go @@ -8,6 +8,7 @@ import ( "github.com/spf13/cobra" "github.com/stellar/go/support/config" "github.com/stellar/go/support/log" + "github.com/stellar/wallet-backend/cmd/utils" "github.com/stellar/wallet-backend/internal/db" "github.com/stellar/wallet-backend/internal/services" diff --git a/cmd/ingest.go b/cmd/ingest.go index 581dcf569..386e036a3 100644 --- a/cmd/ingest.go +++ b/cmd/ingest.go @@ -10,6 +10,7 @@ import ( "github.com/spf13/cobra" "github.com/stellar/go/support/config" "github.com/stellar/go/support/log" + "github.com/stellar/wallet-backend/cmd/utils" "github.com/stellar/wallet-backend/internal/apptracker/sentry" "github.com/stellar/wallet-backend/internal/ingest" diff --git a/cmd/migrate.go b/cmd/migrate.go index da23bcc4b..789d8b3cd 100644 --- a/cmd/migrate.go +++ b/cmd/migrate.go @@ -9,6 +9,7 @@ import ( "github.com/spf13/cobra" "github.com/stellar/go/support/config" "github.com/stellar/go/support/log" + "github.com/stellar/wallet-backend/cmd/utils" "github.com/stellar/wallet-backend/internal/db" ) @@ -76,7 +77,7 @@ func (c *migrateCmd) RunMigrateDown(ctx context.Context, databaseURL string, arg return fmt.Errorf("invalid [count] argument: %s", args[0]) } if err := executeMigrations(ctx, databaseURL, migrate.Down, count); err != nil { - return fmt.Errorf("executing migrate down: %v", err) + return fmt.Errorf("executing migrate down: %w", err) } return nil } diff --git a/cmd/channel_account_mocks.go b/cmd/mocks.go similarity index 100% rename from cmd/channel_account_mocks.go rename to cmd/mocks.go diff --git a/cmd/utils/global_options.go b/cmd/utils/global_options.go index e91cdcdbd..3ea903126 100644 --- a/cmd/utils/global_options.go +++ b/cmd/utils/global_options.go @@ -140,7 +140,6 @@ func StartLedgerOption(configKey *int) *config.ConfigOption { FlagDefault: 0, Required: true, } - } func EndLedgerOption(configKey *int) *config.ConfigOption { @@ -152,7 +151,6 @@ func EndLedgerOption(configKey *int) *config.ConfigOption { FlagDefault: 0, Required: true, } - } func AWSOptions(awsRegionConfigKey *string, kmsKeyARN *string, required bool) config.ConfigOptions { diff --git a/cmd/utils/password_prompter.go b/cmd/utils/password_prompter.go index 7bc71266b..90060363c 100644 --- a/cmd/utils/password_prompter.go +++ b/cmd/utils/password_prompter.go @@ -21,12 +21,19 @@ type defaultPasswordPrompter struct { var _ PasswordPrompter = (*defaultPasswordPrompter)(nil) func (pp *defaultPasswordPrompter) Run() (string, error) { - fmt.Fprint(pp.stdout, pp.inputLabelText, " ") + _, err := fmt.Fprint(pp.stdout, pp.inputLabelText, " ") + if err != nil { + return "", fmt.Errorf("writing input label text: %w", err) + } + password, err := term.ReadPassword(int(pp.stdin.Fd())) if err != nil { - return "", err + return "", fmt.Errorf("reading password: %w", err) + } + _, err = fmt.Fprintln(pp.stdout) + if err != nil { + return "", fmt.Errorf("writing newline: %w", err) } - fmt.Fprintln(pp.stdout) return string(password), nil } diff --git a/cmd/utils/tss_options.go b/cmd/utils/tss_options.go index 4e239914e..e841a80d6 100644 --- a/cmd/utils/tss_options.go +++ b/cmd/utils/tss_options.go @@ -24,7 +24,6 @@ func RPCCallerChannelMaxWorkersOption(configKey *int) *config.ConfigOption { ConfigKey: configKey, FlagDefault: 100, } - } func ErrorHandlerJitterChannelBufferSizeOption(configKey *int) *config.ConfigOption { @@ -58,7 +57,6 @@ func ErrorHandlerNonJitterChannelBufferSizeOption(configKey *int) *config.Config FlagDefault: 1000, Required: true, } - } func ErrorHandlerNonJitterChannelMaxWorkersOption(configKey *int) *config.ConfigOption { @@ -103,7 +101,6 @@ func ErrorHandlerJitterChannelMaxRetriesOptions(configKey *int) *config.ConfigOp FlagDefault: 3, Required: true, } - } func ErrorHandlerNonJitterChannelMaxRetriesOption(configKey *int) *config.ConfigOption { diff --git a/cmd/utils/utils.go b/cmd/utils/utils.go index 9df3ebb2d..b591be84d 100644 --- a/cmd/utils/utils.go +++ b/cmd/utils/utils.go @@ -5,6 +5,7 @@ import ( "github.com/spf13/cobra" "github.com/stellar/go/support/config" + "github.com/stellar/wallet-backend/internal/db" "github.com/stellar/wallet-backend/internal/signing" "github.com/stellar/wallet-backend/internal/signing/awskms" @@ -41,7 +42,14 @@ type SignatureClientOptions struct { EncryptionPassphrase string } -func SignatureClientResolver(signatureClientOpts *SignatureClientOptions) (signing.SignatureClient, error) { +//nolint:wrapcheck // defer is used to wrap the error +func SignatureClientResolver(signatureClientOpts *SignatureClientOptions) (sigClient signing.SignatureClient, err error) { + defer func() { + if err != nil { + err = fmt.Errorf("resolving signature client: %w", err) + } + }() + switch signatureClientOpts.Type { case signing.EnvSignatureClientType: return signing.NewEnvSignatureClient(signatureClientOpts.DistributionAccountSecretKey, signatureClientOpts.NetworkPassphrase) diff --git a/cmd/utils/utils_test.go b/cmd/utils/utils_test.go index 7f6f6acc6..b323be3d9 100644 --- a/cmd/utils/utils_test.go +++ b/cmd/utils/utils_test.go @@ -5,8 +5,9 @@ import ( "github.com/stellar/go/keypair" "github.com/stellar/go/network" - "github.com/stellar/wallet-backend/internal/signing" "github.com/stretchr/testify/assert" + + "github.com/stellar/wallet-backend/internal/signing" ) func TestSignatureClientResolver(t *testing.T) { @@ -25,14 +26,14 @@ func TestSignatureClientResolver(t *testing.T) { Type: signing.EnvSignatureClientType, DistributionAccountSecretKey: "invalid", }) - assert.EqualError(t, err, "parsing distribution account private key: base32 decode failed: illegal base32 data at input byte 7") + assert.EqualError(t, err, "resolving signature client: parsing distribution account private key: base32 decode failed: illegal base32 data at input byte 7") assert.Nil(t, sc) sc, err = SignatureClientResolver(&SignatureClientOptions{ Type: signing.EnvSignatureClientType, DistributionAccountSecretKey: keypair.MustRandom().Seed(), }) - assert.EqualError(t, err, "invalid network passphrase provided: ") + assert.EqualError(t, err, "resolving signature client: invalid network passphrase provided: ") assert.Nil(t, sc) sc, err = SignatureClientResolver(&SignatureClientOptions{ @@ -49,7 +50,7 @@ func TestSignatureClientResolver(t *testing.T) { Type: signing.KMSSignatureClientType, DistributionAccountPublicKey: keypair.MustRandom().Address(), }) - assert.EqualError(t, err, "instantiating kms client: aws region cannot be empty") + assert.EqualError(t, err, "resolving signature client: instantiating kms client: aws region cannot be empty") assert.Nil(t, sc) sc, err = SignatureClientResolver(&SignatureClientOptions{ @@ -65,7 +66,7 @@ func TestSignatureClientResolver(t *testing.T) { DistributionAccountPublicKey: keypair.MustRandom().Address(), AWSRegion: "us-east-2", }) - assert.EqualError(t, err, "invalid network passphrase provided: ") + assert.EqualError(t, err, "resolving signature client: invalid network passphrase provided: ") assert.Nil(t, sc) sc, err = SignatureClientResolver(&SignatureClientOptions{ @@ -74,7 +75,7 @@ func TestSignatureClientResolver(t *testing.T) { NetworkPassphrase: network.PublicNetworkPassphrase, AWSRegion: "us-east-2", }) - assert.EqualError(t, err, "aws key arn cannot be empty") + assert.EqualError(t, err, "resolving signature client: aws key arn cannot be empty") assert.Nil(t, sc) sc, err = SignatureClientResolver(&SignatureClientOptions{ diff --git a/internal/apptracker/mock.go b/internal/apptracker/mocks.go similarity index 100% rename from internal/apptracker/mock.go rename to internal/apptracker/mocks.go diff --git a/internal/apptracker/sentry/mocks.go b/internal/apptracker/sentry/mocks.go new file mode 100644 index 000000000..0da19e957 --- /dev/null +++ b/internal/apptracker/sentry/mocks.go @@ -0,0 +1,74 @@ +package sentry + +import ( + "testing" + "time" + + "github.com/getsentry/sentry-go" + "github.com/stretchr/testify/mock" +) + +// MockSentry is a mock struct to capture function calls +type MockSentry struct { + mock.Mock +} + +func (m *MockSentry) CaptureMessage(message string) *sentry.EventID { + args := m.Called(message) + return args.Get(0).(*sentry.EventID) +} + +func (m *MockSentry) CaptureException(exception error) *sentry.EventID { + args := m.Called(exception) + return args.Get(0).(*sentry.EventID) +} + +func (m *MockSentry) Init(options sentry.ClientOptions) error { + args := m.Called(options) + return args.Error(0) +} + +func (m *MockSentry) Flush(timeout time.Duration) bool { + m.Called(timeout) + return true +} + +func (m *MockSentry) Recover() *sentry.EventID { + args := m.Called() + return args.Get(0).(*sentry.EventID) +} + +func NewMockSentry(t *testing.T) *MockSentry { + t.Helper() + + mockSentry := &MockSentry{} + mockSentry. + On("Flush", mock.Anything).Return(true).Once(). + On("Recover").Return((*sentry.EventID)(nil)).Once() + + // Save the original functions + originalInitFunc := InitFunc + originalFlushFunc := FlushFunc + originalRecoverFunc := RecoverFunc + originalCaptureMessageFunc := captureMessageFunc + originalCaptureExceptionFunc := captureExceptionFunc + + // Set the mock functions + InitFunc = mockSentry.Init + FlushFunc = mockSentry.Flush + RecoverFunc = mockSentry.Recover + captureMessageFunc = mockSentry.CaptureMessage + captureExceptionFunc = mockSentry.CaptureException + + // Restore the original functions after the test + t.Cleanup(func() { + InitFunc = originalInitFunc + FlushFunc = originalFlushFunc + RecoverFunc = originalRecoverFunc + captureMessageFunc = originalCaptureMessageFunc + captureExceptionFunc = originalCaptureExceptionFunc + mockSentry.AssertExpectations(t) + }) + + return mockSentry +} diff --git a/internal/apptracker/sentry/sentry_tracker.go b/internal/apptracker/sentry/sentry_tracker.go index 1747eec48..790fbb824 100644 --- a/internal/apptracker/sentry/sentry_tracker.go +++ b/internal/apptracker/sentry/sentry_tracker.go @@ -1,6 +1,7 @@ package sentry import ( + "fmt" "time" "github.com/getsentry/sentry-go" @@ -31,10 +32,9 @@ func NewSentryTracker(dsn string, env string, flushFreq int) (*SentryTracker, er Dsn: dsn, Environment: env, }); err != nil { - return nil, err + return nil, fmt.Errorf("unable to initialize sentry: %w", err) } defer FlushFunc(time.Second * time.Duration(flushFreq)) defer RecoverFunc() return &SentryTracker{}, nil - } diff --git a/internal/apptracker/sentry/sentry_tracker_test.go b/internal/apptracker/sentry/sentry_tracker_test.go index def1a6255..15fe3da05 100644 --- a/internal/apptracker/sentry/sentry_tracker_test.go +++ b/internal/apptracker/sentry/sentry_tracker_test.go @@ -3,105 +3,63 @@ package sentry import ( "errors" "testing" - "time" "github.com/getsentry/sentry-go" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" ) -// MockSentry is a mock struct to capture function calls -type MockSentry struct { - mock.Mock -} - -func (m *MockSentry) CaptureMessage(message string) *sentry.EventID { - args := m.Called(message) - return args.Get(0).(*sentry.EventID) -} - -func (m *MockSentry) CaptureException(exception error) *sentry.EventID { - args := m.Called(exception) - return args.Get(0).(*sentry.EventID) -} - -func (m *MockSentry) Init(options sentry.ClientOptions) error { - args := m.Called(options) - return args.Error(0) -} - -func (m *MockSentry) Flush(timeout time.Duration) bool { - m.Called(timeout) - return true -} +func TestSentryTracker_CaptureMessage(t *testing.T) { + mockSentry := NewMockSentry(t) + mockSentry. + On("Init", mock.Anything).Return(nil).Once(). + On("CaptureMessage", "Test message").Return((*sentry.EventID)(nil)).Once() -func (m *MockSentry) Recover() *sentry.EventID { - args := m.Called() - return args.Get(0).(*sentry.EventID) -} + tracker, err := NewSentryTracker("dsn", "test-env", 5) + require.NoError(t, err) + require.NotNil(t, tracker) -func TestSentryTracker_CaptureMessage(t *testing.T) { - mockSentry := MockSentry{} - captureMessageFunc = mockSentry.CaptureMessage - defer func() { captureMessageFunc = sentry.CaptureMessage }() - mockSentry.On("CaptureMessage", "Test message").Return((*sentry.EventID)(nil)) - tracker, _ := NewSentryTracker("sentrydsn", "test", 5) tracker.CaptureMessage("Test message") - - mockSentry.AssertCalled(t, "CaptureMessage", "Test message") } func TestSentryTracker_CaptureException(t *testing.T) { - mockSentry := MockSentry{} - captureExceptionFunc = mockSentry.CaptureException - defer func() { captureExceptionFunc = sentry.CaptureException }() // Reset after the test + mockSentry := NewMockSentry(t) testError := errors.New("Test exception") - mockSentry.On("CaptureException", testError).Return((*sentry.EventID)(nil)) - tracker, _ := NewSentryTracker("sentrydsn", "test", 5) - tracker.CaptureException(testError) + mockSentry. + On("Init", mock.Anything).Return(nil).Once(). + On("CaptureException", testError).Return((*sentry.EventID)(nil)).Once() + + tracker, err := NewSentryTracker("dsn", "test-env", 5) + require.NoError(t, err) + require.NotNil(t, tracker) - mockSentry.AssertCalled(t, "CaptureException", testError) + tracker.CaptureException(testError) } func TestNewSentryTracker_Success(t *testing.T) { - mockSentry := MockSentry{} - - InitFunc = mockSentry.Init - FlushFunc = mockSentry.Flush - RecoverFunc = mockSentry.Recover - - defer func() { - InitFunc = sentry.Init - FlushFunc = sentry.Flush - RecoverFunc = sentry.Recover - }() - - mockSentry.On("Init", mock.Anything).Return(nil) - mockSentry.On("Flush", time.Second*5).Return(true) - mockSentry.On("Recover").Return((*sentry.EventID)(nil)) + mockSentry := NewMockSentry(t) + mockSentry. + On("Init", mock.Anything).Return(nil).Once() tracker, err := NewSentryTracker("dsn", "test-env", 5) - - assert.NoError(t, err) - assert.NotNil(t, tracker) - - mockSentry.AssertCalled(t, "Init", mock.Anything) - mockSentry.AssertCalled(t, "Flush", time.Second*5) - mockSentry.AssertCalled(t, "Recover") + require.NoError(t, err) + require.NotNil(t, tracker) } func TestNewSentryTracker_InitFailure(t *testing.T) { mockSentry := MockSentry{} InitFunc = mockSentry.Init - defer func() { + t.Cleanup(func() { InitFunc = sentry.Init - }() + mockSentry.AssertExpectations(t) + }) + initError := errors.New("init error") - mockSentry.On("Init", mock.Anything).Return(initError) - tracker, err := NewSentryTracker("dsn", "test-env", 5) - assert.Error(t, err) - assert.Equal(t, initError, err) - assert.Nil(t, tracker) + mockSentry. + On("Init", mock.Anything).Return(initError).Once() - mockSentry.AssertCalled(t, "Init", mock.Anything) + tracker, err := NewSentryTracker("dsn", "test-env", 5) + require.Error(t, err) + require.ErrorContains(t, err, "unable to initialize sentry: init error") + require.Nil(t, tracker) } diff --git a/internal/data/accounts_test.go b/internal/data/accounts_test.go index 088991680..ed3705908 100644 --- a/internal/data/accounts_test.go +++ b/internal/data/accounts_test.go @@ -6,18 +6,18 @@ import ( "testing" "github.com/stellar/go/keypair" - "github.com/stellar/wallet-backend/internal/db" - "github.com/stellar/wallet-backend/internal/db/dbtest" - "github.com/stellar/wallet-backend/internal/metrics" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + + "github.com/stellar/wallet-backend/internal/db" + "github.com/stellar/wallet-backend/internal/db/dbtest" + "github.com/stellar/wallet-backend/internal/metrics" ) func TestAccountModelInsert(t *testing.T) { dbt := dbtest.Open(t) defer dbt.Close() - dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) require.NoError(t, err) defer dbConnectionPool.Close() @@ -48,7 +48,6 @@ func TestAccountModelInsert(t *testing.T) { func TestAccountModelDelete(t *testing.T) { dbt := dbtest.Open(t) defer dbt.Close() - dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) require.NoError(t, err) defer dbConnectionPool.Close() @@ -82,7 +81,6 @@ func TestAccountModelDelete(t *testing.T) { func TestAccountModelIsAccountFeeBumpEligible(t *testing.T) { dbt := dbtest.Open(t) defer dbt.Close() - dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) require.NoError(t, err) defer dbConnectionPool.Close() diff --git a/internal/data/fixtures.go b/internal/data/fixtures.go index 9ae7f85ac..ad4b13432 100644 --- a/internal/data/fixtures.go +++ b/internal/data/fixtures.go @@ -4,8 +4,9 @@ import ( "context" "testing" - "github.com/stellar/wallet-backend/internal/db" "github.com/stretchr/testify/require" + + "github.com/stellar/wallet-backend/internal/db" ) func InsertTestPayments(t *testing.T, ctx context.Context, payments []Payment, connectionPool db.ConnectionPool) { diff --git a/internal/data/payments.go b/internal/data/payments.go index 55b896db1..4c9f0f504 100644 --- a/internal/data/payments.go +++ b/internal/data/payments.go @@ -44,7 +44,7 @@ func (m *PaymentModel) GetLatestLedgerSynced(ctx context.Context, cursorName str m.MetricsService.ObserveDBQueryDuration("SELECT", "ingest_store", duration) m.MetricsService.IncDBQuery("SELECT", "ingest_store") // First run, key does not exist yet - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return 0, nil } if err != nil { diff --git a/internal/data/payments_test.go b/internal/data/payments_test.go index 315de999f..2e12e94d6 100644 --- a/internal/data/payments_test.go +++ b/internal/data/payments_test.go @@ -7,13 +7,14 @@ import ( "time" "github.com/stellar/go/xdr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stellar/wallet-backend/internal/db" "github.com/stellar/wallet-backend/internal/db/dbtest" "github.com/stellar/wallet-backend/internal/metrics" "github.com/stellar/wallet-backend/internal/utils" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" ) func TestPaymentModelAddPayment(t *testing.T) { diff --git a/internal/data/query_utils.go b/internal/data/query_utils.go index 75751acfa..f04e82ec5 100644 --- a/internal/data/query_utils.go +++ b/internal/data/query_utils.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/jmoiron/sqlx" + "github.com/stellar/wallet-backend/internal/db" ) diff --git a/internal/db/db.go b/internal/db/db.go index 0474d11fa..45c97d147 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -47,10 +47,12 @@ func OpenDBConnectionPool(dataSourceName string) (ConnectionPool, error) { return &ConnectionPoolImplementation{DB: sqlxDB}, nil } +//nolint:wrapcheck // this is a thin layer on top of the sqlx.DB.BeginTxx method func (db *ConnectionPoolImplementation) BeginTxx(ctx context.Context, opts *sql.TxOptions) (Transaction, error) { return db.DB.BeginTxx(ctx, opts) } +//nolint:wrapcheck // this is a thin layer on top of the sqlx.DB.PingContext method func (db *ConnectionPoolImplementation) Ping(ctx context.Context) error { return db.DB.PingContext(ctx) } diff --git a/internal/db/dbtest/dbtest.go b/internal/db/dbtest/dbtest.go index f4c12f18f..45650ab83 100644 --- a/internal/db/dbtest/dbtest.go +++ b/internal/db/dbtest/dbtest.go @@ -7,6 +7,7 @@ import ( migrate "github.com/rubenv/sql-migrate" "github.com/stellar/go/support/db/dbtest" "github.com/stellar/go/support/db/schema" + "github.com/stellar/wallet-backend/internal/db/migrations" ) diff --git a/internal/db/migrate.go b/internal/db/migrate.go index 7d58abde6..3e21d725b 100644 --- a/internal/db/migrate.go +++ b/internal/db/migrate.go @@ -6,7 +6,9 @@ import ( "net/http" migrate "github.com/rubenv/sql-migrate" + "github.com/stellar/wallet-backend/internal/db/migrations" + "github.com/stellar/wallet-backend/internal/utils" ) func Migrate(ctx context.Context, databaseURL string, direction migrate.MigrationDirection, count int) (int, error) { @@ -14,12 +16,17 @@ func Migrate(ctx context.Context, databaseURL string, direction migrate.Migratio if err != nil { return 0, fmt.Errorf("connecting to the database: %w", err) } - defer dbConnectionPool.Close() + defer utils.DeferredClose(ctx, dbConnectionPool, "closing dbConnectionPool in the Migrate function") m := migrate.HttpFileSystemMigrationSource{FileSystem: http.FS(migrations.FS)} db, err := dbConnectionPool.SqlDB(ctx) if err != nil { return 0, fmt.Errorf("fetching sql.DB: %w", err) } - return migrate.ExecMax(db, dbConnectionPool.DriverName(), m, direction, count) + + appliedMigrationsCount, err := migrate.ExecMax(db, dbConnectionPool.DriverName(), m, direction, count) + if err != nil { + return appliedMigrationsCount, fmt.Errorf("applying migrations: %w", err) + } + return appliedMigrationsCount, nil } diff --git a/internal/db/migrate_test.go b/internal/db/migrate_test.go index 6edaa33d8..330f22d8f 100644 --- a/internal/db/migrate_test.go +++ b/internal/db/migrate_test.go @@ -6,10 +6,11 @@ import ( "testing" migrate "github.com/rubenv/sql-migrate" - "github.com/stellar/wallet-backend/internal/db/dbtest" - "github.com/stellar/wallet-backend/internal/db/migrations" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/stellar/wallet-backend/internal/db/dbtest" + "github.com/stellar/wallet-backend/internal/db/migrations" ) func TestMigrate_up_1(t *testing.T) { diff --git a/internal/metrics/mock.go b/internal/metrics/mocks.go similarity index 100% rename from internal/metrics/mock.go rename to internal/metrics/mocks.go diff --git a/internal/serve/auth/mock.go b/internal/serve/auth/mocks.go similarity index 100% rename from internal/serve/auth/mock.go rename to internal/serve/auth/mocks.go diff --git a/internal/serve/auth/signature_verifier.go b/internal/serve/auth/signature_verifier.go index 2dbb502d8..b28441fc2 100644 --- a/internal/serve/auth/signature_verifier.go +++ b/internal/serve/auth/signature_verifier.go @@ -20,28 +20,26 @@ type SignatureVerifier interface { VerifySignature(ctx context.Context, signatureHeaderContent string, rawReqBody []byte) error } -var ( - ErrStellarSignatureNotVerified = errors.New("neither Signature nor X-Stellar-Signature header could be verified") -) +var ErrStellarSignatureNotVerified = errors.New("neither Signature nor X-Stellar-Signature header could be verified") -type ErrInvalidTimestampFormat struct { +type InvalidTimestampFormatError struct { TimestampString string timestampValueError bool } -func (e ErrInvalidTimestampFormat) Error() string { +func (e InvalidTimestampFormatError) Error() string { if e.timestampValueError { return fmt.Sprintf("signature format different than expected. expected unix seconds, got: %s", e.TimestampString) } return fmt.Sprintf("malformed timestamp: %s", e.TimestampString) } -type ErrExpiredSignatureTimestamp struct { +type ExpiredSignatureTimestampError struct { ExpiredSignatureTimestamp time.Time CheckTime time.Time } -func (e ErrExpiredSignatureTimestamp) Error() string { +func (e ExpiredSignatureTimestampError) Error() string { return fmt.Sprintf("signature timestamp has expired. sig timestamp: %s, check time %s", e.ExpiredSignatureTimestamp.Format(time.RFC3339), e.CheckTime.Format(time.RFC3339)) } @@ -99,7 +97,7 @@ func ExtractTimestampedSignature(signatureHeaderContent string) (t string, s str tHeaderContent := parts[0] timestampParts := strings.SplitN(tHeaderContent, "=", 2) if len(timestampParts) != 2 || strings.TrimSpace(timestampParts[0]) != "t" { - return "", "", &ErrInvalidTimestampFormat{TimestampString: tHeaderContent} + return "", "", &InvalidTimestampFormatError{TimestampString: tHeaderContent} } t = strings.TrimSpace(timestampParts[1]) @@ -115,13 +113,17 @@ func ExtractTimestampedSignature(signatureHeaderContent string) (t string, s str func VerifyGracePeriodSeconds(timestampString string, gracePeriod time.Duration) error { // Note: from Nov 20th, 2286 this RegEx will fail because of an extra digit - if ok, _ := regexp.MatchString(`^\d{10}$`, timestampString); !ok { - return &ErrInvalidTimestampFormat{TimestampString: timestampString, timestampValueError: true} + ok, err := regexp.MatchString(`^\d{10}$`, timestampString) + if !ok { + return &InvalidTimestampFormatError{TimestampString: timestampString, timestampValueError: true} + } + if err != nil { + return fmt.Errorf("attempting to parse timestamp %q with regex: %w", timestampString, err) } timestampUnix, err := strconv.ParseInt(timestampString, 10, 64) if err != nil { - return fmt.Errorf("unable to parse timestamp value %s: %v", timestampString, err) + return fmt.Errorf("unable to parse timestamp value %s: %w", timestampString, err) } return verifyGracePeriod(time.Unix(timestampUnix, 0), gracePeriod) @@ -130,7 +132,7 @@ func VerifyGracePeriodSeconds(timestampString string, gracePeriod time.Duration) func verifyGracePeriod(timestamp time.Time, gracePeriod time.Duration) error { now := time.Now() if !timestamp.Add(gracePeriod).After(now) { - return &ErrExpiredSignatureTimestamp{ExpiredSignatureTimestamp: timestamp, CheckTime: now} + return &ExpiredSignatureTimestampError{ExpiredSignatureTimestamp: timestamp, CheckTime: now} } return nil diff --git a/internal/serve/auth/signature_verifier_test.go b/internal/serve/auth/signature_verifier_test.go index aae45f6a4..0f24abc14 100644 --- a/internal/serve/auth/signature_verifier_test.go +++ b/internal/serve/auth/signature_verifier_test.go @@ -31,7 +31,7 @@ func TestSignatureVerifierVerifySignature(t *testing.T) { require.NoError(t, err) signatureHeaderContent := fmt.Sprintf("t=%d, s=%s", now.Unix(), sig) - err := signatureVerifier.VerifySignature(ctx, signatureHeaderContent, []byte(reqBody)) + err = signatureVerifier.VerifySignature(ctx, signatureHeaderContent, []byte(reqBody)) assert.EqualError(t, err, ErrStellarSignatureNotVerified.Error()) }) @@ -66,8 +66,8 @@ func TestExtractTimestampedSignature(t *testing.T) { assert.Empty(t, s) ts, s, err = ExtractTimestampedSignature("a,b") - var errTimestampFormat *ErrInvalidTimestampFormat - assert.ErrorAs(t, err, &errTimestampFormat) + var invalidTimestampFormatErr *InvalidTimestampFormatError + assert.ErrorAs(t, err, &invalidTimestampFormatErr) assert.EqualError(t, err, "malformed timestamp: a") assert.Empty(t, ts) assert.Empty(t, s) @@ -88,26 +88,26 @@ func TestExtractTimestampedSignature(t *testing.T) { func TestVerifyGracePeriodSeconds(t *testing.T) { t.Run("invalid_timestamp", func(t *testing.T) { - var errTimestampFormat *ErrInvalidTimestampFormat + var invalidTimestampFormatErr *InvalidTimestampFormatError err := VerifyGracePeriodSeconds("", 2*time.Second) - assert.ErrorAs(t, err, &errTimestampFormat) + assert.ErrorAs(t, err, &invalidTimestampFormatErr) assert.EqualError(t, err, "signature format different than expected. expected unix seconds, got: ") err = VerifyGracePeriodSeconds("123", 2*time.Second) - assert.ErrorAs(t, err, &errTimestampFormat) + assert.ErrorAs(t, err, &invalidTimestampFormatErr) assert.EqualError(t, err, "signature format different than expected. expected unix seconds, got: 123") err = VerifyGracePeriodSeconds("12345678910", 2*time.Second) - assert.ErrorAs(t, err, &errTimestampFormat) + assert.ErrorAs(t, err, &invalidTimestampFormatErr) assert.EqualError(t, err, "signature format different than expected. expected unix seconds, got: 12345678910") }) t.Run("successfully_verifies_grace_period", func(t *testing.T) { - var errExpiredSignatureTimestamp *ErrExpiredSignatureTimestamp + var expiredSignatureTimestampErr *ExpiredSignatureTimestampError now := time.Now().Add(-5 * time.Second) ts := now.Unix() err := VerifyGracePeriodSeconds(strconv.FormatInt(ts, 10), 2*time.Second) - assert.ErrorAs(t, err, &errExpiredSignatureTimestamp) + assert.ErrorAs(t, err, &expiredSignatureTimestampErr) assert.ErrorContains(t, err, fmt.Sprintf("signature timestamp has expired. sig timestamp: %s, check time", now.Format(time.RFC3339))) now = time.Now().Add(-1 * time.Second) diff --git a/internal/serve/httperror/errors.go b/internal/serve/httperror/errors.go index 2c8810e0c..363945037 100644 --- a/internal/serve/httperror/errors.go +++ b/internal/serve/httperror/errors.go @@ -6,6 +6,7 @@ import ( "github.com/stellar/go/support/log" "github.com/stellar/go/support/render/httpjson" + "github.com/stellar/wallet-backend/internal/apptracker" ) diff --git a/internal/serve/httperror/errors_test.go b/internal/serve/httperror/errors_test.go index de834f831..1997481c7 100644 --- a/internal/serve/httperror/errors_test.go +++ b/internal/serve/httperror/errors_test.go @@ -9,9 +9,10 @@ import ( "net/http/httptest" "testing" - "github.com/stellar/wallet-backend/internal/apptracker" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/stellar/wallet-backend/internal/apptracker" ) func TestErrorResponseRender(t *testing.T) { diff --git a/internal/serve/httphandler/account_handler.go b/internal/serve/httphandler/account_handler.go index d6e90b6a3..f919edc32 100644 --- a/internal/serve/httphandler/account_handler.go +++ b/internal/serve/httphandler/account_handler.go @@ -6,6 +6,7 @@ import ( "github.com/stellar/go/support/render/httpjson" "github.com/stellar/go/txnbuild" + "github.com/stellar/wallet-backend/internal/apptracker" "github.com/stellar/wallet-backend/internal/entities" "github.com/stellar/wallet-backend/internal/serve/httperror" @@ -138,10 +139,10 @@ func (h AccountHandler) CreateFeeBumpTransaction(rw http.ResponseWriter, req *ht feeBumpTxe, networkPassphrase, err := h.AccountSponsorshipService.WrapTransaction(ctx, tx) if err != nil { - var errOperationNotAllowed *services.ErrOperationNotAllowed + var opNotAllowedErr *services.OperationNotAllowedError switch { case errors.Is(err, services.ErrAccountNotEligibleForBeingSponsored), errors.Is(err, services.ErrFeeExceedsMaximumBaseFee), - errors.Is(err, services.ErrNoSignaturesProvided), errors.As(err, &errOperationNotAllowed): + errors.Is(err, services.ErrNoSignaturesProvided), errors.As(err, &opNotAllowedErr): httperror.BadRequest(err.Error(), nil).Render(rw) return default: diff --git a/internal/serve/httphandler/account_handler_test.go b/internal/serve/httphandler/account_handler_test.go index f3f0dad8b..75e1666fc 100644 --- a/internal/serve/httphandler/account_handler_test.go +++ b/internal/serve/httphandler/account_handler_test.go @@ -17,16 +17,16 @@ import ( "github.com/stellar/go/network" "github.com/stellar/go/txnbuild" "github.com/stellar/go/xdr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stellar/wallet-backend/internal/data" "github.com/stellar/wallet-backend/internal/db" "github.com/stellar/wallet-backend/internal/db/dbtest" "github.com/stellar/wallet-backend/internal/entities" "github.com/stellar/wallet-backend/internal/metrics" "github.com/stellar/wallet-backend/internal/services" - "github.com/stellar/wallet-backend/internal/services/servicesmocks" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" ) func TestAccountHandlerRegisterAccount(t *testing.T) { @@ -63,7 +63,8 @@ func TestAccountHandlerRegisterAccount(t *testing.T) { // Prepare request address := keypair.MustRandom().Address() - req, err := http.NewRequest(http.MethodPost, path.Join("/accounts", address), nil) + var req *http.Request + req, err = http.NewRequest(http.MethodPost, path.Join("/accounts", address), nil) require.NoError(t, err) // Serve request @@ -173,7 +174,8 @@ func TestAccountHandlerDeregisterAccount(t *testing.T) { require.NoError(t, err) // Prepare request - req, err := http.NewRequest(http.MethodDelete, path.Join("/accounts", address), nil) + var req *http.Request + req, err = http.NewRequest(http.MethodDelete, path.Join("/accounts", address), nil) require.NoError(t, err) // Serve request @@ -229,7 +231,7 @@ func TestAccountHandlerDeregisterAccount(t *testing.T) { } func TestAccountHandlerSponsorAccountCreation(t *testing.T) { - asService := servicesmocks.AccountSponsorshipServiceMock{} + asService := services.AccountSponsorshipServiceMock{} defer asService.AssertExpectations(t) assets := []entities.Asset{ @@ -552,7 +554,7 @@ func TestAccountHandlerSponsorAccountCreation(t *testing.T) { } func TestAccountHandlerCreateFeeBumpTransaction(t *testing.T) { - asService := servicesmocks.AccountSponsorshipServiceMock{} + asService := services.AccountSponsorshipServiceMock{} defer asService.AssertExpectations(t) handler := &AccountHandler{ @@ -837,7 +839,7 @@ func TestAccountHandlerCreateFeeBumpTransaction(t *testing.T) { asService. On("WrapTransaction", req.Context(), tx). - Return("", "", &services.ErrOperationNotAllowed{OperationType: xdr.OperationTypeLiquidityPoolDeposit}). + Return("", "", &services.OperationNotAllowedError{OperationType: xdr.OperationTypeLiquidityPoolDeposit}). Once() http.HandlerFunc(handler.CreateFeeBumpTransaction).ServeHTTP(rw, req) diff --git a/internal/serve/httphandler/payment_handler.go b/internal/serve/httphandler/payment_handler.go index 51fdbb3d0..49eff95c8 100644 --- a/internal/serve/httphandler/payment_handler.go +++ b/internal/serve/httphandler/payment_handler.go @@ -4,6 +4,7 @@ import ( "net/http" "github.com/stellar/go/support/render/httpjson" + "github.com/stellar/wallet-backend/internal/apptracker" "github.com/stellar/wallet-backend/internal/data" "github.com/stellar/wallet-backend/internal/entities" diff --git a/internal/serve/httphandler/payment_handler_test.go b/internal/serve/httphandler/payment_handler_test.go index 0cf436c18..2a8c96c10 100644 --- a/internal/serve/httphandler/payment_handler_test.go +++ b/internal/serve/httphandler/payment_handler_test.go @@ -10,15 +10,16 @@ import ( "github.com/go-chi/chi" "github.com/stellar/go/xdr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stellar/wallet-backend/internal/data" "github.com/stellar/wallet-backend/internal/db" "github.com/stellar/wallet-backend/internal/db/dbtest" "github.com/stellar/wallet-backend/internal/metrics" "github.com/stellar/wallet-backend/internal/services" "github.com/stellar/wallet-backend/internal/utils" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" ) func TestPaymentHandlerGetPayments(t *testing.T) { diff --git a/internal/serve/httphandler/request_params_validator.go b/internal/serve/httphandler/request_params_validator.go index 789c01338..8ec0e95c1 100644 --- a/internal/serve/httphandler/request_params_validator.go +++ b/internal/serve/httphandler/request_params_validator.go @@ -2,10 +2,12 @@ package httphandler import ( "context" + "errors" "net/http" "github.com/go-playground/validator/v10" "github.com/stellar/go/support/http/httpdecode" + "github.com/stellar/wallet-backend/internal/apptracker" "github.com/stellar/wallet-backend/internal/serve/httperror" "github.com/stellar/wallet-backend/internal/validators" @@ -39,9 +41,14 @@ func DecodePathAndValidate(ctx context.Context, req *http.Request, reqPath inter } func ValidateRequestParams(ctx context.Context, reqParams interface{}, appTracker apptracker.AppTracker) *httperror.ErrorResponse { - val := validators.NewValidator() + val, err := validators.NewValidator() + if err != nil { + return httperror.InternalServerError(ctx, "Internal error while creating a new validator.", err, nil, appTracker) + } + if err := val.StructCtx(ctx, reqParams); err != nil { - if vErrs, ok := err.(validator.ValidationErrors); ok { + var vErrs validator.ValidationErrors + if ok := errors.As(err, &vErrs); ok { extras := validators.ParseValidationError(vErrs) return httperror.BadRequest("Validation error.", extras) } diff --git a/internal/serve/httphandler/tss_handler.go b/internal/serve/httphandler/tss_handler.go index db64d49de..80f2fde94 100644 --- a/internal/serve/httphandler/tss_handler.go +++ b/internal/serve/httphandler/tss_handler.go @@ -7,6 +7,7 @@ import ( "github.com/stellar/go/support/log" "github.com/stellar/go/support/render/httpjson" "github.com/stellar/go/txnbuild" + "github.com/stellar/wallet-backend/internal/apptracker" "github.com/stellar/wallet-backend/internal/metrics" "github.com/stellar/wallet-backend/internal/serve/httperror" @@ -14,7 +15,8 @@ import ( "github.com/stellar/wallet-backend/internal/tss/router" tssservices "github.com/stellar/wallet-backend/internal/tss/services" "github.com/stellar/wallet-backend/internal/tss/store" - "github.com/stellar/wallet-backend/internal/tss/utils" + tssUtils "github.com/stellar/wallet-backend/internal/tss/utils" + "github.com/stellar/wallet-backend/internal/utils" ) type TSSHandler struct { @@ -59,7 +61,7 @@ func (t *TSSHandler) BuildTransactions(w http.ResponseWriter, r *http.Request) { } var transactionXDRs []string for _, transaction := range reqParams.Transactions { - ops, err := utils.BuildOperations(transaction.Operations) + ops, err := tssUtils.BuildOperations(transaction.Operations) if err != nil { httperror.BadRequest("bad operation xdr", nil).Render(w) return @@ -157,8 +159,9 @@ func (t *TSSHandler) GetTransaction(w http.ResponseWriter, r *http.Request) { return } - if tx == (store.Transaction{}) { + if utils.IsEmpty(tx) { httperror.NotFound.Render(w) + return } tssTry, err := t.Store.GetLatestTry(ctx, tx.Hash) diff --git a/internal/serve/httphandler/tss_handler_test.go b/internal/serve/httphandler/tss_handler_test.go index 7f9371479..5f16942a1 100644 --- a/internal/serve/httphandler/tss_handler_test.go +++ b/internal/serve/httphandler/tss_handler_test.go @@ -17,6 +17,10 @@ import ( xdr3 "github.com/stellar/go-xdr/xdr3" "github.com/stellar/go/keypair" "github.com/stellar/go/txnbuild" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stellar/wallet-backend/internal/apptracker" "github.com/stellar/wallet-backend/internal/db" "github.com/stellar/wallet-backend/internal/db/dbtest" @@ -26,9 +30,6 @@ import ( tssservices "github.com/stellar/wallet-backend/internal/tss/services" "github.com/stellar/wallet-backend/internal/tss/store" "github.com/stellar/wallet-backend/internal/tss/utils" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" ) func TestBuildTransactions(t *testing.T) { @@ -39,7 +40,8 @@ func TestBuildTransactions(t *testing.T) { require.NoError(t, err) defer dbConnectionPool.Close() mockMetricsService := metrics.NewMockMetricsService() - store, _ := store.NewStore(dbConnectionPool, mockMetricsService) + store, err := store.NewStore(dbConnectionPool, mockMetricsService) + require.NoError(t, err) mockRouter := router.MockRouter{} mockAppTracker := apptracker.MockAppTracker{} mockTxService := tssservices.TransactionServiceMock{} @@ -59,11 +61,13 @@ func TestBuildTransactions(t *testing.T) { Asset: txnbuild.NativeAsset{}, SourceAccount: srcAccount, } - op, _ := p.BuildXDR() + op, err := p.BuildXDR() + require.NoError(t, err) var buf strings.Builder enc := xdr3.NewEncoder(&buf) - _ = op.EncodeTo(enc) + err = op.EncodeTo(enc) + require.NoError(t, err) opXDR := buf.String() opXDRBase64 := base64.StdEncoding.EncodeToString([]byte(opXDR)) @@ -77,9 +81,10 @@ func TestBuildTransactions(t *testing.T) { rw := httptest.NewRecorder() req := httptest.NewRequest(http.MethodPost, endpoint, strings.NewReader(reqBody)) - expectedOps, _ := utils.BuildOperations([]string{opXDRBase64}) + expectedOps, err := utils.BuildOperations([]string{opXDRBase64}) + require.NoError(t, err) - err := errors.New("unable to find channel account") + err = errors.New("unable to find channel account") mockTxService. On("BuildAndSignTransactionWithChannelAccount", context.Background(), expectedOps, int64(100)). Return(nil, err). @@ -107,8 +112,9 @@ func TestBuildTransactions(t *testing.T) { rw := httptest.NewRecorder() req := httptest.NewRequest(http.MethodPost, endpoint, strings.NewReader(reqBody)) - expectedOps, _ := utils.BuildOperations([]string{opXDRBase64}) - tx := utils.BuildTestTransaction() + expectedOps, err := utils.BuildOperations([]string{opXDRBase64}) + require.NoError(t, err) + tx := utils.BuildTestTransaction(t) mockTxService. On("BuildAndSignTransactionWithChannelAccount", context.Background(), expectedOps, int64(100)). @@ -123,11 +129,12 @@ func TestBuildTransactions(t *testing.T) { assert.Equal(t, http.StatusOK, resp.StatusCode) var buildTxResp BuildTransactionsResponse - _ = json.Unmarshal(respBody, &buildTxResp) - expectedTxXDR, _ := tx.Base64() + err = json.Unmarshal(respBody, &buildTxResp) + require.NoError(t, err) + expectedTxXDR, err := tx.Base64() + require.NoError(t, err) assert.Equal(t, expectedTxXDR, buildTxResp.TransactionXDRs[0]) }) - } func TestSubmitTransactions(t *testing.T) { @@ -140,7 +147,8 @@ func TestSubmitTransactions(t *testing.T) { sqlxDB, err := dbConnectionPool.SqlxDB(context.Background()) require.NoError(t, err) metricsService := metrics.NewMetricsService(sqlxDB) - store, _ := store.NewStore(dbConnectionPool, metricsService) + store, err := store.NewStore(dbConnectionPool, metricsService) + require.NoError(t, err) mockRouter := router.MockRouter{} mockAppTracker := apptracker.MockAppTracker{} txServiceMock := tssservices.TransactionServiceMock{} @@ -195,12 +203,12 @@ func TestSubmitTransactions(t *testing.T) { expectedRespBody = `{"error": "bad transaction xdr"}` assert.Equal(t, http.StatusBadRequest, resp.StatusCode) assert.JSONEq(t, expectedRespBody, string(respBody)) - }) t.Run("happy_path", func(t *testing.T) { - tx := utils.BuildTestTransaction() - txXDR, _ := tx.Base64() + tx := utils.BuildTestTransaction(t) + txXDR, err := tx.Base64() + require.NoError(t, err) reqBody := fmt.Sprintf(`{ "webhook": "localhost:8080", "transactions": [%q] @@ -227,7 +235,8 @@ func TestSubmitTransactions(t *testing.T) { assert.Equal(t, http.StatusOK, resp.StatusCode) var txSubmissionResp TransactionSubmissionResponse - _ = json.Unmarshal(respBody, &txSubmissionResp) + err = json.Unmarshal(respBody, &txSubmissionResp) + require.NoError(t, err) assert.Equal(t, 1, len(txSubmissionResp.TransactionHashes)) @@ -246,7 +255,8 @@ func TestGetTransaction(t *testing.T) { sqlxDB, err := dbConnectionPool.SqlxDB(context.Background()) require.NoError(t, err) metricsService := metrics.NewMetricsService(sqlxDB) - store, _ := store.NewStore(dbConnectionPool, metricsService) + store, err := store.NewStore(dbConnectionPool, metricsService) + require.NoError(t, err) mockRouter := router.MockRouter{} mockAppTracker := apptracker.MockAppTracker{} txServiceMock := tssservices.TransactionServiceMock{} @@ -274,25 +284,29 @@ func TestGetTransaction(t *testing.T) { t.Run("returns_empty_try", func(t *testing.T) { txHash := "hash" ctx := context.Background() - _ = store.UpsertTransaction(ctx, "localhost:8080/webhook", txHash, "xdr", tss.RPCTXStatus{OtherStatus: tss.NewStatus}) - req, err := http.NewRequest(http.MethodGet, path.Join(endpoint, txHash), nil) + err = store.UpsertTransaction(ctx, "localhost:8080/webhook", txHash, "xdr", tss.RPCTXStatus{OtherStatus: tss.NewStatus}) + require.NoError(t, err) + var req *http.Request + req, err = http.NewRequest(http.MethodGet, path.Join(endpoint, txHash), nil) require.NoError(t, err) // Serve request rw := httptest.NewRecorder() r.ServeHTTP(rw, req) resp := rw.Result() - respBody, err := io.ReadAll(resp.Body) + var respBody []byte + respBody, err = io.ReadAll(resp.Body) require.NoError(t, err) var tssResp tss.TSSResponse - _ = json.Unmarshal(respBody, &tssResp) + err = json.Unmarshal(respBody, &tssResp) + require.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) assert.Equal(t, txHash, tssResp.TransactionHash) assert.Equal(t, fmt.Sprint(tss.NoCode), tssResp.TransactionResultCode) assert.Equal(t, fmt.Sprint(tss.NewStatus), tssResp.Status) - assert.Equal(t, "", tssResp.ResultXDR) - assert.Equal(t, "", tssResp.EnvelopeXDR) + assert.Empty(t, tssResp.ResultXDR) + assert.Empty(t, tssResp.EnvelopeXDR) clearTransactions(ctx) }) @@ -301,8 +315,10 @@ func TestGetTransaction(t *testing.T) { txHash := "hash" resultXdr := "resultXdr" ctx := context.Background() - _ = store.UpsertTransaction(ctx, "localhost:8080/webhook", txHash, "xdr", tss.RPCTXStatus{OtherStatus: tss.NewStatus}) - _ = store.UpsertTry(ctx, txHash, "feebumphash", "feebumpxdr", tss.RPCTXStatus{OtherStatus: tss.NewStatus}, tss.RPCTXCode{OtherCodes: tss.NewCode}, resultXdr) + err = store.UpsertTransaction(ctx, "localhost:8080/webhook", txHash, "xdr", tss.RPCTXStatus{OtherStatus: tss.NewStatus}) + require.NoError(t, err) + err = store.UpsertTry(ctx, txHash, "feebumphash", "feebumpxdr", tss.RPCTXStatus{OtherStatus: tss.NewStatus}, tss.RPCTXCode{OtherCodes: tss.NewCode}, resultXdr) + require.NoError(t, err) req, err := http.NewRequest(http.MethodGet, path.Join(endpoint, txHash), nil) require.NoError(t, err) @@ -313,7 +329,8 @@ func TestGetTransaction(t *testing.T) { respBody, err := io.ReadAll(resp.Body) require.NoError(t, err) var tssResp tss.TSSResponse - _ = json.Unmarshal(respBody, &tssResp) + err = json.Unmarshal(respBody, &tssResp) + require.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) assert.Equal(t, txHash, tssResp.TransactionHash) @@ -336,13 +353,12 @@ func TestGetTransaction(t *testing.T) { respBody, err := io.ReadAll(resp.Body) require.NoError(t, err) var tssResp tss.TSSResponse - _ = json.Unmarshal(respBody, &tssResp) + err = json.Unmarshal(respBody, &tssResp) + require.NoError(t, err) assert.Equal(t, http.StatusNotFound, resp.StatusCode) assert.Empty(t, tssResp.TransactionHash) assert.Empty(t, tssResp.EnvelopeXDR) assert.Empty(t, tssResp.Status) - }) - } diff --git a/internal/serve/middleware/metrics_middleware.go b/internal/serve/middleware/metrics_middleware.go index bdc27c2d4..f9f8584df 100644 --- a/internal/serve/middleware/metrics_middleware.go +++ b/internal/serve/middleware/metrics_middleware.go @@ -43,6 +43,7 @@ func (rw *responseWriter) WriteHeader(code int) { rw.ResponseWriter.WriteHeader(code) } +//nolint:wrapcheck // This is a thin wrapper around the ResponseWriter func (rw *responseWriter) Write(b []byte) (int, error) { // If WriteHeader hasn't been called yet, we assume it's a 200 if rw.statusCode == 0 { diff --git a/internal/serve/middleware/middleware.go b/internal/serve/middleware/middleware.go index 0e207be13..9b1c43f11 100644 --- a/internal/serve/middleware/middleware.go +++ b/internal/serve/middleware/middleware.go @@ -8,6 +8,7 @@ import ( "net/http" "github.com/stellar/go/support/log" + "github.com/stellar/wallet-backend/internal/apptracker" "github.com/stellar/wallet-backend/internal/serve/auth" "github.com/stellar/wallet-backend/internal/serve/httperror" diff --git a/internal/serve/middleware/middleware_test.go b/internal/serve/middleware/middleware_test.go index b54f49017..d4479c081 100644 --- a/internal/serve/middleware/middleware_test.go +++ b/internal/serve/middleware/middleware_test.go @@ -11,11 +11,12 @@ import ( "github.com/go-chi/chi" "github.com/stellar/go/support/log" - "github.com/stellar/wallet-backend/internal/apptracker" - "github.com/stellar/wallet-backend/internal/serve/auth" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + + "github.com/stellar/wallet-backend/internal/apptracker" + "github.com/stellar/wallet-backend/internal/serve/auth" ) func TestSignatureMiddleware(t *testing.T) { @@ -160,10 +161,10 @@ func TestRecoverHandler(t *testing.T) { // assert response assert.Equal(t, http.StatusInternalServerError, rr.Code) - wantJson := `{ + wantJSON := `{ "error": "An error occurred while processing this request." }` - assert.JSONEq(t, wantJson, rr.Body.String()) + assert.JSONEq(t, wantJSON, rr.Body.String()) entries := getEntries() require.Len(t, entries, 2) diff --git a/internal/serve/serve.go b/internal/serve/serve.go index fa9cb5e78..aceb087f4 100644 --- a/internal/serve/serve.go +++ b/internal/serve/serve.go @@ -199,7 +199,6 @@ func initHandlerDeps(ctx context.Context, cfg Configs) (handlerDeps, error) { RPCService: rpcService, BaseFee: int64(cfg.BaseFee), }) - if err != nil { return handlerDeps{}, fmt.Errorf("instantiating tss transaction service: %w", err) } diff --git a/internal/services/account_service.go b/internal/services/account_service.go index 314eb93ff..d1c8d9ad6 100644 --- a/internal/services/account_service.go +++ b/internal/services/account_service.go @@ -2,9 +2,8 @@ package services import ( "context" - "fmt" - "errors" + "fmt" "github.com/stellar/wallet-backend/internal/data" "github.com/stellar/wallet-backend/internal/metrics" diff --git a/internal/services/account_service_test.go b/internal/services/account_service_test.go index 3b1409004..f31a27b66 100644 --- a/internal/services/account_service_test.go +++ b/internal/services/account_service_test.go @@ -6,13 +6,14 @@ import ( "testing" "github.com/stellar/go/keypair" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stellar/wallet-backend/internal/data" "github.com/stellar/wallet-backend/internal/db" "github.com/stellar/wallet-backend/internal/db/dbtest" "github.com/stellar/wallet-backend/internal/metrics" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" ) func TestAccountRegister(t *testing.T) { diff --git a/internal/services/account_sponsorship_service.go b/internal/services/account_sponsorship_service.go index aebfa1129..288d8c4dc 100644 --- a/internal/services/account_sponsorship_service.go +++ b/internal/services/account_sponsorship_service.go @@ -24,11 +24,11 @@ var ( ErrAccountNotFound = errors.New("account not found") ) -type ErrOperationNotAllowed struct { +type OperationNotAllowedError struct { OperationType xdr.OperationType } -func (e ErrOperationNotAllowed) Error() string { +func (e OperationNotAllowedError) Error() string { return fmt.Sprintf("operation %s not allowed", e.OperationType.String()) } @@ -67,7 +67,7 @@ func (s *accountSponsorshipService) SponsorAccountCreationTransaction(ctx contex fullSignerWeight, err := entities.ValidateSignersWeights(signers) if err != nil { - return "", "", err + return "", "", fmt.Errorf("validating signers weights: %w", err) } // Make sure the total number of entries does not exceed the numSponsoredThreshold @@ -180,14 +180,14 @@ func (s *accountSponsorshipService) WrapTransaction(ctx context.Context, tx *txn } for _, op := range tx.Operations() { - operationXDR, err := op.BuildXDR() - if err != nil { - return "", "", fmt.Errorf("retrieving xdr for operation: %w", err) + operationXDR, innerErr := op.BuildXDR() + if innerErr != nil { + return "", "", fmt.Errorf("retrieving xdr for operation: %w", innerErr) } if slices.Contains(s.BlockedOperationsTypes, operationXDR.Body.Type) { log.Ctx(ctx).Warnf("blocked operation type: %s", operationXDR.Body.Type.String()) - return "", "", &ErrOperationNotAllowed{OperationType: operationXDR.Body.Type} + return "", "", &OperationNotAllowedError{OperationType: operationXDR.Body.Type} } } @@ -265,7 +265,7 @@ func (o *AccountSponsorshipServiceOptions) Validate() error { func NewAccountSponsorshipService(opts AccountSponsorshipServiceOptions) (*accountSponsorshipService, error) { if err := opts.Validate(); err != nil { - return nil, err + return nil, fmt.Errorf("validating account sponsorship service options: %w", err) } return &accountSponsorshipService{ diff --git a/internal/services/account_sponsorship_service_test.go b/internal/services/account_sponsorship_service_test.go index d118b47c8..80c5b8409 100644 --- a/internal/services/account_sponsorship_service_test.go +++ b/internal/services/account_sponsorship_service_test.go @@ -83,7 +83,7 @@ func TestAccountSponsorshipServiceSponsorAccountCreationTransaction(t *testing.T } txe, networkPassphrase, err := s.SponsorAccountCreationTransaction(ctx, accountToSponsor, signers, []entities.Asset{}) - assert.EqualError(t, err, "no full signers provided") + assert.EqualError(t, err, "validating signers weights: no full signers provided") assert.Empty(t, txe) assert.Empty(t, networkPassphrase) }) @@ -406,8 +406,8 @@ func TestAccountSponsorshipServiceWrapTransaction(t *testing.T) { require.NoError(t, err) feeBumpTxe, networkPassphrase, err := s.WrapTransaction(ctx, tx) - var errOperationNotAllowed *ErrOperationNotAllowed - assert.ErrorAs(t, err, &errOperationNotAllowed) + var opNotAllowedErr *OperationNotAllowedError + assert.ErrorAs(t, err, &opNotAllowedErr) assert.Empty(t, feeBumpTxe) assert.Empty(t, networkPassphrase) }) diff --git a/internal/services/channel_account_service.go b/internal/services/channel_account_service.go index bc32b7ac4..165dd26cb 100644 --- a/internal/services/channel_account_service.go +++ b/internal/services/channel_account_service.go @@ -60,15 +60,15 @@ func (s *channelAccountService) EnsureChannelAccounts(ctx context.Context, numbe ops := make([]txnbuild.Operation, 0, numOfChannelAccountsToCreate) channelAccountsToInsert := []*store.ChannelAccount{} for range numOfChannelAccountsToCreate { - kp, err := keypair.Random() - if err != nil { - return fmt.Errorf("generating random keypair for channel account: %w", err) + kp, innerErr := keypair.Random() + if innerErr != nil { + return fmt.Errorf("generating random keypair for channel account: %w", innerErr) } log.Ctx(ctx).Infof("⏳ Creating Stellar channel account with address: %s", kp.Address()) - encryptedPrivateKey, err := s.PrivateKeyEncrypter.Encrypt(ctx, kp.Seed(), s.EncryptionPassphrase) - if err != nil { - return fmt.Errorf("encrypting channel account private key: %w", err) + encryptedPrivateKey, innerErr := s.PrivateKeyEncrypter.Encrypt(ctx, kp.Seed(), s.EncryptionPassphrase) + if innerErr != nil { + return fmt.Errorf("encrypting channel account private key: %w", innerErr) } ops = append(ops, &txnbuild.CreateAccount{ @@ -245,8 +245,9 @@ func (o *ChannelAccountServiceOptions) Validate() error { } func NewChannelAccountService(ctx context.Context, opts ChannelAccountServiceOptions) (*channelAccountService, error) { - if err := opts.Validate(); err != nil { - return nil, err + err := opts.Validate() + if err != nil { + return nil, fmt.Errorf("validating channel account service options: %w", err) } go opts.RPCService.TrackRPCServiceHealth(ctx) diff --git a/internal/services/channel_account_service_test.go b/internal/services/channel_account_service_test.go index 061a7b70d..90670a39a 100644 --- a/internal/services/channel_account_service_test.go +++ b/internal/services/channel_account_service_test.go @@ -24,8 +24,8 @@ func TestChannelAccountServiceEnsureChannelAccounts(t *testing.T) { dbt := dbtest.Open(t) defer dbt.Close() - dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) - require.NoError(t, err) + dbConnectionPool, outerErr := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) defer dbConnectionPool.Close() ctx := context.Background() @@ -36,7 +36,7 @@ func TestChannelAccountServiceEnsureChannelAccounts(t *testing.T) { channelAccountStore := store.ChannelAccountStoreMock{} privateKeyEncrypter := signingutils.DefaultPrivateKeyEncrypter{} passphrase := "test" - s, err := NewChannelAccountService(ctx, ChannelAccountServiceOptions{ + s, outerErr := NewChannelAccountService(ctx, ChannelAccountServiceOptions{ DB: dbConnectionPool, RPCService: &mockRPCService, BaseFee: 100 * txnbuild.MinBaseFee, @@ -45,7 +45,8 @@ func TestChannelAccountServiceEnsureChannelAccounts(t *testing.T) { PrivateKeyEncrypter: &privateKeyEncrypter, EncryptionPassphrase: passphrase, }) - require.NoError(t, err) + require.NoError(t, outerErr) + time.Sleep(100 * time.Millisecond) // waiting for the goroutine to call `TrackRPCServiceHealth` t.Run("sufficient_number_of_channel_accounts", func(t *testing.T) { channelAccountStore. @@ -89,7 +90,7 @@ func TestChannelAccountServiceEnsureChannelAccounts(t *testing.T) { channelAccountsAddressesBeingInserted = append(channelAccountsAddressesBeingInserted, caOp.Destination) } - tx, err = tx.Sign(network.TestNetworkPassphrase, distributionAccount) + tx, err := tx.Sign(network.TestNetworkPassphrase, distributionAccount) require.NoError(t, err) signedTx = *tx @@ -106,7 +107,8 @@ func TestChannelAccountServiceEnsureChannelAccounts(t *testing.T) { Return(entities.RPCGetHealthResult{Status: "healthy"}, nil) // Create and set up the heartbeat channel - health, _ := mockRPCService.GetHealth() + health, err := mockRPCService.GetHealth() + require.NoError(t, err) heartbeatChan <- health mockRPCService.On("GetHeartbeatChannel").Return(heartbeatChan) @@ -165,7 +167,7 @@ func TestChannelAccountServiceEnsureChannelAccounts(t *testing.T) { tx, ok := args.Get(1).(*txnbuild.Transaction) require.True(t, ok) - tx, err = tx.Sign(network.TestNetworkPassphrase, distributionAccount) + tx, err := tx.Sign(network.TestNetworkPassphrase, distributionAccount) require.NoError(t, err) signedTx = *tx @@ -182,7 +184,8 @@ func TestChannelAccountServiceEnsureChannelAccounts(t *testing.T) { Return(entities.RPCGetHealthResult{Status: "healthy"}, nil) // Create and set up the heartbeat channel - health, _ := mockRPCService.GetHealth() + health, err := mockRPCService.GetHealth() + require.NoError(t, err) heartbeatChan <- health mockRPCService.On("GetHeartbeatChannel").Return(heartbeatChan) @@ -223,7 +226,7 @@ func TestChannelAccountServiceEnsureChannelAccounts(t *testing.T) { tx, ok := args.Get(1).(*txnbuild.Transaction) require.True(t, ok) - tx, err = tx.Sign(network.TestNetworkPassphrase, distributionAccount) + tx, err := tx.Sign(network.TestNetworkPassphrase, distributionAccount) require.NoError(t, err) signedTx = *tx @@ -258,7 +261,7 @@ func TestChannelAccountServiceEnsureChannelAccounts(t *testing.T) { Once() defer mockRPCService.AssertExpectations(t) - err = s.EnsureChannelAccounts(ctx, 5) + err := s.EnsureChannelAccounts(ctx, 5) require.Error(t, err) assert.Contains(t, err.Error(), "failed with status FAILED and errorResultXdr error_xdr") }) @@ -316,6 +319,7 @@ func TestSubmitTransaction(t *testing.T) { EncryptionPassphrase: passphrase, }) require.NoError(t, err) + time.Sleep(100 * time.Millisecond) // waiting for the goroutine to call `TrackRPCServiceHealth` hash := "test_hash" signedTxXDR := "test_xdr" @@ -371,6 +375,7 @@ func TestWaitForTransactionConfirmation(t *testing.T) { EncryptionPassphrase: passphrase, }) require.NoError(t, err) + time.Sleep(100 * time.Millisecond) // waiting for the goroutine to call `TrackRPCServiceHealth` hash := "test_hash" diff --git a/internal/services/ingest.go b/internal/services/ingest.go index 6a955ef20..e4131777e 100644 --- a/internal/services/ingest.go +++ b/internal/services/ingest.go @@ -157,7 +157,6 @@ func (m *ingestService) Run(ctx context.Context, startLedger uint32, endLedger u } func (m *ingestService) GetLedgerTransactions(ledger int64) ([]entities.Transaction, error) { - var ledgerTransactions []entities.Transaction var cursor string lastLedgerSeen := ledger @@ -181,7 +180,7 @@ func (m *ingestService) GetLedgerTransactions(ledger int64) ([]entities.Transact } func (m *ingestService) ingestPayments(ctx context.Context, ledgerTransactions []entities.Transaction) error { - return db.RunInTransaction(ctx, m.models.Payments.DB, nil, func(dbTx db.Transaction) error { + err := db.RunInTransaction(ctx, m.models.Payments.DB, nil, func(dbTx db.Transaction) error { paymentOpsIngested := 0 pathPaymentStrictSendOpsIngested := 0 pathPaymentStrictReceiveOpsIngested := 0 @@ -244,6 +243,11 @@ func (m *ingestService) ingestPayments(ctx context.Context, ledgerTransactions [ m.metricsService.SetNumPaymentOpsIngestedPerLedger(pathPaymentStrictReceivePrometheusLabel, pathPaymentStrictReceiveOpsIngested) return nil }) + if err != nil { + return fmt.Errorf("ingesting payments: %w", err) + } + + return nil } func (m *ingestService) processTSSTransactions(ctx context.Context, ledgerTransactions []entities.Transaction) error { @@ -294,7 +298,7 @@ func (m *ingestService) processTSSTransactions(ctx context.Context, ledgerTransa CreatedAt: int64(tx.CreatedAt), } payload := tss.Payload{ - RpcGetIngestTxResponse: tssGetIngestResponse, + RPCGetIngestTxResponse: tssGetIngestResponse, } err = m.tssRouter.Route(payload) if err != nil { diff --git a/internal/services/ingest_test.go b/internal/services/ingest_test.go index 901b1bafb..b3c3aecc9 100644 --- a/internal/services/ingest_test.go +++ b/internal/services/ingest_test.go @@ -41,8 +41,10 @@ func TestGetLedgerTransactions(t *testing.T) { mockAppTracker := apptracker.MockAppTracker{} mockRPCService := RPCServiceMock{} mockRouter := tssrouter.MockRouter{} - tssStore, _ := tssstore.NewStore(dbConnectionPool, mockMetricsService) - ingestService, _ := NewIngestService(models, "ingestionLedger", &mockAppTracker, &mockRPCService, &mockRouter, tssStore, mockMetricsService) + tssStore, err := tssstore.NewStore(dbConnectionPool, mockMetricsService) + require.NoError(t, err) + ingestService, err := NewIngestService(models, "ingestionLedger", &mockAppTracker, &mockRPCService, &mockRouter, tssStore, mockMetricsService) + require.NoError(t, err) t.Run("all_ledger_transactions_in_single_gettransactions_call", func(t *testing.T) { defer mockMetricsService.AssertExpectations(t) @@ -123,7 +125,6 @@ func TestGetLedgerTransactions(t *testing.T) { assert.Equal(t, txns[2].Hash, "hash3") assert.NoError(t, err) }) - } func TestProcessTSSTransactions(t *testing.T) { @@ -140,8 +141,10 @@ func TestProcessTSSTransactions(t *testing.T) { mockAppTracker := apptracker.MockAppTracker{} mockRPCService := RPCServiceMock{} mockRouter := tssrouter.MockRouter{} - tssStore, _ := tssstore.NewStore(dbConnectionPool, mockMetricsService) - ingestService, _ := NewIngestService(models, "ingestionLedger", &mockAppTracker, &mockRPCService, &mockRouter, tssStore, mockMetricsService) + tssStore, err := tssstore.NewStore(dbConnectionPool, mockMetricsService) + require.NoError(t, err) + ingestService, err := NewIngestService(models, "ingestionLedger", &mockAppTracker, &mockRPCService, &mockRouter, tssStore, mockMetricsService) + require.NoError(t, err) t.Run("routes_to_tss_router", func(t *testing.T) { mockMetricsService.On("ObserveDBQueryDuration", "INSERT", "tss_transactions", mock.AnythingOfType("float64")).Times(2) @@ -172,8 +175,10 @@ func TestProcessTSSTransactions(t *testing.T) { }, } - _ = tssStore.UpsertTransaction(context.Background(), "localhost:8000/webhook", "hash", "xdr", tss.RPCTXStatus{OtherStatus: tss.NewStatus}) - _ = tssStore.UpsertTry(context.Background(), "hash", "feebumphash", "feebumpxdr", tss.RPCTXStatus{OtherStatus: tss.NewStatus}, tss.RPCTXCode{OtherCodes: tss.NewCode}, "") + err = tssStore.UpsertTransaction(context.Background(), "localhost:8000/webhook", "hash", "xdr", tss.RPCTXStatus{OtherStatus: tss.NewStatus}) + require.NoError(t, err) + err = tssStore.UpsertTry(context.Background(), "hash", "feebumphash", "feebumpxdr", tss.RPCTXStatus{OtherStatus: tss.NewStatus}, tss.RPCTXCode{OtherCodes: tss.NewCode}, "") + require.NoError(t, err) mockRouter. On("Route", mock.AnythingOfType("tss.Payload")). @@ -183,9 +188,11 @@ func TestProcessTSSTransactions(t *testing.T) { err := ingestService.processTSSTransactions(context.Background(), transactions) assert.NoError(t, err) - updatedTX, _ := tssStore.GetTransaction(context.Background(), "hash") + updatedTX, err := tssStore.GetTransaction(context.Background(), "hash") + require.NoError(t, err) assert.Equal(t, string(entities.SuccessStatus), updatedTX.Status) - updatedTry, _ := tssStore.GetTry(context.Background(), "feebumphash") + updatedTry, err := tssStore.GetTry(context.Background(), "feebumphash") + require.NoError(t, err) assert.Equal(t, "AAAAAAAAAMj////9AAAAAA==", updatedTry.ResultXDR) assert.Equal(t, int32(xdr.TransactionResultCodeTxTooLate), updatedTry.Code) }) @@ -205,8 +212,10 @@ func TestIngestPayments(t *testing.T) { mockAppTracker := apptracker.MockAppTracker{} mockRPCService := RPCServiceMock{} mockRouter := tssrouter.MockRouter{} - tssStore, _ := tssstore.NewStore(dbConnectionPool, mockMetricsService) - ingestService, _ := NewIngestService(models, "ingestionLedger", &mockAppTracker, &mockRPCService, &mockRouter, tssStore, mockMetricsService) + tssStore, err := tssstore.NewStore(dbConnectionPool, mockMetricsService) + require.NoError(t, err) + ingestService, err := NewIngestService(models, "ingestionLedger", &mockAppTracker, &mockRPCService, &mockRouter, tssStore, mockMetricsService) + require.NoError(t, err) srcAccount := keypair.MustRandom().Address() destAccount := keypair.MustRandom().Address() usdIssuer := keypair.MustRandom().Address() @@ -224,22 +233,26 @@ func TestIngestPayments(t *testing.T) { mockMetricsService.On("SetNumPaymentOpsIngestedPerLedger", "path_payment_strict_receive", 0).Once() defer mockMetricsService.AssertExpectations(t) - _ = models.Account.Insert(context.Background(), srcAccount) + err = models.Account.Insert(context.Background(), srcAccount) + require.NoError(t, err) paymentOp := txnbuild.Payment{ SourceAccount: srcAccount, Destination: destAccount, Amount: "10", Asset: txnbuild.NativeAsset{}, } - transaction, _ := txnbuild.NewTransaction(txnbuild.TransactionParams{ + var transaction *txnbuild.Transaction + transaction, err = txnbuild.NewTransaction(txnbuild.TransactionParams{ SourceAccount: &txnbuild.SimpleAccount{ AccountID: keypair.MustRandom().Address(), }, Operations: []txnbuild.Operation{&paymentOp}, Preconditions: txnbuild.Preconditions{TimeBounds: txnbuild.NewTimeout(10)}, }) - - txEnvXDR, _ := transaction.Base64() + require.NoError(t, err) + var txEnvXDR string + txEnvXDR, err = transaction.Base64() + require.NoError(t, err) ledgerTransaction := entities.Transaction{ Status: entities.SuccessStatus, @@ -253,10 +266,11 @@ func TestIngestPayments(t *testing.T) { ledgerTransactions := []entities.Transaction{ledgerTransaction} - err := ingestService.ingestPayments(context.Background(), ledgerTransactions) - assert.NoError(t, err) + err = ingestService.ingestPayments(context.Background(), ledgerTransactions) + require.NoError(t, err) - payments, _, _, err := models.Payments.GetPaymentsPaginated(context.Background(), srcAccount, "", "", data.ASC, 1) + var payments []data.Payment + payments, _, _, err = models.Payments.GetPaymentsPaginated(context.Background(), srcAccount, "", "", data.ASC, 1) assert.NoError(t, err) assert.Equal(t, payments[0].TransactionHash, "abcd") }) @@ -273,7 +287,8 @@ func TestIngestPayments(t *testing.T) { mockMetricsService.On("SetNumPaymentOpsIngestedPerLedger", "path_payment_strict_receive", 0).Once() defer mockMetricsService.AssertExpectations(t) - _ = models.Account.Insert(context.Background(), srcAccount) + err = models.Account.Insert(context.Background(), srcAccount) + require.NoError(t, err) path := []txnbuild.Asset{ txnbuild.CreditAsset{Code: "USD", Issuer: usdIssuer}, @@ -289,20 +304,26 @@ func TestIngestPayments(t *testing.T) { DestAsset: txnbuild.NativeAsset{}, Path: path, } - transaction, _ := txnbuild.NewTransaction(txnbuild.TransactionParams{ + var transaction *txnbuild.Transaction + transaction, err = txnbuild.NewTransaction(txnbuild.TransactionParams{ SourceAccount: &txnbuild.SimpleAccount{ AccountID: keypair.MustRandom().Address(), }, Operations: []txnbuild.Operation{&pathPaymentOp}, Preconditions: txnbuild.Preconditions{TimeBounds: txnbuild.NewTimeout(10)}, }) + require.NoError(t, err) signer := keypair.MustRandom() - _ = models.Account.Insert(context.Background(), signer.Address()) - - signedTx, _ := transaction.Sign(network.TestNetworkPassphrase, signer) + err = models.Account.Insert(context.Background(), signer.Address()) + require.NoError(t, err) + var signedTx *txnbuild.Transaction + signedTx, err = transaction.Sign(network.TestNetworkPassphrase, signer) + require.NoError(t, err) - txEnvXDR, _ := signedTx.Base64() + var txEnvXDR string + txEnvXDR, err = signedTx.Base64() + require.NoError(t, err) ledgerTransaction := entities.Transaction{ Status: entities.SuccessStatus, @@ -319,7 +340,8 @@ func TestIngestPayments(t *testing.T) { err = ingestService.ingestPayments(context.Background(), ledgerTransactions) require.NoError(t, err) - payments, _, _, err := models.Payments.GetPaymentsPaginated(context.Background(), srcAccount, "", "", data.ASC, 1) + var payments []data.Payment + payments, _, _, err = models.Payments.GetPaymentsPaginated(context.Background(), srcAccount, "", "", data.ASC, 1) require.NoError(t, err) require.NotEmpty(t, payments, "Expected at least one payment") assert.Equal(t, payments[0].TransactionHash, ledgerTransaction.Hash) @@ -343,7 +365,8 @@ func TestIngestPayments(t *testing.T) { mockMetricsService.On("SetNumPaymentOpsIngestedPerLedger", "path_payment_strict_receive", 1).Once() defer mockMetricsService.AssertExpectations(t) - _ = models.Account.Insert(context.Background(), srcAccount) + err = models.Account.Insert(context.Background(), srcAccount) + require.NoError(t, err) path := []txnbuild.Asset{ txnbuild.CreditAsset{Code: "USD", Issuer: usdIssuer}, @@ -359,20 +382,24 @@ func TestIngestPayments(t *testing.T) { DestAsset: txnbuild.NativeAsset{}, Path: path, } - transaction, _ := txnbuild.NewTransaction(txnbuild.TransactionParams{ + transaction, err := txnbuild.NewTransaction(txnbuild.TransactionParams{ SourceAccount: &txnbuild.SimpleAccount{ AccountID: keypair.MustRandom().Address(), }, Operations: []txnbuild.Operation{&pathPaymentOp}, Preconditions: txnbuild.Preconditions{TimeBounds: txnbuild.NewTimeout(10)}, }) + require.NoError(t, err) signer := keypair.MustRandom() - _ = models.Account.Insert(context.Background(), signer.Address()) + err = models.Account.Insert(context.Background(), signer.Address()) + require.NoError(t, err) - signedTx, _ := transaction.Sign(network.TestNetworkPassphrase, signer) + signedTx, err := transaction.Sign(network.TestNetworkPassphrase, signer) + require.NoError(t, err) - txEnvXDR, _ := signedTx.Base64() + txEnvXDR, err := signedTx.Base64() + require.NoError(t, err) ledgerTransaction := entities.Transaction{ Status: entities.SuccessStatus, @@ -409,7 +436,7 @@ func TestIngest_LatestSyncedLedgerBehindRPC(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer func() { cancel() - _ = dbConnectionPool.Close() + require.NoError(t, dbConnectionPool.Close()) dbt.Close() }() @@ -450,14 +477,16 @@ func TestIngest_LatestSyncedLedgerBehindRPC(t *testing.T) { Amount: "10", Asset: txnbuild.NativeAsset{}, } - transaction, _ := txnbuild.NewTransaction(txnbuild.TransactionParams{ + transaction, err := txnbuild.NewTransaction(txnbuild.TransactionParams{ SourceAccount: &txnbuild.SimpleAccount{ AccountID: keypair.MustRandom().Address(), }, Operations: []txnbuild.Operation{&paymentOp}, Preconditions: txnbuild.Preconditions{TimeBounds: txnbuild.NewTimeout(10)}, }) - txEnvXDR, _ := transaction.Base64() + require.NoError(t, err) + txEnvXDR, err := transaction.Base64() + require.NoError(t, err) mockResult := entities.RPCGetTransactionsResult{ Transactions: []entities.Transaction{{ Status: entities.SuccessStatus, @@ -509,7 +538,7 @@ func TestIngest_LatestSyncedLedgerAheadOfRPC(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer func() { cancel() - _ = dbConnectionPool.Close() + require.NoError(t, dbConnectionPool.Close()) dbt.Close() log.DefaultLogger.SetOutput(os.Stderr) }() diff --git a/internal/services/kms_import_key_service.go b/internal/services/kms_import_key_service.go index 583958212..49a4a9aae 100644 --- a/internal/services/kms_import_key_service.go +++ b/internal/services/kms_import_key_service.go @@ -11,6 +11,7 @@ import ( "github.com/aws/aws-sdk-go/service/kms/kmsiface" "github.com/stellar/go/keypair" "github.com/stellar/go/strkey" + "github.com/stellar/wallet-backend/internal/signing/awskms" "github.com/stellar/wallet-backend/internal/signing/store" ) @@ -40,7 +41,7 @@ func (s *kmsImportService) ImportDistributionAccountKey(ctx context.Context, dis kp, err := keypair.ParseFull(distributionAccountSeed) if err != nil { - return fmt.Errorf("parsing distribution private key: %s", err) + return fmt.Errorf("parsing distribution account private key: %w", err) } if kp.Address() != s.distributionAccountPublicKey { @@ -59,9 +60,6 @@ func (s *kmsImportService) ImportDistributionAccountKey(ctx context.Context, dis err = s.keypairStore.Insert(ctx, kp.Address(), output.CiphertextBlob) if err != nil { - if errors.Is(err, store.ErrPublicKeyAlreadyExists) { - return err - } return fmt.Errorf("storing distribution account encrypted private key: %w", err) } diff --git a/internal/services/kms_import_key_service_test.go b/internal/services/kms_import_key_service_test.go index fcf11540c..227417c19 100644 --- a/internal/services/kms_import_key_service_test.go +++ b/internal/services/kms_import_key_service_test.go @@ -7,13 +7,14 @@ import ( "github.com/aws/aws-sdk-go/service/kms" "github.com/stellar/go/keypair" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stellar/wallet-backend/internal/db" "github.com/stellar/wallet-backend/internal/db/dbtest" "github.com/stellar/wallet-backend/internal/signing/awskms" "github.com/stellar/wallet-backend/internal/signing/store" "github.com/stellar/wallet-backend/internal/utils" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestKMSImportServiceImportDistributionAccountKey(t *testing.T) { diff --git a/internal/services/mocks.go b/internal/services/mocks.go index e17bdb551..54e560a5c 100644 --- a/internal/services/mocks.go +++ b/internal/services/mocks.go @@ -5,6 +5,8 @@ import ( "github.com/stretchr/testify/mock" + "github.com/stellar/go/txnbuild" + "github.com/stellar/wallet-backend/internal/entities" ) @@ -52,3 +54,19 @@ func (r *RPCServiceMock) GetAccountLedgerSequence(address string) (int64, error) args := r.Called(address) return args.Get(0).(int64), args.Error(1) } + +type AccountSponsorshipServiceMock struct { + mock.Mock +} + +var _ AccountSponsorshipService = (*AccountSponsorshipServiceMock)(nil) + +func (s *AccountSponsorshipServiceMock) SponsorAccountCreationTransaction(ctx context.Context, accountToSponsor string, signers []entities.Signer, assets []entities.Asset) (string, string, error) { + args := s.Called(ctx, accountToSponsor, signers, assets) + return args.String(0), args.String(1), args.Error(2) +} + +func (s *AccountSponsorshipServiceMock) WrapTransaction(ctx context.Context, tx *txnbuild.Transaction) (string, string, error) { + args := s.Called(ctx, tx) + return args.String(0), args.String(1), args.Error(2) +} diff --git a/internal/services/payment_service_test.go b/internal/services/payment_service_test.go index 9e30c10e3..4d44d7733 100644 --- a/internal/services/payment_service_test.go +++ b/internal/services/payment_service_test.go @@ -7,14 +7,15 @@ import ( "time" "github.com/stellar/go/xdr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stellar/wallet-backend/internal/data" "github.com/stellar/wallet-backend/internal/db" "github.com/stellar/wallet-backend/internal/db/dbtest" "github.com/stellar/wallet-backend/internal/entities" "github.com/stellar/wallet-backend/internal/metrics" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" ) func TestPaymentServiceGetPaymentsPaginated(t *testing.T) { diff --git a/internal/services/rpc_service.go b/internal/services/rpc_service.go index 77243b882..9183a3730 100644 --- a/internal/services/rpc_service.go +++ b/internal/services/rpc_service.go @@ -10,15 +10,16 @@ import ( "time" "github.com/stellar/go/support/log" + "github.com/stellar/wallet-backend/internal/entities" "github.com/stellar/wallet-backend/internal/metrics" "github.com/stellar/wallet-backend/internal/utils" ) const ( - rpcHealthCheckSleepTime = 5 * time.Second - rpcHealthCheckMaxWaitTime = 60 * time.Second - getHealthMethodName = "getHealth" + defaultHealthCheckTickInterval = 5 * time.Second + defaultHealthCheckWarningInterval = 60 * time.Second + getHealthMethodName = "getHealth" ) type RPCService interface { @@ -33,10 +34,12 @@ type RPCService interface { } type rpcService struct { - rpcURL string - httpClient utils.HTTPClient - heartbeatChannel chan entities.RPCGetHealthResult - metricsService metrics.MetricsService + rpcURL string + httpClient utils.HTTPClient + heartbeatChannel chan entities.RPCGetHealthResult + metricsService metrics.MetricsService + healthCheckWarningInterval time.Duration + healthCheckTickInterval time.Duration } var PageLimit = 200 @@ -56,10 +59,12 @@ func NewRPCService(rpcURL string, httpClient utils.HTTPClient, metricsService me heartbeatChannel := make(chan entities.RPCGetHealthResult, 1) return &rpcService{ - rpcURL: rpcURL, - httpClient: httpClient, - heartbeatChannel: heartbeatChannel, - metricsService: metricsService, + rpcURL: rpcURL, + httpClient: httpClient, + heartbeatChannel: heartbeatChannel, + metricsService: metricsService, + healthCheckWarningInterval: defaultHealthCheckWarningInterval, + healthCheckTickInterval: defaultHealthCheckTickInterval, }, nil } @@ -69,7 +74,6 @@ func (r *rpcService) GetHeartbeatChannel() chan entities.RPCGetHealthResult { func (r *rpcService) GetTransaction(transactionHash string) (entities.RPCGetTransactionResult, error) { resultBytes, err := r.sendRPCRequest("getTransaction", entities.RPCParams{Hash: transactionHash}) - if err != nil { return entities.RPCGetTransactionResult{}, fmt.Errorf("sending getTransaction request: %w", err) } @@ -110,7 +114,7 @@ func (r *rpcService) GetTransactions(startLedger int64, startCursor string, limi func (r *rpcService) GetHealth() (entities.RPCGetHealthResult, error) { resultBytes, err := r.sendRPCRequest("getHealth", entities.RPCParams{}) if err != nil { - return entities.RPCGetHealthResult{}, fmt.Errorf("sending getHealth request: %v", err) + return entities.RPCGetHealthResult{}, fmt.Errorf("sending getHealth request: %w", err) } var result entities.RPCGetHealthResult @@ -139,7 +143,6 @@ func (r *rpcService) GetLedgerEntries(keys []string) (entities.RPCGetLedgerEntri } func (r *rpcService) SendTransaction(transactionXDR string) (entities.RPCSendTransactionResult, error) { - resultBytes, err := r.sendRPCRequest("sendTransaction", entities.RPCParams{Transaction: transactionXDR}) if err != nil { return entities.RPCSendTransactionResult{}, fmt.Errorf("sending sendTransaction request: %w", err) @@ -173,9 +176,23 @@ func (r *rpcService) GetAccountLedgerSequence(address string) (int64, error) { return int64(accountEntry.SeqNum), nil } +func (r *rpcService) HealthCheckWarningInterval() time.Duration { + if utils.IsEmpty(r.healthCheckWarningInterval) { + return defaultHealthCheckWarningInterval + } + return r.healthCheckWarningInterval +} + +func (r *rpcService) HealthCheckTickInterval() time.Duration { + if utils.IsEmpty(r.healthCheckTickInterval) { + return defaultHealthCheckTickInterval + } + return r.healthCheckTickInterval +} + func (r *rpcService) TrackRPCServiceHealth(ctx context.Context) { - healthCheckTicker := time.NewTicker(rpcHealthCheckSleepTime) - warningTicker := time.NewTicker(rpcHealthCheckMaxWaitTime) + healthCheckTicker := time.NewTicker(r.HealthCheckTickInterval()) + warningTicker := time.NewTicker(r.HealthCheckWarningInterval()) defer func() { healthCheckTicker.Stop() warningTicker.Stop() @@ -187,9 +204,9 @@ func (r *rpcService) TrackRPCServiceHealth(ctx context.Context) { case <-ctx.Done(): return case <-warningTicker.C: - log.Warn(fmt.Sprintf("rpc service unhealthy for over %s", rpcHealthCheckMaxWaitTime)) + log.Warn(fmt.Sprintf("rpc service unhealthy for over %s", r.HealthCheckWarningInterval())) r.metricsService.SetRPCServiceHealth(false) - warningTicker.Reset(rpcHealthCheckMaxWaitTime) + warningTicker.Reset(r.HealthCheckWarningInterval()) case <-healthCheckTicker.C: result, err := r.GetHealth() if err != nil { @@ -200,7 +217,7 @@ func (r *rpcService) TrackRPCServiceHealth(ctx context.Context) { r.heartbeatChannel <- result r.metricsService.SetRPCServiceHealth(true) r.metricsService.SetRPCLatestLedger(int64(result.LatestLedger)) - warningTicker.Reset(rpcHealthCheckMaxWaitTime) + warningTicker.Reset(r.HealthCheckWarningInterval()) } } } @@ -234,13 +251,13 @@ func (r *rpcService) sendRPCRequest(method string, params entities.RPCParams) (j r.metricsService.IncRPCEndpointFailure(method) return nil, fmt.Errorf("sending POST request to RPC: %w", err) } - defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { r.metricsService.IncRPCEndpointFailure(method) return nil, fmt.Errorf("unmarshaling RPC response: %w", err) } + defer utils.DeferredClose(context.TODO(), resp.Body, "closing response body in the sendRPCRequest function") var res entities.RPCResponse err = json.Unmarshal(body, &res) diff --git a/internal/services/rpc_service_test.go b/internal/services/rpc_service_test.go index 1d8f90ab0..9455bd1c5 100644 --- a/internal/services/rpc_service_test.go +++ b/internal/services/rpc_service_test.go @@ -12,11 +12,11 @@ import ( "testing" "time" + "github.com/stellar/go/support/log" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" - "github.com/stellar/go/support/log" "github.com/stellar/wallet-backend/internal/db" "github.com/stellar/wallet-backend/internal/db/dbtest" "github.com/stellar/wallet-backend/internal/entities" @@ -44,7 +44,8 @@ func TestSendRPCRequest(t *testing.T) { mockMetricsService := metrics.NewMockMetricsService() mockHTTPClient := utils.MockHTTPClient{} rpcURL := "http://api.vibrantapp.com/soroban/rpc" - rpcService, _ := NewRPCService(rpcURL, &mockHTTPClient, mockMetricsService) + rpcService, err := NewRPCService(rpcURL, &mockHTTPClient, mockMetricsService) + require.NoError(t, err) t.Run("successful", func(t *testing.T) { mockMetricsService.On("IncRPCRequests", "sendTransaction").Once() @@ -171,7 +172,8 @@ func TestSendTransaction(t *testing.T) { mockMetricsService := metrics.NewMockMetricsService() mockHTTPClient := utils.MockHTTPClient{} rpcURL := "http://api.vibrantapp.com/soroban/rpc" - rpcService, _ := NewRPCService(rpcURL, &mockHTTPClient, mockMetricsService) + rpcService, err := NewRPCService(rpcURL, &mockHTTPClient, mockMetricsService) + require.NoError(t, err) t.Run("successful", func(t *testing.T) { mockMetricsService.On("IncRPCRequests", "sendTransaction").Once() @@ -188,7 +190,8 @@ func TestSendTransaction(t *testing.T) { "method": "sendTransaction", "params": params, } - jsonData, _ := json.Marshal(payload) + jsonData, err := json.Marshal(payload) + require.NoError(t, err) httpResponse := http.Response{ StatusCode: http.StatusOK, @@ -249,7 +252,8 @@ func TestGetTransaction(t *testing.T) { mockMetricsService := metrics.NewMockMetricsService() mockHTTPClient := utils.MockHTTPClient{} rpcURL := "http://api.vibrantapp.com/soroban/rpc" - rpcService, _ := NewRPCService(rpcURL, &mockHTTPClient, mockMetricsService) + rpcService, err := NewRPCService(rpcURL, &mockHTTPClient, mockMetricsService) + require.NoError(t, err) t.Run("successful", func(t *testing.T) { mockMetricsService.On("IncRPCRequests", "getTransaction").Once() @@ -266,7 +270,8 @@ func TestGetTransaction(t *testing.T) { "method": "getTransaction", "params": params, } - jsonData, _ := json.Marshal(payload) + jsonData, err := json.Marshal(payload) + require.NoError(t, err) httpResponse := http.Response{ StatusCode: http.StatusOK, @@ -342,7 +347,8 @@ func TestGetTransactions(t *testing.T) { mockMetricsService := metrics.NewMockMetricsService() mockHTTPClient := utils.MockHTTPClient{} rpcURL := "http://api.vibrantapp.com/soroban/rpc" - rpcService, _ := NewRPCService(rpcURL, &mockHTTPClient, mockMetricsService) + rpcService, err := NewRPCService(rpcURL, &mockHTTPClient, mockMetricsService) + require.NoError(t, err) t.Run("rpc_request_fails", func(t *testing.T) { mockMetricsService.On("IncRPCRequests", "getTransactions").Once() @@ -376,7 +382,8 @@ func TestGetTransactions(t *testing.T) { "method": "getTransactions", "params": params, } - jsonData, _ := json.Marshal(payload) + jsonData, err := json.Marshal(payload) + require.NoError(t, err) httpResponse := http.Response{ StatusCode: http.StatusOK, @@ -422,7 +429,8 @@ func TestSendGetHealth(t *testing.T) { mockMetricsService := metrics.NewMockMetricsService() mockHTTPClient := utils.MockHTTPClient{} rpcURL := "http://api.vibrantapp.com/soroban/rpc" - rpcService, _ := NewRPCService(rpcURL, &mockHTTPClient, mockMetricsService) + rpcService, err := NewRPCService(rpcURL, &mockHTTPClient, mockMetricsService) + require.NoError(t, err) t.Run("successful", func(t *testing.T) { mockMetricsService.On("IncRPCRequests", "getHealth").Once() @@ -435,7 +443,8 @@ func TestSendGetHealth(t *testing.T) { "id": 1, "method": "getHealth", } - jsonData, _ := json.Marshal(payload) + jsonData, err := json.Marshal(payload) + require.NoError(t, err) httpResponse := http.Response{ StatusCode: http.StatusOK, @@ -536,7 +545,10 @@ func TestTrackRPCServiceHealth_HealthyService(t *testing.T) { } func TestTrackRPCServiceHealth_UnhealthyService(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 70*time.Second) + healthCheckTickInterval := 300 * time.Millisecond + healthCheckWarningInterval := 400 * time.Millisecond + contextTimeout := healthCheckWarningInterval + time.Millisecond*190 + ctx, cancel := context.WithTimeout(context.Background(), contextTimeout) defer cancel() dbt := dbtest.Open(t) @@ -559,36 +571,43 @@ func TestTrackRPCServiceHealth_UnhealthyService(t *testing.T) { getLogs := log.DefaultLogger.StartTest(log.WarnLevel) mockHTTPClient := &utils.MockHTTPClient{} + defer mockHTTPClient.AssertExpectations(t) rpcURL := "http://test-url-track-rpc-service-health" rpcService, err := NewRPCService(rpcURL, mockHTTPClient, mockMetricsService) require.NoError(t, err) + rpcService.healthCheckTickInterval = healthCheckTickInterval + rpcService.healthCheckWarningInterval = healthCheckWarningInterval // Mock error response for GetHealth with a valid http.Response + getHealthRequestBody, err := json.Marshal(map[string]any{"jsonrpc": "2.0", "id": 1, "method": "getHealth"}) + require.NoError(t, err) + getHealthResponseBody := `{ + "jsonrpc": "2.0", + "id": 1, + "error": { + "code": -32601, + "message": "rpc error" + } + }` mockResponse := &http.Response{ - Body: io.NopCloser(bytes.NewBuffer([]byte(`{ - "jsonrpc": "2.0", - "id": 1, - "error": { - "code": -32601, - "message": "rpc error" - } - }`))), + Body: io.NopCloser(strings.NewReader(getHealthResponseBody)), } - mockHTTPClient.On("Post", rpcURL, "application/json", mock.Anything). + mockHTTPClient.On("Post", rpcURL, "application/json", bytes.NewBuffer(getHealthRequestBody)). Return(mockResponse, nil) - // The ctx will timeout after 70 seconds, which is enough for the warning to trigger + // The ctx will timeout after {contextTimeout}, which is enough for the warning to trigger rpcService.TrackRPCServiceHealth(ctx) entries := getLogs() - testFailed := true + testSucceeded := false for _, entry := range entries { - if strings.Contains(entry.Message, "rpc service unhealthy for over 1m0s") { - testFailed = false + t.Logf("entry: %v\n", entry.Message) + if strings.Contains(entry.Message, "rpc service unhealthy for over "+healthCheckWarningInterval.String()) { + testSucceeded = true + break } } - assert.False(t, testFailed) - mockHTTPClient.AssertExpectations(t) + assert.True(t, testSucceeded) } func TestTrackRPCService_ContextCancelled(t *testing.T) { diff --git a/internal/services/servicesmocks/account_sponsorship_mock.go b/internal/services/servicesmocks/account_sponsorship_mock.go deleted file mode 100644 index f60024692..000000000 --- a/internal/services/servicesmocks/account_sponsorship_mock.go +++ /dev/null @@ -1,26 +0,0 @@ -package servicesmocks - -import ( - "context" - - "github.com/stellar/go/txnbuild" - "github.com/stellar/wallet-backend/internal/entities" - "github.com/stellar/wallet-backend/internal/services" - "github.com/stretchr/testify/mock" -) - -type AccountSponsorshipServiceMock struct { - mock.Mock -} - -var _ services.AccountSponsorshipService = (*AccountSponsorshipServiceMock)(nil) - -func (s *AccountSponsorshipServiceMock) SponsorAccountCreationTransaction(ctx context.Context, accountToSponsor string, signers []entities.Signer, assets []entities.Asset) (string, string, error) { - args := s.Called(ctx, accountToSponsor, signers, assets) - return args.String(0), args.String(1), args.Error(2) -} - -func (s *AccountSponsorshipServiceMock) WrapTransaction(ctx context.Context, tx *txnbuild.Transaction) (string, string, error) { - args := s.Called(ctx, tx) - return args.String(0), args.String(1), args.Error(2) -} diff --git a/internal/signing/channel_account_db_signature_client.go b/internal/signing/channel_account_db_signature_client.go index 89b612094..bdecaa5df 100644 --- a/internal/signing/channel_account_db_signature_client.go +++ b/internal/signing/channel_account_db_signature_client.go @@ -10,9 +10,11 @@ import ( "github.com/stellar/go/network" "github.com/stellar/go/support/log" "github.com/stellar/go/txnbuild" + "github.com/stellar/wallet-backend/internal/db" "github.com/stellar/wallet-backend/internal/signing/store" signingutils "github.com/stellar/wallet-backend/internal/signing/utils" + "github.com/stellar/wallet-backend/internal/utils" ) type channelAccountDBSignatureClient struct { @@ -21,8 +23,15 @@ type channelAccountDBSignatureClient struct { encryptionPassphrase string privateKeyEncrypter signingutils.PrivateKeyEncrypter channelAccountStore store.ChannelAccountStore + retryInterval time.Duration + retryCount int } +const ( + DefaultRetryInterval = 1 * time.Second + DefaultRetryCount = 6 +) + var _ SignatureClient = (*channelAccountDBSignatureClient)(nil) func NewChannelAccountDBSignatureClient(dbConnectionPool db.ConnectionPool, networkPassphrase string, privateKeyEncrypter signingutils.PrivateKeyEncrypter, encryptionPassphrase string) (*channelAccountDBSignatureClient, error) { @@ -36,9 +45,25 @@ func NewChannelAccountDBSignatureClient(dbConnectionPool db.ConnectionPool, netw channelAccountStore: store.NewChannelAccountModel(dbConnectionPool), privateKeyEncrypter: privateKeyEncrypter, encryptionPassphrase: encryptionPassphrase, + retryInterval: DefaultRetryInterval, + retryCount: DefaultRetryCount, }, nil } +func (sc *channelAccountDBSignatureClient) RetryInterval() time.Duration { + if utils.IsEmpty(sc.retryInterval) { + return DefaultRetryInterval + } + return sc.retryInterval +} + +func (sc *channelAccountDBSignatureClient) RetryCount() int { + if utils.IsEmpty(sc.retryCount) { + return DefaultRetryCount + } + return sc.retryCount +} + func (sc *channelAccountDBSignatureClient) NetworkPassphrase() string { return sc.networkPassphrase } @@ -50,16 +75,17 @@ func (sc *channelAccountDBSignatureClient) GetAccountPublicKey(ctx context.Conte } else { lockedUntil = time.Minute } - for range store.ChannelAccountWaitTime { + for range sc.RetryCount() { // check to see if the variadic parameter for time exists and if so, use it here channelAccount, err := sc.channelAccountStore.GetAndLockIdleChannelAccount(ctx, lockedUntil) if err != nil { if errors.Is(err, store.ErrNoIdleChannelAccountAvailable) { - log.Ctx(ctx).Warn("All channel accounts are in use. Retry in 1 second.") - time.Sleep(1 * time.Second) + log.Ctx(ctx).Warnf("All channel accounts are in use. Retry in %s.", sc.RetryInterval()) + time.Sleep(sc.RetryInterval()) continue } - return "", fmt.Errorf("getting idle channel account: %w", err) + + return "", fmt.Errorf("could not get an idle channel account after %v: %w", time.Duration(sc.RetryCount())*sc.RetryInterval(), err) } return channelAccount.PublicKey, nil diff --git a/internal/signing/channel_account_db_signature_client_test.go b/internal/signing/channel_account_db_signature_client_test.go index 984e6e42c..eebe80e53 100644 --- a/internal/signing/channel_account_db_signature_client_test.go +++ b/internal/signing/channel_account_db_signature_client_test.go @@ -2,6 +2,7 @@ package signing import ( "context" + "fmt" "testing" "time" @@ -9,16 +10,19 @@ import ( "github.com/stellar/go/network" "github.com/stellar/go/support/log" "github.com/stellar/go/txnbuild" - "github.com/stellar/wallet-backend/internal/signing/store" - signingutils "github.com/stellar/wallet-backend/internal/signing/utils" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/stellar/wallet-backend/internal/signing/store" + signingutils "github.com/stellar/wallet-backend/internal/signing/utils" ) func TestChannelAccountDBSignatureClientGetAccountPublicKey(t *testing.T) { t.Parallel() ctx := context.Background() + retryCount := 6 + retryInterval := 100 * time.Millisecond privateKeyEncrypter := signingutils.DefaultPrivateKeyEncrypter{} channelAccountStore := store.ChannelAccountStoreMock{} sc := channelAccountDBSignatureClient{ @@ -26,15 +30,17 @@ func TestChannelAccountDBSignatureClientGetAccountPublicKey(t *testing.T) { encryptionPassphrase: "test", privateKeyEncrypter: &privateKeyEncrypter, channelAccountStore: &channelAccountStore, + retryCount: retryCount, + retryInterval: retryInterval, } t.Run("returns_error_when_couldn't_get_an_idle_channel_account", func(t *testing.T) { channelAccountStore. On("GetAndLockIdleChannelAccount", ctx, time.Duration(100)*time.Second). Return(nil, store.ErrNoIdleChannelAccountAvailable). - Times(6). + Times(retryCount). On("Count", ctx). - Return(5, nil). + Return(retryCount-1, nil). Once() defer channelAccountStore.AssertExpectations(t) @@ -45,10 +51,10 @@ func TestChannelAccountDBSignatureClientGetAccountPublicKey(t *testing.T) { assert.Empty(t, publicKey) entries := getEntries() - require.Len(t, entries, 6) + require.Len(t, entries, retryCount) for _, entry := range entries { - assert.Equal(t, entry.Message, "All channel accounts are in use. Retry in 1 second.") + assert.Equal(t, entry.Message, fmt.Sprintf("All channel accounts are in use. Retry in %s.", retryInterval)) } }) @@ -56,7 +62,7 @@ func TestChannelAccountDBSignatureClientGetAccountPublicKey(t *testing.T) { channelAccountStore. On("GetAndLockIdleChannelAccount", ctx, time.Minute). Return(nil, store.ErrNoIdleChannelAccountAvailable). - Times(6). + Times(retryCount). On("Count", ctx). Return(0, nil). Once() @@ -69,10 +75,10 @@ func TestChannelAccountDBSignatureClientGetAccountPublicKey(t *testing.T) { assert.Empty(t, publicKey) entries := getEntries() - require.Len(t, entries, 6) + require.Len(t, entries, retryCount) for _, entry := range entries { - assert.Equal(t, entry.Message, "All channel accounts are in use. Retry in 1 second.") + assert.Equal(t, fmt.Sprintf("All channel accounts are in use. Retry in %s.", retryInterval), entry.Message) } }) diff --git a/internal/signing/env_signature_client_test.go b/internal/signing/env_signature_client_test.go index 6655a318a..933b72052 100644 --- a/internal/signing/env_signature_client_test.go +++ b/internal/signing/env_signature_client_test.go @@ -17,7 +17,8 @@ func TestEnvSignatureClientGetAccountPublicKey(t *testing.T) { distributionAccount := keypair.MustRandom() sc, err := NewEnvSignatureClient(distributionAccount.Seed(), network.TestNetworkPassphrase) require.NoError(t, err) - publicKey, _ := sc.GetAccountPublicKey(ctx) + publicKey, err := sc.GetAccountPublicKey(ctx) + require.NoError(t, err) assert.Equal(t, distributionAccount.Address(), publicKey) } diff --git a/internal/signing/kms_signature_client.go b/internal/signing/kms_signature_client.go index a297f1fab..f28d89911 100644 --- a/internal/signing/kms_signature_client.go +++ b/internal/signing/kms_signature_client.go @@ -12,6 +12,7 @@ import ( "github.com/stellar/go/network" "github.com/stellar/go/strkey" "github.com/stellar/go/txnbuild" + "github.com/stellar/wallet-backend/internal/signing/awskms" "github.com/stellar/wallet-backend/internal/signing/store" ) @@ -85,7 +86,7 @@ func (sc *kmsSignatureClient) SignStellarTransaction(ctx context.Context, tx *tx kpFull, err := sc.getKPFull(ctx) if err != nil { - return nil, err + return nil, fmt.Errorf("getting keypair full in %T: %w", sc, err) } signedTx, err := tx.Sign(sc.NetworkPassphrase(), kpFull) @@ -103,7 +104,7 @@ func (sc *kmsSignatureClient) SignStellarFeeBumpTransaction(ctx context.Context, kpFull, err := sc.getKPFull(ctx) if err != nil { - return nil, err + return nil, fmt.Errorf("getting keypair full in %T: %w", sc, err) } signedFeeBumpTx, err := feeBumpTx.Sign(sc.NetworkPassphrase(), kpFull) diff --git a/internal/signing/kms_signature_client_test.go b/internal/signing/kms_signature_client_test.go index 26da3d8cd..dcdd8c5c4 100644 --- a/internal/signing/kms_signature_client_test.go +++ b/internal/signing/kms_signature_client_test.go @@ -10,20 +10,22 @@ import ( "github.com/stellar/go/keypair" "github.com/stellar/go/network" "github.com/stellar/go/txnbuild" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stellar/wallet-backend/internal/db" "github.com/stellar/wallet-backend/internal/db/dbtest" "github.com/stellar/wallet-backend/internal/signing/awskms" "github.com/stellar/wallet-backend/internal/signing/store" "github.com/stellar/wallet-backend/internal/utils" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestKMSSignatureClientGetAccountPublicKey(t *testing.T) { ctx := context.Background() distributionAccount := keypair.MustRandom() sc := kmsSignatureClient{distributionAccountPublicKey: distributionAccount.Address()} - publicKey, _ := sc.GetAccountPublicKey(ctx) + publicKey, err := sc.GetAccountPublicKey(ctx) + require.NoError(t, err) assert.Equal(t, distributionAccount.Address(), publicKey) } @@ -171,7 +173,7 @@ func TestKMSSignatureClientSignStellarTransaction(t *testing.T) { defer kmsClient.AssertExpectations(t) signedTx, err := sc.SignStellarTransaction(ctx, tx, distributionAccount.Address()) - assert.EqualError(t, err, "decrypting distribution account private key in *signing.kmsSignatureClient: unexpected error") + assert.ErrorContains(t, err, "decrypting distribution account private key in *signing.kmsSignatureClient: unexpected error") assert.Nil(t, signedTx) }) @@ -214,7 +216,7 @@ func TestKMSSignatureClientSignStellarTransaction(t *testing.T) { defer kmsClient.AssertExpectations(t) signedTx, err := sc.SignStellarTransaction(ctx, tx, distributionAccount.Address()) - assert.EqualError(t, err, "parsing distribution account private key in *signing.kmsSignatureClient: base32 decode failed: illegal base32 data at input byte 7") + assert.ErrorContains(t, err, "parsing distribution account private key in *signing.kmsSignatureClient: base32 decode failed: illegal base32 data at input byte 7") assert.Nil(t, signedTx) }) @@ -369,7 +371,7 @@ func TestKMSSignatureClientSignStellarFeeBumpTransaction(t *testing.T) { defer kmsClient.AssertExpectations(t) signedFeeBumpTx, err := sc.SignStellarFeeBumpTransaction(ctx, feeBumpTx) - assert.EqualError(t, err, "decrypting distribution account private key in *signing.kmsSignatureClient: unexpected error") + assert.ErrorContains(t, err, "decrypting distribution account private key in *signing.kmsSignatureClient: unexpected error") assert.Nil(t, signedFeeBumpTx) }) @@ -419,7 +421,7 @@ func TestKMSSignatureClientSignStellarFeeBumpTransaction(t *testing.T) { defer kmsClient.AssertExpectations(t) signedFeeBumpTx, err := sc.SignStellarFeeBumpTransaction(ctx, feeBumpTx) - assert.EqualError(t, err, "parsing distribution account private key in *signing.kmsSignatureClient: base32 decode failed: illegal base32 data at input byte 7") + assert.ErrorContains(t, err, "parsing distribution account private key in *signing.kmsSignatureClient: base32 decode failed: illegal base32 data at input byte 7") assert.Nil(t, signedFeeBumpTx) }) diff --git a/internal/signing/store/channel_accounts_model.go b/internal/signing/store/channel_accounts_model.go index d3dd1f2ae..146575e08 100644 --- a/internal/signing/store/channel_accounts_model.go +++ b/internal/signing/store/channel_accounts_model.go @@ -8,11 +8,10 @@ import ( "time" "github.com/lib/pq" + "github.com/stellar/wallet-backend/internal/db" ) -const ChannelAccountWaitTime = 6 - var ( ErrNoIdleChannelAccountAvailable = errors.New("no idle channel account available") ErrNoChannelAccountConfigured = errors.New("no channel accounts") diff --git a/internal/signing/store/channel_accounts_model_test.go b/internal/signing/store/channel_accounts_model_test.go index 2cd7f0d7e..2e3872461 100644 --- a/internal/signing/store/channel_accounts_model_test.go +++ b/internal/signing/store/channel_accounts_model_test.go @@ -7,10 +7,11 @@ import ( "github.com/lib/pq" "github.com/stellar/go/keypair" - "github.com/stellar/wallet-backend/internal/db" - "github.com/stellar/wallet-backend/internal/db/dbtest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/stellar/wallet-backend/internal/db" + "github.com/stellar/wallet-backend/internal/db/dbtest" ) func createChannelAccountFixture(t *testing.T, ctx context.Context, dbConnectionPool db.ConnectionPool, channelAccounts ...ChannelAccount) { @@ -149,7 +150,6 @@ func TestAssignTxToChannelAccount(t *testing.T) { channelAccountFromDB, err := m.Get(ctx, dbConnectionPool, channelAccount.Address()) assert.NoError(t, err) assert.Equal(t, "txhash", channelAccountFromDB.LockedTxHash.String) - } func TestUnlockChannelAccountFromTx(t *testing.T) { @@ -183,7 +183,6 @@ func TestUnlockChannelAccountFromTx(t *testing.T) { func TestChannelAccountModelBatchInsert(t *testing.T) { dbt := dbtest.Open(t) defer dbt.Close() - dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) require.NoError(t, err) defer dbConnectionPool.Close() @@ -192,7 +191,7 @@ func TestChannelAccountModelBatchInsert(t *testing.T) { m := NewChannelAccountModel(dbConnectionPool) t.Run("channel_accounts_empty", func(t *testing.T) { - err := m.BatchInsert(ctx, dbConnectionPool, []*ChannelAccount{}) + err = m.BatchInsert(ctx, dbConnectionPool, []*ChannelAccount{}) require.NoError(t, err) }) @@ -202,7 +201,7 @@ func TestChannelAccountModelBatchInsert(t *testing.T) { PublicKey: "", }, } - err := m.BatchInsert(ctx, dbConnectionPool, channelAccounts) + err = m.BatchInsert(ctx, dbConnectionPool, channelAccounts) assert.EqualError(t, err, "public key cannot be empty") channelAccounts = []*ChannelAccount{ diff --git a/internal/signing/store/keypairs_model.go b/internal/signing/store/keypairs_model.go index 26e891472..01364864e 100644 --- a/internal/signing/store/keypairs_model.go +++ b/internal/signing/store/keypairs_model.go @@ -7,6 +7,7 @@ import ( "fmt" "github.com/lib/pq" + "github.com/stellar/wallet-backend/internal/db" ) @@ -27,7 +28,8 @@ func (k *KeypairModel) Insert(ctx context.Context, publicKey string, encryptedPr ` _, err := k.DB.ExecContext(ctx, query, publicKey, encryptedPrivateKey) if err != nil { - if pqError, ok := err.(*pq.Error); ok && pqError.Constraint == "keypairs_pkey" { + var pqError *pq.Error + if ok := errors.As(err, &pqError); ok && pqError.Constraint == "keypairs_pkey" { return ErrPublicKeyAlreadyExists } return fmt.Errorf("inserting keypair for public key %s: %w", publicKey, err) diff --git a/internal/signing/store/keypairs_model_test.go b/internal/signing/store/keypairs_model_test.go index 0053c16b7..882b41fb0 100644 --- a/internal/signing/store/keypairs_model_test.go +++ b/internal/signing/store/keypairs_model_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/stellar/go/keypair" - "github.com/stellar/wallet-backend/internal/db" - "github.com/stellar/wallet-backend/internal/db/dbtest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/stellar/wallet-backend/internal/db" + "github.com/stellar/wallet-backend/internal/db/dbtest" ) func createKeypairFixture(t *testing.T, ctx context.Context, dbConnectionPool db.ConnectionPool, kp Keypair) { @@ -64,7 +65,7 @@ func TestKeypairModelInsert(t *testing.T) { t.Run("keypair_already_exists", func(t *testing.T) { kpFull := keypair.MustRandom() createKeypairFixture(t, ctx, dbConnectionPool, Keypair{PublicKey: kpFull.Address(), EncryptedPrivateKey: []byte(kpFull.Seed())}) - err := m.Insert(ctx, kpFull.Address(), []byte(kpFull.Seed())) + err = m.Insert(ctx, kpFull.Address(), []byte(kpFull.Seed())) assert.ErrorIs(t, err, ErrPublicKeyAlreadyExists) }) diff --git a/internal/signing/store/mocks.go b/internal/signing/store/mocks.go index f9fac725e..d74c75369 100644 --- a/internal/signing/store/mocks.go +++ b/internal/signing/store/mocks.go @@ -4,8 +4,9 @@ import ( "context" "time" - "github.com/stellar/wallet-backend/internal/db" "github.com/stretchr/testify/mock" + + "github.com/stellar/wallet-backend/internal/db" ) type ChannelAccountStoreMock struct { diff --git a/internal/signing/utils/encrypter.go b/internal/signing/utils/encrypter.go index 18940c7bb..025693da2 100644 --- a/internal/signing/utils/encrypter.go +++ b/internal/signing/utils/encrypter.go @@ -30,12 +30,12 @@ func (e *DefaultPrivateKeyEncrypter) Encrypt(ctx context.Context, message, passp block, err := aes.NewCipher([]byte(key)) if err != nil { - return "", err + return "", fmt.Errorf("creating aes cipher: %w", err) } gcmCipher, err := cipher.NewGCM(block) if err != nil { - return "", err + return "", fmt.Errorf("creating gcm cipher: %w", err) } nonce := make([]byte, gcmCipher.NonceSize()) @@ -60,17 +60,17 @@ func (e *DefaultPrivateKeyEncrypter) Decrypt(ctx context.Context, encryptedMessa block, err := aes.NewCipher([]byte(key)) if err != nil { - return "", err + return "", fmt.Errorf("creating aes cipher: %w", err) } gcmCipher, err := cipher.NewGCM(block) if err != nil { - return "", err + return "", fmt.Errorf("creating gcm cipher: %w", err) } decodedMsg, err := base64.StdEncoding.DecodeString(encryptedMessage) if err != nil { - return "", err + return "", fmt.Errorf("decoding encrypted message: %w", err) } nonceSize := gcmCipher.NonceSize() @@ -78,7 +78,7 @@ func (e *DefaultPrivateKeyEncrypter) Decrypt(ctx context.Context, encryptedMessa plainText, err := gcmCipher.Open(nil, nonce, cipheredText, nil) if err != nil { - return "", err + return "", fmt.Errorf("decrypting and authenticating message: %w", err) } return string(plainText), nil diff --git a/internal/tss/channels/error_jitter_channel.go b/internal/tss/channels/error_jitter_channel.go index a4ea6c8ec..061672f34 100644 --- a/internal/tss/channels/error_jitter_channel.go +++ b/internal/tss/channels/error_jitter_channel.go @@ -7,11 +7,12 @@ import ( "github.com/alitto/pond" "github.com/stellar/go/support/log" + "golang.org/x/exp/rand" + "github.com/stellar/wallet-backend/internal/metrics" "github.com/stellar/wallet-backend/internal/tss" "github.com/stellar/wallet-backend/internal/tss/router" "github.com/stellar/wallet-backend/internal/tss/services" - "golang.org/x/exp/rand" ) type ErrorJitterChannelConfigs struct { @@ -72,14 +73,14 @@ func (p *errorJitterPool) Receive(payload tss.Payload) { currentBackoff := p.MinWaitBtwnRetriesMS * (1 << i) time.Sleep(jitter(time.Duration(currentBackoff)) * time.Millisecond) - oldStatus := payload.RpcSubmitTxResponse.Status.Status() + oldStatus := payload.RPCSubmitTxResponse.Status.Status() rpcSendResp, err := p.TxManager.BuildAndSubmitTransaction(ctx, ErrorJitterChannelName, payload) if err != nil { log.Errorf("%s: unable to sign and submit transaction: %e", ErrorJitterChannelName, err) return } - payload.RpcSubmitTxResponse = rpcSendResp + payload.RPCSubmitTxResponse = rpcSendResp if !slices.Contains(tss.JitterErrorCodes, rpcSendResp.Code.TxResultCode) { err := p.Router.Route(payload) if err != nil { diff --git a/internal/tss/channels/error_jitter_channel_test.go b/internal/tss/channels/error_jitter_channel_test.go index 279db214e..2d8c28cc2 100644 --- a/internal/tss/channels/error_jitter_channel_test.go +++ b/internal/tss/channels/error_jitter_channel_test.go @@ -5,6 +5,9 @@ import ( "errors" "testing" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stellar/wallet-backend/internal/db" "github.com/stellar/wallet-backend/internal/db/dbtest" "github.com/stellar/wallet-backend/internal/entities" @@ -12,8 +15,6 @@ import ( "github.com/stellar/wallet-backend/internal/tss" "github.com/stellar/wallet-backend/internal/tss/router" "github.com/stellar/wallet-backend/internal/tss/services" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" ) func TestJitterSend(t *testing.T) { @@ -51,7 +52,7 @@ func TestJitterSend(t *testing.T) { Status: tss.RPCTXStatus{RPCStatus: entities.ErrorStatus}, Code: tss.RPCTXCode{TxResultCode: tss.NonJitterErrorCodes[0]}, } - payload.RpcSubmitTxResponse = rpcResp + payload.RPCSubmitTxResponse = rpcResp txManagerMock. On("BuildAndSubmitTransaction", context.Background(), ErrorJitterChannelName, payload). diff --git a/internal/tss/channels/error_non_jitter_channel.go b/internal/tss/channels/error_non_jitter_channel.go index 0f5e3b781..7da30c9ef 100644 --- a/internal/tss/channels/error_non_jitter_channel.go +++ b/internal/tss/channels/error_non_jitter_channel.go @@ -7,6 +7,7 @@ import ( "github.com/alitto/pond" "github.com/stellar/go/support/log" + "github.com/stellar/wallet-backend/internal/metrics" "github.com/stellar/wallet-backend/internal/tss" "github.com/stellar/wallet-backend/internal/tss/router" @@ -66,14 +67,14 @@ func (p *errorNonJitterPool) Receive(payload tss.Payload) { for i = 0; i < p.MaxRetries; i++ { time.Sleep(time.Duration(p.WaitBtwnRetriesMS) * time.Millisecond) - oldStatus := payload.RpcSubmitTxResponse.Status.Status() + oldStatus := payload.RPCSubmitTxResponse.Status.Status() rpcSendResp, err := p.TxManager.BuildAndSubmitTransaction(ctx, ErrorNonJitterChannelName, payload) if err != nil { log.Errorf("%s: unable to sign and submit transaction: %v", ErrorNonJitterChannelName, err) return } - payload.RpcSubmitTxResponse = rpcSendResp + payload.RPCSubmitTxResponse = rpcSendResp if !slices.Contains(tss.NonJitterErrorCodes, rpcSendResp.Code.TxResultCode) { err := p.Router.Route(payload) if err != nil { diff --git a/internal/tss/channels/error_non_jitter_channel_test.go b/internal/tss/channels/error_non_jitter_channel_test.go index 5a37eecc0..d42c6bb5a 100644 --- a/internal/tss/channels/error_non_jitter_channel_test.go +++ b/internal/tss/channels/error_non_jitter_channel_test.go @@ -5,6 +5,9 @@ import ( "errors" "testing" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stellar/wallet-backend/internal/db" "github.com/stellar/wallet-backend/internal/db/dbtest" "github.com/stellar/wallet-backend/internal/entities" @@ -12,8 +15,6 @@ import ( "github.com/stellar/wallet-backend/internal/tss" "github.com/stellar/wallet-backend/internal/tss/router" "github.com/stellar/wallet-backend/internal/tss/services" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" ) func TestNonJitterSend(t *testing.T) { @@ -51,7 +52,7 @@ func TestNonJitterSend(t *testing.T) { Status: tss.RPCTXStatus{RPCStatus: entities.ErrorStatus}, Code: tss.RPCTXCode{TxResultCode: tss.JitterErrorCodes[0]}, } - payload.RpcSubmitTxResponse = rpcResp + payload.RPCSubmitTxResponse = rpcResp txManagerMock. On("BuildAndSubmitTransaction", context.Background(), ErrorNonJitterChannelName, payload). diff --git a/internal/tss/channels/rpc_caller_channel.go b/internal/tss/channels/rpc_caller_channel.go index e2061a50f..bc58ed8dd 100644 --- a/internal/tss/channels/rpc_caller_channel.go +++ b/internal/tss/channels/rpc_caller_channel.go @@ -6,6 +6,7 @@ import ( "github.com/alitto/pond" "github.com/stellar/go/support/log" + "github.com/stellar/wallet-backend/internal/metrics" "github.com/stellar/wallet-backend/internal/tss" "github.com/stellar/wallet-backend/internal/tss/router" @@ -59,20 +60,18 @@ func (p *rpcCallerPool) Receive(payload tss.Payload) { ctx := context.Background() // Create a new transaction record in the transactions table. err := p.Store.UpsertTransaction(ctx, payload.WebhookURL, payload.TransactionHash, payload.TransactionXDR, tss.RPCTXStatus{OtherStatus: tss.NewStatus}) - if err != nil { log.Errorf("%s: unable to upsert transaction into transactions table: %e", RPCCallerChannelName, err) return } rpcSendResp, err := p.TxManager.BuildAndSubmitTransaction(ctx, RPCCallerChannelName, payload) - if err != nil { log.Errorf("%s: unable to sign and submit transaction: %e", RPCCallerChannelName, err) return } - payload.RpcSubmitTxResponse = rpcSendResp + payload.RPCSubmitTxResponse = rpcSendResp err = p.Router.Route(payload) if err != nil { log.Errorf("%s: unable to route payload: %e", RPCCallerChannelName, err) diff --git a/internal/tss/channels/rpc_caller_channel_test.go b/internal/tss/channels/rpc_caller_channel_test.go index 6b71b0c14..2164be903 100644 --- a/internal/tss/channels/rpc_caller_channel_test.go +++ b/internal/tss/channels/rpc_caller_channel_test.go @@ -5,6 +5,9 @@ import ( "errors" "testing" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stellar/wallet-backend/internal/db" "github.com/stellar/wallet-backend/internal/db/dbtest" "github.com/stellar/wallet-backend/internal/entities" @@ -13,8 +16,6 @@ import ( "github.com/stellar/wallet-backend/internal/tss/router" "github.com/stellar/wallet-backend/internal/tss/services" "github.com/stellar/wallet-backend/internal/tss/store" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" ) func TestSend(t *testing.T) { @@ -25,7 +26,8 @@ func TestSend(t *testing.T) { defer dbConnectionPool.Close() mockMetricsService := metrics.NewMockMetricsService() - store, _ := store.NewStore(dbConnectionPool, mockMetricsService) + store, err := store.NewStore(dbConnectionPool, mockMetricsService) + require.NoError(t, err) txManagerMock := services.TransactionManagerMock{} routerMock := router.MockRouter{} cfgs := RPCCallerChannelConfigs{ @@ -52,7 +54,7 @@ func TestSend(t *testing.T) { rpcResp := tss.RPCSendTxResponse{ Status: tss.RPCTXStatus{RPCStatus: entities.TryAgainLaterStatus}, } - payload.RpcSubmitTxResponse = rpcResp + payload.RPCSubmitTxResponse = rpcResp txManagerMock. On("BuildAndSubmitTransaction", context.Background(), RPCCallerChannelName, payload). @@ -78,7 +80,8 @@ func TestReceive(t *testing.T) { defer dbConnectionPool.Close() mockMetricsService := metrics.NewMockMetricsService() - store, _ := store.NewStore(dbConnectionPool, mockMetricsService) + store, err := store.NewStore(dbConnectionPool, mockMetricsService) + require.NoError(t, err) txManagerMock := services.TransactionManagerMock{} routerMock := router.MockRouter{} cfgs := RPCCallerChannelConfigs{ @@ -123,7 +126,7 @@ func TestReceive(t *testing.T) { rpcResp := tss.RPCSendTxResponse{ Status: tss.RPCTXStatus{RPCStatus: entities.ErrorStatus}, } - payload.RpcSubmitTxResponse = rpcResp + payload.RPCSubmitTxResponse = rpcResp txManagerMock. On("BuildAndSubmitTransaction", context.Background(), RPCCallerChannelName, payload). diff --git a/internal/tss/channels/webhook_channel.go b/internal/tss/channels/webhook_channel.go index f187c2d41..205eb5ad8 100644 --- a/internal/tss/channels/webhook_channel.go +++ b/internal/tss/channels/webhook_channel.go @@ -11,8 +11,9 @@ import ( "github.com/alitto/pond" "github.com/stellar/go/support/log" "github.com/stellar/go/txnbuild" - channelAccountStore "github.com/stellar/wallet-backend/internal/signing/store" + "github.com/stellar/wallet-backend/internal/metrics" + channelAccountStore "github.com/stellar/wallet-backend/internal/signing/store" "github.com/stellar/wallet-backend/internal/tss" "github.com/stellar/wallet-backend/internal/tss/store" tssutils "github.com/stellar/wallet-backend/internal/tss/utils" @@ -89,7 +90,7 @@ func (p *webhookPool) Receive(payload tss.Payload) { if err != nil { log.Errorf("%s: error making POST request to webhook: %e", WebhookChannelName, err) } else { - defer httpResp.Body.Close() + defer utils.DeferredClose(ctx, httpResp.Body, "closing response body in the Receive function") if httpResp.StatusCode == http.StatusOK { sent = true err := p.Store.UpsertTransaction( @@ -110,7 +111,6 @@ func (p *webhookPool) Receive(payload tss.Payload) { log.Errorf("%s: error updating transaction status: %e", WebhookChannelName, err) } } - } func (p *webhookPool) UnlockChannelAccount(ctx context.Context, txXDR string) error { diff --git a/internal/tss/channels/webhook_channel_test.go b/internal/tss/channels/webhook_channel_test.go index 7d0f5b899..d107dc8df 100644 --- a/internal/tss/channels/webhook_channel_test.go +++ b/internal/tss/channels/webhook_channel_test.go @@ -12,6 +12,10 @@ import ( "github.com/stellar/go/keypair" "github.com/stellar/go/txnbuild" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stellar/wallet-backend/internal/db" "github.com/stellar/wallet-backend/internal/db/dbtest" "github.com/stellar/wallet-backend/internal/metrics" @@ -20,9 +24,6 @@ import ( "github.com/stellar/wallet-backend/internal/tss/store" tssutils "github.com/stellar/wallet-backend/internal/tss/utils" "github.com/stellar/wallet-backend/internal/utils" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" ) func TestWebhookHandlerServiceChannel(t *testing.T) { @@ -33,7 +34,8 @@ func TestWebhookHandlerServiceChannel(t *testing.T) { defer dbConnectionPool.Close() mockMetricsService := metrics.NewMockMetricsService() - store, _ := store.NewStore(dbConnectionPool, mockMetricsService) + store, err := store.NewStore(dbConnectionPool, mockMetricsService) + require.NoError(t, err) channelAccountStore := channelAccountStore.ChannelAccountStoreMock{} mockHTTPClient := utils.MockHTTPClient{} cfg := WebhookChannelConfigs{ @@ -61,8 +63,8 @@ func TestWebhookHandlerServiceChannel(t *testing.T) { payload.TransactionHash = "hash" payload.TransactionXDR = "xdr" payload.WebhookURL = "www.stellar.org" - jsonData, _ := json.Marshal(tssutils.PayloadTOTSSResponse(payload)) - + jsonData, err := json.Marshal(tssutils.PayloadTOTSSResponse(payload)) + require.NoError(t, err) httpResponse1 := &http.Response{ StatusCode: http.StatusBadGateway, Body: io.NopCloser(strings.NewReader(`{"result": {"status": "OK"}}`)), @@ -101,7 +103,8 @@ func TestUnlockChannelAccount(t *testing.T) { defer dbConnectionPool.Close() mockMetricsService := metrics.NewMockMetricsService() - store, _ := store.NewStore(dbConnectionPool, mockMetricsService) + store, err := store.NewStore(dbConnectionPool, mockMetricsService) + require.NoError(t, err) channelAccountStore := channelAccountStore.ChannelAccountStoreMock{} mockHTTPClient := utils.MockHTTPClient{} cfg := WebhookChannelConfigs{ diff --git a/internal/tss/errors/errors.go b/internal/tss/errors/errors.go index 69f8bc328..a44357565 100644 --- a/internal/tss/errors/errors.go +++ b/internal/tss/errors/errors.go @@ -4,6 +4,4 @@ import ( "errors" ) -var ( - OriginalXDRMalformed = errors.New("transaction string is malformed") -) +var ErrOriginalXDRMalformed = errors.New("transaction string (XDR) is malformed") diff --git a/internal/tss/router/mocks.go b/internal/tss/router/mocks.go index 2f269b7bb..df1c1874b 100644 --- a/internal/tss/router/mocks.go +++ b/internal/tss/router/mocks.go @@ -1,8 +1,9 @@ package router import ( - "github.com/stellar/wallet-backend/internal/tss" "github.com/stretchr/testify/mock" + + "github.com/stellar/wallet-backend/internal/tss" ) type MockRouter struct { diff --git a/internal/tss/router/router.go b/internal/tss/router/router.go index 331368695..bce5f7289 100644 --- a/internal/tss/router/router.go +++ b/internal/tss/router/router.go @@ -39,19 +39,19 @@ func NewRouter(cfg RouterConfigs) Router { func (r *router) Route(payload tss.Payload) error { var channel tss.Channel - if payload.RpcSubmitTxResponse.Status.Status() != "" { - switch payload.RpcSubmitTxResponse.Status { + if payload.RPCSubmitTxResponse.Status.Status() != "" { + switch payload.RPCSubmitTxResponse.Status { case tss.RPCTXStatus{OtherStatus: tss.NewStatus}: channel = r.RPCCallerChannel case tss.RPCTXStatus{RPCStatus: entities.TryAgainLaterStatus}: channel = r.ErrorJitterChannel case tss.RPCTXStatus{RPCStatus: entities.ErrorStatus}: - if payload.RpcSubmitTxResponse.Code.OtherCodes == tss.NoCode { - if slices.Contains(tss.JitterErrorCodes, payload.RpcSubmitTxResponse.Code.TxResultCode) { + if payload.RPCSubmitTxResponse.Code.OtherCodes == tss.NoCode { + if slices.Contains(tss.JitterErrorCodes, payload.RPCSubmitTxResponse.Code.TxResultCode) { channel = r.ErrorJitterChannel - } else if slices.Contains(tss.NonJitterErrorCodes, payload.RpcSubmitTxResponse.Code.TxResultCode) { + } else if slices.Contains(tss.NonJitterErrorCodes, payload.RPCSubmitTxResponse.Code.TxResultCode) { channel = r.ErrorNonJitterChannel - } else if slices.Contains(tss.FinalCodes, payload.RpcSubmitTxResponse.Code.TxResultCode) { + } else if slices.Contains(tss.FinalCodes, payload.RPCSubmitTxResponse.Code.TxResultCode) { channel = r.WebhookChannel } } @@ -63,7 +63,7 @@ func (r *router) Route(payload tss.Payload) error { // Do nothing for PENDING / DUPLICATE statuses return nil } - } else if payload.RpcGetIngestTxResponse.Status != "" { + } else if payload.RPCGetIngestTxResponse.Status != "" { channel = r.WebhookChannel } else { channel = r.RPCCallerChannel diff --git a/internal/tss/router/router_test.go b/internal/tss/router/router_test.go index e9d933591..05db7e798 100644 --- a/internal/tss/router/router_test.go +++ b/internal/tss/router/router_test.go @@ -3,9 +3,10 @@ package router import ( "testing" + "github.com/stretchr/testify/assert" + "github.com/stellar/wallet-backend/internal/entities" "github.com/stellar/wallet-backend/internal/tss" - "github.com/stretchr/testify/assert" ) func TestRouter(t *testing.T) { @@ -26,7 +27,7 @@ func TestRouter(t *testing.T) { }) t.Run("status_new_routes_to_rpc_caller_channel", func(t *testing.T) { payload := tss.Payload{} - payload.RpcSubmitTxResponse.Status = tss.RPCTXStatus{OtherStatus: tss.NewStatus} + payload.RPCSubmitTxResponse.Status = tss.RPCTXStatus{OtherStatus: tss.NewStatus} rpcCallerChannel. On("Send", payload). @@ -40,7 +41,7 @@ func TestRouter(t *testing.T) { }) t.Run("status_try_again_later_routes_to_error_jitter_channel", func(t *testing.T) { payload := tss.Payload{} - payload.RpcSubmitTxResponse.Status = tss.RPCTXStatus{RPCStatus: entities.TryAgainLaterStatus} + payload.RPCSubmitTxResponse.Status = tss.RPCTXStatus{RPCStatus: entities.TryAgainLaterStatus} errorJitterChannel. On("Send", payload). @@ -55,7 +56,7 @@ func TestRouter(t *testing.T) { t.Run("status_failure_routes_to_webhook_channel", func(t *testing.T) { payload := tss.Payload{} - payload.RpcSubmitTxResponse.Status = tss.RPCTXStatus{RPCStatus: entities.FailedStatus} + payload.RPCSubmitTxResponse.Status = tss.RPCTXStatus{RPCStatus: entities.FailedStatus} webhookChannel. On("Send", payload). @@ -70,7 +71,7 @@ func TestRouter(t *testing.T) { t.Run("status_success_routes_to_webhook_channel", func(t *testing.T) { payload := tss.Payload{} - payload.RpcSubmitTxResponse.Status = tss.RPCTXStatus{RPCStatus: entities.SuccessStatus} + payload.RPCSubmitTxResponse.Status = tss.RPCTXStatus{RPCStatus: entities.SuccessStatus} webhookChannel. On("Send", payload). @@ -86,7 +87,7 @@ func TestRouter(t *testing.T) { t.Run("status_error_routes_to_error_jitter_channel", func(t *testing.T) { for _, code := range tss.JitterErrorCodes { payload := tss.Payload{ - RpcSubmitTxResponse: tss.RPCSendTxResponse{ + RPCSubmitTxResponse: tss.RPCSendTxResponse{ Status: tss.RPCTXStatus{ RPCStatus: entities.ErrorStatus, }, @@ -95,7 +96,7 @@ func TestRouter(t *testing.T) { }, }, } - payload.RpcSubmitTxResponse.Code.TxResultCode = code + payload.RPCSubmitTxResponse.Code.TxResultCode = code errorJitterChannel. On("Send", payload). Return(). @@ -110,7 +111,7 @@ func TestRouter(t *testing.T) { t.Run("status_error_routes_to_error_non_jitter_channel", func(t *testing.T) { for _, code := range tss.NonJitterErrorCodes { payload := tss.Payload{ - RpcSubmitTxResponse: tss.RPCSendTxResponse{ + RPCSubmitTxResponse: tss.RPCSendTxResponse{ Status: tss.RPCTXStatus{ RPCStatus: entities.ErrorStatus, }, @@ -119,7 +120,7 @@ func TestRouter(t *testing.T) { }, }, } - payload.RpcSubmitTxResponse.Code.TxResultCode = code + payload.RPCSubmitTxResponse.Code.TxResultCode = code errorNonJitterChannel. On("Send", payload). Return(). @@ -134,7 +135,7 @@ func TestRouter(t *testing.T) { t.Run("status_error_routes_to_webhook_channel", func(t *testing.T) { for _, code := range tss.FinalCodes { payload := tss.Payload{ - RpcSubmitTxResponse: tss.RPCSendTxResponse{ + RPCSubmitTxResponse: tss.RPCSendTxResponse{ Status: tss.RPCTXStatus{ RPCStatus: entities.ErrorStatus, }, @@ -156,7 +157,7 @@ func TestRouter(t *testing.T) { }) t.Run("get_ingest_resp_always_routes_to_webhook_channel", func(t *testing.T) { payload := tss.Payload{ - RpcGetIngestTxResponse: tss.RPCGetIngestTxResponse{ + RPCGetIngestTxResponse: tss.RPCGetIngestTxResponse{ Status: entities.SuccessStatus, Code: tss.RPCTXCode{ TxResultCode: tss.FinalCodes[0], diff --git a/internal/tss/services/mocks.go b/internal/tss/services/mocks.go index 9bb27291b..8240f0b4a 100644 --- a/internal/tss/services/mocks.go +++ b/internal/tss/services/mocks.go @@ -4,6 +4,7 @@ import ( "context" "github.com/stellar/go/txnbuild" + "github.com/stellar/wallet-backend/internal/tss" "github.com/stretchr/testify/mock" diff --git a/internal/tss/services/pool_populator.go b/internal/tss/services/pool_populator.go index efba21802..741c1a747 100644 --- a/internal/tss/services/pool_populator.go +++ b/internal/tss/services/pool_populator.go @@ -7,6 +7,7 @@ import ( "github.com/stellar/go/support/log" "github.com/stellar/go/xdr" + "github.com/stellar/wallet-backend/internal/entities" "github.com/stellar/wallet-backend/internal/services" "github.com/stellar/wallet-backend/internal/tss" @@ -84,7 +85,7 @@ func (p *poolPopulator) routeNewTransactions(ctx context.Context) error { return fmt.Errorf("getting latest try for transaction: %w", err) } if try == (store.Try{}) || try.Code == int32(tss.RPCFailCode) || try.Code == int32(tss.NewCode) { - payload.RpcSubmitTxResponse.Status = tss.RPCTXStatus{OtherStatus: tss.NewStatus} + payload.RPCSubmitTxResponse.Status = tss.RPCTXStatus{OtherStatus: tss.NewStatus} } err = p.Router.Route(payload) if err != nil { @@ -111,7 +112,7 @@ func (p *poolPopulator) routeErrorTransactions(ctx context.Context) error { } if slices.Contains(tss.FinalCodes, xdr.TransactionResultCode(try.Code)) { // route to webhook channel - payload.RpcSubmitTxResponse = tss.RPCSendTxResponse{ + payload.RPCSubmitTxResponse = tss.RPCSendTxResponse{ TransactionHash: try.Hash, TransactionXDR: try.XDR, Status: tss.RPCTXStatus{RPCStatus: entities.ErrorStatus}, @@ -120,12 +121,11 @@ func (p *poolPopulator) routeErrorTransactions(ctx context.Context) error { } } else if try.Code == int32(tss.RPCFailCode) || try.Code == int32(tss.NewCode) { // route to the error jitter channel - payload.RpcSubmitTxResponse = tss.RPCSendTxResponse{ + payload.RPCSubmitTxResponse = tss.RPCSendTxResponse{ TransactionHash: try.Hash, TransactionXDR: try.XDR, Status: tss.RPCTXStatus{RPCStatus: entities.TryAgainLaterStatus}, } - } err = p.Router.Route(payload) if err != nil { @@ -150,7 +150,7 @@ func (p *poolPopulator) routeFinalTransactions(ctx context.Context, status tss.R if err != nil { return fmt.Errorf("gretting latest try for transaction: %w", err) } - payload.RpcGetIngestTxResponse = tss.RPCGetIngestTxResponse{ + payload.RPCGetIngestTxResponse = tss.RPCGetIngestTxResponse{ Status: status.RPCStatus, Code: tss.RPCTXCode{TxResultCode: xdr.TransactionResultCode(try.Code)}, EnvelopeXDR: try.XDR, @@ -179,7 +179,7 @@ func (p *poolPopulator) routeNotSentTransactions(ctx context.Context) error { if err != nil { return fmt.Errorf("gretting latest try for transaction: %w", err) } - payload.RpcSubmitTxResponse = tss.RPCSendTxResponse{ + payload.RPCSubmitTxResponse = tss.RPCSendTxResponse{ TransactionHash: try.Hash, TransactionXDR: try.XDR, Status: tss.RPCTXStatus{RPCStatus: entities.RPCStatus(try.Status)}, diff --git a/internal/tss/services/pool_populator_test.go b/internal/tss/services/pool_populator_test.go index 8c07d9c60..dba92d8df 100644 --- a/internal/tss/services/pool_populator_test.go +++ b/internal/tss/services/pool_populator_test.go @@ -5,6 +5,10 @@ import ( "testing" "github.com/stellar/go/xdr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stellar/wallet-backend/internal/db" "github.com/stellar/wallet-backend/internal/db/dbtest" "github.com/stellar/wallet-backend/internal/entities" @@ -13,9 +17,6 @@ import ( "github.com/stellar/wallet-backend/internal/tss" "github.com/stellar/wallet-backend/internal/tss/router" "github.com/stellar/wallet-backend/internal/tss/store" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" ) func TestRouteNewTransactions(t *testing.T) { @@ -27,10 +28,13 @@ func TestRouteNewTransactions(t *testing.T) { defer dbConnectionPool.Close() mockMetricsService := metrics.NewMockMetricsService() - store, _ := store.NewStore(dbConnectionPool, mockMetricsService) + store, err := store.NewStore(dbConnectionPool, mockMetricsService) + require.NoError(t, err) mockRouter := router.MockRouter{} mockRPCSerive := services.RPCServiceMock{} - populator, _ := NewPoolPopulator(&mockRouter, store, &mockRPCSerive) + populator, err := NewPoolPopulator(&mockRouter, store, &mockRPCSerive) + require.NoError(t, err) + t.Run("tx_has_no_try", func(t *testing.T) { mockMetricsService.On("ObserveDBQueryDuration", "SELECT", "tss_transaction_submission_tries", mock.AnythingOfType("float64")).Once() mockMetricsService.On("IncDBQuery", "SELECT", "tss_transaction_submission_tries").Once() @@ -40,20 +44,21 @@ func TestRouteNewTransactions(t *testing.T) { mockMetricsService.On("IncDBQuery", "SELECT", "tss_transactions").Once() defer mockMetricsService.AssertExpectations(t) - _ = store.UpsertTransaction(context.Background(), "localhost:8000/webhook", "hash", "xdr", tss.RPCTXStatus{OtherStatus: tss.NewStatus}) + err = store.UpsertTransaction(context.Background(), "localhost:8000/webhook", "hash", "xdr", tss.RPCTXStatus{OtherStatus: tss.NewStatus}) + require.NoError(t, err) expectedPayload := tss.Payload{ TransactionHash: "hash", TransactionXDR: "xdr", WebhookURL: "localhost:8000/webhook", - RpcSubmitTxResponse: tss.RPCSendTxResponse{Status: tss.RPCTXStatus{OtherStatus: tss.NewStatus}}, + RPCSubmitTxResponse: tss.RPCSendTxResponse{Status: tss.RPCTXStatus{OtherStatus: tss.NewStatus}}, } mockRouter. On("Route", expectedPayload). Return(nil). Once() - err := populator.routeNewTransactions(context.Background()) + err = populator.routeNewTransactions(context.Background()) assert.Empty(t, err) }) @@ -68,14 +73,16 @@ func TestRouteNewTransactions(t *testing.T) { mockMetricsService.On("IncDBQuery", "SELECT", "tss_transactions").Once() defer mockMetricsService.AssertExpectations(t) - _ = store.UpsertTransaction(context.Background(), "localhost:8000/webhook", "hash", "xdr", tss.RPCTXStatus{OtherStatus: tss.NewStatus}) - _ = store.UpsertTry(context.Background(), "hash", "feebumphash", "feebumpxdr", tss.RPCTXStatus{OtherStatus: tss.NewStatus}, tss.RPCTXCode{OtherCodes: tss.NewCode}, "ABCD") + err = store.UpsertTransaction(context.Background(), "localhost:8000/webhook", "hash", "xdr", tss.RPCTXStatus{OtherStatus: tss.NewStatus}) + require.NoError(t, err) + err = store.UpsertTry(context.Background(), "hash", "feebumphash", "feebumpxdr", tss.RPCTXStatus{OtherStatus: tss.NewStatus}, tss.RPCTXCode{OtherCodes: tss.NewCode}, "ABCD") + require.NoError(t, err) expectedPayload := tss.Payload{ TransactionHash: "hash", TransactionXDR: "xdr", WebhookURL: "localhost:8000/webhook", - RpcSubmitTxResponse: tss.RPCSendTxResponse{Status: tss.RPCTXStatus{OtherStatus: tss.NewStatus}}, + RPCSubmitTxResponse: tss.RPCSendTxResponse{Status: tss.RPCTXStatus{OtherStatus: tss.NewStatus}}, } mockRouter. @@ -97,10 +104,11 @@ func TestRouteErrorTransactions(t *testing.T) { defer dbConnectionPool.Close() mockMetricsService := metrics.NewMockMetricsService() - store, _ := store.NewStore(dbConnectionPool, mockMetricsService) + store, err := store.NewStore(dbConnectionPool, mockMetricsService) + require.NoError(t, err) mockRouter := router.MockRouter{} mockRPCSerive := services.RPCServiceMock{} - populator, _ := NewPoolPopulator(&mockRouter, store, &mockRPCSerive) + populator, err := NewPoolPopulator(&mockRouter, store, &mockRPCSerive) t.Run("tx_has_final_error_code", func(t *testing.T) { mockMetricsService.On("ObserveDBQueryDuration", "SELECT", "tss_transactions", mock.AnythingOfType("float64")).Once() @@ -113,14 +121,16 @@ func TestRouteErrorTransactions(t *testing.T) { mockMetricsService.On("IncDBQuery", "SELECT", "tss_transaction_submission_tries").Once() defer mockMetricsService.AssertExpectations(t) - _ = store.UpsertTransaction(context.Background(), "localhost:8000/webhook", "hash", "xdr", tss.RPCTXStatus{RPCStatus: entities.ErrorStatus}) - _ = store.UpsertTry(context.Background(), "hash", "feebumphash", "feebumpxdr", tss.RPCTXStatus{RPCStatus: entities.ErrorStatus}, tss.RPCTXCode{TxResultCode: xdr.TransactionResultCodeTxInsufficientBalance}, "ABCD") + err = store.UpsertTransaction(context.Background(), "localhost:8000/webhook", "hash", "xdr", tss.RPCTXStatus{RPCStatus: entities.ErrorStatus}) + require.NoError(t, err) + err = store.UpsertTry(context.Background(), "hash", "feebumphash", "feebumpxdr", tss.RPCTXStatus{RPCStatus: entities.ErrorStatus}, tss.RPCTXCode{TxResultCode: xdr.TransactionResultCodeTxInsufficientBalance}, "ABCD") + require.NoError(t, err) expectedPayload := tss.Payload{ TransactionHash: "hash", TransactionXDR: "xdr", WebhookURL: "localhost:8000/webhook", - RpcSubmitTxResponse: tss.RPCSendTxResponse{ + RPCSubmitTxResponse: tss.RPCSendTxResponse{ TransactionHash: "feebumphash", TransactionXDR: "feebumpxdr", Status: tss.RPCTXStatus{RPCStatus: entities.ErrorStatus}, @@ -134,7 +144,7 @@ func TestRouteErrorTransactions(t *testing.T) { Return(nil). Once() - err := populator.routeErrorTransactions(context.Background()) + err = populator.routeErrorTransactions(context.Background()) assert.Empty(t, err) }) @@ -149,14 +159,16 @@ func TestRouteErrorTransactions(t *testing.T) { mockMetricsService.On("IncDBQuery", "SELECT", "tss_transaction_submission_tries").Once() defer mockMetricsService.AssertExpectations(t) - _ = store.UpsertTransaction(context.Background(), "localhost:8000/webhook", "hash", "xdr", tss.RPCTXStatus{RPCStatus: entities.ErrorStatus}) - _ = store.UpsertTry(context.Background(), "hash", "feebumphash", "feebumpxdr", tss.RPCTXStatus{RPCStatus: entities.ErrorStatus}, tss.RPCTXCode{OtherCodes: tss.RPCFailCode}, "ABCD") + err = store.UpsertTransaction(context.Background(), "localhost:8000/webhook", "hash", "xdr", tss.RPCTXStatus{RPCStatus: entities.ErrorStatus}) + require.NoError(t, err) + err = store.UpsertTry(context.Background(), "hash", "feebumphash", "feebumpxdr", tss.RPCTXStatus{RPCStatus: entities.ErrorStatus}, tss.RPCTXCode{OtherCodes: tss.RPCFailCode}, "ABCD") + require.NoError(t, err) expectedPayload := tss.Payload{ TransactionHash: "hash", TransactionXDR: "xdr", WebhookURL: "localhost:8000/webhook", - RpcSubmitTxResponse: tss.RPCSendTxResponse{ + RPCSubmitTxResponse: tss.RPCSendTxResponse{ TransactionHash: "feebumphash", TransactionXDR: "feebumpxdr", Status: tss.RPCTXStatus{RPCStatus: entities.TryAgainLaterStatus}, @@ -182,11 +194,12 @@ func TestRouteFinalTransactions(t *testing.T) { defer dbConnectionPool.Close() mockMetricsService := metrics.NewMockMetricsService() - store, _ := store.NewStore(dbConnectionPool, mockMetricsService) + store, err := store.NewStore(dbConnectionPool, mockMetricsService) + require.NoError(t, err) mockRouter := router.MockRouter{} mockRPCSerive := services.RPCServiceMock{} - populator, _ := NewPoolPopulator(&mockRouter, store, &mockRPCSerive) - + populator, err := NewPoolPopulator(&mockRouter, store, &mockRPCSerive) + require.NoError(t, err) t.Run("route_successful_tx", func(t *testing.T) { mockMetricsService.On("ObserveDBQueryDuration", "SELECT", "tss_transactions", mock.AnythingOfType("float64")).Once() mockMetricsService.On("IncDBQuery", "SELECT", "tss_transactions").Once() @@ -198,14 +211,16 @@ func TestRouteFinalTransactions(t *testing.T) { mockMetricsService.On("IncDBQuery", "SELECT", "tss_transaction_submission_tries").Once() defer mockMetricsService.AssertExpectations(t) - _ = store.UpsertTransaction(context.Background(), "localhost:8000/webhook", "hash", "xdr", tss.RPCTXStatus{RPCStatus: entities.SuccessStatus}) - _ = store.UpsertTry(context.Background(), "hash", "feebumphash", "feebumpxdr", tss.RPCTXStatus{RPCStatus: entities.SuccessStatus}, tss.RPCTXCode{TxResultCode: xdr.TransactionResultCodeTxSuccess}, "ABCD") + err = store.UpsertTransaction(context.Background(), "localhost:8000/webhook", "hash", "xdr", tss.RPCTXStatus{RPCStatus: entities.SuccessStatus}) + require.NoError(t, err) + err = store.UpsertTry(context.Background(), "hash", "feebumphash", "feebumpxdr", tss.RPCTXStatus{RPCStatus: entities.SuccessStatus}, tss.RPCTXCode{TxResultCode: xdr.TransactionResultCodeTxSuccess}, "ABCD") + require.NoError(t, err) expectedPayload := tss.Payload{ TransactionHash: "hash", TransactionXDR: "xdr", WebhookURL: "localhost:8000/webhook", - RpcGetIngestTxResponse: tss.RPCGetIngestTxResponse{ + RPCGetIngestTxResponse: tss.RPCGetIngestTxResponse{ Status: entities.SuccessStatus, Code: tss.RPCTXCode{TxResultCode: xdr.TransactionResultCodeTxSuccess}, EnvelopeXDR: "feebumpxdr", @@ -232,10 +247,12 @@ func TestNotSentTransactions(t *testing.T) { defer dbConnectionPool.Close() mockMetricsService := metrics.NewMockMetricsService() - store, _ := store.NewStore(dbConnectionPool, mockMetricsService) + store, err := store.NewStore(dbConnectionPool, mockMetricsService) + require.NoError(t, err) mockRouter := router.MockRouter{} mockRPCSerive := services.RPCServiceMock{} - populator, _ := NewPoolPopulator(&mockRouter, store, &mockRPCSerive) + populator, err := NewPoolPopulator(&mockRouter, store, &mockRPCSerive) + require.NoError(t, err) t.Run("routes_not_sent_txns", func(t *testing.T) { mockMetricsService.On("ObserveDBQueryDuration", "SELECT", "tss_transactions", mock.AnythingOfType("float64")).Once() @@ -248,14 +265,16 @@ func TestNotSentTransactions(t *testing.T) { mockMetricsService.On("IncDBQuery", "SELECT", "tss_transaction_submission_tries").Once() defer mockMetricsService.AssertExpectations(t) - _ = store.UpsertTransaction(context.Background(), "localhost:8000/webhook", "hash", "xdr", tss.RPCTXStatus{OtherStatus: tss.NotSentStatus}) - _ = store.UpsertTry(context.Background(), "hash", "feebumphash", "feebumpxdr", tss.RPCTXStatus{RPCStatus: entities.SuccessStatus}, tss.RPCTXCode{TxResultCode: xdr.TransactionResultCodeTxSuccess}, "ABCD") + err = store.UpsertTransaction(context.Background(), "localhost:8000/webhook", "hash", "xdr", tss.RPCTXStatus{OtherStatus: tss.NotSentStatus}) + require.NoError(t, err) + err = store.UpsertTry(context.Background(), "hash", "feebumphash", "feebumpxdr", tss.RPCTXStatus{RPCStatus: entities.SuccessStatus}, tss.RPCTXCode{TxResultCode: xdr.TransactionResultCodeTxSuccess}, "ABCD") + require.NoError(t, err) expectedPayload := tss.Payload{ TransactionHash: "hash", TransactionXDR: "xdr", WebhookURL: "localhost:8000/webhook", - RpcSubmitTxResponse: tss.RPCSendTxResponse{ + RPCSubmitTxResponse: tss.RPCSendTxResponse{ TransactionHash: "feebumphash", TransactionXDR: "feebumpxdr", Status: tss.RPCTXStatus{RPCStatus: entities.SuccessStatus}, diff --git a/internal/tss/services/transaction_manager.go b/internal/tss/services/transaction_manager.go index 6e180de6d..523dfbccf 100644 --- a/internal/tss/services/transaction_manager.go +++ b/internal/tss/services/transaction_manager.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/stellar/go/txnbuild" + "github.com/stellar/wallet-backend/internal/services" "github.com/stellar/wallet-backend/internal/tss" "github.com/stellar/wallet-backend/internal/tss/errors" @@ -39,16 +40,17 @@ func NewTransactionManager(cfg TransactionManagerConfigs) *transactionManager { func (t *transactionManager) BuildAndSubmitTransaction(ctx context.Context, channelName string, payload tss.Payload) (tss.RPCSendTxResponse, error) { genericTx, err := txnbuild.TransactionFromXDR(payload.TransactionXDR) if err != nil { - return tss.RPCSendTxResponse{}, errors.OriginalXDRMalformed + return tss.RPCSendTxResponse{}, errors.ErrOriginalXDRMalformed } tx, txEmpty := genericTx.Transaction() if !txEmpty { - return tss.RPCSendTxResponse{}, errors.OriginalXDRMalformed + return tss.RPCSendTxResponse{}, errors.ErrOriginalXDRMalformed } var tryTxHash string var tryTxXDR string if payload.FeeBump { - feeBumpTx, err := t.TxService.BuildFeeBumpTransaction(ctx, tx) + var feeBumpTx *txnbuild.FeeBumpTransaction + feeBumpTx, err = t.TxService.BuildFeeBumpTransaction(ctx, tx) if err != nil { return tss.RPCSendTxResponse{}, fmt.Errorf("%s: Unable to build fee bump transaction: %w", channelName, err) } diff --git a/internal/tss/services/transaction_manager_test.go b/internal/tss/services/transaction_manager_test.go index acc0bb65f..0096d6fb9 100644 --- a/internal/tss/services/transaction_manager_test.go +++ b/internal/tss/services/transaction_manager_test.go @@ -6,6 +6,10 @@ import ( "testing" "github.com/stellar/go/xdr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stellar/wallet-backend/internal/db" "github.com/stellar/wallet-backend/internal/db/dbtest" "github.com/stellar/wallet-backend/internal/entities" @@ -14,9 +18,6 @@ import ( "github.com/stellar/wallet-backend/internal/tss" "github.com/stellar/wallet-backend/internal/tss/store" "github.com/stellar/wallet-backend/internal/tss/utils" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" ) func TestBuildAndSubmitTransaction(t *testing.T) { @@ -28,21 +29,26 @@ func TestBuildAndSubmitTransaction(t *testing.T) { defer dbConnectionPool.Close() mockMetricsService := metrics.NewMockMetricsService() - store, _ := store.NewStore(dbConnectionPool, mockMetricsService) + dbStore, err := store.NewStore(dbConnectionPool, mockMetricsService) + require.NoError(t, err) txServiceMock := TransactionServiceMock{} rpcServiceMock := services.RPCServiceMock{} txManager := NewTransactionManager(TransactionManagerConfigs{ TxService: &txServiceMock, RPCService: &rpcServiceMock, - Store: store, + Store: dbStore, }) networkPass := "passphrase" - tx := utils.BuildTestTransaction() - txHash, _ := tx.HashHex(networkPass) - txXDR, _ := tx.Base64() - feeBumpTx := utils.BuildTestFeeBumpTransaction() - feeBumpTxXDR, _ := feeBumpTx.Base64() - feeBumpTxHash, _ := feeBumpTx.HashHex(networkPass) + tx := utils.BuildTestTransaction(t) + txHash, err := tx.HashHex(networkPass) + require.NoError(t, err) + txXDR, err := tx.Base64() + require.NoError(t, err) + feeBumpTx := utils.BuildTestFeeBumpTransaction(t) + feeBumpTxXDR, err := feeBumpTx.Base64() + require.NoError(t, err) + feeBumpTxHash, err := feeBumpTx.HashHex(networkPass) + require.NoError(t, err) payload := tss.Payload{} payload.WebhookURL = "www.stellar.com" payload.TransactionHash = txHash @@ -55,19 +61,23 @@ func TestBuildAndSubmitTransaction(t *testing.T) { mockMetricsService.On("IncDBQuery", "SELECT", "tss_transactions").Once() defer mockMetricsService.AssertExpectations(t) - _ = store.UpsertTransaction(context.Background(), payload.WebhookURL, payload.TransactionHash, payload.TransactionXDR, tss.RPCTXStatus{OtherStatus: tss.NewStatus}) + err = dbStore.UpsertTransaction(context.Background(), payload.WebhookURL, payload.TransactionHash, payload.TransactionXDR, tss.RPCTXStatus{OtherStatus: tss.NewStatus}) + require.NoError(t, err) txServiceMock. On("BuildFeeBumpTransaction", context.Background(), tx). Return(nil, errors.New("signing failed")). Once() payload.FeeBump = true - txSendResp, err := txManager.BuildAndSubmitTransaction(context.Background(), "channel", payload) - + var txSendResp tss.RPCSendTxResponse + txSendResp, err = txManager.BuildAndSubmitTransaction(context.Background(), "channel", payload) + require.Error(t, err) assert.Equal(t, tss.RPCSendTxResponse{}, txSendResp) assert.Equal(t, "channel: Unable to build fee bump transaction: signing failed", err.Error()) - tx, _ := store.GetTransaction(context.Background(), payload.TransactionHash) + var tx store.Transaction + tx, err = dbStore.GetTransaction(context.Background(), payload.TransactionHash) + require.NoError(t, err) assert.Equal(t, string(tss.NewStatus), tx.Status) }) @@ -82,7 +92,8 @@ func TestBuildAndSubmitTransaction(t *testing.T) { mockMetricsService.On("IncDBQuery", "SELECT", "tss_transaction_submission_tries").Once() defer mockMetricsService.AssertExpectations(t) - _ = store.UpsertTransaction(context.Background(), payload.WebhookURL, payload.TransactionHash, payload.TransactionXDR, tss.RPCTXStatus{OtherStatus: tss.NewStatus}) + err = dbStore.UpsertTransaction(context.Background(), payload.WebhookURL, payload.TransactionHash, payload.TransactionXDR, tss.RPCTXStatus{OtherStatus: tss.NewStatus}) + require.NoError(t, err) sendResp := entities.RPCSendTransactionResult{Status: entities.ErrorStatus} txServiceMock. @@ -98,16 +109,21 @@ func TestBuildAndSubmitTransaction(t *testing.T) { Once() payload.FeeBump = true - txSendResp, err := txManager.BuildAndSubmitTransaction(context.Background(), "channel", payload) + var txSendResp tss.RPCSendTxResponse + txSendResp, err = txManager.BuildAndSubmitTransaction(context.Background(), "channel", payload) assert.Equal(t, entities.ErrorStatus, txSendResp.Status.RPCStatus) assert.Equal(t, tss.RPCFailCode, txSendResp.Code.OtherCodes) assert.Equal(t, "channel: RPC fail: RPC fail: RPC down", err.Error()) - tx, _ := store.GetTransaction(context.Background(), payload.TransactionHash) + var tx store.Transaction + tx, err = dbStore.GetTransaction(context.Background(), payload.TransactionHash) + require.NoError(t, err) assert.Equal(t, string(tss.NewStatus), tx.Status) - try, _ := store.GetTry(context.Background(), feeBumpTxHash) + var try store.Try + try, err = dbStore.GetTry(context.Background(), feeBumpTxHash) + require.NoError(t, err) assert.Equal(t, string(entities.ErrorStatus), try.Status) assert.Equal(t, int32(tss.RPCFailCode), try.Code) }) @@ -123,7 +139,8 @@ func TestBuildAndSubmitTransaction(t *testing.T) { mockMetricsService.On("IncDBQuery", "SELECT", "tss_transaction_submission_tries").Once() defer mockMetricsService.AssertExpectations(t) - _ = store.UpsertTransaction(context.Background(), payload.WebhookURL, payload.TransactionHash, payload.TransactionXDR, tss.RPCTXStatus{OtherStatus: tss.NewStatus}) + err = dbStore.UpsertTransaction(context.Background(), payload.WebhookURL, payload.TransactionHash, payload.TransactionXDR, tss.RPCTXStatus{OtherStatus: tss.NewStatus}) + require.NoError(t, err) sendResp := entities.RPCSendTransactionResult{ Status: entities.PendingStatus, ErrorResultXDR: "", @@ -142,16 +159,21 @@ func TestBuildAndSubmitTransaction(t *testing.T) { Once() payload.FeeBump = true - txSendResp, err := txManager.BuildAndSubmitTransaction(context.Background(), "channel", payload) + var txSendResp tss.RPCSendTxResponse + txSendResp, err = txManager.BuildAndSubmitTransaction(context.Background(), "channel", payload) assert.Equal(t, entities.PendingStatus, txSendResp.Status.RPCStatus) assert.Equal(t, tss.EmptyCode, txSendResp.Code.OtherCodes) assert.Empty(t, err) - tx, _ := store.GetTransaction(context.Background(), payload.TransactionHash) + var tx store.Transaction + tx, err = dbStore.GetTransaction(context.Background(), payload.TransactionHash) + require.NoError(t, err) assert.Equal(t, string(entities.PendingStatus), tx.Status) - try, _ := store.GetTry(context.Background(), feeBumpTxHash) + var try store.Try + try, err = dbStore.GetTry(context.Background(), feeBumpTxHash) + require.NoError(t, err) assert.Equal(t, string(entities.PendingStatus), try.Status) assert.Equal(t, int32(tss.EmptyCode), try.Code) }) @@ -167,7 +189,8 @@ func TestBuildAndSubmitTransaction(t *testing.T) { mockMetricsService.On("IncDBQuery", "SELECT", "tss_transaction_submission_tries").Once() defer mockMetricsService.AssertExpectations(t) - _ = store.UpsertTransaction(context.Background(), payload.WebhookURL, payload.TransactionHash, payload.TransactionXDR, tss.RPCTXStatus{OtherStatus: tss.NewStatus}) + err = dbStore.UpsertTransaction(context.Background(), payload.WebhookURL, payload.TransactionHash, payload.TransactionXDR, tss.RPCTXStatus{OtherStatus: tss.NewStatus}) + require.NoError(t, err) sendResp := entities.RPCSendTransactionResult{ Status: entities.ErrorStatus, ErrorResultXDR: "ABCD", @@ -186,16 +209,22 @@ func TestBuildAndSubmitTransaction(t *testing.T) { Once() payload.FeeBump = true - txSendResp, err := txManager.BuildAndSubmitTransaction(context.Background(), "channel", payload) + var txSendResp tss.RPCSendTxResponse + txSendResp, err = txManager.BuildAndSubmitTransaction(context.Background(), "channel", payload) + require.Error(t, err) assert.Equal(t, entities.ErrorStatus, txSendResp.Status.RPCStatus) assert.Equal(t, tss.UnmarshalBinaryCode, txSendResp.Code.OtherCodes) assert.Equal(t, "channel: RPC fail: parse error result xdr string: unable to parse: unable to unmarshal errorResultXDR: ABCD", err.Error()) - tx, _ := store.GetTransaction(context.Background(), payload.TransactionHash) + var tx store.Transaction + tx, err = dbStore.GetTransaction(context.Background(), payload.TransactionHash) + require.NoError(t, err) assert.Equal(t, string(tss.NewStatus), tx.Status) - try, _ := store.GetTry(context.Background(), feeBumpTxHash) + var try store.Try + try, err = dbStore.GetTry(context.Background(), feeBumpTxHash) + require.NoError(t, err) assert.Equal(t, string(entities.ErrorStatus), try.Status) assert.Equal(t, int32(tss.UnmarshalBinaryCode), try.Code) }) @@ -211,7 +240,8 @@ func TestBuildAndSubmitTransaction(t *testing.T) { mockMetricsService.On("IncDBQuery", "SELECT", "tss_transaction_submission_tries").Once() defer mockMetricsService.AssertExpectations(t) - _ = store.UpsertTransaction(context.Background(), payload.WebhookURL, payload.TransactionHash, payload.TransactionXDR, tss.RPCTXStatus{OtherStatus: tss.NewStatus}) + err = dbStore.UpsertTransaction(context.Background(), payload.WebhookURL, payload.TransactionHash, payload.TransactionXDR, tss.RPCTXStatus{OtherStatus: tss.NewStatus}) + require.NoError(t, err) sendResp := entities.RPCSendTransactionResult{ Status: entities.ErrorStatus, ErrorResultXDR: "AAAAAAAAAMj////9AAAAAA==", @@ -230,16 +260,21 @@ func TestBuildAndSubmitTransaction(t *testing.T) { Once() payload.FeeBump = true - txSendResp, err := txManager.BuildAndSubmitTransaction(context.Background(), "channel", payload) + var txSendResp tss.RPCSendTxResponse + txSendResp, err = txManager.BuildAndSubmitTransaction(context.Background(), "channel", payload) assert.Equal(t, entities.ErrorStatus, txSendResp.Status.RPCStatus) assert.Equal(t, xdr.TransactionResultCodeTxTooLate, txSendResp.Code.TxResultCode) assert.Empty(t, err) - tx, _ := store.GetTransaction(context.Background(), payload.TransactionHash) + var tx store.Transaction + tx, err = dbStore.GetTransaction(context.Background(), payload.TransactionHash) + require.NoError(t, err) assert.Equal(t, string(entities.ErrorStatus), tx.Status) - try, _ := store.GetTry(context.Background(), feeBumpTxHash) + var try store.Try + try, err = dbStore.GetTry(context.Background(), feeBumpTxHash) + require.NoError(t, err) assert.Equal(t, string(entities.ErrorStatus), try.Status) assert.Equal(t, int32(xdr.TransactionResultCodeTxTooLate), try.Code) }) @@ -255,7 +290,8 @@ func TestBuildAndSubmitTransaction(t *testing.T) { mockMetricsService.On("IncDBQuery", "SELECT", "tss_transaction_submission_tries").Once() defer mockMetricsService.AssertExpectations(t) - _ = store.UpsertTransaction(context.Background(), payload.WebhookURL, payload.TransactionHash, payload.TransactionXDR, tss.RPCTXStatus{OtherStatus: tss.NewStatus}) + err = dbStore.UpsertTransaction(context.Background(), payload.WebhookURL, payload.TransactionHash, payload.TransactionXDR, tss.RPCTXStatus{OtherStatus: tss.NewStatus}) + require.NoError(t, err) sendResp := entities.RPCSendTransactionResult{ Status: entities.ErrorStatus, ErrorResultXDR: "AAAAAAAAAMj////9AAAAAA==", @@ -280,10 +316,12 @@ func TestBuildAndSubmitTransaction(t *testing.T) { assert.Equal(t, xdr.TransactionResultCodeTxTooLate, txSendResp.Code.TxResultCode) assert.Empty(t, err) - tx, _ := store.GetTransaction(context.Background(), payload.TransactionHash) + tx, err := dbStore.GetTransaction(context.Background(), payload.TransactionHash) + require.NoError(t, err) assert.Equal(t, string(entities.ErrorStatus), tx.Status) - try, _ := store.GetTry(context.Background(), txHash) + try, err := dbStore.GetTry(context.Background(), txHash) + require.NoError(t, err) assert.Equal(t, string(entities.ErrorStatus), try.Status) assert.Equal(t, int32(xdr.TransactionResultCodeTxTooLate), try.Code) }) diff --git a/internal/tss/services/transaction_service.go b/internal/tss/services/transaction_service.go index ebccabfbf..dac630f7e 100644 --- a/internal/tss/services/transaction_service.go +++ b/internal/tss/services/transaction_service.go @@ -67,7 +67,7 @@ func (o *TransactionServiceOptions) ValidateOptions() error { func NewTransactionService(opts TransactionServiceOptions) (*transactionService, error) { if err := opts.ValidateOptions(); err != nil { - return nil, err + return nil, fmt.Errorf("validating transaction service options: %w", err) } return &transactionService{ DB: opts.DB, @@ -84,48 +84,41 @@ func (t *transactionService) NetworkPassphrase() string { } func (t *transactionService) BuildAndSignTransactionWithChannelAccount(ctx context.Context, operations []txnbuild.Operation, timeoutInSecs int64) (*txnbuild.Transaction, error) { - var tx *txnbuild.Transaction - var channelAccountPublicKey string - err := db.RunInTransaction(ctx, t.DB, nil, func(dbTx db.Transaction) error { - var err error - channelAccountPublicKey, err = t.ChannelAccountSignatureClient.GetAccountPublicKey(ctx, int(timeoutInSecs)) - if err != nil { - return fmt.Errorf("getting channel account public key: %w", err) - } - channelAccountSeq, err := t.RPCService.GetAccountLedgerSequence(channelAccountPublicKey) - if err != nil { - return fmt.Errorf("getting ledger sequence for channel account public key: %s: %w", channelAccountPublicKey, err) - } - tx, err = txnbuild.NewTransaction( - txnbuild.TransactionParams{ - SourceAccount: &txnbuild.SimpleAccount{ - AccountID: channelAccountPublicKey, - Sequence: channelAccountSeq, - }, - Operations: operations, - BaseFee: int64(t.BaseFee), - Preconditions: txnbuild.Preconditions{ - TimeBounds: txnbuild.NewTimeout(timeoutInSecs), - }, - IncrementSequenceNum: true, + channelAccountPublicKey, err := t.ChannelAccountSignatureClient.GetAccountPublicKey(ctx, int(timeoutInSecs)) + if err != nil { + return nil, fmt.Errorf("getting channel account public key: %w", err) + } + channelAccountSeq, err := t.RPCService.GetAccountLedgerSequence(channelAccountPublicKey) + if err != nil { + return nil, fmt.Errorf("getting ledger sequence for channel account public key %q: %w", channelAccountPublicKey, err) + } + tx, err := txnbuild.NewTransaction( + txnbuild.TransactionParams{ + SourceAccount: &txnbuild.SimpleAccount{ + AccountID: channelAccountPublicKey, + Sequence: channelAccountSeq, }, - ) - if err != nil { - return fmt.Errorf("building transaction: %w", err) - } - txHash, err := tx.HashHex(t.ChannelAccountSignatureClient.NetworkPassphrase()) - if err != nil { - return fmt.Errorf("unable to hashhex transaction: %w", err) - } - err = t.ChannelAccountStore.AssignTxToChannelAccount(ctx, channelAccountPublicKey, txHash) - if err != nil { - return fmt.Errorf("assigning channel account to tx: %w", err) - } - return nil - }) + Operations: operations, + BaseFee: int64(t.BaseFee), + Preconditions: txnbuild.Preconditions{ + TimeBounds: txnbuild.NewTimeout(timeoutInSecs), + }, + IncrementSequenceNum: true, + }, + ) + if err != nil { + return nil, fmt.Errorf("building transaction: %w", err) + } + txHash, err := tx.HashHex(t.ChannelAccountSignatureClient.NetworkPassphrase()) if err != nil { - return nil, err + return nil, fmt.Errorf("unable to hashhex transaction: %w", err) } + + err = t.ChannelAccountStore.AssignTxToChannelAccount(ctx, channelAccountPublicKey, txHash) + if err != nil { + return nil, fmt.Errorf("assigning channel account to tx: %w", err) + } + tx, err = t.ChannelAccountSignatureClient.SignStellarTransaction(ctx, tx, channelAccountPublicKey) if err != nil { return nil, fmt.Errorf("signing transaction with channel account: %w", err) diff --git a/internal/tss/services/transaction_service_test.go b/internal/tss/services/transaction_service_test.go index 18313df2a..012d780bf 100644 --- a/internal/tss/services/transaction_service_test.go +++ b/internal/tss/services/transaction_service_test.go @@ -36,7 +36,6 @@ func TestValidateOptions(t *testing.T) { } err := opts.ValidateOptions() assert.Equal(t, "DB cannot be nil", err.Error()) - }) t.Run("return_error_when_distribution_signature_client_nil", func(t *testing.T) { opts := TransactionServiceOptions{ @@ -49,7 +48,6 @@ func TestValidateOptions(t *testing.T) { } err := opts.ValidateOptions() assert.Equal(t, "distribution account signature client cannot be nil", err.Error()) - }) t.Run("return_error_when_channel_signature_client_nil", func(t *testing.T) { @@ -115,7 +113,7 @@ func TestBuildAndSignTransactionWithChannelAccount(t *testing.T) { channelAccountSignatureClient := signing.SignatureClientMock{} channelAccountStore := store.ChannelAccountStoreMock{} mockRPCService := &services.RPCServiceMock{} - txService, _ := NewTransactionService(TransactionServiceOptions{ + txService, err := NewTransactionService(TransactionServiceOptions{ DB: dbConnectionPool, DistributionAccountSignatureClient: &distributionAccountSignatureClient, ChannelAccountSignatureClient: &channelAccountSignatureClient, @@ -123,7 +121,8 @@ func TestBuildAndSignTransactionWithChannelAccount(t *testing.T) { RPCService: mockRPCService, BaseFee: 114, }) - atomicTxErrorPrefix := "running atomic function in RunInTransactionWithResult: " + require.NoError(t, err) + t.Run("channel_account_signature_client_get_account_public_key_err", func(t *testing.T) { channelAccountSignatureClient. On("GetAccountPublicKey", context.Background()). @@ -134,7 +133,7 @@ func TestBuildAndSignTransactionWithChannelAccount(t *testing.T) { channelAccountSignatureClient.AssertExpectations(t) assert.Empty(t, tx) - assert.Equal(t, atomicTxErrorPrefix+"getting channel account public key: channel accounts unavailable", err.Error()) + assert.Equal(t, "getting channel account public key: channel accounts unavailable", err.Error()) }) t.Run("rpc_client_get_account_seq_err", func(t *testing.T) { @@ -154,8 +153,8 @@ func TestBuildAndSignTransactionWithChannelAccount(t *testing.T) { channelAccountSignatureClient.AssertExpectations(t) assert.Empty(t, tx) - expectedErr := fmt.Errorf("getting ledger sequence for channel account public key: %s: rpc service down", channelAccount.Address()) - assert.Equal(t, atomicTxErrorPrefix+expectedErr.Error(), err.Error()) + expectedErr := fmt.Errorf("getting ledger sequence for channel account public key %q: rpc service down", channelAccount.Address()) + assert.Equal(t, expectedErr.Error(), err.Error()) }) t.Run("build_tx_fails", func(t *testing.T) { @@ -175,8 +174,7 @@ func TestBuildAndSignTransactionWithChannelAccount(t *testing.T) { channelAccountSignatureClient.AssertExpectations(t) assert.Empty(t, tx) - assert.Equal(t, atomicTxErrorPrefix+"building transaction: transaction has no operations", err.Error()) - + assert.Equal(t, "building transaction: transaction has no operations", err.Error()) }) t.Run("lock_channel_account_to_tx_err", func(t *testing.T) { @@ -211,7 +209,7 @@ func TestBuildAndSignTransactionWithChannelAccount(t *testing.T) { channelAccountSignatureClient.AssertExpectations(t) channelAccountStore.AssertExpectations(t) assert.Empty(t, tx) - assert.Equal(t, atomicTxErrorPrefix+"assigning channel account to tx: unable to assign channel account to tx", err.Error()) + assert.Equal(t, "assigning channel account to tx: unable to assign channel account to tx", err.Error()) }) t.Run("sign_stellar_transaction_w_channel_account_err", func(t *testing.T) { @@ -253,7 +251,7 @@ func TestBuildAndSignTransactionWithChannelAccount(t *testing.T) { }) t.Run("returns_signed_tx", func(t *testing.T) { - signedTx := utils.BuildTestTransaction() + signedTx := utils.BuildTestTransaction(t) channelAccount := keypair.MustRandom() channelAccountSignatureClient. On("GetAccountPublicKey", context.Background()). @@ -301,7 +299,7 @@ func TestBuildFeeBumpTransaction(t *testing.T) { channelAccountSignatureClient := signing.SignatureClientMock{} channelAccountStore := store.ChannelAccountStoreMock{} mockRPCService := &services.RPCServiceMock{} - txService, _ := NewTransactionService(TransactionServiceOptions{ + txService, err := NewTransactionService(TransactionServiceOptions{ DB: dbConnectionPool, DistributionAccountSignatureClient: &distributionAccountSignatureClient, ChannelAccountSignatureClient: &channelAccountSignatureClient, @@ -309,9 +307,9 @@ func TestBuildFeeBumpTransaction(t *testing.T) { RPCService: mockRPCService, BaseFee: 114, }) - + require.NoError(t, err) t.Run("distribution_account_signature_client_get_account_public_key_err", func(t *testing.T) { - tx := utils.BuildTestTransaction() + tx := utils.BuildTestTransaction(t) distributionAccountSignatureClient. On("GetAccountPublicKey", context.Background()). Return("", errors.New("channel accounts unavailable")). @@ -339,7 +337,7 @@ func TestBuildFeeBumpTransaction(t *testing.T) { }) t.Run("signing_feebump_tx_fails", func(t *testing.T) { - tx := utils.BuildTestTransaction() + tx := utils.BuildTestTransaction(t) distributionAccount := keypair.MustRandom() distributionAccountSignatureClient. On("GetAccountPublicKey", context.Background()). @@ -357,8 +355,8 @@ func TestBuildFeeBumpTransaction(t *testing.T) { }) t.Run("returns_singed_feebump_tx", func(t *testing.T) { - tx := utils.BuildTestTransaction() - feeBump := utils.BuildTestFeeBumpTransaction() + tx := utils.BuildTestTransaction(t) + feeBump := utils.BuildTestFeeBumpTransaction(t) distributionAccount := keypair.MustRandom() distributionAccountSignatureClient. On("GetAccountPublicKey", context.Background()). @@ -374,5 +372,4 @@ func TestBuildFeeBumpTransaction(t *testing.T) { assert.Equal(t, feeBump, feeBumpTx) assert.NoError(t, err) }) - } diff --git a/internal/tss/store/store_test.go b/internal/tss/store/store_test.go index bee319633..1c133bf93 100644 --- a/internal/tss/store/store_test.go +++ b/internal/tss/store/store_test.go @@ -5,14 +5,15 @@ import ( "testing" "github.com/stellar/go/xdr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stellar/wallet-backend/internal/db" "github.com/stellar/wallet-backend/internal/db/dbtest" "github.com/stellar/wallet-backend/internal/entities" "github.com/stellar/wallet-backend/internal/metrics" "github.com/stellar/wallet-backend/internal/tss" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" ) func TestUpsertTransaction(t *testing.T) { @@ -21,8 +22,11 @@ func TestUpsertTransaction(t *testing.T) { dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) require.NoError(t, err) defer dbConnectionPool.Close() + mockMetricsService := metrics.NewMockMetricsService() - store, _ := NewStore(dbConnectionPool, mockMetricsService) + store, err := NewStore(dbConnectionPool, mockMetricsService) + require.NoError(t, err) + t.Run("insert", func(t *testing.T) { mockMetricsService.On("ObserveDBQueryDuration", "SELECT", "tss_transactions", mock.AnythingOfType("float64")).Once() mockMetricsService.On("IncDBQuery", "SELECT", "tss_transactions").Once() @@ -30,9 +34,12 @@ func TestUpsertTransaction(t *testing.T) { mockMetricsService.On("IncDBQuery", "INSERT", "tss_transactions").Once() defer mockMetricsService.AssertExpectations(t) - _ = store.UpsertTransaction(context.Background(), "www.stellar.org", "hash", "xdr", tss.RPCTXStatus{OtherStatus: tss.NewStatus}) + err = store.UpsertTransaction(context.Background(), "www.stellar.org", "hash", "xdr", tss.RPCTXStatus{OtherStatus: tss.NewStatus}) + require.NoError(t, err) - tx, _ := store.GetTransaction(context.Background(), "hash") + var tx Transaction + tx, err = store.GetTransaction(context.Background(), "hash") + require.NoError(t, err) assert.Equal(t, "xdr", tx.XDR) assert.Equal(t, string(tss.NewStatus), tx.Status) }) @@ -44,10 +51,13 @@ func TestUpsertTransaction(t *testing.T) { mockMetricsService.On("IncDBQuery", "INSERT", "tss_transactions").Times(2) defer mockMetricsService.AssertExpectations(t) - _ = store.UpsertTransaction(context.Background(), "www.stellar.org", "hash", "xdr", tss.RPCTXStatus{OtherStatus: tss.NewStatus}) - _ = store.UpsertTransaction(context.Background(), "www.stellar.org", "hash", "xdr", tss.RPCTXStatus{RPCStatus: entities.SuccessStatus}) + err = store.UpsertTransaction(context.Background(), "www.stellar.org", "hash", "xdr", tss.RPCTXStatus{OtherStatus: tss.NewStatus}) + require.NoError(t, err) + err = store.UpsertTransaction(context.Background(), "www.stellar.org", "hash", "xdr", tss.RPCTXStatus{RPCStatus: entities.SuccessStatus}) + require.NoError(t, err) - tx, _ := store.GetTransaction(context.Background(), "hash") + tx, err := store.GetTransaction(context.Background(), "hash") + require.NoError(t, err) assert.Equal(t, "xdr", tx.XDR) assert.Equal(t, string(entities.SuccessStatus), tx.Status) @@ -64,8 +74,11 @@ func TestUpsertTry(t *testing.T) { dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) require.NoError(t, err) defer dbConnectionPool.Close() + mockMetricsService := metrics.NewMockMetricsService() - store, _ := NewStore(dbConnectionPool, mockMetricsService) + store, err := NewStore(dbConnectionPool, mockMetricsService) + require.NoError(t, err) + t.Run("insert", func(t *testing.T) { mockMetricsService.On("ObserveDBQueryDuration", "INSERT", "tss_transaction_submission_tries", mock.AnythingOfType("float64")).Once() mockMetricsService.On("IncDBQuery", "INSERT", "tss_transaction_submission_tries").Once() @@ -77,8 +90,11 @@ func TestUpsertTry(t *testing.T) { code := tss.RPCTXCode{OtherCodes: tss.NewCode} resultXDR := "ABCD//" err = store.UpsertTry(context.Background(), "hash", "feebumptxhash", "feebumptxxdr", status, code, resultXDR) + require.NoError(t, err) - try, _ := store.GetTry(context.Background(), "feebumptxhash") + var try Try + try, err = store.GetTry(context.Background(), "feebumptxhash") + require.NoError(t, err) assert.Equal(t, "hash", try.OrigTxHash) assert.Equal(t, status.Status(), try.Status) assert.Equal(t, code.Code(), int(try.Code)) @@ -96,11 +112,15 @@ func TestUpsertTry(t *testing.T) { status := tss.RPCTXStatus{OtherStatus: tss.NewStatus} code := tss.RPCTXCode{OtherCodes: tss.NewCode} resultXDR := "ABCD//" - _ = store.UpsertTry(context.Background(), "hash", "feebumptxhash", "feebumptxxdr", status, code, resultXDR) + err = store.UpsertTry(context.Background(), "hash", "feebumptxhash", "feebumptxxdr", status, code, resultXDR) + require.NoError(t, err) code = tss.RPCTXCode{OtherCodes: tss.RPCFailCode} - _ = store.UpsertTry(context.Background(), "hash", "feebumptxhash", "feebumptxxdr", status, code, resultXDR) + err = store.UpsertTry(context.Background(), "hash", "feebumptxhash", "feebumptxxdr", status, code, resultXDR) + require.NoError(t, err) - try, _ := store.GetTry(context.Background(), "feebumptxhash") + var try Try + try, err = store.GetTry(context.Background(), "feebumptxhash") + require.NoError(t, err) assert.Equal(t, "hash", try.OrigTxHash) assert.Equal(t, status.Status(), try.Status) assert.Equal(t, code.Code(), int(try.Code)) @@ -122,11 +142,14 @@ func TestUpsertTry(t *testing.T) { status := tss.RPCTXStatus{RPCStatus: entities.ErrorStatus} code := tss.RPCTXCode{TxResultCode: xdr.TransactionResultCodeTxInsufficientFee} resultXDR := "ABCD//" - _ = store.UpsertTry(context.Background(), "hash", "feebumptxhash", "feebumptxxdr", status, code, resultXDR) + err = store.UpsertTry(context.Background(), "hash", "feebumptxhash", "feebumptxxdr", status, code, resultXDR) + require.NoError(t, err) code = tss.RPCTXCode{TxResultCode: xdr.TransactionResultCodeTxSuccess} - _ = store.UpsertTry(context.Background(), "hash", "feebumptxhash", "feebumptxxdr", status, code, resultXDR) + err = store.UpsertTry(context.Background(), "hash", "feebumptxhash", "feebumptxxdr", status, code, resultXDR) + require.NoError(t, err) - try, _ := store.GetTry(context.Background(), "feebumptxhash") + try, err := store.GetTry(context.Background(), "feebumptxhash") + require.NoError(t, err) assert.Equal(t, "hash", try.OrigTxHash) assert.Equal(t, status.Status(), try.Status) assert.Equal(t, code.Code(), int(try.Code)) @@ -145,8 +168,11 @@ func TestGetTransaction(t *testing.T) { dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) require.NoError(t, err) defer dbConnectionPool.Close() + mockMetricsService := metrics.NewMockMetricsService() - store, _ := NewStore(dbConnectionPool, mockMetricsService) + store, err := NewStore(dbConnectionPool, mockMetricsService) + require.NoError(t, err) + t.Run("transaction_exists", func(t *testing.T) { mockMetricsService.On("ObserveDBQueryDuration", "SELECT", "tss_transactions", mock.AnythingOfType("float64")).Once() mockMetricsService.On("IncDBQuery", "SELECT", "tss_transactions").Once() @@ -155,21 +181,21 @@ func TestGetTransaction(t *testing.T) { defer mockMetricsService.AssertExpectations(t) status := tss.RPCTXStatus{OtherStatus: tss.NewStatus} - _ = store.UpsertTransaction(context.Background(), "localhost:8000", "hash", "xdr", status) + err = store.UpsertTransaction(context.Background(), "localhost:8000", "hash", "xdr", status) + require.NoError(t, err) tx, err := store.GetTransaction(context.Background(), "hash") - + require.NoError(t, err) assert.Equal(t, "xdr", tx.XDR) - assert.Empty(t, err) }) t.Run("transaction_does_not_exist", func(t *testing.T) { mockMetricsService.On("ObserveDBQueryDuration", "SELECT", "tss_transactions", mock.AnythingOfType("float64")).Once() mockMetricsService.On("IncDBQuery", "SELECT", "tss_transactions").Once() defer mockMetricsService.AssertExpectations(t) - tx, _ := store.GetTransaction(context.Background(), "doesnotexist") - assert.Equal(t, Transaction{}, tx) - assert.Empty(t, err) + tx, err := store.GetTransaction(context.Background(), "doesnotexist") + require.NoError(t, err) + assert.Empty(t, tx) }) } @@ -179,8 +205,11 @@ func TestGetTry(t *testing.T) { dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) require.NoError(t, err) defer dbConnectionPool.Close() + mockMetricsService := metrics.NewMockMetricsService() - store, _ := NewStore(dbConnectionPool, mockMetricsService) + store, err := NewStore(dbConnectionPool, mockMetricsService) + require.NoError(t, err) + t.Run("try_exists", func(t *testing.T) { mockMetricsService.On("ObserveDBQueryDuration", "SELECT", "tss_transaction_submission_tries", mock.AnythingOfType("float64")).Once() mockMetricsService.On("IncDBQuery", "SELECT", "tss_transaction_submission_tries").Once() @@ -191,24 +220,24 @@ func TestGetTry(t *testing.T) { status := tss.RPCTXStatus{OtherStatus: tss.NewStatus} code := tss.RPCTXCode{OtherCodes: tss.NewCode} resultXDR := "ABCD//" - _ = store.UpsertTry(context.Background(), "hash", "feebumptxhash", "feebumptxxdr", status, code, resultXDR) + err = store.UpsertTry(context.Background(), "hash", "feebumptxhash", "feebumptxxdr", status, code, resultXDR) + require.NoError(t, err) try, err := store.GetTry(context.Background(), "feebumptxhash") - + require.NoError(t, err) assert.Equal(t, "hash", try.OrigTxHash) assert.Equal(t, status.Status(), try.Status) assert.Equal(t, code.Code(), int(try.Code)) assert.Equal(t, resultXDR, try.ResultXDR) - assert.Empty(t, err) }) t.Run("try_does_not_exist", func(t *testing.T) { mockMetricsService.On("ObserveDBQueryDuration", "SELECT", "tss_transaction_submission_tries", mock.AnythingOfType("float64")).Once() mockMetricsService.On("IncDBQuery", "SELECT", "tss_transaction_submission_tries").Once() defer mockMetricsService.AssertExpectations(t) - try, _ := store.GetTry(context.Background(), "doesnotexist") - assert.Equal(t, Try{}, try) - assert.Empty(t, err) + try, err := store.GetTry(context.Background(), "doesnotexist") + require.NoError(t, err) + assert.Empty(t, try) }) } @@ -218,8 +247,11 @@ func TestGetTryByXDR(t *testing.T) { dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) require.NoError(t, err) defer dbConnectionPool.Close() + mockMetricsService := metrics.NewMockMetricsService() - store, _ := NewStore(dbConnectionPool, mockMetricsService) + store, err := NewStore(dbConnectionPool, mockMetricsService) + require.NoError(t, err) + t.Run("try_exists", func(t *testing.T) { mockMetricsService.On("ObserveDBQueryDuration", "SELECT", "tss_transaction_submission_tries", mock.AnythingOfType("float64")).Once() mockMetricsService.On("IncDBQuery", "SELECT", "tss_transaction_submission_tries").Once() @@ -230,24 +262,24 @@ func TestGetTryByXDR(t *testing.T) { status := tss.RPCTXStatus{OtherStatus: tss.NewStatus} code := tss.RPCTXCode{OtherCodes: tss.NewCode} resultXDR := "ABCD//" - _ = store.UpsertTry(context.Background(), "hash", "feebumptxhash", "feebumptxxdr", status, code, resultXDR) + err = store.UpsertTry(context.Background(), "hash", "feebumptxhash", "feebumptxxdr", status, code, resultXDR) + require.NoError(t, err) try, err := store.GetTryByXDR(context.Background(), "feebumptxxdr") - + require.NoError(t, err) assert.Equal(t, "hash", try.OrigTxHash) assert.Equal(t, status.Status(), try.Status) assert.Equal(t, code.Code(), int(try.Code)) assert.Equal(t, resultXDR, try.ResultXDR) - assert.Empty(t, err) }) t.Run("try_does_not_exist", func(t *testing.T) { mockMetricsService.On("ObserveDBQueryDuration", "SELECT", "tss_transaction_submission_tries", mock.AnythingOfType("float64")).Once() mockMetricsService.On("IncDBQuery", "SELECT", "tss_transaction_submission_tries").Once() defer mockMetricsService.AssertExpectations(t) - try, _ := store.GetTryByXDR(context.Background(), "doesnotexist") - assert.Equal(t, Try{}, try) - assert.Empty(t, err) + try, err := store.GetTryByXDR(context.Background(), "doesnotexist") + require.NoError(t, err) + assert.Empty(t, try) }) } @@ -257,8 +289,10 @@ func TestGetTransactionsWithStatus(t *testing.T) { dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) require.NoError(t, err) defer dbConnectionPool.Close() + mockMetricsService := metrics.NewMockMetricsService() - store, _ := NewStore(dbConnectionPool, mockMetricsService) + store, err := NewStore(dbConnectionPool, mockMetricsService) + require.NoError(t, err) t.Run("transactions_do_not_exist", func(t *testing.T) { mockMetricsService.On("ObserveDBQueryDuration", "SELECT", "tss_transactions", mock.AnythingOfType("float64")).Once() @@ -266,9 +300,10 @@ func TestGetTransactionsWithStatus(t *testing.T) { defer mockMetricsService.AssertExpectations(t) status := tss.RPCTXStatus{OtherStatus: tss.NewStatus} - txns, err := store.GetTransactionsWithStatus(context.Background(), status) - assert.Equal(t, 0, len(txns)) - assert.Empty(t, err) + var txns []Transaction + txns, err = store.GetTransactionsWithStatus(context.Background(), status) + require.NoError(t, err) + assert.Empty(t, txns) }) t.Run("transactions_exist", func(t *testing.T) { @@ -279,15 +314,16 @@ func TestGetTransactionsWithStatus(t *testing.T) { defer mockMetricsService.AssertExpectations(t) status := tss.RPCTXStatus{OtherStatus: tss.NewStatus} - _ = store.UpsertTransaction(context.Background(), "localhost:8000", "hash1", "xdr1", status) - _ = store.UpsertTransaction(context.Background(), "localhost:8000", "hash2", "xdr2", status) + err = store.UpsertTransaction(context.Background(), "localhost:8000", "hash1", "xdr1", status) + require.NoError(t, err) + err = store.UpsertTransaction(context.Background(), "localhost:8000", "hash2", "xdr2", status) + require.NoError(t, err) txns, err := store.GetTransactionsWithStatus(context.Background(), status) - + require.NoError(t, err) assert.Equal(t, 2, len(txns)) assert.Equal(t, "hash1", txns[0].Hash) assert.Equal(t, "hash2", txns[1].Hash) - assert.Empty(t, err) }) } @@ -297,18 +333,20 @@ func TestGetLatestTry(t *testing.T) { dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) require.NoError(t, err) defer dbConnectionPool.Close() + mockMetricsService := metrics.NewMockMetricsService() - store, _ := NewStore(dbConnectionPool, mockMetricsService) + store, err := NewStore(dbConnectionPool, mockMetricsService) + require.NoError(t, err) t.Run("tries_do_not_exist", func(t *testing.T) { mockMetricsService.On("ObserveDBQueryDuration", "SELECT", "tss_transaction_submission_tries", mock.AnythingOfType("float64")).Once() mockMetricsService.On("IncDBQuery", "SELECT", "tss_transaction_submission_tries").Once() defer mockMetricsService.AssertExpectations(t) - try, err := store.GetLatestTry(context.Background(), "hash") - - assert.Equal(t, Try{}, try) - assert.Empty(t, err) + var try Try + try, err = store.GetLatestTry(context.Background(), "hash") + require.NoError(t, err) + assert.Empty(t, try) }) t.Run("tries_exist", func(t *testing.T) { @@ -321,12 +359,13 @@ func TestGetLatestTry(t *testing.T) { status := tss.RPCTXStatus{OtherStatus: tss.NewStatus} code := tss.RPCTXCode{OtherCodes: tss.NewCode} resultXDR := "ABCD//" - _ = store.UpsertTry(context.Background(), "hash", "feebumptxhash1", "feebumptxxdr1", status, code, resultXDR) - _ = store.UpsertTry(context.Background(), "hash", "feebumptxhash2", "feebumptxxdr2", status, code, resultXDR) + err = store.UpsertTry(context.Background(), "hash", "feebumptxhash1", "feebumptxxdr1", status, code, resultXDR) + require.NoError(t, err) + err = store.UpsertTry(context.Background(), "hash", "feebumptxhash2", "feebumptxxdr2", status, code, resultXDR) + require.NoError(t, err) try, err := store.GetLatestTry(context.Background(), "hash") - + require.NoError(t, err) assert.Equal(t, "feebumptxhash2", try.Hash) - assert.Empty(t, err) }) } diff --git a/internal/tss/types.go b/internal/tss/types.go index d7de9c56a..8153adfc3 100644 --- a/internal/tss/types.go +++ b/internal/tss/types.go @@ -8,6 +8,7 @@ import ( xdr3 "github.com/stellar/go-xdr/xdr3" "github.com/stellar/go/xdr" + "github.com/stellar/wallet-backend/internal/entities" ) @@ -27,7 +28,7 @@ type RPCGetIngestTxResponse struct { func ParseToRPCGetIngestTxResponse(result entities.RPCGetTransactionResult, err error) (RPCGetIngestTxResponse, error) { if err != nil { - return RPCGetIngestTxResponse{Status: entities.ErrorStatus}, err + return RPCGetIngestTxResponse{Status: entities.ErrorStatus}, fmt.Errorf("parseing to rpc get ingest tx response: %w", err) } getIngestTxResponse := RPCGetIngestTxResponse{ @@ -198,9 +199,9 @@ type Payload struct { // The xdr of the transaction TransactionXDR string // Relevant fields in an RPC sendTransaction response - RpcSubmitTxResponse RPCSendTxResponse + RPCSubmitTxResponse RPCSendTxResponse // Relevant fields in the transaction list inside the RPC getTransactions response - RpcGetIngestTxResponse RPCGetIngestTxResponse + RPCGetIngestTxResponse RPCGetIngestTxResponse // indicates if the transaction to be built from this payload should be wrapped in a fee bump transaction FeeBump bool } diff --git a/internal/tss/types_test.go b/internal/tss/types_test.go index ad5f0f513..a99003376 100644 --- a/internal/tss/types_test.go +++ b/internal/tss/types_test.go @@ -5,9 +5,10 @@ import ( "testing" "github.com/stellar/go/xdr" - "github.com/stellar/wallet-backend/internal/entities" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/stellar/wallet-backend/internal/entities" ) func TestParseToRPCSendTxResponse(t *testing.T) { @@ -16,7 +17,7 @@ func TestParseToRPCSendTxResponse(t *testing.T) { require.Error(t, err) assert.Equal(t, entities.ErrorStatus, resp.Status) - assert.Equal(t, "sending sendTransaction request: sending POST request to RPC: connection failed", err.Error()) + assert.ErrorContains(t, err, "sending sendTransaction request: sending POST request to RPC: connection failed") }) t.Run("response_has_empty_errorResultXdr", func(t *testing.T) { @@ -57,7 +58,7 @@ func TestParseToRPCGetIngestTxResponse(t *testing.T) { require.Error(t, err) assert.Equal(t, entities.ErrorStatus, resp.Status) - assert.Equal(t, "sending getTransaction request: sending POST request to RPC: connection failed", err.Error()) + assert.ErrorContains(t, err, "sending getTransaction request: sending POST request to RPC: connection failed") }) t.Run("unable_to_parse_createdAt", func(t *testing.T) { @@ -68,7 +69,7 @@ func TestParseToRPCGetIngestTxResponse(t *testing.T) { require.Error(t, err) assert.Equal(t, entities.ErrorStatus, resp.Status) - assert.Equal(t, "unable to parse createdAt: strconv.ParseInt: parsing \"ABCD\": invalid syntax", err.Error()) + assert.ErrorContains(t, err, "unable to parse createdAt: strconv.ParseInt: parsing \"ABCD\": invalid syntax") }) t.Run("response_has_createdAt_field", func(t *testing.T) { diff --git a/internal/tss/utils/helpers.go b/internal/tss/utils/helpers.go index 1c8e25818..a57b7c8fc 100644 --- a/internal/tss/utils/helpers.go +++ b/internal/tss/utils/helpers.go @@ -1,58 +1,23 @@ package utils import ( - "github.com/stellar/go/keypair" - "github.com/stellar/go/txnbuild" "github.com/stellar/wallet-backend/internal/tss" ) func PayloadTOTSSResponse(payload tss.Payload) tss.TSSResponse { response := tss.TSSResponse{} response.TransactionHash = payload.TransactionHash - if payload.RpcSubmitTxResponse.Status.Status() != "" { - response.Status = string(payload.RpcSubmitTxResponse.Status.Status()) - response.TransactionResultCode = payload.RpcSubmitTxResponse.Code.TxResultCode.String() - response.EnvelopeXDR = payload.RpcSubmitTxResponse.TransactionXDR - response.ResultXDR = payload.RpcSubmitTxResponse.ErrorResultXDR - } else if payload.RpcGetIngestTxResponse.Status != "" { - response.Status = string(payload.RpcGetIngestTxResponse.Status) - response.TransactionResultCode = payload.RpcGetIngestTxResponse.Code.TxResultCode.String() - response.EnvelopeXDR = payload.RpcGetIngestTxResponse.EnvelopeXDR - response.ResultXDR = payload.RpcGetIngestTxResponse.ResultXDR - response.CreatedAt = payload.RpcGetIngestTxResponse.CreatedAt + if payload.RPCSubmitTxResponse.Status.Status() != "" { + response.Status = string(payload.RPCSubmitTxResponse.Status.Status()) + response.TransactionResultCode = payload.RPCSubmitTxResponse.Code.TxResultCode.String() + response.EnvelopeXDR = payload.RPCSubmitTxResponse.TransactionXDR + response.ResultXDR = payload.RPCSubmitTxResponse.ErrorResultXDR + } else if payload.RPCGetIngestTxResponse.Status != "" { + response.Status = string(payload.RPCGetIngestTxResponse.Status) + response.TransactionResultCode = payload.RPCGetIngestTxResponse.Code.TxResultCode.String() + response.EnvelopeXDR = payload.RPCGetIngestTxResponse.EnvelopeXDR + response.ResultXDR = payload.RPCGetIngestTxResponse.ResultXDR + response.CreatedAt = payload.RPCGetIngestTxResponse.CreatedAt } return response } - -func BuildTestTransaction() *txnbuild.Transaction { - accountToSponsor := keypair.MustRandom() - - tx, _ := txnbuild.NewTransaction(txnbuild.TransactionParams{ - SourceAccount: &txnbuild.SimpleAccount{ - AccountID: accountToSponsor.Address(), - Sequence: 124, - }, - IncrementSequenceNum: true, - Operations: []txnbuild.Operation{ - &txnbuild.Payment{ - Destination: keypair.MustRandom().Address(), - Amount: "14.0000000", - Asset: txnbuild.NativeAsset{}, - }, - }, - BaseFee: 104, - Preconditions: txnbuild.Preconditions{TimeBounds: txnbuild.NewTimeout(10)}, - }) - return tx -} - -func BuildTestFeeBumpTransaction() *txnbuild.FeeBumpTransaction { - - feeBumpTx, _ := txnbuild.NewFeeBumpTransaction( - txnbuild.FeeBumpTransactionParams{ - Inner: BuildTestTransaction(), - FeeAccount: keypair.MustRandom().Address(), - BaseFee: 110, - }) - return feeBumpTx -} diff --git a/internal/tss/utils/operation_builder_test.go b/internal/tss/utils/operation_builder_test.go index ca1039dbf..22ad89f37 100644 --- a/internal/tss/utils/operation_builder_test.go +++ b/internal/tss/utils/operation_builder_test.go @@ -10,6 +10,7 @@ import ( "github.com/stellar/go/txnbuild" "github.com/stellar/go/xdr" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestBuildOperations(t *testing.T) { @@ -21,14 +22,17 @@ func TestBuildOperations(t *testing.T) { Amount: "10", SourceAccount: srcAccount, } - op, _ := c.BuildXDR() + op, err := c.BuildXDR() + require.NoError(t, err) var buf strings.Builder enc := xdr3.NewEncoder(&buf) - _ = op.EncodeTo(enc) + err = op.EncodeTo(enc) + require.NoError(t, err) opXDR := buf.String() opXDRBase64 := base64.StdEncoding.EncodeToString([]byte(opXDR)) - ops, _ := BuildOperations([]string{opXDRBase64}) + ops, err := BuildOperations([]string{opXDRBase64}) + require.NoError(t, err) assert.Equal(t, srcAccount, ops[0].GetSourceAccount()) assert.Equal(t, dstAccount, ops[0].(*txnbuild.CreateAccount).Destination) @@ -43,14 +47,17 @@ func TestBuildOperations(t *testing.T) { Asset: txnbuild.NativeAsset{}, SourceAccount: srcAccount, } - op, _ := p.BuildXDR() + op, err := p.BuildXDR() + require.NoError(t, err) var buf strings.Builder enc := xdr3.NewEncoder(&buf) - _ = op.EncodeTo(enc) + err = op.EncodeTo(enc) + require.NoError(t, err) opXDR := buf.String() opXDRBase64 := base64.StdEncoding.EncodeToString([]byte(opXDR)) - ops, _ := BuildOperations([]string{opXDRBase64}) + ops, err := BuildOperations([]string{opXDRBase64}) + require.NoError(t, err) assert.Equal(t, srcAccount, ops[0].GetSourceAccount()) assert.Equal(t, string("10.0000000"), ops[0].(*txnbuild.Payment).Amount) @@ -68,14 +75,17 @@ func TestBuildOperations(t *testing.T) { SourceAccount: srcAccount, Price: xdr.Price{N: 10, D: 10}, } - op, _ := m.BuildXDR() + op, err := m.BuildXDR() + require.NoError(t, err) var buf strings.Builder enc := xdr3.NewEncoder(&buf) - _ = op.EncodeTo(enc) + err = op.EncodeTo(enc) + require.NoError(t, err) opXDR := buf.String() opXDRBase64 := base64.StdEncoding.EncodeToString([]byte(opXDR)) - ops, _ := BuildOperations([]string{opXDRBase64}) + ops, err := BuildOperations([]string{opXDRBase64}) + require.NoError(t, err) assert.Equal(t, srcAccount, ops[0].GetSourceAccount()) assert.Equal(t, string("10.0000000"), ops[0].(*txnbuild.ManageSellOffer).Amount) @@ -94,15 +104,17 @@ func TestBuildOperations(t *testing.T) { Price: xdr.Price{N: 10, D: 10}, SourceAccount: srcAccount, } - op, _ := c.BuildXDR() + op, err := c.BuildXDR() + require.NoError(t, err) var buf strings.Builder enc := xdr3.NewEncoder(&buf) - _ = op.EncodeTo(enc) + err = op.EncodeTo(enc) + require.NoError(t, err) opXDR := buf.String() opXDRBase64 := base64.StdEncoding.EncodeToString([]byte(opXDR)) - ops, _ := BuildOperations([]string{opXDRBase64}) - + ops, err := BuildOperations([]string{opXDRBase64}) + require.NoError(t, err) assert.Equal(t, srcAccount, ops[0].GetSourceAccount()) assert.Equal(t, string("10.0000000"), ops[0].(*txnbuild.CreatePassiveSellOffer).Amount) assert.Equal(t, txnbuild.NativeAsset{}, ops[0].(*txnbuild.CreatePassiveSellOffer).Selling) @@ -115,14 +127,17 @@ func TestBuildOperations(t *testing.T) { s := txnbuild.SetOptions{ SourceAccount: srcAccount, } - op, _ := s.BuildXDR() + op, err := s.BuildXDR() + require.NoError(t, err) var buf strings.Builder enc := xdr3.NewEncoder(&buf) - _ = op.EncodeTo(enc) + err = op.EncodeTo(enc) + require.NoError(t, err) opXDR := buf.String() opXDRBase64 := base64.StdEncoding.EncodeToString([]byte(opXDR)) - ops, _ := BuildOperations([]string{opXDRBase64}) + ops, err := BuildOperations([]string{opXDRBase64}) + require.NoError(t, err) assert.Equal(t, srcAccount, ops[0].GetSourceAccount()) }) @@ -134,14 +149,17 @@ func TestBuildOperations(t *testing.T) { Destination: dstAccount, SourceAccount: srcAccount, } - op, _ := a.BuildXDR() + op, err := a.BuildXDR() + require.NoError(t, err) var buf strings.Builder enc := xdr3.NewEncoder(&buf) - _ = op.EncodeTo(enc) + err = op.EncodeTo(enc) + require.NoError(t, err) opXDR := buf.String() opXDRBase64 := base64.StdEncoding.EncodeToString([]byte(opXDR)) - ops, _ := BuildOperations([]string{opXDRBase64}) + ops, err := BuildOperations([]string{opXDRBase64}) + require.NoError(t, err) assert.Equal(t, srcAccount, ops[0].GetSourceAccount()) assert.Equal(t, dstAccount, ops[0].(*txnbuild.AccountMerge).Destination) @@ -152,14 +170,17 @@ func TestBuildOperations(t *testing.T) { i := txnbuild.Inflation{ SourceAccount: srcAccount, } - op, _ := i.BuildXDR() + op, err := i.BuildXDR() + require.NoError(t, err) var buf strings.Builder enc := xdr3.NewEncoder(&buf) - _ = op.EncodeTo(enc) + err = op.EncodeTo(enc) + require.NoError(t, err) opXDR := buf.String() opXDRBase64 := base64.StdEncoding.EncodeToString([]byte(opXDR)) - ops, _ := BuildOperations([]string{opXDRBase64}) + ops, err := BuildOperations([]string{opXDRBase64}) + require.NoError(t, err) assert.Equal(t, srcAccount, ops[0].GetSourceAccount()) }) @@ -170,14 +191,17 @@ func TestBuildOperations(t *testing.T) { Name: "foo", SourceAccount: srcAccount, } - op, _ := m.BuildXDR() + op, err := m.BuildXDR() + require.NoError(t, err) var buf strings.Builder enc := xdr3.NewEncoder(&buf) - _ = op.EncodeTo(enc) + err = op.EncodeTo(enc) + require.NoError(t, err) opXDR := buf.String() opXDRBase64 := base64.StdEncoding.EncodeToString([]byte(opXDR)) - ops, _ := BuildOperations([]string{opXDRBase64}) + ops, err := BuildOperations([]string{opXDRBase64}) + require.NoError(t, err) assert.Equal(t, srcAccount, ops[0].GetSourceAccount()) assert.Equal(t, "foo", ops[0].(*txnbuild.ManageData).Name) @@ -189,14 +213,17 @@ func TestBuildOperations(t *testing.T) { BumpTo: int64(100), SourceAccount: srcAccount, } - op, _ := b.BuildXDR() + op, err := b.BuildXDR() + require.NoError(t, err) var buf strings.Builder enc := xdr3.NewEncoder(&buf) - _ = op.EncodeTo(enc) + err = op.EncodeTo(enc) + require.NoError(t, err) opXDR := buf.String() opXDRBase64 := base64.StdEncoding.EncodeToString([]byte(opXDR)) - ops, _ := BuildOperations([]string{opXDRBase64}) + ops, err := BuildOperations([]string{opXDRBase64}) + require.NoError(t, err) assert.Equal(t, srcAccount, ops[0].GetSourceAccount()) assert.Equal(t, int64(100), ops[0].(*txnbuild.BumpSequence).BumpTo) @@ -212,14 +239,17 @@ func TestBuildOperations(t *testing.T) { OfferID: int64(100), SourceAccount: srcAccount, } - op, _ := m.BuildXDR() + op, err := m.BuildXDR() + require.NoError(t, err) var buf strings.Builder enc := xdr3.NewEncoder(&buf) - _ = op.EncodeTo(enc) + err = op.EncodeTo(enc) + require.NoError(t, err) opXDR := buf.String() opXDRBase64 := base64.StdEncoding.EncodeToString([]byte(opXDR)) - ops, _ := BuildOperations([]string{opXDRBase64}) + ops, err := BuildOperations([]string{opXDRBase64}) + require.NoError(t, err) assert.Equal(t, srcAccount, ops[0].GetSourceAccount()) assert.Equal(t, txnbuild.NativeAsset{}, ops[0].(*txnbuild.ManageBuyOffer).Selling) @@ -241,14 +271,17 @@ func TestBuildOperations(t *testing.T) { Path: []txnbuild.Asset{}, SourceAccount: srcAccount, } - op, _ := p.BuildXDR() + op, err := p.BuildXDR() + require.NoError(t, err) var buf strings.Builder enc := xdr3.NewEncoder(&buf) - _ = op.EncodeTo(enc) + err = op.EncodeTo(enc) + require.NoError(t, err) opXDR := buf.String() opXDRBase64 := base64.StdEncoding.EncodeToString([]byte(opXDR)) - ops, _ := BuildOperations([]string{opXDRBase64}) + ops, err := BuildOperations([]string{opXDRBase64}) + require.NoError(t, err) assert.Equal(t, srcAccount, ops[0].GetSourceAccount()) assert.Equal(t, txnbuild.NativeAsset{}, ops[0].(*txnbuild.PathPaymentStrictSend).SendAsset) @@ -266,14 +299,17 @@ func TestBuildOperations(t *testing.T) { Asset: txnbuild.NativeAsset{}, SourceAccount: srcAccount, } - op, _ := c.BuildXDR() + op, err := c.BuildXDR() + require.NoError(t, err) var buf strings.Builder enc := xdr3.NewEncoder(&buf) - _ = op.EncodeTo(enc) + err = op.EncodeTo(enc) + require.NoError(t, err) opXDR := buf.String() opXDRBase64 := base64.StdEncoding.EncodeToString([]byte(opXDR)) - ops, _ := BuildOperations([]string{opXDRBase64}) + ops, err := BuildOperations([]string{opXDRBase64}) + require.NoError(t, err) assert.Equal(t, srcAccount, ops[0].GetSourceAccount()) assert.Equal(t, "10.0000000", ops[0].(*txnbuild.CreateClaimableBalance).Amount) @@ -285,14 +321,17 @@ func TestBuildOperations(t *testing.T) { e := txnbuild.EndSponsoringFutureReserves{ SourceAccount: srcAccount, } - op, _ := e.BuildXDR() + op, err := e.BuildXDR() + require.NoError(t, err) var buf strings.Builder enc := xdr3.NewEncoder(&buf) - _ = op.EncodeTo(enc) + err = op.EncodeTo(enc) + require.NoError(t, err) opXDR := buf.String() opXDRBase64 := base64.StdEncoding.EncodeToString([]byte(opXDR)) - ops, _ := BuildOperations([]string{opXDRBase64}) + ops, err := BuildOperations([]string{opXDRBase64}) + require.NoError(t, err) assert.Equal(t, srcAccount, ops[0].GetSourceAccount()) }) @@ -305,14 +344,17 @@ func TestBuildOperations(t *testing.T) { MinPrice: xdr.Price{N: 10, D: 10}, MaxPrice: xdr.Price{N: 10, D: 10}, } - op, _ := l.BuildXDR() + op, err := l.BuildXDR() + require.NoError(t, err) var buf strings.Builder enc := xdr3.NewEncoder(&buf) - _ = op.EncodeTo(enc) + err = op.EncodeTo(enc) + require.NoError(t, err) opXDR := buf.String() opXDRBase64 := base64.StdEncoding.EncodeToString([]byte(opXDR)) - ops, _ := BuildOperations([]string{opXDRBase64}) + ops, err := BuildOperations([]string{opXDRBase64}) + require.NoError(t, err) assert.Equal(t, srcAccount, ops[0].GetSourceAccount()) assert.Equal(t, "10.0000000", ops[0].(*txnbuild.LiquidityPoolDeposit).MaxAmountA) @@ -329,14 +371,17 @@ func TestBuildOperations(t *testing.T) { MinAmountA: "10", MinAmountB: "10", } - op, _ := l.BuildXDR() + op, err := l.BuildXDR() + require.NoError(t, err) var buf strings.Builder enc := xdr3.NewEncoder(&buf) - _ = op.EncodeTo(enc) + err = op.EncodeTo(enc) + require.NoError(t, err) opXDR := buf.String() opXDRBase64 := base64.StdEncoding.EncodeToString([]byte(opXDR)) - ops, _ := BuildOperations([]string{opXDRBase64}) + ops, err := BuildOperations([]string{opXDRBase64}) + require.NoError(t, err) assert.Equal(t, srcAccount, ops[0].GetSourceAccount()) assert.Equal(t, "10.0000000", ops[0].(*txnbuild.LiquidityPoolWithdraw).Amount) @@ -350,14 +395,17 @@ func TestBuildOperations(t *testing.T) { ExtendTo: uint32(10), SourceAccount: srcAccount, } - op, _ := e.BuildXDR() + op, err := e.BuildXDR() + require.NoError(t, err) var buf strings.Builder enc := xdr3.NewEncoder(&buf) - _ = op.EncodeTo(enc) + err = op.EncodeTo(enc) + require.NoError(t, err) opXDR := buf.String() opXDRBase64 := base64.StdEncoding.EncodeToString([]byte(opXDR)) - ops, _ := BuildOperations([]string{opXDRBase64}) + ops, err := BuildOperations([]string{opXDRBase64}) + require.NoError(t, err) assert.Equal(t, srcAccount, ops[0].GetSourceAccount()) }) @@ -367,16 +415,18 @@ func TestBuildOperations(t *testing.T) { r := txnbuild.RestoreFootprint{ SourceAccount: srcAccount, } - op, _ := r.BuildXDR() + op, err := r.BuildXDR() + require.NoError(t, err) var buf strings.Builder enc := xdr3.NewEncoder(&buf) - _ = op.EncodeTo(enc) + err = op.EncodeTo(enc) + require.NoError(t, err) opXDR := buf.String() opXDRBase64 := base64.StdEncoding.EncodeToString([]byte(opXDR)) - ops, _ := BuildOperations([]string{opXDRBase64}) + ops, err := BuildOperations([]string{opXDRBase64}) + require.NoError(t, err) assert.Equal(t, srcAccount, ops[0].GetSourceAccount()) }) - } diff --git a/internal/tss/utils/test_helpers.go b/internal/tss/utils/test_helpers.go new file mode 100644 index 000000000..53e1610f7 --- /dev/null +++ b/internal/tss/utils/test_helpers.go @@ -0,0 +1,52 @@ +package utils + +import ( + "testing" + + "github.com/stellar/go/keypair" + "github.com/stellar/go/txnbuild" + "github.com/stretchr/testify/require" +) + +// BuildTestTransaction is a test helper that builds a transaction with a random account. +// It is used to test the transaction manager. +// It is not used to build transactions for the TSS. +// For that, use the `BuildOperations` function. +func BuildTestTransaction(t *testing.T) *txnbuild.Transaction { + t.Helper() + + accountToSponsor := keypair.MustRandom() + + tx, err := txnbuild.NewTransaction(txnbuild.TransactionParams{ + SourceAccount: &txnbuild.SimpleAccount{ + AccountID: accountToSponsor.Address(), + Sequence: 124, + }, + IncrementSequenceNum: true, + Operations: []txnbuild.Operation{ + &txnbuild.Payment{ + Destination: keypair.MustRandom().Address(), + Amount: "14.0000000", + Asset: txnbuild.NativeAsset{}, + }, + }, + BaseFee: 104, + Preconditions: txnbuild.Preconditions{TimeBounds: txnbuild.NewTimeout(10)}, + }) + require.NoError(t, err) + return tx +} + +// BuildTestFeeBumpTransaction is a test helper that builds a fee bump transaction with a random fee account. +func BuildTestFeeBumpTransaction(t *testing.T) *txnbuild.FeeBumpTransaction { + t.Helper() + + feeBumpTx, err := txnbuild.NewFeeBumpTransaction( + txnbuild.FeeBumpTransactionParams{ + Inner: BuildTestTransaction(t), + FeeAccount: keypair.MustRandom().Address(), + BaseFee: 110, + }) + require.NoError(t, err) + return feeBumpTx +} diff --git a/internal/utils/http_client.go b/internal/utils/http_client.go index 514abf084..76b3a76c0 100644 --- a/internal/utils/http_client.go +++ b/internal/utils/http_client.go @@ -3,19 +3,8 @@ package utils import ( "io" "net/http" - - "github.com/stretchr/testify/mock" ) type HTTPClient interface { Post(url string, t string, body io.Reader) (resp *http.Response, err error) } - -type MockHTTPClient struct { - mock.Mock -} - -func (s *MockHTTPClient) Post(url, contentType string, body io.Reader) (resp *http.Response, err error) { - args := s.Called(url, contentType, body) - return args.Get(0).(*http.Response), args.Error(1) -} diff --git a/internal/utils/ingestion_utils_test.go b/internal/utils/ingestion_utils_test.go index 0614c63ec..38c4c4fc1 100644 --- a/internal/utils/ingestion_utils_test.go +++ b/internal/utils/ingestion_utils_test.go @@ -14,9 +14,9 @@ func TestMemo(t *testing.T) { Type: xdr.MemoTypeMemoNone, } - memo_value, memo_type := Memo(memo, "") - assert.Equal(t, (*string)(nil), memo_value) - assert.Equal(t, xdr.MemoTypeMemoNone.String(), memo_type) + memoValue, memoType := Memo(memo, "") + assert.Equal(t, (*string)(nil), memoValue) + assert.Equal(t, xdr.MemoTypeMemoNone.String(), memoType) }) t.Run("type_text", func(t *testing.T) { @@ -25,9 +25,9 @@ func TestMemo(t *testing.T) { Text: PointOf("test"), } - memo_value, memo_type := Memo(memo, "") - assert.Equal(t, "test", *memo_value) - assert.Equal(t, xdr.MemoTypeMemoText.String(), memo_type) + memoValue, memoType := Memo(memo, "") + assert.Equal(t, "test", *memoValue) + assert.Equal(t, xdr.MemoTypeMemoText.String(), memoType) }) t.Run("type_id", func(t *testing.T) { @@ -36,9 +36,9 @@ func TestMemo(t *testing.T) { Id: PointOf(xdr.Uint64(12345)), } - memo_value, memo_type := Memo(memo, "") - assert.Equal(t, "12345", *memo_value) - assert.Equal(t, xdr.MemoTypeMemoId.String(), memo_type) + memoValue, memoType := Memo(memo, "") + assert.Equal(t, "12345", *memoValue) + assert.Equal(t, xdr.MemoTypeMemoId.String(), memoType) }) t.Run("type_hash", func(t *testing.T) { @@ -50,9 +50,9 @@ func TestMemo(t *testing.T) { Hash: &value, } - memo_value, memo_type := Memo(memo, "") - assert.Equal(t, value.HexString(), *memo_value) - assert.Equal(t, xdr.MemoTypeMemoHash.String(), memo_type) + memoValue, memoType := Memo(memo, "") + assert.Equal(t, value.HexString(), *memoValue) + assert.Equal(t, xdr.MemoTypeMemoHash.String(), memoType) }) t.Run("type_return", func(t *testing.T) { @@ -64,8 +64,8 @@ func TestMemo(t *testing.T) { RetHash: &value, } - memo_value, memo_type := Memo(memo, "") - assert.Equal(t, value.HexString(), *memo_value) - assert.Equal(t, xdr.MemoTypeMemoReturn.String(), memo_type) + memoValue, memoType := Memo(memo, "") + assert.Equal(t, value.HexString(), *memoValue) + assert.Equal(t, xdr.MemoTypeMemoReturn.String(), memoType) }) } diff --git a/internal/utils/mocks.go b/internal/utils/mocks.go new file mode 100644 index 000000000..abf4987d7 --- /dev/null +++ b/internal/utils/mocks.go @@ -0,0 +1,20 @@ +package utils + +import ( + "io" + "net/http" + + "github.com/stretchr/testify/mock" +) + +type MockHTTPClient struct { + mock.Mock +} + +func (s *MockHTTPClient) Post(url, contentType string, body io.Reader) (resp *http.Response, err error) { + args := s.Called(url, contentType, body) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*http.Response), args.Error(1) +} diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 1dc18b5b7..dac0d762e 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -2,10 +2,14 @@ package utils import ( "bytes" + "context" + "fmt" + "io" "reflect" "strings" "github.com/stellar/go/strkey" + "github.com/stellar/go/support/log" "github.com/stellar/go/xdr" "github.com/stellar/wallet-backend/internal/entities" @@ -41,7 +45,7 @@ func PointOf[T any](value T) *T { func GetAccountLedgerKey(address string) (string, error) { decoded, err := strkey.Decode(strkey.VersionByteAccountID, address) if err != nil { - return "", err + return "", fmt.Errorf("decoding address %q: %w", address, err) } var key xdr.Uint256 copy(key[:], decoded) @@ -55,7 +59,7 @@ func GetAccountLedgerKey(address string) (string, error) { }, }.MarshalBinaryBase64() if err != nil { - return "", err + return "", fmt.Errorf("marshalling ledger key: %w", err) } return keyXdr, nil } @@ -64,7 +68,17 @@ func GetAccountFromLedgerEntry(entry entities.LedgerEntryResult) (xdr.AccountEnt var data xdr.LedgerEntryData err := xdr.SafeUnmarshalBase64(entry.DataXDR, &data) if err != nil { - return xdr.AccountEntry{}, err + return xdr.AccountEntry{}, fmt.Errorf("unmarshalling ledger entry data: %w", err) } return data.MustAccount(), nil } + +// DeferredClose is a function that closes an `io.Closer` resource and logs an error if it fails. +func DeferredClose(ctx context.Context, closer io.Closer, errMsg string) { + if err := closer.Close(); err != nil { + if errMsg == "" { + errMsg = "closing resource" + } + log.Ctx(ctx).Errorf("%s: %v", errMsg, err) + } +} diff --git a/internal/validators/validate.go b/internal/validators/validate.go index 77aa8d5b4..91312728f 100644 --- a/internal/validators/validate.go +++ b/internal/validators/validate.go @@ -12,11 +12,15 @@ import ( "github.com/stellar/go/support/log" ) -func NewValidator() *validator.Validate { +func NewValidator() (*validator.Validate, error) { validate := validator.New() - _ = validate.RegisterValidation("public_key", publicKeyValidation) + err := validate.RegisterValidation("public_key", publicKeyValidation) + if err != nil { + return nil, fmt.Errorf("registering public_key validation: %w", err) + } + validate.RegisterAlias("not_empty", "required") - return validate + return validate, nil } func publicKeyValidation(fl validator.FieldLevel) bool { diff --git a/internal/validators/validate_test.go b/internal/validators/validate_test.go index 29d8f2a32..45c1958de 100644 --- a/internal/validators/validate_test.go +++ b/internal/validators/validate_test.go @@ -1,6 +1,7 @@ package validators import ( + "errors" "testing" "github.com/go-playground/validator/v10" @@ -89,11 +90,13 @@ func TestParseValidationError(t *testing.T) { }, } - val := NewValidator() + val, err := NewValidator() + require.NoError(t, err) for _, tc := range testCases { err := val.Struct(tc.stc) require.Error(t, err) - vErrs, ok := err.(validator.ValidationErrors) + var vErrs validator.ValidationErrors + ok := errors.As(err, &vErrs) require.True(t, ok) fieldErrors := ParseValidationError(vErrs) assert.Equal(t, tc.expectedFieldErrors, fieldErrors) @@ -114,10 +117,13 @@ func TestParseValidationError(t *testing.T) { SliceGT: []string{"a", "b"}, SliceGTE: []string{"a"}, } - val := NewValidator() - err := val.Struct(stc) + val, err := NewValidator() + require.NoError(t, err) + err = val.Struct(stc) require.Error(t, err) - vErrs, ok := err.(validator.ValidationErrors) + + var vErrs validator.ValidationErrors + ok := errors.As(err, &vErrs) require.True(t, ok) fieldErrors := ParseValidationError(vErrs) assert.Equal(t, map[string]interface{}{ @@ -155,10 +161,12 @@ func TestParseValidationError(t *testing.T) { SliceLT: []string{"a", "b"}, SliceLTE: []string{"a", "b", "c"}, } - val := NewValidator() - err := val.Struct(stc) + val, err := NewValidator() + require.NoError(t, err) + err = val.Struct(stc) require.Error(t, err) - vErrs, ok := err.(validator.ValidationErrors) + var vErrs validator.ValidationErrors + ok := errors.As(err, &vErrs) require.True(t, ok) fieldErrors := ParseValidationError(vErrs) assert.Equal(t, map[string]interface{}{ @@ -222,11 +230,13 @@ func TestGetFieldName(t *testing.T) { }, }, } - val := NewValidator() - err := val.Struct(stc) + val, err := NewValidator() + require.NoError(t, err) + err = val.Struct(stc) require.Error(t, err) - vErrs, ok := err.(validator.ValidationErrors) + var vErrs validator.ValidationErrors + ok := errors.As(err, &vErrs) require.True(t, ok) require.Len(t, vErrs, 2) diff --git a/scripts/exclude_from_coverage.sh b/scripts/exclude_from_coverage.sh new file mode 100755 index 000000000..853e398d8 --- /dev/null +++ b/scripts/exclude_from_coverage.sh @@ -0,0 +1,20 @@ +#!/bin/sh + +exclude_terms() { + local terms="$1" + local infile="$2" + local tmpfile="${infile}.tmp" + + while IFS= read -r term || [ -n "$term" ]; do + local exp=".*${term}.*" + grep -v "$exp" "$infile" > "$tmpfile" + mv "$tmpfile" "$infile" + done << EOF +$terms +EOF +} + +# Usage +exclude_terms "mock" "c.out" +exclude_terms "mocks" "c.out" +exclude_terms "fixtures.go" "c.out"