Skip to content
This repository was archived by the owner on Dec 19, 2023. It is now read-only.

Commit a306294

Browse files
authored
Merge pull request #944 from graphql-java-kickstart/943-support-csrf-on-websockets-to-secure-against-cross-site-attacks
feat: add support for csrf check on websocket upgrade
2 parents 846c66f + be7d443 commit a306294

File tree

12 files changed

+406
-16
lines changed

12 files changed

+406
-16
lines changed
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package graphql.kickstart.autoconfigure.web.servlet;
2+
3+
import lombok.RequiredArgsConstructor;
4+
5+
@RequiredArgsConstructor
6+
class DefaultWsCsrfToken implements WsCsrfToken {
7+
8+
private final String token;
9+
private final String parameterName;
10+
11+
@Override
12+
public String getToken() {
13+
return token;
14+
}
15+
16+
@Override
17+
public String getParameterName() {
18+
return parameterName;
19+
}
20+
}

graphql-spring-boot-autoconfigure/src/main/java/graphql/kickstart/autoconfigure/web/servlet/GraphQLSubscriptionWebsocketProperties.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,12 @@ class GraphQLSubscriptionWebsocketProperties {
1212

1313
private String path = "/subscriptions";
1414
private List<String> allowedOrigins = emptyList();
15+
private CsrfProperties csrf = new CsrfProperties();
16+
17+
@Data
18+
static
19+
class CsrfProperties {
20+
21+
private boolean enabled = false;
22+
}
1523
}

graphql-spring-boot-autoconfigure/src/main/java/graphql/kickstart/autoconfigure/web/servlet/GraphQLWebsocketAutoConfiguration.java

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.springframework.boot.context.properties.EnableConfigurationProperties;
2727
import org.springframework.context.annotation.Bean;
2828
import org.springframework.context.annotation.Conditional;
29+
import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
2930
import org.springframework.web.servlet.DispatcherServlet;
3031
import org.springframework.web.socket.server.standard.ServerEndpointExporter;
3132
import org.springframework.web.socket.server.standard.ServerEndpointRegistration;
@@ -62,11 +63,7 @@ public GraphQLWebsocketServlet graphQLWebsocketServlet(
6263
}
6364
keepAliveListener().ifPresent(listeners::add);
6465
return new GraphQLWebsocketServlet(
65-
graphQLInvoker,
66-
invocationInputFactory,
67-
graphQLObjectMapper,
68-
listeners,
69-
websocketProperties.getAllowedOrigins());
66+
graphQLInvoker, invocationInputFactory, graphQLObjectMapper, listeners);
7067
}
7168

