Sunday, June 23, 2013

Secure WebSockets with Jetty

Websockets is a protocol that runs on top of TCP and allows server to push data to the client, unlike with HTTP. Let's see how to use WebSockets with TLS using Eclipse Jetty.

Add following dependencies to the project POM

<dependency>
    <groupId>org.eclipse.jetty</groupId>
    <artifactId>jetty-server</artifactId>
    <version>9.0.3.v20130506</version>
</dependency>
<dependency>
    <groupId>org.eclipse.jetty.websocket</groupId>
    <artifactId>websocket-server</artifactId>
    <version>9.0.3.v20130506</version>
</dependency>
<dependency>
    <groupId>org.eclipse.jetty.websocket</groupId>
    <artifactId>websocket-client</artifactId>
    <version>9.0.3.v20130506</version>
</dependency>

Create a websocket by annotating with @WebSocket

package org.amila.sample.websocket.server;

import org.eclipse.jetty.websocket.api.RemoteEndpoint;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.annotations.OnWebSocketClose;
import org.eclipse.jetty.websocket.api.annotations.OnWebSocketConnect;
import org.eclipse.jetty.websocket.api.annotations.OnWebSocketMessage;
import org.eclipse.jetty.websocket.api.annotations.WebSocket;

import java.io.IOException;

@WebSocket
public class MyWebSocket {
    private RemoteEndpoint remote;

    @OnWebSocketConnect
    public void onConnect(Session session) {
        System.out.println("WebSocket Opened");
        this.remote = session.getRemote();
    }

