1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
|
/*******************************************************************************
* Copyright 2020-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
/// @example bnorm_u8_via_binary_postops.cpp
/// @copybrief bnorm_u8_via_binary_postops_cpp
/// > Annotated version: @ref bnorm_u8_via_binary_postops_cpp
///
/// @page bnorm_u8_via_binary_postops_cpp_short
/// Bnorm u8 via binary postops example.
///
/// @page bnorm_u8_via_binary_postops_cpp Bnorm u8 by binary post-ops example
/// The example implements the Batch normalization u8 via the following
/// operations: binary_sub(src, mean), binary_div(tmp_dst, variance),
/// binary_mul(tmp_dst, scale), binary_add(tmp_dst, shift).
///
/// Some key take-aways include:
///
/// * How tensors are implemented and submitted to primitives.
/// * How primitives are created.
/// * How to use multiple binary post operations.
/// * How to use different data types in binary.
///
/// @include bnorm_u8_via_binary_postops.cpp
#include <algorithm>
#include <cmath>
#include <iostream>
#include <string>
#include <vector>
#include "dnnl.hpp"
#include "example_utils.hpp"
using namespace dnnl;
void bnorm_u8_via_binary_postops(dnnl::engine::kind engine_kind) {
// Create execution dnnl::engine.
dnnl::engine engine(engine_kind, 0);
// Create dnnl::stream.
dnnl::stream engine_stream(engine);
// Tensor dimensions.
const memory::dim N = 3, // batch size
IC = 3, // channels
IH = 150, // tensor height
IW = 150; // tensor width
// Tensors dimensions.
memory::dims src_dims = {N, IC, IH, IW};
memory::dims params_dims = {1, IC, 1, 1};
// Allocate buffers.
std::vector<float> src_data(product(src_dims));
std::vector<float> mean_data(product(params_dims));
std::vector<float> variance_data(product(params_dims));
std::vector<float> scale_data(product(params_dims));
std::vector<float> shift_data(product(params_dims));
std::vector<float> oscale_data(product(params_dims));
// Initialize
std::generate(src_data.begin(), src_data.end(), []() {
static int i = 0;
return std::cos(i++ / 10.f);
});
std::generate(mean_data.begin(), mean_data.end(), []() {
static int i = 0;
return std::sin(i++ * 2.f);
});
std::generate(variance_data.begin(), variance_data.end(), []() {
static int i = 0;
float value = std::abs(std::sin(i++ * 4.f));
// Avoid division by zero. Variance should be positive.
return value == 0.f ? 1.f : value;
});
std::generate(scale_data.begin(), scale_data.end(), []() {
static int i = 0;
return std::sin(i++ * 6.f);
});
std::generate(shift_data.begin(), shift_data.end(), []() {
static int i = 0;
return std::sin(i++ * 8.f);
});
std::generate(
oscale_data.begin(), oscale_data.end(), []() { return 0.5f; });
// Create descriptors.
auto src_md = memory::desc(
src_dims, memory::data_type::u8, memory::format_tag::nhwc);
auto mean_md = memory::desc(
params_dims, memory::data_type::f32, memory::format_tag::nhwc);
auto variance_md = memory::desc(
params_dims, memory::data_type::f32, memory::format_tag::nhwc);
auto scale_md = memory::desc(
params_dims, memory::data_type::f32, memory::format_tag::nhwc);
auto shift_md = memory::desc(
params_dims, memory::data_type::f32, memory::format_tag::nhwc);
auto oscale_md = memory::desc(
params_dims, memory::data_type::f32, memory::format_tag::nhwc);
// Create src memory objects.
auto src_mem = memory(src_md, engine);
auto mean_mem = memory(mean_md, engine);
auto variance_mem = memory(variance_md, engine);
auto scale_mem = memory(scale_md, engine);
auto shift_mem = memory(shift_md, engine);
auto oscale_mem = memory(oscale_md, engine);
// Write data to memory object's handle.
write_to_dnnl_memory(src_data.data(), src_mem);
write_to_dnnl_memory(mean_data.data(), mean_mem);
write_to_dnnl_memory(variance_data.data(), variance_mem);
write_to_dnnl_memory(scale_data.data(), scale_mem);
write_to_dnnl_memory(shift_data.data(), shift_mem);
write_to_dnnl_memory(oscale_data.data(), oscale_mem);
// Bnorm operation with scale and shift
post_ops binary_ops;
// dst_tmp = dst_tmp / variance
binary_ops.append_binary(algorithm::binary_div, variance_md);
// dst_tmp = dst_tmp * scale
binary_ops.append_binary(algorithm::binary_mul, scale_md);
// dst_tmp = dst_tmp + shift
binary_ops.append_binary(algorithm::binary_add, shift_md);
// dst = dst_tmp * output_scale (only for re-quantization)
binary_ops.append_binary(algorithm::binary_mul, oscale_md);
primitive_attr binary_attr;
binary_attr.set_post_ops(binary_ops);
// Create primitive descriptor.
// dst_tmp = src - mean
auto binary_pd = binary::primitive_desc(engine, algorithm::binary_sub,
src_md, mean_md, src_md, binary_attr);
// Create the primitive.
auto binary_prim = binary(binary_pd);
// Primitive arguments.
std::unordered_map<int, memory> binary_args;
binary_args.insert({DNNL_ARG_SRC_0, src_mem});
binary_args.insert({DNNL_ARG_SRC_1, mean_mem});
// In-place mode (dst is src)
binary_args.insert({DNNL_ARG_DST, src_mem});
binary_args.insert(
{DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1, variance_mem});
binary_args.insert(
{DNNL_ARG_ATTR_MULTIPLE_POST_OP(1) | DNNL_ARG_SRC_1, scale_mem});
binary_args.insert(
{DNNL_ARG_ATTR_MULTIPLE_POST_OP(2) | DNNL_ARG_SRC_1, shift_mem});
binary_args.insert(
{DNNL_ARG_ATTR_MULTIPLE_POST_OP(3) | DNNL_ARG_SRC_1, oscale_mem});
// Primitive execution
binary_prim.execute(engine_stream, binary_args);
// Wait for the computation to finalize.
engine_stream.wait();
// Read data from memory object's handle.
read_from_dnnl_memory(src_data.data(), src_mem);
}
int main(int argc, char **argv) {
return handle_example_errors(
bnorm_u8_via_binary_postops, parse_engine_kind(argc, argv));
}
|