11/*
2- * Copyright 2012-2017 the original author or authors.
2+ * Copyright 2012-2020 the original author or authors.
33 *
44 * Licensed under the Apache License, Version 2.0 (the "License");
55 * you may not use this file except in compliance with the License.
1616
1717package org .springframework .security .web .firewall ;
1818
19- import javax .servlet .http .HttpServletRequest ;
20- import javax .servlet .http .HttpServletResponse ;
2119import java .util .Arrays ;
2220import java .util .Collection ;
2321import java .util .Collections ;
2422import java .util .HashSet ;
2523import java .util .List ;
2624import java .util .Set ;
25+ import javax .servlet .http .HttpServletRequest ;
26+ import javax .servlet .http .HttpServletResponse ;
2727
2828/**
2929 * <p>
5959 * Rejects URLs that contain a URL encoded percent. See
6060 * {@link #setAllowUrlEncodedPercent(boolean)}
6161 * </li>
62+ * <li>
63+ * Rejects hosts that are not allowed. See
64+ * {@link #setAllowedHostnames(Collection)}
65+ * </li>
6266 * </ul>
6367 *
6468 * @see DefaultHttpFirewall
6569 * @author Rob Winch
70+ * @author Eddú Meléndez
6671 * @since 4.2.4
6772 */
6873public class StrictHttpFirewall implements HttpFirewall {
@@ -82,6 +87,8 @@ public class StrictHttpFirewall implements HttpFirewall {
8287
8388 private Set <String > decodedUrlBlacklist = new HashSet <String >();
8489
90+ private Collection <String > allowedHostnames ;
91+
8592 public StrictHttpFirewall () {
8693 urlBlacklistsAddAll (FORBIDDEN_SEMICOLON );
8794 urlBlacklistsAddAll (FORBIDDEN_FORWARDSLASH );
@@ -230,6 +237,13 @@ public void setAllowUrlEncodedPercent(boolean allowUrlEncodedPercent) {
230237 }
231238 }
232239
240+ public void setAllowedHostnames (Collection <String > allowedHostnames ) {
241+ if (allowedHostnames == null ) {
242+ throw new IllegalArgumentException ("allowedHostnames cannot be null" );
243+ }
244+ this .allowedHostnames = allowedHostnames ;
245+ }
246+
233247 private void urlBlacklistsAddAll (Collection <String > values ) {
234248 this .encodedUrlBlacklist .addAll (values );
235249 this .decodedUrlBlacklist .addAll (values );
@@ -243,6 +257,7 @@ private void urlBlacklistsRemoveAll(Collection<String> values) {
243257 @ Override
244258 public FirewalledRequest getFirewalledRequest (HttpServletRequest request ) throws RequestRejectedException {
245259 rejectedBlacklistedUrls (request );
260+ rejectedUntrustedHosts (request );
246261
247262 if (!isNormalized (request )) {
248263 throw new RequestRejectedException ("The request was rejected because the URL was not normalized." );
@@ -272,6 +287,19 @@ private void rejectedBlacklistedUrls(HttpServletRequest request) {
272287 }
273288 }
274289
290+ private void rejectedUntrustedHosts (HttpServletRequest request ) {
291+ String serverName = request .getServerName ();
292+ if (serverName == null ) {
293+ return ;
294+ }
295+ if (this .allowedHostnames == null ) {
296+ return ;
297+ }
298+ if (!this .allowedHostnames .contains (serverName )) {
299+ throw new RequestRejectedException ("The request was rejected because the domain " + serverName + " is untrusted." );
300+ }
301+ }
302+
275303 @ Override
276304 public HttpServletResponse getFirewalledResponse (HttpServletResponse response ) {
277305 return new FirewalledResponse (response );
0 commit comments