File: PtrUnion.hpp

package info (click to toggle)
rccl 5.4.3-3
  • links: PTS, VCS
  • area: main
  • in suites: sid, trixie
  • size: 4,332 kB
  • sloc: cpp: 33,357; ansic: 6,717; xml: 5,265; makefile: 508; sh: 365; awk: 243; python: 85
file content (98 lines) | stat: -rw-r--r-- 3,504 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
/*************************************************************************
 * Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved.
 *
 * See LICENSE.txt for license information
 ************************************************************************/

#pragma once
#include "ErrCode.hpp"
#include "rccl/rccl.h"
#include "rccl_bfloat16.h"
#include "hip/hip_fp16.h"

namespace RcclUnitTesting
{
  // Performs the various basic reduction operations
  template <typename T>
  T ReduceOp(ncclRedOp_t const op, T const A, T const B)
  {
    switch (op)
    {
    case ncclSum:  return A + B;
    case ncclProd: return A * B;
    case ncclMax:  return std::max(A, B);
    case ncclMin:  return std::min(A, B);
    default:
      ERROR("Unsupported reduction operator (%d)\n", op);
      exit(0);
    }
  }

  size_t DataTypeToBytes(ncclDataType_t const dataType);

  // PtrUnion encapsulates a pointer of all the different supported datatypes
  // NOTE: Currently half-precision float tests are unsupported due to half
  //       being supported on GPU only and not host
  union PtrUnion
  {
    void*          ptr;
    int8_t*        I1; // ncclInt8
    uint8_t*       U1; // ncclUint8
    int32_t*       I4; // ncclInt32
    uint32_t*      U4; // ncclUint32
    int64_t*       I8; // ncclInt64
    uint64_t*      U8; // ncclUint64
    __half*        F2; // ncclFloat16
    float*         F4; // ncclFloat32
    double*        F8; // ncclFloat64
    rccl_bfloat16* B2; // ncclBfloat16

    ErrCode Attach(void *ptr);
    ErrCode Attach(PtrUnion ptrUnion);

    ErrCode AllocateGpuMem(size_t const numBytes, bool const useManagedMem = false);
    ErrCode AllocateCpuMem(size_t const numBytes);

    ErrCode FreeGpuMem();
    ErrCode FreeCpuMem();

    ErrCode ClearGpuMem(size_t const numBytes);
    ErrCode ClearCpuMem(size_t const numBytes);

    ErrCode FillPattern(ncclDataType_t const dataType,
                        size_t         const numElements,
                        int            const globalRank,
                        bool           const isGpuMem);

    ErrCode Set(ncclDataType_t const dataType, int const idx, int valueI, double valueF);
    ErrCode Get(ncclDataType_t const dataType, int const idx, int& valueI, double& valueF) const;

    // Multiplies in-place each element by scalarsPerRank[rank]
    ErrCode Scale(ncclDataType_t const  dataType,
                  size_t         const  numElements,
                  PtrUnion       const& scalarsPerRank,
                  int            const  rank);

    // Reduces input into this PtrUnion
    ErrCode Reduce(ncclDataType_t const  dataType,
                   size_t         const  numElements,
                   PtrUnion       const& inputCpu,
                   ncclRedOp_t    const  op);

    // Divide each element by a integer value
    ErrCode DivideByInt(ncclDataType_t const dataType,
                        size_t         const numElements,
                        int            const divisor);

    // Compares for equality (fuzzy comparision for floating point types)
    ErrCode IsEqual(ncclDataType_t const  dataType,
                    size_t         const  numElements,
                    PtrUnion       const& expected,
                    bool           const  verbose,
                    bool&                 isMatch);

    // Output to string (for debug)
    std::string ToString(ncclDataType_t const  dataType,
                         size_t         const  numElements) const;
  };
}