diff --git a/core/src/main/java/io/grpc/internal/ServerCallImpl.java b/core/src/main/java/io/grpc/internal/ServerCallImpl.java index e224384ce8f..22d5912050a 100644 --- a/core/src/main/java/io/grpc/internal/ServerCallImpl.java +++ b/core/src/main/java/io/grpc/internal/ServerCallImpl.java @@ -327,16 +327,20 @@ private void messagesAvailableInternal(final MessageProducer producer) { return; } - InputStream message; try { + InputStream message; while ((message = producer.next()) != null) { - try { - listener.onMessage(call.method.parseRequest(message)); - } catch (Throwable t) { + ReqT parsedMessage; + try (InputStream ignored = message) { + parsedMessage = call.method.parseRequest(message); + } catch (StatusRuntimeException e) { GrpcUtil.closeQuietly(message); - throw t; + GrpcUtil.closeQuietly(producer); + call.cancelled = true; + call.close(e.getStatus(), new Metadata()); + return; } - message.close(); + listener.onMessage(parsedMessage); } } catch (Throwable t) { GrpcUtil.closeQuietly(producer); diff --git a/core/src/test/java/io/grpc/internal/ServerCallImplTest.java b/core/src/test/java/io/grpc/internal/ServerCallImplTest.java index 7394c83eab2..abe8fb0ee56 100644 --- a/core/src/test/java/io/grpc/internal/ServerCallImplTest.java +++ b/core/src/test/java/io/grpc/internal/ServerCallImplTest.java @@ -48,9 +48,11 @@ import io.grpc.SecurityLevel; import io.grpc.ServerCall; import io.grpc.Status; +import io.grpc.StatusRuntimeException; import io.grpc.internal.ServerCallImpl.ServerStreamListenerImpl; import io.perfmark.PerfMark; import java.io.ByteArrayInputStream; +import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import org.junit.Before; @@ -69,6 +71,8 @@ public class ServerCallImplTest { @Mock private ServerStream stream; @Mock private ServerCall.Listener callListener; + @Mock private StreamListener.MessageProducer messageProducer; + @Mock private InputStream message; private final CallTracer serverCallTracer = CallTracer.getDefaultFactory().create(); private ServerCallImpl call; @@ -493,6 +497,43 @@ public void streamListener_unexpectedRuntimeException() { assertThat(e).hasMessageThat().isEqualTo("unexpected exception"); } + @Test + public void streamListener_statusRuntimeException() throws IOException { + MethodDescriptor failingParseMethod = MethodDescriptor.newBuilder() + .setType(MethodType.UNARY) + .setFullMethodName("service/method") + .setRequestMarshaller(new LongMarshaller() { + @Override + public Long parse(InputStream stream) { + throw new StatusRuntimeException(Status.RESOURCE_EXHAUSTED + .withDescription("Decompressed gRPC message exceeds maximum size")); + } + }) + .setResponseMarshaller(new LongMarshaller()) + .build(); + + call = new ServerCallImpl<>(stream, failingParseMethod, requestHeaders, context, + DecompressorRegistry.getDefaultInstance(), CompressorRegistry.getDefaultInstance(), + serverCallTracer, PerfMark.createTag()); + + ServerStreamListenerImpl streamListener = + new ServerCallImpl.ServerStreamListenerImpl<>(call, callListener, context); + + when(messageProducer.next()).thenReturn(message, (InputStream) null); + streamListener.messagesAvailable(messageProducer); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + ArgumentCaptor metadataCaptor = ArgumentCaptor.forClass(Metadata.class); + + verify(stream).close(statusCaptor.capture(), metadataCaptor.capture()); + Status status = statusCaptor.getValue(); + assertEquals(Status.RESOURCE_EXHAUSTED.getCode(), status.getCode()); + assertEquals("Decompressed gRPC message exceeds maximum size", status.getDescription()); + + streamListener.halfClosed(); + verify(callListener, never()).onHalfClose(); + verify(callListener, never()).onMessage(any()); + } + private static class LongMarshaller implements Marshaller { @Override public InputStream stream(Long value) { diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java b/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java index 11455790497..843019433aa 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java @@ -2030,7 +2030,7 @@ private void assertPayload(Payload expected, Payload actual) { } } - private static void assertCodeEquals(Status.Code expected, Status actual) { + protected static void assertCodeEquals(Status.Code expected, Status actual) { assertWithMessage("Unexpected status: %s", actual).that(actual.getCode()).isEqualTo(expected); } diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/TransportCompressionTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/TransportCompressionTest.java index b9692383254..33cd624aebb 100644 --- a/interop-testing/src/test/java/io/grpc/testing/integration/TransportCompressionTest.java +++ b/interop-testing/src/test/java/io/grpc/testing/integration/TransportCompressionTest.java @@ -17,6 +17,7 @@ package io.grpc.testing.integration; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import com.google.protobuf.ByteString; @@ -37,6 +38,8 @@ import io.grpc.ServerCall.Listener; import io.grpc.ServerCallHandler; import io.grpc.ServerInterceptor; +import io.grpc.Status.Code; +import io.grpc.StatusRuntimeException; import io.grpc.internal.GrpcUtil; import io.grpc.netty.InternalNettyChannelBuilder; import io.grpc.netty.InternalNettyServerBuilder; @@ -53,7 +56,9 @@ import java.io.OutputStream; import org.junit.Before; import org.junit.BeforeClass; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.TestName; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -84,10 +89,16 @@ public static void registerCompressors() { compressors.register(Codec.Identity.NONE); } + @Rule + public final TestName currentTest = new TestName(); + @Override protected ServerBuilder getServerBuilder() { NettyServerBuilder builder = NettyServerBuilder.forPort(0, InsecureServerCredentials.create()) - .maxInboundMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE) + .maxInboundMessageSize( + DECOMPRESSED_MESSAGE_TOO_LONG_METHOD_NAME.equals(currentTest.getMethodName()) + ? 1000 + : AbstractInteropTest.MAX_MESSAGE_SIZE) .compressorRegistry(compressors) .decompressorRegistry(decompressors) .intercept(new ServerInterceptor() { @@ -126,6 +137,22 @@ public void compresses() { assertTrue(FZIPPER.anyWritten); } + private static final String DECOMPRESSED_MESSAGE_TOO_LONG_METHOD_NAME = + "decompressedMessageTooLong"; + + @Test + public void decompressedMessageTooLong() { + assertEquals(DECOMPRESSED_MESSAGE_TOO_LONG_METHOD_NAME, currentTest.getMethodName()); + final SimpleRequest bigRequest = SimpleRequest.newBuilder() + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(new byte[10_000]))) + .build(); + StatusRuntimeException e = assertThrows(StatusRuntimeException.class, + () -> blockingStub.withCompression("gzip").unaryCall(bigRequest)); + assertCodeEquals(Code.RESOURCE_EXHAUSTED, e.getStatus()); + assertEquals("Decompressed gRPC message exceeds maximum size 1000", + e.getStatus().getDescription()); + } + @Override protected NettyChannelBuilder createChannelBuilder() { NettyChannelBuilder builder = NettyChannelBuilder.forAddress(getListenAddress())