Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 109 additions & 87 deletions apps/webapp/app/presenters/v3/RunStreamPresenter.server.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import { TaskRun } from "@trigger.dev/database";
import { eventStream } from "remix-utils/sse/server";
import { PrismaClient, prisma } from "~/db.server";
import { logger } from "~/services/logger.server";
import { singleton } from "~/utils/singleton";
import { createSSELoader } from "~/utils/sse";
import { throttle } from "~/utils/throttle";
import { tracePubSub } from "~/v3/services/tracePubSub.server";

const pingInterval = 1000;
const PING_INTERVAL = 1000;
const STREAM_TIMEOUT = 30 * 1000; // 30 seconds

export class RunStreamPresenter {
#prismaClient: PrismaClient;
Expand All @@ -14,105 +15,126 @@ export class RunStreamPresenter {
this.#prismaClient = prismaClient;
}

public async call({
request,
runFriendlyId,
}: {
request: Request;
runFriendlyId: TaskRun["friendlyId"];
}) {
const run = await this.#prismaClient.taskRun.findFirst({
where: {
friendlyId: runFriendlyId,
},
select: {
traceId: true,
},
});
public createLoader() {
const prismaClient = this.#prismaClient;

if (!run) {
return new Response("Not found", { status: 404 });
}
return createSSELoader({
timeout: STREAM_TIMEOUT,
interval: PING_INTERVAL,
handler: async (context) => {
const runFriendlyId = context.params.runParam;

logger.info("RunStreamPresenter.call", {
runFriendlyId,
traceId: run.traceId,
});

let pinger: NodeJS.Timeout | undefined = undefined;

const { unsubscribe, eventEmitter } = await tracePubSub.subscribeToTrace(run.traceId);

return eventStream(request.signal, (send, close) => {
const safeSend = (args: { event?: string; data: string }) => {
try {
send(args);
} catch (error) {
if (error instanceof Error) {
if (error.name !== "TypeError") {
logger.debug("Error sending SSE, aborting", {
error: {
name: error.name,
message: error.message,
stack: error.stack,
},
args,
});
}
} else {
logger.debug("Unknown error sending SSE, aborting", {
error,
args,
});
}

close();
if (!runFriendlyId) {
throw new Response("Missing runParam", { status: 400 });
}
};

const throttledSend = throttle(safeSend, 1000);

eventEmitter.addListener("message", (event) => {
throttledSend({ data: event });
});
const run = await prismaClient.taskRun.findFirst({
where: {
friendlyId: runFriendlyId,
},
select: {
traceId: true,
},
});

pinger = setInterval(() => {
if (request.signal.aborted) {
return close();
if (!run) {
throw new Response("Not found", { status: 404 });
}

safeSend({ event: "ping", data: new Date().toISOString() });
}, pingInterval);

return function clear() {
logger.info("RunStreamPresenter.abort", {
logger.info("RunStreamPresenter.start", {
runFriendlyId,
traceId: run.traceId,
});

clearInterval(pinger);

eventEmitter.removeAllListeners();
// Subscribe to trace updates
const { unsubscribe, eventEmitter } = await tracePubSub.subscribeToTrace(run.traceId);

// Store throttled send function and message listener for cleanup
let throttledSend: ReturnType<typeof throttle> | undefined;
let messageListener: ((event: string) => void) | undefined;

return {
initStream: ({ send }) => {
// Create throttled send function
throttledSend = throttle((args: { event?: string; data: string }) => {
try {
send(args);
} catch (error) {
if (error instanceof Error) {
if (error.name !== "TypeError") {
logger.debug("Error sending SSE in RunStreamPresenter", {
error: {
name: error.name,
message: error.message,
stack: error.stack,
},
});
}
}
// Abort the stream on send error
context.controller.abort("Send error");
}
}, 1000);

// Set up message listener for pub/sub events
messageListener = (event: string) => {
throttledSend?.({ data: event });
};
eventEmitter.addListener("message", messageListener);

context.debug("Subscribed to trace pub/sub");
},

iterator: ({ send }) => {
// Send ping to keep connection alive
try {
send({ event: "ping", data: new Date().toISOString() });
} catch (error) {
// If we can't send a ping, the connection is likely dead
return false;
}
},

unsubscribe()
.then(() => {
logger.info("RunStreamPresenter.abort.unsubscribe succeeded", {
runFriendlyId,
traceId: run.traceId,
});
})
.catch((error) => {
logger.error("RunStreamPresenter.abort.unsubscribe failed", {
cleanup: () => {
logger.info("RunStreamPresenter.cleanup", {
runFriendlyId,
traceId: run.traceId,
error: {
name: error.name,
message: error.message,
stack: error.stack,
},
});
});
};

// Remove message listener
if (messageListener) {
eventEmitter.removeListener("message", messageListener);
}
eventEmitter.removeAllListeners();

// Unsubscribe from Redis pub/sub
unsubscribe()
.then(() => {
logger.info("RunStreamPresenter.cleanup.unsubscribe succeeded", {
runFriendlyId,
traceId: run.traceId,
});
})
.catch((error) => {
logger.error("RunStreamPresenter.cleanup.unsubscribe failed", {
runFriendlyId,
traceId: run.traceId,
error: {
name: error.name,
message: error.message,
stack: error.stack,
},
});
});
},
};
},
});
}
}

// Export a singleton loader for the route to use
export const runStreamLoader = singleton("runStreamLoader", () => {
const presenter = new RunStreamPresenter();
return presenter.createLoader();
});
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import type { LoaderFunctionArgs } from "@remix-run/server-runtime";
import { z } from "zod";
import { RunStreamPresenter } from "~/presenters/v3/RunStreamPresenter.server";
import { runStreamLoader } from "~/presenters/v3/RunStreamPresenter.server";
import { requireUserId } from "~/services/session.server";

export async function loader({ request, params }: LoaderFunctionArgs) {
await requireUserId(request);
export async function loader(args: LoaderFunctionArgs) {
// Authenticate the user before starting the stream
await requireUserId(args.request);

const { runParam } = z.object({ runParam: z.string() }).parse(params);

const presenter = new RunStreamPresenter();
return presenter.call({ request, runFriendlyId: runParam });
// Delegate to the SSE loader
return runStreamLoader(args);
}
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ export class RedisRealtimeStreams implements StreamIngestor, StreamResponder {
await redis.quit().catch(console.error);
}

signal.addEventListener("abort", cleanup);
signal.addEventListener("abort", cleanup, { once: true });

return new Response(stream, {
headers: {
Expand Down
31 changes: 25 additions & 6 deletions apps/webapp/app/v3/services/tracePubSub.server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ import { createRedisClient, RedisClient, RedisWithClusterOptions } from "~/redis
import { EventEmitter } from "node:events";
import { env } from "~/env.server";
import { singleton } from "~/utils/singleton";
import { Gauge } from "prom-client";
import { metricsRegister } from "~/metrics.server";

export type TracePubSubOptions = {
redis: RedisWithClusterOptions;
Expand All @@ -15,7 +17,10 @@ export class TracePubSub {
this._publisher = createRedisClient("trigger:eventRepoPublisher", this._options.redis);
}

// TODO: do this more efficiently
get subscriberCount() {
return this._subscriberCount;
}

async publish(traceIds: string[]) {
if (traceIds.length === 0) return;
const uniqueTraces = new Set(traceIds.map((e) => `events:${e}`));
Expand All @@ -40,15 +45,18 @@ export class TracePubSub {

const eventEmitter = new EventEmitter();

// Define the message handler.
redis.on("message", (_, message) => {
// Define the message handler - store reference so we can remove it later.
const messageHandler = (_: string, message: string) => {
eventEmitter.emit("message", message);
});
};
redis.on("message", messageHandler);

// Return a function that can be used to unsubscribe.
const unsubscribe = async () => {
// Remove the message listener before closing the connection
redis.off("message", messageHandler);
await redis.unsubscribe(channel);
redis.quit();
await redis.quit();
this._subscriberCount--;
};

Expand All @@ -62,7 +70,7 @@ export class TracePubSub {
export const tracePubSub = singleton("tracePubSub", initializeTracePubSub);

function initializeTracePubSub() {
return new TracePubSub({
const pubSub = new TracePubSub({
redis: {
port: env.PUBSUB_REDIS_PORT,
host: env.PUBSUB_REDIS_HOST,
Expand All @@ -72,4 +80,15 @@ function initializeTracePubSub() {
clusterMode: env.PUBSUB_REDIS_CLUSTER_MODE_ENABLED === "1",
},
});

new Gauge({
name: "trace_pub_sub_subscribers",
help: "Number of trace pub sub subscribers",
collect() {
this.set(pubSub.subscriberCount);
},
registers: [metricsRegister],
});

return pubSub;
}