Skip to content

Commit be90817

Browse files
committed
Validate Connection: Upgrade on protocol switch responses
ProtocolSwitchStrategy now rejects 101 Switching Protocols responses that advertise Upgrade without the required Connection: Upgrade token.
1 parent e005c48 commit be90817

File tree

2 files changed

+69
-0
lines changed

2 files changed

+69
-0
lines changed

httpclient5/src/main/java/org/apache/hc/client5/http/impl/ProtocolSwitchStrategy.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
*/
2727
package org.apache.hc.client5.http.impl;
2828

29+
import java.util.concurrent.atomic.AtomicBoolean;
2930
import java.util.concurrent.atomic.AtomicReference;
3031

3132
import org.apache.hc.core5.annotation.Internal;
@@ -70,6 +71,10 @@ public InternalProtocolException(final ProtocolException cause) {
7071
}
7172

7273
public ProtocolVersion switchProtocol(final HttpMessage response) throws ProtocolException {
74+
if (!containsConnectionUpgrade(response)) {
75+
throw new ProtocolException("Invalid protocol switch response: missing Connection: Upgrade");
76+
}
77+
7378
final AtomicReference<ProtocolVersion> tlsUpgrade = new AtomicReference<>();
7479

7580
parseHeaders(response, HttpHeaders.UPGRADE, (buffer, cursor) -> {
@@ -91,6 +96,16 @@ public ProtocolVersion switchProtocol(final HttpMessage response) throws Protoco
9196
}
9297
}
9398

99+
private boolean containsConnectionUpgrade(final HttpMessage message) {
100+
final AtomicBoolean found = new AtomicBoolean(false);
101+
MessageSupport.parseTokens(message, HttpHeaders.CONNECTION, token -> {
102+
if ("upgrade".equalsIgnoreCase(token)) {
103+
found.set(true);
104+
}
105+
});
106+
return found.get();
107+
}
108+
94109
private ProtocolVersion parseProtocolVersion(final CharSequence buffer, final ParserCursor cursor) throws ProtocolException {
95110
TOKENIZER.skipWhiteSpace(buffer, cursor);
96111
final String proto = TOKENIZER.parseToken(buffer, cursor, LAX_PROTO_DELIMITER);

httpclient5/src/test/java/org/apache/hc/client5/http/impl/TestProtocolSwitchStrategy.java

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,30 +52,36 @@ void setUp() {
5252
@Test
5353
void testSwitchToTLS() throws Exception {
5454
final HttpResponse response1 = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
55+
response1.addHeader(HttpHeaders.CONNECTION, "Upgrade");
5556
response1.addHeader(HttpHeaders.UPGRADE, "TLS");
5657
Assertions.assertEquals(TLS.V_1_2.getVersion(), switchStrategy.switchProtocol(response1));
5758

5859
final HttpResponse response2 = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
60+
response2.addHeader(HttpHeaders.CONNECTION, "Upgrade");
5961
response2.addHeader(HttpHeaders.UPGRADE, "TLS/1.3");
6062
Assertions.assertEquals(TLS.V_1_3.getVersion(), switchStrategy.switchProtocol(response2));
6163
}
6264

6365
@Test
6466
void testSwitchToHTTP11AndTLS() throws Exception {
6567
final HttpResponse response1 = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
68+
response1.addHeader(HttpHeaders.CONNECTION, "Upgrade");
6669
response1.addHeader(HttpHeaders.UPGRADE, "TLS, HTTP/1.1");
6770
Assertions.assertEquals(TLS.V_1_2.getVersion(), switchStrategy.switchProtocol(response1));
6871

6972
final HttpResponse response2 = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
73+
response2.addHeader(HttpHeaders.CONNECTION, "Upgrade");
7074
response2.addHeader(HttpHeaders.UPGRADE, ",, HTTP/1.1, TLS, ");
7175
Assertions.assertEquals(TLS.V_1_2.getVersion(), switchStrategy.switchProtocol(response2));
7276

7377
final HttpResponse response3 = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
78+
response3.addHeader(HttpHeaders.CONNECTION, "Upgrade");
7479
response3.addHeader(HttpHeaders.UPGRADE, "HTTP/1.1");
7580
response3.addHeader(HttpHeaders.UPGRADE, "TLS/1.2");
7681
Assertions.assertEquals(TLS.V_1_2.getVersion(), switchStrategy.switchProtocol(response3));
7782

7883
final HttpResponse response4 = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
84+
response4.addHeader(HttpHeaders.CONNECTION, "Upgrade");
7985
response4.addHeader(HttpHeaders.UPGRADE, "HTTP/1.1");
8086
response4.addHeader(HttpHeaders.UPGRADE, "TLS/1.2, TLS/1.3");
8187
Assertions.assertEquals(TLS.V_1_3.getVersion(), switchStrategy.switchProtocol(response4));
@@ -84,42 +90,49 @@ void testSwitchToHTTP11AndTLS() throws Exception {
8490
@Test
8591
void testSwitchInvalid() {
8692
final HttpResponse response1 = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
93+
response1.addHeader(HttpHeaders.CONNECTION, "Upgrade");
8794
response1.addHeader(HttpHeaders.UPGRADE, "Crap");
8895
Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response1));
8996

9097
final HttpResponse response2 = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
98+
response2.addHeader(HttpHeaders.CONNECTION, "Upgrade");
9199
response2.addHeader(HttpHeaders.UPGRADE, "TLS, huh?");
92100
Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response2));
93101

94102
final HttpResponse response3 = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
103+
response3.addHeader(HttpHeaders.CONNECTION, "Upgrade");
95104
response3.addHeader(HttpHeaders.UPGRADE, ",,,");
96105
Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response3));
97106
}
98107

