File: node.py

package info (click to toggle)
nvidia-cutlass 3.4.1%2Bds-2
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 48,488 kB
  • sloc: cpp: 206,571; ansic: 69,215; python: 25,487; sh: 16; makefile: 15
file content (293 lines) | stat: -rw-r--r-- 10,348 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
#################################################################################################
#
# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################

"""
Base & visitor classes of DAGIR Nodes
"""

import ctypes
from re import sub

from cutlass_library import LayoutType

from cutlass.backend.evt.ir.layout_algorithm import _list_to_tuple, _reverse_tuple
from cutlass.backend.evt.ir.tensor import Tensor


class ImplBase:
    """
    Base class for Node Implementation
    """
    def __init__(self, node) -> None:
        self.node = node
        self.name = node.name
        self.tensor = node.tensor
        self._type_decl = None
        self.stride_dtype = "int64_t"

    @staticmethod
    def match(node, problem_size: tuple):
        """
        Match function used in get_underlying_impl
        """
        raise NotImplementedError(f"The `match` function is not defined.")

    @property
    def argument_type(self):
        """
        Default class for Argument Type
        """
        class _Argument(ctypes.Structure):
            _fields_ = []

            def __init__(self, *args, **kwargs) -> None:
                pass

        return _Argument

    @property
    def name_camel(self) -> str:
        """
        Return the CamelCase name.
        """
        return sub(r"(_|-)+", " ", self.name).title().replace(" ", "")

    def _emit_cute_tuple(self, py_tuple):
        """
        Emit the cute tuple to C++ code
        """
        if isinstance(py_tuple, int):
            if py_tuple in [0, 1]:
                return f"cute::Int<{py_tuple}>"
            else:
                return f"{self.stride_dtype}"
        elif isinstance(py_tuple, tuple):
            decl = "cute::Stride<"
            for item in py_tuple:
                decl += self._emit_cute_tuple(item) + ", "
            return decl[:-2] + ">"
        else:
            raise ValueError(f"_emit_cute_tuple only accepts tuple or int, got {type(py_tuple).__name__}")

    @property
    def stride_mnl(self):
        """
        Typename StrideMNL
        """
        stride = _list_to_tuple([self.stride[-2], self.stride[-1]] + list(_reverse_tuple(tuple(self.stride[:-2]))))
        return self._emit_cute_tuple(stride)

    def get_non_constant_stride(self, py_tuple):
        if isinstance(py_tuple, int):
            if py_tuple not in [0, 1]:
                return py_tuple
            else:
                return None
        non_constant_stride = []
        for item in py_tuple:
            item_out = self.get_non_constant_stride(item)
            if item_out:
                non_constant_stride.append(item_out)
        return tuple(non_constant_stride)

    def get_stride_mnl(self):
        """
        Get the non-zero stride mnl. This is used in argument construction
        """
        stride = _list_to_tuple([self.stride[-2], self.stride[-1]] + list(_reverse_tuple(tuple(self.stride[:-2]))))
        return stride

    def get_smem_size(self, *args, **kwargs):
        """
        Get the shared memory size and alignment of current node
        """
        return (0, 1)


class NoOpImpl(ImplBase):
    """
    The NoOpImpl does nothing but forward its input to users
    """
    def __init__(self, node) -> None:
        super().__init__(node)

    @staticmethod
    def match(node, problem_size: tuple):
        if node.op == "store":
            # Store that is not output is a No OP
            return not node.is_output


