11/*
2- * Copyright 2012-2017 the original author or authors.
2+ * Copyright 2012-2020 the original author or authors.
33 *
44 * Licensed under the Apache License, Version 2.0 (the "License");
55 * you may not use this file except in compliance with the License.
2424import java .util .HashSet ;
2525import java .util .List ;
2626import java .util .Set ;
27+ import java .util .function .Predicate ;
2728
2829/**
2930 * <p>
5960 * Rejects URLs that contain a URL encoded percent. See
6061 * {@link #setAllowUrlEncodedPercent(boolean)}
6162 * </li>
63+ * <li>
64+ * Rejects hosts that are not allowed. See
65+ * {@link #setAllowedHostnames(Predicate)}
66+ * </li>
6267 * </ul>
6368 *
6469 * @see DefaultHttpFirewall
6570 * @author Rob Winch
71+ * @author Eddú Meléndez
6672 * @since 5.0.1
6773 */
6874public class StrictHttpFirewall implements HttpFirewall {
@@ -82,6 +88,8 @@ public class StrictHttpFirewall implements HttpFirewall {
8288
8389 private Set <String > decodedUrlBlacklist = new HashSet <String >();
8490
91+ private Predicate <String > allowedHostnames = hostname -> true ;
92+
8593 public StrictHttpFirewall () {
8694 urlBlacklistsAddAll (FORBIDDEN_SEMICOLON );
8795 urlBlacklistsAddAll (FORBIDDEN_FORWARDSLASH );
@@ -230,6 +238,13 @@ public void setAllowUrlEncodedPercent(boolean allowUrlEncodedPercent) {
230238 }
231239 }
232240
241+ public void setAllowedHostnames (Predicate <String > allowedHostnames ) {
242+ if (allowedHostnames == null ) {
243+ throw new IllegalArgumentException ("allowedHostnames cannot be null" );
244+ }
245+ this .allowedHostnames = allowedHostnames ;
246+ }
247+
233248 private void urlBlacklistsAddAll (Collection <String > values ) {
234249 this .encodedUrlBlacklist .addAll (values );
235250 this .decodedUrlBlacklist .addAll (values );
@@ -243,6 +258,7 @@ private void urlBlacklistsRemoveAll(Collection<String> values) {
243258 @ Override
244259 public FirewalledRequest getFirewalledRequest (HttpServletRequest request ) throws RequestRejectedException {
245260 rejectedBlacklistedUrls (request );
261+ rejectedUntrustedHosts (request );
246262
247263 if (!isNormalized (request )) {
248264 throw new RequestRejectedException ("The request was rejected because the URL was not normalized." );
@@ -272,6 +288,13 @@ private void rejectedBlacklistedUrls(HttpServletRequest request) {
272288 }
273289 }
274290
291+ private void rejectedUntrustedHosts (HttpServletRequest request ) {
292+ String serverName = request .getServerName ();
293+ if (serverName != null && !this .allowedHostnames .test (serverName )) {
294+ throw new RequestRejectedException ("The request was rejected because the domain " + serverName + " is untrusted." );
295+ }
296+ }
297+
275298 @ Override
276299 public HttpServletResponse getFirewalledResponse (HttpServletResponse response ) {
277300 return new FirewalledResponse (response );
0 commit comments