File: OutputTree.cpp

package info (click to toggle)
pentobi 29.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 3,892 kB
  • sloc: cpp: 25,719; javascript: 875; xml: 40; makefile: 13; sh: 6
file content (237 lines) | stat: -rw-r--r-- 7,704 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
//-----------------------------------------------------------------------------
/** @file twogtp/OutputTree.cpp
    @author Markus Enzenberger
    @copyright GNU General Public License version 3 or later */
//-----------------------------------------------------------------------------

#include "OutputTree.h"

#include <fstream>
#include "libboardgame_base/TreeReader.h"
#include "libboardgame_base/TreeWriter.h"
#include "libpentobi_base/BoardUtil.h"

using libboardgame_base::ArrayList;
using libboardgame_base::SgfNode;
using libboardgame_base::SgfTree;
using libboardgame_base::TreeReader;
using libboardgame_base::TreeWriter;
using libpentobi_base::get_transforms;
using libpentobi_base::ColorMove;
using libpentobi_base::MovePoints;
using libpentobi_base::get_transformed;

//-----------------------------------------------------------------------------

namespace {

void add(PentobiTree& tree, const SgfNode& node, bool is_player_black,
         bool is_real_move, float result)
{
    unsigned index = is_player_black ? 0 : 1;
    array<unsigned short, 2> count;
    array<float, 2> avg_result;
    array<unsigned short, 2> real_count;
    auto comment = SgfTree::get_comment(node);
    if (comment.empty())
    {
        count.fill(0);
        avg_result.fill(0);
        real_count.fill(0);
        count[index] = 1;
        real_count[index] = 1;
        avg_result[index] = result;
    }
    else
    {
        istringstream in(comment);
        in >> count[0] >> real_count[0] >> avg_result[0]
           >> count[1] >> real_count[1] >> avg_result[1];
        if (! in)
            throw runtime_error("OutputTree: invalid comment: " + comment);
        ++count[index];
        avg_result[index] += (result - avg_result[index]) / count[index];
        if (is_real_move)
            ++real_count[index];
    }
    ostringstream out;
    out.precision(numeric_limits<double>::digits10);
    out << count[0] << ' ' << real_count[0] << ' ' << avg_result[0] << '\n'
        << count[1] << ' ' << real_count[1] << ' ' << avg_result[1];
    tree.set_comment(node, out.str());
}

bool compare_sequence(ArrayList<ColorMove, Board::max_moves>& s1,
                      ArrayList<ColorMove, Board::max_moves>& s2)
{
    LIBBOARDGAME_ASSERT(s1.size() == s2.size());
    for (unsigned i = 0; i < s1.size(); ++i)
    {
        LIBBOARDGAME_ASSERT(s1[i].color == s2[i].color);
        if (s1[i].move.to_int() < s2[i].move.to_int())
            return true;
        if (s1[i].move.to_int() > s2[i].move.to_int())
            return false;
    }
    return false;
}

unsigned get_real_count(const SgfNode& node, bool is_player_black)
{
    unsigned index = is_player_black ? 0 : 1;
    array<unsigned, 2> count;
    array<double, 2> avg_result;
    array<unsigned, 2> real_count;
    auto comment = SgfTree::get_comment(node);
    istringstream in(comment);
    in >> count[0] >> real_count[0] >> avg_result[0]
       >> count[1] >> real_count[1] >> avg_result[1];
    if (! in)
        throw runtime_error("OutputTree: invalid comment: " + comment);
    return real_count[index];
}

} // namespace

//-----------------------------------------------------------------------------

OutputTree::OutputTree(Variant variant)
    : m_tree(variant)
{
    get_transforms(variant, m_transforms, m_inv_transforms);
}

OutputTree::~OutputTree() = default; // Non-inline to avoid GCC -Winline warning

