diff --git a/go.mod b/go.mod index 5117687..17a942c 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/abevier/go-sqs go 1.18 require ( - github.com/abevier/tsk v0.0.0-20221228184442-7aa6a1d7f829 + github.com/abevier/tsk v0.0.0-20230712145722-249b1e98b01c github.com/aws/aws-sdk-go-v2/config v1.1.5 github.com/aws/aws-sdk-go-v2/service/sqs v1.19.17 ) diff --git a/go.sum b/go.sum index aefd007..284dc1b 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/abevier/tsk v0.0.0-20221228184442-7aa6a1d7f829 h1:DFeide3v0BCOL+BVJltLiyMDImaeIQ+5LWsEBVmRpOE= -github.com/abevier/tsk v0.0.0-20221228184442-7aa6a1d7f829/go.mod h1:5stFO5XvX4fFShiWpYidc8K6TNGpmOZrVXJjF2Nbhmw= +github.com/abevier/tsk v0.0.0-20230712145722-249b1e98b01c h1:lmYm6DZwDH8IyZUBw1oqWjDk2KSn8v0zgkSPKa8gXFk= +github.com/abevier/tsk v0.0.0-20230712145722-249b1e98b01c/go.mod h1:5stFO5XvX4fFShiWpYidc8K6TNGpmOZrVXJjF2Nbhmw= github.com/aws/aws-sdk-go-v2 v1.3.2/go.mod h1:7OaACgj2SX3XGWnrIjGlJM22h6yD6MEWKvm7levnnM8= github.com/aws/aws-sdk-go-v2 v1.17.3 h1:shN7NlnVzvDUgPQ+1rLMSxY8OWRNDRYtiqe0p/PgrhY= github.com/aws/aws-sdk-go-v2 v1.17.3/go.mod h1:uzbQtefpm44goOPmdKyAlXSNcwlRgF3ePWVW6EtJvvw= diff --git a/gosqs/consumer.go b/gosqs/consumer.go index 38bef92..ebe32c9 100644 --- a/gosqs/consumer.go +++ b/gosqs/consumer.go @@ -30,9 +30,12 @@ type SQSConsumer struct { callbackFunc MessageCallbackFunc - workerWG *sync.WaitGroup - isShutudown uint32 - shutdownChan chan struct{} + workerWG *sync.WaitGroup + isShutdown uint32 + isRxShutdown uint32 + shutdownInitiatedChan chan struct{} + shutdownChan chan struct{} + rxShutdownChan chan struct{} } func NewConsumer(opts Opts, publisher *SQSPublisher, callback MessageCallbackFunc) *SQSConsumer { @@ -45,9 +48,12 @@ func NewConsumer(opts Opts, publisher *SQSPublisher, callback MessageCallbackFun callbackFunc: callback, - workerWG: &sync.WaitGroup{}, - isShutudown: 0, - shutdownChan: make(chan struct{}), + workerWG: &sync.WaitGroup{}, + isShutdown: 0, + isRxShutdown: 0, + shutdownInitiatedChan: make(chan struct{}), + shutdownChan: make(chan struct{}), + rxShutdownChan: make(chan struct{}), } } @@ -60,11 +66,11 @@ func (c *SQSConsumer) Start() { // do a request // count number of messages pulled // if greater than 7 - make 2 requests - // if less than 3 - do no make a request another request (unless 0 reuqests would be outstanding) - // also check this against the number i'm allowed to prefetch + // if less than 3 - do not make a request another request (unless 0 requests would be outstanding) + // also check this against the number I'm allowed to prefetch go func() { - retreivedMsgChan := make(chan []SQSMessage) + retrievedMsgChan := make(chan []SQSMessage) numReceivedMessages := 0 numInflightRetrieveRequests := 0 @@ -73,9 +79,17 @@ func (c *SQSConsumer) Start() { calc := newCalculator(c.maxReceivedMessages, c.maxInflightReceiveMessageRequests) for { - if atomic.LoadUint32(&c.isShutudown) == 1 { - if numInflightRetrieveRequests == 0 && numReceivedMessages == 0 { - break + if atomic.LoadUint32(&c.isShutdown) == 1 { + if numInflightRetrieveRequests == 0 { + // Close the rx shutdown channel to signal that all SQS receive requests have completed. Note that + // the processing for these requests, or for the messages returned for the requests, could still be + // in progress after the channel is closed. + if atomic.CompareAndSwapUint32(&c.isRxShutdown, 0, 1) { + close(c.rxShutdownChan) + } + if numReceivedMessages == 0 { + break + } } } else { neededRequests := calc.NeededReceiveRequests(numReceivedMessages, numInflightRetrieveRequests, retrieveRequestLimit) @@ -85,14 +99,14 @@ func (c *SQSConsumer) Start() { for i := 0; i < neededRequests; i++ { go func() { - receiveMessageWorker(c.publisher, retreivedMsgChan) + receiveMessageWorker(c.publisher, retrievedMsgChan) }() } numInflightRetrieveRequests += neededRequests } select { - case msgs := <-retreivedMsgChan: + case msgs := <-retrievedMsgChan: numInflightRetrieveRequests-- numReceivedMessages += len(msgs) retrieveRequestLimit = calc.NewReceiveRequestLimit(retrieveRequestLimit, len(msgs)) @@ -102,11 +116,16 @@ func (c *SQSConsumer) Start() { case <-msgProcessingCompleteChannel: numReceivedMessages-- + case _, ok := <-c.shutdownInitiatedChan: + if ok { + log.Printf("gosqs: initiating shutdown...") + } + // TODO: Do something else if the channel was already closed? } } // All writers to these channels should have completed by the time the above loop exits - close(retreivedMsgChan) + close(retrievedMsgChan) close(messageChan) }() @@ -117,7 +136,8 @@ func (c *SQSConsumer) Start() { defer c.workerWG.Done() for msg := range messageChan { - log.Printf("Handling message: %v on worker: %v\n", msg.body, id) + // TODO: disabling for now, will add back with better logging level support + //log.Printf("Handling message: %v on worker: %v\n", msg.body, id) c.processMessage(msg) msgProcessingCompleteChannel <- struct{}{} } @@ -167,16 +187,26 @@ func (c *SQSConsumer) processMessage(msg SQSMessage) { if err := msg.ack(); err != nil { // TODO: log something?? - log.Printf("err acking - what do? %v", err) + log.Printf("gosqs: err acking - what do? %v", err) } } func (c *SQSConsumer) Shutdown() { - if atomic.CompareAndSwapUint32(&c.isShutudown, 0, 1) { + if atomic.CompareAndSwapUint32(&c.isShutdown, 0, 1) { + c.shutdownInitiatedChan <- struct{}{} go func() { c.workerWG.Wait() close(c.shutdownChan) }() } <-c.shutdownChan + log.Printf("gosqs: shutdown complete") +} + +func (c *SQSConsumer) WaitForRxShutdown() { + // This channel is only used to detect the close of SQS receive operations and signal the app that there are no more + // messages forthcoming beyond the ones currently in flight. + log.Printf("gosqs: waiting for rx shutdown...") + <-c.rxShutdownChan + log.Printf("gosqs: rx shutdown complete") } diff --git a/gosqs/consumer_test.go b/gosqs/consumer_test.go index d2863a1..6ac7135 100644 --- a/gosqs/consumer_test.go +++ b/gosqs/consumer_test.go @@ -2,9 +2,14 @@ package gosqs import ( "context" + "math" + "sync" "testing" "time" + "github.com/abevier/tsk/batch" + "github.com/abevier/tsk/results" + "github.com/aws/aws-sdk-go-v2/service/sqs" "github.com/aws/aws-sdk-go-v2/service/sqs/types" ) @@ -27,6 +32,49 @@ func TestConsumer(t *testing.T) { c.Shutdown() } +func TestConsumerShutdownWithBatchExecutor(t *testing.T) { + client := &TestSQSClient{} + + publisher := NewPublisher(client, "queue-url", 20*time.Millisecond) + batchEx := batch.New[int, bool]( + // batch will never fill up or expire - only a flush can clear it + batch.Opts{MaxSize: 100, MaxLinger: math.MaxInt64}, + func(reqs []int) ([]results.Result[bool], error) { + res := make([]results.Result[bool], len(reqs)) + for idx, _ := range res { + res[idx] = results.New(true, nil) + } + return res, nil + }, + ) + cb := func(ctx context.Context, msg string) error { + batchEx.Submit(context.Background(), 0) + return nil + } + + c := NewConsumer(Opts{MaxReceivedMessages: 50, MaxWorkers: 50, MaxInflightReceiveMessageRequests: 10}, publisher, cb) + + c.Start() + + time.Sleep(5 * time.Second) + + batchWg := sync.WaitGroup{} + batchWg.Add(2) + go func() { + defer batchWg.Done() + // This will start the shutdown process + c.Shutdown() + }() + go func() { + defer batchWg.Done() + // This will wait for all in flight receive requests to complete + c.WaitForRxShutdown() + // This will flush the batch and unblock all workers waiting for the batch to be created + batchEx.Flush() + }() + batchWg.Wait() +} + type TestSQSClient struct { }