File: Tree.cpp

package info (click to toggle)
stopt 5.12%2Bdfsg-3
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 8,860 kB
  • sloc: cpp: 70,456; python: 5,950; makefile: 72; sh: 57
file content (69 lines) | stat: -rw-r--r-- 1,935 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
// Copyright (C) 2019 EDF
// All Rights Reserved
// This code is published under the GNU Lesser General Public License (GNU LGPL)
#include <iostream>
#include "StOpt/tree/Tree.h"

using namespace std;
using namespace Eigen;

namespace StOpt
{

Tree::Tree() {}


Tree::Tree(const vector< double > &p_proba, const vector< std::vector< std::array<int, 2> > > &p_connected): m_proba(p_proba), m_connected(p_connected), m_nbNodeNextDate(0)
{

    for (size_t i = 0; i < m_connected.size(); ++i)
    {
        for (size_t j = 0; j < m_connected[i].size(); ++j)
        {
            m_nbNodeNextDate = std::max(m_nbNodeNextDate, static_cast<int>(m_connected[i][j][0]));
        }
    }
    m_nbNodeNextDate += 1;
}

void Tree::update(const vector< double > &p_proba,
                  const vector< std::vector< std::array<int, 2> > >  &p_connected)
{
    m_proba = p_proba;
    m_connected = p_connected;
    m_nbNodeNextDate = 0;
    for (size_t i = 0; i < m_connected.size(); ++i)
    {
        for (size_t j = 0; j < m_connected[i].size(); ++j)
            m_nbNodeNextDate = std::max(m_nbNodeNextDate, static_cast<int>(m_connected[i][j][0]));
    }
    m_nbNodeNextDate += 1;
}


ArrayXd  Tree::expCond(const ArrayXd &p_values) const
{
    ArrayXd ret =  ArrayXd::Zero(m_connected.size());
    for (size_t i = 0 ; i < m_connected.size(); ++i)
    {
        for (size_t j = 0; j < m_connected[i].size(); ++j)
        {
            ret(i) += m_proba[m_connected[i][j][1]] * p_values(m_connected[i][j][0]);
        }
    }
    return ret;
}

ArrayXXd  Tree::expCondMultiple(const ArrayXXd &p_values) const
{
    ArrayXXd ret =  ArrayXXd::Zero(p_values.rows(), m_connected.size());
    for (size_t i = 0 ; i < m_connected.size(); ++i)
    {
        for (size_t j = 0; j < m_connected[i].size(); ++j)
        {
            ret.col(i) += m_proba[m_connected[i][j][1]] * p_values.col(m_connected[i][j][0]);
        }
    }
    return ret;
}
}