Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module github.com/abevier/go-sqs
go 1.18

require (
github.com/abevier/tsk v0.0.0-20221228184442-7aa6a1d7f829
github.com/abevier/tsk v0.0.0-20230712145722-249b1e98b01c
github.com/aws/aws-sdk-go-v2/config v1.1.5
github.com/aws/aws-sdk-go-v2/service/sqs v1.19.17
)
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
github.com/abevier/tsk v0.0.0-20221228184442-7aa6a1d7f829 h1:DFeide3v0BCOL+BVJltLiyMDImaeIQ+5LWsEBVmRpOE=
github.com/abevier/tsk v0.0.0-20221228184442-7aa6a1d7f829/go.mod h1:5stFO5XvX4fFShiWpYidc8K6TNGpmOZrVXJjF2Nbhmw=
github.com/abevier/tsk v0.0.0-20230712145722-249b1e98b01c h1:lmYm6DZwDH8IyZUBw1oqWjDk2KSn8v0zgkSPKa8gXFk=
github.com/abevier/tsk v0.0.0-20230712145722-249b1e98b01c/go.mod h1:5stFO5XvX4fFShiWpYidc8K6TNGpmOZrVXJjF2Nbhmw=
github.com/aws/aws-sdk-go-v2 v1.3.2/go.mod h1:7OaACgj2SX3XGWnrIjGlJM22h6yD6MEWKvm7levnnM8=
github.com/aws/aws-sdk-go-v2 v1.17.3 h1:shN7NlnVzvDUgPQ+1rLMSxY8OWRNDRYtiqe0p/PgrhY=
github.com/aws/aws-sdk-go-v2 v1.17.3/go.mod h1:uzbQtefpm44goOPmdKyAlXSNcwlRgF3ePWVW6EtJvvw=
Expand Down
66 changes: 48 additions & 18 deletions gosqs/consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,12 @@ type SQSConsumer struct {

callbackFunc MessageCallbackFunc

workerWG *sync.WaitGroup
isShutudown uint32
shutdownChan chan struct{}
workerWG *sync.WaitGroup
isShutdown uint32
isRxShutdown uint32
shutdownInitiatedChan chan struct{}
shutdownChan chan struct{}
rxShutdownChan chan struct{}
}

func NewConsumer(opts Opts, publisher *SQSPublisher, callback MessageCallbackFunc) *SQSConsumer {
Expand All @@ -45,9 +48,12 @@ func NewConsumer(opts Opts, publisher *SQSPublisher, callback MessageCallbackFun

callbackFunc: callback,

workerWG: &sync.WaitGroup{},
isShutudown: 0,
shutdownChan: make(chan struct{}),
workerWG: &sync.WaitGroup{},
isShutdown: 0,
isRxShutdown: 0,
shutdownInitiatedChan: make(chan struct{}),
shutdownChan: make(chan struct{}),
rxShutdownChan: make(chan struct{}),
}
}

