001/**
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.activemq.transport;
018
019import java.io.IOException;
020import java.io.InterruptedIOException;
021import java.net.Socket;
022import java.util.concurrent.CountDownLatch;
023import java.util.concurrent.TimeUnit;
024import java.util.concurrent.atomic.AtomicBoolean;
025
026import org.apache.activemq.command.Command;
027import org.apache.activemq.command.WireFormatInfo;
028import org.apache.activemq.openwire.OpenWireFormat;
029import org.apache.activemq.util.IOExceptionSupport;
030import org.slf4j.Logger;
031import org.slf4j.LoggerFactory;
032
033/**
034 * Negotiates the wire format with a new connection
035 */
036public class WireFormatNegotiator extends TransportFilter {
037
038    private static final Logger LOG = LoggerFactory.getLogger(WireFormatNegotiator.class);
039
040    private OpenWireFormat wireFormat;
041    private final int minimumVersion;
042    private long negotiateTimeout = 15000L;
043
044    private final AtomicBoolean firstStart = new AtomicBoolean(true);
045    private final CountDownLatch readyCountDownLatch = new CountDownLatch(1);
046    private final CountDownLatch wireInfoSentDownLatch = new CountDownLatch(1);
047
048    /**
049     * Negotiator
050     * 
051     * @param next
052     */
053    public WireFormatNegotiator(Transport next, OpenWireFormat wireFormat, int minimumVersion) {
054        super(next);
055        this.wireFormat = wireFormat;
056        if (minimumVersion <= 0) {
057            minimumVersion = 1;
058        }
059        this.minimumVersion = minimumVersion;
060        
061        // Setup the initial negociation timeout to be the same as the inital max inactivity delay specified on the wireformat
062        // Does not make sense for us to take longer.
063        try {
064            if( wireFormat.getPreferedWireFormatInfo() !=null ) {
065                setNegotiateTimeout(wireFormat.getPreferedWireFormatInfo().getMaxInactivityDurationInitalDelay());
066            }
067        } catch (IOException e) {
068        }
069    }
070
071    public void start() throws Exception {
072        super.start();
073        if (firstStart.compareAndSet(true, false)) {
074            sendWireFormat();
075        }
076    }
077
078    public void sendWireFormat() throws IOException {
079        try {
080            WireFormatInfo info = wireFormat.getPreferedWireFormatInfo();
081            if (LOG.isDebugEnabled()) {
082                LOG.debug("Sending: " + info);
083            }
084            sendWireFormat(info);
085        } finally {
086            wireInfoSentDownLatch.countDown();
087        }
088    }
089
090    public void stop() throws Exception {
091        super.stop();
092        readyCountDownLatch.countDown();
093    }
094
095    public void oneway(Object command) throws IOException {
096        try {
097            if (!readyCountDownLatch.await(negotiateTimeout, TimeUnit.MILLISECONDS)) {
098                throw new IOException("Wire format negotiation timeout: peer did not send his wire format.");
099            }
100        } catch (InterruptedException e) {
101            Thread.currentThread().interrupt();
102            throw new InterruptedIOException();
103        }
104        super.oneway(command);
105    }
106
107    public void onCommand(Object o) {
108        Command command = (Command)o;
109        if (command.isWireFormatInfo()) {
110            WireFormatInfo info = (WireFormatInfo)command;
111            negociate(info);
112        }
113        getTransportListener().onCommand(command);
114    }
115
116    public void negociate(WireFormatInfo info) {
117        if (LOG.isDebugEnabled()) {
118            LOG.debug("Received WireFormat: " + info);
119        }
120
121        try {
122            wireInfoSentDownLatch.await();
123
124            if (LOG.isDebugEnabled()) {
125                LOG.debug(this + " before negotiation: " + wireFormat);
126            }
127            if (!info.isValid()) {
128                onException(new IOException("Remote wire format magic is invalid"));
129            } else if (info.getVersion() < minimumVersion) {
130                onException(new IOException("Remote wire format (" + info.getVersion() + ") is lower the minimum version required (" + minimumVersion + ")"));
131            }
132
133            wireFormat.renegotiateWireFormat(info);
134            Socket socket = next.narrow(Socket.class);
135            if (socket != null) {
136                socket.setTcpNoDelay(wireFormat.isTcpNoDelayEnabled());
137            }
138
139            if (LOG.isDebugEnabled()) {
140                LOG.debug(this + " after negotiation: " + wireFormat);
141            }
142
143        } catch (IOException e) {
144            onException(e);
145        } catch (InterruptedException e) {
146            onException((IOException)new InterruptedIOException().initCause(e));
147        } catch (Exception e) {
148            onException(IOExceptionSupport.create(e));
149        }
150        readyCountDownLatch.countDown();
151        onWireFormatNegotiated(info);
152    }
153
154    public void onException(IOException error) {
155        readyCountDownLatch.countDown();
156        /*
157         * try { super.oneway(new ExceptionResponse(error)); } catch
158         * (IOException e) { // ignore as we are already throwing an exception }
159         */
160        super.onException(error);
161    }
162
163    public String toString() {
164        return next.toString();
165    }
166
167    protected void sendWireFormat(WireFormatInfo info) throws IOException {
168        next.oneway(info);
169    }
170
171    protected void onWireFormatNegotiated(WireFormatInfo info) {
172    }
173
174    public long getNegotiateTimeout() {
175        return negotiateTimeout;
176    }
177
178    public void setNegotiateTimeout(long negotiateTimeout) {
179        this.negotiateTimeout = negotiateTimeout;
180    }
181}