import sys
import os
import numpy as np
import pickle

""" Extracts parameter for sidechain and backbone construction from highres
X-ray structures. High res structures are defined in accompanying PISCES file
(cullpdb*). If run outside the sciCORE environment, you need to adapt the
paths in io.LoadMMCIF function defined in this file.

The parameters are printed to stdout and need to be added manually in

 - sidechain_atom_rule_lookup.cc
 - bb_trace_param.cc

in the loop module source. Have fun
"""




if len(sys.argv) != 2:
    print("Usage: ost parameter_extraction.py pisces_file")
    sys.exit(1)

pisces_file = sys.argv[1]

# aa specific
bond_param = dict()
angle_param = dict()
dihedral_param = dict()


# peptide bond
pep_bond = list()
ca_c_n_angle = list()
c_n_ca_angle = list()


# special dihedrals
leu_cd1_cb_cg_cd2_dihedrals = list()
ile_cg1_ca_cb_cg2_dihedrals = list()
val_cg1_ca_cb_cg2_dihedrals = list()
thr_og1_ca_cb_cg2_dihedrals = list()

# C_O bond length
c_o_bond = list()

# N_C_CA_CB_DiAngle
n_c_ca_cb_diangle = dict()
n_c_ca_cb_diangle["ALA"] = list()
n_c_ca_cb_diangle["ARG"] = list()
n_c_ca_cb_diangle["ASN"] = list()
n_c_ca_cb_diangle["ASP"] = list()
n_c_ca_cb_diangle["GLN"] = list()
n_c_ca_cb_diangle["GLU"] = list()
n_c_ca_cb_diangle["LYS"] = list()
n_c_ca_cb_diangle["SER"] = list()
n_c_ca_cb_diangle["CYS"] = list()
n_c_ca_cb_diangle["MET"] = list()
n_c_ca_cb_diangle["TRP"] = list()
n_c_ca_cb_diangle["TYR"] = list()
n_c_ca_cb_diangle["THR"] = list()
n_c_ca_cb_diangle["VAL"] = list()
n_c_ca_cb_diangle["ILE"] = list()
n_c_ca_cb_diangle["LEU"] = list()
n_c_ca_cb_diangle["GLY"] = list()
n_c_ca_cb_diangle["PRO"] = list()
n_c_ca_cb_diangle["HIS"] = list()
n_c_ca_cb_diangle["PHE"] = list()


bond_param["ALA"] = dict()
bond_param["ARG"] = dict()
bond_param["ASN"] = dict()
bond_param["ASP"] = dict()
bond_param["GLN"] = dict()
bond_param["GLU"] = dict()
bond_param["LYS"] = dict()
bond_param["SER"] = dict()
bond_param["CYS"] = dict()
bond_param["MET"] = dict()
bond_param["TRP"] = dict()
bond_param["TYR"] = dict()
bond_param["THR"] = dict()
bond_param["VAL"] = dict()
bond_param["ILE"] = dict()
bond_param["LEU"] = dict()
bond_param["GLY"] = dict()
bond_param["PRO"] = dict()
bond_param["HIS"] = dict()
bond_param["PHE"] = dict()

angle_param["ALA"] = dict()
angle_param["ARG"] = dict()
angle_param["ASN"] = dict()
angle_param["ASP"] = dict()
angle_param["GLN"] = dict()
angle_param["GLU"] = dict()
angle_param["LYS"] = dict()
angle_param["SER"] = dict()
angle_param["CYS"] = dict()
angle_param["MET"] = dict()
angle_param["TRP"] = dict()
angle_param["TYR"] = dict()
angle_param["THR"] = dict()
angle_param["VAL"] = dict()
angle_param["ILE"] = dict()
angle_param["LEU"] = dict()
angle_param["GLY"] = dict()
angle_param["PRO"] = dict()
angle_param["HIS"] = dict()
angle_param["PHE"] = dict()

# backbone parameter
for k in bond_param.keys():
    bond_param[k] = dict()
    bond_param[k][("N", "CA")] = list()
    bond_param[k][("CA", "C")] = list()
    bond_param[k][("C", "O")] = list()
    bond_param[k][("CA", "CB")] = list()

