from ost import io, seq
from promod3 import modelling, loop

# setup raw model
tpl = io.LoadPDB('data/1crn_cut.pdb')
seq_trg = 'TTCCPSIVARSNFNVCRLPGTPEAICATYTGCIIIPGATCPGDYAN'
seq_tpl = 'TTCCPSIVARSNFNVCRLPGTPEA------GCIIIPGATCPGDYAN'
aln = seq.CreateAlignment(seq.CreateSequence('trg', seq_trg),
                          seq.CreateSequence('tpl', seq_tpl))
aln.AttachView(1, tpl.CreateFullView())
mhandle = modelling.BuildRawModel(aln)
print(("Number of gaps in raw model: %d" % len(mhandle.gaps)))

# setup default scorers for modelling handle
modelling.SetupDefaultBackboneScoring(mhandle)
modelling.SetupDefaultAllAtomScoring(mhandle)

# setup databases
frag_db = loop.LoadFragDB()
structure_db = loop.LoadStructureDB()
torsion_sampler = loop.LoadTorsionSamplerCoil()

# get data for gap to close
gap = mhandle.gaps[0].Copy()
print(("Gap to close: %s" % str(gap)))
n_stem = gap.before
c_stem = gap.after
start_resnum = n_stem.GetNumber().GetNum()
start_idx = start_resnum - 1   # res. num. starts at 1

# get loop candidates from FragDB
candidates = modelling.LoopCandidates.FillFromDatabase(\
                n_stem, c_stem, gap.full_seq, frag_db, structure_db)
print(("Number of loop candidates: %d" % len(candidates)))

# all scores will be kept in a score container which we update
all_scores = modelling.ScoreContainer()
# the keys used to identify scores are globally defined
print(("Stem RMSD key = '%s'" \
      % modelling.ScoringWeights.GetStemRMSDsKey()))
print(("Profile keys = ['%s', '%s']" \
      % (modelling.ScoringWeights.GetSequenceProfileScoresKey(),
         modelling.ScoringWeights.GetStructureProfileScoresKey())))
print(("Backbone scoring keys = %s" \
      % str(modelling.ScoringWeights.GetBackboneScoringKeys())))
print(("All atom scoring keys = %s" \
      % str(modelling.ScoringWeights.GetAllAtomScoringKeys())))

# get stem RMSDs for each candidate (i.e. how well does it fit?)
# -> this must be done before CCD to be meaningful
candidates.CalculateStemRMSDs(all_scores, n_stem, c_stem)

# close the candidates with CCD
orig_indices = candidates.ApplyCCD(n_stem, c_stem, torsion_sampler)
print(("Number of closed loop candidates: %d" % len(candidates)))

# get subset of previously computed scores
all_scores = all_scores.Extract(orig_indices)

# add profile scores (needs profile for target sequence)
prof = io.LoadSequenceProfile("data/1CRNA.hhm")
candidates.CalculateSequenceProfileScores(all_scores, structure_db,
                                          prof, start_idx)
candidates.CalculateStructureProfileScores(all_scores, structure_db,
                                           prof, start_idx)
# add backbone scores
candidates.CalculateBackboneScores(all_scores, mhandle, start_resnum)
# add all atom scores
candidates.CalculateAllAtomScores(all_scores, mhandle, start_resnum)

# use default weights to combine scores
weights = modelling.ScoringWeights.GetWeights(with_db=True,
                                              with_aa=True)
scores = all_scores.LinearCombine(weights)

# rank them (best = lowest "score")
arg_sorted_scores = sorted([(v,i) for i,v in enumerate(scores)])
print("Ranked candidates: score, index")
for v,i in arg_sorted_scores:
  print(("%g, %d" % (v,i)))

# insert best into model, update scorers and clear gaps
best_candidate = candidates[arg_sorted_scores[0][1]]
modelling.InsertLoopClearGaps(mhandle, best_candidate, gap)
print(("Number of gaps in closed model: %d" % len(mhandle.gaps)))
io.SavePDB(mhandle.model, "model.pdb")