void OutputTree::add_game(const Board& bd, unsigned player_black,
                          float result, const array<bool,
                          Board::max_moves>& is_real_move)
{
    if (bd.has_setup())
        throw runtime_error("OutputTree: setup not supported");

    // Find the canonical representation
    ArrayList<ColorMove, Board::max_moves> sequence;
    for (auto& transform : m_transforms)
    {
        ArrayList<ColorMove, Board::max_moves> s;
        for (unsigned i = 0; i < bd.get_nu_moves(); ++i)
        {
            auto mv = bd.get_move(i);
            s.push_back(ColorMove(mv.color,
                                  get_transformed(bd, mv.move, *transform)));
        }
        if (sequence.empty() || compare_sequence(s, sequence))
            sequence = s;
    }

    auto node = &m_tree.get_root();
    add(m_tree, *node, player_black == 0, true, result);
    unsigned nu_moves_3 = 0;
    for (unsigned i = 0; i < sequence.size(); ++i)
    {
        unsigned player;
        auto mv = sequence[i];
        Color c = mv.color;
        if (bd.get_variant() == Variant::classic_3 && c == Color(3))
        {
            player = nu_moves_3 % 3;
            ++nu_moves_3;
        }
        else
            player = c.to_int() % bd.get_nu_players();
        auto child = m_tree.find_child_with_move(*node, mv);
        if (child == nullptr)
        {
            child = &m_tree.create_new_child(*node);
            m_tree.set_move(*child, mv);
            add(m_tree, *child, player == player_black, true, result);
            return;
        }
        add(m_tree, *child, player == player_black, is_real_move[i], result);
        node = child;
    }
}

void OutputTree::generate_move(bool is_player_black, const Board& bd,
                               Color to_play, Move& mv)
{
    bool play_real;
    for (unsigned i = 0; i < m_transforms.size(); ++i)
    {
        generate_move(is_player_black, bd, to_play, *m_transforms[i],
                      *m_inv_transforms[i], mv, play_real);
        if (play_real || ! mv.is_null())
            break;
    }
}

void OutputTree::generate_move(bool is_player_black, const Board& bd,
                               Color to_play, const PointTransform& transform,
                               const PointTransform& inv_transform, Move& mv,
                               bool& play_real)
{
    if (bd.has_setup())
        throw runtime_error("OutputTree: setup not supported");
    play_real = false;
    mv = Move::null();
    auto node = &m_tree.get_root();
    for (unsigned i = 0; i < bd.get_nu_moves(); ++i)
    {
        auto mv = bd.get_move(i);
        ColorMove transformed_mv(mv.color,
                                 get_transformed(bd, mv.move, transform));
        auto child = m_tree.find_child_with_move(*node, transformed_mv);
        if (child == nullptr)
            return;
        node = child;
    }
    unsigned sum = 0;
    for (auto& i : node->get_children())
        sum += get_real_count(i, is_player_black);
    if (sum == 0)
        return;
    uniform_real_distribution<double> distribution(0, 1);
    if (distribution(m_random) < 1.0 / sum)
    {
        play_real = true;
        return;
    }
    auto random = static_cast<unsigned>(distribution(m_random) * sum);
    sum = 0;
    for (auto& i : node->get_children())
    {
        auto real_count = get_real_count(i, is_player_black);
        if (real_count == 0)
            continue;
        sum += real_count;
        if (sum >= random)
        {
            auto color_mv = m_tree.get_move(i);
            if (color_mv.is_null())
                throw runtime_error("OutputTree: tree has node without move");
            if (color_mv.color != to_play)
                throw runtime_error("OutputTree: tree has node wrong move color");
            mv = get_transformed(bd, color_mv.move, inv_transform);
            return;
        }
    }
    LIBBOARDGAME_ASSERT(false);
}

void OutputTree::load(const string& file)
{
    TreeReader reader;
    reader.read(file);
    auto tree = reader.move_tree();
    m_tree.init(tree);
}

void OutputTree::save(const string& file)
{
    ofstream out(file);
    TreeWriter writer(out, m_tree.get_root());
    writer.write();
}

//-----------------------------------------------------------------------------