99108
@Test
100109
void testWhitespaceOnlyToken() throws ProtocolException {
101110
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
111+
response.addHeader(HttpHeaders.CONNECTION, "Upgrade");
102112
response.addHeader(HttpHeaders.UPGRADE, " , TLS");
103113
Assertions.assertEquals(TLS.V_1_2.getVersion(), switchStrategy.switchProtocol(response));
104114
}
105115

106116
@Test
107117
void testUnsupportedTlsVersion() throws Exception {
108118
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
119+
response.addHeader(HttpHeaders.CONNECTION, "Upgrade");
109120
response.addHeader(HttpHeaders.UPGRADE, "TLS/1.4");
110121
Assertions.assertEquals(new ProtocolVersion("TLS", 1, 4), switchStrategy.switchProtocol(response));
111122
}
112123

113124
@Test
114125
void testUnsupportedTlsMajorVersion() throws Exception {
115126
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
127+
response.addHeader(HttpHeaders.CONNECTION, "Upgrade");
116128
response.addHeader(HttpHeaders.UPGRADE, "TLS/2.0");
117129
Assertions.assertEquals(new ProtocolVersion("TLS", 2, 0), switchStrategy.switchProtocol(response));
118130
}
119131

120132
@Test
121133
void testUnsupportedHttpVersion() {
122134
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
135+
response.addHeader(HttpHeaders.CONNECTION, "Upgrade");
123136
response.addHeader(HttpHeaders.UPGRADE, "HTTP/2.0");
124137
final ProtocolException ex = Assertions.assertThrows(ProtocolException.class, () ->
125138
switchStrategy.switchProtocol(response));
@@ -129,6 +142,7 @@ void testUnsupportedHttpVersion() {
129142
@Test
130143
void testInvalidTlsFormat() {
131144
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
145+
response.addHeader(HttpHeaders.CONNECTION, "Upgrade");
132146
response.addHeader(HttpHeaders.UPGRADE, "TLS/abc");
133147
final ProtocolException ex = Assertions.assertThrows(ProtocolException.class, () ->
134148
switchStrategy.switchProtocol(response));
@@ -138,6 +152,7 @@ void testInvalidTlsFormat() {
138152
@Test
139153
void testHttp11Only() {
140154
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
155+
response.addHeader(HttpHeaders.CONNECTION, "Upgrade");
141156
response.addHeader(HttpHeaders.UPGRADE, "HTTP/1.1");
142157
final ProtocolException ex = Assertions.assertThrows(ProtocolException.class, () ->
143158
switchStrategy.switchProtocol(response));
@@ -147,6 +162,7 @@ void testHttp11Only() {
147162
@Test
148163
void testSwitchToTlsValid_TLS_1_2() throws Exception {
149164
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
165+
response.addHeader(HttpHeaders.CONNECTION, "Upgrade");
150166
response.addHeader(HttpHeaders.UPGRADE, "TLS/1.2");
151167
final ProtocolVersion result = switchStrategy.switchProtocol(response);
152168
Assertions.assertEquals(TLS.V_1_2.getVersion(), result);
@@ -155,6 +171,7 @@ void testSwitchToTlsValid_TLS_1_2() throws Exception {
155171
@Test
156172
void testSwitchToTlsValid_TLS_1_0() throws Exception {
157173
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
174+
response.addHeader(HttpHeaders.CONNECTION, "Upgrade");
158175
response.addHeader(HttpHeaders.UPGRADE, "TLS/1.0");
159176
final ProtocolVersion result = switchStrategy.switchProtocol(response);
160177
Assertions.assertEquals(TLS.V_1_0.getVersion(), result);
@@ -163,6 +180,7 @@ void testSwitchToTlsValid_TLS_1_0() throws Exception {
163180
@Test
164181
void testSwitchToTlsValid_TLS_1_1() throws Exception {
165182
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
183+
response.addHeader(HttpHeaders.CONNECTION, "Upgrade");
166184
response.addHeader(HttpHeaders.UPGRADE, "TLS/1.1");
167185
final ProtocolVersion result = switchStrategy.switchProtocol(response);
168186
Assertions.assertEquals(TLS.V_1_1.getVersion(), result);
@@ -171,6 +189,7 @@ void testSwitchToTlsValid_TLS_1_1() throws Exception {
171189
@Test
172190
void testInvalidTlsFormat_NoSlash() {
173191
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
192+
response.addHeader(HttpHeaders.CONNECTION, "Upgrade");
174193
response.addHeader(HttpHeaders.UPGRADE, "TLSv1");
175194
final ProtocolException ex = Assertions.assertThrows(ProtocolException.class, () ->
176195
switchStrategy.switchProtocol(response));
@@ -180,6 +199,7 @@ void testInvalidTlsFormat_NoSlash() {
180199
@Test
181200
void testSwitchToTlsValid_TLS_1() throws Exception {
182201
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
202+
response.addHeader(HttpHeaders.CONNECTION, "Upgrade");
183203
response.addHeader(HttpHeaders.UPGRADE, "TLS/1");
184204
final ProtocolVersion result = switchStrategy.switchProtocol(response);
185205
Assertions.assertEquals(TLS.V_1_0.getVersion(), result);
@@ -188,6 +208,7 @@ void testSwitchToTlsValid_TLS_1() throws Exception {
188208
@Test
189209
void testInvalidTlsFormat_MissingMajor() {
190210
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
211+
response.addHeader(HttpHeaders.CONNECTION, "Upgrade");
191212
response.addHeader(HttpHeaders.UPGRADE, "TLS/.1");
192213
final ProtocolException ex = Assertions.assertThrows(ProtocolException.class, () ->
193214
switchStrategy.switchProtocol(response));
@@ -197,6 +218,7 @@ void testInvalidTlsFormat_MissingMajor() {
197218
@Test
198219
void testMultipleHttp11Tokens() {
199220
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
221+
response.addHeader(HttpHeaders.CONNECTION, "Upgrade");
200222
response.addHeader(HttpHeaders.UPGRADE, "HTTP/1.1, HTTP/1.1");
201223
final ProtocolException ex = Assertions.assertThrows(ProtocolException.class, () ->
202224
switchStrategy.switchProtocol(response));
@@ -206,6 +228,7 @@ void testMultipleHttp11Tokens() {
206228
@Test
207229
void testMixedInvalidAndValidTokens() {
208230
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
231+
response.addHeader(HttpHeaders.CONNECTION, "Upgrade");
209232
response.addHeader(HttpHeaders.UPGRADE, "Crap, TLS/1.2, Invalid");
210233
final ProtocolException ex = Assertions.assertThrows(ProtocolException.class, () ->
211234
switchStrategy.switchProtocol(response));
@@ -215,10 +238,41 @@ void testMixedInvalidAndValidTokens() {
215238
@Test
216239
void testInvalidTlsFormat_NoProtocolName() {
217240
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
241+
response.addHeader(HttpHeaders.CONNECTION, "Upgrade");
218242
response.addHeader(HttpHeaders.UPGRADE, ",,/1.1");
219243
final ProtocolException ex = Assertions.assertThrows(ProtocolException.class, () ->
220244
switchStrategy.switchProtocol(response));
221245
Assertions.assertEquals("Invalid protocol; error at offset 2: <,,/1.1>", ex.getMessage());
222246
}
223247

248+
@Test
249+
void testMissingConnectionUpgradeRejected() {
250+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
251+
response.addHeader(HttpHeaders.UPGRADE, "TLS/1.2");
252+
253+
final ProtocolException ex = Assertions.assertThrows(ProtocolException.class, () ->
254+
switchStrategy.switchProtocol(response));
255+
Assertions.assertEquals("Invalid protocol switch response: missing Connection: Upgrade", ex.getMessage());
256+
}
257+
258+
@Test
259+
void testConnectionWithoutUpgradeTokenRejected() {
260+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
261+
response.addHeader(HttpHeaders.CONNECTION, "keep-alive");
262+
response.addHeader(HttpHeaders.UPGRADE, "TLS/1.2");
263+
264+
final ProtocolException ex = Assertions.assertThrows(ProtocolException.class, () ->
265+
switchStrategy.switchProtocol(response));
266+
Assertions.assertEquals("Invalid protocol switch response: missing Connection: Upgrade", ex.getMessage());
267+
}
268+
269+
@Test
270+
void testConnectionWithUpgradeTokenInListAccepted() throws Exception {
271+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
272+
response.addHeader(HttpHeaders.CONNECTION, "keep-alive, Upgrade");
273+
response.addHeader(HttpHeaders.UPGRADE, "TLS/1.2");
274+
275+
Assertions.assertEquals(TLS.V_1_2.getVersion(), switchStrategy.switchProtocol(response));
276+
}
277+
224278
}

0 commit comments

Comments
 (0)