1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214
|
"""
Python polyfills for common builtins.
"""
# NOTE: 1. Please do not import any submodule in the directory here to avoid circular imports.
# 2. While adding a new polyfill module, also add it to POLYFILLED_MODULE_NAMES in loader.py.
# Add it in the TYPE_CHECKING block below as well.
# mypy: allow-untyped-defs
from typing import Any, Callable, Sequence, TYPE_CHECKING
import torch
if TYPE_CHECKING:
# Load by torch._dynamo.polyfills.loader
# See also the POLYFILLED_MODULE_NAMES in torch/_dynamo/polyfills/loader.py
# Put the submodules here to avoid circular imports
from . import (
builtins as builtins,
functools as functools,
itertools as itertools,
operator as operator,
os as os,
pytree as pytree,
sys as sys,
)
from torch.overrides import BaseTorchFunctionMode
# These classes handle support for TorchFunctionModes across
# graph breaks
# Today the TorchFunctionMode enter (for the classes we support)
# simply pushes the mode onto the stack. Since after this occurs
# the stack is mutated, and we replay these mutations, we don't need
# any cleanup logic to be run once the graph break occurs, we simply replay
# these mutations to ensure at the graph break the torch function mode stack is correct
# and reconstruct the torch function mode stack normally
# when we compile the resume function on the other side of the break.
# However, to ensure we exit properly
# in the resume function, we need to re-enter the contexts as we do other contexts.
# These contexts do nothing on enter, but provide the correct exit logic to ensure
# the stack state is correct.
class NoEnterTorchFunctionMode(BaseTorchFunctionMode):
def __enter__(self):
pass
def index(iterator, item, start=0, end=None):
from itertools import islice
for i, elem in islice(enumerate(iterator), start, end):
if item == elem:
return i
# This will not run in dynamo
raise ValueError(f"{item} is not in {type(iterator)}")
def repeat(item, count):
for i in range(count):
yield item
def radians(x):
import math
return math.pi / 180.0 * x
def accumulate_grad(x, new_grad):
new_grad = torch.clone(new_grad)
if x.grad is None:
x.grad = new_grad
else:
x.grad.add_(new_grad)
def list_cmp(op: Callable[[Any, Any], bool], left: Sequence[Any], right: Sequence[Any]):
"""emulate `(1,2,3) > (1,2)` etc"""
for a, b in zip(left, right):
if a != b:
return op(a, b)
return op(len(left), len(right))
def set_isdisjoint(set1, set2):
for x in set1:
if x in set2:
return False
return True
def set_intersection(set1, set2):
intersection_set = set()
for x in set1:
if x in set2:
intersection_set.add(x)
return intersection_set
def set_union(set1, set2):
union_set = set1.copy()
set_update(union_set, set2)
return union_set
def set_update(set1, set2):
for x in set2:
if x not in set1:
set1.add(x)
return set1
def set_difference(set1, set2):
difference_set = set()
for x in set1:
if x not in set2:
difference_set.add(x)
return difference_set
def dropwhile(predicate, iterable):
# dropwhile(lambda x: x<5, [1,4,6,4,1]) -> 6 4 1
iterable = iter(iterable)
for x in iterable:
if not predicate(x):
yield x
break
yield from iterable
def zip_longest(*iterables, fillvalue=None):
# Create a list of iterators from the input iterables
iterators = [iter(it) for it in iterables]
result = []
while True:
row = []
active = False
for it in iterators:
try:
# Try to get the next item from the iterator
value = next(it)
row.append(value)
active = True
except StopIteration:
# If the iterator is exhausted, use the fillvalue
row.append(fillvalue)
if not active:
break
result.append(tuple(row))
return result
def getattr_and_trace(*args, **kwargs):
wrapper_obj = args[0]
attr_name = args[1]
fn = getattr(wrapper_obj, attr_name)
return fn(*args[2:], **kwargs)
def mapping_get(obj, key, value=None):
try:
return obj.__getitem__(key)
except KeyError:
return value
def instantiate_user_defined_class_object(cls, /, *args, **kwargs):
obj = cls.__new__(cls, *args, **kwargs)
# Only call __init__ if the object is an instance of the class
# Reference: https://github.com/python/cpython/blob/3.12/Objects/typeobject.c#L1670-L1673
if isinstance(obj, cls):
obj.__init__(*args, **kwargs)
return obj
def foreach_lerp_inplace(self, end, weight):
# decompose foreach lerp into constituent ops, prevents a graph break due to
# converting a value to a scalar when arg[2] is a single tensor
result = torch._foreach_sub(end, self)
result = torch._foreach_mul(result, weight)
return torch._foreach_add_(self, result)
def foreach_pow_scalar(scalar, exps):
return torch._foreach_pow([scalar for _ in exps], exps)
def addcmul_inplace(self, tensor1, tensor2, value):
return self.add_(tensor1 * tensor2 * value)
def predicate(obj: Any) -> bool:
# This will cause the rest of dynamo to handle the if statement correctly, so we don't have to rewrite it here.
# We can't just use bool() here since we can't trace into that in general.
if obj:
return True
return False
def object_eq(self, other):
# Mirrors CPython implementation:
# https://github.com/python/cpython/blob/a1c52d1265c65bcf0d9edf87e143843ad54f9b8f/Objects/typeobject.c#L6228-L6233
return self is other
def object_ne(self, other):
# Mirrors CPython implementation:
# https://github.com/python/cpython/blob/a1c52d1265c65bcf0d9edf87e143843ad54f9b8f/Objects/typeobject.c#L6235-L6255
# Using `==` is important because `self` might have a user-defined `__eq__`.
return not (self == other)
|