Skip to content
This repository was archived by the owner on Dec 19, 2023. It is now read-only.

Commit be7d443

Browse files
committed
feat: add unit tests for websocket csrf
1 parent 0c3be87 commit be7d443

File tree

3 files changed

+159
-0
lines changed

3 files changed

+159
-0
lines changed

graphql-spring-boot-autoconfigure/src/main/java/graphql/kickstart/autoconfigure/web/servlet/GraphQLSubscriptionWebsocketProperties.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class GraphQLSubscriptionWebsocketProperties {
1515
private CsrfProperties csrf = new CsrfProperties();
1616

1717
@Data
18+
static
1819
class CsrfProperties {
1920

2021
private boolean enabled = false;
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
package graphql.kickstart.autoconfigure.web.servlet;
2+
3+
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
4+
import static org.junit.jupiter.api.Assertions.*;
5+
import static org.mockito.ArgumentMatchers.any;
6+
import static org.mockito.Mockito.mock;
7+
import static org.mockito.Mockito.never;
8+
import static org.mockito.Mockito.verify;
9+
import static org.mockito.Mockito.when;
10+
11+
import graphql.kickstart.autoconfigure.web.servlet.GraphQLSubscriptionWebsocketProperties.CsrfProperties;
12+
import jakarta.websocket.server.HandshakeRequest;
13+
import java.util.List;
14+
import java.util.Map;
15+
import org.junit.jupiter.api.Test;
16+
import org.junit.jupiter.api.extension.ExtendWith;
17+
import org.mockito.Mock;
18+
import org.mockito.junit.jupiter.MockitoExtension;
19+
20+
@ExtendWith(MockitoExtension.class)
21+
class WsCsrfFilterTest {
22+
23+
private CsrfProperties csrfProperties = new CsrfProperties();
24+
@Mock private WsCsrfTokenRepository tokenRepository;
25+
@Mock private HandshakeRequest handshakeRequest;
26+
27+
@Test
28+
void givenCsrfDisabled_whenDoFilter_thenDoesNotLoadToken() {
29+
csrfProperties.setEnabled(false);
30+
WsCsrfFilter filter = new WsCsrfFilter(csrfProperties, tokenRepository);
31+
filter.doFilter(handshakeRequest);
32+
33+
verify(tokenRepository, never()).loadToken(any());
34+
}
35+
36+
@Test
37+
void givenCsrfEnabledAndRepositoryNull_whenDoFilter_thenDoesNotGetTokenFromRequest() {
38+
csrfProperties.setEnabled(true);
39+
WsCsrfFilter filter = new WsCsrfFilter(csrfProperties, null);
40+
filter.doFilter(handshakeRequest);
41+
42+
verify(handshakeRequest, never()).getParameterMap();
43+
}
44+
45+
@Test
46+
void givenNoTokenInSession_whenDoFilter_thenGenerateAndSaveToken() {
47+
csrfProperties.setEnabled(true);
48+
when(tokenRepository.loadToken(handshakeRequest)).thenReturn(null);
49+
WsCsrfToken csrfToken = mock(WsCsrfToken.class);
50+
when(tokenRepository.generateToken(handshakeRequest)).thenReturn(csrfToken);
51+
52+
WsCsrfFilter filter = new WsCsrfFilter(csrfProperties, tokenRepository);
53+
filter.doFilter(handshakeRequest);
54+
55+
verify(tokenRepository).saveToken(csrfToken, handshakeRequest);
56+
}
57+
58+
@Test
59+
void givenDifferentActualToken_whenDoFilter_thenThrowsException() {
60+
csrfProperties.setEnabled(true);
61+
WsCsrfToken csrfToken = new DefaultWsCsrfToken("some-token", "_csrf");
62+
when(tokenRepository.loadToken(handshakeRequest)).thenReturn(csrfToken);
63+
when(handshakeRequest.getParameterMap())
64+
.thenReturn(Map.of("_csrf", List.of("different-token")));
65+
66+
WsCsrfFilter filter = new WsCsrfFilter(csrfProperties, tokenRepository);
67+
assertThatThrownBy(() -> filter.doFilter(handshakeRequest))
68+
.isInstanceOf(IllegalStateException.class)
69+
.hasMessage(
70+
"Invalid CSRF Token 'different-token' was found on the request parameter '_csrf'.");
71+
}
72+
73+
@Test
74+
void givenSameToken_whenDoFilter_thenDoesNotThrow() {
75+
csrfProperties.setEnabled(true);
76+
WsCsrfToken csrfToken = new DefaultWsCsrfToken("some-token", "_csrf");
77+
when(tokenRepository.loadToken(handshakeRequest)).thenReturn(csrfToken);
78+
when(handshakeRequest.getParameterMap())
79+
.thenReturn(Map.of("_csrf", List.of("some-token")));
80+
81+
WsCsrfFilter filter = new WsCsrfFilter(csrfProperties, tokenRepository);
82+
assertDoesNotThrow(() -> filter.doFilter(handshakeRequest));
83+
84+
verify(tokenRepository).loadToken(handshakeRequest);
85+
}
86+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
package graphql.kickstart.autoconfigure.web.servlet;
2+
3+
import static org.assertj.core.api.Assertions.assertThat;
4+
import static org.junit.jupiter.api.Assertions.*;
5+
import static org.mockito.Mockito.verify;
6+
import static org.mockito.Mockito.when;
7+
8+
import jakarta.servlet.http.HttpSession;
9+
import jakarta.websocket.server.HandshakeRequest;
10+
import java.util.UUID;
11+
import org.junit.jupiter.api.Test;
12+
import org.junit.jupiter.api.extension.ExtendWith;
13+
import org.mockito.Mock;
14+
import org.mockito.junit.jupiter.MockitoExtension;
15+
16+
@ExtendWith(MockitoExtension.class)
17+
class WsSessionCsrfTokenRepositoryTest {
18+
19+
public static final String TOKEN_SESSION_ATTRIBUTE_NAME =
20+
"org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository.CSRF_TOKEN";
21+
@Mock private HandshakeRequest handshakeRequest;
22+
@Mock private HttpSession httpSession;
23+
@Mock private WsCsrfToken csrfToken;
24+
private WsSessionCsrfTokenRepository tokenRepository = new WsSessionCsrfTokenRepository();
25+
26+
@Test
27+
void givenNoSession_whenSaveToken_thenDoesNotThrow() {
28+
when(handshakeRequest.getHttpSession()).thenReturn(null);
29+
assertDoesNotThrow(() -> tokenRepository.saveToken(csrfToken, handshakeRequest));
30+
}
31+
32+
@Test
33+
void givenNoToken_whenSaveToken_thenRemovesFromSession() {
34+
when(handshakeRequest.getHttpSession()).thenReturn(httpSession);
35+
tokenRepository.saveToken(null, handshakeRequest);
36+
verify(httpSession).removeAttribute(TOKEN_SESSION_ATTRIBUTE_NAME);
37+
}
38+
39+
@Test
40+
void givenToken_whenSaveToken_thenSetsInSession() {
41+
when(handshakeRequest.getHttpSession()).thenReturn(httpSession);
42+
tokenRepository.saveToken(csrfToken, handshakeRequest);
43+
verify(httpSession).setAttribute(TOKEN_SESSION_ATTRIBUTE_NAME, csrfToken);
44+
}
45+
46+
@Test
47+
void givenNoSession_whenLoadToken_thenReturnNull() {
48+
when(handshakeRequest.getHttpSession()).thenReturn(null);
49+
WsCsrfToken csrfToken = tokenRepository.loadToken(handshakeRequest);
50+
assertThat(csrfToken).isNull();
51+
}
52+
53+
@Test
54+
void givenTokenInSession_whenLoadToken_thenReturnTokenFromSession() {
55+
when(handshakeRequest.getHttpSession()).thenReturn(httpSession);
56+
when(httpSession.getAttribute(TOKEN_SESSION_ATTRIBUTE_NAME)).thenReturn(csrfToken);
57+
WsCsrfToken loadedToken = tokenRepository.loadToken(handshakeRequest);
58+
assertThat(loadedToken).isEqualTo(csrfToken);
59+
}
60+
61+
@Test
62+
void whenGenerateToken_thenContainsUUID() {
63+
var generatedToken = tokenRepository.generateToken(handshakeRequest);
64+
assertDoesNotThrow(() -> UUID.fromString(generatedToken.getToken()));
65+
}
66+
67+
@Test
68+
void whenGenerateToken_thenContainsCorrectParameterName() {
69+
var generatedToken = tokenRepository.generateToken(handshakeRequest);
70+
assertThat(generatedToken.getParameterName()).isEqualTo("_csrf");
71+
}
72+
}

0 commit comments

Comments
 (0)