From 56086604bb70d6dece53b810827c94e8b1809168 Mon Sep 17 00:00:00 2001 From: Aveline <352441+ym@users.noreply.github.com> Date: Wed, 9 Apr 2025 19:41:34 +0200 Subject: [PATCH] trickle ice (#36) * refactor: improve WebSocket handling in CreateSession function (#30) * feat: implement Tricke ICE WebRTC signaling with dedicated WebSocket (#31) * feat: implement Tricke ICE WebRTC signaling with dedicated WebSocket handling * Update src/webrtc.ts Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Add more logging (#32) * Add more logging * Refactor logging in WebRTC signaling to remove "WS" prefix for consistency * A tiny bit clearer logging (#33) * Fix/more logging (#34) * A tiny bit clearer logging * Enhance WebSocket close event logging to include closure code and reason * chore: update .gitignore and enhance WebSocket connection handling (#35) * Add .env.development to .gitignore * Improve handling of existing WebSocket connections by waiting for closure before terminating --------- Co-authored-by: Adam Shiervani Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .gitignore | 3 +- package.json | 2 +- src/devices.ts | 2 +- src/index.ts | 28 +-- src/webrtc-signaling.ts | 420 ++++++++++++++++++++++++++++++++++++++++ src/webrtc.ts | 188 +++--------------- 6 files changed, 471 insertions(+), 172 deletions(-) create mode 100644 src/webrtc-signaling.ts diff --git a/.gitignore b/.gitignore index 1af2f6e..e2622bf 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ node_modules .idea -.env \ No newline at end of file +.env +.env.development diff --git a/package.json b/package.json index 6593453..93e4a99 100644 --- a/package.json +++ b/package.json @@ -5,7 +5,7 @@ "main": "index.js", "scripts": { "start": "NODE_ENV=production node -r ts-node/register ./src/index.ts", - "dev": "NODE_ENV=development node --env-file=.env.development -r ts-node/register ./src/index.ts" + "dev": "NODE_ENV=development node --watch --env-file=.env.development -r ts-node/register ./src/index.ts" }, "engines": { "node": "21.1.0" diff --git a/src/devices.ts b/src/devices.ts index 9a66247..c27cd2e 100644 --- a/src/devices.ts +++ b/src/devices.ts @@ -7,9 +7,9 @@ import { UnauthorizedError, UnprocessableEntityError, } from "./errors"; -import { activeConnections } from "./webrtc"; import * as crypto from "crypto"; import { authenticated } from "./auth"; +import { activeConnections } from "./webrtc-signaling"; export const List = async (req: express.Request, res: express.Response) => { const idToken = req.session?.id_token; diff --git a/src/index.ts b/src/index.ts index 1178ee3..ab4145e 100644 --- a/src/index.ts +++ b/src/index.ts @@ -12,6 +12,7 @@ import * as Releases from "./releases"; import { HttpError } from "./errors"; import { authenticated } from "./auth"; import { prisma } from "./db"; +import { initializeWebRTCSignaling } from "./webrtc-signaling"; declare global { namespace NodeJS { @@ -55,22 +56,23 @@ app.use(express.urlencoded({ extended: true })); app.use( cors({ origin: process.env.CORS_ORIGINS?.split(",") || [ - "https://app.jetkvm.com", "http://localhost:5173" + "https://app.jetkvm.com", + "http://localhost:5173", ], credentials: true, }), ); -app.use( - cookieSession({ - name: "session", - path: "/", - httpOnly: true, - keys: [process.env.COOKIE_SECRET], - secure: process.env.NODE_ENV === "production", - sameSite: "strict", - maxAge: 24 * 60 * 60 * 1000, // 24 hours - }), -); +export const cookieSessionMiddleware = cookieSession({ + name: "session", + path: "/", + httpOnly: true, + keys: [process.env.COOKIE_SECRET], + secure: process.env.NODE_ENV === "production", + sameSite: "strict", + maxAge: 24 * 60 * 60 * 1000, // 24 hours +}); + +app.use(cookieSessionMiddleware); function asyncHandler(fn: any) { return (req: express.Request, res: express.Response, next: express.NextFunction) => { @@ -209,4 +211,4 @@ const server = app.listen(3000, () => { console.log("Server started on port 3000"); }); -Webrtc.registerWebsocketServer(server); +initializeWebRTCSignaling(server); diff --git a/src/webrtc-signaling.ts b/src/webrtc-signaling.ts new file mode 100644 index 0000000..310b513 --- /dev/null +++ b/src/webrtc-signaling.ts @@ -0,0 +1,420 @@ +// src/webrtc.ts + +import { MessageEvent, WebSocket, WebSocketServer } from "ws"; +import * as jose from "jose"; +import { prisma } from "./db"; +import { IncomingMessage } from "http"; +import { Socket } from "node:net"; +import { Device } from "@prisma/client"; +import { Server, ServerResponse } from "node:http"; +import { cookieSessionMiddleware } from "."; + +// Maintain the shared state +export const activeConnections: Map = + new Map(); // [deviceWs, ip, version] +export const inFlight: Set = new Set(); + +function toICEServers(str: string) { + return str.split(",").filter(url => url.startsWith("stun:")); +} + +export const iceServers = toICEServers( + process.env.ICE_SERVERS || + "stun.cloudflare.com:3478,stun:stun.l.google.com:19302,stun:stun1.l.google.com:5349", +); + +// Helper function to update device last seen timestamp +async function updateDeviceLastSeen(id: string) { + const device = await prisma.device.findUnique({ where: { id } }); + if (device) { + return prisma.device.update({ where: { id }, data: { lastSeen: new Date() } }); + } +} + +const wssDevice = new WebSocketServer({ noServer: true }); +const wssClient = new WebSocketServer({ noServer: true }); + +// WebSocket router - routes WebSocket connections based on URL path +export function registerWebSocketRouter( + server: Server, +) { + server.on("upgrade", async (req: IncomingMessage, socket: Socket, head: Buffer) => { + const url = new URL(req.url || "", "http://localhost"); // We don't care about the hostname, we're just using the path to route + const path = url.pathname; + + // Route to appropriate handler based on path + // This path should be something like /webrtc/signaling/device, but due to legacy reasons we have to use `/` for device ws regitstrations + if (path === "/") { + await handleDeviceSocketRequest(req, socket, head); + } else if (path === "/webrtc/signaling/client") { + await handleClientSocketRequest(req, socket, head); + } else { + console.log(`[Webrtc] Unrecognized path: ${path}`); + return socket.destroy(); + } + }); +} + +// ========================================================================== +// Device WebSocket handlers +// ========================================================================== + +// Handle device WebSocket connection requests +async function handleDeviceSocketRequest( + req: IncomingMessage, + socket: Socket, + head: Buffer, +) { + try { + // Authenticate device + const device = await authenticateDeviceRequest(req); + if (!device) { + return socket.destroy(); + } + + // Inflight means that the device has connected, a client has connected to that device via HTTP, and they're now doing the signaling dance + if (inFlight.has(device.id)) { + console.log( + `[Device] Device ${device.id} already has an inflight client connection.`, + ); + return socket.destroy(); + } + + // Handle existing connections for this device + if (activeConnections.has(device.id)) { + console.log( + `[Device] Device ${device.id} already connected. Terminating existing connection.`, + ); + + const [existingDeviceWs] = activeConnections.get(device.id)!; + await new Promise(resolve => { + console.log("[Device] Waiting for existing connection to close..."); + existingDeviceWs.on("close", () => { + activeConnections.delete(device.id); + console.log("[Device] Existing connection closed."); + + // Now we continue with the new connection + resolve(true); + }); + + existingDeviceWs.terminate(); + }); + } + + // Complete the WebSocket upgrade + wssDevice.handleUpgrade(req, socket, head, ws => { + setupDeviceWebSocket(ws, device, req); + }); + } catch (error) { + console.error("Error handling device socket request:", error); + socket.destroy(); + } +} + +// Authenticate the device connection +async function authenticateDeviceRequest(req: IncomingMessage) { + const authHeader = req.headers["authorization"]; + const secretToken = authHeader?.split(" ")?.[1]; + + if (!secretToken) { + console.log("[Device] No authorization header provided."); + return null; + } + + try { + const device = await prisma.device.findFirst({ where: { secretToken } }); + if (!device) { + console.log("[Device] Invalid secret token provided."); + return null; + } + + const id = req.headers["x-device-id"] as string; + if (!id || id !== device.id) { + console.log("[Device] Invalid device ID or ID/token mismatch."); + return null; + } + + return device; + } catch (error) { + console.error("[Device] Error authenticating device:", error); + return null; + } +} + +// Setup the device WebSocket after authentication +function setupDeviceWebSocket(deviceWs: WebSocket, device: Device, req: IncomingMessage) { + const id = device.id; + const ip = + (process.env.REAL_IP_HEADER && req.headers[process.env.REAL_IP_HEADER]) || + req.socket.remoteAddress; + + const deviceVersion = req.headers["x-app-version"] as string | null; + + // Store the connection + activeConnections.set(id, [deviceWs, `${ip}`, deviceVersion || null]); + console.log( + `[Device] New connection for device ${id}, with version ${deviceVersion || "unknown"}`, + ); + + // Setup ping/pong for connection health checks + // @ts-ignore + deviceWs.isAlive = true; + deviceWs.on("pong", function heartbeat() { + // @ts-ignore + this.isAlive = true; + }); + + const checkAliveInterval = setInterval(function checkAlive() { + // @ts-ignore + if (deviceWs.isAlive === false) { + console.log(`[Device] ${id} is not alive. Terminating connection.`); + return deviceWs.terminate(); + } + // @ts-ignore + deviceWs.isAlive = false; + deviceWs.ping(); + // We check for aliveness every 10s + }, 10000); + + // Handle errors and connection close + deviceWs.on("error", async error => { + console.log(`[Device] Error for ${id}:`, error); + await cleanup(); + }); + + deviceWs.on("close", async (code, reason) => { + console.log( + `[Device] Connection closed for ${id} with code ${code} and reason ${reason}`, + ); + await cleanup(); + }); + + // Cleanup function + async function cleanup() { + console.log(`[Device] Cleanup for ${id}`); + activeConnections.delete(id); + clearInterval(checkAliveInterval); + await updateDeviceLastSeen(id); + } +} + +// ========================================================================== +// Client WebSocket handlers +// ========================================================================== + +// Handle client WebSocket connection requests +async function handleClientSocketRequest( + req: IncomingMessage, + socket: Socket, + head: Buffer, +) { + try { + // Apply session middleware to access authentication + cookieSessionMiddleware(req as any, {} as any, async () => { + try { + // Authenticate client and get device ID + const { deviceId, token } = await authenticateClientRequest(req as any); + if (!deviceId) { + return socket.destroy(); + } + + // Check if device is connected + if (!activeConnections.has(deviceId)) { + console.log(`[Client] Device ${deviceId} not connected.`); + socket.write("HTTP/1.1 404 Not Found\r\n\r\n"); + return socket.destroy(); + } + + // Complete the WebSocket upgrade + wssClient.handleUpgrade(req, socket, head, ws => { + setupClientWebSocket(ws, deviceId, token); + }); + } catch (error) { + console.error("Error in client WebSocket setup:", error); + socket.destroy(); + } + }); + } catch (error) { + console.error("Error handling client socket request:", error); + socket.destroy(); + } +} + +// Authenticate the client connection +async function authenticateClientRequest(req: Request & { session: any }) { + const session = req.session; + const token = session?.id_token; + + if (!token) { + console.log("[Client] No authentication token."); + return { deviceId: null }; + } + + try { + const { sub } = jose.decodeJwt(token); + const url = new URL(req.url || "", "http://localhost"); + const deviceId = url.searchParams.get("id"); + + if (!deviceId) { + console.log("[Client] No device ID provided."); + return { deviceId: null }; + } + + // Check if device exists and user has access + const device = await prisma.device.findUnique({ + where: { id: deviceId, user: { googleId: sub } }, + select: { id: true }, + }); + + if (!device) { + console.log("[Client] Device not found or user doesn't have access."); + return { deviceId: null }; + } + + return { deviceId, token }; + } catch (error) { + console.error("[Client] Authentication error:", error); + return { deviceId: null }; + } +} + +// Setup the client WebSocket after authentication +function setupClientWebSocket(clientWs: WebSocket, deviceId: string, token: string) { + console.log(`[Client] New connection for device ${deviceId}`); + + // Get device WebSocket + const deviceConn = activeConnections.get(deviceId); + if (!deviceConn) { + console.log(`[Client] No device connection for ${deviceId}`); + return clientWs.close(); + } + + const [deviceWs, ip, version] = deviceConn; + + // If there's an active connection with this device, prevent a new one + if (inFlight.has(deviceId)) { + console.log(`[Client] Device ${deviceId} already has an active client connection.`); + return clientWs.close(); + } + + console.log( + "[Client] Sending client device-metadata, version:", + version, + " - ", + deviceId, + ); + + clientWs.send( + JSON.stringify({ + type: "device-metadata", + data: { deviceVersion: version }, + }), + ); + + // Handle message forwarding from client to device + clientWs.on("message", data => { + // Handle ping/pong + if (data.toString() === "ping") return clientWs.send("pong"); + + try { + const msg = JSON.parse(data.toString()); + + switch (msg.type) { + case "offer": + console.log(`[Client] Sending offer to device ${deviceId}`); + deviceWs.send( + JSON.stringify({ + type: "offer", + data: { + sd: msg.data.sd, + ip, + iceServers, + OidcGoogle: token, + }, + }), + ); + break; + + case "new-ice-candidate": + console.log(`[Client] Sending ICE candidate to device ${deviceId}`); + deviceWs.send( + JSON.stringify({ + type: "new-ice-candidate", + data: msg.data, + }), + ); + break; + } + } catch (error) { + console.error(`[Client] Error processing message for ${deviceId}:`, error); + } + }); + + // Handle message forwarding from device to client + const deviceMessageHandler = (event: MessageEvent) => { + try { + const msg = JSON.parse(event.data as string); + + switch (msg.type) { + case "answer": + console.log(`[Device] Sending answer to client for ${deviceId}`); + clientWs.send(JSON.stringify({ type: "answer", data: msg.data })); + break; + + case "new-ice-candidate": + console.log(`[Device] Sending ICE candidate to client for ${deviceId}`); + clientWs.send(JSON.stringify({ type: "new-ice-candidate", data: msg.data })); + break; + } + } catch (error) { + console.error(`[Device] Error processing message for ${deviceId}:`, error); + } + }; + + // Store original handlers so we can restore them + const originalHandlers = { + onmessage: deviceWs.onmessage, + onerror: deviceWs.onerror, + onclose: deviceWs.onclose, + }; + + // Set up device -> client message handling + deviceWs.onmessage = deviceMessageHandler; + + // Handle device errors and disconnections + deviceWs.onerror = () => { + console.log(`[Device] Error, closing client connection for ${deviceId}`); + cleanup(); + clientWs.close(); + }; + + deviceWs.onclose = () => { + console.log(`[Device] Closed, terminating client connection for ${deviceId}`); + cleanup(); + clientWs.terminate(); + }; + + // Handle client disconnection + clientWs.on("close", () => { + console.log(`[Client] Connection closed for ${deviceId}`); + cleanup(); + }); + + // Cleanup function + function cleanup() { + // Restore original device handlers + deviceWs.onmessage = originalHandlers.onmessage; + deviceWs.onerror = originalHandlers.onerror; + deviceWs.onclose = originalHandlers.onclose; + + // Remove from in-flight set + inFlight.delete(deviceId); + } +} + +// Export a single initialization function +export function initializeWebRTCSignaling( + server: Server, +) { + registerWebSocketRouter(server); +} diff --git a/src/webrtc.ts b/src/webrtc.ts index 8fa0ef6..fca86bc 100644 --- a/src/webrtc.ts +++ b/src/webrtc.ts @@ -3,22 +3,7 @@ import express from "express"; import * as jose from "jose"; import { prisma } from "./db"; import { NotFoundError, UnprocessableEntityError } from "./errors"; -import { IncomingMessage } from "http"; -import { Socket } from "node:net"; -import { Device } from "@prisma/client"; - -export const activeConnections: Map = new Map(); -export const inFlight: Set = new Set(); - -function toICEServers(str: string) { - return str.split(",").filter( - (url) => url.startsWith("stun:") - ); -} - -export const iceServers = toICEServers( - process.env.ICE_SERVERS || "stun.cloudflare.com:3478,stun:stun.l.google.com:19302,stun:stun1.l.google.com:5349" -); +import { activeConnections, iceServers, inFlight } from "./webrtc-signaling"; export const CreateSession = async (req: express.Request, res: express.Response) => { const idToken = req.session?.id_token; @@ -54,11 +39,10 @@ export const CreateSession = async (req: express.Request, res: express.Response) // extract the websocket and ip from the tuple const [ws, ip] = wsTuple; - let wsRes: ((value: unknown) => void) | null = null, - wsRej: ((value: unknown) => void) | null = null; - let timeout: NodeJS.Timeout | undefined; + let httpClose: (() => void) | null = null; + try { inFlight.add(id); const resp: any = await new Promise((res, rej) => { @@ -66,43 +50,50 @@ export const CreateSession = async (req: express.Request, res: express.Response) rej(new Error("Timeout waiting for response from ws")); }, 15000); - // Hoist the res and rej functions to be used in the finally block for cleanup - wsRes = res; - wsRej = rej; + ws.onerror = rej; + ws.onclose = rej; + ws.onmessage = res; - ws.addEventListener("message", wsRes); - ws.addEventListener("error", wsRej); - ws.addEventListener("close", wsRej); + httpClose = () => { + rej(new Error("HTTP client closed the connection")); + }; // If the HTTP client closes the connection before the websocket response is received, reject the promise - req.socket.on("close", wsRej); + req.socket.on("close", httpClose); - ws.send(JSON.stringify({ - sd, - ip, - iceServers, - OidcGoogle: idToken - })); + ws.send( + JSON.stringify({ + sd, + ip, + iceServers, + OidcGoogle: idToken, + }), + ); }); + console.log("[CreateSession] got response from device", id); return res.json(JSON.parse(resp.data)); } catch (e) { - console.error(`Error sending data to kvm with ${id}`, e); - - // If there was an error, remove the socket from the map - ws.close(); // Most likely there is no-one on the other end to close the connection - activeConnections.delete(id); + console.log(`Error sending data to kvm with ${id}`, e); return res .status(500) .json({ error: "There was an error sending and receiving data to the KVM" }); } finally { if (timeout) clearTimeout(timeout); + console.log("Removing in flight", id); inFlight.delete(id); - if (wsRes && wsRej) { - ws.removeEventListener("message", wsRes); - ws.removeEventListener("error", wsRej); - ws.removeEventListener("close", wsRej); + + if (httpClose) { + console.log("Removing http close listener", id); + req.socket.off("close", httpClose); + } + + if (ws) { + console.log("Removing ws listeners", id); + ws.onerror = null; + ws.onclose = null; + ws.onmessage = null; } } }; @@ -153,118 +144,3 @@ export const CreateTurnActivity = async (req: express.Request, res: express.Resp return res.json({ success: true }); }; - -async function updateDeviceLastSeen(id: string) { - const device = await prisma.device.findUnique({ where: { id } }); - if (device) { - return prisma.device.update({ where: { id }, data: { lastSeen: new Date() } }); - } -} - -export const registerWebsocketServer = (server: any) => { - const wss = new WebSocketServer({ noServer: true }); - - server.on("upgrade", async (req: IncomingMessage, socket: Socket, head: Buffer) => { - const authHeader = req.headers["authorization"]; - const secretToken = authHeader?.split(" ")?.[1]; - if (!secretToken) { - console.log("No authorization header provided. Closing socket."); - return socket.destroy(); - } - - let device: Device | null = null; - try { - device = await prisma.device.findFirst({ where: { secretToken } }); - } catch (error) { - console.log("There was an error validating the secret token", error); - return socket.destroy(); - } - - if (!device) { - console.log("Invalid secret token provided. Closing socket."); - return socket.destroy(); - } - - if (activeConnections.has(device.id)) { - console.log( - "Device already in active connection list. Terminating & deleting existing websocket.", - ); - activeConnections.get(device.id)?.[0]?.terminate(); - activeConnections.delete(device.id); - } - - wss.handleUpgrade(req, socket, head, function done(ws) { - wss.emit("connection", ws, req); - }); - }); - - wss.on("connection", async function connection(ws, req) { - const authHeader = req.headers["authorization"]; - const secretToken = authHeader?.split(" ")?.[1]; - - let device: Device | null = null; - try { - device = await prisma.device.findFirst({ where: { secretToken } }); - } catch (error) { - ws.send("There was an error validating the secret token. Closing ws connection."); - console.log("There was an error validating the secret token", error); - return ws.close(); - } - - if (!device) { - ws.send("Invalid secret token provided. Closing ws connection."); - console.log("Invalid secret token provided. Closing ws connection."); - return ws.close(); - } - - const id = req.headers["x-device-id"] as string | undefined; - const hasId = !!id; - - // Ensure id is provided - if (!hasId) { - ws.send("No id provided. Closing ws connection."); - console.log("No id provided. Closing ws connection."); - return ws.close(); - } - - if (!id) { - ws.send("Invalid id provided. Closing ws connection."); - console.log("Invalid id provided. Closing ws connection."); - return ws.close(); - } - - if (id !== device.id) { - ws.send("Id and token mismatch. Closing ws connection."); - console.log("Id and token mismatch. Closing ws connection."); - return ws.close(); - } - - // Ensure id is not inflight - if (inFlight.has(id)) { - ws.send(`ID, ${id} is in flight. Please try again.`); - console.log(`ID, ${id} is in flight. Please try again.`); - return ws.close(); - } - - const ip = ( - process.env.REAL_IP_HEADER && req.headers[process.env.REAL_IP_HEADER] - ) || req.socket.remoteAddress; - - activeConnections.set(id, [ws, `${ip}`]); - console.log("New socket for id", id); - - ws.on("error", async () => { - if (!id) return; - console.log("WS Error - Remove socket ", id); - activeConnections.delete(id); - await updateDeviceLastSeen(id); - }); - - ws.on("close", async () => { - if (!id) return; - console.log("WS Close - Remove socket ", id); - activeConnections.delete(id); - await updateDeviceLastSeen(id); - }); - }); -};