File: resharding-spmdization.mlir

package info (click to toggle)
swiftlang 6.1.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 2,791,604 kB
  • sloc: cpp: 9,901,740; ansic: 2,201,431; asm: 1,091,827; python: 308,252; objc: 82,166; f90: 80,126; lisp: 38,358; pascal: 25,559; sh: 20,429; ml: 5,058; perl: 4,745; makefile: 4,484; awk: 3,535; javascript: 3,018; xml: 918; fortran: 664; cs: 573; ruby: 396
file content (148 lines) | stat: -rw-r--r-- 7,593 bytes parent folder | download | duplicates (4)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
// RUN: mlir-opt -test-mesh-resharding-spmdization %s | FileCheck %s

mesh.mesh @mesh_1d(shape = 2)
mesh.mesh @mesh_1d_dynamic(shape = ?)

// CHECK-LABEL: func @same_source_and_target_sharding
func.func @same_source_and_target_sharding(
  // CHECK-SAME: %[[ARG:.*]]: tensor<2xf32>
  %arg0: tensor<2xf32>
) -> tensor<2xf32> {
  %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<2xf32>
  %1 = mesh.shard %0 to <@mesh_1d, [[]]> annotate_for_users : tensor<2xf32>
  // CHECK: return %[[ARG]]
  return %1 : tensor<2xf32>
}

// CHECK-LABEL: func @split_replicated_tensor_axis
func.func @split_replicated_tensor_axis(
  // CHECK-SAME: %[[ARG:.*]]: tensor<3x14xf32>
  %arg0: tensor<3x14xf32>
) -> tensor<3x14xf32> {
  // CHECK: %[[ALL_SLICE:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d mesh_axes = [0] slice_axis = 1
  // CHECK-SAME: tensor<3x14xf32> -> tensor<3x7xf32>
  // CHECK: %[[RESULT:.*]] = builtin.unrealized_conversion_cast %[[ALL_SLICE]] : tensor<3x7xf32> to tensor<3x14xf32>
  %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<3x14xf32>
  %1 = mesh.shard %0 to <@mesh_1d, [[], [0]]> annotate_for_users : tensor<3x14xf32>
  // CHECK: return %[[RESULT]] : tensor<3x14xf32>
  return %1 : tensor<3x14xf32>
}

// CHECK-LABEL: func @split_replicated_tensor_axis_dynamic
func.func @split_replicated_tensor_axis_dynamic(
  // CHECK-SAME: %[[ARG:.*]]: tensor<?x3x?xf32>
  %arg0: tensor<?x3x?xf32>
) -> tensor<?x3x?xf32> {
  // CHECK: %[[RESULT:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d_dynamic mesh_axes = [0] slice_axis = 0
  // CHECK-SAME: tensor<?x3x?xf32> -> tensor<?x3x?xf32>
  %0 = mesh.shard %arg0 to <@mesh_1d_dynamic, [[], [], []]> : tensor<?x3x?xf32>
  %1 = mesh.shard %0 to <@mesh_1d_dynamic, [[0]]> annotate_for_users : tensor<?x3x?xf32>
  // CHECK: return %[[RESULT]] : tensor<?x3x?xf32>
  return %1 : tensor<?x3x?xf32>
}

// CHECK-LABEL: func @move_split_axis
func.func @move_split_axis(
  // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
  %arg0: tensor<10x14xf32>
) -> tensor<10x14xf32> {
  // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32>
  // CHECK: %[[TARGET_SHARD:.*]] = mesh.all_to_all %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<5x14xf32> -> tensor<10x7xf32>
  // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x7xf32> to tensor<10x14xf32>
  %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<10x14xf32>
  %1 = mesh.shard %0 to <@mesh_1d, [[], [0]]> annotate_for_users : tensor<10x14xf32>
  // CHECK: return %[[RES]] : tensor<10x14xf32>
  return %1 : tensor<10x14xf32>
}

// CHECK-LABEL: func @move_split_axis_dynamic_mesh
func.func @move_split_axis_dynamic_mesh(
  // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
  %arg0: tensor<10x14xf32>
) -> tensor<10x14xf32> {
  // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<?x14xf32>
  // CHECK: %[[ALL_TO_ALL:.*]] = mesh.all_to_all %[[SOURCE_SHARD]] on @mesh_1d_dynamic mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<?x14xf32> -> tensor<?x?xf32>
  // CHECK: %[[TARGET_SHARD:.*]] = tensor.cast %[[ALL_TO_ALL]] : tensor<?x?xf32> to tensor<10x?xf32>
  // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x?xf32> to tensor<10x14xf32>
  %0 = mesh.shard %arg0 to <@mesh_1d_dynamic, [[0]]> : tensor<10x14xf32>
  %1 = mesh.shard %0 to <@mesh_1d_dynamic, [[], [0]]> annotate_for_users : tensor<10x14xf32>
  // CHECK: return %[[RES]] : tensor<10x14xf32>
  return %1 : tensor<10x14xf32>
}

