File: lm_merged.h

package info (click to toggle)
onboard 1.4.1-5
  • links: PTS, VCS
  • area: main
  • in suites: bookworm, bullseye
  • size: 31,548 kB
  • sloc: python: 29,215; cpp: 5,965; ansic: 5,735; xml: 1,026; sh: 163; makefile: 39
file content (140 lines) | stat: -rw-r--r-- 4,797 bytes parent folder | download | duplicates (4)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
/*
 * Copyright © 2009-2010, 2013-2014 marmuta <marmvta@gmail.com>
 *
 * This file is part of Onboard.
 *
 * Onboard is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 3 of the License, or
 * (at your option) any later version.
 *
 * Onboard is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program. If not, see <http://www.gnu.org/licenses/>.
 */

#ifndef LM_MERGED_H
#define LM_MERGED_H

#include <vector>
#include "lm.h"

//------------------------------------------------------------------------
// MergedModel - abstract container for one or more component language models
//------------------------------------------------------------------------

struct map_wstr_cmp
{
  bool operator() (const std::wstring& lhs, const std::wstring& rhs) const
  { return lhs < rhs; }
};
typedef std::map<std::wstring, double, map_wstr_cmp> ResultsMap;
//#include <unordered_map>
//typedef std::unordered_map<const wchar_t*, double> ResultsMap;

class MergedModel : public LanguageModel
{
    public:
        // language model overloads
        virtual bool is_model_valid()
        {
            for (unsigned i=0; i<components.size(); i++)
                if (!components[i]->is_model_valid())
                    return false;
            return true;
        };

        virtual void predict(std::vector<LanguageModel::Result>& results,
                             const std::vector<wchar_t*>& context,
                             int limit=-1,
                             uint32_t options = DEFAULT_OPTIONS);

        virtual LMError load(const char* filename)
        {return ERR_NOT_IMPL;}
        virtual LMError save(const char* filename)
        {return ERR_NOT_IMPL;}

        // merged model interface
        virtual void set_models(const std::vector<LanguageModel*>& models)
        { components = models;}

    protected:
        // merged model interface
        virtual void init_merge() {}
        virtual bool can_limit_components() {return false;}
        virtual void merge(ResultsMap& dst, const std::vector<Result>& values,
                                      int model_index) = 0;
        virtual bool needs_normalization() {return false;}

    private:
        void normalize(std::vector<Result>& results, int result_size);

    protected:
        std::vector<LanguageModel*> components;
};

//------------------------------------------------------------------------
// OverlayModel - merge by overlaying language models
//------------------------------------------------------------------------

class OverlayModel : public MergedModel
{
    protected:
        virtual void merge(ResultsMap& dst, const std::vector<Result>& values,
                                      int model_index);

        // overlay can safely use a limit on prediction results
        // for component models
        virtual bool can_limit_components() {return true;}

        virtual bool needs_normalization() {return true;}
};

//------------------------------------------------------------------------
// LinintModel - linearly interpolate language models
//------------------------------------------------------------------------

class LinintModel : public MergedModel
{
    public:
        void set_weights(const std::vector<double>& weights)
        { this->weights = weights; }

        virtual void init_merge();
        virtual void merge(ResultsMap& dst, const std::vector<Result>& values,
                                      int model_index);
        virtual double get_probability(const wchar_t* const* ngram, int n);

    protected:
        std::vector<double> weights;
        double weight_sum;
};


//------------------------------------------------------------------------
// LoglinintModel - log-linear interpolation of language models
//------------------------------------------------------------------------

class LoglinintModel : public MergedModel
{
    public:
        void set_weights(const std::vector<double>& weights)
        { this->weights = weights; }

        virtual void init_merge();
        virtual void merge(ResultsMap& dst, const std::vector<Result>& values,
                                      int model_index);

        // there appears to be no simply way to for direct normalized results
        // -> run normalization explicitly
        virtual bool needs_normalization() {return true;}
    protected:
        std::vector<double> weights;
};

#endif