1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
|
#pragma once
#include <c10/util/Exception.h>
namespace torch {
namespace jit {
namespace fuser {
// Common Functions
constexpr int64_t ceilDiv(int64_t a, int64_t b) {
return (a + b - 1) / b;
}
// Simple mixin for suppressing copy & move operations, ex:
//
// class Foo : public NonCopyable {
// ...
// };
//
class NonCopyable {
public:
NonCopyable() = default;
// No copy/move semantics
NonCopyable(const NonCopyable&) = delete;
NonCopyable& operator=(const NonCopyable&) = delete;
};
// A generic root for a hierarchy of polymorphic classes:
// - It ensures virtual destructors
// - Provides the base->as<Derived>() and node->isA<T>() notation
class PolymorphicBase {
public:
virtual ~PolymorphicBase() = default;
// Replacement for static_cast<T*>(ptr): ptr->as<T>()
// (checked in DEBUG builds)
template <class T>
T* as() {
#ifdef NDEBUG
auto downcast_ptr = static_cast<T*>(this);
#else
auto downcast_ptr = dynamic_cast<T*>(this);
TORCH_INTERNAL_ASSERT(downcast_ptr != nullptr);
#endif
return downcast_ptr;
}
template <class T>
const T* as() const {
#ifdef NDEBUG
auto downcast_ptr = static_cast<const T*>(this);
#else
auto downcast_ptr = dynamic_cast<const T*>(this);
TORCH_INTERNAL_ASSERT(downcast_ptr != nullptr);
#endif
return downcast_ptr;
}
// Check if the runtime time is T (or derived from T)
//
// NOTE: Don't use this for conditional casts. Use:
//
// if (auto t = dynamic_cast<T>(p)) { ... }
//
// instead of:
//
// if (p->isA<T>()) { auto t = p->as<T>(); ... }
//
template <class T>
bool isA() const {
return dynamic_cast<const T*>(this) != nullptr;
}
};
} // namespace fuser
} // namespace jit
} // namespace torch
|