File: prelu.cpp

package info (click to toggle)
onednn 3.9.1%2Bds-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 79,124 kB
  • sloc: cpp: 850,217; ansic: 37,403; lisp: 16,757; python: 3,463; asm: 831; sh: 78; javascript: 66; makefile: 41
file content (143 lines) | stat: -rw-r--r-- 5,418 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
/*******************************************************************************
* Copyright 2020-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

/// @example prelu.cpp
/// > Annotated version: @ref prelu_example_cpp
///
/// @page prelu_example_cpp_short
///
/// This C++ API example demonstrates how to create and execute an
/// [PReLU](@ref dev_guide_prelu) primitive in forward training
/// propagation mode.
///
/// @page prelu_example_cpp Primitive Example
/// @copydetails prelu_example_cpp_short
///
/// @include prelu.cpp

#include <algorithm>
#include <cmath>
#include <string>
#include <vector>

#include "dnnl.hpp"
#include "example_utils.hpp"

using namespace dnnl;

void prelu_example(dnnl::engine::kind engine_kind) {

    // Create execution dnnl::engine.
    dnnl::engine engine(engine_kind, 0);

    // Create dnnl::stream.
    dnnl::stream engine_stream(engine);

    // Tensor dimensions.
    const memory::dim N = 3, // batch size
            IC = 3, // channels
            IH = 227, // tensor height
            IW = 227; // tensor width

    // Source (src), weights and destination (dst) tensors dimensions.
    const memory::dims src_dims = {N, IC, IH, IW};
    const memory::dims weights_dims = {N, IC, IH, IW};
    const memory::dims dst_dims = {N, IC, IH, IW};

    // Allocate buffers. In this example, out-of-place primitive execution is
    // demonstrated since both src and dst are required for later backward
    // propagation.
    std::vector<float> src_data(product(src_dims));
    std::vector<float> weights_data(product(weights_dims));
    std::vector<float> dst_data(product(dst_dims));

    // Initialize src tensor.
    std::generate(src_data.begin(), src_data.end(), []() {
        static int i = 0;
        return std::cos(i++ / 10.f);
    });

    // Initialize weights tensor.
    std::fill(weights_data.begin(), weights_data.end(), 0.3f);

    // Create memory objects for tensor data (src, weights, dst). In this
    // example, NCHW layout is assumed for src, weights and dst.
    auto user_src_mem = memory(
            {src_dims, memory::data_type::f32, memory::format_tag::nchw},
            engine);
    auto user_weights_mem = memory(
            {weights_dims, memory::data_type::f32, memory::format_tag::nchw},
            engine);
    auto user_dst_mem = memory(
            {dst_dims, memory::data_type::f32, memory::format_tag::nchw},
            engine);

    // Create memory descriptors for the primitive. Src tag is set
    // to match src memory object. Setting weights tag to format_tag::any
    // enables the PReLU primitive to choose memory layout for an optimized
    // primitive implementation, and that layout may differ from the one
    // provided by the user.
    auto src_md = memory::desc(
            src_dims, memory::data_type::f32, memory::format_tag::nchw);
    auto weights_md = memory::desc(
            weights_dims, memory::data_type::f32, memory::format_tag::any);
    auto dst_md = memory::desc(
            src_dims, memory::data_type::f32, memory::format_tag::any);

    // Write data to memory object's handle.
    write_to_dnnl_memory(src_data.data(), user_src_mem);
    write_to_dnnl_memory(weights_data.data(), user_weights_mem);

    // Create primitive descriptor.
    auto prelu_pd = prelu_forward::primitive_desc(
            engine, prop_kind::forward_training, src_md, weights_md, dst_md);

    // For now, assume that the weights memory layout generated
    // by the primitive and the one provided by the user are identical.
    auto prelu_weights_mem = user_weights_mem;

    // Reorder the data in case the weights memory layout generated by
    // the primitive and the one provided by the user are different. In this
    // case, we create additional memory object with internal buffers that will
    // contain the reordered data.
    if (prelu_pd.weights_desc() != user_weights_mem.get_desc()) {
        prelu_weights_mem = memory(prelu_pd.weights_desc(), engine);
        reorder(user_weights_mem, prelu_weights_mem)
                .execute(engine_stream, user_weights_mem, prelu_weights_mem);
    }

    // Create the primitive.
    auto prelu_prim = prelu_forward(prelu_pd);

    // Primitive arguments.
    std::unordered_map<int, memory> prelu_args;
    prelu_args.insert({DNNL_ARG_SRC, user_src_mem});
    prelu_args.insert({DNNL_ARG_WEIGHTS, prelu_weights_mem});
    prelu_args.insert({DNNL_ARG_DST, user_dst_mem});

    // Primitive execution: PReLU.
    prelu_prim.execute(engine_stream, prelu_args);

    // Wait for the computation to finalize.
    engine_stream.wait();

    // Read data from memory object's handle.
    read_from_dnnl_memory(dst_data.data(), user_dst_mem);
}

int main(int argc, char **argv) {
    return handle_example_errors(prelu_example, parse_engine_kind(argc, argv));
}