File: model_layer.hpp

package info (click to toggle)
frugally-deep 0.18.2-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,036 kB
  • sloc: cpp: 6,680; python: 1,262; makefile: 4; sh: 1
file content (73 lines) | stat: -rw-r--r-- 2,567 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
// Copyright 2016, Tobias Hermann.
// https://github.com/Dobiasd/frugally-deep
// Distributed under the MIT License.
// (See accompanying LICENSE file or at
//  https://opensource.org/licenses/MIT)

#pragma once

#include "fdeep/common.hpp"

#include "fdeep/tensor.hpp"

#include "fdeep/layers/layer.hpp"

#include <algorithm>
#include <cstddef>
#include <memory>
#include <string>

namespace fdeep {
namespace internal {

    class model_layer : public layer {
    public:
        explicit model_layer(const std::string& name,
            const layer_ptrs& layers,
            const node_connections& input_connections,
            const node_connections& output_connections)
            : layer(name)
            , layers_(layers)
            , input_connections_(input_connections)
            , output_connections_(output_connections)
        {
            assertion(fplus::all_unique(
                          fplus::transform(fplus_get_ptr_mem(name_), layers)),
                "layer names must be unique");
        }

        tensor get_output(const layer_ptrs& layers, output_dict& output_cache,
            std::size_t node_idx, std::size_t tensor_idx) const override
        {
            // https://stackoverflow.com/questions/46011749/understanding-keras-model-architecture-node-index-of-nested-model
            if (node_idx >= 1) {
                node_idx = node_idx - 1;
            }
            assertion(node_idx < nodes_.size(), "invalid node index: " + std::to_string(node_idx) + " of " + std::to_string(nodes_.size()));
            return layer::get_output(layers, output_cache, node_idx, tensor_idx);
        }

    protected:
        tensors apply_impl(const tensors& inputs) const override
        {
            output_dict output_cache;

            assertion(inputs.size() == input_connections_.size(),
                "invalid number of input tensors for this model: " + fplus::show(input_connections_.size()) + " required but " + fplus::show(inputs.size()) + " provided");

            for (std::size_t i = 0; i < inputs.size(); ++i) {
                output_cache[input_connections_[i].without_tensor_idx()] = { inputs[i] };
            }

            const auto get_output = [this, &output_cache](const node_connection& conn) -> tensor {
                return get_layer_output(layers_, output_cache, conn);
            };
            return fplus::transform(get_output, output_connections_);
        }
        layer_ptrs layers_;
        node_connections input_connections_;
        node_connections output_connections_;
    };

}
}