1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
|
// Copyright 2016, Tobias Hermann.
// https://github.com/Dobiasd/frugally-deep
// Distributed under the MIT License.
// (See accompanying LICENSE file or at
// https://opensource.org/licenses/MIT)
#pragma once
#include "fdeep/common.hpp"
#include "fdeep/tensor.hpp"
#include "fdeep/node.hpp"
#include <cstddef>
#include <memory>
#include <string>
#include <vector>
namespace fdeep {
namespace internal {
class layer;
typedef std::shared_ptr<layer> layer_ptr;
typedef std::vector<layer_ptr> layer_ptrs;
class activation_layer;
typedef std::shared_ptr<activation_layer> activation_layer_ptr;
tensors apply_activation_layer(const activation_layer_ptr& ptr,
const tensors& input);
class layer {
public:
explicit layer(const std::string& name)
: name_(name)
, nodes_()
, activation_(nullptr)
{
}
virtual ~layer()
{
}
void set_activation(const activation_layer_ptr& activation)
{
activation_ = activation;
}
void set_nodes(const nodes& layer_nodes)
{
nodes_ = layer_nodes;
}
virtual tensors apply(const tensors& input) const final
{
const auto result = apply_impl(input);
if (activation_ == nullptr)
return result;
else
return apply_activation_layer(activation_, result);
}
virtual tensor get_output(const layer_ptrs& layers,
output_dict& output_cache,
std::size_t node_idx, std::size_t tensor_idx) const
{
const node_connection conn(name_, node_idx, tensor_idx);
if (!fplus::map_contains(output_cache, conn.without_tensor_idx())) {
assertion(node_idx < nodes_.size(), "invalid node index");
output_cache[conn.without_tensor_idx()] = nodes_[node_idx].get_output(layers, output_cache, *this);
}
const auto& outputs = fplus::get_from_map_unsafe(
output_cache, conn.without_tensor_idx());
assertion(tensor_idx < outputs.size(),
"invalid tensor index");
return outputs[tensor_idx];
}
std::string name_;
nodes nodes_;
protected:
virtual tensors apply_impl(const tensors& input) const = 0;
activation_layer_ptr activation_;
};
inline tensor get_layer_output(const layer_ptrs& layers,
output_dict& output_cache,
const node_connection& conn)
{
return get_layer(layers, conn.layer_id_)->get_output(layers, output_cache, conn.node_idx_, conn.tensor_idx_);
}
inline tensors apply_layer(const layer& layer, const tensors& inputs)
{
return layer.apply(inputs);
}
inline layer_ptr get_layer(const layer_ptrs& layers,
const std::string& layer_id)
{
const auto is_matching_layer = [layer_id](const layer_ptr& ptr) -> bool {
return ptr->name_ == layer_id;
};
return fplus::throw_on_nothing(
error("dangling layer reference: " + layer_id),
fplus::find_first_by(is_matching_layer, layers));
}
}
}
|