for k in angle_param.keys():
    angle_param[k] = dict()
    angle_param[k][("N", "CA", "C")] = list()
    angle_param[k][("C", "CA", "CB")] = list()


# sidechain parameter
bond_param["ARG"][("CB", "CG")] = list()
bond_param["ARG"][("CG", "CD")] = list()
bond_param["ARG"][("CD", "NE")] = list()
bond_param["ARG"][("NE", "CZ")] = list()
bond_param["ARG"][("CZ", "NH1")] = list()
bond_param["ARG"][("CZ", "NH2")] = list()
angle_param["ARG"][("CA", "CB", "CG")] = list()
angle_param["ARG"][("CB", "CG", "CD")] = list()
angle_param["ARG"][("CG", "CD", "NE")] = list()
angle_param["ARG"][("CD", "NE", "CZ")] = list()
angle_param["ARG"][("NE", "CZ", "NH1")] = list()
angle_param["ARG"][("NE", "CZ", "NH2")] = list()


bond_param["ASN"][("CB", "CG")] = list()
bond_param["ASN"][("CG", "OD1")] = list()
bond_param["ASN"][("CG", "ND2")] = list()
angle_param["ASN"][("CA", "CB", "CG")] = list()
angle_param["ASN"][("CB", "CG", "OD1")] = list()
angle_param["ASN"][("CB", "CG", "ND2")] = list()


bond_param["ASP"][("CB", "CG")] = list()
bond_param["ASP"][("CG", "OD1")] = list()
bond_param["ASP"][("CG", "OD2")] = list()
angle_param["ASP"][("CA", "CB", "CG")] = list()
angle_param["ASP"][("CB", "CG", "OD1")] = list()
angle_param["ASP"][("CB", "CG", "OD2")] = list()



bond_param["GLN"][("CB", "CG")] = list()
bond_param["GLN"][("CG", "CD")] = list()
bond_param["GLN"][("CD", "OE1")] = list()
bond_param["GLN"][("CD", "NE2")] = list()
angle_param["GLN"][("CA", "CB", "CG")] = list()
angle_param["GLN"][("CB", "CG", "CD")] = list()
angle_param["GLN"][("CG", "CD", "OE1")] = list()
angle_param["GLN"][("CG", "CD", "NE2")] = list()


bond_param["GLU"][("CB", "CG")] = list()
bond_param["GLU"][("CG", "CD")] = list()
bond_param["GLU"][("CD", "OE1")] = list()
bond_param["GLU"][("CD", "OE2")] = list()
angle_param["GLU"][("CA", "CB", "CG")] = list()
angle_param["GLU"][("CB", "CG", "CD")] = list()
angle_param["GLU"][("CG", "CD", "OE1")] = list()
angle_param["GLU"][("CG", "CD", "OE2")] = list()


bond_param["LYS"][("CB", "CG")] = list()
bond_param["LYS"][("CG", "CD")] = list()
bond_param["LYS"][("CD", "CE")] = list()
bond_param["LYS"][("CE", "NZ")] = list()
angle_param["LYS"][("CA", "CB", "CG")] = list()
angle_param["LYS"][("CB", "CG", "CD")] = list()
angle_param["LYS"][("CG", "CD", "CE")] = list()
angle_param["LYS"][("CD", "CE", "NZ")] = list()


bond_param["SER"][("CB", "OG")] = list()
angle_param["SER"][("CA", "CB", "OG")] = list()


bond_param["CYS"][("CB", "SG")] = list()
angle_param["CYS"][("CA", "CB", "SG")] = list()


bond_param["MET"][("CB", "CG")] = list()
bond_param["MET"][("CG", "SD")] = list()
bond_param["MET"][("SD", "CE")] = list()
angle_param["MET"][("CA", "CB", "CG")] = list()
angle_param["MET"][("CB", "CG", "SD")] = list()
angle_param["MET"][("CG", "SD", "CE")] = list()


