diff --git a/Sources/AWSLambdaRuntime/Lambda+LocalServer.swift b/Sources/AWSLambdaRuntime/Lambda+LocalServer.swift index 20163633..322bce1a 100644 --- a/Sources/AWSLambdaRuntime/Lambda+LocalServer.swift +++ b/Sources/AWSLambdaRuntime/Lambda+LocalServer.swift @@ -452,7 +452,7 @@ internal struct LambdaHTTPServer { await self.responsePool.push( LocalServerResponse( id: requestId, - status: .ok, + status: .accepted, // the local server has no mecanism to collect headers set by the lambda function headers: HTTPHeaders(), body: body, diff --git a/Sources/AWSLambdaRuntime/Lambda.swift b/Sources/AWSLambdaRuntime/Lambda.swift index 706fe567..f6223cb5 100644 --- a/Sources/AWSLambdaRuntime/Lambda.swift +++ b/Sources/AWSLambdaRuntime/Lambda.swift @@ -41,6 +41,8 @@ public enum Lambda { var logger = logger do { while !Task.isCancelled { + + logger.trace("Waiting for next invocation") let (invocation, writer) = try await runtimeClient.nextInvocation() logger[metadataKey: "aws-request-id"] = "\(invocation.metadata.requestID)" @@ -76,14 +78,18 @@ public enum Lambda { logger: logger ) ) + logger.trace("Handler finished processing invocation") } catch { + logger.trace("Handler failed processing invocation", metadata: ["Handler error": "\(error)"]) try await writer.reportError(error) continue } + logger.handler.metadata.removeValue(forKey: "aws-request-id") } } catch is CancellationError { // don't allow cancellation error to propagate further } + } /// The default EventLoop the Lambda is scheduled on. diff --git a/Sources/AWSLambdaRuntime/LambdaRuntime.swift b/Sources/AWSLambdaRuntime/LambdaRuntime.swift index 5f66df6f..a639ac31 100644 --- a/Sources/AWSLambdaRuntime/LambdaRuntime.swift +++ b/Sources/AWSLambdaRuntime/LambdaRuntime.swift @@ -94,16 +94,29 @@ public final class LambdaRuntime: Sendable where Handler: StreamingLamb let ip = String(ipAndPort[0]) guard let port = Int(ipAndPort[1]) else { throw LambdaRuntimeError(code: .invalidPort) } - try await LambdaRuntimeClient.withRuntimeClient( - configuration: .init(ip: ip, port: port), - eventLoop: self.eventLoop, - logger: self.logger - ) { runtimeClient in - try await Lambda.runLoop( - runtimeClient: runtimeClient, - handler: handler, + do { + try await LambdaRuntimeClient.withRuntimeClient( + configuration: .init(ip: ip, port: port), + eventLoop: self.eventLoop, logger: self.logger - ) + ) { runtimeClient in + try await Lambda.runLoop( + runtimeClient: runtimeClient, + handler: handler, + logger: self.logger + ) + } + } catch { + // catch top level errors that have not been handled until now + // this avoids the runtime to crash and generate a backtrace + self.logger.error("LambdaRuntime.run() failed with error", metadata: ["error": "\(error)"]) + if let error = error as? LambdaRuntimeError, + error.code != .connectionToControlPlaneLost + { + // if the error is a LambdaRuntimeError but not a connection error, + // we rethrow it to preserve existing behaviour + throw error + } } } else { diff --git a/Sources/AWSLambdaRuntime/LambdaRuntimeClient.swift b/Sources/AWSLambdaRuntime/LambdaRuntimeClient.swift index a1afb464..1d90e2c9 100644 --- a/Sources/AWSLambdaRuntime/LambdaRuntimeClient.swift +++ b/Sources/AWSLambdaRuntime/LambdaRuntimeClient.swift @@ -97,6 +97,7 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol { private let configuration: Configuration private var connectionState: ConnectionState = .disconnected + private var lambdaState: LambdaState = .idle(previousRequestID: nil) private var closingState: ClosingState = .notClosing @@ -118,10 +119,7 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol { } catch { result = .failure(error) } - await runtime.close() - - //try? await runtime.close() return try result.get() } @@ -163,12 +161,16 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol { @usableFromInline func nextInvocation() async throws -> (Invocation, Writer) { - try await withTaskCancellationHandler { + + try Task.checkCancellation() + + return try await withTaskCancellationHandler { switch self.lambdaState { case .idle: self.lambdaState = .waitingForNextInvocation let handler = try await self.makeOrGetConnection() let invocation = try await handler.nextInvocation() + guard case .waitingForNextInvocation = self.lambdaState else { fatalError("Invalid state: \(self.lambdaState)") } @@ -283,7 +285,7 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol { case (.connecting(let array), .notClosing): self.connectionState = .disconnected for continuation in array { - continuation.resume(throwing: LambdaRuntimeError(code: .lostConnectionToControlPlane)) + continuation.resume(throwing: LambdaRuntimeError(code: .connectionToControlPlaneLost)) } case (.connecting(let array), .closing(let continuation)): @@ -363,7 +365,9 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol { ) channel.closeFuture.whenComplete { result in self.assumeIsolated { runtimeClient in + // close the channel runtimeClient.channelClosed(channel) + runtimeClient.connectionState = .disconnected } } @@ -382,6 +386,7 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol { return handler } } catch { + switch self.connectionState { case .disconnected, .connected: fatalError("Unexpected state: \(self.connectionState)") @@ -430,7 +435,6 @@ extension LambdaRuntimeClient: LambdaChannelHandlerDelegate { } isolated.connectionState = .disconnected - } } } @@ -884,8 +888,16 @@ extension LambdaChannelHandler: ChannelInboundHandler { func channelInactive(context: ChannelHandlerContext) { // fail any pending responses with last error or assume peer disconnected switch self.state { - case .connected(_, .waitingForNextInvocation(let continuation)): - continuation.resume(throwing: self.lastError ?? ChannelError.ioOnClosedChannel) + case .connected(_, let lambdaState): + switch lambdaState { + case .waitingForNextInvocation(let continuation): + continuation.resume(throwing: self.lastError ?? ChannelError.ioOnClosedChannel) + case .sentResponse(let continuation): + continuation.resume(throwing: self.lastError ?? ChannelError.ioOnClosedChannel) + case .idle, .sendingResponse, .waitingForResponse: + break + } + self.state = .disconnected default: break } diff --git a/Sources/AWSLambdaRuntime/LambdaRuntimeError.swift b/Sources/AWSLambdaRuntime/LambdaRuntimeError.swift index a9c0cbca..bc4865db 100644 --- a/Sources/AWSLambdaRuntime/LambdaRuntimeError.swift +++ b/Sources/AWSLambdaRuntime/LambdaRuntimeError.swift @@ -25,7 +25,6 @@ package struct LambdaRuntimeError: Error { case writeAfterFinishHasBeenSent case finishAfterFinishHasBeenSent - case lostConnectionToControlPlane case unexpectedStatusCodeForRequest case nextInvocationMissingHeaderRequestID diff --git a/Sources/MockServer/MockHTTPServer.swift b/Sources/MockServer/MockHTTPServer.swift index ada1d765..7405b5bd 100644 --- a/Sources/MockServer/MockHTTPServer.swift +++ b/Sources/MockServer/MockHTTPServer.swift @@ -216,7 +216,7 @@ struct HttpServer { } else if requestHead.uri.hasSuffix("/response") { responseStatus = .accepted } else if requestHead.uri.hasSuffix("/error") { - responseStatus = .ok + responseStatus = .accepted } else { responseStatus = .notFound } diff --git a/Tests/AWSLambdaRuntimeTests/LambdaRuntimeClientTests.swift b/Tests/AWSLambdaRuntimeTests/LambdaRuntimeClientTests.swift index 35e89c3a..fe85f973 100644 --- a/Tests/AWSLambdaRuntimeTests/LambdaRuntimeClientTests.swift +++ b/Tests/AWSLambdaRuntimeTests/LambdaRuntimeClientTests.swift @@ -42,10 +42,10 @@ struct LambdaRuntimeClientTests { .success((self.requestId, self.event)) } - func processResponse(requestId: String, response: String?) -> Result { + func processResponse(requestId: String, response: String?) -> Result { #expect(self.requestId == requestId) #expect(self.event == response) - return .success(()) + return .success(nil) } func processError(requestId: String, error: ErrorResponse) -> Result { @@ -102,9 +102,9 @@ struct LambdaRuntimeClientTests { .success((self.requestId, self.event)) } - func processResponse(requestId: String, response: String?) -> Result { + func processResponse(requestId: String, response: String?) -> Result { #expect(self.requestId == requestId) - return .success(()) + return .success(nil) } mutating func captureHeaders(_ headers: HTTPHeaders) { @@ -197,10 +197,10 @@ struct LambdaRuntimeClientTests { .success((self.requestId, self.event)) } - func processResponse(requestId: String, response: String?) -> Result { + func processResponse(requestId: String, response: String?) -> Result { #expect(self.requestId == requestId) #expect(self.event == response) - return .success(()) + return .success(nil) } func processError(requestId: String, error: ErrorResponse) -> Result { @@ -238,4 +238,91 @@ struct LambdaRuntimeClientTests { } } } + + struct DisconnectAfterSendingResponseBehavior: LambdaServerBehavior { + func getInvocation() -> GetInvocationResult { + .success((UUID().uuidString, "hello")) + } + + func processResponse(requestId: String, response: String?) -> Result { + // Return "delayed-disconnect" to trigger server closing the connection + // after having accepted the first response + .success("delayed-disconnect") + } + + func processError(requestId: String, error: ErrorResponse) -> Result { + Issue.record("should not report error") + return .failure(.internalServerError) + } + + func processInitError(error: ErrorResponse) -> Result { + Issue.record("should not report init error") + return .failure(.internalServerError) + } + } + + struct DisconnectBehavior: LambdaServerBehavior { + func getInvocation() -> GetInvocationResult { + .success(("disconnect", "0")) + } + + func processResponse(requestId: String, response: String?) -> Result { + .success(nil) + } + + func processError(requestId: String, error: ErrorResponse) -> Result { + Issue.record("should not report error") + return .failure(.internalServerError) + } + + func processInitError(error: ErrorResponse) -> Result { + Issue.record("should not report init error") + return .failure(.internalServerError) + } + } + + @Test( + "Server closing the connection when waiting for next invocation throws an error", + arguments: [DisconnectBehavior(), DisconnectAfterSendingResponseBehavior()] as [any LambdaServerBehavior] + ) + func testChannelCloseFutureWithWaitingForNextInvocation(behavior: LambdaServerBehavior) async throws { + try await withMockServer(behaviour: behavior) { port in + let configuration = LambdaRuntimeClient.Configuration(ip: "127.0.0.1", port: port) + + try await LambdaRuntimeClient.withRuntimeClient( + configuration: configuration, + eventLoop: NIOSingletons.posixEventLoopGroup.next(), + logger: self.logger + ) { runtimeClient in + do { + + // simulate traffic until the server reports it has closed the connection + // or a timeout, whichever comes first + // result is ignored here, either there is a connection error or a timeout + let _ = try await timeout(deadline: .seconds(1)) { + while true { + let (_, writer) = try await runtimeClient.nextInvocation() + try await writer.writeAndFinish(ByteBuffer(string: "hello")) + } + } + // result is ignored here, we should never reach this line + Issue.record("Connection reset test did not throw an error") + + } catch is CancellationError { + Issue.record("Runtime client did not send connection closed error") + } catch let error as LambdaRuntimeError { + logger.trace("LambdaRuntimeError - expected") + #expect(error.code == .connectionToControlPlaneLost) + } catch let error as ChannelError { + logger.trace("ChannelError - expected") + #expect(error == .ioOnClosedChannel) + } catch let error as IOError { + logger.trace("IOError - expected") + #expect(error.errnoCode == ECONNRESET || error.errnoCode == EPIPE) + } catch { + Issue.record("Unexpected error type: \(error)") + } + } + } + } } diff --git a/Tests/AWSLambdaRuntimeTests/MockLambdaServer.swift b/Tests/AWSLambdaRuntimeTests/MockLambdaServer.swift index 5d307ce2..d5ad8876 100644 --- a/Tests/AWSLambdaRuntimeTests/MockLambdaServer.swift +++ b/Tests/AWSLambdaRuntimeTests/MockLambdaServer.swift @@ -160,6 +160,7 @@ final class HTTPHandler: ChannelInboundHandler { var responseStatus: HTTPResponseStatus var responseBody: String? var responseHeaders: [(String, String)]? + var disconnectAfterSend = false // Handle post-init-error first to avoid matching the less specific post-error suffix. if request.head.uri.hasSuffix(Consts.postInitErrorURL) { @@ -202,8 +203,11 @@ final class HTTPHandler: ChannelInboundHandler { behavior.captureHeaders(request.head.headers) switch behavior.processResponse(requestId: String(requestId), response: requestBody) { - case .success: + case .success(let next): responseStatus = .accepted + if next == "delayed-disconnect" { + disconnectAfterSend = true + } case .failure(let error): responseStatus = .init(statusCode: error.rawValue) } @@ -223,14 +227,21 @@ final class HTTPHandler: ChannelInboundHandler { } else { responseStatus = .notFound } - self.writeResponse(context: context, status: responseStatus, headers: responseHeaders, body: responseBody) + self.writeResponse( + context: context, + status: responseStatus, + headers: responseHeaders, + body: responseBody, + closeConnection: disconnectAfterSend + ) } func writeResponse( context: ChannelHandlerContext, status: HTTPResponseStatus, headers: [(String, String)]? = nil, - body: String? = nil + body: String? = nil, + closeConnection: Bool = false ) { var headers = HTTPHeaders(headers ?? []) headers.add(name: "Content-Length", value: "\(body?.utf8.count ?? 0)") @@ -253,14 +264,19 @@ final class HTTPHandler: ChannelInboundHandler { } let loopBoundContext = NIOLoopBound(context, eventLoop: context.eventLoop) - let keepAlive = self.keepAlive context.writeAndFlush(wrapOutboundOut(.end(nil))).whenComplete { result in + let context = loopBoundContext.value + if closeConnection { + context.close(promise: nil) + return + } + if case .failure(let error) = result { logger.error("write error \(error)") } + if !keepAlive { - let context = loopBoundContext.value context.close().whenFailure { error in logger.error("close error \(error)") } @@ -271,7 +287,7 @@ final class HTTPHandler: ChannelInboundHandler { protocol LambdaServerBehavior: Sendable { func getInvocation() -> GetInvocationResult - func processResponse(requestId: String, response: String?) -> Result + func processResponse(requestId: String, response: String?) -> Result func processError(requestId: String, error: ErrorResponse) -> Result func processInitError(error: ErrorResponse) -> Result diff --git a/Tests/AWSLambdaRuntimeTests/Timeout.swift b/Tests/AWSLambdaRuntimeTests/Timeout.swift new file mode 100644 index 00000000..6a8dc5dc --- /dev/null +++ b/Tests/AWSLambdaRuntimeTests/Timeout.swift @@ -0,0 +1,67 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftAWSLambdaRuntime open source project +// +// Copyright (c) 2025 Apple Inc. and the SwiftAWSLambdaRuntime project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftAWSLambdaRuntime project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +// as suggested by https://github.com/vapor/postgres-nio/issues/489#issuecomment-2186509773 +func timeout( + deadline: Duration, + _ closure: @escaping @Sendable () async throws -> Success +) async throws -> Success { + + let clock = ContinuousClock() + + let result = await withTaskGroup(of: TimeoutResult.self, returning: Result.self) { + taskGroup in + taskGroup.addTask { + do { + try await clock.sleep(until: clock.now + deadline, tolerance: nil) + return .deadlineHit + } catch { + return .deadlineCancelled + } + } + + taskGroup.addTask { + do { + let success = try await closure() + return .workFinished(.success(success)) + } catch let error { + return .workFinished(.failure(error)) + } + } + + var r: Swift.Result? + while let taskResult = await taskGroup.next() { + switch taskResult { + case .deadlineCancelled: + continue // loop + + case .deadlineHit: + taskGroup.cancelAll() + + case .workFinished(let result): + taskGroup.cancelAll() + r = result + } + } + return r! + } + + return try result.get() +} + +enum TimeoutResult { + case deadlineHit + case deadlineCancelled + case workFinished(Result) +}