diff --git a/cfn-resources/project/cmd/resource/model.go b/cfn-resources/project/cmd/resource/model.go index a71bd6823..39e1df842 100644 --- a/cfn-resources/project/cmd/resource/model.go +++ b/cfn-resources/project/cmd/resource/model.go @@ -13,6 +13,7 @@ type Model struct { ClusterCount *int `json:",omitempty"` ProjectSettings *ProjectSettings `json:",omitempty"` Profile *string `json:",omitempty"` + LambdaProxyArn *string `json:",omitempty"` ProjectTeams []ProjectTeam `json:",omitempty"` ProjectApiKeys []ProjectApiKey `json:",omitempty"` RegionUsageRestrictions *string `json:",omitempty"` diff --git a/cfn-resources/project/cmd/resource/resource.go b/cfn-resources/project/cmd/resource/resource.go index a4f9d14eb..2d26fa5db 100644 --- a/cfn-resources/project/cmd/resource/resource.go +++ b/cfn-resources/project/cmd/resource/resource.go @@ -20,13 +20,15 @@ import ( "fmt" "reflect" + "go.mongodb.org/atlas-sdk/v20231115014/admin" + "github.com/aws-cloudformation/cloudformation-cli-go-plugin/cfn/handler" "github.com/aws/aws-sdk-go/service/cloudformation" + "github.com/mongodb/mongodbatlas-cloudformation-resources/util" "github.com/mongodb/mongodbatlas-cloudformation-resources/util/constants" "github.com/mongodb/mongodbatlas-cloudformation-resources/util/progressevent" "github.com/mongodb/mongodbatlas-cloudformation-resources/util/validator" - "go.mongodb.org/atlas-sdk/v20231115014/admin" ) var CreateRequiredFields = []string{constants.OrgID, constants.Name} @@ -45,7 +47,8 @@ func initEnvWithLatestClient(req handler.Request, currentModel *Model, requiredF return nil, errEvent } - client, peErr := util.NewAtlasClient(&req, currentModel.Profile) + // client, peErr := util.NewAtlasClient(&req, currentModel.Profile) + client, peErr := util.NewAtlasClientWithLambdaProxySupport(&req, currentModel.Profile, currentModel.LambdaProxyArn) if peErr != nil { return nil, peErr } diff --git a/cfn-resources/project/docs/README.md b/cfn-resources/project/docs/README.md index ad7f67f8c..db4accce1 100644 --- a/cfn-resources/project/docs/README.md +++ b/cfn-resources/project/docs/README.md @@ -18,6 +18,7 @@ To declare this entity in your AWS CloudFormation template, use the following sy "WithDefaultAlertsSettings" : Boolean, "ProjectSettings" : projectSettings, "Profile" : String, + "LambdaProxyArn" : String, "ProjectTeams" : [ projectTeam, ... ], "ProjectApiKeys" : [ projectApiKey, ... ], "RegionUsageRestrictions" : String, @@ -37,6 +38,7 @@ Properties: WithDefaultAlertsSettings: Boolean ProjectSettings: projectSettings Profile: String + LambdaProxyArn: String ProjectTeams: - projectTeam ProjectApiKeys: @@ -105,6 +107,16 @@ _Type_: String _Update requires_: [Replacement](https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/using-cfn-updating-stacks-update-behaviors.html#update-replacement) +#### LambdaProxyArn + +lambda arn + +_Required_: No + +_Type_: String + +_Update requires_: [No interruption](https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/using-cfn-updating-stacks-update-behaviors.html#update-no-interrupt) + #### ProjectTeams Teams to which the authenticated user has access in the project specified using its unique 24-hexadecimal digit identifier. diff --git a/cfn-resources/project/lambdaproxy/ec2proxy.py b/cfn-resources/project/lambdaproxy/ec2proxy.py new file mode 100644 index 000000000..12b085b02 --- /dev/null +++ b/cfn-resources/project/lambdaproxy/ec2proxy.py @@ -0,0 +1,51 @@ +from flask import Flask, request, Response +import requests +import logging + +app = Flask(__name__) + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + +# NOTE: additional configuration would be required to also support Realm +TARGET_SERVER = "https://cloud-dev.mongodb.com" +logger.debug(f"EC2 Proxy configured with TARGET_SERVER: {TARGET_SERVER}") + +@app.route('/', defaults={'path': ''}, methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"]) +@app.route('/', methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"]) +def proxy(path): + logger.debug(f"Received request for path: {path}") + # Build the target URL + target_url = f"{TARGET_SERVER}/{path}" + logger.debug(f"Target URL: {target_url}") + + # Copy the incoming headers + headers = {key: value for key, value in request.headers if key.lower() != 'host'} + logger.debug(f"Request headers: {headers}") + + # Forward the request to Atlas + try: + resp = requests.request( + method=request.method, + url=target_url, + headers=headers, + data=request.get_data(), + cookies=request.cookies, + allow_redirects=False + ) + # logger.debug(f"Received response from target server - Status: {resp.status_code}, Headers: {resp.headers}") + except Exception as e: + logger.exception("Error forwarding the request to the target server:") + return Response("Error forwarding request", status=500) + + excluded_headers = ['content-encoding', 'content-length', 'transfer-encoding', 'connection'] + response_headers = [(name, value) for (name, value) in resp.raw.headers.items() + if name.lower() not in excluded_headers] + + response = Response(resp.content, resp.status_code, response_headers) + logger.debug(f"Returning proxied response with status: {resp.status_code}") + return response + +if __name__ == '__main__': + logger.debug("Starting EC2 Proxy Flask app on 0.0.0.0:80") + app.run(host='0.0.0.0', port=80) diff --git a/cfn-resources/project/lambdaproxy/lambda.py b/cfn-resources/project/lambdaproxy/lambda.py new file mode 100644 index 000000000..649453413 --- /dev/null +++ b/cfn-resources/project/lambdaproxy/lambda.py @@ -0,0 +1,73 @@ +import json +import logging +import requests # this requires adding request layer to lambda function +from urllib.parse import urlparse, urlunparse + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger() + +EC2_PROXY_ENDPOINT = "http://XX.X.X.XX" # Replace with private IP of EC2 proxy running Python/Flask +if not EC2_PROXY_ENDPOINT: + logger.error("EC2_PROXY_ENDPOINT is not set") + raise ValueError("EC2_PROXY_ENDPOINT is not set") +logger.debug(f"EC2_PROXY_ENDPOINT: {EC2_PROXY_ENDPOINT}") + +def lambda_handler(event, context): + """ + Expected event example: + { + "method": "GET", + "url": "https://cloud-dev.mongodb.com/api/atlas/v2/groups", + "headers": {"Header-Key": "value", ...}, + "body": "request body as string" + } + + This Lambda function should be deployed in a private subnet. It's corresponding SG + only allows traffic to EC2 proxy running Python/Flask + """ + logger.debug(f"Received event: {json.dumps(event)}") # TODO: remove + try: + method = event.get("method") + url = event.get("url") + headers = event.get("headers", {}) + body = event.get("body", None) + + if method is None or url is None: + msg = "Missing 'method' or 'url' in the event payload" + logger.error(msg) + raise ValueError(msg) + + logger.debug(f"Forwarding request - Method: {method}, URL: {url}, Headers: {headers}, Body: {body}") + + parsed_url = urlparse(url) + if parsed_url.scheme and parsed_url.netloc: + # Extract the path and query string because incoming URL will be in format: + # https://www.cloud.mongodb.com/api/atlas/v2/groups?param=value + path = parsed_url.path or "" + query = f"?{parsed_url.query}" if parsed_url.query else "" + new_url = path + query + logger.debug(f"Extracted relative URL: {new_url} from absolute URL: {url}") + else: + new_url = url + logger.debug(f"URL is relative: {new_url}") + + # Construct URL for the EC2 proxy + full_url = EC2_PROXY_ENDPOINT.rstrip("/") + new_url + logger.debug(f"Constructed full URL for proxy: {full_url}") + + # Forward request to the EC2 proxy + response = requests.request(method, full_url, headers=headers, data=body) + logger.debug(f"Response from EC2 proxy - Status: {response.status_code}, Headers: {dict(response.headers)}, Body: {response.text}") + + return { + "statusCode": response.status_code, + "headers": dict(response.headers), + "body": response.text + } + except Exception as e: + logger.exception("Error processing the request:") + return { + "statusCode": 500, + "headers": {}, + "body": json.dumps({"error": str(e)}) + } diff --git a/cfn-resources/project/mongodb-atlas-project.json b/cfn-resources/project/mongodb-atlas-project.json index 562c64147..c89569e90 100644 --- a/cfn-resources/project/mongodb-atlas-project.json +++ b/cfn-resources/project/mongodb-atlas-project.json @@ -127,6 +127,10 @@ "description": "Profile used to provide credentials information, (a secret with the cfn/atlas/profile/{Profile}, is required), if not provided default is used", "default": "default" }, + "LambdaProxyArn": { + "type": "string", + "description": "lambda arn" + }, "ProjectTeams": { "items": { "$ref": "#/definitions/projectTeam" @@ -177,22 +181,30 @@ "handlers": { "create": { "permissions": [ - "secretsmanager:GetSecretValue" + "secretsmanager:GetSecretValue", + "lambda:InvokeFunction", + "lambda:GetFunction" ] }, "read": { "permissions": [ - "secretsmanager:GetSecretValue" + "secretsmanager:GetSecretValue", + "lambda:InvokeFunction", + "lambda:GetFunction" ] }, "update": { "permissions": [ - "secretsmanager:GetSecretValue" + "secretsmanager:GetSecretValue", + "lambda:InvokeFunction", + "lambda:GetFunction" ] }, "delete": { "permissions": [ - "secretsmanager:GetSecretValue" + "secretsmanager:GetSecretValue", + "lambda:InvokeFunction", + "lambda:GetFunction" ] } }, diff --git a/cfn-resources/project/resource-role.yaml b/cfn-resources/project/resource-role.yaml index 8d3c7258d..6b1249ed2 100644 --- a/cfn-resources/project/resource-role.yaml +++ b/cfn-resources/project/resource-role.yaml @@ -30,6 +30,8 @@ Resources: Statement: - Effect: Allow Action: + - "lambda:GetFunction" + - "lambda:InvokeFunction" - "secretsmanager:GetSecretValue" Resource: "*" Outputs: diff --git a/cfn-resources/util/atlasLambdaClient.go b/cfn-resources/util/atlasLambdaClient.go new file mode 100644 index 000000000..3f4883757 --- /dev/null +++ b/cfn-resources/util/atlasLambdaClient.go @@ -0,0 +1,216 @@ +package util + +import ( + "bytes" + "encoding/json" + "errors" + "io/ioutil" + "log" + "net/http" + "strings" + + "github.com/aws-cloudformation/cloudformation-cli-go-plugin/cfn/handler" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/cloudformation" + "github.com/aws/aws-sdk-go/service/lambda" + "github.com/mongodb-forks/digest" + + "github.com/mongodb/mongodbatlas-cloudformation-resources/profile" +) + +// lambdaForwardingTransport implements http.RoundTripper +type lambdaForwardingTransport struct { + lambdaArn string + lambdaClient *lambda.Lambda +} + +func newLambdaForwardingTransport(req *handler.Request, lambdaArn string) *lambdaForwardingTransport { + log.Printf("Initializing lambdaForwardingTransport with Lambda ARN: %s", lambdaArn) + // Extract region from Lambda ARN - TODO: probably not required, remove + arnParts := strings.Split(lambdaArn, ":") + region := "us-east-1" // Default + if len(arnParts) >= 4 { + region = arnParts[3] + } + svc := lambda.New(req.Session, aws.NewConfig().WithRegion(region)) + + return &lambdaForwardingTransport{ + lambdaArn: lambdaArn, + lambdaClient: svc, + } +} + +// LambdaRequestPayload is sent to the Lambda function +type LambdaRequestPayload struct { + Method string `json:"method"` + URL string `json:"url"` + Headers map[string]string `json:"headers"` + Body string `json:"body"` +} + +// LambdaResponsePayload is returned by the Lambda function +type LambdaResponsePayload struct { + StatusCode int `json:"statusCode"` + Headers map[string]string `json:"headers"` + Body string `json:"body"` +} + +// This method currently uses extensive logging for POC purpose which should be reduced +func (t *lambdaForwardingTransport) RoundTrip(req *http.Request) (*http.Response, error) { + log.Printf("Entering lambdaForwardingTransport.RoundTrip for URL: %s, Method: %s", req.URL.String(), req.Method) + + headers := make(map[string]string) + for key, values := range req.Header { + if len(values) > 0 { + headers[key] = values[0] + } + } + log.Printf("Captured request headers: %+v", headers) + + var bodyBytes []byte + if req.Body != nil { + var err error + bodyBytes, err = ioutil.ReadAll(req.Body) + if err != nil { + log.Printf("Error reading request body: %v", err) + return nil, err + } + req.Body = ioutil.NopCloser(bytes.NewBuffer(bodyBytes)) + } else { + log.Printf("No request body to read (req.Body is nil)") + } + + payloadStruct := LambdaRequestPayload{ + Method: req.Method, + URL: req.URL.String(), + Headers: headers, + Body: string(bodyBytes), + } + payloadBytes, err := json.Marshal(payloadStruct) + if err != nil { + log.Printf("Error marshaling payload: %v", err) + return nil, err + } + log.Printf("Payload to be sent to Lambda: %s", string(payloadBytes)) + + input := &lambda.InvokeInput{ + FunctionName: aws.String(t.lambdaArn), + Payload: payloadBytes, + } + log.Printf("Invoking Lambda with input: %+v", input) + result, err := t.lambdaClient.Invoke(input) + if err != nil { + log.Printf("Error invoking Lambda: %v", err) + return nil, err + } + if result.FunctionError != nil { + errMsg := "Lambda function error: " + *result.FunctionError + log.Printf(errMsg) + return nil, errors.New(errMsg) + } + log.Printf("Lambda invocation result: %+v", result) + + var respPayload LambdaResponsePayload + err = json.Unmarshal(result.Payload, &respPayload) + if err != nil { + log.Printf("Error unmarshaling Lambda response payload: %v", err) + return nil, err + } + log.Printf("Parsed Lambda response payload: %+v", respPayload) + + resp := &http.Response{ + StatusCode: respPayload.StatusCode, + Status: http.StatusText(respPayload.StatusCode), + Header: make(http.Header), + Body: ioutil.NopCloser(bytes.NewBufferString(respPayload.Body)), + Request: req, + } + for key, value := range respPayload.Headers { + resp.Header.Set(key, value) + } + log.Printf("Returning HTTP response from RoundTrip: %+v", resp) + return resp, nil +} + +// This method currently uses extensive logging for POC purpose which should be reduced +func newAtlasV2ClientWithLambdaProxySupport(req *handler.Request, profileName *string, profileNamePrefixRequired bool, lambdaArn *string) (*MongoDBClient, *handler.ProgressEvent) { + log.Printf("Initializing newAtlasV2Client with profileName: %v", profileName) + prof, err := profile.NewProfile(req, profileName, profileNamePrefixRequired) + if err != nil { + log.Printf("Error creating profile: %v", err) + return nil, &handler.ProgressEvent{ + OperationStatus: handler.Failed, + Message: err.Error(), + HandlerErrorCode: cloudformation.HandlerErrorCodeNotFound, + } + } + + var client *http.Client + if lambdaArn != nil && *lambdaArn != "" { + log.Printf("Using chained digest transport with Lambda forwarding. Lambda ARN: %s", *lambdaArn) + lambdaTransport := newLambdaForwardingTransport(req, *lambdaArn) + digestTransport := digest.NewTransport(prof.PublicKey, prof.PrivateKey) + // Set the underlying transport to our Lambda transport + digestTransport.Transport = lambdaTransport + // Use the digest transport as the client transport + client = &http.Client{Transport: digestTransport} + } else { + log.Printf("Using default digest transport with PublicKey: %s", prof.PublicKey) + transport := digest.NewTransport(prof.PublicKey, prof.PrivateKey) + client, err = transport.Client() + if err != nil { + log.Printf("Error creating digest transport client: %v", err) + return nil, &handler.ProgressEvent{ + OperationStatus: handler.Failed, + Message: err.Error(), + HandlerErrorCode: cloudformation.HandlerErrorCodeInvalidRequest, + } + } + } + + c := Config{BaseURL: prof.BaseURL, DebugClient: prof.UseDebug()} + log.Printf("Config initialized: %+v", c) + + sdk20231115002Client, err := c.NewSDKv20231115002Client(client) + if err != nil { + log.Printf("Error creating SDKv20231115002Client: %v", err) + return nil, &handler.ProgressEvent{ + OperationStatus: handler.Failed, + Message: err.Error(), + HandlerErrorCode: cloudformation.HandlerErrorCodeInvalidRequest, + } + } + + sdk20231115014Client, err := c.NewSDKv20231115014Client(client) + if err != nil { + log.Printf("Error creating SDKv20231115014Client: %v", err) + return nil, &handler.ProgressEvent{ + OperationStatus: handler.Failed, + Message: err.Error(), + HandlerErrorCode: cloudformation.HandlerErrorCodeInvalidRequest, + } + } + + sdkV2LatestClient, err := c.NewSDKV2LatestClient(client) + if err != nil { + log.Printf("Error creating SDKV2LatestClient: %v", err) + return nil, &handler.ProgressEvent{ + OperationStatus: handler.Failed, + Message: err.Error(), + HandlerErrorCode: cloudformation.HandlerErrorCodeInvalidRequest, + } + } + + clients := &MongoDBClient{ + Atlas20231115002: sdk20231115002Client, + Atlas20231115014: sdk20231115014Client, + AtlasSDK: sdkV2LatestClient, + Config: &c, + } + log.Printf("newAtlasV2Client successfully created clients: %+v", clients) + return clients, nil +} + +func NewAtlasClientWithLambdaProxySupport(req *handler.Request, profileName, lambdaARN *string) (*MongoDBClient, *handler.ProgressEvent) { + return newAtlasV2ClientWithLambdaProxySupport(req, profileName, true, lambdaARN) +}