diff --git a/.changes/next-release/feature-AWSSDKforJavav2-e386e28.json b/.changes/next-release/feature-AWSSDKforJavav2-e386e28.json new file mode 100644 index 000000000000..99942ad06e7e --- /dev/null +++ b/.changes/next-release/feature-AWSSDKforJavav2-e386e28.json @@ -0,0 +1,6 @@ +{ + "type": "feature", + "category": "AWS SDK for Java v2", + "contributor": "", + "description": "Add support for payload signing of async streaming requests signed with SigV4 using default `AwsV4HttpSigner` (using `AwsV4HttpSigner.create()`). Note, requests using the `http` URI scheme will not be signed regardless of the value of `AwsV4FamilyHttpSigner.PAYLOAD_SIGNING_ENABLED` to remain consistent with existing behavior. This may change in a future release." +} diff --git a/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/crt/internal/signer/AwsChunkedV4aPayloadSigner.java b/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/crt/internal/signer/AwsChunkedV4aPayloadSigner.java index de5d16b92799..e0a491c6bfd4 100644 --- a/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/crt/internal/signer/AwsChunkedV4aPayloadSigner.java +++ b/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/crt/internal/signer/AwsChunkedV4aPayloadSigner.java @@ -21,7 +21,6 @@ import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant.STREAMING_ECDSA_SIGNED_PAYLOAD_TRAILER; import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant.STREAMING_UNSIGNED_PAYLOAD_TRAILER; import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant.X_AMZ_TRAILER; -import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerUtils.moveContentLength; import java.io.InputStream; import java.nio.charset.StandardCharsets; @@ -41,6 +40,7 @@ import software.amazon.awssdk.http.auth.aws.internal.signer.chunkedencoding.TrailerProvider; import software.amazon.awssdk.http.auth.aws.internal.signer.io.ChecksumInputStream; import software.amazon.awssdk.http.auth.aws.internal.signer.io.ResettableContentStreamProvider; +import software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerUtils; import software.amazon.awssdk.http.auth.spi.signer.PayloadChecksumStore; import software.amazon.awssdk.utils.BinaryUtils; import software.amazon.awssdk.utils.Logger; @@ -115,7 +115,7 @@ public ContentStreamProvider sign(ContentStreamProvider payload, V4aRequestSigni @Override public void beforeSigning(SdkHttpRequest.Builder request, ContentStreamProvider payload, String checksum) { long encodedContentLength = 0; - long contentLength = moveContentLength(request, payload); + long contentLength = SignerUtils.computeAndMoveContentLength(request, payload); setupPreExistingTrailers(request); // pre-existing trailers diff --git a/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/AwsChunkedV4PayloadSigner.java b/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/AwsChunkedV4PayloadSigner.java index 8e2e3a3a168b..7647d4af483d 100644 --- a/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/AwsChunkedV4PayloadSigner.java +++ b/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/AwsChunkedV4PayloadSigner.java @@ -23,14 +23,17 @@ import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant.STREAMING_SIGNED_PAYLOAD_TRAILER; import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant.STREAMING_UNSIGNED_PAYLOAD_TRAILER; import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant.X_AMZ_CONTENT_SHA256; +import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant.X_AMZ_DECODED_CONTENT_LENGTH; import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant.X_AMZ_TRAILER; -import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerUtils.moveContentLength; +import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerUtils.computeAndMoveContentLength; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; import org.reactivestreams.Publisher; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.checksums.SdkChecksum; @@ -38,13 +41,17 @@ import software.amazon.awssdk.http.ContentStreamProvider; import software.amazon.awssdk.http.Header; import software.amazon.awssdk.http.SdkHttpRequest; +import software.amazon.awssdk.http.auth.aws.internal.signer.chunkedencoding.AsyncChunkEncodedPayload; import software.amazon.awssdk.http.auth.aws.internal.signer.chunkedencoding.ChecksumTrailerProvider; import software.amazon.awssdk.http.auth.aws.internal.signer.chunkedencoding.ChunkedEncodedInputStream; +import software.amazon.awssdk.http.auth.aws.internal.signer.chunkedencoding.ChunkedEncodedPayload; +import software.amazon.awssdk.http.auth.aws.internal.signer.chunkedencoding.ChunkedEncodedPublisher; import software.amazon.awssdk.http.auth.aws.internal.signer.chunkedencoding.SigV4ChunkExtensionProvider; import software.amazon.awssdk.http.auth.aws.internal.signer.chunkedencoding.SigV4TrailerProvider; +import software.amazon.awssdk.http.auth.aws.internal.signer.chunkedencoding.SyncChunkEncodedPayload; import software.amazon.awssdk.http.auth.aws.internal.signer.chunkedencoding.TrailerProvider; -import software.amazon.awssdk.http.auth.aws.internal.signer.io.ChecksumInputStream; import software.amazon.awssdk.http.auth.aws.internal.signer.io.ResettableContentStreamProvider; +import software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerUtils; import software.amazon.awssdk.http.auth.spi.signer.PayloadChecksumStore; import software.amazon.awssdk.utils.BinaryUtils; import software.amazon.awssdk.utils.Logger; @@ -79,81 +86,140 @@ public static Builder builder() { @Override public ContentStreamProvider sign(ContentStreamProvider payload, V4RequestSigningResult requestSigningResult) { - SdkHttpRequest.Builder request = requestSigningResult.getSignedRequest(); - - String checksum = request.firstMatchingHeader(X_AMZ_CONTENT_SHA256).orElseThrow( - () -> new IllegalArgumentException(X_AMZ_CONTENT_SHA256 + " must be set!") - ); - ChunkedEncodedInputStream.Builder chunkedEncodedInputStreamBuilder = ChunkedEncodedInputStream .builder() .inputStream(payload.newStream()) .chunkSize(chunkSize) .header(chunk -> Integer.toHexString(chunk.remaining()).getBytes(StandardCharsets.UTF_8)); - preExistingTrailers.forEach(trailer -> chunkedEncodedInputStreamBuilder.addTrailer(() -> trailer)); + SyncChunkEncodedPayload chunkedPayload = new SyncChunkEncodedPayload(chunkedEncodedInputStreamBuilder); + signCommon(chunkedPayload, requestSigningResult); + + return new ResettableContentStreamProvider(chunkedEncodedInputStreamBuilder::build); + } + + @Override + public Publisher signAsync(Publisher payload, V4RequestSigningResult requestSigningResult) { + ChunkedEncodedPublisher.Builder chunkedStreamBuilder = ChunkedEncodedPublisher.builder() + .publisher(payload) + .chunkSize(chunkSize) + .addEmptyTrailingChunk(true); + + AsyncChunkEncodedPayload chunkedPayload = new AsyncChunkEncodedPayload(chunkedStreamBuilder); + signCommon(chunkedPayload, requestSigningResult); + + return chunkedStreamBuilder.build(); + } + + private void signCommon(ChunkedEncodedPayload payload, V4RequestSigningResult requestSigningResult) { + preExistingTrailers.forEach(t -> payload.addTrailer(() -> t)); + + SdkHttpRequest.Builder request = requestSigningResult.getSignedRequest(); + + payload.decodedContentLength(request.firstMatchingHeader(X_AMZ_DECODED_CONTENT_LENGTH) + .map(Long::parseLong) + .orElseThrow(() -> { + String msg = String.format("Expected header '%s' to be present", + X_AMZ_DECODED_CONTENT_LENGTH); + return new RuntimeException(msg); + })); + + String checksum = request.firstMatchingHeader(X_AMZ_CONTENT_SHA256).orElseThrow( + () -> new IllegalArgumentException(X_AMZ_CONTENT_SHA256 + " must be set!") + ); switch (checksum) { case STREAMING_SIGNED_PAYLOAD: { RollingSigner rollingSigner = new RollingSigner(requestSigningResult.getSigningKey(), requestSigningResult.getSignature()); - chunkedEncodedInputStreamBuilder.addExtension(new SigV4ChunkExtensionProvider(rollingSigner, credentialScope)); + payload.addExtension(new SigV4ChunkExtensionProvider(rollingSigner, credentialScope)); break; } case STREAMING_UNSIGNED_PAYLOAD_TRAILER: - setupChecksumTrailerIfNeeded(chunkedEncodedInputStreamBuilder); + setupChecksumTrailerIfNeeded(payload); break; case STREAMING_SIGNED_PAYLOAD_TRAILER: { + setupChecksumTrailerIfNeeded(payload); RollingSigner rollingSigner = new RollingSigner(requestSigningResult.getSigningKey(), requestSigningResult.getSignature()); - chunkedEncodedInputStreamBuilder.addExtension(new SigV4ChunkExtensionProvider(rollingSigner, credentialScope)); - setupChecksumTrailerIfNeeded(chunkedEncodedInputStreamBuilder); - chunkedEncodedInputStreamBuilder.addTrailer( - new SigV4TrailerProvider(chunkedEncodedInputStreamBuilder.trailers(), rollingSigner, credentialScope) + payload.addExtension(new SigV4ChunkExtensionProvider(rollingSigner, credentialScope)); + payload.addTrailer( + new SigV4TrailerProvider(payload.trailers(), rollingSigner, credentialScope) ); break; } default: throw new UnsupportedOperationException(); } - - return new ResettableContentStreamProvider(chunkedEncodedInputStreamBuilder::build); - } - - @Override - public Publisher signAsync(Publisher payload, V4RequestSigningResult requestSigningResult) { - // TODO(sra-identity-and-auth): implement this first and remove addFlexibleChecksumInTrailer logic in HttpChecksumStage - throw new UnsupportedOperationException(); } @Override public void beforeSigning(SdkHttpRequest.Builder request, ContentStreamProvider payload) { long encodedContentLength = 0; - long contentLength = moveContentLength(request, payload); + long contentLength = SignerUtils.computeAndMoveContentLength(request, payload); setupPreExistingTrailers(request); // pre-existing trailers + encodedContentLength = calculateEncodedContentLength(request, contentLength); + + if (checksumAlgorithm != null) { + String checksumHeaderName = checksumHeaderName(checksumAlgorithm); + request.appendHeader(X_AMZ_TRAILER, checksumHeaderName); + } + request.putHeader(Header.CONTENT_LENGTH, Long.toString(encodedContentLength)); + request.appendHeader(CONTENT_ENCODING, AWS_CHUNKED); + } + + @Override + public CompletableFuture>>> beforeSigningAsync( + SdkHttpRequest.Builder request, Publisher payload) { + return computeAndMoveContentLength(request, payload) + .thenApply(p -> { + SdkHttpRequest.Builder requestBuilder = p.left(); + setupPreExistingTrailers(requestBuilder); + + long decodedContentLength = requestBuilder.firstMatchingHeader(X_AMZ_DECODED_CONTENT_LENGTH) + .map(Long::parseLong) + // should not happen, this header is added by moveContentLength + .orElseThrow(() -> new RuntimeException(X_AMZ_DECODED_CONTENT_LENGTH + + " header not present")); + + long encodedContentLength = calculateEncodedContentLength(request, decodedContentLength); + + if (checksumAlgorithm != null) { + String checksumHeaderName = checksumHeaderName(checksumAlgorithm); + request.appendHeader(X_AMZ_TRAILER, checksumHeaderName); + } + request.putHeader(Header.CONTENT_LENGTH, Long.toString(encodedContentLength)); + request.appendHeader(CONTENT_ENCODING, AWS_CHUNKED); + return Pair.of(requestBuilder, p.right()); + }); + } + + private long calculateEncodedContentLength(SdkHttpRequest.Builder requestBuilder, long decodedContentLength) { + long encodedContentLength = 0; + encodedContentLength += calculateExistingTrailersLength(); - String checksum = request.firstMatchingHeader(X_AMZ_CONTENT_SHA256).orElseThrow( + String checksum = requestBuilder.firstMatchingHeader(X_AMZ_CONTENT_SHA256).orElseThrow( () -> new IllegalArgumentException(X_AMZ_CONTENT_SHA256 + " must be set!") ); switch (checksum) { case STREAMING_SIGNED_PAYLOAD: { long extensionsLength = 81; // ;chunk-signature: - encodedContentLength += calculateChunksLength(contentLength, extensionsLength); + encodedContentLength += calculateChunksLength(decodedContentLength, extensionsLength); break; } case STREAMING_UNSIGNED_PAYLOAD_TRAILER: if (checksumAlgorithm != null) { encodedContentLength += calculateChecksumTrailerLength(checksumHeaderName(checksumAlgorithm)); } - encodedContentLength += calculateChunksLength(contentLength, 0); + encodedContentLength += calculateChunksLength(decodedContentLength, 0); break; case STREAMING_SIGNED_PAYLOAD_TRAILER: { long extensionsLength = 81; // ;chunk-signature: - encodedContentLength += calculateChunksLength(contentLength, extensionsLength); + encodedContentLength += calculateChunksLength(decodedContentLength, extensionsLength); if (checksumAlgorithm != null) { encodedContentLength += calculateChecksumTrailerLength(checksumHeaderName(checksumAlgorithm)); } @@ -167,12 +233,7 @@ public void beforeSigning(SdkHttpRequest.Builder request, ContentStreamProvider // terminating \r\n encodedContentLength += 2; - if (checksumAlgorithm != null) { - String checksumHeaderName = checksumHeaderName(checksumAlgorithm); - request.appendHeader(X_AMZ_TRAILER, checksumHeaderName); - } - request.putHeader(Header.CONTENT_LENGTH, Long.toString(encodedContentLength)); - request.appendHeader(CONTENT_ENCODING, AWS_CHUNKED); + return encodedContentLength; } /** @@ -256,12 +317,7 @@ private long calculateChecksumTrailerLength(String checksumHeaderName) { return lengthInBytes + 2; } - /** - * Add the checksum as a trailer to the chunk-encoded stream. - *

- * If the checksum-algorithm is not present, then nothing is done. - */ - private void setupChecksumTrailerIfNeeded(ChunkedEncodedInputStream.Builder builder) { + private void setupChecksumTrailerIfNeeded(ChunkedEncodedPayload payload) { if (checksumAlgorithm == null) { return; } @@ -273,20 +329,17 @@ private void setupChecksumTrailerIfNeeded(ChunkedEncodedInputStream.Builder buil if (cachedChecksum != null) { LOG.debug(() -> String.format("Cached payload checksum available for algorithm %s: %s. Using cached value", checksumAlgorithm.algorithmId(), checksumHeaderName)); - builder.addTrailer(() -> Pair.of(checksumHeaderName, Collections.singletonList(cachedChecksum))); + payload.addTrailer(() -> Pair.of(checksumHeaderName, Collections.singletonList(cachedChecksum))); return; } SdkChecksum sdkChecksum = fromChecksumAlgorithm(checksumAlgorithm); - ChecksumInputStream checksumInputStream = new ChecksumInputStream( - builder.inputStream(), - Collections.singleton(sdkChecksum) - ); TrailerProvider checksumTrailer = new ChecksumTrailerProvider(sdkChecksum, checksumHeaderName, checksumAlgorithm, payloadChecksumStore); - builder.inputStream(checksumInputStream).addTrailer(checksumTrailer); + payload.checksumPayload(sdkChecksum); + payload.addTrailer(checksumTrailer); } private String getCachedChecksum() { diff --git a/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/DefaultAwsV4HttpSigner.java b/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/DefaultAwsV4HttpSigner.java index 66ee6d4cf733..eeb684910f90 100644 --- a/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/DefaultAwsV4HttpSigner.java +++ b/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/DefaultAwsV4HttpSigner.java @@ -26,11 +26,13 @@ import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant.X_AMZ_TRAILER; import static software.amazon.awssdk.http.auth.spi.signer.SdkInternalHttpSignerProperty.CHECKSUM_STORE; +import java.nio.ByteBuffer; import java.time.Clock; import java.time.Duration; import java.time.Instant; import java.util.concurrent.CompletableFuture; import java.util.function.Function; +import org.reactivestreams.Publisher; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.http.ContentStreamProvider; import software.amazon.awssdk.http.SdkHttpRequest; @@ -67,19 +69,19 @@ public SignedRequest sign(SignRequest request) @Override public CompletableFuture signAsync(AsyncSignRequest request) { - Checksummer checksummer = asyncChecksummer(request); + Checksummer checksummer = asyncChecksummer(request, checksumStore(request)); V4Properties v4Properties = v4Properties(request); V4RequestSigner v4RequestSigner = v4RequestSigner(request, v4Properties); V4PayloadSigner payloadSigner = v4PayloadAsyncSigner(request, v4Properties); - return doSign(request, checksummer, v4RequestSigner, payloadSigner); + return doSignAsync(request, checksummer, v4RequestSigner, payloadSigner); } private static V4Properties v4Properties(BaseSignRequest request) { Clock signingClock = request.requireProperty(SIGNING_CLOCK, Clock.systemUTC()); Instant signingInstant = signingClock.instant(); AwsCredentialsIdentity credentials = sanitizeCredentials(request.identity()); - String regionName = request.requireProperty(AwsV4HttpSigner.REGION_NAME); + String regionName = request.requireProperty(REGION_NAME); String serviceSigningName = request.requireProperty(SERVICE_SIGNING_NAME); CredentialScope credentialScope = new CredentialScope(regionName, serviceSigningName, signingInstant); boolean doubleUrlEncode = request.requireProperty(DOUBLE_URL_ENCODE, true); @@ -127,27 +129,29 @@ private static V4RequestSigner v4RequestSigner( return requestSigner.apply(v4Properties); } - /** - * This is needed because of the pre-existing gap (pre-SRA) in behavior where we don't treat async + streaming + http + - * unsigned-payload as signed-payload (fallback). We have to do some finagling of the payload-signing options before - * calling the actual checksummer() method - */ - private static Checksummer asyncChecksummer(BaseSignRequest request) { - boolean isHttp = !"https".equals(request.request().protocol()); - boolean isPayloadSigning = isPayloadSigning(request); - boolean isChunkEncoding = request.requireProperty(CHUNK_ENCODING_ENABLED, false); - boolean shouldTreatAsUnsigned = isHttp && isPayloadSigning && isChunkEncoding; + // TODO: remove this once we consolidate the behavior for plaintext HTTP signing for sync and async + private static Checksummer asyncChecksummer(BaseSignRequest request, + PayloadChecksumStore checksumStore) { + boolean shouldTreatAsUnsigned = asyncShouldTreatAsUnsigned(request); // set the override to false if it should be treated as unsigned, otherwise, null should be passed so that the normal // check for payload signing is done. Boolean overridePayloadSigning = shouldTreatAsUnsigned ? false : null; - return checksummer(request, overridePayloadSigning, PayloadChecksumStore.create()); + return checksummer(request, overridePayloadSigning, checksumStore); + } + + // TODO: remove this once we consolidate the behavior for plaintext HTTP signing for sync and async + private static boolean asyncShouldTreatAsUnsigned(BaseSignRequest request) { + boolean isHttp = !"https".equals(request.request().protocol()); + boolean isPayloadSigning = isPayloadSigning(request); + boolean isChunkEncoding = request.requireProperty(CHUNK_ENCODING_ENABLED, false); + + return isHttp && isPayloadSigning && isChunkEncoding; } private static V4PayloadSigner v4PayloadSigner( - SignRequest request, - V4Properties properties) { + BaseSignRequest request, V4Properties properties) { boolean isPayloadSigning = isPayloadSigning(request); boolean isEventStreaming = isEventStreaming(request.request()); @@ -178,13 +182,16 @@ private static V4PayloadSigner v4PayloadSigner( return V4PayloadSigner.create(); } + // TODO: remove this once we consolidate the behavior for plaintext HTTP signing for sync and async private static V4PayloadSigner v4PayloadAsyncSigner( AsyncSignRequest request, V4Properties properties) { - boolean isPayloadSigning = request.requireProperty(PAYLOAD_SIGNING_ENABLED, true); + boolean isPayloadSigning = isPayloadSigning(request); boolean isEventStreaming = isEventStreaming(request.request()); boolean isChunkEncoding = request.requireProperty(CHUNK_ENCODING_ENABLED, false); + boolean isTrailing = request.request().firstMatchingHeader(X_AMZ_TRAILER).isPresent(); + boolean isFlexible = request.hasProperty(CHECKSUM_ALGORITHM) && !hasChecksumHeader(request); if (isEventStreaming) { if (isPayloadSigning) { @@ -197,13 +204,21 @@ private static V4PayloadSigner v4PayloadAsyncSigner( throw new UnsupportedOperationException("Unsigned payload is not supported with event-streaming."); } - if (isChunkEncoding && isPayloadSigning) { - // TODO(sra-identity-and-auth): We need to implement aws-chunk content-encoding for async. - // For now, we basically have to treat this as an unsigned case because there are existing s3 use-cases for - // Unsigned-payload + HTTP. These requests SHOULD be signed-payload, but are not pre-SRA, hence the problem. This - // will be taken care of in HttpChecksumStage for now, so we shouldn't throw an unsupported exception here, we - // should just fall through to the default since it will already encoded by the time it gets here. - return V4PayloadSigner.create(); + // Note: this check is done after we check if the request is eventstreaming, during which we just use the same logic + // as sync to determine if the body should be signed. If it's not eventstreaming, then async needs to treat this + // request differently to maintain current behavior re: plain HTTP requests. + boolean nonEvenstreamingPayloadSigning = isPayloadSigning; + if (asyncShouldTreatAsUnsigned(request)) { + nonEvenstreamingPayloadSigning = false; + } + + if (useChunkEncoding(nonEvenstreamingPayloadSigning, isChunkEncoding, isTrailing || isFlexible)) { + return AwsChunkedV4PayloadSigner.builder() + .credentialScope(properties.getCredentialScope()) + .chunkSize(DEFAULT_CHUNK_SIZE_IN_BYTES) + .checksumStore(checksumStore(request)) + .checksumAlgorithm(request.property(CHECKSUM_ALGORITHM)) + .build(); } return V4PayloadSigner.create(); @@ -233,19 +248,30 @@ private static SignedRequest doSign(SignRequest doSign(AsyncSignRequest request, - Checksummer checksummer, - V4RequestSigner requestSigner, - V4PayloadSigner payloadSigner) { + private static CompletableFuture doSignAsync(AsyncSignRequest request, + Checksummer checksummer, + V4RequestSigner requestSigner, + V4PayloadSigner payloadSigner) { SdkHttpRequest.Builder requestBuilder = request.request().toBuilder(); + Publisher requestPayload = request.payload().orElse(null); + + return checksummer.checksum(requestPayload, requestBuilder) + .thenCompose(checksummedPayload -> + payloadSigner.beforeSigningAsync(requestBuilder, checksummedPayload)) + .thenApply(p -> { + SdkHttpRequest.Builder requestToSign = p.left(); + Publisher payloadToSign = p.right().orElse(null); + + V4RequestSigningResult requestSigningResult = requestSigner.sign(requestToSign); - return checksummer.checksum(request.payload().orElse(null), requestBuilder) - .thenApply(payload -> { - V4RequestSigningResult requestSigningResultFuture = requestSigner.sign(requestBuilder); + Publisher signedPayload = null; + if (payloadToSign != null) { + signedPayload = payloadSigner.signAsync(payloadToSign, requestSigningResult); + } return AsyncSignedRequest.builder() - .request(requestSigningResultFuture.getSignedRequest().build()) - .payload(payloadSigner.signAsync(payload, requestSigningResultFuture)) + .request(requestSigningResult.getSignedRequest().build()) + .payload(signedPayload) .build(); }); } @@ -265,7 +291,7 @@ private static boolean isBetweenInclusive(Duration start, Duration x, Duration e return start.compareTo(x) <= 0 && x.compareTo(end) <= 0; } - private static PayloadChecksumStore checksumStore(SignRequest request) { + private static PayloadChecksumStore checksumStore(BaseSignRequest request) { PayloadChecksumStore cache = request.property(CHECKSUM_STORE); if (cache == null) { return NoOpPayloadChecksumStore.create(); diff --git a/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/V4PayloadSigner.java b/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/V4PayloadSigner.java index 189fbe420085..d8a88cf3f91a 100644 --- a/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/V4PayloadSigner.java +++ b/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/V4PayloadSigner.java @@ -16,10 +16,13 @@ package software.amazon.awssdk.http.auth.aws.internal.signer; import java.nio.ByteBuffer; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; import org.reactivestreams.Publisher; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.http.ContentStreamProvider; import software.amazon.awssdk.http.SdkHttpRequest; +import software.amazon.awssdk.utils.Pair; /** * An interface for defining how to sign a payload via SigV4. @@ -48,4 +51,9 @@ static V4PayloadSigner create() { */ default void beforeSigning(SdkHttpRequest.Builder request, ContentStreamProvider payload) { } + + default CompletableFuture>>> beforeSigningAsync( + SdkHttpRequest.Builder request, Publisher payload) { + return CompletableFuture.completedFuture(Pair.of(request, Optional.ofNullable(payload))); + } } diff --git a/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/AsyncChunkEncodedPayload.java b/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/AsyncChunkEncodedPayload.java new file mode 100644 index 000000000000..4ec26eeaee5d --- /dev/null +++ b/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/AsyncChunkEncodedPayload.java @@ -0,0 +1,64 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.http.auth.aws.internal.signer.chunkedencoding; + +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.List; +import org.reactivestreams.Publisher; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.checksums.SdkChecksum; +import software.amazon.awssdk.http.auth.aws.internal.signer.io.UnbufferedChecksumSubscriber; + +@SdkInternalApi +public class AsyncChunkEncodedPayload implements ChunkedEncodedPayload { + private final ChunkedEncodedPublisher.Builder publisherBuilder; + + public AsyncChunkEncodedPayload(ChunkedEncodedPublisher.Builder publisherBuilder) { + this.publisherBuilder = publisherBuilder; + } + + @Override + public void addTrailer(TrailerProvider trailerProvider) { + publisherBuilder.addTrailer(trailerProvider); + } + + @Override + public List trailers() { + return publisherBuilder.trailers(); + } + + @Override + public void addExtension(ChunkExtensionProvider chunkExtensionProvider) { + publisherBuilder.addExtension(chunkExtensionProvider); + } + + @Override + public void checksumPayload(SdkChecksum checksum) { + Publisher checksumPayload = computeChecksum(publisherBuilder.publisher(), checksum); + publisherBuilder.publisher(checksumPayload); + } + + @Override + public void decodedContentLength(long contentLength) { + publisherBuilder.contentLength(contentLength); + } + + private Publisher computeChecksum(Publisher publisher, SdkChecksum checksum) { + return subscriber -> publisher.subscribe( + new UnbufferedChecksumSubscriber(Collections.singletonList(checksum), subscriber)); + } +} diff --git a/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPayload.java b/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPayload.java new file mode 100644 index 000000000000..c5c574db1c7f --- /dev/null +++ b/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPayload.java @@ -0,0 +1,45 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.http.auth.aws.internal.signer.chunkedencoding; + +import java.util.List; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.checksums.SdkChecksum; +import software.amazon.awssdk.http.auth.aws.internal.signer.AwsChunkedV4PayloadSigner; + +/** + * Abstraction interface to simplify payload signing in {@link AwsChunkedV4PayloadSigner} by allowing us to have a uniform + * interface for signing both sync and async payloads. See the {@code signCommon} method in {@link AwsChunkedV4PayloadSigner}. + */ +@SdkInternalApi +public interface ChunkedEncodedPayload { + void addTrailer(TrailerProvider trailerProvider); + + List trailers(); + + void addExtension(ChunkExtensionProvider chunkExtensionProvider); + + /** + * Update the payload so that its data is fed to the given checksum. + */ + void checksumPayload(SdkChecksum checksum); + + /** + * Set the decoded content length of the payload. + */ + default void decodedContentLength(long contentLength) { + } +} diff --git a/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPublisher.java b/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPublisher.java index d0196b48f5a3..4c1e1940d209 100644 --- a/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPublisher.java +++ b/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPublisher.java @@ -61,6 +61,7 @@ */ @SdkInternalApi public class ChunkedEncodedPublisher implements Publisher { + private static final ByteBuffer EMPTY_BUFFER = ByteBuffer.allocate(0); private static final byte[] CRLF = {'\r', '\n'}; private static final byte SEMICOLON = ';'; private static final byte EQUALS = '='; @@ -72,9 +73,10 @@ public class ChunkedEncodedPublisher implements Publisher { private final List extensions = new ArrayList<>(); private final List trailers = new ArrayList<>(); private final int chunkSize; - private ByteBuffer chunkBuffer; private final boolean addEmptyTrailingChunk; + private ByteBuffer chunkBuffer; + public ChunkedEncodedPublisher(Builder b) { this.wrapped = b.publisher; this.contentLength = Validate.notNull(b.contentLength, "contentLength must not be null"); @@ -82,35 +84,44 @@ public ChunkedEncodedPublisher(Builder b) { this.extensions.addAll(b.extensions); this.trailers.addAll(b.trailers); this.addEmptyTrailingChunk = b.addEmptyTrailingChunk; + this.chunkBuffer = ByteBuffer.allocate(chunkSize); } @Override public void subscribe(Subscriber subscriber) { + resetState(); + Publisher lengthEnforced = limitLength(wrapped, contentLength); Publisher> chunked = chunk(lengthEnforced); Publisher> trailingAdded = addTrailingChunks(chunked); Publisher flattened = flatten(trailingAdded); - Publisher encoded = map(flattened, this::encodeChunk); - encoded.subscribe(subscriber); + flattened.subscribe(subscriber); } public static Builder builder() { return new Builder(); } + private void resetState() { + extensions.forEach(Resettable::reset); + trailers.forEach(Resettable::reset); + chunkBuffer = ByteBuffer.allocate(chunkSize); + } + private Iterable> getTrailingChunks() { List trailing = new ArrayList<>(); if (chunkBuffer != null) { chunkBuffer.flip(); if (chunkBuffer.hasRemaining()) { - trailing.add(chunkBuffer); + trailing.add(encodeChunk(chunkBuffer)); + chunkBuffer = null; } } if (addEmptyTrailingChunk) { - trailing.add(ByteBuffer.allocate(0)); + trailing.add(encodeChunk(EMPTY_BUFFER.duplicate())); } return Collections.singletonList(trailing); @@ -165,6 +176,7 @@ private ByteBuffer encodeChunk(ByteBuffer byteBuffer) { .mapToInt(t -> t.remaining() + CRLF.length) .sum(); + int encodedLen = chunkSizeHex.length + extensionsLength + CRLF.length + contentLen + trailerLen + CRLF.length; if (isTrailerChunk) { @@ -263,40 +275,60 @@ protected ChunkingSubscriber(Subscriber> subscriber } @Override - public void onNext(ByteBuffer byteBuffer) { - if (chunkBuffer == null) { - chunkBuffer = ByteBuffer.allocate(chunkSize); - } + public void onNext(ByteBuffer incomingData) { + long totalAvailableBytes = (long) chunkBuffer.position() + incomingData.remaining(); + // compute the number full chunks we have currently + int numCompleteChunks = (int) (totalAvailableBytes / chunkSize); - long totalBufferedBytes = (long) chunkBuffer.position() + byteBuffer.remaining(); - int nBufferedChunks = (int) (totalBufferedBytes / chunkSize); + List encodedChunks = new ArrayList<>(numCompleteChunks); - List chunks = new ArrayList<>(nBufferedChunks); + if (numCompleteChunks > 0) { + // We have some data from the previous incomingData + if (chunkBuffer.position() > 0) { + int bytesToFill = chunkBuffer.remaining(); - if (nBufferedChunks > 0) { - for (int i = 0; i < nBufferedChunks; i++) { - ByteBuffer slice = byteBuffer.slice(); - int maxBytesToCopy = Math.min(chunkBuffer.remaining(), slice.remaining()); - slice.limit(maxBytesToCopy); + ByteBuffer dataToFillBuffer = incomingData.slice(); - chunkBuffer.put(slice); - if (!chunkBuffer.hasRemaining()) { - chunkBuffer.flip(); - chunks.add(chunkBuffer); - chunkBuffer = ByteBuffer.allocate(chunkSize); - } + dataToFillBuffer.limit(dataToFillBuffer.position() + bytesToFill); + incomingData.position(incomingData.position() + bytesToFill); - byteBuffer.position(byteBuffer.position() + maxBytesToCopy); + // At this point, we know chunkBuffer is full since incomingData has at least enough bytes to make up a full + // chunk along with the data already in chunkBuffer + chunkBuffer.put(dataToFillBuffer); + chunkBuffer.flip(); + encodedChunks.add(encodeChunk(chunkBuffer)); + + chunkBuffer.flip(); + + numCompleteChunks--; + } + + // Now encode all the remaining full chunks from incomingData. + // At this point chunkBuffer has no data in it; slice off chunks from incomingData and encode directly + for (int i = 0; i < numCompleteChunks; i++) { + ByteBuffer chunkData = incomingData.slice(); + + int maxChunkBytes = Math.min(chunkData.limit(), chunkSize); + chunkData.limit(maxChunkBytes); + + incomingData.position(incomingData.position() + chunkData.remaining()); + + if (chunkData.remaining() >= chunkSize) { + chunkData.limit(chunkData.position() + chunkSize); + encodedChunks.add(encodeChunk(chunkData)); + } else { + chunkBuffer.put(chunkData); + } } - if (byteBuffer.hasRemaining()) { - chunkBuffer.put(byteBuffer); + if (incomingData.hasRemaining()) { + chunkBuffer.put(incomingData); } } else { - chunkBuffer.put(byteBuffer); + chunkBuffer.put(incomingData); } - subscriber.onNext(chunks); + subscriber.onNext(encodedChunks); } } diff --git a/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/SyncChunkEncodedPayload.java b/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/SyncChunkEncodedPayload.java new file mode 100644 index 000000000000..2fc67519a758 --- /dev/null +++ b/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/SyncChunkEncodedPayload.java @@ -0,0 +1,56 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.http.auth.aws.internal.signer.chunkedencoding; + +import java.util.Collections; +import java.util.List; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.checksums.SdkChecksum; +import software.amazon.awssdk.http.auth.aws.internal.signer.io.ChecksumInputStream; + +@SdkInternalApi +public class SyncChunkEncodedPayload implements ChunkedEncodedPayload { + private final ChunkedEncodedInputStream.Builder chunkedInputStream; + + public SyncChunkEncodedPayload(ChunkedEncodedInputStream.Builder chunkedInputStream) { + this.chunkedInputStream = chunkedInputStream; + } + + @Override + public void addTrailer(TrailerProvider trailerProvider) { + chunkedInputStream.addTrailer(trailerProvider); + } + + @Override + public List trailers() { + return chunkedInputStream.trailers(); + } + + @Override + public void addExtension(ChunkExtensionProvider chunkExtensionProvider) { + chunkedInputStream.addExtension(chunkExtensionProvider); + } + + @Override + public void checksumPayload(SdkChecksum checksum) { + ChecksumInputStream checksumInputStream = new ChecksumInputStream( + chunkedInputStream.inputStream(), + Collections.singleton(checksum) + ); + + chunkedInputStream.inputStream(checksumInputStream); + } +} diff --git a/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/io/UnbufferedChecksumSubscriber.java b/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/io/UnbufferedChecksumSubscriber.java new file mode 100644 index 000000000000..2163fc6e7480 --- /dev/null +++ b/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/io/UnbufferedChecksumSubscriber.java @@ -0,0 +1,68 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.http.auth.aws.internal.signer.io; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.checksums.SdkChecksum; + +/** + * A decorating {@code Subscriber} that updates a list of {@code SdkChecksum}s with the data of each buffer given to + * {@code onNext}. + *

+ * This is "unbuffered", as opposed to {@link ChecksumSubscriber} which does buffer the data. + */ +@SdkInternalApi +public class UnbufferedChecksumSubscriber implements Subscriber { + private final List checksums; + private final Subscriber wrapped; + + public UnbufferedChecksumSubscriber(List checksums, Subscriber wrapped) { + this.checksums = new ArrayList<>(checksums); + this.wrapped = wrapped; + } + + @Override + public void onSubscribe(Subscription subscription) { + if (subscription == null) { + throw new NullPointerException("subscription is null"); + } + wrapped.onSubscribe(subscription); + } + + @Override + public void onNext(ByteBuffer byteBuffer) { + checksums.forEach(ck -> ck.update(byteBuffer.duplicate())); + wrapped.onNext(byteBuffer); + } + + @Override + public void onError(Throwable throwable) { + if (throwable == null) { + throw new NullPointerException("throwable is null"); + } + wrapped.onError(throwable); + } + + @Override + public void onComplete() { + wrapped.onComplete(); + } +} diff --git a/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/util/LengthCalculatingSubscriber.java b/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/util/LengthCalculatingSubscriber.java new file mode 100644 index 000000000000..054f1761a162 --- /dev/null +++ b/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/util/LengthCalculatingSubscriber.java @@ -0,0 +1,58 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.http.auth.aws.internal.signer.util; + +import java.nio.ByteBuffer; +import java.util.concurrent.CompletableFuture; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.annotations.SdkInternalApi; + +@SdkInternalApi +class LengthCalculatingSubscriber implements Subscriber { + private final CompletableFuture contentLengthFuture = new CompletableFuture<>(); + private Subscription subscription; + private long length = 0; + + @Override + public void onSubscribe(Subscription subscription) { + if (this.subscription == null) { + this.subscription = subscription; + this.subscription.request(Long.MAX_VALUE); + } else { + subscription.cancel(); + } + } + + @Override + public void onNext(ByteBuffer byteBuffer) { + length += byteBuffer.remaining(); + } + + @Override + public void onError(Throwable throwable) { + contentLengthFuture.completeExceptionally(throwable); + } + + @Override + public void onComplete() { + contentLengthFuture.complete(length); + } + + public CompletableFuture contentLengthFuture() { + return contentLengthFuture; + } +} diff --git a/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/util/SignerUtils.java b/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/util/SignerUtils.java index e4e0b711eb9e..c8a3d1bc3ffc 100644 --- a/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/util/SignerUtils.java +++ b/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/util/SignerUtils.java @@ -26,8 +26,10 @@ import java.time.ZoneId; import java.time.format.DateTimeFormatter; import java.util.Optional; +import java.util.concurrent.CompletableFuture; import javax.crypto.Mac; import javax.crypto.spec.SecretKeySpec; +import org.reactivestreams.Publisher; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.checksums.SdkChecksum; import software.amazon.awssdk.http.ContentStreamProvider; @@ -37,6 +39,7 @@ import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; import software.amazon.awssdk.utils.BinaryUtils; import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.Pair; import software.amazon.awssdk.utils.http.SdkHttpUtils; /** @@ -198,7 +201,7 @@ public static void addDateHeader(SdkHttpRequest.Builder requestBuilder, String d * Move `Content-Length` to `x-amz-decoded-content-length` if not already present. If `Content-Length` is not present, then * the payload is read in its entirety to calculate the length. */ - public static long moveContentLength(SdkHttpRequest.Builder request, ContentStreamProvider contentStreamProvider) { + public static long computeAndMoveContentLength(SdkHttpRequest.Builder request, ContentStreamProvider contentStreamProvider) { Optional decodedContentLength = request.firstMatchingHeader(X_AMZ_DECODED_CONTENT_LENGTH); if (decodedContentLength.isPresent()) { @@ -221,6 +224,41 @@ public static long moveContentLength(SdkHttpRequest.Builder request, ContentStre return contentLength; } + /** + * Move `Content-Length` to `x-amz-decoded-content-length` if not already present. If `Content-Length` is not present, then + * the payload is read in its entirety to calculate the length. + */ + public static CompletableFuture>>> computeAndMoveContentLength( + SdkHttpRequest.Builder request, Publisher contentPublisher) { + Optional decodedContentLength = request.firstMatchingHeader(X_AMZ_DECODED_CONTENT_LENGTH); + + if (decodedContentLength.isPresent()) { + request.removeHeader(Header.CONTENT_LENGTH); + return CompletableFuture.completedFuture(Pair.of(request, Optional.of(contentPublisher))); + } + + CompletableFuture contentLengthFuture; + + Optional contentLengthFromHeader = + request.firstMatchingHeader(Header.CONTENT_LENGTH); + if (contentLengthFromHeader.isPresent()) { + long contentLength = Long.parseLong(contentLengthFromHeader.get()); + contentLengthFuture = CompletableFuture.completedFuture(contentLength); + } else { + if (contentPublisher == null) { + contentLengthFuture = CompletableFuture.completedFuture(0L); + } else { + throw new UnsupportedOperationException("Content-Length header must be specified"); + } + } + + return contentLengthFuture.thenApply(cl -> { + request.putHeader(X_AMZ_DECODED_CONTENT_LENGTH, String.valueOf(cl)) + .removeHeader(Header.CONTENT_LENGTH); + return Pair.of(request, Optional.ofNullable(contentPublisher)); + }); + } + public static InputStream getBinaryRequestPayloadStream(ContentStreamProvider streamProvider) { try { if (streamProvider == null) { diff --git a/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/TestUtils.java b/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/TestUtils.java index 954ab243bf11..80aad87d88c8 100644 --- a/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/TestUtils.java +++ b/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/TestUtils.java @@ -4,6 +4,7 @@ import static software.amazon.awssdk.http.auth.aws.signer.AwsV4HttpSigner.SERVICE_SIGNING_NAME; import static software.amazon.awssdk.http.auth.spi.signer.HttpSigner.SIGNING_CLOCK; +import io.reactivex.Flowable; import java.io.ByteArrayInputStream; import java.net.URI; import java.nio.ByteBuffer; @@ -13,6 +14,7 @@ import java.time.ZoneId; import java.time.ZoneOffset; import java.util.function.Consumer; +import org.reactivestreams.Publisher; import software.amazon.awssdk.http.SdkHttpMethod; import software.amazon.awssdk.http.SdkHttpRequest; import software.amazon.awssdk.http.auth.spi.signer.AsyncSignRequest; @@ -58,16 +60,15 @@ public static AsyncSignRequest generateBas Consumer requestOverrides, Consumer> signRequestOverrides ) { - SimplePublisher publisher = new SimplePublisher<>(); - publisher.send(ByteBuffer.wrap(testPayload())); - publisher.complete(); + Publisher publisher = Flowable.just(ByteBuffer.wrap(testPayload())); return AsyncSignRequest.builder(credentials) .request(SdkHttpRequest.builder() .protocol("https") .method(SdkHttpMethod.POST) .putHeader("Host", "demo.us-east-1.amazonaws.com") + .putHeader("content-length", Integer.toString(testPayload().length)) .putHeader("x-amz-archive-description", "test test") .encodedPath("/") .uri(URI.create("https://demo.us-east-1.amazonaws.com")) diff --git a/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/internal/signer/AwsChunkedV4PayloadSignerTest.java b/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/internal/signer/AwsChunkedV4PayloadSignerTest.java index 01ad8b847151..3c626ec91daa 100644 --- a/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/internal/signer/AwsChunkedV4PayloadSignerTest.java +++ b/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/internal/signer/AwsChunkedV4PayloadSignerTest.java @@ -17,22 +17,35 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; import static software.amazon.awssdk.checksums.DefaultChecksumAlgorithm.CRC32; import static software.amazon.awssdk.checksums.DefaultChecksumAlgorithm.SHA256; +import io.reactivex.subscribers.TestSubscriber; import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; import java.net.URI; +import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.time.Instant; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.TimeUnit; +import java.util.stream.Stream; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; import software.amazon.awssdk.http.ContentStreamProvider; import software.amazon.awssdk.http.Header; import software.amazon.awssdk.http.SdkHttpMethod; import software.amazon.awssdk.http.SdkHttpRequest; +import software.amazon.awssdk.utils.Pair; /** * Test the delegation of signing to the correct implementations. @@ -43,9 +56,7 @@ public class AwsChunkedV4PayloadSignerTest { CredentialScope credentialScope = new CredentialScope("us-east-1", "s3", Instant.EPOCH); - byte[] data = "{\"TableName\": \"foo\"}".getBytes(); - - ContentStreamProvider payload = () -> new ByteArrayInputStream(data); + static final byte[] data = "{\"TableName\": \"foo\"}".getBytes(); SdkHttpRequest.Builder requestBuilder; @@ -61,8 +72,9 @@ public void setUp() { .uri(URI.create("http://demo.us-east-1.amazonaws.com")); } - @Test - public void sign_withSignedPayload_shouldChunkEncodeWithSigV4Ext() throws IOException { + @ParameterizedTest(name = "{0}") + @MethodSource("signingImpls") + void sign_withSignedPayload_shouldChunkEncodeWithSigV4Ext(String name, SigningImplementation impl) { String expectedContent = "4;chunk-signature=082f5b0e588893570e152b401a886161ee772ed066948f68c8f01aee11cca4f8\r\n{\"Ta\r\n" + "4;chunk-signature=777b02ec61ce7934578b1efe6fbe08c21ae4a8cdf66a709d3b4fd320dddd2839\r\nbleN\r\n" + @@ -89,20 +101,21 @@ public void sign_withSignedPayload_shouldChunkEncodeWithSigV4Ext() throws IOExce .chunkSize(chunkSize) .build(); - signer.beforeSigning(requestBuilder, null); - ContentStreamProvider signedPayload = signer.sign(payload, requestSigningResult); + Pair signingResult = impl.sign(signer, requestSigningResult); - assertThat(requestBuilder.firstMatchingHeader("x-amz-decoded-content-length")).hasValue(Integer.toString(data.length)); + SdkHttpRequest.Builder finalRequest = signingResult.left(); + byte[] payloadBytes = signingResult.right(); - byte[] tmp = new byte[1024]; - int actualBytes = readAll(signedPayload.newStream(), tmp); + assertThat(finalRequest.firstMatchingHeader("x-amz-decoded-content-length")).hasValue(Integer.toString(data.length)); - assertThat(requestBuilder.firstMatchingHeader(Header.CONTENT_LENGTH)).hasValue(Integer.toString(actualBytes)); - assertEquals(expectedContent, new String(tmp, 0, actualBytes)); + assertThat(finalRequest.firstMatchingHeader(Header.CONTENT_LENGTH)).hasValue(Integer.toString(payloadBytes.length)); + assertThat(new String(payloadBytes, StandardCharsets.UTF_8)).isEqualTo(expectedContent); } - @Test - public void sign_withSignedPayloadAndChecksum_shouldChunkEncodeWithSigV4ExtAndSigV4Trailer() throws IOException { + @ParameterizedTest(name = "{0}") + @MethodSource("signingImpls") + void sign_withSignedPayloadAndChecksum_shouldChunkEncodeWithSigV4ExtAndSigV4Trailer(String name, + SigningImplementation impl) { String expectedContent = "4;chunk-signature=082f5b0e588893570e152b401a886161ee772ed066948f68c8f01aee11cca4f8\r\n{\"Ta\r\n" + "4;chunk-signature=777b02ec61ce7934578b1efe6fbe08c21ae4a8cdf66a709d3b4fd320dddd2839\r\nbleN\r\n" + @@ -132,21 +145,20 @@ public void sign_withSignedPayloadAndChecksum_shouldChunkEncodeWithSigV4ExtAndSi .checksumAlgorithm(CRC32) .build(); - signer.beforeSigning(requestBuilder, payload); - ContentStreamProvider signedPayload = signer.sign(payload, requestSigningResult); + Pair signingResult = impl.sign(signer, requestSigningResult); + SdkHttpRequest.Builder finalRequest = signingResult.left(); + byte[] payloadBytes = signingResult.right(); - assertThat(requestBuilder.firstMatchingHeader("x-amz-decoded-content-length")).hasValue(Integer.toString(data.length)); - assertThat(requestBuilder.firstMatchingHeader("x-amz-trailer")).hasValue("x-amz-checksum-crc32"); + assertThat(finalRequest.firstMatchingHeader("x-amz-decoded-content-length")).hasValue(Integer.toString(data.length)); + assertThat(finalRequest.firstMatchingHeader("x-amz-trailer")).hasValue("x-amz-checksum-crc32"); - byte[] tmp = new byte[1024]; - int actualBytes = readAll(signedPayload.newStream(), tmp); - - assertThat(requestBuilder.firstMatchingHeader(Header.CONTENT_LENGTH)).hasValue(Integer.toString(actualBytes)); - assertEquals(expectedContent, new String(tmp, 0, actualBytes)); + assertThat(finalRequest.firstMatchingHeader(Header.CONTENT_LENGTH)).hasValue(Integer.toString(payloadBytes.length)); + assertThat(new String(payloadBytes, StandardCharsets.UTF_8)).isEqualTo(expectedContent); } - @Test - public void sign_withChecksum_shouldChunkEncodeWithChecksumTrailer() throws IOException { + @ParameterizedTest(name = "{0}") + @MethodSource("signingImpls") + void sign_withChecksum_shouldChunkEncodeWithChecksumTrailer(String name, SigningImplementation impl) { String expectedContent = "4\r\n{\"Ta\r\n" + "4\r\nbleN\r\n" + @@ -175,21 +187,20 @@ public void sign_withChecksum_shouldChunkEncodeWithChecksumTrailer() throws IOEx .checksumAlgorithm(SHA256) .build(); - signer.beforeSigning(requestBuilder, payload); - ContentStreamProvider signedPayload = signer.sign(payload, requestSigningResult); + Pair signingResult = impl.sign(signer, requestSigningResult); + SdkHttpRequest.Builder finalRequest = signingResult.left(); + byte[] payloadBytes = signingResult.right(); - assertThat(requestBuilder.firstMatchingHeader("x-amz-decoded-content-length")).hasValue(Integer.toString(data.length)); - assertThat(requestBuilder.firstMatchingHeader("x-amz-trailer")).hasValue("x-amz-checksum-sha256"); + assertThat(finalRequest.firstMatchingHeader("x-amz-decoded-content-length")).hasValue(Integer.toString(data.length)); + assertThat(finalRequest.firstMatchingHeader("x-amz-trailer")).hasValue("x-amz-checksum-sha256"); - byte[] tmp = new byte[1024]; - int actualBytes = readAll(signedPayload.newStream(), tmp); - - assertThat(requestBuilder.firstMatchingHeader(Header.CONTENT_LENGTH)).hasValue(Integer.toString(actualBytes)); - assertEquals(expectedContent, new String(tmp, 0, actualBytes)); + assertThat(finalRequest.firstMatchingHeader(Header.CONTENT_LENGTH)).hasValue(Integer.toString(payloadBytes.length)); + assertThat(new String(payloadBytes, StandardCharsets.UTF_8)).isEqualTo(expectedContent); } - @Test - public void sign_withPreExistingTrailers_shouldChunkEncodeWithExistingTrailers() throws IOException { + @ParameterizedTest(name = "{0}") + @MethodSource("signingImpls") + void sign_withPreExistingTrailers_shouldChunkEncodeWithExistingTrailers(String name, SigningImplementation impl) { String expectedContent = "4\r\n{\"Ta\r\n" + "4\r\nbleN\r\n" + @@ -225,23 +236,22 @@ public void sign_withPreExistingTrailers_shouldChunkEncodeWithExistingTrailers() .chunkSize(chunkSize) .build(); - signer.beforeSigning(requestBuilder, payload); - ContentStreamProvider signedPayload = signer.sign(payload, requestSigningResult); - - assertThat(requestBuilder.firstMatchingHeader("x-amz-decoded-content-length")).hasValue(Integer.toString(data.length)); - assertThat(requestBuilder.firstMatchingHeader("PreExistingHeader1")).isNotPresent(); - assertThat(requestBuilder.firstMatchingHeader("PreExistingHeader2")).isNotPresent(); - assertThat(requestBuilder.matchingHeaders("x-amz-trailer")).contains("PreExistingHeader1", "PreExistingHeader2"); + Pair signingResult = impl.sign(signer, requestSigningResult); + SdkHttpRequest.Builder finalRequest = signingResult.left(); + byte[] payloadBytes = signingResult.right(); - byte[] tmp = new byte[1024]; - int actualBytes = readAll(signedPayload.newStream(), tmp); + assertThat(finalRequest.firstMatchingHeader("x-amz-decoded-content-length")).hasValue(Integer.toString(data.length)); + assertThat(finalRequest.firstMatchingHeader("PreExistingHeader1")).isNotPresent(); + assertThat(finalRequest.firstMatchingHeader("PreExistingHeader2")).isNotPresent(); + assertThat(finalRequest.matchingHeaders("x-amz-trailer")).contains("PreExistingHeader1", "PreExistingHeader2"); - assertThat(requestBuilder.firstMatchingHeader(Header.CONTENT_LENGTH)).hasValue(Integer.toString(actualBytes)); - assertEquals(expectedContent, new String(tmp, 0, actualBytes)); + assertThat(finalRequest.firstMatchingHeader(Header.CONTENT_LENGTH)).hasValue(Integer.toString(payloadBytes.length)); + assertThat(new String(payloadBytes, StandardCharsets.UTF_8)).isEqualTo(expectedContent); } - @Test - public void sign_withPreExistingTrailersAndChecksum_shouldChunkEncodeWithTrailers() throws IOException { + @ParameterizedTest(name = "{0}") + @MethodSource("signingImpls") + void sign_withPreExistingTrailersAndChecksum_shouldChunkEncodeWithTrailers(String name, SigningImplementation impl) { String expectedContent = "4\r\n{\"Ta\r\n" + "4\r\nbleN\r\n" + @@ -279,25 +289,25 @@ public void sign_withPreExistingTrailersAndChecksum_shouldChunkEncodeWithTrailer .checksumAlgorithm(CRC32) .build(); - signer.beforeSigning(requestBuilder, payload); - ContentStreamProvider signedPayload = signer.sign(payload, requestSigningResult); + Pair signingResult = impl.sign(signer, requestSigningResult); + SdkHttpRequest.Builder finalRequest = signingResult.left(); + byte[] payloadBytes = signingResult.right(); - assertThat(requestBuilder.firstMatchingHeader("x-amz-decoded-content-length")).hasValue(Integer.toString(data.length)); - assertThat(requestBuilder.firstMatchingHeader("PreExistingHeader1")).isNotPresent(); - assertThat(requestBuilder.firstMatchingHeader("PreExistingHeader2")).isNotPresent(); - assertThat(requestBuilder.matchingHeaders("x-amz-trailer")).contains( + assertThat(finalRequest.firstMatchingHeader("x-amz-decoded-content-length")).hasValue(Integer.toString(data.length)); + assertThat(finalRequest.firstMatchingHeader("PreExistingHeader1")).isNotPresent(); + assertThat(finalRequest.firstMatchingHeader("PreExistingHeader2")).isNotPresent(); + assertThat(finalRequest.matchingHeaders("x-amz-trailer")).contains( "PreExistingHeader1", "PreExistingHeader2", "x-amz-checksum-crc32" ); - byte[] tmp = new byte[1024]; - int actualBytes = readAll(signedPayload.newStream(), tmp); - - assertThat(requestBuilder.firstMatchingHeader(Header.CONTENT_LENGTH)).hasValue(Integer.toString(actualBytes)); - assertEquals(expectedContent, new String(tmp, 0, actualBytes)); + assertThat(finalRequest.firstMatchingHeader(Header.CONTENT_LENGTH)).hasValue(Integer.toString(payloadBytes.length)); + assertThat(new String(payloadBytes, StandardCharsets.UTF_8)).isEqualTo(expectedContent); } - @Test - public void sign_withPreExistingTrailersAndChecksumAndSignedPayload_shouldAwsChunkEncode() throws IOException { + @ParameterizedTest(name = "{0}") + @MethodSource("signingImpls") + void sign_withPreExistingTrailersAndChecksumAndSignedPayload_shouldAwsChunkEncode(String name, + SigningImplementation impl) { String expectedContent = "4;chunk-signature=082f5b0e588893570e152b401a886161ee772ed066948f68c8f01aee11cca4f8\r\n{\"Ta\r\n" + "4;chunk-signature=777b02ec61ce7934578b1efe6fbe08c21ae4a8cdf66a709d3b4fd320dddd2839\r\nbleN\r\n" + @@ -335,23 +345,21 @@ public void sign_withPreExistingTrailersAndChecksumAndSignedPayload_shouldAwsChu .checksumAlgorithm(CRC32) .build(); - signer.beforeSigning(requestBuilder, payload); - ContentStreamProvider signedPayload = signer.sign(payload, requestSigningResult); + Pair signingResult = impl.sign(signer, requestSigningResult); + SdkHttpRequest.Builder finalRequest = signingResult.left(); + byte[] payloadBytes = signingResult.right(); - assertThat(requestBuilder.firstMatchingHeader("x-amz-decoded-content-length")).hasValue(Integer.toString(data.length)); - assertThat(requestBuilder.firstMatchingHeader("PreExistingHeader1")).isNotPresent(); - assertThat(requestBuilder.matchingHeaders("x-amz-trailer")).contains("zzz", "PreExistingHeader1", "x-amz-checksum-crc32"); - - byte[] tmp = new byte[1024]; - int actualBytes = readAll(signedPayload.newStream(), tmp); + assertThat(finalRequest.firstMatchingHeader("x-amz-decoded-content-length")).hasValue(Integer.toString(data.length)); + assertThat(finalRequest.firstMatchingHeader("PreExistingHeader1")).isNotPresent(); + assertThat(finalRequest.matchingHeaders("x-amz-trailer")).contains("zzz", "PreExistingHeader1", "x-amz-checksum-crc32"); - assertThat(requestBuilder.firstMatchingHeader(Header.CONTENT_LENGTH)).hasValue(Integer.toString(actualBytes)); - assertEquals(expectedContent, new String(tmp, 0, actualBytes)); + assertThat(finalRequest.firstMatchingHeader(Header.CONTENT_LENGTH)).hasValue(Integer.toString(payloadBytes.length)); + assertThat(new String(payloadBytes, StandardCharsets.UTF_8)).isEqualTo(expectedContent); } - - @Test - public void sign_withoutContentLength_calculatesContentLengthFromPayload() throws IOException { + @ParameterizedTest(name = "{0}") + @MethodSource("signingImpls") + void sign_withoutContentLength_calculatesContentLengthFromPayload(String name, SigningImplementation impl) { String expectedContent = "4\r\n{\"Ta\r\n" + "4\r\nbleN\r\n" + @@ -382,21 +390,19 @@ public void sign_withoutContentLength_calculatesContentLengthFromPayload() throw .checksumAlgorithm(SHA256) .build(); - signer.beforeSigning(requestBuilder, payload); - ContentStreamProvider signedPayload = signer.sign(payload, requestSigningResult); - - assertThat(requestBuilder.firstMatchingHeader("x-amz-decoded-content-length")).hasValue(Integer.toString(data.length)); - assertThat(requestBuilder.firstMatchingHeader("x-amz-trailer")).hasValue("x-amz-checksum-sha256"); + Pair signingResult = impl.sign(signer, requestSigningResult); + SdkHttpRequest.Builder finalRequest = signingResult.left(); + byte[] payloadBytes = signingResult.right(); - byte[] tmp = new byte[1024]; - int actualBytes = readAll(signedPayload.newStream(), tmp); + assertThat(finalRequest.firstMatchingHeader("x-amz-decoded-content-length")).hasValue(Integer.toString(data.length)); + assertThat(finalRequest.firstMatchingHeader("x-amz-trailer")).hasValue("x-amz-checksum-sha256"); - assertThat(requestBuilder.firstMatchingHeader(Header.CONTENT_LENGTH)).hasValue(Integer.toString(actualBytes)); - assertEquals(expectedContent, new String(tmp, 0, actualBytes)); + assertThat(finalRequest.firstMatchingHeader(Header.CONTENT_LENGTH)).hasValue(Integer.toString(payloadBytes.length)); + assertThat(new String(payloadBytes, StandardCharsets.UTF_8)).isEqualTo(expectedContent); } @Test - public void sign_shouldReturnResettableContentStreamProvider() throws IOException { + void sign_shouldReturnResettableContentStreamProvider() throws IOException { String expectedContent = "4;chunk-signature=082f5b0e588893570e152b401a886161ee772ed066948f68c8f01aee11cca4f8\r\n{\"Ta\r\n" + "4;chunk-signature=777b02ec61ce7934578b1efe6fbe08c21ae4a8cdf66a709d3b4fd320dddd2839\r\nbleN\r\n" + @@ -423,6 +429,8 @@ public void sign_shouldReturnResettableContentStreamProvider() throws IOExceptio .chunkSize(chunkSize) .build(); + ContentStreamProvider payload = () -> new ByteArrayInputStream(data); + signer.beforeSigning(requestBuilder, payload); ContentStreamProvider signedPayload = signer.sign(payload, requestSigningResult); @@ -436,13 +444,58 @@ public void sign_shouldReturnResettableContentStreamProvider() throws IOExceptio } @Test - public void signAsync_throws() { + void signAsync_shouldReturnSameContentToAllSubscriptions() { + String expectedContent = + "4;chunk-signature=082f5b0e588893570e152b401a886161ee772ed066948f68c8f01aee11cca4f8\r\n{\"Ta\r\n" + + "4;chunk-signature=777b02ec61ce7934578b1efe6fbe08c21ae4a8cdf66a709d3b4fd320dddd2839\r\nbleN\r\n" + + "4;chunk-signature=84abdae650f64dee4d703d41c7d87c8bc251c22b8c493c75ce24431b60b73937\r\name\"\r\n" + + "4;chunk-signature=aff22ddad9d4388233fe9bc47e9c552a6e9ba9285af79555d2ce7fdaab726320\r\n: \"f\r\n" + + "4;chunk-signature=30e55f4e1c1fd444c06e9be42d9594b8fd7ead436bc67a58b5350ffd58b6aaa5\r\noo\"}\r\n" + + "0;chunk-signature=825ad80195cae47f54984835543ff2179c2c5a53c324059cd632e50259384ee3\r\n\r\n"; + + requestBuilder.putHeader("x-amz-content-sha256", "STREAMING-AWS4-HMAC-SHA256-PAYLOAD"); + V4CanonicalRequest canonicalRequest = new V4CanonicalRequest( + requestBuilder.build(), + "STREAMING-AWS4-HMAC-SHA256-PAYLOAD", + new V4CanonicalRequest.Options(true, true) + ); + V4RequestSigningResult requestSigningResult = new V4RequestSigningResult( + "STREAMING-AWS4-HMAC-SHA256-PAYLOAD", + "key".getBytes(StandardCharsets.UTF_8), + "sig", + canonicalRequest, + requestBuilder + ); AwsChunkedV4PayloadSigner signer = AwsChunkedV4PayloadSigner.builder() .credentialScope(credentialScope) .chunkSize(chunkSize) .build(); - assertThrows(UnsupportedOperationException.class, () -> signer.signAsync(null, null)); + TestPublisher payload = new TestPublisher(data); + + Pair>> beforeSigningResult = + signer.beforeSigningAsync(requestBuilder, payload).join(); + + Publisher signedPayload = signer.signAsync(beforeSigningResult.right().get(), requestSigningResult); + + // successive subscriptions should result in the same data + for (int i = 0; i < 2; i++) { + TestSubscriber subscriber = new TestSubscriber<>(); + signedPayload.subscribe(subscriber); + + subscriber.awaitTerminalEvent(5, TimeUnit.SECONDS); + subscriber.assertComplete(); + + List signedData = subscriber.values(); + + int signedDataSum = signedData.stream().mapToInt(ByteBuffer::remaining).sum(); + byte[] array = new byte[signedDataSum]; + + ByteBuffer combined = ByteBuffer.wrap(array); + signedData.forEach(combined::put); + + assertThat(new String(array, StandardCharsets.UTF_8)).isEqualTo(expectedContent); + } } private int readAll(InputStream src, byte[] dst) throws IOException { @@ -457,4 +510,93 @@ private int readAll(InputStream src, byte[] dst) throws IOException { } return offset; } + + public static Stream signingImpls() { + return Stream.of( + Arguments.of("ASYNC", (SigningImplementation) AwsChunkedV4PayloadSignerTest::doSignAsync), + Arguments.of("SYNC", (SigningImplementation) AwsChunkedV4PayloadSignerTest::doSign) + ); + } + + private static Pair doSign(AwsChunkedV4PayloadSigner signer, + V4RequestSigningResult requestSigningResult) { + SdkHttpRequest.Builder request = requestSigningResult.getSignedRequest(); + + ContentStreamProvider payload = () -> new ByteArrayInputStream(data); + + signer.beforeSigning(request, payload); + ContentStreamProvider signedPayload = signer.sign(payload, requestSigningResult); + + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + try { + InputStream is = signedPayload.newStream(); + byte[] buff = new byte[1024]; + int read; + while ((read = is.read(buff)) != -1) { + baos.write(buff, 0, read); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + + return Pair.of(request, baos.toByteArray()); + } + + private static Pair doSignAsync(AwsChunkedV4PayloadSigner signer, + V4RequestSigningResult requestSigningResult) { + SdkHttpRequest.Builder request = requestSigningResult.getSignedRequest(); + + TestPublisher payload = new TestPublisher(data); + + Pair>> beforeSigningResult = + signer.beforeSigningAsync(request, payload).join(); + + request = beforeSigningResult.left(); + Publisher signedPayload = signer.signAsync(beforeSigningResult.right().get(), requestSigningResult); + + TestSubscriber subscriber = new TestSubscriber<>(); + signedPayload.subscribe(subscriber); + + subscriber.awaitTerminalEvent(5, TimeUnit.SECONDS); + subscriber.assertComplete(); + + List signedData = subscriber.values(); + + int signedDataSum = signedData.stream().mapToInt(ByteBuffer::remaining).sum(); + byte[] array = new byte[signedDataSum]; + + ByteBuffer combined = ByteBuffer.wrap(array); + signedData.forEach(combined::put); + + return Pair.of(request, array); + } + + interface SigningImplementation { + Pair sign(AwsChunkedV4PayloadSigner signer, + V4RequestSigningResult requestSigningResult); + } + + private static final class TestPublisher implements Publisher { + private final byte[] data; + + private TestPublisher(byte[] data) { + this.data = data; + } + + @Override + public void subscribe(Subscriber subscriber) { + subscriber.onSubscribe(new Subscription() { + + @Override + public void request(long l) { + subscriber.onNext(ByteBuffer.wrap(data)); + subscriber.onComplete(); + } + + @Override + public void cancel() { + } + }); + } + } } diff --git a/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/internal/signer/DefaultAwsV4HttpSignerTest.java b/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/internal/signer/DefaultAwsV4HttpSignerTest.java index ebd139e67963..282fc1fe67d5 100644 --- a/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/internal/signer/DefaultAwsV4HttpSignerTest.java +++ b/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/internal/signer/DefaultAwsV4HttpSignerTest.java @@ -16,6 +16,7 @@ package software.amazon.awssdk.http.auth.aws.internal.signer; import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; import static org.junit.jupiter.api.Assertions.assertThrows; import static software.amazon.awssdk.checksums.DefaultChecksumAlgorithm.CRC32; import static software.amazon.awssdk.checksums.DefaultChecksumAlgorithm.SHA256; @@ -30,17 +31,21 @@ import static software.amazon.awssdk.http.auth.aws.signer.AwsV4HttpSigner.PAYLOAD_SIGNING_ENABLED; import static software.amazon.awssdk.http.auth.spi.signer.SdkInternalHttpSignerProperty.CHECKSUM_STORE; +import io.reactivex.Flowable; import java.io.IOException; import java.net.URI; +import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.time.Duration; -import java.util.Optional; +import java.util.List; +import java.util.stream.Collectors; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; import org.mockito.MockedStatic; import org.mockito.Mockito; +import org.reactivestreams.Publisher; import software.amazon.awssdk.checksums.SdkChecksum; import software.amazon.awssdk.checksums.spi.ChecksumAlgorithm; import software.amazon.awssdk.http.Header; @@ -419,8 +424,6 @@ void sign_WithChunkEncodingTrue_DelegatesToAwsChunkedPayloadSigner() { assertThat(signedRequest.request().firstMatchingHeader("x-amz-decoded-content-length")).hasValue("20"); } - // TODO(sra-identity-and-auth): Once chunk-encoding support in async is added, we can enable these tests. - @Disabled("Chunk-encoding is not currently supported in the Async signing path - it is handled in HttpChecksumStage for now.") @Test void signAsync_WithChunkEncodingTrue_DelegatesToAwsChunkedPayloadSigner_futureBehavior() { AsyncSignRequest request = generateBasicAsyncRequest( @@ -440,25 +443,6 @@ void signAsync_WithChunkEncodingTrue_DelegatesToAwsChunkedPayloadSigner_futureBe assertThat(signedRequest.request().firstMatchingHeader("x-amz-decoded-content-length")).hasValue("20"); } - // TODO(sra-identity-and-auth): Replace this test with the above test once chunk-encoding support is added - @Test - void signAsync_WithChunkEncodingTrue_DelegatesToAwsChunkedPayloadSigner() { - AsyncSignRequest request = generateBasicAsyncRequest( - AwsCredentialsIdentity.create("access", "secret"), - httpRequest -> httpRequest - .putHeader(Header.CONTENT_LENGTH, "20"), - signRequest -> signRequest - .putProperty(CHUNK_ENCODING_ENABLED, true) - ); - - AsyncSignedRequest signedRequest = signer.signAsync(request).join(); - - assertThat(signedRequest.request().firstMatchingHeader("x-amz-content-sha256")) - .hasValue("STREAMING-AWS4-HMAC-SHA256-PAYLOAD"); - assertThat(signedRequest.request().firstMatchingHeader(Header.CONTENT_LENGTH)).hasValue("20"); - assertThat(signedRequest.request().firstMatchingHeader("x-amz-decoded-content-length")).isNotPresent(); - } - @Test void sign_WithChunkEncodingTrueAndChecksumAlgorithm_DelegatesToAwsChunkedPayloadSigner() { SignRequest request = generateBasicRequest( @@ -479,8 +463,6 @@ void sign_WithChunkEncodingTrueAndChecksumAlgorithm_DelegatesToAwsChunkedPayload assertThat(signedRequest.request().firstMatchingHeader("x-amz-trailer")).hasValue("x-amz-checksum-crc32"); } - // TODO(sra-identity-and-auth): Once chunk-encoding support in async is added, we can enable these tests. - @Disabled("Chunk-encoding is not currently supported in the Async signing path - it is handled in HttpChecksumStage for now.") @Test void signAsync_WithChunkEncodingTrueAndChecksumAlgorithm_DelegatesToAwsChunkedPayloadSigner_futureBehavior() { AsyncSignRequest request = generateBasicAsyncRequest( @@ -502,27 +484,6 @@ void signAsync_WithChunkEncodingTrueAndChecksumAlgorithm_DelegatesToAwsChunkedPa assertThat(signedRequest.request().firstMatchingHeader("x-amz-trailer")).hasValue("x-amz-checksum-crc32"); } - // TODO(sra-identity-and-auth): Replace this test with the above test once chunk-encoding support is added - @Test - void signAsync_WithChunkEncodingTrueAndChecksumAlgorithm_DelegatesToAwsChunkedPayloadSigner() { - AsyncSignRequest request = generateBasicAsyncRequest( - AwsCredentialsIdentity.create("access", "secret"), - httpRequest -> httpRequest - .putHeader(Header.CONTENT_LENGTH, "20"), - signRequest -> signRequest - .putProperty(CHUNK_ENCODING_ENABLED, true) - .putProperty(CHECKSUM_ALGORITHM, CRC32) - ); - - AsyncSignedRequest signedRequest = signer.signAsync(request).join(); - - assertThat(signedRequest.request().firstMatchingHeader("x-amz-content-sha256")) - .hasValue("STREAMING-AWS4-HMAC-SHA256-PAYLOAD-TRAILER"); - assertThat(signedRequest.request().firstMatchingHeader(Header.CONTENT_LENGTH)).hasValue("20"); - assertThat(signedRequest.request().firstMatchingHeader("x-amz-decoded-content-length")).isNotPresent(); - assertThat(signedRequest.request().firstMatchingHeader("x-amz-trailer")).isNotPresent(); - } - @Test void sign_WithPayloadSigningFalseAndChunkEncodingTrueAndFlexibleChecksum_DelegatesToAwsChunkedPayloadSigner() { SignRequest request = generateBasicRequest( @@ -544,8 +505,6 @@ void sign_WithPayloadSigningFalseAndChunkEncodingTrueAndFlexibleChecksum_Delegat assertThat(signedRequest.request().firstMatchingHeader("x-amz-trailer")).hasValue("x-amz-checksum-crc32"); } - // TODO(sra-identity-and-auth): Once chunk-encoding support in async is added, we can enable these tests. - @Disabled("Chunk-encoding is not currently supported in the Async signing path - it is handled in HttpChecksumStage for now.") @Test void signAsync_WithPayloadSigningFalseAndChunkEncodingTrueAndTrailer_DelegatesToAwsChunkedPayloadSigner_futureBehavior() { AsyncSignRequest request = generateBasicAsyncRequest( @@ -568,28 +527,6 @@ void signAsync_WithPayloadSigningFalseAndChunkEncodingTrueAndTrailer_DelegatesTo assertThat(signedRequest.request().firstMatchingHeader("x-amz-trailer")).hasValue("x-amz-checksum-crc32"); } - // TODO(sra-identity-and-auth): Replace this test with the above test once chunk-encoding support is added - @Test - void signAsync_WithPayloadSigningFalseAndChunkEncodingTrueAndTrailer_DelegatesToAwsChunkedPayloadSigner() { - AsyncSignRequest request = generateBasicAsyncRequest( - AwsCredentialsIdentity.create("access", "secret"), - httpRequest -> httpRequest - .putHeader(Header.CONTENT_LENGTH, "20"), - signRequest -> signRequest - .putProperty(PAYLOAD_SIGNING_ENABLED, false) - .putProperty(CHUNK_ENCODING_ENABLED, true) - .putProperty(CHECKSUM_ALGORITHM, CRC32) - ); - - AsyncSignedRequest signedRequest = signer.signAsync(request).join(); - - assertThat(signedRequest.request().firstMatchingHeader("x-amz-content-sha256")) - .hasValue("STREAMING-UNSIGNED-PAYLOAD-TRAILER"); - assertThat(signedRequest.request().firstMatchingHeader(Header.CONTENT_LENGTH)).hasValue("20"); - assertThat(signedRequest.request().firstMatchingHeader("x-amz-decoded-content-length")).isNotPresent(); - assertThat(signedRequest.request().firstMatchingHeader("x-amz-trailer")).isNotPresent(); - } - @Test void sign_WithPayloadSigningFalseAndChunkEncodingTrue_DelegatesToUnsignedPayload() { // Currently, there is no use-case for unsigned chunk-encoding without trailers, so we should assert it falls back to @@ -776,10 +713,10 @@ void sign_WithPayloadSigningTrueAndChunkEncodingTrueAndHttp_SignsPayload() { assertThat(signedRequest.request().firstMatchingHeader("x-amz-decoded-content-length")).hasValue("20"); } - // TODO(sra-identity-and-auth): Once chunk-encoding is implemented in the async path, the assertions this test makes should - // be different - the assertions should mirror the above case. @Test - void signAsync_WithPayloadSigningTrueAndChunkEncodingTrueAndHttp_IgnoresPayloadSigning() { + @Disabled("Fallback to signing is disabled to match pre-SRA behavior") + // TODO: Enable this test once we figure out what the expected behavior is post SRA. See JAVA-8078 + void signAsync_WithPayloadSigningTrueAndChunkEncodingTrueAndHttp_RespectsPayloadSigning() { AsyncSignRequest request = generateBasicAsyncRequest( AwsCredentialsIdentity.create("access", "secret"), httpRequest -> httpRequest.uri(URI.create("http://demo.us-east-1.amazonaws.com")), @@ -791,10 +728,14 @@ void signAsync_WithPayloadSigningTrueAndChunkEncodingTrueAndHttp_IgnoresPayloadS AsyncSignedRequest signedRequest = signer.signAsync(request).join(); assertThat(signedRequest.request().firstMatchingHeader("x-amz-content-sha256")) - .hasValue("UNSIGNED-PAYLOAD"); + .hasValue("STREAMING-AWS4-HMAC-SHA256-PAYLOAD"); + assertThat(signedRequest.request().firstMatchingHeader(Header.CONTENT_LENGTH)).hasValue("193"); + assertThat(signedRequest.request().firstMatchingHeader("x-amz-decoded-content-length")).hasValue("20"); } @Test + @Disabled("Fallback to signing is disabled to match pre-SRA behavior") + // TODO: Enable this test once we figure out what the expected behavior is post SRA. See JAVA-8078 void sign_WithPayloadSigningFalseAndChunkEncodingTrueAndHttp_SignsPayload() { SignRequest request = generateBasicRequest( AwsCredentialsIdentity.create("access", "secret"), @@ -812,10 +753,10 @@ void sign_WithPayloadSigningFalseAndChunkEncodingTrueAndHttp_SignsPayload() { assertThat(signedRequest.request().firstMatchingHeader("x-amz-decoded-content-length")).hasValue("20"); } - // TODO(sra-identity-and-auth): Once chunk-encoding is implemented in the async path, the assertions this test makes should - // be different - the assertions should mirror the above case. @Test - void signAsync_WithPayloadSigningFalseAndChunkEncodingTrueAndHttp_DoesNotFallBackToPayloadSigning() { + @Disabled("Fallback to signing is disabled to match pre-SRA behavior") + // TODO: Enable this test once we figure out what the expected behavior is post SRA. See JAVA-8078 + void signAsync_WithPayloadSigningFalseAndChunkEncodingTrueAndHttp_FallsBackToPayloadSigning() { AsyncSignRequest request = generateBasicAsyncRequest( AwsCredentialsIdentity.create("access", "secret"), httpRequest -> httpRequest.uri(URI.create("http://demo.us-east-1.amazonaws.com")), @@ -827,7 +768,9 @@ void signAsync_WithPayloadSigningFalseAndChunkEncodingTrueAndHttp_DoesNotFallBac AsyncSignedRequest signedRequest = signer.signAsync(request).join(); assertThat(signedRequest.request().firstMatchingHeader("x-amz-content-sha256")) - .hasValue("UNSIGNED-PAYLOAD"); + .hasValue("STREAMING-AWS4-HMAC-SHA256-PAYLOAD"); + assertThat(signedRequest.request().firstMatchingHeader(Header.CONTENT_LENGTH)).hasValue("193"); + assertThat(signedRequest.request().firstMatchingHeader("x-amz-decoded-content-length")).hasValue("20"); } @Test @@ -850,10 +793,10 @@ void sign_WithPayloadSigningFalseAndChunkEncodingTrueAndFlexibleChecksumAndHttp_ assertThat(signedRequest.request().firstMatchingHeader("x-amz-trailer")).hasValue("x-amz-checksum-crc32"); } - // TODO(sra-identity-and-auth): Once chunk-encoding is implemented in the async path, the assertions this test makes should - // be different - the assertions should mirror the above case. @Test - void signAsync_WithPayloadSigningFalseAndChunkEncodingTrueAndFlexibleChecksumAndHttp_DoesNotFallBackToPayloadSigning() { + @Disabled("Fallback to signing is disabled to match pre-SRA behavior") + // TODO: Enable this test once we figure out what the expected behavior is post SRA. See JAVA-8078 + void signAsync_WithPayloadSigningFalseAndChunkEncodingTrueAndFlexibleChecksumAndHttp_FallsBackToPayloadSigning() { AsyncSignRequest request = generateBasicAsyncRequest( AwsCredentialsIdentity.create("access", "secret"), httpRequest -> httpRequest.uri(URI.create("http://demo.us-east-1.amazonaws.com")), @@ -866,7 +809,10 @@ void signAsync_WithPayloadSigningFalseAndChunkEncodingTrueAndFlexibleChecksumAnd AsyncSignedRequest signedRequest = signer.signAsync(request).join(); assertThat(signedRequest.request().firstMatchingHeader("x-amz-content-sha256")) - .hasValue("STREAMING-UNSIGNED-PAYLOAD-TRAILER"); + .hasValue("STREAMING-AWS4-HMAC-SHA256-PAYLOAD-TRAILER"); + assertThat(signedRequest.request().firstMatchingHeader(Header.CONTENT_LENGTH)).hasValue("314"); + assertThat(signedRequest.request().firstMatchingHeader("x-amz-decoded-content-length")).hasValue("20"); + assertThat(signedRequest.request().firstMatchingHeader("x-amz-trailer")).hasValue("x-amz-checksum-crc32"); } @Test @@ -967,9 +913,84 @@ void sign_withPayloadSigningTrue_chunkEncodingFalse_withChecksum_cacheEmpty_stor assertThat(cache.getChecksumValue(CRC32)).isEqualTo(crc32Value); } + @Test + void signAsync_WithPayloadSigningFalse_chunkEncodingTrue_cacheEmpty_storesComputedChecksum() throws IOException { + PayloadChecksumStore cache = PayloadChecksumStore.create(); + + AsyncSignRequest request = generateBasicAsyncRequest( + AwsCredentialsIdentity.create("access", "secret"), + httpRequest -> httpRequest.uri(URI.create("http://demo.us-east-1.amazonaws.com")), + signRequest -> signRequest + .putProperty(PAYLOAD_SIGNING_ENABLED, false) + .putProperty(CHUNK_ENCODING_ENABLED, true) + .putProperty(CHECKSUM_ALGORITHM, CRC32) + .putProperty(CHECKSUM_STORE, cache) + ); + + AsyncSignedRequest signedRequest = signer.signAsync(request).join(); + + getAllItems(signedRequest.payload().get()); + assertThat(cache.getChecksumValue(CRC32)).isEqualTo(computeChecksum(CRC32, testPayload())); + } + + @Test + void signAsync_WithPayloadSigningFalse_chunkEncodingTrue_cacheContainsChecksum_usesCachedValue() throws IOException { + PayloadChecksumStore cache = PayloadChecksumStore.create(); + + byte[] checksumValue = "my-checksum".getBytes(StandardCharsets.UTF_8); + cache.putChecksumValue(CRC32, checksumValue); + + AsyncSignRequest request = generateBasicAsyncRequest( + AwsCredentialsIdentity.create("access", "secret"), + httpRequest -> httpRequest.uri(URI.create("http://demo.us-east-1.amazonaws.com")), + signRequest -> signRequest + .putProperty(PAYLOAD_SIGNING_ENABLED, false) + .putProperty(CHUNK_ENCODING_ENABLED, true) + .putProperty(CHECKSUM_ALGORITHM, CRC32) + .putProperty(CHECKSUM_STORE, cache) + ); + + AsyncSignedRequest signedRequest = signer.signAsync(request).join(); + + List content = getAllItems(signedRequest.payload().get()); + String contentAsString = content.stream().map(DefaultAwsV4HttpSignerTest::bufferAsString).collect(Collectors.joining()); + assertThat(contentAsString).contains("x-amz-checksum-crc32:" + BinaryUtils.toBase64(checksumValue) + "\r\n"); + } + + @Test + void signAsync_WithPayloadSigningFalse_chunkEncodingTrue_noContentLengthHeader_throws() throws IOException { + PayloadChecksumStore cache = PayloadChecksumStore.create(); + + byte[] checksumValue = "my-checksum".getBytes(StandardCharsets.UTF_8); + cache.putChecksumValue(CRC32, checksumValue); + + AsyncSignRequest request = generateBasicAsyncRequest( + AwsCredentialsIdentity.create("access", "secret"), + httpRequest -> httpRequest.uri(URI.create("http://demo.us-east-1.amazonaws.com")) + .removeHeader("content-length"), + signRequest -> signRequest + .putProperty(PAYLOAD_SIGNING_ENABLED, false) + .putProperty(CHUNK_ENCODING_ENABLED, true) + .putProperty(CHECKSUM_ALGORITHM, CRC32) + ); + + assertThatThrownBy(signer.signAsync(request)::join) + .hasCauseInstanceOf(UnsupportedOperationException.class) + .hasMessageContaining("Content-Length header must be specified"); + } + + private static byte[] computeChecksum(ChecksumAlgorithm algorithm, byte[] data) { SdkChecksum checksum = SdkChecksum.forAlgorithm(algorithm); checksum.update(data, 0, data.length); return checksum.getChecksumBytes(); } + + private List getAllItems(Publisher publisher) { + return Flowable.fromPublisher(publisher).toList().blockingGet(); + } + + private static String bufferAsString(ByteBuffer buffer) { + return StandardCharsets.UTF_8.decode(buffer.duplicate()).toString(); + } } diff --git a/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPublisherTest.java b/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPublisherTest.java index 7f62802ecd1e..fc6acc0aba69 100644 --- a/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPublisherTest.java +++ b/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPublisherTest.java @@ -16,6 +16,7 @@ package software.amazon.awssdk.http.auth.aws.internal.signer.chunkedencoding; import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.mockito.ArgumentMatchers.any; import io.reactivex.Flowable; import io.reactivex.subscribers.TestSubscriber; @@ -32,12 +33,12 @@ import java.util.stream.Stream; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.mockito.ArgumentCaptor; import org.mockito.Mockito; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import software.amazon.awssdk.checksums.DefaultChecksumAlgorithm; import software.amazon.awssdk.checksums.SdkChecksum; +import software.amazon.awssdk.utils.BinaryUtils; import software.amazon.awssdk.utils.Pair; public class ChunkedEncodedPublisherTest { @@ -65,7 +66,7 @@ public void subscribe_publisherEmpty_onlyProducesTrailer() { assertThat(chunks.size()).isEqualTo(1); - String trailerAsString = StandardCharsets.UTF_8.decode(chunks.get(0)).toString(); + String trailerAsString = bufferAsString(chunks.get(0)); assertThat(trailerAsString).isEqualTo( "0\r\n" + @@ -92,8 +93,9 @@ void subscribe_trailerProviderPresent_trailerPartAdded() { List chunks = getAllElements(chunkedPublisher); String expectedTrailer = "foo:bar"; - String trailerAsString = StandardCharsets.UTF_8.decode(chunks.get(1)).toString().trim(); + String trailerAsString = bufferAsString(chunks.get(1).duplicate()).trim(); assertThat(trailerAsString).endsWith(expectedTrailer); + assertChunksHaveChecksum(chunks, upstream.wrappedChecksum()); } @Test @@ -114,8 +116,9 @@ void subscribe_trailerProviderPresent_multipleValues_trailerPartAdded() { List chunks = getAllElements(chunkedPublisher); String expectedTrailer = "foo:bar1,bar2,bar3"; - String trailerAsString = StandardCharsets.UTF_8.decode(chunks.get(1)).toString().trim(); + String trailerAsString = bufferAsString(chunks.get(1).duplicate()).trim(); assertThat(trailerAsString).endsWith(expectedTrailer); + assertChunksHaveChecksum(chunks, upstream.wrappedChecksum()); } @Test @@ -132,7 +135,7 @@ void subscribe_trailerProviderPresent_onlyInvokedOnce() { .contentLength(contentLength) .addTrailer(trailerProvider).build(); - getAllElements(chunkedPublisher); + assertChunksHaveChecksum(getAllElements(chunkedPublisher), upstream.wrappedChecksum()); Mockito.verify(trailerProvider, Mockito.times(1)).get(); } @@ -158,7 +161,8 @@ void subscribe_trailerPresent_trailerFormattedCorrectly() { "foo:bar\r\n" + "\r\n"; - assertThat(chunkAsString(last)).isEqualTo(expected); + assertThat(bufferAsString(last.duplicate())).isEqualTo(expected); + assertChunksHaveChecksum(chunks, testPublisher.wrappedChecksum()); } @Test @@ -176,8 +180,9 @@ void subscribe_wrappedDoesNotFillBuffer_allDataInSingleChunk() { List chunks = getAllElements(publisher); assertThat(chunks.size()).isEqualTo(1); - assertThat(stripEncoding(chunks.get(0))) + assertThat(stripEncoding(chunks.get(0).duplicate())) .isEqualTo(element); + assertChunksHaveChecksum(chunks, crc32(content)); } @Test @@ -197,7 +202,8 @@ void subscribe_extensionHasNoValue_formattedCorrectly() { List chunks = getAllElements(chunkPublisher); - assertThat(getHeaderAsString(chunks.get(0))).endsWith(";foo"); + assertThat(getHeaderAsString(chunks.get(0).duplicate())).endsWith(";foo"); + assertChunksHaveChecksum(chunks, testPublisher.wrappedChecksum()); } @Test @@ -217,7 +223,8 @@ void subscribe_multipleExtensions_formattedCorrectly() { List chunks = getAllElements(chunkPublisher.build()); - chunks.forEach(chunk -> assertThat(getHeaderAsString(chunk)).endsWith(";key1=value1;key2=value2;key3=value3")); + chunks.forEach(chunk -> assertThat(getHeaderAsString(chunk.duplicate())).endsWith(";key1=value1;key2=value2;key3=value3")); + assertChunksHaveChecksum(chunks, testPublisher.wrappedChecksum()); } @Test @@ -264,9 +271,10 @@ void subscribe_randomElementSizes_chunksHaveExtensions_dataChunkedCorrectly() { .build(); List chunks = getAllElements(chunkedPublisher); + assertThat(chunks.size()).isEqualTo(24); chunks.forEach(c -> { - String header = StandardCharsets.UTF_8.decode(getHeader(c)).toString(); + String header = bufferAsString(getHeader(c.duplicate())); assertThat(header).isEqualTo("4000;foo=bar"); }); @@ -300,7 +308,8 @@ void subscribe_addTrailingChunkTrue_trailingChunkAdded() { assertThat(chunks.size()).isEqualTo(3); ByteBuffer trailing = chunks.get(chunks.size() - 1); - assertThat(stripEncoding(trailing).remaining()).isEqualTo(0); + assertThat(stripEncoding(trailing.duplicate()).remaining()).isEqualTo(0); + assertChunksHaveChecksum(chunks, testPublisher.wrappedChecksum()); } @Test @@ -318,35 +327,35 @@ void subscribe_addTrailingChunkTrue_upstreamEmpty_trailingChunkAdded() { List chunks = getAllElements(chunkedPublisher); assertThat(chunks.size()).isEqualTo(1); + assertChunksHaveChecksum(chunks, crc32(new byte[0])); } @Test void subscribe_extensionsPresent_extensionsInvokedForEachChunk() { - ChunkExtensionProvider mockProvider = Mockito.spy(new StaticExtensionProvider("foo", "bar")); + StaticExtensionProvider mockProvider = Mockito.spy(new StaticExtensionProvider("foo", "bar")); + int chunkSize = CHUNK_SIZE; int nChunks = 16; - int contentLength = CHUNK_SIZE * nChunks; + int contentLength = chunkSize * nChunks; TestPublisher elements = randomPublisherOfLength(contentLength); ChunkedEncodedPublisher chunkPublisher = ChunkedEncodedPublisher.builder() .publisher(elements) .contentLength(contentLength) - .chunkSize(CHUNK_SIZE) + .chunkSize(chunkSize) .addExtension(mockProvider) .build(); List chunks = getAllElements(chunkPublisher); - - ArgumentCaptor chunkCaptor = ArgumentCaptor.forClass(ByteBuffer.class); - - Mockito.verify(mockProvider, Mockito.times(nChunks)).get(chunkCaptor.capture()); - List extensionChunks = chunkCaptor.getAllValues(); + Mockito.verify(mockProvider, Mockito.times(nChunks)).get(any(ByteBuffer.class)); for (int i = 0; i < chunks.size(); ++i) { ByteBuffer chunk = chunks.get(i); - ByteBuffer extensionChunk = extensionChunks.get(i); - assertThat(stripEncoding(chunk)).isEqualTo(extensionChunk); + ByteBuffer extensionChunk = mockProvider.recordedChunks.get(i); + + assertThat(stripEncoding(chunk.duplicate())).isEqualTo(extensionChunk); } + assertChunksHaveChecksum(chunks, elements.wrappedChecksum()); } @Test @@ -388,7 +397,9 @@ private TestPublisher randomPublisherOfLength(int bytes) { bytes -= elementSize; byte[] elementContent = new byte[elementSize]; - RNG.nextBytes(elementContent); + for (int i = 0; i < elementSize; ++i) { + elementContent[i] = (byte) ('A' + RNG.nextInt(8)); + } CRC32.update(elementContent); elements.add(ByteBuffer.wrap(elementContent)); } @@ -402,8 +413,8 @@ private List getAllElements(Publisher publisher) { return Flowable.fromPublisher(publisher).toList().blockingGet(); } - private String chunkAsString(ByteBuffer chunk) { - return StandardCharsets.UTF_8.decode(chunk).toString(); + private String bufferAsString(ByteBuffer buffer) { + return StandardCharsets.UTF_8.decode(buffer).toString(); } private String getHeaderAsString(ByteBuffer chunk) { @@ -412,22 +423,23 @@ private String getHeaderAsString(ByteBuffer chunk) { private ByteBuffer getHeader(ByteBuffer chunk) { ByteBuffer header = chunk.duplicate(); - byte a = header.get(0); - byte b = header.get(1); + header.mark(); + byte a = header.get(); + byte b = header.get(); int i = 2; for (; i < header.limit() && a != '\r' && b != '\n'; ++i) { a = b; - b = header.get(i); + b = header.get(); } header.limit(i - 2); + header.reset(); return header; } private ByteBuffer stripEncoding(ByteBuffer chunk) { ByteBuffer header = getHeader(chunk); - ByteBuffer lengthHex = header.duplicate(); boolean semiFound = false; @@ -445,7 +457,7 @@ private ByteBuffer stripEncoding(ByteBuffer chunk) { // assume the whole line is the length (no extensions) lengthHex.flip(); - int length = Integer.parseInt(StandardCharsets.UTF_8.decode(lengthHex).toString(), 16); + int length = Integer.parseInt(bufferAsString(lengthHex), 16); ByteBuffer stripped = chunk.duplicate(); @@ -456,8 +468,18 @@ private ByteBuffer stripEncoding(ByteBuffer chunk) { return stripped; } - private long totalRemaining(List buffers) { - return buffers.stream().mapToLong(ByteBuffer::remaining).sum(); + private byte[] crc32(byte[] data) { + CRC32.reset(); + CRC32.update(data); + byte[] checksum = CRC32.getChecksumBytes(); + CRC32.reset(); + return checksum; + } + + private void assertChunksHaveChecksum(List chunks, byte[] checksum) { + CRC32.reset(); + chunks.forEach(chunk -> CRC32.update(stripEncoding(chunk).duplicate())); + assertThat(CRC32.getChecksumBytes()).isEqualTo(checksum); } private static class TestPublisher implements Publisher { @@ -483,6 +505,7 @@ public byte[] wrappedChecksum() { private static class StaticExtensionProvider implements ChunkExtensionProvider { private final byte[] key; private final byte[] value; + private final List recordedChunks = new ArrayList<>(); public StaticExtensionProvider(String key, String value) { this.key = key.getBytes(StandardCharsets.UTF_8); @@ -491,6 +514,7 @@ public StaticExtensionProvider(String key, String value) { @Override public Pair get(ByteBuffer chunk) { + this.recordedChunks.add(BinaryUtils.immutableCopyOf(chunk)); return Pair.of(key, value); } } diff --git a/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/internal/signer/io/UnbufferedChecksumSubscriberTckTest.java b/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/internal/signer/io/UnbufferedChecksumSubscriberTckTest.java new file mode 100644 index 000000000000..72979d4c5a53 --- /dev/null +++ b/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/internal/signer/io/UnbufferedChecksumSubscriberTckTest.java @@ -0,0 +1,44 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.http.auth.aws.internal.signer.io; + +import io.reactivex.subscribers.TestSubscriber; +import java.nio.ByteBuffer; +import java.util.Collections; +import org.reactivestreams.Subscriber; +import org.reactivestreams.tck.SubscriberBlackboxVerification; +import org.reactivestreams.tck.TestEnvironment; +import software.amazon.awssdk.checksums.DefaultChecksumAlgorithm; +import software.amazon.awssdk.checksums.SdkChecksum; + +public class UnbufferedChecksumSubscriberTckTest extends SubscriberBlackboxVerification { + + public UnbufferedChecksumSubscriberTckTest() { + super(new TestEnvironment()); + } + + @Override + public Subscriber createSubscriber() { + return new UnbufferedChecksumSubscriber( + Collections.singletonList(SdkChecksum.forAlgorithm(DefaultChecksumAlgorithm.CRC32)), + new TestSubscriber<>()); + } + + @Override + public ByteBuffer createElement(int element) { + return ByteBuffer.wrap(String.valueOf(element).getBytes()); + } +} diff --git a/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/internal/signer/io/UnbufferedChecksumSubscriberTest.java b/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/internal/signer/io/UnbufferedChecksumSubscriberTest.java new file mode 100644 index 000000000000..7de930442a80 --- /dev/null +++ b/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/internal/signer/io/UnbufferedChecksumSubscriberTest.java @@ -0,0 +1,89 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.http.auth.aws.internal.signer.io; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.when; + +import io.reactivex.Flowable; +import io.reactivex.subscribers.TestSubscriber; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.checksums.SdkChecksum; + +public class UnbufferedChecksumSubscriberTest { + @Test + void subscribe_updatesEachChecksumWithIdenticalData() { + List buffers = Arrays.asList(ByteBuffer.wrap("foo".getBytes()), + ByteBuffer.wrap("bar".getBytes()), + ByteBuffer.wrap("baz".getBytes())); + + Publisher publisher = Flowable.fromIterable(buffers); + + SdkChecksum checksum1 = Mockito.mock(SdkChecksum.class); + SdkChecksum checksum2 = Mockito.mock(SdkChecksum.class); + + List checksums = Arrays.asList(checksum1, checksum2); + + UnbufferedChecksumSubscriber subscriber = new UnbufferedChecksumSubscriber(checksums, new TestSubscriber<>()); + + publisher.subscribe(subscriber); + + for (SdkChecksum checksum : checksums) { + ArgumentCaptor captor = ArgumentCaptor.forClass(ByteBuffer.class); + Mockito.verify(checksum, Mockito.times(3)).update(captor.capture()); + assertThat(captor.getAllValues()).containsExactlyElementsOf(buffers); + } + } + + @Test + public void subscribe_onNextDelegatedToWrappedSubscriber() { + List buffers = Arrays.asList(ByteBuffer.wrap("foo".getBytes()), + ByteBuffer.wrap("bar".getBytes()), + ByteBuffer.wrap("baz".getBytes())); + + Publisher publisher = Flowable.fromIterable(buffers); + + SdkChecksum checksum = Mockito.mock(SdkChecksum.class); + + Subscriber wrappedSubscriber = Mockito.mock(Subscriber.class); + doAnswer(i -> { + ((Subscription) i.getArguments()[0]).request(Long.MAX_VALUE); + return null; + }).when(wrappedSubscriber).onSubscribe(any(Subscription.class)); + + UnbufferedChecksumSubscriber subscriber = new UnbufferedChecksumSubscriber(Collections.singletonList(checksum), + wrappedSubscriber); + + publisher.subscribe(subscriber); + + ArgumentCaptor captor = ArgumentCaptor.forClass(ByteBuffer.class); + + Mockito.verify(wrappedSubscriber, Mockito.times(3)).onNext(captor.capture()); + + assertThat(captor.getAllValues()).containsExactlyElementsOf(buffers); + } +} diff --git a/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/internal/signer/util/LengthCalculatingSubscriberTckTest.java b/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/internal/signer/util/LengthCalculatingSubscriberTckTest.java new file mode 100644 index 000000000000..946a073a56e4 --- /dev/null +++ b/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/internal/signer/util/LengthCalculatingSubscriberTckTest.java @@ -0,0 +1,38 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.http.auth.aws.internal.signer.util; + +import java.nio.ByteBuffer; +import org.reactivestreams.Subscriber; +import org.reactivestreams.tck.SubscriberBlackboxVerification; +import org.reactivestreams.tck.TestEnvironment; + +public class LengthCalculatingSubscriberTckTest extends SubscriberBlackboxVerification { + + public LengthCalculatingSubscriberTckTest() { + super(new TestEnvironment()); + } + + @Override + public Subscriber createSubscriber() { + return new LengthCalculatingSubscriber(); + } + + @Override + public ByteBuffer createElement(int element) { + return ByteBuffer.wrap(Integer.toString(element).getBytes()); + } +} diff --git a/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/internal/signer/util/SignerUtilsTest.java b/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/internal/signer/util/SignerUtilsTest.java index 5fb5ffd284d4..351cfbfc8be3 100644 --- a/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/internal/signer/util/SignerUtilsTest.java +++ b/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/internal/signer/util/SignerUtilsTest.java @@ -19,27 +19,32 @@ import static software.amazon.awssdk.http.Header.CONTENT_LENGTH; import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant.X_AMZ_DECODED_CONTENT_LENGTH; +import io.reactivex.Flowable; import java.io.ByteArrayInputStream; import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; import java.util.stream.Stream; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.mockito.Mockito; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; import software.amazon.awssdk.http.ContentStreamProvider; import software.amazon.awssdk.http.SdkHttpRequest; public class SignerUtilsTest { @Test - void moveContentLength_decodedContentLengthPresent_shouldNotInvokeNewStream() { + void computeAndMoveContentLength_decodedContentLengthPresent_shouldNotInvokeNewStream() { SdkHttpRequest.Builder request = SdkHttpRequest.builder() .appendHeader(X_AMZ_DECODED_CONTENT_LENGTH, "10") .appendHeader(CONTENT_LENGTH, "10"); ContentStreamProvider streamProvider = Mockito.mock(ContentStreamProvider.class); - long contentLength = SignerUtils.moveContentLength(request, streamProvider); + long contentLength = SignerUtils.computeAndMoveContentLength(request, streamProvider); Mockito.verify(streamProvider, Mockito.never()).newStream(); assertThat(contentLength).isEqualTo(10L); assertThat(request.firstMatchingHeader(CONTENT_LENGTH)).isEmpty(); @@ -47,36 +52,89 @@ void moveContentLength_decodedContentLengthPresent_shouldNotInvokeNewStream() { } @Test - void moveContentLength_contentLengthPresent_shouldNotInvokeNewStream() { + void computeAndMoveContentLength_contentLengthPresent_shouldNotInvokeNewStream() { SdkHttpRequest.Builder request = SdkHttpRequest.builder() .appendHeader(CONTENT_LENGTH, "10"); ContentStreamProvider streamProvider = Mockito.mock(ContentStreamProvider.class); - long contentLength = SignerUtils.moveContentLength(request, streamProvider); + long contentLength = SignerUtils.computeAndMoveContentLength(request, streamProvider); Mockito.verify(streamProvider, Mockito.never()).newStream(); assertThat(contentLength).isEqualTo(10L); assertThat(request.firstMatchingHeader(CONTENT_LENGTH)).isEmpty(); assertThat(request.firstMatchingHeader(X_AMZ_DECODED_CONTENT_LENGTH)).contains("10"); } - public static Stream streams() { - return Stream.of(Arguments.of(new ByteArrayInputStream("hello".getBytes()), 5), - Arguments.of(null, 0)); - } - - @ParameterizedTest @MethodSource("streams") - void moveContentLength_contentLengthNotPresent_shouldInvokeNewStream(InputStream inputStream, long expectedLength) { + void computeAndMoveContentLength_contentLengthNotPresent_shouldInvokeNewStream(InputStream inputStream, long expectedLength) { SdkHttpRequest.Builder request = SdkHttpRequest.builder(); ContentStreamProvider streamProvider = Mockito.mock(ContentStreamProvider.class); Mockito.when(streamProvider.newStream()).thenReturn(inputStream); - long contentLength = SignerUtils.moveContentLength(request, streamProvider); + long contentLength = SignerUtils.computeAndMoveContentLength(request, streamProvider); Mockito.verify(streamProvider, Mockito.times(1)).newStream(); assertThat(contentLength).isEqualTo(expectedLength); assertThat(request.firstMatchingHeader(CONTENT_LENGTH)).isEmpty(); assertThat(request.firstMatchingHeader(X_AMZ_DECODED_CONTENT_LENGTH)).contains(String.valueOf(expectedLength)); } + + @Test + void computeAndMoveContentLength_async_decodedContentLengthPresent_shouldNotSubscribeToPublisher() { + + SdkHttpRequest.Builder request = SdkHttpRequest.builder() + .appendHeader(X_AMZ_DECODED_CONTENT_LENGTH, "10") + .appendHeader(CONTENT_LENGTH, "10"); + + Publisher contentPublisher = Mockito.spy(Flowable.empty()); + + SignerUtils.computeAndMoveContentLength(request, contentPublisher).join(); + Mockito.verify(contentPublisher, Mockito.never()).subscribe(Mockito.any(Subscriber.class)); + + assertThat(request.firstMatchingHeader(CONTENT_LENGTH)).isEmpty(); + assertThat(request.firstMatchingHeader(X_AMZ_DECODED_CONTENT_LENGTH)).contains("10"); + } + + @Test + void computeAndMoveContentLength_async_contentLengthPresent_shouldNotSubscribeToPublisher() { + SdkHttpRequest.Builder request = SdkHttpRequest.builder() + .appendHeader(CONTENT_LENGTH, "10"); + + Publisher contentPublisher = Mockito.spy(Flowable.empty()); + + SignerUtils.computeAndMoveContentLength(request, contentPublisher).join(); + Mockito.verify(contentPublisher, Mockito.never()).subscribe(Mockito.any(Subscriber.class)); + + assertThat(request.firstMatchingHeader(CONTENT_LENGTH)).isEmpty(); + assertThat(request.firstMatchingHeader(X_AMZ_DECODED_CONTENT_LENGTH)).contains("10"); + } + + @ParameterizedTest + @MethodSource("publishers") + void computeAndMoveContentLength_contentLengthNotPresent_shouldInvokeSubscribe(Flowable publisher, long expectedLength) { + SdkHttpRequest.Builder request = SdkHttpRequest.builder(); + + if (publisher != null) { + publisher = Mockito.spy(publisher); + } + + SignerUtils.computeAndMoveContentLength(request, publisher).join(); + + if (publisher != null) { + Mockito.verify(publisher, Mockito.times(1)).subscribe(Mockito.any(Subscriber.class)); + } + + assertThat(request.firstMatchingHeader(CONTENT_LENGTH)).isEmpty(); + assertThat(request.firstMatchingHeader(X_AMZ_DECODED_CONTENT_LENGTH)).contains(String.valueOf(expectedLength)); + } + + public static Stream streams() { + return Stream.of(Arguments.of(new ByteArrayInputStream("hello".getBytes(StandardCharsets.UTF_8)), 5), + Arguments.of(null, 0)); + } + + public static Stream publishers() { + return Stream.of(Arguments.of(Flowable.just(ByteBuffer.wrap("hello".getBytes(StandardCharsets.UTF_8))), 5), + Arguments.of(null, 0)); + } } diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AsyncSigningStage.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AsyncSigningStage.java index ff0c641ec4de..5156395c581f 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AsyncSigningStage.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AsyncSigningStage.java @@ -43,6 +43,8 @@ import software.amazon.awssdk.http.auth.spi.signer.AsyncSignedRequest; import software.amazon.awssdk.http.auth.spi.signer.BaseSignedRequest; import software.amazon.awssdk.http.auth.spi.signer.HttpSigner; +import software.amazon.awssdk.http.auth.spi.signer.PayloadChecksumStore; +import software.amazon.awssdk.http.auth.spi.signer.SdkInternalHttpSignerProperty; import software.amazon.awssdk.http.auth.spi.signer.SignRequest; import software.amazon.awssdk.http.auth.spi.signer.SignedRequest; import software.amazon.awssdk.identity.spi.Identity; @@ -88,11 +90,15 @@ public CompletableFuture execute(SdkHttpFullRequest request, private CompletableFuture sraSignRequest(SdkHttpFullRequest request, RequestExecutionContext context, SelectedAuthScheme selectedAuthScheme) { + // Should not be null, added by HttpChecksumStage for SRA signed requests + PayloadChecksumStore payloadChecksumStore = + context.executionAttributes().getAttribute(SdkInternalExecutionAttribute.CHECKSUM_STORE); + adjustForClockSkew(context.executionAttributes()); CompletableFuture identityFuture = selectedAuthScheme.identity(); return identityFuture.thenCompose(identity -> { CompletableFuture signedRequestFuture = MetricUtils.reportDuration( - () -> doSraSign(request, context, selectedAuthScheme, identity), + () -> doSraSign(request, context, selectedAuthScheme, identity, payloadChecksumStore), context.attemptMetricCollector(), CoreMetric.SIGNING_DURATION); @@ -106,7 +112,8 @@ private CompletableFuture sraSignReques private CompletableFuture doSraSign(SdkHttpFullRequest request, RequestExecutionContext context, SelectedAuthScheme selectedAuthScheme, - T identity) { + T identity, + PayloadChecksumStore payloadChecksumStore) { AuthSchemeOption authSchemeOption = selectedAuthScheme.authSchemeOption(); HttpSigner signer = selectedAuthScheme.signer(); @@ -114,6 +121,7 @@ private CompletableFuture doSraSign(Sdk SignRequest.Builder signRequestBuilder = SignRequest .builder(identity) .putProperty(HttpSigner.SIGNING_CLOCK, signingClock()) + .putProperty(SdkInternalHttpSignerProperty.CHECKSUM_STORE, payloadChecksumStore) .request(request) .payload(request.contentStreamProvider().orElse(null)); authSchemeOption.forEachSignerProperty(signRequestBuilder::putProperty); @@ -125,8 +133,10 @@ private CompletableFuture doSraSign(Sdk AsyncSignRequest.Builder signRequestBuilder = AsyncSignRequest .builder(identity) .putProperty(HttpSigner.SIGNING_CLOCK, signingClock()) + .putProperty(SdkInternalHttpSignerProperty.CHECKSUM_STORE, payloadChecksumStore) .request(request) .payload(context.requestProvider()); + authSchemeOption.forEachSignerProperty(signRequestBuilder::putProperty); CompletableFuture signedRequestFuture = signer.signAsync(signRequestBuilder.build()); diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/HttpChecksumStage.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/HttpChecksumStage.java index ca4b4d8f7f2c..f66f6ff4566b 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/HttpChecksumStage.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/HttpChecksumStage.java @@ -28,7 +28,6 @@ import static software.amazon.awssdk.core.internal.util.ChunkContentUtils.calculateStreamContentLength; import static software.amazon.awssdk.core.internal.util.HttpChecksumResolver.getResolvedChecksumSpecs; import static software.amazon.awssdk.core.internal.util.HttpChecksumUtils.isHttpChecksumCalculationNeeded; -import static software.amazon.awssdk.core.internal.util.HttpChecksumUtils.isStreamingUnsignedPayload; import static software.amazon.awssdk.http.Header.CONTENT_LENGTH; import java.io.IOException; @@ -53,7 +52,6 @@ import software.amazon.awssdk.http.ContentStreamProvider; import software.amazon.awssdk.http.Header; import software.amazon.awssdk.http.SdkHttpFullRequest; -import software.amazon.awssdk.http.SdkHttpRequest; import software.amazon.awssdk.http.auth.aws.internal.signer.util.ChecksumUtil; import software.amazon.awssdk.http.auth.spi.signer.PayloadChecksumStore; import software.amazon.awssdk.utils.BinaryUtils; @@ -124,19 +122,6 @@ private SdkHttpFullRequest.Builder sraChecksum(SdkHttpFullRequest.Builder reques } executionAttributes.putAttribute(RESOLVED_CHECKSUM_SPECS, resolvedChecksumSpecs); - SdkHttpRequest httpRequest = context.executionContext().interceptorContext().httpRequest(); - - // TODO(sra-identity-and-auth): payload checksum calculation (trailer) for sync is done in AwsChunkedV4PayloadSigner, - // but async is still in this class. We should first add chunked encoding support for async to - // AwsChunkedV4PayloadSigner - // and remove the logic here. Details in https://github.com/aws/aws-sdk-java-v2/pull/4568 - if (clientType == ClientType.ASYNC && - isStreamingUnsignedPayload(httpRequest, executionAttributes, resolvedChecksumSpecs, - resolvedChecksumSpecs.isRequestStreaming())) { - addFlexibleChecksumInTrailer(request, context, resolvedChecksumSpecs); - return request; - } - return request; } diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AsyncSigningStageTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AsyncSigningStageTest.java index 134002694aa9..02a74672db4b 100644 --- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AsyncSigningStageTest.java +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AsyncSigningStageTest.java @@ -24,6 +24,7 @@ import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; import static software.amazon.awssdk.core.interceptor.SdkExecutionAttribute.TIME_OFFSET; +import static software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute.CHECKSUM_STORE; import static software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME; import static software.amazon.awssdk.core.metrics.CoreMetric.SIGNING_DURATION; @@ -61,6 +62,8 @@ import software.amazon.awssdk.http.auth.spi.signer.AsyncSignRequest; import software.amazon.awssdk.http.auth.spi.signer.AsyncSignedRequest; import software.amazon.awssdk.http.auth.spi.signer.HttpSigner; +import software.amazon.awssdk.http.auth.spi.signer.PayloadChecksumStore; +import software.amazon.awssdk.http.auth.spi.signer.SdkInternalHttpSignerProperty; import software.amazon.awssdk.http.auth.spi.signer.SignRequest; import software.amazon.awssdk.http.auth.spi.signer.SignedRequest; import software.amazon.awssdk.http.auth.spi.signer.SignerProperty; @@ -525,6 +528,36 @@ public void execute_selectedAuthScheme_signer_doesPreSraSign() throws Exception verifyNoInteractions(httpSigner); } + + @Test + public void execute_checksumStoreAttributePresent_propagatesChecksumStoreToSigner() throws Exception { + SelectedAuthScheme selectedAuthScheme = new SelectedAuthScheme<>( + CompletableFuture.completedFuture(identity), + httpSigner, + AuthSchemeOption.builder() + .schemeId("my.auth#myAuth") + .putSignerProperty(SIGNER_PROPERTY, "value") + .build()); + RequestExecutionContext context = createContext(selectedAuthScheme, null); + + PayloadChecksumStore cache = PayloadChecksumStore.create(); + context.executionAttributes().putAttribute(CHECKSUM_STORE, cache); + + SdkHttpRequest signedRequest = ValidSdkObjects.sdkHttpFullRequest().build(); + when(httpSigner.sign(ArgumentMatchers.>any())) + .thenReturn(SignedRequest.builder() + .request(signedRequest) + .build()); + + SdkHttpFullRequest request = ValidSdkObjects.sdkHttpFullRequest().build(); + stage.execute(request, context); + + ArgumentCaptor> signRequestCaptor = ArgumentCaptor.forClass(SignRequest.class); + verify(httpSigner).sign(signRequestCaptor.capture()); + + assertThat(signRequestCaptor.getValue().property(SdkInternalHttpSignerProperty.CHECKSUM_STORE)).isSameAs(cache); + } + private RequestExecutionContext createContext(SelectedAuthScheme selectedAuthScheme, Signer oldSigner) { return createContext(selectedAuthScheme, null, oldSigner); } diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/HttpChecksumStageSraTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/HttpChecksumStageSraTest.java index ed166597a441..01b6f831fc02 100644 --- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/HttpChecksumStageSraTest.java +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/HttpChecksumStageSraTest.java @@ -16,13 +16,11 @@ package software.amazon.awssdk.core.internal.http.pipeline.stages; import static org.assertj.core.api.Assertions.assertThat; -import static software.amazon.awssdk.core.HttpChecksumConstant.HEADER_FOR_TRAILER_REFERENCE; import static software.amazon.awssdk.core.HttpChecksumConstant.SIGNING_METHOD; import static software.amazon.awssdk.core.interceptor.SdkExecutionAttribute.RESOLVED_CHECKSUM_SPECS; import static software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute.AUTH_SCHEMES; import static software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute.CHECKSUM_STORE; import static software.amazon.awssdk.core.internal.signer.SigningMethod.UNSIGNED_PAYLOAD; -import static software.amazon.awssdk.http.Header.CONTENT_LENGTH; import static software.amazon.awssdk.http.Header.CONTENT_MD5; import java.util.HashMap; @@ -123,28 +121,6 @@ public void sync_flexibleChecksumInTrailer_shouldUpdateResolvedChecksumSpec() th assertThat(checksumSpecs.algorithmV2()).isEqualTo(DefaultChecksumAlgorithm.SHA1); } - @Test - public void async_flexibleChecksumInTrailer_addsFlexibleChecksumInTrailer() throws Exception { - SdkHttpFullRequest.Builder requestBuilder = createHttpRequestBuilder(); - boolean isStreaming = true; - RequestExecutionContext ctx = flexibleChecksumRequestContext(ClientType.ASYNC, - ChecksumSpecs.builder() - .algorithmV2(DefaultChecksumAlgorithm.SHA256) - .headerName(ChecksumUtil.checksumHeaderName(DefaultChecksumAlgorithm.SHA1)), - isStreaming); - - new HttpChecksumStage(ClientType.ASYNC).execute(requestBuilder, ctx); - - assertThat(requestBuilder.headers().get(HEADER_FOR_TRAILER_REFERENCE)).containsExactly(CHECKSUM_SPECS_HEADER); - assertThat(requestBuilder.headers().get("Content-encoding")).containsExactly("aws-chunked"); - assertThat(requestBuilder.headers().get("x-amz-content-sha256")).containsExactly("STREAMING-UNSIGNED-PAYLOAD-TRAILER"); - assertThat(requestBuilder.headers().get("x-amz-decoded-content-length")).containsExactly("8"); - assertThat(requestBuilder.headers().get(CONTENT_LENGTH)).containsExactly("86"); - - assertThat(requestBuilder.firstMatchingHeader(CONTENT_MD5)).isEmpty(); - assertThat(requestBuilder.firstMatchingHeader(CHECKSUM_SPECS_HEADER)).isEmpty(); - } - @Test public void execute_checksumStoreAttributeNotPresent_shouldCreate() throws Exception { SdkHttpFullRequest.Builder requestBuilder = createHttpRequestBuilder(); diff --git a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/AsyncRequestBodyFlexibleChecksumInTrailerTest.java b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/AsyncRequestBodyFlexibleChecksumInTrailerTest.java index 71600ce2fa94..68eeb8d48bd8 100644 --- a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/AsyncRequestBodyFlexibleChecksumInTrailerTest.java +++ b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/AsyncRequestBodyFlexibleChecksumInTrailerTest.java @@ -52,7 +52,6 @@ import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; import software.amazon.awssdk.core.HttpChecksumConstant; import software.amazon.awssdk.core.async.AsyncRequestBody; -import software.amazon.awssdk.core.async.AsyncResponseTransformer; import software.amazon.awssdk.core.checksums.Algorithm; import software.amazon.awssdk.core.checksums.SdkChecksum; import software.amazon.awssdk.core.internal.async.FileAsyncRequestBody; @@ -181,12 +180,10 @@ public void asyncStreaming_FromAsyncRequestBody_VariableChunkSize_NoSigner_addsC FileAsyncRequestBody.builder().path(randomFileOfFixedLength.toPath()) .chunkSizeInBytes(16 * KB) .build()).join(); - verifyHeadersForPutRequest("37948", "37888", "x-amz-checksum-crc32"); + verifyHeadersForPutRequest("37932", "37888", "x-amz-checksum-crc32"); verify(putRequestedFor(anyUrl()).withRequestBody( containing( - "4000" + CRLF + contentString.substring(0, 16 * KB) + CRLF - + "4000" + CRLF + contentString.substring(16 * KB, 32 * KB) + CRLF - + "1400" + CRLF + contentString.substring(32 * KB) + CRLF + "9400" + CRLF + contentString + CRLF + "0" + CRLF + "x-amz-checksum-crc32:" + expectedChecksum + CRLF + CRLF))); } @@ -204,12 +201,10 @@ public void asyncStreaming_withRetry_FromAsyncRequestBody_VariableChunkSize_NoSi FileAsyncRequestBody.builder().path(randomFileOfFixedLength.toPath()) .chunkSizeInBytes(16 * KB) .build()).join(); - verifyHeadersForPutRequest("37948", "37888", "x-amz-checksum-crc32"); + verifyHeadersForPutRequest("37932", "37888", "x-amz-checksum-crc32"); verify(putRequestedFor(anyUrl()).withRequestBody( containing( - "4000" + CRLF + contentString.substring(0, 16 * KB) + CRLF - + "4000" + CRLF + contentString.substring(16 * KB, 32 * KB) + CRLF - + "1400" + CRLF + contentString.substring(32 * KB) + CRLF + "9400" + CRLF + contentString + CRLF + "0" + CRLF + "x-amz-checksum-crc32:" + expectedChecksum + CRLF + CRLF))); }