bond_param["TRP"][("CB", "CG")] = list()
bond_param["TRP"][("CG", "CD1")] = list()
bond_param["TRP"][("CG", "CD2")] = list()
bond_param["TRP"][("CD2", "CE2")] = list()
bond_param["TRP"][("CD1", "NE1")] = list()
bond_param["TRP"][("CD2", "CE3")] = list()
bond_param["TRP"][("CE3", "CZ3")] = list()
bond_param["TRP"][("CZ3", "CH2")] = list()
bond_param["TRP"][("CH2", "CZ2")] = list()
angle_param["TRP"][("CA", "CB", "CG")] = list()
angle_param["TRP"][("CB", "CG", "CD1")] = list()
angle_param["TRP"][("CB", "CG", "CD2")] = list()
angle_param["TRP"][("CG", "CD2", "CE2")] = list()
angle_param["TRP"][("CG", "CD1", "NE1")] = list()
angle_param["TRP"][("CG", "CD2", "CE3")] = list()
angle_param["TRP"][("CD2", "CE3", "CZ3")] = list()
angle_param["TRP"][("CE3", "CZ3", "CH2")] = list()
angle_param["TRP"][("CZ3", "CH2", "CZ2")] = list()


bond_param["TYR"][("CB", "CG")] = list()
bond_param["TYR"][("CG", "CD1")] = list()
bond_param["TYR"][("CG", "CD2")] = list()
bond_param["TYR"][("CD1", "CE1")] = list()
bond_param["TYR"][("CD2", "CE2")] = list()
bond_param["TYR"][("CE1", "CZ")] = list()
bond_param["TYR"][("CZ", "OH")] = list()
angle_param["TYR"][("CA", "CB", "CG")] = list()
angle_param["TYR"][("CB", "CG", "CD1")] = list()
angle_param["TYR"][("CB", "CG", "CD2")] = list()
angle_param["TYR"][("CG", "CD1", "CE1")] = list()
angle_param["TYR"][("CG", "CD2", "CE2")] = list()
angle_param["TYR"][("CD1", "CE1", "CZ")] = list()
angle_param["TYR"][("CE1", "CZ", "OH")] = list()


bond_param["THR"][("CB", "OG1")] = list()
bond_param["THR"][("CB", "CG2")] = list()
angle_param["THR"][("CA", "CB", "OG1")] = list()
angle_param["THR"][("CA", "CB", "CG2")] = list()


bond_param["VAL"][("CB", "CG1")] = list()
bond_param["VAL"][("CB", "CG2")] = list()
angle_param["VAL"][("CA", "CB", "CG1")] = list()
angle_param["VAL"][("CA", "CB", "CG2")] = list()


bond_param["ILE"][("CB", "CG1")] = list()
bond_param["ILE"][("CB", "CG2")] = list()
bond_param["ILE"][("CG1", "CD1")] = list()
angle_param["ILE"][("CA", "CB", "CG1")] = list()
angle_param["ILE"][("CA", "CB", "CG2")] = list()
angle_param["ILE"][("CB", "CG1", "CD1")] = list()


bond_param["LEU"][("CB", "CG")] = list()
bond_param["LEU"][("CG", "CD1")] = list()
bond_param["LEU"][("CG", "CD2")] = list()
angle_param["LEU"][("CA", "CB", "CG")] = list()
angle_param["LEU"][("CB", "CG", "CD1")] = list()
angle_param["LEU"][("CB", "CG", "CD2")] = list()


bond_param["PRO"][("CB", "CG")] = list()
bond_param["PRO"][("CG", "CD")] = list()
angle_param["PRO"][("CA", "CB", "CG")] = list()
angle_param["PRO"][("CB", "CG", "CD")] = list()


bond_param["HIS"][("CB", "CG")] = list()
bond_param["HIS"][("CG", "ND1")] = list()
bond_param["HIS"][("CG", "CD2")] = list()
bond_param["HIS"][("ND1", "CE1")] = list()
bond_param["HIS"][("CD2", "NE2")] = list()
angle_param["HIS"][("CA", "CB", "CG")] = list()
angle_param["HIS"][("CB", "CG", "ND1")] = list()
angle_param["HIS"][("CB", "CG", "CD2")] = list()
angle_param["HIS"][("CG", "ND1", "CE1")] = list()
angle_param["HIS"][("CG", "CD2", "NE2")] = list()


