1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.hadoop.hbase.rest.filter;
19
20 import java.io.IOException;
21 import java.util.HashMap;
22 import java.util.HashSet;
23 import java.util.Map;
24 import java.util.Set;
25 import java.util.regex.Matcher;
26 import java.util.regex.Pattern;
27
28 import javax.servlet.Filter;
29 import javax.servlet.FilterChain;
30 import javax.servlet.FilterConfig;
31 import javax.servlet.ServletException;
32 import javax.servlet.ServletRequest;
33 import javax.servlet.ServletResponse;
34 import javax.servlet.http.HttpServletRequest;
35 import javax.servlet.http.HttpServletResponse;
36
37 import org.apache.hadoop.classification.InterfaceAudience;
38 import org.apache.hadoop.classification.InterfaceStability;
39 import org.apache.hadoop.conf.Configuration;
40
41 import org.slf4j.Logger;
42 import org.slf4j.LoggerFactory;
43
44
45
46
47
48
49
50
51 @InterfaceAudience.Public
52 @InterfaceStability.Evolving
53 public class RestCsrfPreventionFilter implements Filter {
54
55 private static final Logger LOG =
56 LoggerFactory.getLogger(RestCsrfPreventionFilter.class);
57
58 public static final String HEADER_USER_AGENT = "User-Agent";
59 public static final String BROWSER_USER_AGENT_PARAM =
60 "browser-useragents-regex";
61 public static final String CUSTOM_HEADER_PARAM = "custom-header";
62 public static final String CUSTOM_METHODS_TO_IGNORE_PARAM =
63 "methods-to-ignore";
64 static final String BROWSER_USER_AGENTS_DEFAULT = "^Mozilla.*,^Opera.*";
65 static final String HEADER_DEFAULT = "X-XSRF-HEADER";
66 static final String METHODS_TO_IGNORE_DEFAULT = "GET,OPTIONS,HEAD,TRACE";
67 private String headerName = HEADER_DEFAULT;
68 private Set<String> methodsToIgnore = null;
69 private Set<Pattern> browserUserAgents;
70
71 @Override
72 public void init(FilterConfig filterConfig) throws ServletException {
73 String customHeader = filterConfig.getInitParameter(CUSTOM_HEADER_PARAM);
74 if (customHeader != null) {
75 headerName = customHeader;
76 }
77 String customMethodsToIgnore =
78 filterConfig.getInitParameter(CUSTOM_METHODS_TO_IGNORE_PARAM);
79 if (customMethodsToIgnore != null) {
80 parseMethodsToIgnore(customMethodsToIgnore);
81 } else {
82 parseMethodsToIgnore(METHODS_TO_IGNORE_DEFAULT);
83 }
84
85 String agents = filterConfig.getInitParameter(BROWSER_USER_AGENT_PARAM);
86 if (agents == null) {
87 agents = BROWSER_USER_AGENTS_DEFAULT;
88 }
89 parseBrowserUserAgents(agents);
90 LOG.info("Adding cross-site request forgery (CSRF) protection, "
91 + "headerName = {}, methodsToIgnore = {}, browserUserAgents = {}",
92 headerName, methodsToIgnore, browserUserAgents);
93 }
94
95 void parseBrowserUserAgents(String userAgents) {
96 String[] agentsArray = userAgents.split(",");
97 browserUserAgents = new HashSet<Pattern>();
98 for (String patternString : agentsArray) {
99 browserUserAgents.add(Pattern.compile(patternString));
100 }
101 }
102
103 void parseMethodsToIgnore(String mti) {
104 String[] methods = mti.split(",");
105 methodsToIgnore = new HashSet<String>();
106 for (int i = 0; i < methods.length; i++) {
107 methodsToIgnore.add(methods[i]);
108 }
109 }
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126 protected boolean isBrowser(String userAgent) {
127 if (userAgent == null) {
128 return false;
129 }
130 for (Pattern pattern : browserUserAgents) {
131 Matcher matcher = pattern.matcher(userAgent);
132 if (matcher.matches()) {
133 return true;
134 }
135 }
136 return false;
137 }
138
139
140
141
142
143
144
145
146
147
148 public interface HttpInteraction {
149
150
151
152
153
154
155
156 String getHeader(String header);
157
158
159
160
161
162
163 String getMethod();
164
165
166
167
168
169
170
171
172 void proceed() throws IOException, ServletException;
173
174
175
176
177
178
179
180
181
182 void sendError(int code, String message) throws IOException;
183 }
184
185
186
187
188
189
190
191
192
193 public void handleHttpInteraction(HttpInteraction httpInteraction)
194 throws IOException, ServletException {
195 if (!isBrowser(httpInteraction.getHeader(HEADER_USER_AGENT)) ||
196 methodsToIgnore.contains(httpInteraction.getMethod()) ||
197 httpInteraction.getHeader(headerName) != null) {
198 httpInteraction.proceed();
199 } else {
200 httpInteraction.sendError(HttpServletResponse.SC_BAD_REQUEST,
201 "Missing Required Header for CSRF Vulnerability Protection");
202 }
203 }
204
205 @Override
206 public void doFilter(ServletRequest request, ServletResponse response,
207 final FilterChain chain) throws IOException, ServletException {
208 final HttpServletRequest httpRequest = (HttpServletRequest)request;
209 final HttpServletResponse httpResponse = (HttpServletResponse)response;
210 handleHttpInteraction(new ServletFilterHttpInteraction(httpRequest,
211 httpResponse, chain));
212 }
213
214 @Override
215 public void destroy() {
216 }
217
218
219
220
221
222
223
224
225
226
227
228
229 public static Map<String, String> getFilterParams(Configuration conf,
230 String confPrefix) {
231 Map<String, String> filterConfigMap = new HashMap<>();
232 for (Map.Entry<String, String> entry : conf) {
233 String name = entry.getKey();
234 if (name.startsWith(confPrefix)) {
235 String value = conf.get(name);
236 name = name.substring(confPrefix.length());
237 filterConfigMap.put(name, value);
238 }
239 }
240 return filterConfigMap;
241 }
242
243
244
245
246 private static final class ServletFilterHttpInteraction
247 implements HttpInteraction {
248
249 private final FilterChain chain;
250 private final HttpServletRequest httpRequest;
251 private final HttpServletResponse httpResponse;
252
253
254
255
256
257
258
259
260 public ServletFilterHttpInteraction(HttpServletRequest httpRequest,
261 HttpServletResponse httpResponse, FilterChain chain) {
262 this.httpRequest = httpRequest;
263 this.httpResponse = httpResponse;
264 this.chain = chain;
265 }
266
267 @Override
268 public String getHeader(String header) {
269 return httpRequest.getHeader(header);
270 }
271
272 @Override
273 public String getMethod() {
274 return httpRequest.getMethod();
275 }
276
277 @Override
278 public void proceed() throws IOException, ServletException {
279 chain.doFilter(httpRequest, httpResponse);
280 }
281
282 @Override
283 public void sendError(int code, String message) throws IOException {
284 httpResponse.sendError(code, message);
285 }
286 }
287 }