View Javadoc

1   /**
2    * Licensed to the Apache Software Foundation (ASF) under one
3    * or more contributor license agreements.  See the NOTICE file
4    * distributed with this work for additional information
5    * regarding copyright ownership.  The ASF licenses this file
6    * to you under the Apache License, Version 2.0 (the
7    * "License"); you may not use this file except in compliance
8    * with the License.  You may obtain a copy of the License at
9    *
10   * http://www.apache.org/licenses/LICENSE-2.0
11   *
12   * Unless required by applicable law or agreed to in writing, software
13   * distributed under the License is distributed on an "AS IS" BASIS,
14   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15   * See the License for the specific language governing permissions and
16   * limitations under the License.
17   */
18  package org.apache.hadoop.hbase.security;
19  
20  import io.netty.buffer.ByteBuf;
21  import io.netty.channel.Channel;
22  import io.netty.channel.ChannelDuplexHandler;
23  import io.netty.channel.ChannelFuture;
24  import io.netty.channel.ChannelFutureListener;
25  import io.netty.channel.ChannelHandlerContext;
26  import io.netty.channel.ChannelPromise;
27  
28  import org.apache.commons.logging.Log;
29  import org.apache.commons.logging.LogFactory;
30  import org.apache.hadoop.hbase.classification.InterfaceAudience;
31  import org.apache.hadoop.ipc.RemoteException;
32  import org.apache.hadoop.security.UserGroupInformation;
33  import org.apache.hadoop.security.token.Token;
34  import org.apache.hadoop.security.token.TokenIdentifier;
35  
36  import javax.security.auth.callback.CallbackHandler;
37  import javax.security.sasl.Sasl;
38  import javax.security.sasl.SaslClient;
39  import javax.security.sasl.SaslException;
40  
41  import java.io.IOException;
42  import java.nio.charset.Charset;
43  import java.security.PrivilegedExceptionAction;
44  import java.util.Random;
45  
46  /**
47   * Handles Sasl connections
48   */
49  @InterfaceAudience.Private
50  public class SaslClientHandler extends ChannelDuplexHandler {
51    public static final Log LOG = LogFactory.getLog(SaslClientHandler.class);
52  
53    private final boolean fallbackAllowed;
54  
55    private final UserGroupInformation ticket;
56  
57    /**
58     * Used for client or server's token to send or receive from each other.
59     */
60    private final SaslClient saslClient;
61    private final SaslExceptionHandler exceptionHandler;
62    private final SaslSuccessfulConnectHandler successfulConnectHandler;
63    private byte[] saslToken;
64    private boolean firstRead = true;
65  
66    private int retryCount = 0;
67    private Random random;
68  
69    /**
70     * Constructor
71     *
72     * @param ticket                   the ugi
73     * @param method                   auth method
74     * @param token                    for Sasl
75     * @param serverPrincipal          Server's Kerberos principal name
76     * @param fallbackAllowed          True if server may also fall back to less secure connection
77     * @param rpcProtection            Quality of protection. Integrity or privacy
78     * @param exceptionHandler         handler for exceptions
79     * @param successfulConnectHandler handler for succesful connects
80     * @throws java.io.IOException if handler could not be created
81     */
82    public SaslClientHandler(UserGroupInformation ticket, AuthMethod method,
83        Token<? extends TokenIdentifier> token, String serverPrincipal, boolean fallbackAllowed,
84        String rpcProtection, SaslExceptionHandler exceptionHandler,
85        SaslSuccessfulConnectHandler successfulConnectHandler) throws IOException {
86      this.ticket = ticket;
87      this.fallbackAllowed = fallbackAllowed;
88  
89      this.exceptionHandler = exceptionHandler;
90      this.successfulConnectHandler = successfulConnectHandler;
91  
92      SaslUtil.initSaslProperties(rpcProtection);
93      switch (method) {
94      case DIGEST:
95        if (LOG.isDebugEnabled())
96          LOG.debug("Creating SASL " + AuthMethod.DIGEST.getMechanismName()
97              + " client to authenticate to service at " + token.getService());
98        saslClient = createDigestSaslClient(new String[] { AuthMethod.DIGEST.getMechanismName() },
99            SaslUtil.SASL_DEFAULT_REALM, new HBaseSaslRpcClient.SaslClientCallbackHandler(token));
100       break;
101     case KERBEROS:
102       if (LOG.isDebugEnabled()) {
103         LOG.debug("Creating SASL " + AuthMethod.KERBEROS.getMechanismName()
104             + " client. Server's Kerberos principal name is " + serverPrincipal);
105       }
106       if (serverPrincipal == null || serverPrincipal.isEmpty()) {
107         throw new IOException("Failed to specify server's Kerberos principal name");
108       }
109       String[] names = SaslUtil.splitKerberosName(serverPrincipal);
110       if (names.length != 3) {
111         throw new IOException(
112             "Kerberos principal does not have the expected format: " + serverPrincipal);
113       }
114       saslClient = createKerberosSaslClient(new String[] { AuthMethod.KERBEROS.getMechanismName() },
115           names[0], names[1]);
116       break;
117     default:
118       throw new IOException("Unknown authentication method " + method);
119     }
120     if (saslClient == null) {
121       throw new IOException("Unable to find SASL client implementation");
122     }
123   }
124 
125   /**
126    * Create a Digest Sasl client
127    *
128    * @param mechanismNames            names of mechanisms
129    * @param saslDefaultRealm          default realm for sasl
130    * @param saslClientCallbackHandler handler for the client
131    * @return new SaslClient
132    * @throws java.io.IOException if creation went wrong
133    */
134   protected SaslClient createDigestSaslClient(String[] mechanismNames, String saslDefaultRealm,
135       CallbackHandler saslClientCallbackHandler) throws IOException {
136     return Sasl.createSaslClient(mechanismNames, null, null, saslDefaultRealm, SaslUtil.SASL_PROPS,
137         saslClientCallbackHandler);
138   }
139 
140   /**
141    * Create Kerberos client
142    *
143    * @param mechanismNames names of mechanisms
144    * @param userFirstPart  first part of username
145    * @param userSecondPart second part of username
146    * @return new SaslClient
147    * @throws java.io.IOException if fails
148    */
149   protected SaslClient createKerberosSaslClient(String[] mechanismNames, String userFirstPart,
150       String userSecondPart) throws IOException {
151     return Sasl
152         .createSaslClient(mechanismNames, null, userFirstPart, userSecondPart, SaslUtil.SASL_PROPS,
153             null);
154   }
155 
156   @Override
157   public void channelUnregistered(ChannelHandlerContext ctx) throws Exception {
158     saslClient.dispose();
159   }
160 
161   private byte[] evaluateChallenge(final byte[] challenge) throws Exception {
162     return ticket.doAs(new PrivilegedExceptionAction<byte[]>() {
163 
164       @Override
165       public byte[] run() throws Exception {
166         return saslClient.evaluateChallenge(challenge);
167       }
168     });
169   }
170 
171   @Override
172   public void handlerAdded(final ChannelHandlerContext ctx) throws Exception {
173     saslToken = new byte[0];
174     if (saslClient.hasInitialResponse()) {
175       saslToken = evaluateChallenge(saslToken);
176     }
177     if (saslToken != null) {
178       writeSaslToken(ctx, saslToken);
179       if (LOG.isDebugEnabled()) {
180         LOG.debug("Have sent token of size " + saslToken.length + " from initSASLContext.");
181       }
182     }
183   }
184 
185   @Override
186   public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
187     ByteBuf in = (ByteBuf) msg;
188 
189     // If not complete, try to negotiate
190     if (!saslClient.isComplete()) {
191       while (!saslClient.isComplete() && in.isReadable()) {
192         readStatus(in);
193         int len = in.readInt();
194         if (firstRead) {
195           firstRead = false;
196           if (len == SaslUtil.SWITCH_TO_SIMPLE_AUTH) {
197             if (!fallbackAllowed) {
198               throw new IOException("Server asks us to fall back to SIMPLE auth, " + "but this "
199                   + "client is configured to only allow secure connections.");
200             }
201             if (LOG.isDebugEnabled()) {
202               LOG.debug("Server asks us to fall back to simple auth.");
203             }
204             saslClient.dispose();
205 
206             ctx.pipeline().remove(this);
207             successfulConnectHandler.onSuccess(ctx.channel());
208             return;
209           }
210         }
211         saslToken = new byte[len];
212         if (LOG.isDebugEnabled()) {
213           LOG.debug("Will read input token of size " + saslToken.length
214               + " for processing by initSASLContext");
215         }
216         in.readBytes(saslToken);
217 
218         saslToken = evaluateChallenge(saslToken);
219         if (saslToken != null) {
220           if (LOG.isDebugEnabled()) {
221             LOG.debug("Will send token of size " + saslToken.length + " from initSASLContext.");
222           }
223           writeSaslToken(ctx, saslToken);
224         }
225       }
226 
227       if (saslClient.isComplete()) {
228         String qop = (String) saslClient.getNegotiatedProperty(Sasl.QOP);
229 
230         if (LOG.isDebugEnabled()) {
231           LOG.debug("SASL client context established. Negotiated QoP: " + qop);
232         }
233 
234         boolean useWrap = qop != null && !"auth".equalsIgnoreCase(qop);
235 
236         if (!useWrap) {
237           ctx.pipeline().remove(this);
238         }
239         successfulConnectHandler.onSuccess(ctx.channel());
240       }
241     }
242     // Normal wrapped reading
243     else {
244       try {
245         int length = in.readInt();
246         if (LOG.isDebugEnabled()) {
247           LOG.debug("Actual length is " + length);
248         }
249         saslToken = new byte[length];
250         in.readBytes(saslToken);
251       } catch (IndexOutOfBoundsException e) {
252         return;
253       }
254       try {
255         ByteBuf b = ctx.channel().alloc().buffer(saslToken.length);
256 
257         b.writeBytes(saslClient.unwrap(saslToken, 0, saslToken.length));
258         ctx.fireChannelRead(b);
259 
260       } catch (SaslException se) {
261         try {
262           saslClient.dispose();
263         } catch (SaslException ignored) {
264           LOG.debug("Ignoring SASL exception", ignored);
265         }
266         throw se;
267       }
268     }
269   }
270 
271   /**
272    * Write SASL token
273    * @param ctx to write to
274    * @param saslToken to write
275    */
276   private void writeSaslToken(final ChannelHandlerContext ctx, byte[] saslToken) {
277     ByteBuf b = ctx.alloc().buffer(4 + saslToken.length);
278     b.writeInt(saslToken.length);
279     b.writeBytes(saslToken, 0, saslToken.length);
280     ctx.writeAndFlush(b).addListener(new ChannelFutureListener() {
281       @Override
282       public void operationComplete(ChannelFuture future) throws Exception {
283         if (!future.isSuccess()) {
284           exceptionCaught(ctx, future.cause());
285         }
286       }
287     });
288   }
289 
290   /**
291    * Get the read status
292    *
293    * @param inStream to read
294    * @throws org.apache.hadoop.ipc.RemoteException if status was not success
295    */
296   private static void readStatus(ByteBuf inStream) throws RemoteException {
297     int status = inStream.readInt(); // read status
298     if (status != SaslStatus.SUCCESS.state) {
299       throw new RemoteException(inStream.toString(Charset.forName("UTF-8")),
300           inStream.toString(Charset.forName("UTF-8")));
301     }
302   }
303 
304   @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause)
305       throws Exception {
306     saslClient.dispose();
307 
308     ctx.close();
309 
310     if (this.random == null) {
311       this.random = new Random();
312     }
313     exceptionHandler.handle(this.retryCount++, this.random, cause);
314   }
315 
316   @Override
317   public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
318       throws Exception {
319     // If not complete, try to negotiate
320     if (!saslClient.isComplete()) {
321       super.write(ctx, msg, promise);
322     } else {
323       ByteBuf in = (ByteBuf) msg;
324 
325       try {
326         saslToken = saslClient.wrap(in.array(), in.readerIndex(), in.readableBytes());
327       } catch (SaslException se) {
328         try {
329           saslClient.dispose();
330         } catch (SaslException ignored) {
331           LOG.debug("Ignoring SASL exception", ignored);
332         }
333         promise.setFailure(se);
334       }
335       if (saslToken != null) {
336         ByteBuf out = ctx.channel().alloc().buffer(4 + saslToken.length);
337         out.writeInt(saslToken.length);
338         out.writeBytes(saslToken, 0, saslToken.length);
339 
340         ctx.write(out).addListener(new ChannelFutureListener() {
341           @Override public void operationComplete(ChannelFuture future) throws Exception {
342             if (!future.isSuccess()) {
343               exceptionCaught(ctx, future.cause());
344             }
345           }
346         });
347 
348         saslToken = null;
349       }
350     }
351   }
352 
353   /**
354    * Handler for exceptions during Sasl connection
355    */
356   public interface SaslExceptionHandler {
357     /**
358      * Handle the exception
359      *
360      * @param retryCount current retry count
361      * @param random     to create new backoff with
362      * @param cause      of fail
363      */
364     public void handle(int retryCount, Random random, Throwable cause);
365   }
366 
367   /**
368    * Handler for successful connects
369    */
370   public interface SaslSuccessfulConnectHandler {
371     /**
372      * Runs on success
373      *
374      * @param channel which is successfully authenticated
375      */
376     public void onSuccess(Channel channel);
377   }
378 }