class NodeBase:
    """
    Base class of DAG Node
    """
    def __init__(self, name: str) -> None:
        self.name = name
        self.underlying_impl = None

        self._tensor = None

        # Whether the node is disabled for emit
        self.disabled = False

    @property
    def name_camel(self) -> str:
        """
        Return the CamelCase name.
        """
        return self.underlying_impl.name_camel

    @property
    def tensor(self) -> Tensor:
        """
        Return the output tensor (concept: cutlass.backend.evt.ir.tensor)
        """
        return self._tensor

    @tensor.setter
    def tensor(self, kwargs):
        """
        Setting the tensor
        """
        self._tensor = Tensor(**kwargs)

    #
    # Helper functions for type/shape propagation
    #

    def shape_propagation(self, input_node_metas):
        """
        Infer shape from input nodes
        General Broadcasting Rules from NumPy
        When operating on two arrays, we compare their shapes element-wise.
        It starts with the trailing (i.e. rightmost) dimension and works its
        way left. Two dimensions are compatible when
        1. they are equal
        2. one of them is 1
        """
        if self._tensor is not None:
            return

        shape = None
        for src in input_node_metas:
            src_shape = src.tensor.shape
            if shape is None:
                shape = src_shape
            else:
                len_difference = len(shape) - len(src_shape)
                if len_difference > 0:
                    for _ in range(len_difference):
                        src_shape = [1, ] + list(src_shape)
                elif len_difference < 0:
                    for _ in range(-len_difference):
                        shape = [1, ] + list(shape)
                broadcasted_shape = []
                # Infer broadcast shape
                for shape_dim, src_dim in zip(reversed(shape), reversed(src_shape)):
                    if shape_dim == 1:
                        broadcasted_shape = [src_dim, ] + list(broadcasted_shape)
                    elif src_dim == 1:
                        broadcasted_shape = [shape_dim, ] + list(broadcasted_shape)
                    elif shape_dim == src_dim:
                        broadcasted_shape = [shape_dim, ] + list(broadcasted_shape)
                    else:
                        error_msg = "Dimension mismatch between "
                        for src_ in input_node_metas:
                            error_msg += f"{src_.name}{src_.tensor.shape}, "
                        error_msg = error_msg[:-2] + "."
                        raise RuntimeError(error_msg)
                shape = tuple(broadcasted_shape)

        self._tensor = Tensor(element=self.element_output, shape=shape, layout_tag=LayoutType.RowMajor)

    def type_propagation(self, *args, **kwargs):
        """
        Each node is associated with two data types: `element` and `element_output`.
        The `element_output` is the type of return array of the node. The `element`
        has specific meaning for different node types.
        * Load Node: data type of tensor in gmem
        * Compute Node: element compute
        * Store Node: data type of tensor in gmem
        This function must be overloaded in the derived classes
        """
        raise NotImplementedError(f"Function `type_propagation` is not overloaded in {self.__class__.__name__}")

    def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'):
        """
        Propagate the broadcast in the reversed topological order.
        For example:
            C[l, m, n] = A[m, 1] + B[l, m, n]
        After the broadcast propagation, it will be come
            C[l, m, n] = A[l, m, n] + B[l, m, n]
        and each tensor will have a proper stride accessing the underlying tensor
        """
        if self.tensor is None:
            raise RuntimeError(f"The tensor of node {self.name} is unknown.")
        for child in input_node_metas:
            child.tensor.broadcast(self.tensor.shape)

    def get_underlying_impl(self, problem_size: tuple):
        """
        Get the underlying implementation of the current node.
        """
        if self.tensor is None:
            raise RuntimeError(f"The Layout of node {self.name} is unknown. Please call PassShapeTypePropagation first.")

        for impl in self.possible_impls:
            if impl.match(self, problem_size):
                self.underlying_impl = impl(self)
                break

        if self.underlying_impl is None:
            raise NotImplementedError(f"No matching op for node {self.name} with stride {self.tensor.stride}.")

#
# Visitor Nodes & Impls
#

class TopoVisitorImpl(ImplBase):
    """
    Impl for topological visitor
    """
    def __init__(self, node) -> None:
        super().__init__(node.output_node)
        self.name = node.name
        self.element_output = node.output_node.element_output

class TopoVisitorNode(NodeBase):
    def __init__(self, name: str, subgraph, output_node) -> None:
        super().__init__(name)
        self.subgraph = subgraph
        self.output_node = output_node
        self.op = "dag"
        self.underlying_impl = TopoVisitorImpl(self)