• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// RUN: mlir-hlo-opt --split-input-file --allow-unregistered-dialect --mhlo-broadcast-propagation --canonicalize --cse %s | FileCheck %s
2
3// Shape computations shall be reified.
4// CHECK-LABEL: @shape_of_unary
5// CHECK-SAME: (%[[ARG:.*]]: tensor<?x32xi16>)
6func @shape_of_unary(%arg : tensor<?x32xi16>) {
7  // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor<?x32xi16> -> tensor<?xindex>
8  // CHECK: "use"(%[[SHAPE]])
9  %0 = "mhlo.convert"(%arg) : (tensor<?x32xi16>) -> tensor<?x32xf16>
10  %1 = shape.shape_of %0 : tensor<?x32xf16> -> tensor<?xindex>
11  "use"(%1) : (tensor<?xindex>) -> ()
12  return
13}
14
15// -----
16
17// Shape computations shall be reified.
18// CHECK-LABEL: @shape_of_nary
19// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xf16>, %[[ARG1:.*]]: tensor<?x32xf16>)
20func @shape_of_nary(%arg0 : tensor<?x32xf16>, %arg1 : tensor<?x32xf16>) {
21  // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG0]] : tensor<?x32xf16> -> tensor<?xindex>
22  // CHECK: "use"(%[[SHAPE]])
23  %0 = mhlo.subtract %arg0, %arg1 : tensor<?x32xf16>
24  %1 = mhlo.subtract %0, %arg1 : tensor<?x32xf16>
25  %2 = shape.shape_of %1 : tensor<?x32xf16> -> tensor<?xindex>
26  "use"(%2) : (tensor<?xindex>) -> ()
27  return
28}
29
30// -----
31
32// Broadcasts can be moved up over unary shape-preserving operations.
33// CHECK-LABEL: @bcast_unary
34// CHECK-SAME: (%[[ARG:.*]]: tensor<?x32xi16>, %[[OUT_DIMS:.*]]: tensor<3xindex>)
35func @bcast_unary(%arg : tensor<?x32xi16>, %out_dims : tensor<3xindex>)
36    -> tensor<?x?x32xf16> {
37  // CHECK:      %[[BCASTED_OPERAND:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG]], %[[OUT_DIMS]])
38  // CHECK-SAME: broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x32xi16>, tensor<3xindex>) -> tensor<?x?x32xi16>
39  // CHECK:      "mhlo.convert"(%[[BCASTED_OPERAND]]) : (tensor<?x?x32xi16>) -> tensor<?x?x32xf16>
40  %0 = "mhlo.convert"(%arg) : (tensor<?x32xi16>) -> tensor<?x32xf16>
41  %1 = "mhlo.dynamic_broadcast_in_dim"(%0, %out_dims) {
42      broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> } :
43      (tensor<?x32xf16>, tensor<3xindex>) -> tensor<?x?x32xf16>
44  return %1 : tensor<?x?x32xf16>
45}
46
47// -----
48
49// Broadcasts can be moved up over n-ary shape-preserving operations.
50// CHECK-LABEL: @bcast_nary
51// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xf32>, %[[ARG1:.*]]: tensor<?x32xf32>, %[[OUT_DIMS:.*]]: tensor<3xindex>)
52func @bcast_nary(%arg0 : tensor<?x32xf32>, %arg1 : tensor<?x32xf32>,
53    %out_dims : tensor<3xindex>) -> tensor<?x?x32xf32> {
54  // CHECK-NOT: subtract
55  // CHECK:     %[[BCASTED_ARG0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[OUT_DIMS]])
56  // CHECK:     %[[BCASTED_ARG1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[OUT_DIMS]])
57  // CHECK:     %{{.*}} = mhlo.subtract %[[BCASTED_ARG0]], %[[BCASTED_ARG1]] : tensor<?x?x32xf32>
58  %0 = mhlo.subtract %arg0, %arg1 : tensor<?x32xf32>
59  %1 = "mhlo.dynamic_broadcast_in_dim"(%0, %out_dims) {
60      broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> } :
61      (tensor<?x32xf32>, tensor<3xindex>) -> tensor<?x?x32xf32>
62  return %1 : tensor<?x?x32xf32>
63}
64
65// -----
66
67// Exemplary IR as it appears in the lowering with `tf.Sub` and `tf.Cast`.
68// CHECK-LABEL: @cast_sub
69// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xi16>, %[[ARG1:.*]]: tensor<?x?x32xf16>) -> tensor<?x?x32xf16>
70func @cast_sub(%arg0: tensor<?x32xi16>, %arg1: tensor<?x?x32xf16>)
71    -> tensor<?x?x32xf16> {
72  // CHECK-NOT: convert
73  // CHECK:     %[[BCASTED_ARG1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %{{.*}})
74  // CHECK:     %[[BCASTED_ARG0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %{{.*}})
75  // CHECK:     %[[CONVERTED_BCASTED_ARG0:.*]] = "mhlo.convert"(%[[BCASTED_ARG0]]) : (tensor<?x?x32xi16>) -> tensor<?x?x32xf16>
76  // CHECK:     %{{.*}} = mhlo.subtract %[[BCASTED_ARG1]], %[[CONVERTED_BCASTED_ARG0]] : tensor<?x?x32xf16>
77  %0 = "mhlo.convert"(%arg0) : (tensor<?x32xi16>) -> tensor<?x32xf16>
78  %1 = shape.shape_of %arg1 : tensor<?x?x32xf16> -> tensor<?xindex>
79  %2 = shape.shape_of %0 : tensor<?x32xf16> -> tensor<?xindex>
80  %3 = shape.cstr_broadcastable %1, %2 : tensor<?xindex>, tensor<?xindex>
81  %4 = shape.assuming %3 -> (tensor<?x?x32xf16>) {
82    %5 = shape.shape_of %arg1 : tensor<?x?x32xf16> -> tensor<?xindex>
83    %6 = shape.shape_of %0 : tensor<?x32xf16> -> tensor<?xindex>
84    %7 = shape.broadcast %5, %6 : tensor<?xindex>, tensor<?xindex>
85        -> tensor<?xindex>
86    %8 = tensor.cast %7 : tensor<?xindex> to tensor<3xindex>
87    %9 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %8) {
88        broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} :
89        (tensor<?x?x32xf16>, tensor<3xindex>) -> tensor<?x?x32xf16>
90    %10 = "mhlo.dynamic_broadcast_in_dim"(%0, %8) {
91        broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} :
92        (tensor<?x32xf16>, tensor<3xindex>) -> tensor<?x?x32xf16>
93    %11 = mhlo.subtract %9, %10 : tensor<?x?x32xf16>
94    shape.assuming_yield %11 : tensor<?x?x32xf16>
95  }
96  return %4 : tensor<?x?x32xf16>
97}
98
99// -----
100
101// CHECK-LABEL: @inline_bcasted_shape_operands
102// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<?xindex>)
103func @inline_bcasted_shape_operands(%a : tensor<?xindex>, %b : tensor<?xindex>,
104    %c : tensor<?xindex>) -> !shape.witness {
105  // CHECK-NOT: shape.broadcast
106  // CHECK:     %[[WITNESS:.*]] = shape.cstr_broadcastable %[[A]], %[[B]], %[[C]]
107  // CHECK:     return %[[WITNESS]] : !shape.witness
108  %0 = shape.broadcast %a, %b : tensor<?xindex>, tensor<?xindex>
109      -> tensor<?xindex>
110  %1 = shape.cstr_broadcastable %0, %c : tensor<?xindex>, tensor<?xindex>
111  return %1 : !shape.witness
112}
113
114// -----
115
116// CHECK-LABEL: @move_shape_of_into_assuming
117// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<?x32xf32>)
118func @move_shape_of_into_assuming(%arg0 : !shape.witness,
119    %arg1 : tensor<?x32xf32>) -> tensor<3xindex> {
120  // CHECK:     %[[ASSUMING_RESULTS:.*]]:3 = shape.assuming %[[ARG0]] -> (tensor<?x32xf32>, tensor<?x32xf32>, tensor<3xindex>) {
121  // CHECK:       %[[DUMMY_TENSOR:.*]] = "dummy.tensor"() : () -> tensor<?x32xf32>
122  // CHECK:       %[[SHAPE:.*]] = shape.shape_of %[[DUMMY_TENSOR]]
123  // CHECK:       shape.assuming_yield %[[ARG1]], %[[DUMMY_TENSOR]], %[[SHAPE]]
124  // CHECK:     }
125  // CHECK-NOT: shape_of
126  // CHECK:     return %[[ASSUMING_RESULTS]]#2
127  %0:2 = shape.assuming %arg0 -> (tensor<?x32xf32>, tensor<?x32xf32>) {
128    %1 = "dummy.tensor"() : () -> tensor<?x32xf32>
129    shape.assuming_yield %arg1, %1 : tensor<?x32xf32>, tensor<?x32xf32>
130  }
131  %2 = shape.shape_of %0#1 : tensor<?x32xf32> -> tensor<3xindex>
132  "use"(%0#0, %0#1) : (tensor<?x32xf32>, tensor<?x32xf32>) -> ()
133  return %2 : tensor<3xindex>
134}
135
136// -----
137
138// CHECK-LABEL: @move_cstr_broadcastable_into_assuming
139// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<2xindex>)
140func @move_cstr_broadcastable_into_assuming(%arg0 : !shape.witness,
141    %arg1 : tensor<2xindex>) -> !shape.witness {
142  // CHECK:     %[[ASSUMING_RESULTS:.*]]:3 = shape.assuming %[[ARG0]] -> (tensor<2xindex>, tensor<3xindex>, !shape.witness) {
143  // CHECK:       %[[DUMMY_TENSOR:.*]] = "dummy.tensor"() : () -> tensor<3xindex>
144  // CHECK:       %[[WITNESS:.*]] = shape.cstr_broadcastable %[[ARG1]], %[[DUMMY_TENSOR]]
145  // CHECK:       shape.assuming_yield %[[ARG1]], %[[DUMMY_TENSOR]], %[[WITNESS]]
146  // CHECK:     }
147  // CHECK-NOT: cstr_broadcastable
148  // CHECK:     return %[[ASSUMING_RESULTS]]#2
149  %0:2 = shape.assuming %arg0 -> (tensor<2xindex>, tensor<3xindex>) {
150    %1 = "dummy.tensor"() : () -> tensor<3xindex>
151    shape.assuming_yield %arg1, %1 : tensor<2xindex>, tensor<3xindex>
152  }
153  %1 = shape.cstr_broadcastable %arg1, %0#1 : tensor<2xindex>, tensor<3xindex>
154  "use"(%0#0, %0#1) : (tensor<2xindex>, tensor<3xindex>) -> ()
155  return %1 : !shape.witness
156}
157
158// -----
159
160// CHECK-LABEL: @not_move_shape_of_into_assuming
161func @not_move_shape_of_into_assuming(%arg0 : !shape.witness,
162    %arg1 : tensor<?x32xf32>, %arg2 : tensor<?x32xf32>) -> tensor<3xindex> {
163  // CHECK:      shape.assuming
164  // CHECK-SAME: {
165  // CHECK-NOT:    shape_of
166  // CHECK:      }
167  // CHECK:     "some.other.op"
168  // CHECK:     shape_of
169  %0:2 = shape.assuming %arg0 -> (tensor<?x32xf32>, tensor<?x32xf32>) {
170    shape.assuming_yield %arg1, %arg2 : tensor<?x32xf32>, tensor<?x32xf32>
171  }
172  "some.other.op"() : () -> ()
173  %2 = shape.shape_of %0#1 : tensor<?x32xf32> -> tensor<3xindex>
174  return %2 : tensor<3xindex>
175}
176
177// -----
178
179// CHECK-LABEL: @move_cstr_broadcastable_out_of_assuming
180// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<2xindex>, %[[ARG2:.*]]: tensor<3xindex>)
181func @move_cstr_broadcastable_out_of_assuming(%arg0 : !shape.witness,
182    %arg1 : tensor<2xindex>, %arg2 : tensor<3xindex>) -> !shape.witness {
183  // CHECK:     %[[WITNESS:.*]] = shape.cstr_broadcastable %[[ARG1]], %[[ARG2]]
184  // CHECK-NOT: assuming
185  // CHECK-NOT: cstr_broadcastable
186  // CHECK:     return %[[WITNESS]]
187  %0 = shape.assuming %arg0 -> (!shape.witness) {
188    %1 = shape.cstr_broadcastable %arg1, %arg2 : tensor<2xindex>, tensor<3xindex>
189    shape.assuming_yield %1 : !shape.witness
190  }
191  return %0 : !shape.witness
192}
193
194// -----
195
196// CHECK-LABEL: @move_elementwise_into_assuming
197// CHECK-SAME:  (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<?xf32>)
198func @move_elementwise_into_assuming(%arg0 : !shape.witness,
199    %arg1 : tensor<?xf32>) -> tensor<?xf32> {
200  // CHECK:     %[[RES:.*]] = shape.assuming %[[ARG0]]
201  // CHECK:       %[[SOME:.*]] = "some.op"
202  // CHECK:       %[[TANH:.*]] = "mhlo.tanh"(%[[ARG1]])
203  // CHECK:       %[[BCAST_ADD:.*]] = chlo.broadcast_add %[[TANH]], %[[SOME]]
204  // CHECK:       shape.assuming_yield %[[BCAST_ADD]]
205  // CHECK-NOT: tanh
206  // CHECK-NOT: broadcast_add
207  // CHECK:     return %[[RES]]
208  %0:2 = shape.assuming %arg0 -> (tensor<?xf32>, tensor<?xf32>) {
209    %1 = "some.op"() : () -> tensor<?xf32>
210    shape.assuming_yield %arg1, %1 : tensor<?xf32>, tensor<?xf32>
211  }
212  %1 = "mhlo.tanh"(%arg1) : (tensor<?xf32>) -> tensor<?xf32>
213  %2 = chlo.broadcast_add %1, %0#1
214      : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
215  return %2 : tensor<?xf32>
216}
217
218// -----
219
220// CHECK-LABEL: @move_shape_of_out_of_assuming
221// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<2x?xf32>)
222func @move_shape_of_out_of_assuming(%arg0 : !shape.witness,
223    %arg1 : tensor<2x?xf32>) -> tensor<2xindex> {
224  // CHECK:     %[[SHAPE:.*]] = shape.shape_of %[[ARG1]]
225  // CHECK-NOT: assuming
226  // CHECK-NOT: cstr_broadcastable
227  // CHECK:     return %[[SHAPE]]
228    %0 = shape.assuming %arg0 -> (tensor<2xindex>) {
229    %1 = shape.shape_of %arg1 : tensor<2x?xf32> -> tensor<2xindex>
230    shape.assuming_yield %1 : tensor<2xindex>
231  }
232  return %0 : tensor<2xindex>
233}
234
235// -----
236
237// CHECK-LABEL: @move_shape_of_out_of_assuming
238// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<2x?xf32>)
239func @move_shape_of_out_of_assuming(%arg0 : !shape.witness,
240    %arg1 : tensor<2x?xf32>) -> tensor<2xindex> {
241  // CHECK:     %[[SHAPE:.*]] = shape.shape_of %[[ARG1]]
242  // CHECK:     %{{.*}} = shape.assuming %[[ARG0]] -> (tensor<2x?xf32>) {
243  // CHECK:       %[[SOME_VAL:.*]] = "some.op"() : () -> tensor<2x?xf32>
244  // CHECK:       shape.assuming_yield %[[SOME_VAL]] : tensor<2x?xf32>
245  // CHECK:     }
246  // CHECK:     return %[[SHAPE]]
247  %0:2 = shape.assuming %arg0 -> (tensor<2x?xf32>, tensor<2xindex>) {
248    %1 = "some.op"() : () -> (tensor<2x?xf32>)
249    %2 = shape.shape_of %arg1 : tensor<2x?xf32> -> tensor<2xindex>
250    shape.assuming_yield %1, %2 : tensor<2x?xf32>, tensor<2xindex>
251  }
252  "use"(%0#0, %0#1) : (tensor<2x?xf32>, tensor<2xindex>) -> ()
253  return %0#1 : tensor<2xindex>
254}
255
256// -----
257
258// CHECK-LABEL: @not_move_shape_of_out_of_assuming
259// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<2x?xf32>)
260func @not_move_shape_of_out_of_assuming(%arg0 : !shape.witness,
261    %arg1 : tensor<2x?xf32>) -> tensor<2xindex> {
262  // CHECK-NOT:  shape_of
263  // CHECK:      shape.assuming
264  // CHECK-SAME: {
265  // CHECK:        "some.tensor"
266  // CHECK:        shape_of
267  // CHECK:      }
268  %0 = shape.assuming %arg0 -> (tensor<2xindex>) {
269    %1 = "some.tensor"() : () -> tensor<2x?xf32>
270    %2 = shape.shape_of %1 : tensor<2x?xf32> -> tensor<2xindex>
271    shape.assuming_yield %2 : tensor<2xindex>
272  }
273  return %0 : tensor<2xindex>
274}
275
276// -----
277
278// CHECK: @merge_assuming_ops
279// CHECK: (%[[ARG0:.*]]: tensor<?x32xf16>, %[[ARG1:.*]]: tensor<?x32xf16>, %[[ARG2:.*]]: tensor<?x?x32xf16>)
280func @merge_assuming_ops(%arg0: tensor<?x32xf16>, %arg1 : tensor<?x32xf16>,
281    %arg2: tensor<?x?x32xf16>) -> tensor<?x?x32xf16> {
282  // CHECK:      %[[SHAPE0:.*]] = shape.shape_of %[[ARG0]]
283  // CHECK:      %[[SHAPE1:.*]] = shape.shape_of %[[ARG1]]
284  // CHECK:      %[[SHAPE2:.*]] = shape.shape_of %[[ARG2]]
285  // CHECK:      %[[WITNESS0:.*]] = shape.cstr_broadcastable %[[SHAPE0]], %[[SHAPE1]]
286  // CHECK:      %[[WITNESS1:.*]] = shape.cstr_broadcastable %[[SHAPE0]], %[[SHAPE1]], %[[SHAPE2]]
287  // CHECK:      %[[COMBINED_WITNESS:.*]] = shape.assuming_all %[[WITNESS0]], %[[WITNESS1]]
288  // CHECK:      %[[MERGED:.*]]:2 = shape.assuming %[[COMBINED_WITNESS]]
289  // CHECK-SAME: {
290  // CHECK:        "some.op"
291  // CHECK:        %[[RESULT0:.*]] = "some.producer"
292  // CHECK:        "another.op"
293  // CHECK:        %[[RESULT1:.*]] = "another.producer"
294  // CHECK:        shape.assuming_yield %[[RESULT0]], %[[RESULT1]]
295  // CHECK:      }
296  // CHECK:      return %[[MERGED]]#1
297  %0 = shape.shape_of %arg0 : tensor<?x32xf16> -> tensor<2xindex>
298  %1 = shape.shape_of %arg1 : tensor<?x32xf16> -> tensor<2xindex>
299  %2 = shape.shape_of %arg2 : tensor<?x?x32xf16> -> tensor<3xindex>
300  %3 = shape.cstr_broadcastable %0, %1 : tensor<2xindex>, tensor<2xindex>
301  %4 = shape.cstr_broadcastable %0, %1, %2 : tensor<2xindex>, tensor<2xindex>,
302      tensor<3xindex>
303  %5 = shape.assuming %3 -> (tensor<?x32xf16>) {
304    "some.op"() : () -> ()
305    %6 = "some.producer"() : () -> tensor<?x32xf16>
306    shape.assuming_yield %6 : tensor<?x32xf16>
307  }
308  %7 = shape.assuming %4 -> (tensor<?x?x32xf16>) {
309    "another.op"() : () -> ()
310    %8 = "another.producer"() : () -> tensor<?x?x32xf16>
311    shape.assuming_yield %8 : tensor<?x?x32xf16>
312  }
313  "use"(%5, %7) : (tensor<?x32xf16>, tensor<?x?x32xf16>) -> ()
314  return %7 : tensor<?x?x32xf16>
315}
316
317// -----
318
319// Do not merge assuming ops if witness will not dominate use.
320// CHECK: @do_not_merge_assuming_ops
321func @do_not_merge_assuming_ops() {
322  // CHECK: shape.assuming
323  // CHECK: shape.assuming
324  %0 = "some.witness"() : () -> !shape.witness
325  %1 = shape.assuming %0 -> (!shape.witness) {
326    %2 = "some.witness"() : () -> !shape.witness
327    shape.assuming_yield %2 : !shape.witness
328  }
329  shape.assuming %1 {
330    "some.op"() : () -> ()
331    shape.assuming_yield
332  }
333  return
334}
335
336// -----
337
338// CHECK:      @eliminate_extent_tensor_cast
339// CHECK-SAME: (%[[ARG:.*]]: tensor<2x?x4xf32>)
340func @eliminate_extent_tensor_cast(%arg : tensor<2x?x4xf32>) {
341  // CHECK-NOT:  shape_of
342  // CHECK:      %[[RESULT:.*]] = shape.shape_of %[[ARG]] : tensor<2x?x4xf32> -> tensor<3xindex>
343  // CHECK-NEXT: "use"(%[[RESULT]]) : (tensor<3xindex>) -> ()
344  %0 = shape.shape_of %arg : tensor<2x?x4xf32> -> tensor<?xindex>
345  %1 = tensor.cast %0 : tensor<?xindex> to tensor<3xindex>
346  "use"(%1) : (tensor<3xindex>) -> ()
347  return
348}
349
350// -----
351
352// Exemplary IR as it appears in the lowering of two subsequent `tf.Sub` ops.
353// CHECK-LABEL: @sub_sub
354// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xf16>, %[[ARG1:.*]]: tensor<?x32xf16>, %[[ARG2:.*]]: tensor<?x?x32xf16>)
355func @sub_sub(%arg0: tensor<?x32xf16>, %arg1 : tensor<?x32xf16>,
356    %arg2: tensor<?x?x32xf16>) -> tensor<?x?x32xf16> {
357  // CHECK-DAG:  %[[SHAPE0:.*]] = shape.shape_of %[[ARG0]]
358  // CHECK-DAG:  %[[SHAPE1:.*]] = shape.shape_of %[[ARG1]]
359  // CHECK-DAG:  %[[SHAPE2:.*]] = shape.shape_of %[[ARG2]]
360  // CHECK-DAG:  %[[WITNESS0:.*]] = shape.cstr_broadcastable %[[SHAPE0]], %[[SHAPE1]]
361  // CHECK-DAG:  %[[WITNESS1:.*]] = shape.cstr_broadcastable %[[SHAPE2]], %[[SHAPE0]], %[[SHAPE1]]
362  // CHECK-DAG:  %[[COMBINED_WITNESS:.*]] = shape.assuming_all %[[WITNESS0]], %[[WITNESS1]]
363  // CHECK:      %[[ASSUMING_RESULT:.*]] = shape.assuming %[[COMBINED_WITNESS]]
364  // CHECK:        %[[BCASTED_SHAPE01:.*]] = shape.broadcast %[[SHAPE0]], %[[SHAPE1]]
365  // CHECK:        %[[BCASTED_SHAPE012:.*]] = shape.broadcast %[[SHAPE2]], %[[BCASTED_SHAPE01]]
366  // CHECK:        %[[BCASTED_ARG2:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG2]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[0, 1, 2]>
367  // CHECK:        %[[BCASTED_ARG0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
368  // CHECK:        %[[BCASTED_ARG1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
369  // CHECK:        %[[TMP:.*]] = mhlo.subtract %[[BCASTED_ARG0]], %[[BCASTED_ARG1]]
370  // CHECK:        %[[RESULT:.*]] = mhlo.subtract %[[BCASTED_ARG2]], %[[TMP]]
371  // CHECK:        shape.assuming_yield %[[RESULT]]
372  // CHECK:      return %[[ASSUMING_RESULT]]
373  %0 = shape.shape_of %arg0 : tensor<?x32xf16> -> tensor<2xindex>
374  %1 = shape.shape_of %arg1 : tensor<?x32xf16> -> tensor<2xindex>
375  %2 = shape.cstr_broadcastable %0, %1 : tensor<2xindex>, tensor<2xindex>
376  %3 = shape.assuming %2 -> (tensor<?x32xf16>) {
377    %8 = shape.shape_of %arg0 : tensor<?x32xf16> -> tensor<2xindex>
378    %9 = shape.shape_of %arg1 : tensor<?x32xf16> -> tensor<2xindex>
379    %10 = shape.broadcast %8, %9 : tensor<2xindex>, tensor<2xindex> -> tensor<?xindex>
380    %11 = tensor.cast %10 : tensor<?xindex> to tensor<2xindex>
381    %12 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %11) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x32xf16>, tensor<2xindex>) -> tensor<?x32xf16>
382    %13 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %11) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x32xf16>, tensor<2xindex>) -> tensor<?x32xf16>
383    %14 = mhlo.subtract %12, %13 : tensor<?x32xf16>
384    shape.assuming_yield %14 : tensor<?x32xf16>
385  }
386  %4 = shape.shape_of %arg2 : tensor<?x?x32xf16> -> tensor<3xindex>
387  %5 = shape.shape_of %3 : tensor<?x32xf16> -> tensor<2xindex>
388  %6 = shape.cstr_broadcastable %4, %5 : tensor<3xindex>, tensor<2xindex>
389  %7 = shape.assuming %6 -> (tensor<?x?x32xf16>) {
390    %8 = shape.shape_of %arg2 : tensor<?x?x32xf16> -> tensor<3xindex>
391    %9 = shape.shape_of %3 : tensor<?x32xf16> -> tensor<2xindex>
392    %10 = shape.broadcast %8, %9 : tensor<3xindex>, tensor<2xindex> -> tensor<?xindex>
393    %11 = tensor.cast %10 : tensor<?xindex> to tensor<3xindex>
394    %12 = "mhlo.dynamic_broadcast_in_dim"(%arg2, %11) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<?x?x32xf16>, tensor<3xindex>) -> tensor<?x?x32xf16>
395    %13 = "mhlo.dynamic_broadcast_in_dim"(%3, %11) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<?x32xf16>, tensor<3xindex>) -> tensor<?x?x32xf16>
396    %14 = mhlo.subtract %12, %13 : tensor<?x?x32xf16>
397    shape.assuming_yield %14 : tensor<?x?x32xf16>
398  }
399  return %7 : tensor<?x?x32xf16>
400}
401
402// -----
403
404// CHECK-LABEL: @redundant_cstr_broadcastable
405// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xindex>, %[[ARG1:.*]]: tensor<?xindex>)
406func @redundant_cstr_broadcastable(%arg0: tensor<?xindex>,
407    %arg1 : tensor<?xindex>) {
408  // CHECK-DAG:  %[[WITNESS:.*]] = shape.cstr_broadcastable %[[ARG0]], %[[ARG1]]
409  // CHECK:      shape.assuming %[[WITNESS]]
410  %0 = shape.cstr_broadcastable %arg0, %arg1 : tensor<?xindex>, tensor<?xindex>
411  %1 = shape.cstr_broadcastable %arg0, %arg1 : tensor<?xindex>, tensor<?xindex>
412  %2 = shape.assuming_all %0, %1
413  shape.assuming %2 -> () {
414    "some.op"() : () -> ()
415    shape.assuming_yield
416  }
417  return
418}
419