File: time_distributed_layer.hpp

package info (click to toggle)
frugally-deep 0.18.2-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,036 kB
  • sloc: cpp: 6,680; python: 1,262; makefile: 4; sh: 1
file content (99 lines) | stat: -rw-r--r-- 3,370 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
// Copyright 2016, Tobias Hermann.
// https://github.com/Dobiasd/frugally-deep
// Distributed under the MIT License.
// (See accompanying LICENSE file or at
//  https://opensource.org/licenses/MIT)

#pragma once

#include "fdeep/layers/layer.hpp"
#include "fdeep/recurrent_ops.hpp"
#include <nlohmann/json.hpp>

#include <functional>
#include <string>

namespace fdeep {
namespace internal {

    class time_distributed_layer : public layer {
    public:
        explicit time_distributed_layer(const std::string& name,
            const layer_ptr& inner_layer,
            const std::size_t td_input_len,
            const std::size_t td_output_len)
            : layer(name)
            , inner_layer_(inner_layer)
            , td_input_len_(td_input_len)
            , td_output_len_(td_output_len)
        {
            assertion(td_output_len_ > 1, "Wrong input dimension");
        }

    protected:
        tensors apply_impl(const tensors& inputs) const override final
        {
            const auto& input = single_tensor_from_tensors(inputs);
            tensors result_time_step = {};
            std::size_t len_series = 0;
            tensors slices = {};
            std::int32_t concat_axis = 0;

            if (td_input_len_ == 2) {
                len_series = input.shape().width_;
                slices = tensor_to_tensors_width_slices(input);
            } else if (td_input_len_ == 3) {
                len_series = input.shape().height_;
                slices = tensor_to_tensors_height_slices(input);
            } else if (td_input_len_ == 4) {
                len_series = input.shape().size_dim_4_;
                slices = tensor_to_tensors_dim4_slices(input);
            } else if (td_input_len_ == 5) {
                len_series = input.shape().size_dim_5_;
                slices = tensor_to_tensors_dim5_slices(input);
            } else
                raise_error("invalid input dim for TimeDistributed");

            for (auto& slice : slices) {
                slice.shrink_rank();
            }

            if (td_output_len_ == 2)
                concat_axis = 2;
            else if (td_output_len_ == 3)
                concat_axis = 1;
            else if (td_output_len_ == 4)
                concat_axis = 3;
            else if (td_output_len_ == 5)
                concat_axis = 4;
            else
                raise_error("invalid output dim for TimeDistributed");

            for (std::size_t i = 0; i < len_series; ++i) {
                const auto curr_result = inner_layer_->apply({ slices[i] });
                result_time_step.push_back(curr_result.front());
            }

            if (concat_axis == 1) {
                return { concatenate_tensors_height(result_time_step) };
            }
            if (concat_axis == 2) {
                return { concatenate_tensors_width(result_time_step) };
            }
            if (concat_axis == 3) {
                return { concatenate_tensors_dim4(result_time_step) };
            }
            if (concat_axis == 4) {
                return { concatenate_tensors_dim5(result_time_step) };
            }
            raise_error("Invalid concat_axis in time_distributed_layer.");
            return {};
        }

        const layer_ptr inner_layer_;
        const std::size_t td_input_len_;
        const std::size_t td_output_len_;
    };

}
}