1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
|
/*************************************************************************
* Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
#pragma once
#include "ErrCode.hpp"
#include "rccl/rccl.h"
#include "rccl_bfloat16.h"
#include "hip/hip_fp16.h"
namespace RcclUnitTesting
{
// Performs the various basic reduction operations
template <typename T>
T ReduceOp(ncclRedOp_t const op, T const A, T const B)
{
switch (op)
{
case ncclSum: return A + B;
case ncclProd: return A * B;
case ncclMax: return std::max(A, B);
case ncclMin: return std::min(A, B);
default:
ERROR("Unsupported reduction operator (%d)\n", op);
exit(0);
}
}
size_t DataTypeToBytes(ncclDataType_t const dataType);
// PtrUnion encapsulates a pointer of all the different supported datatypes
// NOTE: Currently half-precision float tests are unsupported due to half
// being supported on GPU only and not host
union PtrUnion
{
void* ptr;
int8_t* I1; // ncclInt8
uint8_t* U1; // ncclUint8
int32_t* I4; // ncclInt32
uint32_t* U4; // ncclUint32
int64_t* I8; // ncclInt64
uint64_t* U8; // ncclUint64
__half* F2; // ncclFloat16
float* F4; // ncclFloat32
double* F8; // ncclFloat64
rccl_bfloat16* B2; // ncclBfloat16
ErrCode Attach(void *ptr);
ErrCode Attach(PtrUnion ptrUnion);
ErrCode AllocateGpuMem(size_t const numBytes, bool const useManagedMem = false);
ErrCode AllocateCpuMem(size_t const numBytes);
ErrCode FreeGpuMem();
ErrCode FreeCpuMem();
ErrCode ClearGpuMem(size_t const numBytes);
ErrCode ClearCpuMem(size_t const numBytes);
ErrCode FillPattern(ncclDataType_t const dataType,
size_t const numElements,
int const globalRank,
bool const isGpuMem);
ErrCode Set(ncclDataType_t const dataType, int const idx, int valueI, double valueF);
ErrCode Get(ncclDataType_t const dataType, int const idx, int& valueI, double& valueF) const;
// Multiplies in-place each element by scalarsPerRank[rank]
ErrCode Scale(ncclDataType_t const dataType,
size_t const numElements,
PtrUnion const& scalarsPerRank,
int const rank);
// Reduces input into this PtrUnion
ErrCode Reduce(ncclDataType_t const dataType,
size_t const numElements,
PtrUnion const& inputCpu,
ncclRedOp_t const op);
// Divide each element by a integer value
ErrCode DivideByInt(ncclDataType_t const dataType,
size_t const numElements,
int const divisor);
// Compares for equality (fuzzy comparision for floating point types)
ErrCode IsEqual(ncclDataType_t const dataType,
size_t const numElements,
PtrUnion const& expected,
bool const verbose,
bool& isMatch);
// Output to string (for debug)
std::string ToString(ncclDataType_t const dataType,
size_t const numElements) const;
};
}
|