001/** 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.activemq.transport.nio; 019 020import java.io.DataOutputStream; 021import java.io.EOFException; 022import java.io.IOException; 023import java.net.Socket; 024import java.net.URI; 025import java.net.UnknownHostException; 026import java.nio.ByteBuffer; 027import java.util.concurrent.atomic.AtomicInteger; 028 029import javax.net.SocketFactory; 030import javax.net.ssl.SSLContext; 031import javax.net.ssl.SSLEngine; 032import javax.net.ssl.SSLEngineResult; 033import javax.net.ssl.SSLParameters; 034 035import org.apache.activemq.thread.TaskRunnerFactory; 036import org.apache.activemq.util.IOExceptionSupport; 037import org.apache.activemq.util.ServiceStopper; 038import org.apache.activemq.wireformat.WireFormat; 039 040/** 041 * This transport initializes the SSLEngine and reads the first command before 042 * handing off to the detected transport. 043 * 044 */ 045public class AutoInitNioSSLTransport extends NIOSSLTransport { 046 047 public AutoInitNioSSLTransport(WireFormat wireFormat, SocketFactory socketFactory, URI remoteLocation, URI localLocation) throws UnknownHostException, IOException { 048 super(wireFormat, socketFactory, remoteLocation, localLocation); 049 } 050 051 public AutoInitNioSSLTransport(WireFormat wireFormat, Socket socket) throws IOException { 052 super(wireFormat, socket, null, null, null); 053 } 054 055 @Override 056 public void setSslContext(SSLContext sslContext) { 057 this.sslContext = sslContext; 058 } 059 060 public ByteBuffer getInputBuffer() { 061 return this.inputBuffer; 062 } 063 064 @Override 065 protected void initializeStreams() throws IOException { 066 NIOOutputStream outputStream = null; 067 try { 068 channel = socket.getChannel(); 069 channel.configureBlocking(false); 070 071 if (sslContext == null) { 072 sslContext = SSLContext.getDefault(); 073 } 074 075 String remoteHost = null; 076 int remotePort = -1; 077 078 try { 079 URI remoteAddress = new URI(this.getRemoteAddress()); 080 remoteHost = remoteAddress.getHost(); 081 remotePort = remoteAddress.getPort(); 082 } catch (Exception e) { 083 } 084 085 // initialize engine, the initial sslSession we get will need to be 086 // updated once the ssl handshake process is completed. 087 if (remoteHost != null && remotePort != -1) { 088 sslEngine = sslContext.createSSLEngine(remoteHost, remotePort); 089 } else { 090 sslEngine = sslContext.createSSLEngine(); 091 } 092 093 if (verifyHostName) { 094 SSLParameters sslParams = new SSLParameters(); 095 sslParams.setEndpointIdentificationAlgorithm("HTTPS"); 096 sslEngine.setSSLParameters(sslParams); 097 } 098 099 sslEngine.setUseClientMode(false); 100 if (enabledCipherSuites != null) { 101 sslEngine.setEnabledCipherSuites(enabledCipherSuites); 102 } 103 104 if (enabledProtocols != null) { 105 sslEngine.setEnabledProtocols(enabledProtocols); 106 } 107 108 if (wantClientAuth) { 109 sslEngine.setWantClientAuth(wantClientAuth); 110 } 111 112 if (needClientAuth) { 113 sslEngine.setNeedClientAuth(needClientAuth); 114 } 115 116 sslSession = sslEngine.getSession(); 117 118 inputBuffer = ByteBuffer.allocate(sslSession.getPacketBufferSize()); 119 inputBuffer.clear(); 120 121 outputStream = new NIOOutputStream(channel); 122 outputStream.setEngine(sslEngine); 123 this.dataOut = new DataOutputStream(outputStream); 124 this.buffOut = outputStream; 125 sslEngine.beginHandshake(); 126 handshakeStatus = sslEngine.getHandshakeStatus(); 127 doHandshake(); 128 129 } catch (Exception e) { 130 try { 131 if(outputStream != null) { 132 outputStream.close(); 133 } 134 super.closeStreams(); 135 } catch (Exception ex) {} 136 throw new IOException(e); 137 } 138 } 139 140 @Override 141 protected void doOpenWireInit() throws Exception { 142 143 } 144 145 public SSLEngine getSslSession() { 146 return this.sslEngine; 147 } 148 149 private volatile byte[] readData; 150 151 private final AtomicInteger readSize = new AtomicInteger(); 152 153 public byte[] getReadData() { 154 return readData != null ? readData : new byte[0]; 155 } 156 157 public AtomicInteger getReadSize() { 158 return readSize; 159 } 160 161 @Override 162 public void serviceRead() { 163 try { 164 if (handshakeInProgress) { 165 doHandshake(); 166 } 167 168 ByteBuffer plain = ByteBuffer.allocate(sslSession.getApplicationBufferSize()); 169 plain.position(plain.limit()); 170 171 while (true) { 172 if (!plain.hasRemaining()) { 173 int readCount = secureRead(plain); 174 175 if (readCount == 0) { 176 break; 177 } 178 179 // channel is closed, cleanup 180 if (readCount == -1) { 181 onException(new EOFException()); 182 break; 183 } 184 185 receiveCounter += readCount; 186 readSize.addAndGet(readCount); 187 } 188 189 if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) { 190 processCommand(plain); 191 //we have received enough bytes to detect the protocol 192 if (receiveCounter >= 8) { 193 break; 194 } 195 } 196 } 197 } catch (IOException e) { 198 onException(e); 199 } catch (Throwable e) { 200 onException(IOExceptionSupport.create(e)); 201 } 202 } 203 204 @Override 205 protected void processCommand(ByteBuffer plain) throws Exception { 206 ByteBuffer newBuffer = ByteBuffer.allocate(receiveCounter); 207 if (readData != null) { 208 newBuffer.put(readData); 209 } 210 newBuffer.put(plain); 211 newBuffer.flip(); 212 readData = newBuffer.array(); 213 } 214 215 216 @Override 217 public void doStart() throws Exception { 218 taskRunnerFactory = new TaskRunnerFactory("ActiveMQ NIOSSLTransport Task"); 219 // no need to init as we can delay that until demand (eg in doHandshake) 220 connect(); 221 } 222 223 224 @Override 225 protected void doStop(ServiceStopper stopper) throws Exception { 226 if (taskRunnerFactory != null) { 227 taskRunnerFactory.shutdownNow(); 228 taskRunnerFactory = null; 229 } 230 } 231 232 233}