File: embedding_layer.hpp

package info (click to toggle)
frugally-deep 0.18.2-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,036 kB
  • sloc: cpp: 6,680; python: 1,262; makefile: 4; sh: 1
file content (65 lines) | stat: -rw-r--r-- 2,360 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
// Copyright 2016, Tobias Hermann.
// https://github.com/Dobiasd/frugally-deep
// Distributed under the MIT License.
// (See accompanying LICENSE file or at
//  https://opensource.org/licenses/MIT)

#pragma once

#include "fdeep/layers/layer.hpp"

#include <functional>
#include <string>

namespace fdeep {
namespace internal {

    class embedding_layer : public layer {
    public:
        explicit embedding_layer(const std::string& name,
            std::size_t input_dim,
            std::size_t output_dim,
            const float_vec& weights)
            : layer(name)
            , input_dim_(input_dim)
            , output_dim_(output_dim)
            , weights_(weights)
        {
        }

    protected:
        tensors apply_impl(const tensors& inputs) const override final
        {
            const auto input_shapes = fplus::transform(fplus_c_mem_fn_t(tensor, shape, tensor_shape), inputs);

            // ensure that tensor shape is (1, 1, 1, 1, seq_len)
            assertion(inputs.front().shape().size_dim_5_ == 1
                    && inputs.front().shape().size_dim_4_ == 1
                    && inputs.front().shape().height_ == 1
                    && inputs.front().shape().width_ == 1,
                "size_dim_5, size_dim_4, height and width dimension must be 1, but shape is '" + show_tensor_shapes(input_shapes) + "'");

            tensors results;
            for (auto&& input : inputs) {
                const std::size_t sequence_len = input.shape().depth_;
                float_vec output_vec(sequence_len * output_dim_);
                auto&& it = output_vec.begin();

                for (std::size_t i = 0; i < sequence_len; ++i) {
                    std::size_t index = static_cast<std::size_t>(input.get(tensor_pos(i)));
                    assertion(index < input_dim_, "vocabulary item indices must all be strictly less than the value of input_dim");
                    it = std::copy_n(weights_.cbegin() + static_cast<float_vec::const_iterator::difference_type>(index * output_dim_), output_dim_, it);
                }

                results.push_back(tensor(tensor_shape(sequence_len, output_dim_), std::move(output_vec)));
            }
            return results;
        }

        const std::size_t input_dim_;
        const std::size_t output_dim_;
        const float_vec weights_;
    };

}
}