package com.sap.dbtech.util.security;

import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;

import com.sap.dbtech.util.Tracer;

/**
 * @author D031096
 *  
 */
public class SCRAMMD5 {
   public final static String algorithmname = "SCRAMMD5";
     
    private static byte[] hmacMD5(byte[] data, byte[] key) throws NoSuchAlgorithmException {
        byte[] ipad = new byte[64];
        byte[] opad = new byte[64];
        for (int i = 0; i < 64; i++) {
            ipad[i] = (byte) 0x36;
            opad[i] = (byte) 0x5c;
        }
        for (int i = key.length - 1; i >= 0; i--) {
            ipad[i] ^= key[i];
            opad[i] ^= key[i];
        }
        byte[] content = new byte[data.length + 64];
        System.arraycopy(ipad, 0, content, 0, 64);
        System.arraycopy(data, 0, content, 64, data.length);
        MessageDigest md5 = MessageDigest.getInstance("MD5");
        data = md5.digest(content);
        content = new byte[data.length + 64];
        System.arraycopy(opad, 0, content, 0, 64);
        System.arraycopy(data, 0, content, 64, data.length);
        return md5.digest(content);
    }

    private static void RTESec_DumpHex(String text, byte[] bytearr) {
        System.out.println(text + ": " + Tracer.Hex2String(bytearr));
    }

    /*
     * This section is designed to provide a quick understanding of SCRAM for
     * those who like functional notation.
     *  + octet concatenation XOR the exclusive-or function AU is the
     * authentication user identity (NUL terminated) AZ is the authorization
     * user identity (NUL terminated) if AZ is the same as AU, a single NUL is
     * used instead. csecinfo client security layer option bits and buffer size
     * ssecinfo server security layer option bits and buffer size service is the
     * name of the service and server (NUL terminated) pass is the plain-text
     * passphrase H(x) is a one-way hash function applied to "x", such as MD5
     * MAC(x,y) is a message authentication code (MAC) such as HMAC-MD5 "y" is
     * the key and "x" is the text signed by the key. salt is a per-user salt
     * value the server stores Us is a unique nonce the server sends to the
     * client Uc is a unique nonce the client sends to the server
     * 
     * The SCRAM computations and exchange are as follows:
     * 
     * client-msg-1 = AZ + AU + Uc (1) client -> server: client-msg-1
     * server-msg-1 = salt + ssecinfo + service + Us (2) server -> client:
     * server-msg-1 salted-pass = MAC(salt, pass) client-key = H(salted-pass)
     * client-verifier = H(client-key) shared-key = MAC(server-msg-1 +
     * client-msg-1 + csecinfo, client-verifier) client-proof = client-key XOR
     * shared-key (3) client -> server: csecinfo + client-proof server-key =
     * MAC(salt, salted-pass) server-proof = MAC(client-msg-1 + server-msg-1 +
     * csecinfo, server-key) (4) server -> client: server-proof
     *  
     */

    static public byte[] scrammMD5(byte[] salt, byte[] password,
            byte[] clientkey, byte[] serverkey) throws NoSuchAlgorithmException {

        MessageDigest md5 = MessageDigest.getInstance("MD5");

//        RTESec_DumpHex("scrammMD5: salt", salt);
//        RTESec_DumpHex("scrammMD5: password", password);
//        RTESec_DumpHex("scrammMD5: clientkey", clientkey);
//        RTESec_DumpHex("scrammMD5: serverkey", serverkey);

        byte[] salted_pass = hmacMD5(salt, password);
//        RTESec_DumpHex("scrammMD5: salted_pass", salted_pass);

        byte[] client_key = md5.digest(salted_pass);
//        RTESec_DumpHex("scrammMD5: client_key", client_key);

        byte[] client_verifier = md5.digest(client_key);
//        RTESec_DumpHex("scrammMD5: client_verifier", client_verifier);

        int saltLen = salt.length;
        int serverkeyLen = serverkey.length;
        int clientkeyLen = clientkey.length;
        byte[] content = new byte[saltLen + serverkeyLen + clientkeyLen];
        System.arraycopy(salt, 0, content, 0, saltLen);
        System.arraycopy(serverkey, 0, content, saltLen, serverkeyLen);
        System.arraycopy(clientkey, 0, content, saltLen + serverkeyLen,
                clientkey.length);
//        RTESec_DumpHex("scrammMD5: sharedKey_content", content);

        byte[] shared_key = hmacMD5(content, client_verifier);
//        RTESec_DumpHex("scrammMD5: shared_key", shared_key);

        byte[] client_proof = new byte[shared_key.length];
        for (int i = shared_key.length - 1; i >= 0; i--) {
            client_proof[i] = (byte) (shared_key[i] ^ client_key[i]);
        }
//        RTESec_DumpHex("scrammMD5: client_proof", client_proof);
        return client_proof;
    }

    public static void maintest(String[] args) {
        String client_proof = "5d7f4505accba4e92c5778ea808dbc6a";
        byte[] salt = "der Salt".getBytes();
        byte[] password = "secret".getBytes();
        byte[] clientkey = "eine UserId und eine Zufallszahl".getBytes();
        byte[] serverkey = "-Value und eine andere Zufallszahl".getBytes();
        try {
            byte[] erg = scrammMD5(salt, password, clientkey, serverkey);
            RTESec_DumpHex("scrammMD5: client_proof", erg);
            String ergAsString = Tracer.Hex2String(erg);
            if (!client_proof.equalsIgnoreCase(ergAsString)) {
                System.out.println("Error wrong client proof: \nfound: "
                        + ergAsString + "\nexpected " + client_proof);
            } else {
                System.out.println("Correct client proof computed");
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}
