/*
 * ========== licence begin GPL
    Copyright (C) 2002-2003 SAP AG

    This program is free software; you can redistribute it and/or
    modify it under the terms of the GNU General Public License
    as published by the Free Software Foundation; either version 2
    of the License, or (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the Free Software
    Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
    ========== licence end
 * 
 */
package com.sap.dbtech.rte.comm;

import java.io.IOException;
import java.net.SocketException;
import java.net.UnknownHostException;
import java.security.cert.Certificate;
import java.security.cert.X509Certificate;
import java.util.Properties;

import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
import javax.security.auth.x500.X500Principal;

import com.sap.dbtech.jdbc.DriverSapDB;
import com.sap.dbtech.util.MessageKey;
import com.sap.dbtech.util.MessageTranslator;

/**
 * Communication that uses the SSL features of Java 1.4.x to connect to the
 * database.
 */
public class SecureCommunication extends BasicSocketComm {

    public final static JdbcCommFactory factory = new JdbcCommFactory() {
        public JdbcCommunication open(String host, String dbname,
                Properties properties) throws RTEException {
            SecureCommunication sc = new SecureCommunication(host, properties);
            sc.connectDB(dbname);
            return sc;

        }

        public JdbcCommunication xopen(String host, String db, String dbroot,
                String pgm, Properties properties) throws RTEException {
            SecureCommunication sc = new SecureCommunication(host, properties);
            sc.connectAdmin(db, dbroot, pgm);
            return sc;
        }

    };

    // instance variables

    private boolean ignoreServerCertificate;

    private boolean ignoreHostNameInCertificate;

    /**
     * creates a new socket connection to <i>host </i>.
     * 
     * @exception RTEException
     */
    private SecureCommunication(String hostPort, Properties properties)
            throws RTEException {
        super(hostPort, properties);
        this.ignoreHostNameInCertificate = DriverSapDB.getBooleanProperty(
                properties, "ignoreHostNameInCert", false);
        this.ignoreServerCertificate = DriverSapDB.getBooleanProperty(
                properties, "acceptServerCertAlways", false);
        this.openSocket();
    }

