Skip to content
This repository was archived by the owner on Dec 19, 2023. It is now read-only.

Commit e054af0

Browse files
committed
feat: add support for csrf check on websocket upgrade
fixes #943
1 parent 846c66f commit e054af0

File tree

7 files changed

+145
-4
lines changed

7 files changed

+145
-4
lines changed
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package graphql.kickstart.autoconfigure.web.servlet;
2+
3+
import lombok.RequiredArgsConstructor;
4+
5+
@RequiredArgsConstructor
6+
class DefaultWsCsrfToken implements WsCsrfToken {
7+
8+
private final String token;
9+
private final String parameterName;
10+
11+
@Override
12+
public String getToken() {
13+
return token;
14+
}
15+
16+
@Override
17+
public String getParameterName() {
18+
return parameterName;
19+
}
20+
}

graphql-spring-boot-autoconfigure/src/main/java/graphql/kickstart/autoconfigure/web/servlet/GraphQLWebsocketAutoConfiguration.java

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.springframework.boot.context.properties.EnableConfigurationProperties;
2727
import org.springframework.context.annotation.Bean;
2828
import org.springframework.context.annotation.Conditional;
29+
import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
2930
import org.springframework.web.servlet.DispatcherServlet;
3031
import org.springframework.web.socket.server.standard.ServerEndpointExporter;
3132
import org.springframework.web.socket.server.standard.ServerEndpointRegistration;
@@ -78,10 +79,25 @@ private Optional<SubscriptionConnectionListener> keepAliveListener() {
7879
return Optional.empty();
7980
}
8081

