diff --git a/src/bun.js/api/server/NodeHTTPResponse.zig b/src/bun.js/api/server/NodeHTTPResponse.zig index 67bc27c3a05e99..cbca6f7bcf5757 100644 --- a/src/bun.js/api/server/NodeHTTPResponse.zig +++ b/src/bun.js/api/server/NodeHTTPResponse.zig @@ -407,6 +407,8 @@ extern "C" fn NodeHTTPServer__writeHead_http( statusMessage: [*]const u8, statusMessageLength: usize, headersObjectValue: jsc.JSValue, + shouldKeepAlive: bool, + keepAliveTimeout: u32, response: *anyopaque, ) void; @@ -415,11 +417,13 @@ extern "C" fn NodeHTTPServer__writeHead_https( statusMessage: [*]const u8, statusMessageLength: usize, headersObjectValue: jsc.JSValue, + shouldKeepAlive: bool, + keepAliveTimeout: u32, response: *anyopaque, ) void; pub fn writeHead(this: *NodeHTTPResponse, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!jsc.JSValue { - const arguments = callframe.argumentsUndef(3).slice(); + const arguments = callframe.argumentsUndef(5).slice(); if (this.isRequestedCompletedOrEnded()) { return globalObject.ERR(.STREAM_ALREADY_FINISHED, "Stream is already ended", .{}).throw(); @@ -436,6 +440,8 @@ pub fn writeHead(this: *NodeHTTPResponse, globalObject: *jsc.JSGlobalObject, cal const status_code_value: JSValue = if (arguments.len > 0) arguments[0] else .js_undefined; const status_message_value: JSValue = if (arguments.len > 1 and arguments[1] != .null) arguments[1] else .js_undefined; const headers_object_value: JSValue = if (arguments.len > 2 and arguments[2] != .null) arguments[2] else .js_undefined; + const should_keep_alive_value: JSValue = if (arguments.len > 3) arguments[3] else .js_undefined; + const keep_alive_timeout_value: JSValue = if (arguments.len > 4) arguments[4] else .js_undefined; const status_code: i32 = brk: { if (!status_code_value.isUndefined()) { @@ -461,6 +467,16 @@ pub fn writeHead(this: *NodeHTTPResponse, globalObject: *jsc.JSGlobalObject, cal return error.JSError; } + const should_keep_alive: bool = if (!should_keep_alive_value.isUndefined()) + should_keep_alive_value.toBoolean() + else + true; + + const keep_alive_timeout: u32 = if (!keep_alive_timeout_value.isUndefined()) + @intCast(keep_alive_timeout_value.toU32()) + else + 5000; + if (state.isHttpStatusCalled()) { return globalObject.ERR(.HTTP_HEADERS_SENT, "Stream already started", .{}).throw(); } @@ -468,7 +484,7 @@ pub fn writeHead(this: *NodeHTTPResponse, globalObject: *jsc.JSGlobalObject, cal do_it: { if (status_message_slice.len == 0) { if (HTTPStatusText.get(@intCast(status_code))) |status_message| { - writeHeadInternal(this.raw_response.?, globalObject, status_message, headers_object_value); + writeHeadInternal(this.raw_response.?, globalObject, status_message, headers_object_value, should_keep_alive, keep_alive_timeout); break :do_it; } } @@ -476,18 +492,18 @@ pub fn writeHead(this: *NodeHTTPResponse, globalObject: *jsc.JSGlobalObject, cal const message = if (status_message_slice.len > 0) status_message_slice.slice() else "HM"; const status_message = bun.handleOom(std.fmt.allocPrint(allocator, "{d} {s}", .{ status_code, message })); defer allocator.free(status_message); - writeHeadInternal(this.raw_response.?, globalObject, status_message, headers_object_value); + writeHeadInternal(this.raw_response.?, globalObject, status_message, headers_object_value, should_keep_alive, keep_alive_timeout); break :do_it; } return .js_undefined; } -fn writeHeadInternal(response: uws.AnyResponse, globalObject: *jsc.JSGlobalObject, status_message: []const u8, headers: jsc.JSValue) void { +fn writeHeadInternal(response: uws.AnyResponse, globalObject: *jsc.JSGlobalObject, status_message: []const u8, headers: jsc.JSValue, should_keep_alive: bool, keep_alive_timeout: u32) void { log("writeHeadInternal({s})", .{status_message}); switch (response) { - .TCP => NodeHTTPServer__writeHead_http(globalObject, status_message.ptr, status_message.len, headers, @ptrCast(response.TCP)), - .SSL => NodeHTTPServer__writeHead_https(globalObject, status_message.ptr, status_message.len, headers, @ptrCast(response.SSL)), + .TCP => NodeHTTPServer__writeHead_http(globalObject, status_message.ptr, status_message.len, headers, should_keep_alive, keep_alive_timeout, @ptrCast(response.TCP)), + .SSL => NodeHTTPServer__writeHead_https(globalObject, status_message.ptr, status_message.len, headers, should_keep_alive, keep_alive_timeout, @ptrCast(response.SSL)), } } diff --git a/src/bun.js/bindings/NodeHTTP.cpp b/src/bun.js/bindings/NodeHTTP.cpp index 116183e7f72dbe..fba9f8f6b298cf 100644 --- a/src/bun.js/bindings/NodeHTTP.cpp +++ b/src/bun.js/bindings/NodeHTTP.cpp @@ -510,7 +510,7 @@ static void writeResponseHeader(uWS::HttpResponse* res, const WTF::String } template -static void writeFetchHeadersToUWSResponse(WebCore::FetchHeaders& headers, uWS::HttpResponse* res) +static void writeFetchHeadersToUWSResponse(WebCore::FetchHeaders& headers, uWS::HttpResponse* res, bool shouldKeepAlive = true, uint32_t keepAliveTimeout = 5000) { auto& internalHeaders = headers.internalHeaders(); @@ -526,12 +526,18 @@ static void writeFetchHeadersToUWSResponse(WebCore::FetchHeaders& headers, uWS:: } auto* data = res->getHttpResponseData(); + bool hasConnectionHeader = false; for (const auto& header : internalHeaders.commonHeaders()) { const auto& name = WebCore::httpHeaderNameString(header.key); const auto& value = header.value; + // Check if Connection header is already set + if (header.key == WebCore::HTTPHeaderName::Connection) { + hasConnectionHeader = true; + } + // We have to tell uWS not to automatically insert a TransferEncoding or Date header. // Otherwise, you get this when using Fastify; // @@ -568,8 +574,29 @@ static void writeFetchHeadersToUWSResponse(WebCore::FetchHeaders& headers, uWS:: const auto& name = header.key; const auto& value = header.value; + // Check for Connection header in uncommon headers (case-insensitive) + if (name.length() == 10 && WTF::equalIgnoringASCIICase(name, "connection")) { + hasConnectionHeader = true; + } + writeResponseHeader(res, name, value); } + + // Add Connection: keep-alive and Keep-Alive headers if not already set + // This matches Node.js behavior + if (!hasConnectionHeader) { + if (shouldKeepAlive) { + res->writeHeader(std::string_view("Connection", 10), std::string_view("keep-alive", 10)); + + // Format Keep-Alive header with timeout + char keepAliveValue[32]; + uint32_t timeoutSeconds = keepAliveTimeout / 1000; + int len = snprintf(keepAliveValue, sizeof(keepAliveValue), "timeout=%u", timeoutSeconds); + res->writeHeader(std::string_view("Keep-Alive", 10), std::string_view(keepAliveValue, len)); + } else { + res->writeHeader(std::string_view("Connection", 10), std::string_view("close", 5)); + } + } } template @@ -578,6 +605,8 @@ static void NodeHTTPServer__writeHead( const char* statusMessage, size_t statusMessageLength, JSValue headersObjectValue, + bool shouldKeepAlive, + uint32_t keepAliveTimeout, uWS::HttpResponse* response) { auto& vm = globalObject->vm(); @@ -589,9 +618,11 @@ static void NodeHTTPServer__writeHead( } response->writeStatus(std::string_view(statusMessage, statusMessageLength)); + bool hasConnectionHeader = false; + if (headersObject) { if (auto* fetchHeaders = jsDynamicCast(headersObject)) { - writeFetchHeadersToUWSResponse(fetchHeaders->wrapped(), response); + writeFetchHeadersToUWSResponse(fetchHeaders->wrapped(), response, shouldKeepAlive, keepAliveTimeout); return; } @@ -611,6 +642,12 @@ static void NodeHTTPServer__writeHead( } String key = entry.key(); + + // Check for Connection header (case-insensitive) + if (key.length() == 10 && WTF::equalIgnoringASCIICase(key, "connection")) { + hasConnectionHeader = true; + } + String value = headerValue.toWTFString(globalObject); RETURN_IF_EXCEPTION(scope, false); @@ -631,6 +668,12 @@ static void NodeHTTPServer__writeHead( } String key = propertyNames[i].string(); + + // Check for Connection header (case-insensitive) + if (key.length() == 10 && WTF::equalIgnoringASCIICase(key, "connection")) { + hasConnectionHeader = true; + } + String value = headerValue.toWTFString(globalObject); RETURN_IF_EXCEPTION(scope, void()); writeResponseHeader(response, key, value); @@ -638,6 +681,22 @@ static void NodeHTTPServer__writeHead( } } + // Add Connection: keep-alive and Keep-Alive headers if not already set + // This matches Node.js behavior + if (!hasConnectionHeader) { + if (shouldKeepAlive) { + response->writeHeader(std::string_view("Connection", 10), std::string_view("keep-alive", 10)); + + // Format Keep-Alive header with timeout + char keepAliveValue[32]; + uint32_t timeoutSeconds = keepAliveTimeout / 1000; + int len = snprintf(keepAliveValue, sizeof(keepAliveValue), "timeout=%u", timeoutSeconds); + response->writeHeader(std::string_view("Keep-Alive", 10), std::string_view(keepAliveValue, len)); + } else { + response->writeHeader(std::string_view("Connection", 10), std::string_view("close", 5)); + } + } + RELEASE_AND_RETURN(scope, void()); } @@ -646,9 +705,11 @@ extern "C" void NodeHTTPServer__writeHead_http( const char* statusMessage, size_t statusMessageLength, JSValue headersObjectValue, + bool shouldKeepAlive, + uint32_t keepAliveTimeout, uWS::HttpResponse* response) { - return NodeHTTPServer__writeHead(globalObject, statusMessage, statusMessageLength, headersObjectValue, response); + return NodeHTTPServer__writeHead(globalObject, statusMessage, statusMessageLength, headersObjectValue, shouldKeepAlive, keepAliveTimeout, response); } extern "C" void NodeHTTPServer__writeHead_https( @@ -656,9 +717,11 @@ extern "C" void NodeHTTPServer__writeHead_https( const char* statusMessage, size_t statusMessageLength, JSValue headersObjectValue, + bool shouldKeepAlive, + uint32_t keepAliveTimeout, uWS::HttpResponse* response) { - return NodeHTTPServer__writeHead(globalObject, statusMessage, statusMessageLength, headersObjectValue, response); + return NodeHTTPServer__writeHead(globalObject, statusMessage, statusMessageLength, headersObjectValue, shouldKeepAlive, keepAliveTimeout, response); } extern "C" EncodedJSValue NodeHTTPServer__onRequest_http( diff --git a/src/js/node/_http_server.ts b/src/js/node/_http_server.ts index 81d627e5f5aa50..115bbcd9392fcc 100644 --- a/src/js/node/_http_server.ts +++ b/src/js/node/_http_server.ts @@ -1349,7 +1349,9 @@ ServerResponse.prototype.end = function (chunk, encoding, callback) { } if (headerState !== NodeHTTPHeaderState.sent) { handle.cork(() => { - handle.writeHead(this.statusCode, this.statusMessage, this[headersSymbol]); + const shouldKeepAlive = this.shouldKeepAlive; + const keepAliveTimeout = this.socket?.server?.keepAliveTimeout ?? 5000; + handle.writeHead(this.statusCode, this.statusMessage, this[headersSymbol], shouldKeepAlive, keepAliveTimeout); // If handle.writeHead throws, we don't want headersSent to be set to true. // So we set it here. @@ -1459,7 +1461,9 @@ ServerResponse.prototype.write = function (chunk, encoding, callback) { if (this[headerStateSymbol] !== NodeHTTPHeaderState.sent) { handle.cork(() => { - handle.writeHead(this.statusCode, this.statusMessage, this[headersSymbol]); + const shouldKeepAlive = this.shouldKeepAlive; + const keepAliveTimeout = this.socket?.server?.keepAliveTimeout ?? 5000; + handle.writeHead(this.statusCode, this.statusMessage, this[headersSymbol], shouldKeepAlive, keepAliveTimeout); // If handle.writeHead throws, we don't want headersSent to be set to true. // So we set it here. @@ -1562,7 +1566,9 @@ ServerResponse.prototype._send = function (data, encoding, callback, _byteLength if (this[headerStateSymbol] !== NodeHTTPHeaderState.sent) { handle.cork(() => { - handle.writeHead(this.statusCode, this.statusMessage, this[headersSymbol]); + const shouldKeepAlive = this.shouldKeepAlive; + const keepAliveTimeout = this.socket?.server?.keepAliveTimeout ?? 5000; + handle.writeHead(this.statusCode, this.statusMessage, this[headersSymbol], shouldKeepAlive, keepAliveTimeout); this[headerStateSymbol] = NodeHTTPHeaderState.sent; handle.write(data, encoding, callback, strictContentLength(this)); }); @@ -1629,7 +1635,9 @@ ServerResponse.prototype.flushHeaders = function () { if (this[headerStateSymbol] === NodeHTTPHeaderState.assigned) { this[headerStateSymbol] = NodeHTTPHeaderState.sent; - handle.writeHead(this.statusCode, this.statusMessage, this[headersSymbol]); + const shouldKeepAlive = this.shouldKeepAlive; + const keepAliveTimeout = this.socket?.server?.keepAliveTimeout ?? 5000; + handle.writeHead(this.statusCode, this.statusMessage, this[headersSymbol], shouldKeepAlive, keepAliveTimeout); } handle.flushHeaders(); } diff --git a/test/js/node/http/node-http-connection-headers.test.ts b/test/js/node/http/node-http-connection-headers.test.ts new file mode 100644 index 00000000000000..aef00ce4ebdae6 --- /dev/null +++ b/test/js/node/http/node-http-connection-headers.test.ts @@ -0,0 +1,199 @@ +import { expect, test } from "bun:test"; +import { createServer } from "node:http"; + +test("should include Connection: keep-alive and Keep-Alive headers by default", async () => { + const server = createServer((req, res) => { + res.writeHead(200, { "Content-Type": "text/plain" }); + res.end("Hello World"); + }); + + await new Promise(resolve => server.listen(0, () => resolve())); + const port = (server.address() as any).port; + + try { + const response = await fetch(`http://localhost:${port}/`); + const text = await response.text(); + + expect(text).toBe("Hello World"); + expect(response.headers.get("connection")).toBe("keep-alive"); + const keepAlive = response.headers.get("keep-alive"); + expect(keepAlive).toMatch(/timeout=\d+/); + // Default keepAliveTimeout is 5000ms (5 seconds) + expect(keepAlive).toMatch(/timeout=5/); + } finally { + server.close(); + } +}); + +test("should respect user-set Connection: close header", async () => { + const server = createServer((req, res) => { + res.setHeader("Connection", "close"); + res.writeHead(200); + res.end("test"); + }); + + await new Promise(resolve => server.listen(0, () => resolve())); + const port = (server.address() as any).port; + + try { + const response = await fetch(`http://localhost:${port}/`); + const text = await response.text(); + + expect(text).toBe("test"); + expect(response.headers.get("connection")).toBe("close"); + expect(response.headers.get("keep-alive")).toBeNull(); + } finally { + server.close(); + } +}); + +test("should respect user-set Connection: keep-alive header", async () => { + const server = createServer((req, res) => { + res.setHeader("Connection", "keep-alive"); + res.setHeader("Keep-Alive", "timeout=30"); + res.writeHead(200); + res.end("test"); + }); + + await new Promise(resolve => server.listen(0, () => resolve())); + const port = (server.address() as any).port; + + try { + const response = await fetch(`http://localhost:${port}/`); + const text = await response.text(); + + expect(text).toBe("test"); + expect(response.headers.get("connection")).toBe("keep-alive"); + expect(response.headers.get("keep-alive")).toBe("timeout=30"); + } finally { + server.close(); + } +}); + +test("should use default keepAliveTimeout (5 seconds)", async () => { + const server = createServer((req, res) => { + res.writeHead(200); + res.end("test"); + }); + + await new Promise(resolve => server.listen(0, () => resolve())); + const port = (server.address() as any).port; + + try { + const response = await fetch(`http://localhost:${port}/`); + const text = await response.text(); + + expect(text).toBe("test"); + expect(response.headers.get("connection")).toBe("keep-alive"); + expect(response.headers.get("keep-alive")).toBe("timeout=5"); + } finally { + server.close(); + } +}); + +test("should use custom keepAliveTimeout when configured", async () => { + const server = createServer((req, res) => { + res.writeHead(200); + res.end("test"); + }); + + // Set custom keepAliveTimeout to 10 seconds + server.keepAliveTimeout = 10000; + + await new Promise(resolve => server.listen(0, () => resolve())); + const port = (server.address() as any).port; + + try { + const response = await fetch(`http://localhost:${port}/`); + const text = await response.text(); + + expect(text).toBe("test"); + expect(response.headers.get("connection")).toBe("keep-alive"); + expect(response.headers.get("keep-alive")).toBe("timeout=10"); + } finally { + server.close(); + } +}); + +test("should include Connection headers with POST requests", async () => { + const server = createServer((req, res) => { + let body = ""; + req.on("data", chunk => { + body += chunk; + }); + req.on("end", () => { + res.writeHead(200, { "Content-Type": "application/json" }); + res.end( + JSON.stringify({ + success: true, + receivedData: JSON.parse(body), + }), + ); + }); + }); + + await new Promise(resolve => server.listen(0, () => resolve())); + const port = (server.address() as any).port; + + try { + const response = await fetch(`http://localhost:${port}/`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ test: "data" }), + }); + + const data = await response.json(); + + expect(data.success).toBe(true); + expect(data.receivedData.test).toBe("data"); + expect(response.headers.get("connection")).toBe("keep-alive"); + expect(response.headers.get("keep-alive")).toMatch(/timeout=\d+/); + } finally { + server.close(); + } +}); + +test("should include Connection headers when using setHeader before writeHead", async () => { + const server = createServer((req, res) => { + res.setHeader("Content-Type", "text/plain"); + res.setHeader("X-Custom-Header", "value"); + res.writeHead(200); + res.end("test"); + }); + + await new Promise(resolve => server.listen(0, () => resolve())); + const port = (server.address() as any).port; + + try { + const response = await fetch(`http://localhost:${port}/`); + const text = await response.text(); + + expect(text).toBe("test"); + expect(response.headers.get("connection")).toBe("keep-alive"); + expect(response.headers.get("keep-alive")).toMatch(/timeout=\d+/); + expect(response.headers.get("x-custom-header")).toBe("value"); + } finally { + server.close(); + } +}); + +test("should include Connection headers when not calling writeHead explicitly", async () => { + const server = createServer((req, res) => { + // Not calling writeHead - it will be called implicitly + res.end("test"); + }); + + await new Promise(resolve => server.listen(0, () => resolve())); + const port = (server.address() as any).port; + + try { + const response = await fetch(`http://localhost:${port}/`); + const text = await response.text(); + + expect(text).toBe("test"); + expect(response.headers.get("connection")).toBe("keep-alive"); + expect(response.headers.get("keep-alive")).toMatch(/timeout=\d+/); + } finally { + server.close(); + } +});