Skip to content

Commit b9e61f4

Browse files
H31nz3lovr
andauthored
fiix: Improve subscription ID handling and connection management, thanks @H31nz3l (#10485)
--------- Co-authored-by: Dmitry Patsura <[email protected]>
1 parent a5c516b commit b9e61f4

File tree

5 files changed

+206
-19
lines changed

5 files changed

+206
-19
lines changed

packages/cubejs-api-gateway/src/ws/local-subscription-store.ts

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,33 @@ interface LocalSubscriptionStoreOptions {
22
heartBeatInterval?: number;
33
}
44

5-
export type SubscriptionId = string | number;
6-
75
export type LocalSubscriptionStoreSubscription = {
86
message: any,
97
state: any,
108
timestamp: Date,
119
};
1210

1311
export type LocalSubscriptionStoreConnection = {
14-
subscriptions: Map<SubscriptionId, LocalSubscriptionStoreSubscription>,
12+
subscriptions: Map<string, LocalSubscriptionStoreSubscription>,
1513
authContext?: any,
1614
};
1715

1816
export class LocalSubscriptionStore {
1917
protected readonly connections: Map<string, LocalSubscriptionStoreConnection> = new Map();
2018

21-
protected readonly hearBeatInterval: number;
19+
protected readonly heartBeatInterval: number;
2220

2321
public constructor(options: LocalSubscriptionStoreOptions = {}) {
24-
this.hearBeatInterval = options.heartBeatInterval || 60;
22+
this.heartBeatInterval = options.heartBeatInterval || 60;
2523
}
2624

27-
public async getSubscription(connectionId: string, subscriptionId: string) {
28-
const connection = this.getConnectionOrCreate(connectionId);
25+
public async getSubscription(connectionId: string, subscriptionId: string): Promise<LocalSubscriptionStoreSubscription | undefined> {
26+
// only get subscription, do not create connection if it doesn't exist
27+
const connection = this.getConnection(connectionId);
28+
if (!connection) {
29+
return undefined;
30+
}
31+
2932
return connection.subscriptions.get(subscriptionId);
3033
}
3134

@@ -37,14 +40,22 @@ export class LocalSubscriptionStore {
3740
});
3841
}
3942

40-
public async unsubscribe(connectionId: string, subscriptionId: SubscriptionId) {
41-
const connection = this.getConnectionOrCreate(connectionId);
43+
public async unsubscribe(connectionId: string, subscriptionId: string) {
44+
const connection = this.getConnection(connectionId);
45+
if (!connection) {
46+
return;
47+
}
48+
49+
if (!connection.subscriptions.has(subscriptionId)) {
50+
return;
51+
}
52+
4253
connection.subscriptions.delete(subscriptionId);
4354
}
4455

4556
public getAllSubscriptions() {
4657
const now = Date.now();
47-
const staleThreshold = this.hearBeatInterval * 4 * 1000;
58+
const staleThreshold = this.heartBeatInterval * 4 * 1000;
4859
const result: Array<{ connectionId: string } & LocalSubscriptionStoreSubscription> = [];
4960

5061
for (const [connectionId, connection] of this.connections) {
@@ -75,17 +86,21 @@ export class LocalSubscriptionStore {
7586
}
7687

7788
protected getConnectionOrCreate(connectionId: string): LocalSubscriptionStoreConnection {
78-
const connect = this.connections.get(connectionId);
89+
const connect = this.getConnection(connectionId);
7990
if (connect) {
8091
return connect;
8192
}
8293

83-
const connection = { subscriptions: new Map() };
94+
const connection: LocalSubscriptionStoreConnection = { subscriptions: new Map<string, LocalSubscriptionStoreSubscription>() };
8495
this.connections.set(connectionId, connection);
8596

8697
return connection;
8798
}
8899

100+
protected getConnection(connectionId: string): LocalSubscriptionStoreConnection | undefined {
101+
return this.connections.get(connectionId);
102+
}
103+
89104
public clear() {
90105
this.connections.clear();
91106
}

packages/cubejs-api-gateway/src/ws/message-schema.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { z } from 'zod';
22

3-
const messageId = z.union([z.string().max(16), z.int()]);
3+
const messageId = z.union([z.string().max(16), z.int()]).transform(String);
44
const requestId = z.string().max(64).optional();
55

66
export const authMessageSchema = z.object({

packages/cubejs-api-gateway/src/ws/subscription-server.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ export class SubscriptionServer {
3838
) {
3939
}
4040

41-
protected resultFn(connectionId: string, messageId: string | number | undefined, requestId: string | undefined, logNetworkUsage: boolean = true) {
41+
protected resultFn(connectionId: string, messageId: string | undefined, requestId: string | undefined, logNetworkUsage: boolean = true) {
4242
return async (message, { status } = { status: 200 }) => {
4343
if (logNetworkUsage) {
4444
this.apiGateway.log({ type: 'Outgoing network usage', service: 'api-ws', bytes: calcMessageLength(message), }, { requestId });
@@ -158,7 +158,7 @@ export class SubscriptionServer {
158158
throw new UserError(`Unsupported method: ${message.method}`);
159159
}
160160

161-
const subscriptionId = String(message.messageId);
161+
const subscriptionId = message.messageId;
162162
const baseRequestId = message.requestId || `${connectionId}-${subscriptionId}`;
163163
const requestId = `${baseRequestId}-span-${uuidv4()}`;
164164

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
import {
2+
LocalSubscriptionStore,
3+
} from '../../src/ws/local-subscription-store';
4+
5+
describe('LocalSubscriptionStore', () => {
6+
it('stores and retrieves subscription by id', async () => {
7+
const store = new LocalSubscriptionStore();
8+
9+
await store.subscribe('conn-1', 'sub-1', {
10+
message: { method: 'load' },
11+
state: { foo: 'bar' }
12+
});
13+
14+
const subscription = await store.getSubscription('conn-1', 'sub-1');
15+
16+
expect(subscription).toBeDefined();
17+
expect(subscription?.message).toEqual({ method: 'load' });
18+
expect(subscription?.state).toEqual({ foo: 'bar' });
19+
expect(subscription?.timestamp).toBeInstanceOf(Date);
20+
});
21+
22+
it('stores and retrieves subscription by string id', async () => {
23+
const store = new LocalSubscriptionStore();
24+
25+
await store.subscribe('conn-1', '123', {
26+
message: { method: 'load' },
27+
state: { answer: true }
28+
});
29+
30+
const result = await store.getSubscription('conn-1', '123');
31+
32+
expect(result).toBeDefined();
33+
expect(result?.state).toEqual({ answer: true });
34+
});
35+
36+
it('does not create a connection when reading missing subscription', async () => {
37+
const store = new LocalSubscriptionStore();
38+
39+
const missing = await store.getSubscription('unknown-conn', 'sub-1');
40+
41+
expect(missing).toBeUndefined();
42+
// eslint-disable-next-line dot-notation
43+
expect(store['connections'].size).toBe(0);
44+
});
45+
46+
it('does not create a connection when unsubscribing unknown connection', async () => {
47+
const store = new LocalSubscriptionStore();
48+
49+
await store.unsubscribe('unknown-conn', 'sub-1');
50+
51+
// eslint-disable-next-line dot-notation
52+
expect(store['connections'].size).toBe(0);
53+
});
54+
55+
it('unsubscribes existing subscription', async () => {
56+
const store = new LocalSubscriptionStore();
57+
58+
await store.subscribe('conn-1', 'sub-1', {
59+
message: { method: 'load' },
60+
state: {}
61+
});
62+
63+
await store.unsubscribe('conn-1', 'sub-1');
64+
65+
const subscription = await store.getSubscription('conn-1', 'sub-1');
66+
expect(subscription).toBeUndefined();
67+
});
68+
69+
it('returns all active subscriptions with connectionId', async () => {
70+
const store = new LocalSubscriptionStore();
71+
72+
await store.subscribe('conn-1', 'sub-1', {
73+
message: { method: 'load' },
74+
state: { a: 1 }
75+
});
76+
await store.subscribe('conn-2', 'sub-2', {
77+
message: { method: 'subscribe' },
78+
state: { b: 2 }
79+
});
80+
81+
const allSubscriptions = store.getAllSubscriptions();
82+
83+
expect(allSubscriptions).toHaveLength(2);
84+
expect(allSubscriptions).toEqual(expect.arrayContaining([
85+
expect.objectContaining({
86+
connectionId: 'conn-1',
87+
message: { method: 'load' },
88+
state: { a: 1 }
89+
}),
90+
expect.objectContaining({
91+
connectionId: 'conn-2',
92+
message: { method: 'subscribe' },
93+
state: { b: 2 }
94+
})
95+
]));
96+
});
97+
98+
it('removes stale subscriptions during getAllSubscriptions', async () => {
99+
const store = new LocalSubscriptionStore({ heartBeatInterval: 1 });
100+
101+
await store.subscribe('conn-1', 'stale', {
102+
message: { method: 'load' },
103+
state: {}
104+
});
105+
await store.subscribe('conn-1', 'active', {
106+
message: { method: 'load' },
107+
state: {}
108+
});
109+
110+
const staleSubscription = await store.getSubscription('conn-1', 'stale');
111+
expect(staleSubscription).toBeDefined();
112+
if (!staleSubscription) {
113+
throw new Error('Expected stale subscription to exist');
114+
}
115+
staleSubscription.timestamp = new Date(Date.now() - 5000);
116+
117+
const allSubscriptions = store.getAllSubscriptions();
118+
119+
expect(allSubscriptions).toHaveLength(1);
120+
expect(allSubscriptions[0].connectionId).toBe('conn-1');
121+
expect(allSubscriptions[0].message).toEqual({ method: 'load' });
122+
123+
const staleAfterCleanup = await store.getSubscription('conn-1', 'stale');
124+
expect(staleAfterCleanup).toBeUndefined();
125+
});
126+
127+
it('stores and retrieves auth context', async () => {
128+
const store = new LocalSubscriptionStore();
129+
130+
const authContext = { securityContext: { userId: 42 } };
131+
await store.setAuthContext('conn-1', authContext);
132+
133+
await expect(store.getAuthContext('conn-1')).resolves.toEqual(authContext);
134+
});
135+
136+
it('removes connection on disconnect', async () => {
137+
const store = new LocalSubscriptionStore();
138+
139+
await store.subscribe('conn-1', 'sub-1', {
140+
message: { method: 'load' },
141+
state: {}
142+
});
143+
144+
await store.disconnect('conn-1');
145+
146+
// eslint-disable-next-line dot-notation
147+
expect(store['connections'].has('conn-1')).toBe(false);
148+
});
149+
150+
it('clears all connections', async () => {
151+
const store = new LocalSubscriptionStore();
152+
153+
await store.subscribe('conn-1', 'sub-1', {
154+
message: { method: 'load' },
155+
state: {}
156+
});
157+
await store.subscribe('conn-2', 'sub-2', {
158+
message: { method: 'subscribe' },
159+
state: {}
160+
});
161+
162+
store.clear();
163+
164+
// eslint-disable-next-line dot-notation
165+
expect(store['connections'].size).toBe(0);
166+
});
167+
});

packages/cubejs-api-gateway/test/ws/subscription-server.test.ts

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,15 @@ describe('SubscriptionServer', () => {
5959
expect(mockSubscriptionStore.unsubscribe).toHaveBeenCalledWith('conn-1', 'msg-1');
6060
});
6161

62-
it('should accept unsubscribe with numeric messageId', async () => {
62+
it('should convert numeric unsubscribe id to string', async () => {
6363
const { mockApiGateway, mockSubscriptionStore, mockSendMessage, mockContextAcceptor } = createMocks();
6464
const server = new SubscriptionServer(mockApiGateway, mockSendMessage, mockSubscriptionStore, mockContextAcceptor);
6565

6666
await server.processMessage('conn-1', JSON.stringify({ unsubscribe: 123 }));
6767

68-
expect(mockSubscriptionStore.unsubscribe).toHaveBeenCalledWith('conn-1', 123);
68+
const callArgs = mockSubscriptionStore.unsubscribe.mock.calls[0];
69+
expect(typeof callArgs[1]).toBe('string');
70+
expect(callArgs[1]).toBe('123');
6971
});
7072

7173
it('should accept valid load message', async () => {
@@ -83,7 +85,7 @@ describe('SubscriptionServer', () => {
8385
expect(sentMessages).toContainEqual({ messageProcessedId: '123' });
8486
});
8587

86-
it('should accept messageId as number', async () => {
88+
it('should convert numeric messageId to string', async () => {
8789
const { mockApiGateway, mockSubscriptionStore, mockSendMessage, mockContextAcceptor, sentMessages } = createMocks();
8890
const server = new SubscriptionServer(mockApiGateway, mockSendMessage, mockSubscriptionStore, mockContextAcceptor);
8991

@@ -95,7 +97,10 @@ describe('SubscriptionServer', () => {
9597
await server.processMessage('conn-1', JSON.stringify(message));
9698

9799
expect(mockApiGateway.load).toHaveBeenCalled();
98-
expect(sentMessages).toContainEqual({ messageProcessedId: 123 });
100+
101+
const processedMsg = sentMessages.find((m) => m.messageProcessedId !== undefined);
102+
expect(typeof processedMsg.messageProcessedId).toBe('string');
103+
expect(processedMsg.messageProcessedId).toBe('123');
99104
});
100105

101106
it('should reject invalid JSON payload', async () => {

0 commit comments

Comments
 (0)