/*
 *  Copyright 2001-2005 Internet2
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/* signtest.cpp - test harness for digital signature operations

   Scott Cantor
   4/17/03

   $History:$
*/

#include "../saml/saml.h"

#include <fstream>
#include <iostream>

#include <xsec/framework/XSECException.hpp>
#include <xsec/enc/XSECCryptoException.hpp>
#include <xsec/enc/OpenSSL/OpenSSLCryptoX509.hpp>
#include <xsec/enc/OpenSSL/OpenSSLCryptoKeyRSA.hpp>

#include <openssl/bio.h>
#include <openssl/err.h>
#include <openssl/rsa.h>
#include <openssl/evp.h>
#include <openssl/pem.h>

using namespace std;
using namespace saml;

int main(int argc,char* argv[])
{
    char* path=NULL;
    char* key=NULL;
    char* cert=NULL;
    string type;
    bool sign=false;
    string usage="usage: signtest [-k key_file] [-c cert_file] -d schema_dir -type Assertion|Request|Response sign|verify";

    for (int i=1; i<argc; i++)
    {
        if (!strcmp(argv[i],"-k") && i+1<argc)
            key=argv[++i];
        else if (!strcmp(argv[i],"-c") && i+1<argc)
            cert=argv[++i];
        else if (!strcmp(argv[i],"-d") && i+1<argc)
            path=argv[++i];
        else if (!strcmp(argv[i],"-type") && i+1<argc)
        {
            type=argv[++i];
            if (type!="Assertion" && type!="Request" && type!="Response")
            {
                cerr << usage << endl;
                exit(0);
            }
        }
    }
    
    if (argc>1)
    {
        if (!strcmp(argv[argc-1],"sign"))
        {
            sign=true;
            if (!key || !cert)
            {
                cerr << usage << endl;
                exit(0);
            }
        }
        else if (!strcmp(argv[argc-1],"verify"))
            sign=false;
        else
        {
            cerr << usage << endl;
            exit(0);
        }
    }
    else
    {
        cerr << usage << endl;
        exit(0);
    }

    SAMLConfig& conf1=SAMLConfig::getConfig();
    conf1.schema_dir=path;
    //conf1.compatibility_mode=true;
    if (!conf1.init())
        cerr << "unable to initialize SAML runtime" << endl;

    SAMLSignedObject* obj=NULL;
    try
    {
        if (type=="Assertion")
            obj=new SAMLAssertion(cin);
        else if (type=="Request")
            obj=new SAMLRequest(cin);
        else
            obj=new SAMLResponse(cin);
        if (sign)
        {
            BIO *bio=BIO_new(BIO_s_file());
            BIO_read_filename(bio,key);
            EVP_PKEY* pkey=PEM_read_bio_PrivateKey(bio,NULL,NULL,NULL);
            OpenSSLCryptoKeyRSA* xseckey=new OpenSSLCryptoKeyRSA(pkey);
            if (cert)
            {
                // Load the certificate, stripping the first and last lines.
                string certbuf,line;
                auto_ptr<OpenSSLCryptoX509> x509(new OpenSSLCryptoX509());
                ifstream infile(cert);
                while (!getline(infile,line).fail())
                    if (line.find("CERTIFICATE")==string::npos)
                        certbuf+=line + '\n';
                x509->loadX509Base64Bin(certbuf.data(),certbuf.length());

                XSECCryptoX509* certs[] = { x509.get() };
                if (type=="Response")
                {
                    Iterator<SAMLAssertion*> i=dynamic_cast<SAMLResponse*>(obj)->getAssertions();
                    if (i.hasNext())
                        i.next()->sign(xseckey->clone(),ArrayIterator<XSECCryptoX509*>(certs,1));
                }
                obj->sign(xseckey->clone(),ArrayIterator<XSECCryptoX509*>(certs,1));
                delete xseckey;
            }
            else
            {
                if (type=="Response")
                {
                    Iterator<SAMLAssertion*> i=dynamic_cast<SAMLResponse*>(obj)->getAssertions();
                    if (i.hasNext())
                        i.next()->sign(xseckey->clone());
                }
                obj->sign(xseckey);
            }
            EVP_PKEY_free(pkey);
            cout << *obj;
        }
        else
        {
            if (cert)
            {
                // Load the certificate, stripping the first and last lines.
                string certbuf,line;
                auto_ptr<OpenSSLCryptoX509> x509(new OpenSSLCryptoX509());
                ifstream infile(cert);
                while (!getline(infile,line).fail())
                    if (line.find("CERTIFICATE")==string::npos)
                        certbuf+=line + '\n';
                x509->loadX509Base64Bin(certbuf.data(),certbuf.length());
                if (type=="Response")
                {
                    Iterator<SAMLAssertion*> i=dynamic_cast<SAMLResponse*>(obj)->getAssertions();
                    if (i.hasNext())
                    {
                        SAMLAssertion* a = i.next();
                        if (a->isSigned())
                            a->verify(*x509);
                    }
                }
                obj->verify(*x509);
            }
            else
            {
                if (type=="Response")
                {
                    Iterator<SAMLAssertion*> i=dynamic_cast<SAMLResponse*>(obj)->getAssertions();
                    if (i.hasNext())
                    {
                        SAMLAssertion* a = i.next();
                        if (a->isSigned())
                            a->verify();
                    }
                }
                obj->verify();
            }
            cout << "Success!" << endl;
        }
    }
    catch(SAMLException& e)
    {
        cerr << "caught a SAML exception: " << e << endl;
    }
    catch(XSECException& e)
    {
        cerr << "caught an XMLSec exception: "; xmlout(cerr,e.getMsg()); cerr << endl;
    }
    catch(XSECCryptoException& e)
    {
        cerr << "caught an XMLSecCrypto exception: " << e.getMsg() << endl;
    }
    catch(XMLException& e)
    {
        cerr << "caught an XML exception: "; xmlout(cerr,e.getMessage()); cerr << endl;
    }
/*    catch(...)
    {
        cerr << "caught an unknown exception" << endl;
    }*/

    delete obj;
    conf1.term();
    return 0;
}

