001/**
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.activemq.transport.nio;
019
020import java.io.DataOutputStream;
021import java.io.EOFException;
022import java.io.IOException;
023import java.net.Socket;
024import java.net.URI;
025import java.net.UnknownHostException;
026import java.nio.ByteBuffer;
027import java.util.concurrent.atomic.AtomicInteger;
028
029import javax.net.SocketFactory;
030import javax.net.ssl.SSLContext;
031import javax.net.ssl.SSLEngine;
032import javax.net.ssl.SSLEngineResult;
033import javax.net.ssl.SSLParameters;
034
035import org.apache.activemq.thread.TaskRunnerFactory;
036import org.apache.activemq.util.IOExceptionSupport;
037import org.apache.activemq.util.ServiceStopper;
038import org.apache.activemq.wireformat.WireFormat;
039
040/**
041 * This transport initializes the SSLEngine and reads the first command before
042 * handing off to the detected transport.
043 *
044 */
045public class AutoInitNioSSLTransport extends NIOSSLTransport {
046
047    public AutoInitNioSSLTransport(WireFormat wireFormat, SocketFactory socketFactory, URI remoteLocation, URI localLocation) throws UnknownHostException, IOException {
048        super(wireFormat, socketFactory, remoteLocation, localLocation);
049    }
050
051    public AutoInitNioSSLTransport(WireFormat wireFormat, Socket socket) throws IOException {
052        super(wireFormat, socket, null, null, null);
053    }
054
055    @Override
056    public void setSslContext(SSLContext sslContext) {
057        this.sslContext = sslContext;
058    }
059
060    public ByteBuffer getInputBuffer() {
061        return this.inputBuffer;
062    }
063
064    @Override
065    protected void initializeStreams() throws IOException {
066        NIOOutputStream outputStream = null;
067        try {
068            channel = socket.getChannel();
069            channel.configureBlocking(false);
070
071            if (sslContext == null) {
072                sslContext = SSLContext.getDefault();
073            }
074
075            String remoteHost = null;
076            int remotePort = -1;
077
078            try {
079                URI remoteAddress = new URI(this.getRemoteAddress());
080                remoteHost = remoteAddress.getHost();
081                remotePort = remoteAddress.getPort();
082            } catch (Exception e) {
083            }
084
085            // initialize engine, the initial sslSession we get will need to be
086            // updated once the ssl handshake process is completed.
087            if (remoteHost != null && remotePort != -1) {
088                sslEngine = sslContext.createSSLEngine(remoteHost, remotePort);
089            } else {
090                sslEngine = sslContext.createSSLEngine();
091            }
092
093            if (verifyHostName) {
094                SSLParameters sslParams = new SSLParameters();
095                sslParams.setEndpointIdentificationAlgorithm("HTTPS");
096                sslEngine.setSSLParameters(sslParams);
097            }
098
099            sslEngine.setUseClientMode(false);
100            if (enabledCipherSuites != null) {
101                sslEngine.setEnabledCipherSuites(enabledCipherSuites);
102            }
103
104            if (enabledProtocols != null) {
105                sslEngine.setEnabledProtocols(enabledProtocols);
106            }
107
108            if (wantClientAuth) {
109                sslEngine.setWantClientAuth(wantClientAuth);
110            }
111
112            if (needClientAuth) {
113                sslEngine.setNeedClientAuth(needClientAuth);
114            }
115
116            sslSession = sslEngine.getSession();
117
118            inputBuffer = ByteBuffer.allocate(sslSession.getPacketBufferSize());
119            inputBuffer.clear();
120
121            outputStream = new NIOOutputStream(channel);
122            outputStream.setEngine(sslEngine);
123            this.dataOut = new DataOutputStream(outputStream);
124            this.buffOut = outputStream;
125            sslEngine.beginHandshake();
126            handshakeStatus = sslEngine.getHandshakeStatus();
127            doHandshake();
128
129        } catch (Exception e) {
130            try {
131                if(outputStream != null) {
132                    outputStream.close();
133                }
134                super.closeStreams();
135            } catch (Exception ex) {}
136            throw new IOException(e);
137        }
138    }
139
140    @Override
141    protected void doOpenWireInit() throws Exception {
142
143    }
144
145    public SSLEngine getSslSession() {
146        return this.sslEngine;
147    }
148
149    private volatile byte[] readData;
150
151    private final AtomicInteger readSize = new AtomicInteger();
152
153    public byte[] getReadData() {
154        return readData != null ? readData : new byte[0];
155    }
156
157    public AtomicInteger getReadSize() {
158        return readSize;
159    }
160
161    @Override
162    public void serviceRead() {
163        try {
164            if (handshakeInProgress) {
165                doHandshake();
166            }
167
168            ByteBuffer plain = ByteBuffer.allocate(sslSession.getApplicationBufferSize());
169            plain.position(plain.limit());
170
171            while (true) {
172                if (!plain.hasRemaining()) {
173                    int readCount = secureRead(plain);
174
175                    if (readCount == 0) {
176                        break;
177                    }
178
179                    // channel is closed, cleanup
180                    if (readCount == -1) {
181                        onException(new EOFException());
182                        break;
183                    }
184
185                    receiveCounter += readCount;
186                    readSize.addAndGet(readCount);
187                }
188
189                if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
190                    processCommand(plain);
191                    //we have received enough bytes to detect the protocol
192                    if (receiveCounter >= 8) {
193                        break;
194                    }
195                }
196            }
197        } catch (IOException e) {
198            onException(e);
199        } catch (Throwable e) {
200            onException(IOExceptionSupport.create(e));
201        }
202    }
203
204    @Override
205    protected void processCommand(ByteBuffer plain) throws Exception {
206        ByteBuffer newBuffer = ByteBuffer.allocate(receiveCounter);
207        if (readData != null) {
208            newBuffer.put(readData);
209        }
210        newBuffer.put(plain);
211        newBuffer.flip();
212        readData = newBuffer.array();
213    }
214
215
216    @Override
217    public void doStart() throws Exception {
218        taskRunnerFactory = new TaskRunnerFactory("ActiveMQ NIOSSLTransport Task");
219        // no need to init as we can delay that until demand (eg in doHandshake)
220        connect();
221    }
222
223
224    @Override
225    protected void doStop(ServiceStopper stopper) throws Exception {
226        if (taskRunnerFactory != null) {
227            taskRunnerFactory.shutdownNow();
228            taskRunnerFactory = null;
229        }
230    }
231
232
233}