bond_param["PHE"][("CB", "CG")] = list()
bond_param["PHE"][("CG", "CD1")] = list()
bond_param["PHE"][("CG", "CD2")] = list()
bond_param["PHE"][("CD1", "CE1")] = list()
bond_param["PHE"][("CD2", "CE2")] = list()
bond_param["PHE"][("CE1", "CZ")] = list()
angle_param["PHE"][("CA", "CB", "CG")] = list()
angle_param["PHE"][("CB", "CG", "CD1")] = list()
angle_param["PHE"][("CB", "CG", "CD2")] = list()
angle_param["PHE"][("CG", "CD1", "CE1")] = list()
angle_param["PHE"][("CG", "CD2", "CE2")] = list()
angle_param["PHE"][("CD1", "CE1", "CZ")] = list()

with open(pisces_file) as fh:
    pisces_data = fh.readlines()


for line_idx, line in enumerate(pisces_data[1:]):
    tmp = line.split()[0].strip()
    pdb_id = tmp[:4]
    chain_id = tmp[4:]
    path_prefix = pdb_id.lower()[1:3]

    print(line_idx, pdb_id)

    structure_path = os.path.join("/scicore/data/managed/PDB/latest/data/structures/divided/mmCIF", path_prefix, pdb_id.lower() + ".cif.gz")
    if not os.path.exists(structure_path):
        print("could not find",structure_path,"skip...")
        continue

    try:
        ent = io.LoadMMCIF(structure_path)
    except:
        print("failed to load", tmp, "skip...")

    chain = None
    for ch in ent.chains:
        if ch.GetType() == mol.CHAINTYPE_POLY_PEPTIDE_L and ch.GetStringProp("pdb_auth_chain_name") == chain_id:
            chain = ch
            break

    if chain is None:
        print(pdb_id, "has no chain", chain_id, "of type POLY_PEPTIDE_L skip...")
        continue

    for r in chain.residues:

        rname = r.GetName()

        # do peptide bond
        r_next = r.GetNext()
        if r.GetChemType() == 'A' and r_next.IsValid() and r_next.GetChemType() == 'A':
            ca = r.FindAtom("CA")
            c = r.FindAtom("C")
            n_next = r_next.FindAtom("N")
            ca_next = r_next.FindAtom("CA")
            if ca.IsValid() and c.IsValid() and n_next.IsValid() and ca_next.IsValid():
                # crude check whether pep bond makes sense
                pep_bond_length = geom.Distance(c.GetPos(), n_next.GetPos())
                if pep_bond_length > 1.3 and pep_bond_length < 1.36:
                    pep_bond.append(pep_bond_length)
                    ca_c_n_angle.append(geom.Angle(ca.GetPos() - c.GetPos(), n_next.GetPos() - c.GetPos()))
                    c_n_ca_angle.append(geom.Angle(c.GetPos() - n_next.GetPos(), ca_next.GetPos() - n_next.GetPos()))



        # do n_c_ca_cb_diangle
        if rname in n_c_ca_cb_diangle:
            n = r.FindAtom("N")
            c = r.FindAtom("C")
            ca = r.FindAtom("CA")
            cb = r.FindAtom("CB")
            if n.IsValid() and c.IsValid() and ca.IsValid() and cb.IsValid():
                n_c_ca_cb_diangle[rname].append(geom.DihedralAngle(n.GetPos(), c.GetPos(), ca.GetPos(), cb.GetPos()))


        # do rest
        if rname in bond_param:
            for bond in bond_param[rname].keys():
                a1 = r.FindAtom(bond[0])
                a2 = r.FindAtom(bond[1])
                if a1.IsValid() and a2.IsValid():
                    bond_param[rname][bond].append(geom.Distance(a1.GetPos(), a2.GetPos()))

        if rname in angle_param:
            for angle in angle_param[rname].keys():
                a1 = r.FindAtom(angle[0])
                a2 = r.FindAtom(angle[1])
                a3 = r.FindAtom(angle[2])
                if a1.IsValid() and a2.IsValid() and a3.IsValid():
                    angle_param[rname][angle].append(geom.Angle(a1.GetPos() - a2.GetPos(), a3.GetPos() - a2.GetPos()))

        # special dihedrals
        if rname == "LEU":
            cd1 = r.FindAtom("CD1")
            cb = r.FindAtom("CB")
            cg = r.FindAtom("CG")
            cd2 = r.FindAtom("CD2")
            if cd1.IsValid() and cd2.IsValid() and cg.IsValid() and cd2.IsValid():
                leu_cd1_cb_cg_cd2_dihedrals.append(geom.DihedralAngle(cd1.GetPos(), cb.GetPos(), cg.GetPos(), cd2.GetPos())) 
        if rname == "ILE":
            cg1 = r.FindAtom("CG1")
            ca = r.FindAtom("CA")
            cb = r.FindAtom("CB")
            cg2 = r.FindAtom("CG2")
            if cg1.IsValid() and ca.IsValid() and cb.IsValid() and cg2.IsValid():
                ile_cg1_ca_cb_cg2_dihedrals.append(geom.DihedralAngle(cg1.GetPos(), ca.GetPos(), cb.GetPos(), cg2.GetPos()))
        if rname == "VAL":
            cg1 = r.FindAtom("CG1")
            ca = r.FindAtom("CA")
            cb = r.FindAtom("CB")
            cg2 = r.FindAtom("CG2")
            if cg1.IsValid() and ca.IsValid() and cb.IsValid() and cg2.IsValid():
                val_cg1_ca_cb_cg2_dihedrals.append(geom.DihedralAngle(cg1.GetPos(), ca.GetPos(), cb.GetPos(), cg2.GetPos()))
        if rname == "THR":
            og1 = r.FindAtom("OG1")
            ca = r.FindAtom("CA")
            cb = r.FindAtom("CB")
            cg2 = r.FindAtom("CG2")
            if og1.IsValid() and ca.IsValid() and cb.IsValid() and cg2.IsValid():
                thr_og1_ca_cb_cg2_dihedrals.append(geom.DihedralAngle(og1.GetPos(), ca.GetPos(), cb.GetPos(), cg2.GetPos()))

        # C_O bond
        c = r.FindAtom("C")
        o = r.FindAtom("O")
        if c.IsValid() and o.IsValid():
            c_o_bond.append(geom.Distance(c.GetPos(), o.GetPos()))


