Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 30 additions & 8 deletions pkg/sync/sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package sync

import (
"context"
"sync"

"github.com/Jguer/yay/v12/pkg/completion"
"github.com/Jguer/yay/v12/pkg/db"
Expand All @@ -18,6 +19,11 @@ import (
"github.com/leonelquinteros/gotext"
)

var (
completionNeedsUpdate = completion.NeedsUpdate
completionUpdateCache = completion.UpdateCache
)

type OperationService struct {
ctx context.Context
cfg *settings.Configuration
Expand Down Expand Up @@ -69,14 +75,9 @@ func (o *OperationService) Run(ctx context.Context, run *runtime.Runtime,
installer.AddPostInstallHook(cleanAURDirsFunc)
}

if completion.NeedsUpdate(o.cfg.CompletionPath, o.cfg.CompletionInterval, false) {
go func() {
errComp := completion.UpdateCache(ctx, run.HTTPClient, o.dbExecutor,
o.cfg.AURURL, o.cfg.CompletionPath, o.logger)
if errComp != nil {
o.logger.Warnln(errComp)
}
}()
waitForCompletionUpdate := o.startCompletionUpdate(ctx, run)
if waitForCompletionUpdate != nil {
defer waitForCompletionUpdate()
}

srcInfo, errInstall := srcinfo.NewService(o.dbExecutor, o.cfg,
Expand Down Expand Up @@ -123,6 +124,27 @@ func (o *OperationService) Run(ctx context.Context, run *runtime.Runtime,
return multiErr.Return()
}

func (o *OperationService) startCompletionUpdate(ctx context.Context, run *runtime.Runtime) func() {
if !completionNeedsUpdate(o.cfg.CompletionPath, o.cfg.CompletionInterval, false) {
return nil
}

var wg sync.WaitGroup
wg.Add(1)

go func() {
defer wg.Done()

errComp := completionUpdateCache(ctx, run.HTTPClient, o.dbExecutor,
o.cfg.AURURL, o.cfg.CompletionPath, o.logger)
if errComp != nil {
o.logger.Warnln(errComp)
}
}()

return wg.Wait
}

func (o *OperationService) manualConfirmRequired(cmdArgs *parser.Arguments) bool {
return (!cmdArgs.ExistsArg("u", "sysupgrade") && cmdArgs.Op != "Y") || o.cfg.DoubleConfirm
}
Expand Down
116 changes: 116 additions & 0 deletions pkg/sync/sync_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
//go:build !integration
// +build !integration

package sync

import (
"context"
"io"
"net/http"
"strings"
"testing"
"time"

"github.com/Jguer/yay/v12/pkg/completion"
"github.com/Jguer/yay/v12/pkg/download"
"github.com/Jguer/yay/v12/pkg/runtime"
"github.com/Jguer/yay/v12/pkg/settings"
"github.com/Jguer/yay/v12/pkg/text"

"github.com/stretchr/testify/require"
)

const waitGracePeriod = 50 * time.Millisecond

func TestStartCompletionUpdateSkipsWhenCacheIsFresh(t *testing.T) {
originalNeedsUpdate := completionNeedsUpdate
originalUpdateCache := completionUpdateCache
t.Cleanup(func() {
completionNeedsUpdate = originalNeedsUpdate
completionUpdateCache = originalUpdateCache
})

updateCalled := false
completionNeedsUpdate = func(string, int, bool) bool { return false }
completionUpdateCache = func(context.Context, download.HTTPRequestDoer, completion.PkgSynchronizer,
string, string, *text.Logger,
) error {
updateCalled = true
return nil
}

service, run := newTestOperationService()

wait := service.startCompletionUpdate(context.Background(), run)

require.Nil(t, wait)
require.False(t, updateCalled)
}

func TestStartCompletionUpdateWaitsForBackgroundUpdate(t *testing.T) {
originalNeedsUpdate := completionNeedsUpdate
originalUpdateCache := completionUpdateCache
t.Cleanup(func() {
completionNeedsUpdate = originalNeedsUpdate
completionUpdateCache = originalUpdateCache
})

started := make(chan struct{})
release := make(chan struct{})

completionNeedsUpdate = func(string, int, bool) bool { return true }
completionUpdateCache = func(context.Context, download.HTTPRequestDoer, completion.PkgSynchronizer,
string, string, *text.Logger,
) error {
close(started)
<-release
return nil
}

service, run := newTestOperationService()

wait := service.startCompletionUpdate(context.Background(), run)
require.NotNil(t, wait)

select {
case <-started:
case <-time.After(time.Second):
t.Fatal("completion update did not start")
}

waitReturned := make(chan struct{})
go func() {
wait()
close(waitReturned)
}()

select {
case <-waitReturned:
t.Fatal("wait returned before completion update finished")
case <-time.After(waitGracePeriod):
}

close(release)

select {
case <-waitReturned:
case <-time.After(time.Second):
t.Fatal("wait did not return after completion update finished")
}
}

func newTestOperationService() (*OperationService, *runtime.Runtime) {
cfg := &settings.Configuration{
AURURL: "https://aur.archlinux.org",
CompletionPath: "/tmp/completion",
CompletionInterval: 7,
}
logger := text.NewLogger(io.Discard, io.Discard, strings.NewReader(""), false, "test")
run := &runtime.Runtime{
Cfg: cfg,
HTTPClient: &http.Client{},
Logger: logger,
}

return NewOperationService(context.Background(), nil, run), run
}