File: infer_symbol_values.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (134 lines) | stat: -rw-r--r-- 5,043 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import re
from typing import Any, DefaultDict, Dict, List, Tuple, Union

import numpy as np
import sympy as sp

import torch


square_brackets_pattern = r"\[([^]]+)\]"
parentheses_pattern = r"\((.*?)\)"
s_pattern = r"s\d+"


def infer_symbol_values(
    symints: List[Union[torch.SymInt, int]],
    init_symints: List[Union[torch.SymInt, int]],
    symbol_idx_dict: Dict[str, int],
    padding_constraints: DefaultDict[torch.SymInt, List[Union[sp.Expr, int]]],
    constraint: str,
) -> None:
    if constraint.find("non-singleton") != -1:
        left_expression, right_expression = re.findall(parentheses_pattern, constraint)
        calculate_value(left_expression, right_expression, symints, symbol_idx_dict)

    elif constraint.find("first two dimensions of batch2 tensor to be") != -1:
        matches = re.findall(square_brackets_pattern, constraint)
        left_expression, right_expression = (
            matches[i].split(",")[1].strip() for i in (0, 1)
        )
        calculate_value(left_expression, right_expression, symints, symbol_idx_dict)

    elif constraint.find("a and b must have same reduction dim") != -1:
        matches = re.findall(square_brackets_pattern, constraint)
        left_expression = matches[0].split(",")[1].strip()
        right_expression = matches[1].split(",")[0].strip()
        calculate_value(left_expression, right_expression, symints, symbol_idx_dict)

    elif constraint.find("Split sizes add up to") != -1:
        match_1 = re.search(r"to\s+(.*?)\s+but", constraint)
        extracted_value_1 = match_1.group(1) if match_1 else None
        match_2 = re.search(r"of\s+(.*?)$", constraint)
        extracted_value_2 = match_2.group(1) if match_2 else None
        calculate_value(extracted_value_1, extracted_value_2, symints, symbol_idx_dict)

    elif constraint.find("is invalid for input of size") != -1:
        matches = re.findall(square_brackets_pattern, constraint)
        left_elements = matches[0].split(",")
        left_equation = sp.sympify(1)
        left_num = 1
        right_equation = sp.sympify(constraint.split("size")[1].strip())

        for left_element in left_elements:
            if sp.sympify(left_element) == sp.sympify("-1"):
                continue
            elif sp.sympify(left_element).is_number:
                left_num *= int(left_element)
            else:
                left_equation *= sp.sympify(left_element)
        right_equation = sp.cancel(right_equation / left_equation)

        right_vars = list(right_equation.free_symbols)
        for right_var in right_vars:
            if sp.sympify(right_var) == sp.sympify("s0"):
                right_equation = sp.cancel(right_equation / right_var)
                right_vars.remove(right_var)

        var = right_vars[0]
        idx = symbol_idx_dict[str(var)]
        if var not in padding_constraints:
            padding_constraints[var].append(right_equation)
        update_equation(
            symints,
            init_symints,
            padding_constraints,
            padding_constraints[var][0],  # type: ignore[arg-type]
            left_num,
            var,
            idx,
        )


def calculate_value(
    left_expression: Union[str, Any, None],
    right_expression: Union[str, Any, None],
    symints: List[Union[torch.SymInt, int]],
    symbol_idx_dict: Dict[str, int],
) -> None:
    var, val = solve_equation(left_expression, right_expression)
    idx = symbol_idx_dict[var]
    pre_equation = sp.sympify(f"{symints[idx]}")
    symints[idx] = pre_equation.subs(sp.sympify(var), val)


def solve_equation(
    left_expression: Union[str, Any, None],
    right_expression: Union[str, Any, None],
) -> Tuple[str, int]:
    expression = f"{left_expression} - {right_expression}"
    var = re.findall(s_pattern, expression)[0]
    if re.findall(parentheses_pattern, expression):
        sub_expression = re.findall(parentheses_pattern, expression)[0]
        var, coeff = sub_expression.split("//")
        x = sp.symbols("x")
        sub_equation = sp.sympify(f"{var} - {coeff} * {x}")
        modified_equation = (
            sp.sympify(x) + sp.sympify(expression) - sp.sympify(sub_expression)
        )

        solution = sp.solve((modified_equation, sub_equation), (x, var))
        return (var, int(solution[sp.sympify(var)]))
    else:
        solution = sp.solve(expression, var)
        val = int(solution[0])
        return (var, val)


def update_equation(
    symints: List[Union[torch.SymInt, int]],
    init_symints: List[Union[torch.SymInt, int]],
    padding_constraints: DefaultDict[torch.SymInt, List[Union[sp.Expr, int]]],
    init_eq: sp.Expr,
    new_mod_num: int,
    var: torch.SymInt,
    idx: int,
) -> None:
    padding_constraints[var].append(new_mod_num)
    mod_num = np.lcm.reduce(padding_constraints[var][1:])  # type: ignore[arg-type]
    eq = mod_num * init_symints[idx]
    eq_const = [arg for arg in init_eq.args if arg.is_number]
    if eq_const:
        rem = int(eq_const[0] % mod_num)
        eq -= rem
    symints[idx] = eq