1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.hadoop.hbase.security;
19
20 import io.netty.buffer.ByteBuf;
21 import io.netty.channel.Channel;
22 import io.netty.channel.ChannelDuplexHandler;
23 import io.netty.channel.ChannelFuture;
24 import io.netty.channel.ChannelFutureListener;
25 import io.netty.channel.ChannelHandlerContext;
26 import io.netty.channel.ChannelPromise;
27
28 import org.apache.commons.logging.Log;
29 import org.apache.commons.logging.LogFactory;
30 import org.apache.hadoop.hbase.classification.InterfaceAudience;
31 import org.apache.hadoop.ipc.RemoteException;
32 import org.apache.hadoop.security.UserGroupInformation;
33 import org.apache.hadoop.security.token.Token;
34 import org.apache.hadoop.security.token.TokenIdentifier;
35
36 import javax.security.auth.callback.CallbackHandler;
37 import javax.security.sasl.Sasl;
38 import javax.security.sasl.SaslClient;
39 import javax.security.sasl.SaslException;
40
41 import java.io.IOException;
42 import java.nio.charset.Charset;
43 import java.security.PrivilegedExceptionAction;
44 import java.util.Random;
45
46
47
48
49 @InterfaceAudience.Private
50 public class SaslClientHandler extends ChannelDuplexHandler {
51 public static final Log LOG = LogFactory.getLog(SaslClientHandler.class);
52
53 private final boolean fallbackAllowed;
54
55 private final UserGroupInformation ticket;
56
57
58
59
60 private final SaslClient saslClient;
61 private final SaslExceptionHandler exceptionHandler;
62 private final SaslSuccessfulConnectHandler successfulConnectHandler;
63 private byte[] saslToken;
64 private boolean firstRead = true;
65
66 private int retryCount = 0;
67 private Random random;
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82 public SaslClientHandler(UserGroupInformation ticket, AuthMethod method,
83 Token<? extends TokenIdentifier> token, String serverPrincipal, boolean fallbackAllowed,
84 String rpcProtection, SaslExceptionHandler exceptionHandler,
85 SaslSuccessfulConnectHandler successfulConnectHandler) throws IOException {
86 this.ticket = ticket;
87 this.fallbackAllowed = fallbackAllowed;
88
89 this.exceptionHandler = exceptionHandler;
90 this.successfulConnectHandler = successfulConnectHandler;
91
92 SaslUtil.initSaslProperties(rpcProtection);
93 switch (method) {
94 case DIGEST:
95 if (LOG.isDebugEnabled())
96 LOG.debug("Creating SASL " + AuthMethod.DIGEST.getMechanismName()
97 + " client to authenticate to service at " + token.getService());
98 saslClient = createDigestSaslClient(new String[] { AuthMethod.DIGEST.getMechanismName() },
99 SaslUtil.SASL_DEFAULT_REALM, new HBaseSaslRpcClient.SaslClientCallbackHandler(token));
100 break;
101 case KERBEROS:
102 if (LOG.isDebugEnabled()) {
103 LOG.debug("Creating SASL " + AuthMethod.KERBEROS.getMechanismName()
104 + " client. Server's Kerberos principal name is " + serverPrincipal);
105 }
106 if (serverPrincipal == null || serverPrincipal.isEmpty()) {
107 throw new IOException("Failed to specify server's Kerberos principal name");
108 }
109 String[] names = SaslUtil.splitKerberosName(serverPrincipal);
110 if (names.length != 3) {
111 throw new IOException(
112 "Kerberos principal does not have the expected format: " + serverPrincipal);
113 }
114 saslClient = createKerberosSaslClient(new String[] { AuthMethod.KERBEROS.getMechanismName() },
115 names[0], names[1]);
116 break;
117 default:
118 throw new IOException("Unknown authentication method " + method);
119 }
120 if (saslClient == null) {
121 throw new IOException("Unable to find SASL client implementation");
122 }
123 }
124
125
126
127
128
129
130
131
132
133
134 protected SaslClient createDigestSaslClient(String[] mechanismNames, String saslDefaultRealm,
135 CallbackHandler saslClientCallbackHandler) throws IOException {
136 return Sasl.createSaslClient(mechanismNames, null, null, saslDefaultRealm, SaslUtil.SASL_PROPS,
137 saslClientCallbackHandler);
138 }
139
140
141
142
143
144
145
146
147
148
149 protected SaslClient createKerberosSaslClient(String[] mechanismNames, String userFirstPart,
150 String userSecondPart) throws IOException {
151 return Sasl
152 .createSaslClient(mechanismNames, null, userFirstPart, userSecondPart, SaslUtil.SASL_PROPS,
153 null);
154 }
155
156 @Override
157 public void channelUnregistered(ChannelHandlerContext ctx) throws Exception {
158 saslClient.dispose();
159 }
160
161 private byte[] evaluateChallenge(final byte[] challenge) throws Exception {
162 return ticket.doAs(new PrivilegedExceptionAction<byte[]>() {
163
164 @Override
165 public byte[] run() throws Exception {
166 return saslClient.evaluateChallenge(challenge);
167 }
168 });
169 }
170
171 @Override
172 public void handlerAdded(final ChannelHandlerContext ctx) throws Exception {
173 saslToken = new byte[0];
174 if (saslClient.hasInitialResponse()) {
175 saslToken = evaluateChallenge(saslToken);
176 }
177 if (saslToken != null) {
178 writeSaslToken(ctx, saslToken);
179 if (LOG.isDebugEnabled()) {
180 LOG.debug("Have sent token of size " + saslToken.length + " from initSASLContext.");
181 }
182 }
183 }
184
185 @Override
186 public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
187 ByteBuf in = (ByteBuf) msg;
188
189
190 if (!saslClient.isComplete()) {
191 while (!saslClient.isComplete() && in.isReadable()) {
192 readStatus(in);
193 int len = in.readInt();
194 if (firstRead) {
195 firstRead = false;
196 if (len == SaslUtil.SWITCH_TO_SIMPLE_AUTH) {
197 if (!fallbackAllowed) {
198 throw new IOException("Server asks us to fall back to SIMPLE auth, " + "but this "
199 + "client is configured to only allow secure connections.");
200 }
201 if (LOG.isDebugEnabled()) {
202 LOG.debug("Server asks us to fall back to simple auth.");
203 }
204 saslClient.dispose();
205
206 ctx.pipeline().remove(this);
207 successfulConnectHandler.onSuccess(ctx.channel());
208 return;
209 }
210 }
211 saslToken = new byte[len];
212 if (LOG.isDebugEnabled()) {
213 LOG.debug("Will read input token of size " + saslToken.length
214 + " for processing by initSASLContext");
215 }
216 in.readBytes(saslToken);
217
218 saslToken = evaluateChallenge(saslToken);
219 if (saslToken != null) {
220 if (LOG.isDebugEnabled()) {
221 LOG.debug("Will send token of size " + saslToken.length + " from initSASLContext.");
222 }
223 writeSaslToken(ctx, saslToken);
224 }
225 }
226
227 if (saslClient.isComplete()) {
228 String qop = (String) saslClient.getNegotiatedProperty(Sasl.QOP);
229
230 if (LOG.isDebugEnabled()) {
231 LOG.debug("SASL client context established. Negotiated QoP: " + qop);
232 }
233
234 boolean useWrap = qop != null && !"auth".equalsIgnoreCase(qop);
235
236 if (!useWrap) {
237 ctx.pipeline().remove(this);
238 }
239 successfulConnectHandler.onSuccess(ctx.channel());
240 }
241 }
242
243 else {
244 try {
245 int length = in.readInt();
246 if (LOG.isDebugEnabled()) {
247 LOG.debug("Actual length is " + length);
248 }
249 saslToken = new byte[length];
250 in.readBytes(saslToken);
251 } catch (IndexOutOfBoundsException e) {
252 return;
253 }
254 try {
255 ByteBuf b = ctx.channel().alloc().buffer(saslToken.length);
256
257 b.writeBytes(saslClient.unwrap(saslToken, 0, saslToken.length));
258 ctx.fireChannelRead(b);
259
260 } catch (SaslException se) {
261 try {
262 saslClient.dispose();
263 } catch (SaslException ignored) {
264 LOG.debug("Ignoring SASL exception", ignored);
265 }
266 throw se;
267 }
268 }
269 }
270
271
272
273
274
275
276 private void writeSaslToken(final ChannelHandlerContext ctx, byte[] saslToken) {
277 ByteBuf b = ctx.alloc().buffer(4 + saslToken.length);
278 b.writeInt(saslToken.length);
279 b.writeBytes(saslToken, 0, saslToken.length);
280 ctx.writeAndFlush(b).addListener(new ChannelFutureListener() {
281 @Override
282 public void operationComplete(ChannelFuture future) throws Exception {
283 if (!future.isSuccess()) {
284 exceptionCaught(ctx, future.cause());
285 }
286 }
287 });
288 }
289
290
291
292
293
294
295
296 private static void readStatus(ByteBuf inStream) throws RemoteException {
297 int status = inStream.readInt();
298 if (status != SaslStatus.SUCCESS.state) {
299 throw new RemoteException(inStream.toString(Charset.forName("UTF-8")),
300 inStream.toString(Charset.forName("UTF-8")));
301 }
302 }
303
304 @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause)
305 throws Exception {
306 saslClient.dispose();
307
308 ctx.close();
309
310 if (this.random == null) {
311 this.random = new Random();
312 }
313 exceptionHandler.handle(this.retryCount++, this.random, cause);
314 }
315
316 @Override
317 public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
318 throws Exception {
319
320 if (!saslClient.isComplete()) {
321 super.write(ctx, msg, promise);
322 } else {
323 ByteBuf in = (ByteBuf) msg;
324
325 try {
326 saslToken = saslClient.wrap(in.array(), in.readerIndex(), in.readableBytes());
327 } catch (SaslException se) {
328 try {
329 saslClient.dispose();
330 } catch (SaslException ignored) {
331 LOG.debug("Ignoring SASL exception", ignored);
332 }
333 promise.setFailure(se);
334 }
335 if (saslToken != null) {
336 ByteBuf out = ctx.channel().alloc().buffer(4 + saslToken.length);
337 out.writeInt(saslToken.length);
338 out.writeBytes(saslToken, 0, saslToken.length);
339
340 ctx.write(out).addListener(new ChannelFutureListener() {
341 @Override public void operationComplete(ChannelFuture future) throws Exception {
342 if (!future.isSuccess()) {
343 exceptionCaught(ctx, future.cause());
344 }
345 }
346 });
347
348 saslToken = null;
349 }
350 }
351 }
352
353
354
355
356 public interface SaslExceptionHandler {
357
358
359
360
361
362
363
364 public void handle(int retryCount, Random random, Throwable cause);
365 }
366
367
368
369
370 public interface SaslSuccessfulConnectHandler {
371
372
373
374
375
376 public void onSuccess(Channel channel);
377 }
378 }