11/*
2- * Copyright 2012-2017 the original author or authors.
2+ * Copyright 2012-2020 the original author or authors.
33 *
44 * Licensed under the Apache License, Version 2.0 (the "License");
55 * you may not use this file except in compliance with the License.
2626import java .util .HashSet ;
2727import java .util .List ;
2828import java .util .Set ;
29+ import java .util .function .Predicate ;
2930
3031/**
3132 * <p>
6667 * Rejects URLs that contain a URL encoded percent. See
6768 * {@link #setAllowUrlEncodedPercent(boolean)}
6869 * </li>
70+ * <li>
71+ * Rejects hosts that are not allowed. See
72+ * {@link #setAllowedHostnames(Predicate)}
73+ * </li>
6974 * </ul>
7075 *
7176 * @see DefaultHttpFirewall
7277 * @author Rob Winch
78+ * @author Eddú Meléndez
7379 * @since 4.2.4
7480 */
7581public class StrictHttpFirewall implements HttpFirewall {
@@ -96,6 +102,8 @@ public class StrictHttpFirewall implements HttpFirewall {
96102
97103 private Set <String > allowedHttpMethods = createDefaultAllowedHttpMethods ();
98104
105+ private Predicate <String > allowedHostnames = hostname -> true ;
106+
99107 public StrictHttpFirewall () {
100108 urlBlacklistsAddAll (FORBIDDEN_SEMICOLON );
101109 urlBlacklistsAddAll (FORBIDDEN_FORWARDSLASH );
@@ -277,6 +285,13 @@ public void setAllowUrlEncodedPercent(boolean allowUrlEncodedPercent) {
277285 }
278286 }
279287
288+ public void setAllowedHostnames (Predicate <String > allowedHostnames ) {
289+ if (allowedHostnames == null ) {
290+ throw new IllegalArgumentException ("allowedHostnames cannot be null" );
291+ }
292+ this .allowedHostnames = allowedHostnames ;
293+ }
294+
280295 private void urlBlacklistsAddAll (Collection <String > values ) {
281296 this .encodedUrlBlacklist .addAll (values );
282297 this .decodedUrlBlacklist .addAll (values );
@@ -291,6 +306,7 @@ private void urlBlacklistsRemoveAll(Collection<String> values) {
291306 public FirewalledRequest getFirewalledRequest (HttpServletRequest request ) throws RequestRejectedException {
292307 rejectForbiddenHttpMethod (request );
293308 rejectedBlacklistedUrls (request );
309+ rejectedUntrustedHosts (request );
294310
295311 if (!isNormalized (request )) {
296312 throw new RequestRejectedException ("The request was rejected because the URL was not normalized." );
@@ -332,6 +348,13 @@ private void rejectedBlacklistedUrls(HttpServletRequest request) {
332348 }
333349 }
334350
351+ private void rejectedUntrustedHosts (HttpServletRequest request ) {
352+ String serverName = request .getServerName ();
353+ if (serverName != null && !this .allowedHostnames .test (serverName )) {
354+ throw new RequestRejectedException ("The request was rejected because the domain " + serverName + " is untrusted." );
355+ }
356+ }
357+
335358 @ Override
336359 public HttpServletResponse getFirewalledResponse (HttpServletResponse response ) {
337360 return new FirewalledResponse (response );
0 commit comments