82+
@Bean
83+
public WsCsrfFilter wsCsrfFilter(
84+
@Autowired(required = false) WsCsrfTokenRepository csrfTokenRepository) {
85+
return new WsCsrfFilter(csrfTokenRepository);
86+
}
87+
88+
@Bean
89+
@ConditionalOnMissingBean
90+
@ConditionalOnClass(HttpSessionCsrfTokenRepository.class)
91+
public WsCsrfTokenRepository wsCsrfTokenRepository() {
92+
return new WsSessionCsrfTokenRepository();
93+
}
94+
8195
@Bean
8296
@ConditionalOnClass(ServerContainer.class)
83-
public ServerEndpointRegistration serverEndpointRegistration(GraphQLWebsocketServlet servlet) {
84-
return new GraphQLWsServerEndpointRegistration(websocketProperties.getPath(), servlet);
97+
public ServerEndpointRegistration serverEndpointRegistration(
98+
GraphQLWebsocketServlet servlet, WsCsrfFilter csrfFilter) {
99+
return new GraphQLWsServerEndpointRegistration(
100+
websocketProperties.getPath(), servlet, csrfFilter);
85101
}
86102

87103
@Bean

graphql-spring-boot-autoconfigure/src/main/java/graphql/kickstart/autoconfigure/web/servlet/GraphQLWsServerEndpointRegistration.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,20 @@
77
import org.springframework.context.Lifecycle;
88
import org.springframework.web.socket.server.standard.ServerEndpointRegistration;
99

10-
/** @author Andrew Potter */
10+
/**
11+
* @author Andrew Potter
12+
*/
1113
public class GraphQLWsServerEndpointRegistration extends ServerEndpointRegistration
1214
implements Lifecycle {
1315

1416
private final GraphQLWebsocketServlet servlet;
17+
private final WsCsrfFilter csrfFilter;
1518

16-
public GraphQLWsServerEndpointRegistration(String path, GraphQLWebsocketServlet servlet) {
19+
public GraphQLWsServerEndpointRegistration(
20+
String path, GraphQLWebsocketServlet servlet, WsCsrfFilter csrfFilter) {
1721
super(path, servlet);
1822
this.servlet = servlet;
23+
this.csrfFilter = csrfFilter;
1924
}
2025

2126
@Override
@@ -27,6 +32,7 @@ public boolean checkOrigin(String originHeaderValue) {
2732
public void modifyHandshake(
2833
ServerEndpointConfig sec, HandshakeRequest request, HandshakeResponse response) {
2934
super.modifyHandshake(sec, request, response);
35+
csrfFilter.doFilter(request);
3036
servlet.modifyHandshake(sec, request, response);
3137
}
3238

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package graphql.kickstart.autoconfigure.web.servlet;
2+
3+
import static org.springframework.util.CollectionUtils.firstElement;
4+
5+
import jakarta.websocket.server.HandshakeRequest;
6+
import java.util.Objects;
7+
import lombok.RequiredArgsConstructor;
8+
9+
@RequiredArgsConstructor
10+
class WsCsrfFilter {
11+
12+
private final WsCsrfTokenRepository tokenRepository;
13+
14+
void doFilter(HandshakeRequest request) {
15+
if (tokenRepository != null) {
16+
WsCsrfToken csrfToken = tokenRepository.loadToken(request);
17+
boolean missingToken = csrfToken == null;
18+
if (missingToken) {
19+
csrfToken = tokenRepository.generateToken(request);
20+
tokenRepository.saveToken(csrfToken, request);
21+
}
22+
23+
String actualToken =
24+
firstElement(request.getParameterMap().get(csrfToken.getParameterName()));
25+
if (!Objects.equals(csrfToken.getToken(), actualToken)) {
26+
throw new IllegalStateException(
27+
"Invalid CSRF Token '"
28+
+ actualToken
29+
+ "' was found on the request parameter '"
30+
+ csrfToken.getParameterName()
31+
+ "'.");
32+
}
33+
}
34+
}
35+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
package graphql.kickstart.autoconfigure.web.servlet;
2+
3+
import java.io.Serializable;
4+
5+
public interface WsCsrfToken extends Serializable {
6+
7+
String getToken();
8+
9+
String getParameterName();
10+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
package graphql.kickstart.autoconfigure.web.servlet;
2+
3+
import jakarta.websocket.server.HandshakeRequest;
4+
5+
public interface WsCsrfTokenRepository {
6+
7+
WsCsrfToken loadToken(HandshakeRequest request);
8+
9+
WsCsrfToken generateToken(HandshakeRequest request);
10+
11+
void saveToken(WsCsrfToken csrfToken, HandshakeRequest request);
12+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package graphql.kickstart.autoconfigure.web.servlet;
2+
3+
import jakarta.servlet.http.HttpSession;
4+
import jakarta.websocket.server.HandshakeRequest;
5+
import java.util.UUID;
6+
import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
7+
8+
class WsSessionCsrfTokenRepository implements WsCsrfTokenRepository {
9+
10+
private static final String DEFAULT_CSRF_PARAMETER_NAME = "_csrf";
11+
12+
private static final String DEFAULT_CSRF_TOKEN_ATTR_NAME =
13+
HttpSessionCsrfTokenRepository.class.getName().concat(".CSRF_TOKEN");
14+
15+
private String sessionAttributeName = DEFAULT_CSRF_TOKEN_ATTR_NAME;
16+
17+
@Override
18+
public void saveToken(WsCsrfToken token, HandshakeRequest request) {
19+
HttpSession session = (HttpSession) request.getHttpSession();
20+
if (session != null) {
21+
if (token == null) {
22+
session.removeAttribute(this.sessionAttributeName);
23+
} else {
24+
session.setAttribute(this.sessionAttributeName, token);
25+
}
26+
}
27+
}
28+
29+
@Override
30+
public WsCsrfToken loadToken(HandshakeRequest request) {
31+
HttpSession session = (HttpSession) request.getHttpSession();
32+
if (session == null) {
33+
return null;
34+
}
35+
return (WsCsrfToken) session.getAttribute(this.sessionAttributeName);
36+
}
37+
38+
@Override
39+
public WsCsrfToken generateToken(HandshakeRequest request) {
40+
return new DefaultWsCsrfToken(UUID.randomUUID().toString(), DEFAULT_CSRF_PARAMETER_NAME);
41+
}
42+
}

0 commit comments

Comments
 (0)