    @OnWebSocketMessage
    public void onMessage(String message) {
        System.out.println("Message from Client: " + message);
        try {
            remote.sendString("Hi Client");
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    @OnWebSocketClose
    public void onClose(int statusCode, String reason) {
        System.out.println("WebSocket Closed. Code:" + statusCode);
    }
}

This is the jetty server configured with TLS. Pass a SslConnectionFactory when creating the connector to enable secure communication. For this sample, I've generated a keystore and truststore using java keytool and placed them in src/resources.
Call addWebSocket() with your annotated WebSocket pojo to add WebSockets to the server.

package org.amila.sample.websocket.server;

import org.eclipse.jetty.http.HttpVersion;
import org.eclipse.jetty.server.*;
import org.eclipse.jetty.server.handler.ContextHandler;
import org.eclipse.jetty.server.handler.HandlerCollection;
import org.eclipse.jetty.util.resource.FileResource;
import org.eclipse.jetty.util.resource.Resource;
import org.eclipse.jetty.util.ssl.SslContextFactory;
import org.eclipse.jetty.websocket.server.WebSocketHandler;
import org.eclipse.jetty.websocket.servlet.WebSocketServletFactory;

import java.util.ArrayList;
import java.util.List;

public class WebSocketServer {
    private Server server;
    private String host;
    private int port;
    private Resource keyStoreResource;
    private String keyStorePassword;
    private String keyManagerPassword;
    private List<Handler> webSocketHandlerList = new ArrayList<>();

    public static void main(String[] args) throws Exception {
        WebSocketServer webSocketServer = new WebSocketServer();
        webSocketServer.setHost("localhost");
        webSocketServer.setPort(8443);
        webSocketServer.setKeyStoreResource(new FileResource(WebSocketServer.class.getResource("/keystore.jks")));
        webSocketServer.setKeyStorePassword("password");
        webSocketServer.setKeyManagerPassword("password");
        webSocketServer.addWebSocket(MyWebSocket.class, "/");
        webSocketServer.initialize();
        webSocketServer.start();
    }

    public void initialize() {
        server = new Server();
        // connector configuration
        SslContextFactory sslContextFactory = new SslContextFactory();
        sslContextFactory.setKeyStoreResource(keyStoreResource);
        sslContextFactory.setKeyStorePassword(keyStorePassword);
        sslContextFactory.setKeyManagerPassword(keyManagerPassword);
        SslConnectionFactory sslConnectionFactory = new SslConnectionFactory(sslContextFactory, HttpVersion.HTTP_1_1.asString());
        HttpConnectionFactory httpConnectionFactory = new HttpConnectionFactory(new HttpConfiguration());
        ServerConnector sslConnector = new ServerConnector(server, sslConnectionFactory, httpConnectionFactory);
        sslConnector.setHost(host);
        sslConnector.setPort(port);
        server.addConnector(sslConnector);
        // handler configuration
        HandlerCollection handlerCollection = new HandlerCollection();
        handlerCollection.setHandlers(webSocketHandlerList.toArray(new Handler[0]));
        server.setHandler(handlerCollection);
    }

    public void addWebSocket(final Class<?> webSocket, String pathSpec) {
        WebSocketHandler wsHandler = new WebSocketHandler() {
            @Override
            public void configure(WebSocketServletFactory webSocketServletFactory) {
                webSocketServletFactory.register(webSocket);
            }
        };
        ContextHandler wsContextHandler = new ContextHandler();
        wsContextHandler.setHandler(wsHandler);
        wsContextHandler.setContextPath(pathSpec);  // this context path doesn't work ftm
        webSocketHandlerList.add(wsHandler);
    }

    public void start() throws Exception {
        server.start();
        server.join();
    }
    public void stop() throws Exception {
        server.stop();
        server.join();
    }

    public void setHost(String host) {
        this.host = host;
    }
    public void setPort(int port) {
        this.port = port;
    }
    public void setKeyStoreResource(Resource keyStoreResource) {
        this.keyStoreResource = keyStoreResource;
    }
    public void setKeyStorePassword(String keyStorePassword) {
        this.keyStorePassword = keyStorePassword;
    }
    public void setKeyManagerPassword(String keyManagerPassword) {
        this.keyManagerPassword = keyManagerPassword;
    }

}

And finally the client code. WebSocket is included as an inner class. Pass a SslContextFactory when creating the client and sure "wss" as the protocol prefix of the URL.

package org.amila.sample.websocket.client;

import org.eclipse.jetty.util.resource.Resource;
import org.eclipse.jetty.util.ssl.SslContextFactory;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.annotations.OnWebSocketClose;
import org.eclipse.jetty.websocket.api.annotations.OnWebSocketConnect;
import org.eclipse.jetty.websocket.api.annotations.OnWebSocketMessage;
import org.eclipse.jetty.websocket.api.annotations.WebSocket;
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

public class JettyWebSocketClient {

    public static void main(String[] args) throws IOException, URISyntaxException {
        new JettyWebSocketClient().run(new URI("wss://localhost:8443/"));
    }
    
    public void run(URI destinationUri) throws IOException {

        SslContextFactory sslContextFactory = new SslContextFactory();
        Resource keyStoreResource = Resource.newResource(this.getClass().getResource("/truststore.jks"));
        sslContextFactory.setKeyStoreResource(keyStoreResource);
        sslContextFactory.setKeyStorePassword("password");
        sslContextFactory.setKeyManagerPassword("password");
        WebSocketClient client = new WebSocketClient(sslContextFactory);
        MyWebSocket socket = new MyWebSocket();
        try {
            client.start();
            ClientUpgradeRequest request = new ClientUpgradeRequest();
            System.out.println("Connecting to : " + destinationUri);
            client.connect(socket, destinationUri, request);
            socket.awaitClose(5, TimeUnit.SECONDS);
        } catch (Throwable t) {
            t.printStackTrace();
        } finally {
            try {
                client.stop();
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    }

    @WebSocket
    public class MyWebSocket {
        private final CountDownLatch closeLatch = new CountDownLatch(1);

        @OnWebSocketConnect
        public void onConnect(Session session) {
            System.out.println("WebSocket Opened in client side");
            try {
                System.out.println("Sending message: Hi server");
                session.getRemote().sendString("Hi Server");
            } catch (IOException e) {
                e.printStackTrace();
            }
        }

        @OnWebSocketMessage
        public void onMessage(String message) {
            System.out.println("Message from Server: " + message);
        }

        @OnWebSocketClose
        public void onClose(int statusCode, String reason) {
            System.out.println("WebSocket Closed. Code:" + statusCode);
        }

        public boolean awaitClose(int duration, TimeUnit unit) throws InterruptedException {
            return this.closeLatch.await(duration, unit);
        }
    }

}