File: embedding_lookup_idx.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (57 lines) | stat: -rw-r--r-- 1,674 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#pragma once

#include <cstdint>

namespace caffe2 {

// clang-format off
/**
 * Embedding lookup with reduction.
 *
 * `input` of size data_size * block_size
 * `indices` of size index_size
 * `offsets` of size output_size
 * `weights` nullptr or array of size index_size
 * `out` of size output_size * block_size
 *
 * Behavior is roughly equivalent to pseudocode:
 *
 * pos = 0
 * for (i = 0..output_size-1)
 *   for (k = 0..block_size-1)
 *     out[i*block_size + k] = 0
 *   start_offset = offsets[i]
 *   end_offset = offsets[i+1]
 *   length = end_offset - start_offset
 *   for (j = start_offset..end_offset-1)
 *     for (k = 0..block_size-1)
 *       out[i*block_size + k] += input[indices[pos]*block_size + k] *
 *           (weights ? weights[IS_WEIGHT_POSITIONAL ? j - start_offset : pos] : 1.0)
 *     pos += 1
 *   if (normalize_weights && length > 0)
 *     for (k = 0..block_size-1)
 *       out[i*block_size + k] /= length
 *
 * TODO: make this API also take "offsets" rather than "lengths" to match the
 *       API for PyTorch's EmbeddingBag
 */
// clang-format on
template <
    typename IndexType,
    typename InType,
    typename OutType,
    bool IS_WEIGHT_POSITIONAL = false>
void EmbeddingLookupIdx(
    const std::int64_t block_size,
    const std::int64_t output_size,
    const std::int64_t index_size,
    const std::int64_t data_size,
    const InType* input,
    const IndexType* indices,
    const IndexType* offsets,
    const float* weights, // optional, can be null for non-weighted sum
    const float* scale_bias, // optional scale & bias params for uint8 input
    bool normalize_by_lengths,
    OutType* out);

} // namespace caffe2