print("Pep bond parameters:")
print("Pep bond length:", np.mean(pep_bond), np.std(pep_bond))
print("CA_C_N angle:", np.mean(ca_c_n_angle), np.std(ca_c_n_angle))
print("C_N_CA angle:", np.mean(c_n_ca_angle), np.std(c_n_ca_angle))
print()

print("N C CA CB diangle")
for rname in n_c_ca_cb_diangle.keys():
    print(rname, np.mean(n_c_ca_cb_diangle[rname]), np.std(n_c_ca_cb_diangle[rname]))
print()

rnames = bond_param.keys()
for rname in rnames:
    print(rname)
    print("bond param")
    for b in bond_param[rname].keys():
        print(b, np.mean(bond_param[rname][b]), np.std(bond_param[rname][b]))
    print("angle param")
    for a in angle_param[rname].keys():
        print(a, np.mean(angle_param[rname][a]), np.std(angle_param[rname][a]))
    print()
print("LEU CD1_CB_CG_CD2 dihedral:", np.mean(leu_cd1_cb_cg_cd2_dihedrals), np.std(leu_cd1_cb_cg_cd2_dihedrals))
print("ILE CG1_CA_CB_CG2 dihedral:", np.mean(ile_cg1_ca_cb_cg2_dihedrals), np.std(ile_cg1_ca_cb_cg2_dihedrals))
print("VAL CG1_CA_CB_CG2 dihedral:", np.mean(val_cg1_ca_cb_cg2_dihedrals), np.std(val_cg1_ca_cb_cg2_dihedrals))
print("THR OG1_CA_CB_CG2 dihedral:", np.mean(thr_og1_ca_cb_cg2_dihedrals), np.std(thr_og1_ca_cb_cg2_dihedrals))
print("C O bond:", np.mean(c_o_bond), np.std(c_o_bond))