7269
private Optional<SubscriptionConnectionListener> keepAliveListener() {
@@ -78,10 +75,28 @@ private Optional<SubscriptionConnectionListener> keepAliveListener() {
7875
return Optional.empty();
7976
}
8077

78+
@Bean
79+
public WsCsrfFilter wsCsrfFilter(
80+
@Autowired(required = false) WsCsrfTokenRepository csrfTokenRepository) {
81+
return new WsCsrfFilter(websocketProperties.getCsrf(), csrfTokenRepository);
82+
}
83+
84+
@Bean
85+
@ConditionalOnMissingBean
86+
@ConditionalOnClass(HttpSessionCsrfTokenRepository.class)
87+
public WsCsrfTokenRepository wsCsrfTokenRepository() {
88+
return new WsSessionCsrfTokenRepository();
89+
}
90+
8191
@Bean
8292
@ConditionalOnClass(ServerContainer.class)
83-
public ServerEndpointRegistration serverEndpointRegistration(GraphQLWebsocketServlet servlet) {
84-
return new GraphQLWsServerEndpointRegistration(websocketProperties.getPath(), servlet);
93+
public ServerEndpointRegistration serverEndpointRegistration(
94+
GraphQLWebsocketServlet servlet, WsCsrfFilter csrfFilter) {
95+
return new GraphQLWsServerEndpointRegistration(
96+
websocketProperties.getPath(),
97+
servlet,
98+
csrfFilter,
99+
websocketProperties.getAllowedOrigins());
85100
}
86101

87102
@Bean

graphql-spring-boot-autoconfigure/src/main/java/graphql/kickstart/autoconfigure/web/servlet/GraphQLWsServerEndpointRegistration.java

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,63 @@
11
package graphql.kickstart.autoconfigure.web.servlet;
22

33
import graphql.kickstart.servlet.GraphQLWebsocketServlet;
4+
import java.util.ArrayList;
5+
import java.util.List;
46
import jakarta.websocket.HandshakeResponse;
57
import jakarta.websocket.server.HandshakeRequest;
68
import jakarta.websocket.server.ServerEndpointConfig;
79
import org.springframework.context.Lifecycle;
810
import org.springframework.web.socket.server.standard.ServerEndpointRegistration;
911

10-
/** @author Andrew Potter */
12+
/**
13+
* @author Andrew Potter
14+
*/
1115
public class GraphQLWsServerEndpointRegistration extends ServerEndpointRegistration
1216
implements Lifecycle {
1317

18+
private static final String ALL = "*";
1419
private final GraphQLWebsocketServlet servlet;
20+
private final WsCsrfFilter csrfFilter;
21+
private final List<String> allowedOrigins;
1522

16-
public GraphQLWsServerEndpointRegistration(String path, GraphQLWebsocketServlet servlet) {
23+
public GraphQLWsServerEndpointRegistration(
24+
String path,
25+
GraphQLWebsocketServlet servlet,
26+
WsCsrfFilter csrfFilter,
27+
List<String> allowedOrigins) {
1728
super(path, servlet);
1829
this.servlet = servlet;
30+
if (allowedOrigins == null || allowedOrigins.isEmpty()) {
31+
this.allowedOrigins = List.of(ALL);
32+
} else {
33+
this.allowedOrigins = new ArrayList<>(allowedOrigins);
34+
}
35+
this.csrfFilter = csrfFilter;
1936
}
2037

2138
@Override
2239
public boolean checkOrigin(String originHeaderValue) {
23-
return servlet.checkOrigin(originHeaderValue);
40+
if (originHeaderValue == null || originHeaderValue.isBlank()) {
41+
return allowedOrigins.contains(ALL);
42+
}
43+
if (allowedOrigins.contains(ALL)) {
44+
return true;
45+
}
46+
String originToCheck = trimTrailingSlash(originHeaderValue);
47+
return allowedOrigins.stream()
48+
.map(this::trimTrailingSlash)
49+
.anyMatch(originToCheck::equalsIgnoreCase);
50+
}
51+
52+
private String trimTrailingSlash(String origin) {
53+
return (origin.endsWith("/") ? origin.substring(0, origin.length() - 1) : origin);
2454
}
2555

2656
@Override
2757
public void modifyHandshake(
2858
ServerEndpointConfig sec, HandshakeRequest request, HandshakeResponse response) {
2959
super.modifyHandshake(sec, request, response);
60+
csrfFilter.doFilter(request);
3061
servlet.modifyHandshake(sec, request, response);
3162
}
3263

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package graphql.kickstart.autoconfigure.web.servlet;
2+
3+
import static org.springframework.util.CollectionUtils.firstElement;
4+
5+
import graphql.kickstart.autoconfigure.web.servlet.GraphQLSubscriptionWebsocketProperties.CsrfProperties;
6+
import jakarta.websocket.server.HandshakeRequest;
7+
import java.util.Objects;
8+
import lombok.RequiredArgsConstructor;
9+
10+
@RequiredArgsConstructor
11+
class WsCsrfFilter {
12+
13+
private final CsrfProperties csrfProperties;
14+
private final WsCsrfTokenRepository tokenRepository;
15+
16+
void doFilter(HandshakeRequest request) {
17+
if (csrfProperties.isEnabled() && tokenRepository != null) {
18+
WsCsrfToken csrfToken = tokenRepository.loadToken(request);
19+
boolean missingToken = csrfToken == null;
20+
if (missingToken) {
21+
csrfToken = tokenRepository.generateToken(request);
22+
tokenRepository.saveToken(csrfToken, request);
23+
}
24+
25+
String actualToken =
26+
firstElement(request.getParameterMap().get(csrfToken.getParameterName()));
27+
if (!Objects.equals(csrfToken.getToken(), actualToken)) {
28+
throw new IllegalStateException(
29+
"Invalid CSRF Token '"
30+
+ actualToken
31+
+ "' was found on the request parameter '"
32+
+ csrfToken.getParameterName()
33+
+ "'.");
34+
}
35+
}
36+
}
37+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
package graphql.kickstart.autoconfigure.web.servlet;
2+
3+
import java.io.Serializable;
4+
5+
public interface WsCsrfToken extends Serializable {
6+
7+
String getToken();
8+
9+
String getParameterName();
10+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
package graphql.kickstart.autoconfigure.web.servlet;
2+
3+
import jakarta.websocket.server.HandshakeRequest;
4+
5+
public interface WsCsrfTokenRepository {
6+
7+
WsCsrfToken loadToken(HandshakeRequest request);
8+
9+
WsCsrfToken generateToken(HandshakeRequest request);
10+
11+
void saveToken(WsCsrfToken csrfToken, HandshakeRequest request);
12+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package graphql.kickstart.autoconfigure.web.servlet;
2+
3+
import jakarta.servlet.http.HttpSession;
4+
import jakarta.websocket.server.HandshakeRequest;
5+
import java.util.UUID;
6+
import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
7+
8+
class WsSessionCsrfTokenRepository implements WsCsrfTokenRepository {
9+
10+
private static final String DEFAULT_CSRF_PARAMETER_NAME = "_csrf";
11+
12+
private static final String DEFAULT_CSRF_TOKEN_ATTR_NAME =
13+
HttpSessionCsrfTokenRepository.class.getName().concat(".CSRF_TOKEN");
14+
15+
private String sessionAttributeName = DEFAULT_CSRF_TOKEN_ATTR_NAME;
16+
17+
@Override
18+
public void saveToken(WsCsrfToken token, HandshakeRequest request) {
19+
HttpSession session = (HttpSession) request.getHttpSession();
20+
if (session != null) {
21+
if (token == null) {
22+
session.removeAttribute(this.sessionAttributeName);
23+
} else {
24+
session.setAttribute(this.sessionAttributeName, token);
25+
}
26+
}
27+
}
28+
29+
@Override
30+
public WsCsrfToken loadToken(HandshakeRequest request) {
31+
HttpSession session = (HttpSession) request.getHttpSession();
32+
if (session == null) {
33+
return null;
34+
}
35+
return (WsCsrfToken) session.getAttribute(this.sessionAttributeName);
36+
}
37+
38+
@Override
39+
public WsCsrfToken generateToken(HandshakeRequest request) {
40+
return new DefaultWsCsrfToken(UUID.randomUUID().toString(), DEFAULT_CSRF_PARAMETER_NAME);
41+
}
42+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package graphql.kickstart.autoconfigure.web.servlet;
2+
3+
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
4+
5+
import graphql.kickstart.servlet.GraphQLWebsocketServlet;
6+
import java.util.List;
7+
import org.junit.jupiter.api.extension.ExtendWith;
8+
import org.junit.jupiter.params.ParameterizedTest;
9+
import org.junit.jupiter.params.provider.CsvSource;
10+
import org.mockito.Mock;
11+
import org.mockito.junit.jupiter.MockitoExtension;
12+
13+
@ExtendWith(MockitoExtension.class)
14+
class GraphQLWsServerEndpointRegistrationTest {
15+
16+
private static final String PATH = "/subscriptions";
17+
18+
@Mock private GraphQLWebsocketServlet servlet;
19+
@Mock private WsCsrfFilter csrfFilter;
20+
21+
@ParameterizedTest
22+
@CsvSource(
23+
value = {"https://trusted.com", "NULL", "' '"},
24+
nullValues = {"NULL"})
25+
void givenDefaultAllowedOrigins_whenCheckOrigin_thenReturnTrue(String origin) {
26+
var registration = createRegistration();
27+
var allowed = registration.checkOrigin("null".equals(origin) ? null : origin);
28+
assertThat(allowed).isTrue();
29+
}
30+
31+
private GraphQLWsServerEndpointRegistration createRegistration(String... allowedOrigins) {
32+
return new GraphQLWsServerEndpointRegistration(
33+
PATH, servlet, csrfFilter, List.of(allowedOrigins));
34+
}
35+
36+
@ParameterizedTest(name = "{index} => allowedOrigin=''{0}'', originToCheck=''{1}''")
37+
@CsvSource(
38+
delimiterString = "|",
39+
textBlock =
40+
"""
41+
* | https://trusted.com
42+
https://trusted.com | https://trusted.com
43+
https://trusted.com/ | https://trusted.com
44+
https://trusted.com/ | https://trusted.com/
45+
https://trusted.com | https://trusted.com/
46+
""")
47+
void givenAllowedOrigins_whenCheckOrigin_thenReturnTrue(
48+
String allowedOrigin, String originToCheck) {
49+
var registration = createRegistration(allowedOrigin);
50+
var allowed = registration.checkOrigin(originToCheck);
51+
assertThat(allowed).isTrue();
52+
}
53+
}
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
package graphql.kickstart.autoconfigure.web.servlet;
2+
3+
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
4+
import static org.junit.jupiter.api.Assertions.*;
5+
import static org.mockito.ArgumentMatchers.any;
6+
import static org.mockito.Mockito.mock;
7+
import static org.mockito.Mockito.never;
8+
import static org.mockito.Mockito.verify;
9+
import static org.mockito.Mockito.when;
10+
11+
import graphql.kickstart.autoconfigure.web.servlet.GraphQLSubscriptionWebsocketProperties.CsrfProperties;
12+
import jakarta.websocket.server.HandshakeRequest;
13+
import java.util.List;
14+
import java.util.Map;
15+
import org.junit.jupiter.api.Test;
16+
import org.junit.jupiter.api.extension.ExtendWith;
17+
import org.mockito.Mock;
18+
import org.mockito.junit.jupiter.MockitoExtension;
19+
20+
@ExtendWith(MockitoExtension.class)
21+
class WsCsrfFilterTest {
22+
23+
private CsrfProperties csrfProperties = new CsrfProperties();
24+
@Mock private WsCsrfTokenRepository tokenRepository;
25+
@Mock private HandshakeRequest handshakeRequest;
26+
27+
@Test
28+
void givenCsrfDisabled_whenDoFilter_thenDoesNotLoadToken() {
29+
csrfProperties.setEnabled(false);
30+
WsCsrfFilter filter = new WsCsrfFilter(csrfProperties, tokenRepository);
31+
filter.doFilter(handshakeRequest);
32+
33+
verify(tokenRepository, never()).loadToken(any());
34+
}
35+
36+
@Test
37+
void givenCsrfEnabledAndRepositoryNull_whenDoFilter_thenDoesNotGetTokenFromRequest() {
38+
csrfProperties.setEnabled(true);
39+
WsCsrfFilter filter = new WsCsrfFilter(csrfProperties, null);
40+
filter.doFilter(handshakeRequest);
41+
42+
verify(handshakeRequest, never()).getParameterMap();
43+
}
44+
45+
@Test
46+
void givenNoTokenInSession_whenDoFilter_thenGenerateAndSaveToken() {
47+
csrfProperties.setEnabled(true);
48+
when(tokenRepository.loadToken(handshakeRequest)).thenReturn(null);
49+
WsCsrfToken csrfToken = mock(WsCsrfToken.class);
50+
when(tokenRepository.generateToken(handshakeRequest)).thenReturn(csrfToken);
51+
52+
WsCsrfFilter filter = new WsCsrfFilter(csrfProperties, tokenRepository);
53+
filter.doFilter(handshakeRequest);
54+
55+
verify(tokenRepository).saveToken(csrfToken, handshakeRequest);
56+
}
57+
58+
@Test
59+
void givenDifferentActualToken_whenDoFilter_thenThrowsException() {
60+
csrfProperties.setEnabled(true);
61+
WsCsrfToken csrfToken = new DefaultWsCsrfToken("some-token", "_csrf");
62+
when(tokenRepository.loadToken(handshakeRequest)).thenReturn(csrfToken);
63+
when(handshakeRequest.getParameterMap())
64+
.thenReturn(Map.of("_csrf", List.of("different-token")));
65+
66+
WsCsrfFilter filter = new WsCsrfFilter(csrfProperties, tokenRepository);
67+
assertThatThrownBy(() -> filter.doFilter(handshakeRequest))
68+
.isInstanceOf(IllegalStateException.class)
69+
.hasMessage(
70+
"Invalid CSRF Token 'different-token' was found on the request parameter '_csrf'.");
71+
}
72+
73+
@Test
74+
void givenSameToken_whenDoFilter_thenDoesNotThrow() {
75+
csrfProperties.setEnabled(true);
76+
WsCsrfToken csrfToken = new DefaultWsCsrfToken("some-token", "_csrf");
77+
when(tokenRepository.loadToken(handshakeRequest)).thenReturn(csrfToken);
78+
when(handshakeRequest.getParameterMap())
79+
.thenReturn(Map.of("_csrf", List.of("some-token")));
80+
81+
WsCsrfFilter filter = new WsCsrfFilter(csrfProperties, tokenRepository);
82+
assertDoesNotThrow(() -> filter.doFilter(handshakeRequest));
83+
84+
verify(tokenRepository).loadToken(handshakeRequest);
85+
}
86+
}

0 commit comments

Comments
 (0)