Skip to content

Commit 9838305

Browse files
authored
add test for client middleware options (#1306)
* add test for client middleware options * Create warm-jokes-drive.md * tidy * tidy
1 parent dd23687 commit 9838305

2 files changed

Lines changed: 75 additions & 0 deletions

File tree

.changeset/warm-jokes-drive.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"@livekit/protocol": patch
3+
---
4+
5+
add test for client middleware options

rpc/typed_api_test.go

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
package rpc
2+
3+
import (
4+
"context"
5+
"fmt"
6+
reflect "reflect"
7+
"runtime"
8+
"slices"
9+
"testing"
10+
"time"
11+
12+
"github.com/prometheus/client_golang/prometheus"
13+
"github.com/stretchr/testify/require"
14+
"google.golang.org/protobuf/proto"
15+
16+
"github.com/livekit/protocol/logger"
17+
"github.com/livekit/psrpc"
18+
"github.com/livekit/psrpc/pkg/middleware"
19+
)
20+
21+
func TestMiddleware(t *testing.T) {
22+
t.Run("common middleware propagate client request args", func(t *testing.T) {
23+
InitPSRPCStats(prometheus.Labels{})
24+
25+
cases := []struct {
26+
label string
27+
opt psrpc.ClientOption
28+
}{
29+
{"WithClientLogger", WithClientLogger(logger.GetLogger())},
30+
{"WithClientMetrics", middleware.WithClientMetrics(PSRPCMetricsObserver{})},
31+
{"WithClientObservability", WithClientObservability(logger.GetLogger())},
32+
}
33+
34+
for _, c := range cases {
35+
var o psrpc.ClientOpts
36+
c.opt(&o)
37+
t.Run(c.label, func(t *testing.T) {
38+
for _, c := range o.RpcInterceptors {
39+
ch := make(chan []psrpc.RequestOption, 1)
40+
call := c(psrpc.RPCInfo{}, func(ctx context.Context, req proto.Message, opts ...psrpc.RequestOption) (proto.Message, error) {
41+
ch <- opts
42+
return nil, nil
43+
})
44+
45+
expected := []psrpc.RequestOption{func(*psrpc.RequestOpts) {}, func(*psrpc.RequestOpts) {}}
46+
go call(context.Background(), nil, expected...)
47+
48+
eqPtr := func(a psrpc.RequestOption) func(a psrpc.RequestOption) bool {
49+
return func(b psrpc.RequestOption) bool {
50+
return reflect.ValueOf(a).Pointer() == reflect.ValueOf(b).Pointer()
51+
}
52+
}
53+
54+
fp := reflect.ValueOf(c).Pointer()
55+
f := runtime.FuncForPC(fp)
56+
file, line := f.FileLine(fp)
57+
name := fmt.Sprintf("%s:%d %s", file, line, f.Name())
58+
59+
select {
60+
case res := <-ch:
61+
require.True(t, slices.ContainsFunc(res, eqPtr(expected[0])), "failed to receive option 0 from %s", name)
62+
require.True(t, slices.ContainsFunc(res, eqPtr(expected[1])), "failed to receive option 0 from %s", name)
63+
case <-time.After(time.Second):
64+
require.FailNow(t, "timeout")
65+
}
66+
}
67+
})
68+
}
69+
})
70+
}

0 commit comments

Comments
 (0)