Expand All @@ -60,11 +66,11 @@ func (c *SQSConsumer) Start() {
// do a request
// count number of messages pulled
// if greater than 7 - make 2 requests
// if less than 3 - do no make a request another request (unless 0 reuqests would be outstanding)
// also check this against the number i'm allowed to prefetch
// if less than 3 - do not make a request another request (unless 0 requests would be outstanding)
// also check this against the number I'm allowed to prefetch

go func() {
retreivedMsgChan := make(chan []SQSMessage)
retrievedMsgChan := make(chan []SQSMessage)

numReceivedMessages := 0
numInflightRetrieveRequests := 0
Expand All @@ -73,9 +79,17 @@ func (c *SQSConsumer) Start() {
calc := newCalculator(c.maxReceivedMessages, c.maxInflightReceiveMessageRequests)

for {
if atomic.LoadUint32(&c.isShutudown) == 1 {
if numInflightRetrieveRequests == 0 && numReceivedMessages == 0 {
break
if atomic.LoadUint32(&c.isShutdown) == 1 {
if numInflightRetrieveRequests == 0 {
// Close the rx shutdown channel to signal that all SQS receive requests have completed. Note that
// the processing for these requests, or for the messages returned for the requests, could still be
// in progress after the channel is closed.
if atomic.CompareAndSwapUint32(&c.isRxShutdown, 0, 1) {
close(c.rxShutdownChan)
}
if numReceivedMessages == 0 {
break
}
}
} else {
neededRequests := calc.NeededReceiveRequests(numReceivedMessages, numInflightRetrieveRequests, retrieveRequestLimit)
Expand All @@ -85,14 +99,14 @@ func (c *SQSConsumer) Start() {

for i := 0; i < neededRequests; i++ {
go func() {
receiveMessageWorker(c.publisher, retreivedMsgChan)
receiveMessageWorker(c.publisher, retrievedMsgChan)
}()
}
numInflightRetrieveRequests += neededRequests
}

select {
case msgs := <-retreivedMsgChan:
case msgs := <-retrievedMsgChan:
numInflightRetrieveRequests--
numReceivedMessages += len(msgs)
retrieveRequestLimit = calc.NewReceiveRequestLimit(retrieveRequestLimit, len(msgs))
Expand All @@ -102,11 +116,16 @@ func (c *SQSConsumer) Start() {

case <-msgProcessingCompleteChannel:
numReceivedMessages--
case _, ok := <-c.shutdownInitiatedChan:
if ok {
log.Printf("gosqs: initiating shutdown...")
}
// TODO: Do something else if the channel was already closed?
}
}

// All writers to these channels should have completed by the time the above loop exits
close(retreivedMsgChan)
close(retrievedMsgChan)
close(messageChan)
}()

Expand All @@ -117,7 +136,8 @@ func (c *SQSConsumer) Start() {
defer c.workerWG.Done()

for msg := range messageChan {
log.Printf("Handling message: %v on worker: %v\n", msg.body, id)
// TODO: disabling for now, will add back with better logging level support
//log.Printf("Handling message: %v on worker: %v\n", msg.body, id)
c.processMessage(msg)
msgProcessingCompleteChannel <- struct{}{}
}
Expand Down Expand Up @@ -167,16 +187,26 @@ func (c *SQSConsumer) processMessage(msg SQSMessage) {

if err := msg.ack(); err != nil {
// TODO: log something??
log.Printf("err acking - what do? %v", err)
log.Printf("gosqs: err acking - what do? %v", err)
}
}

func (c *SQSConsumer) Shutdown() {
if atomic.CompareAndSwapUint32(&c.isShutudown, 0, 1) {
if atomic.CompareAndSwapUint32(&c.isShutdown, 0, 1) {
c.shutdownInitiatedChan <- struct{}{}
go func() {
c.workerWG.Wait()
close(c.shutdownChan)
}()
}
<-c.shutdownChan
log.Printf("gosqs: shutdown complete")
}

func (c *SQSConsumer) WaitForRxShutdown() {
// This channel is only used to detect the close of SQS receive operations and signal the app that there are no more
// messages forthcoming beyond the ones currently in flight.
log.Printf("gosqs: waiting for rx shutdown...")
<-c.rxShutdownChan
log.Printf("gosqs: rx shutdown complete")
}
48 changes: 48 additions & 0 deletions gosqs/consumer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@ package gosqs

import (
"context"
"math"
"sync"
"testing"
"time"

"github.com/abevier/tsk/batch"
"github.com/abevier/tsk/results"

"github.com/aws/aws-sdk-go-v2/service/sqs"
"github.com/aws/aws-sdk-go-v2/service/sqs/types"
)
Expand All @@ -27,6 +32,49 @@ func TestConsumer(t *testing.T) {
c.Shutdown()
}

func TestConsumerShutdownWithBatchExecutor(t *testing.T) {
client := &TestSQSClient{}

publisher := NewPublisher(client, "queue-url", 20*time.Millisecond)
batchEx := batch.New[int, bool](
// batch will never fill up or expire - only a flush can clear it
batch.Opts{MaxSize: 100, MaxLinger: math.MaxInt64},
func(reqs []int) ([]results.Result[bool], error) {
res := make([]results.Result[bool], len(reqs))
for idx, _ := range res {
res[idx] = results.New(true, nil)
}
return res, nil
},
)
cb := func(ctx context.Context, msg string) error {
batchEx.Submit(context.Background(), 0)
return nil
}

c := NewConsumer(Opts{MaxReceivedMessages: 50, MaxWorkers: 50, MaxInflightReceiveMessageRequests: 10}, publisher, cb)

c.Start()

time.Sleep(5 * time.Second)

batchWg := sync.WaitGroup{}
batchWg.Add(2)
go func() {
defer batchWg.Done()
// This will start the shutdown process
c.Shutdown()
}()
go func() {
defer batchWg.Done()
// This will wait for all in flight receive requests to complete
c.WaitForRxShutdown()
// This will flush the batch and unblock all workers waiting for the batch to be created
batchEx.Flush()
}()
batchWg.Wait()
}

type TestSQSClient struct {
}

Expand Down