    /**
     * opens a socket connection.
     * 
     * This converts any java.net specific exceptions to RTEException.
     * 
     * @param host
     *            java.lang.String
     * @exception RTEException
     */
    protected void openSocket() throws RTEException {
        try {

            SSLContext sc = null;
            SSLSocketFactory factory = null;

            if (this.ignoreServerCertificate) {
                TrustManager[] trustAllCerts = new TrustManager[] { new X509TrustManager() {
                    public java.security.cert.X509Certificate[] getAcceptedIssuers() {
                        return null;
                    }

                    public void checkClientTrusted(
                            java.security.cert.X509Certificate[] certs,
                            String authType) {
                    }

                    public void checkServerTrusted(
                            java.security.cert.X509Certificate[] certs,
                            String authType) {
                    }
                } };

                sc = SSLContext.getInstance("SSL");
                sc.init(null, trustAllCerts, new java.security.SecureRandom());
                factory = (SSLSocketFactory) sc.getSocketFactory();
            } else {
                factory = (SSLSocketFactory) SSLSocketFactory.getDefault();
            }
            this.socket = factory.createSocket(this.host, this.lookupPort());
            SSLSocket sslsocket = (SSLSocket) socket;
            sslsocket.startHandshake();

            if (!this.ignoreHostNameInCertificate) {
                Certificate[] certs = sslsocket.getSession()
                        .getPeerCertificates();
                if (certs.length == 0) {
                    throw new RTEException(
                            MessageTranslator
                                    .translate(
                                            MessageKey.ERROR_HOST_CONNECT,
                                            this.host,
                                            "No certificate in SSL session",
                                            new Integer(
                                                    RteC.CommunicationErrorCodeMap_C[RteC.SQLSTART_REQUIRED_C])),
                            RteC.CommunicationErrorCodeMap_C[RteC.SQLSTART_REQUIRED_C]);
                }
                try {
                    X509Certificate x509cert = (X509Certificate) certs[0];
                    X500Principal principal = x509cert
                            .getSubjectX500Principal();
                    String rfc2253name = principal
                            .getName(X500Principal.RFC2253);
                    if (!validate(rfc2253name, this.host)) {
                        throw new RTEException(
                                MessageTranslator
                                        .translate(
                                                MessageKey.ERROR_HOST_CONNECT,
                                                this.host,
                                                "Host name verification failed, found "
                                                        + rfc2253name
                                                        + ", expected CN="
                                                        + this.host,
                                                new Integer(
                                                        RteC.CommunicationErrorCodeMap_C[RteC.SQLSTART_REQUIRED_C])),
                                RteC.CommunicationErrorCodeMap_C[RteC.SQLSTART_REQUIRED_C]);

                    }
                } catch (ClassCastException classCastEx) {
                    throw new RTEException(
                            MessageTranslator
                                    .translate(
                                            MessageKey.ERROR_HOST_CONNECT,
                                            this.host,
                                            "SSL connection works currently only with X509 certificates",
                                            new Integer(
                                                    RteC.CommunicationErrorCodeMap_C[RteC.SQLSTART_REQUIRED_C])),
                            RteC.CommunicationErrorCodeMap_C[RteC.SQLSTART_REQUIRED_C]);

                }
            }

            try {
                sslsocket.setSoTimeout(this.socketTimeOut);
                sslsocket.setTcpNoDelay(true);
                sslsocket.setReceiveBufferSize(36864);
                sslsocket.setSendBufferSize(36864);
            } catch (SocketException socketEx) {
                // ignore, as it is harmless
            }

            this.instream = this.socket.getInputStream();
            this.outstream = this.socket.getOutputStream();

        } catch (java.security.NoSuchAlgorithmException noSuchAlg) {
            throw new RTEException(
                    MessageTranslator
                            .translate(
                                    MessageKey.ERROR_HOST_CONNECT,
                                    this.host,
                                    noSuchAlg.getMessage(),
                                    new Integer(
                                            RteC.CommunicationErrorCodeMap_C[RteC.SQLSTART_REQUIRED_C])),
                    RteC.CommunicationErrorCodeMap_C[RteC.SQLSTART_REQUIRED_C]);

        } catch (java.security.KeyManagementException keyManagementEx) {
            throw new RTEException(
                    MessageTranslator
                            .translate(
                                    MessageKey.ERROR_HOST_CONNECT,
                                    this.host,
                                    keyManagementEx.getMessage(),
                                    new Integer(
                                            RteC.CommunicationErrorCodeMap_C[RteC.SQLSTART_REQUIRED_C])),
                    RteC.CommunicationErrorCodeMap_C[RteC.SQLSTART_REQUIRED_C]);
        } catch (UnknownHostException uhexc) {
            throw new RTEException(
                    MessageTranslator
                            .translate(
                                    MessageKey.ERROR_UNKNOWN_HOST,
                                    this.host,
                                    uhexc.getMessage(),
                                    new Integer(
                                            RteC.CommunicationErrorCodeMap_C[RteC.SQLSERVER_OR_DB_UNKNOWN_C])),
                    RteC.CommunicationErrorCodeMap_C[RteC.SQLSERVER_OR_DB_UNKNOWN_C]);
        } catch (IOException ioexc) {
            throw new RTEException(
                    MessageTranslator
                            .translate(
                                    MessageKey.ERROR_HOST_CONNECT,
                                    this.host,
                                    ioexc.getMessage(),
                                    new Integer(
                                            RteC.CommunicationErrorCodeMap_C[RteC.SQLSTART_REQUIRED_C])),
                    RteC.CommunicationErrorCodeMap_C[RteC.SQLSTART_REQUIRED_C]);
        }
    }

    private boolean validate(String rfc2253name, String host) {
        int comma = rfc2253name.indexOf(",");
        String commonNamePart = rfc2253name.substring(0, comma).toUpperCase();
        String compare = ("CN=" + host).toUpperCase();
        return compare.equals(commonNamePart);
    }

    /*
     * (non-Javadoc)
     * 
     * @see com.sap.dbtech.rte.comm.BasicSocketComm#getNewCommunication()
     */
    protected BasicSocketComm getNewCommunication() throws RTEException {
        return new SecureCommunication(this.host + ":" + this.port, null);
    }

    /*
     * (non-Javadoc)
     * 
     * @see com.sap.dbtech.rte.comm.BasicSocketComm#getDefaultPort()
     */
    protected int getDefaultPort() {
        return RteC.defaultSecurePort_C;
    }
    /* (non-Javadoc)
     * @see com.sap.dbtech.rte.comm.BasicSocketComm#socketMustClosedAfterInfoRequest()
     */

}
