1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
|
// Copyright 2016, Tobias Hermann.
// https://github.com/Dobiasd/frugally-deep
// Distributed under the MIT License.
// (See accompanying LICENSE file or at
// https://opensource.org/licenses/MIT)
#pragma once
#include "fdeep/common.hpp"
#include "fdeep/tensor.hpp"
#include "fdeep/tensor_shape.hpp"
#include <cassert>
#include <cstddef>
#include <vector>
namespace fdeep {
namespace internal {
class filter {
public:
filter(const tensor& m, float_type bias)
: m_(m)
, bias_(bias)
{
}
const tensor_shape& shape() const
{
return m_.shape();
}
std::size_t volume() const
{
return m_.shape().volume();
}
const tensor& get_tensor() const
{
return m_;
}
float_type get(const tensor_pos& pos) const
{
return m_.get_ignore_rank(pos);
}
float_type get_bias() const
{
return bias_;
}
void set_params(const float_vec& weights, float_type bias)
{
assertion(weights.size() == m_.shape().volume(),
"invalid parameter count");
m_ = tensor(m_.shape(), float_vec(weights));
bias_ = bias;
}
private:
tensor m_;
float_type bias_;
};
typedef std::vector<filter> filter_vec;
inline filter dilate_filter(const shape2& dilation_rate, const filter& undilated)
{
return filter(dilate_tensor(dilation_rate, undilated.get_tensor(), false),
undilated.get_bias());
}
inline filter_vec generate_filters(
const shape2& dilation_rate,
const tensor_shape& filter_shape, std::size_t k,
const float_vec& weights, const float_vec& bias,
bool transpose)
{
filter_vec filters(k, filter(tensor(filter_shape, 0), 0));
assertion(!filters.empty(), "at least one filter needed");
const std::size_t param_count = fplus::sum(fplus::transform(
fplus_c_mem_fn_t(filter, volume, std::size_t), filters));
assertion(static_cast<std::size_t>(weights.size()) == param_count,
"invalid weight size");
const auto filter_param_cnt = filters.front().shape().volume();
auto filter_weights = fplus::split_every(filter_param_cnt, weights);
assertion(filter_weights.size() == filters.size(),
"invalid size of filter weights");
assertion(bias.size() == filters.size(), "invalid bias size");
auto it_filter_val = std::begin(filter_weights);
auto it_filter_bias = std::begin(bias);
for (auto& filt : filters) {
filt.set_params(*it_filter_val, *it_filter_bias);
filt = dilate_filter(dilation_rate, filt);
if (transpose) {
filt = filter(reverse_height_dimension(filt.get_tensor()), filt.get_bias());
filt = filter(reverse_width_dimension(filt.get_tensor()), filt.get_bias());
}
++it_filter_val;
++it_filter_bias;
}
return filters;
}
}
}
|