22
22
import java .nio .charset .StandardCharsets ;
23
23
import java .util .ArrayDeque ;
24
24
import java .util .Deque ;
25
- import java .util .concurrent .TimeUnit ;
26
25
27
26
import org .junit .jupiter .api .Test ;
28
27
import org .junit .jupiter .api .extension .RegisterExtension ;
29
28
29
+ import com .linecorp .armeria .client .ClientRequestContext ;
30
+ import com .linecorp .armeria .common .HttpHeaderNames ;
31
+ import com .linecorp .armeria .common .HttpMethod ;
32
+ import com .linecorp .armeria .common .HttpRequest ;
30
33
import com .linecorp .armeria .common .HttpResponse ;
34
+ import com .linecorp .armeria .common .websocket .CloseWebSocketFrame ;
35
+ import com .linecorp .armeria .common .websocket .WebSocket ;
36
+ import com .linecorp .armeria .common .websocket .WebSocketCloseStatus ;
37
+ import com .linecorp .armeria .common .websocket .WebSocketFrame ;
38
+ import com .linecorp .armeria .common .websocket .WebSocketWriter ;
39
+ import com .linecorp .armeria .internal .common .websocket .WebSocketFrameEncoder ;
31
40
import com .linecorp .armeria .internal .testing .netty .SimpleHttp2Connection ;
32
41
import com .linecorp .armeria .internal .testing .netty .SimpleHttp2Connection .Http2Stream ;
42
+ import com .linecorp .armeria .server .websocket .WebSocketService ;
33
43
import com .linecorp .armeria .testing .junit5 .server .ServerExtension ;
34
44
45
+ import io .netty .buffer .ByteBuf ;
35
46
import io .netty .channel .ChannelHandlerContext ;
47
+ import io .netty .handler .codec .http .HttpHeaderValues ;
48
+ import io .netty .handler .codec .http2 .DefaultHttp2DataFrame ;
36
49
import io .netty .handler .codec .http2 .DefaultHttp2Headers ;
37
50
import io .netty .handler .codec .http2 .DefaultHttp2HeadersFrame ;
38
51
import io .netty .handler .codec .http2 .Http2DataFrame ;
@@ -49,6 +62,13 @@ class Http2ResetStreamTest {
49
62
@ Override
50
63
protected void configure (ServerBuilder sb ) throws Exception {
51
64
sb .service ("/" , (ctx , req ) -> HttpResponse .of ("hello" ));
65
+ sb .service ("/ws" , WebSocketService .builder ((ctx , in ) -> {
66
+ final WebSocketWriter out = WebSocket .streaming ();
67
+ in .collect ().whenComplete ((unused , err ) -> {
68
+ out .close ();
69
+ });
70
+ return out ;
71
+ }).allowedOrigin (ignored -> true ).build ());
52
72
}
53
73
};
54
74
@@ -88,8 +108,53 @@ public void logRstStream(Direction direction, ChannelHandlerContext ctx, int str
88
108
assertThat (((Http2DataFrame ) dataFrame ).isEndStream ()).isTrue ();
89
109
ReferenceCountUtil .release (dataFrame );
90
110
91
- await ().atLeast (100 , TimeUnit .MILLISECONDS )
92
- .untilAsserted (() -> assertThat (rstStreamFrames ).isEmpty ());
111
+ Thread .sleep (1000 );
112
+ assertThat (rstStreamFrames ).isEmpty ();
113
+ }
114
+ }
115
+
116
+ @ Test
117
+ void resetForWebsockets () throws Exception {
118
+ final Deque <Integer > rstStreamFrames = new ArrayDeque <>();
119
+ final Http2FrameLogger frameLogger = new Http2FrameLogger (LogLevel .DEBUG , Http2ResetStreamTest .class ) {
120
+ @ Override
121
+ public void logRstStream (Direction direction , ChannelHandlerContext ctx , int streamId ,
122
+ long errorCode ) {
123
+ rstStreamFrames .offer (streamId );
124
+ super .logRstStream (direction , ctx , streamId , errorCode );
125
+ }
126
+ };
127
+ try (SimpleHttp2Connection conn = SimpleHttp2Connection .of (server .httpUri (), frameLogger );
128
+ Http2Stream stream = conn .createStream ()) {
129
+ final DefaultHttp2Headers headers = new DefaultHttp2Headers ();
130
+ headers .method ("CONNECT" );
131
+ headers .path ("/ws" );
132
+ headers .set (HttpHeaderNames .PROTOCOL , HttpHeaderValues .WEBSOCKET .toString ());
133
+ headers .set (HttpHeaderNames .ORIGIN , "localhost" );
134
+ headers .set (HttpHeaderNames .SEC_WEBSOCKET_VERSION , "13" );
135
+ final Http2HeadersFrame headersFrame = new DefaultHttp2HeadersFrame (headers , false );
136
+ stream .sendFrame (headersFrame ).syncUninterruptibly ();
137
+
138
+ final ClientRequestContext ctx = ClientRequestContext .of (HttpRequest .of (HttpMethod .GET , "/" ));
139
+ final CloseWebSocketFrame closeFrame = WebSocketFrame .ofClose (WebSocketCloseStatus .NORMAL_CLOSURE );
140
+ final ByteBuf closeBuf = WebSocketFrameEncoder .of (true ).encode (ctx , closeFrame );
141
+ stream .sendFrame (new DefaultHttp2DataFrame (closeBuf )).syncUninterruptibly ();
142
+ stream .sendFrame (new DefaultHttp2DataFrame (true )).syncUninterruptibly ();
143
+
144
+ Http2Frame frame = stream .take ();
145
+ assertThat (frame ).isInstanceOf (Http2HeadersFrame .class );
146
+ assertThat (((Http2HeadersFrame ) frame ).headers ().status ()).asString ().isEqualTo ("200" );
147
+
148
+ frame = stream .take ();
149
+ assertThat (frame ).isInstanceOf (Http2DataFrame .class );
150
+ assertThat (((Http2DataFrame ) frame ).content ().toString (StandardCharsets .UTF_8 )).endsWith ("Bye" );
151
+
152
+ frame = stream .take ();
153
+ assertThat (frame ).isInstanceOf (Http2DataFrame .class );
154
+ assertThat (((Http2DataFrame ) frame ).isEndStream ()).isTrue ();
155
+
156
+ Thread .sleep (1000 );
157
+ assertThat (rstStreamFrames ).isEmpty ();
93
158
}
94
159
}
95
160
}
0 commit comments