1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
|
// Copyright 2016, Tobias Hermann.
// https://github.com/Dobiasd/frugally-deep
// Distributed under the MIT License.
// (See accompanying LICENSE file or at
// https://opensource.org/licenses/MIT)
#pragma once
#include "fdeep/common.hpp"
#include "fdeep/tensor.hpp"
#include "fdeep/layers/layer.hpp"
#include <algorithm>
#include <cstddef>
#include <memory>
#include <string>
namespace fdeep {
namespace internal {
class model_layer : public layer {
public:
explicit model_layer(const std::string& name,
const layer_ptrs& layers,
const node_connections& input_connections,
const node_connections& output_connections)
: layer(name)
, layers_(layers)
, input_connections_(input_connections)
, output_connections_(output_connections)
{
assertion(fplus::all_unique(
fplus::transform(fplus_get_ptr_mem(name_), layers)),
"layer names must be unique");
}
tensor get_output(const layer_ptrs& layers, output_dict& output_cache,
std::size_t node_idx, std::size_t tensor_idx) const override
{
// https://stackoverflow.com/questions/46011749/understanding-keras-model-architecture-node-index-of-nested-model
if (node_idx >= 1) {
node_idx = node_idx - 1;
}
assertion(node_idx < nodes_.size(), "invalid node index: " + std::to_string(node_idx) + " of " + std::to_string(nodes_.size()));
return layer::get_output(layers, output_cache, node_idx, tensor_idx);
}
protected:
tensors apply_impl(const tensors& inputs) const override
{
output_dict output_cache;
assertion(inputs.size() == input_connections_.size(),
"invalid number of input tensors for this model: " + fplus::show(input_connections_.size()) + " required but " + fplus::show(inputs.size()) + " provided");
for (std::size_t i = 0; i < inputs.size(); ++i) {
output_cache[input_connections_[i].without_tensor_idx()] = { inputs[i] };
}
const auto get_output = [this, &output_cache](const node_connection& conn) -> tensor {
return get_layer_output(layers_, output_cache, conn);
};
return fplus::transform(get_output, output_connections_);
}
layer_ptrs layers_;
node_connections input_connections_;
node_connections output_connections_;
};
}
}
|