File: rescoring3D.C

package info (click to toggle)
ball 1.5.0%2Bgit20180813.37fc53c-6
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 239,888 kB
  • sloc: cpp: 326,149; ansic: 4,208; python: 2,303; yacc: 1,778; lex: 1,099; xml: 958; sh: 322; makefile: 95
file content (175 lines) | stat: -rw-r--r-- 6,252 bytes parent folder | download | duplicates (6)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
// ----------------------------------------------------
// $Maintainer: Marcel Schumann $
// $Authors: Marcel Schumann $
// ----------------------------------------------------

#include <BALL/SCORING/FUNCTIONS/rescoring3D.h>
#include <BALL/KERNEL/molecularInteractions.h>


using namespace BALL;
using namespace std;

Rescoring3D::Rescoring3D(AtomContainer& receptor, AtomContainer& reference_ligand, Options& options, String free_energy_label, ScoringFunction* sf)
    : Rescoring(receptor, reference_ligand, options, free_energy_label, sf)
{
    name_ = "Rescoring3D";
    use_calibration_ = 1;
    setup_();
}

void Rescoring3D::setup_()
{
    scoring_function_->enableStoreInteractions(1);

    const HashGrid3<Atom*>* sf_hashgrid = scoring_function_->getHashGrid();
    sizeX_ = sf_hashgrid->getSizeX();
    sizeY_ = sf_hashgrid->getSizeY();
    sizeZ_ = sf_hashgrid->getSizeZ();
    resolution_ = sf_hashgrid->getUnit()[0];
    origin_ = sf_hashgrid->getOrigin();
    reg3D_grid_names_.push_back("rescore3D grid");
}


void Rescoring3D::generateAtomScoreContributions_(const Atom* atom, vector<double>& contributions)
{
    if (!atom)
    {
        contributions.resize(1, 0);
        return;
    }
    if (contributions.size() == 0) contributions.resize(1, 0);
    if (atom->interactions) contributions[0] += atom->interactions->getInteractionEnergy();
}


void Rescoring3D::generateScoreContributions_(Molecule* mol, vector<vector<double> >* matrix, vector<double>* v)
{
    HashGrid3<Atom*> hashgrid(origin_, sizeX_, sizeY_, sizeZ_, resolution_);

    // add all atoms of 'mol' to hashgrid
    Vector3 origin = hashgrid.getOrigin();
    Size hashgrid_size = hashgrid.getSizeX();
    for (AtomIterator it = mol->beginAtom(); +it; it++)
    {
        // position of the current atoms within the HashGrid
        Vector3 atom_pos = it->getPosition()-origin;

        atom_pos[0] /= hashgrid.getUnit()[0];
        atom_pos[1] /= hashgrid.getUnit()[1];
        atom_pos[2] /= hashgrid.getUnit()[2];

        // insert all target atoms that are located within the grid boundaries
        if (atom_pos[0] >= 0 && atom_pos[0] < hashgrid_size && atom_pos[1] >= 0 && atom_pos[1] < hashgrid_size && atom_pos[2] >= 0 && atom_pos[2] < hashgrid_size )
        {
            hashgrid.insert(it->getPosition(), &*it);
        }
    }

    //	sum up the score of each box
    Size cell_no = 0;
    for (Size i = 0; i < hashgrid.getSizeX(); i++)
    {
        for (Size j = 0; j < hashgrid.getSizeY(); j++)
        {
            for (Size k = 0; k < hashgrid.getSizeZ(); k++)
            {
                vector<double> box_scores(0, 0);

                // initializes vector with appropriate size
                generateAtomScoreContributions_(0, box_scores);

                HashGridBox3<Atom*>* box = hashgrid.getBox(i, j, k);
                for (HashGridBox3 < Atom* > ::DataIterator di = box->beginData(); di != box->endData(); di++)
                {
                    generateAtomScoreContributions_(*di, box_scores);
                }
                for (Size s = 0; s < box_scores.size(); s++, cell_no++)
                {
                    if (matrix)
                    {
                        (*matrix)[cell_no].push_back(box_scores[s]);
                    }
                    else if (v)
                    {
                        v->push_back(box_scores[s]);
                    }
                }
            }
        }
    }
}

list<pair<String, RegularData3D*> > Rescoring3D::generateRegularData3DGrids()
{
    if (!model_)
    {
        throw BALL::Exception::GeneralException(__FILE__, __LINE__, "Rescoring3D::generateRegularData3DGrids() Error", "No existing model!");
    }
    const Eigen::MatrixXd* coefficients = model_->getTrainingResult();
    const vector<string>* names = model_->getDescriptorNames();
    if (sizeX_*sizeY_*sizeZ_ < coefficients->rows())
    {
        throw BALL::Exception::GeneralException(__FILE__, __LINE__, "Rescoring3D::generateRegularData3DGrids() Error", "Number of coefficients larger than number of grid cells!");
    }

    list<pair<String, RegularData3D*> > reg3d_list;


    Vector3 resolution(resolution_, resolution_, resolution_);
    Vector3 dimension(resolution_*sizeX_, resolution_*sizeY_, resolution_*sizeZ_);

    vector<RegularData3D*> grids;
    Size no_reg3D_grids = reg3D_grid_names_.size();
    for (Size g = 0; g < no_reg3D_grids; g++)
    {
        RegularData3D* reg3d = new RegularData3D(origin_, dimension, resolution);
        grids.push_back(reg3d);
        reg3d_list.push_back(make_pair(reg3D_grid_names_[g], reg3d));
    }

    vector<string>::const_iterator name_it = names->begin();
    int coeff_index = 1;
    Size cell_no = 0;
    for (Size i = 0; i < sizeX_; i++)
    {
        for (Size j = 0; j < sizeY_; j++)
        {
            for (Size k = 0; k < sizeZ_; k++)
            {
                for (Size g = 0; g < no_reg3D_grids; g++, cell_no++)
                {
                    String name;
                    if (name_it == names->end()) // cell was removed by feature selection
                    {
                        RegularData3D::IndexType index(i, j, k);
                        grids[g]->getData(index) = 0;
                    }
                    else if ((name = *name_it).isDigit())
                    {
                        Size coeff_no = (Size)name.toFloat();
                        RegularData3D::IndexType index(i, j, k);

                        if (coeff_no == cell_no)
                        {
                            grids[g]->getData(index) = (*coefficients)(coeff_index, 1);
                            name_it++;
                            coeff_index++;
                        }
                        else // cell was removed by feature selection
                        {
                            grids[g]->getData(index) = 0;
                        }
                    }
                    else
                    {
                        throw BALL::Exception::GeneralException(__FILE__, __LINE__, "Rescoring3D::generateRegularData3DGrids() Error", "Non-numeric feature label found!");
                    }
                }
            }
        }
    }

    return reg3d_list;
}