// CHECK-LABEL: func @move_split_dynamic_axis
func.func @move_split_dynamic_axis(
  // CHECK-SAME: %[[ARG:.*]]: tensor<?x14xf32>
  %arg0: tensor<?x14xf32>
) -> tensor<?x14xf32> {
  // CHECK: %[[TARGET_SHARD:.*]] = mesh.all_to_all %[[ARG]] on @mesh_1d mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<?x14xf32> -> tensor<?x7xf32>
  // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<?x7xf32> to tensor<?x14xf32>
  %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<?x14xf32>
  %1 = mesh.shard %0 to <@mesh_1d, [[], [0]]> annotate_for_users : tensor<?x14xf32>
  // CHECK: return %[[RES]] : tensor<?x14xf32>
  return %1 : tensor<?x14xf32>
}

// CHECK-LABEL: func @unshard_static_axis
func.func @unshard_static_axis(
  // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
  %arg0: tensor<10x14xf32>
) -> tensor<10x14xf32> {
  // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32>
  // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<5x14xf32> -> tensor<10x14xf32>
  %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<10x14xf32>
  %1 = mesh.shard %0 to <@mesh_1d, [[]]> annotate_for_users : tensor<10x14xf32>
  // CHECK: return %[[ALL_GATHER]] : tensor<10x14xf32>
  return %1 : tensor<10x14xf32>
}

// CHECK-LABEL: func @unshard_static_last_axis
func.func @unshard_static_last_axis(
  // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
  %arg0: tensor<10x14xf32>
) -> tensor<10x14xf32> {
  // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<10x7xf32>
  // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] gather_axis = 1 : tensor<10x7xf32> -> tensor<10x14xf32>
  %0 = mesh.shard %arg0 to <@mesh_1d, [[], [0]]> : tensor<10x14xf32>
  %1 = mesh.shard %0 to <@mesh_1d, [[], []]> annotate_for_users : tensor<10x14xf32>
  // CHECK: return %[[ALL_GATHER]] : tensor<10x14xf32>
  return %1 : tensor<10x14xf32>
}

// CHECK-LABEL: func @unshard_dynamic_axis
func.func @unshard_dynamic_axis(
  // CHECK-SAME: %[[ARG:.*]]: tensor<?x14xf32>
  %arg0: tensor<?x14xf32>
) -> tensor<?x14xf32> {
  // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[ARG]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<?x14xf32> -> tensor<?x14xf32>
  %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<?x14xf32>
  %1 = mesh.shard %0 to <@mesh_1d, [[]]> annotate_for_users : tensor<?x14xf32>
  // CHECK: return %[[ALL_GATHER]] : tensor<?x14xf32>
  return %1 : tensor<?x14xf32>
}

// CHECK-LABEL: func @unshard_static_axis_on_dynamic_mesh_axis
func.func @unshard_static_axis_on_dynamic_mesh_axis(
// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>  
  %arg0: tensor<10x14xf32>
) -> tensor<10x14xf32> {
  // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<?x14xf32>
  // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d_dynamic mesh_axes = [0] gather_axis = 0 : tensor<?x14xf32> -> tensor<?x14xf32>
  // CHECK: %[[RES:.*]] = tensor.cast %[[ALL_GATHER]] : tensor<?x14xf32> to tensor<10x14xf32>
  %0 = mesh.shard %arg0 to <@mesh_1d_dynamic, [[0]]> : tensor<10x14xf32>
  %1 = mesh.shard %0 to <@mesh_1d_dynamic, [[]]> annotate_for_users : tensor<10x14xf32>
  // CHECK: return %[[RES]] : tensor<10x14xf32>
  return %1 : tensor<10x14xf32>
}

// CHECK-LABEL: func @partial_axis_to_full_replication
func.func @partial_axis_to_full_replication(
// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>  
  %arg0: tensor<10x14xf32>
) -> tensor<10x14xf32> {
  // CHECK: %[[ALL_REDUCE:.*]] = mesh.all_reduce %[[ARG]] on @mesh_1d mesh_axes = [0] : tensor<10x14xf32> -> tensor<10x14xf32>
  %0 = mesh.shard %arg0 to <@mesh_1d, [[]], partial = sum[0]> : tensor<10x14xf32>
  %1 = mesh.shard %0 to <@mesh_1d, [[]]> annotate_for_users : tensor<10x14xf32>
  // CHECK: %[[ALL_REDUCE]] : tensor<10x14xf32>
  return %1 : tensor<10x14xf32>
}