• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FILECHECK_OPTS="" FileCheck %s
2
3func @main() -> tensor<f32> {
4  %cst = constant dense<1> : tensor<i32>
5  %cst_0 = constant dense<5.600000e+01> : tensor<f32>
6  %cst_1 = constant dense<1.200000e+01> : tensor<f32>
7  %cst_2 = constant dense<1.300000e+01> : tensor<f32>
8  %0 = "mhlo.case"(%cst, %cst_0, %cst_1, %cst_2) ( {
9  ^bb0(%arg0: tensor<f32>):
10    %1 = "mhlo.negate"(%arg0) : (tensor<f32>) -> tensor<f32>
11    "mhlo.return"(%1) : (tensor<f32>) -> ()
12  },  {
13  ^bb0(%arg0: tensor<f32>):
14    %1 = "mhlo.copy"(%arg0) : (tensor<f32>) -> tensor<f32>
15    "mhlo.return"(%1) : (tensor<f32>) -> ()
16  },  {
17  ^bb0(%arg0: tensor<f32>):
18    %1 = "mhlo.floor"(%arg0) : (tensor<f32>) -> tensor<f32>
19    "mhlo.return"(%1) : (tensor<f32>) -> ()
20  }) : (tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32>
21  return %0 : tensor<f32>
22}
23
24// CHECK: %[[NEGATE_BRANCH:.*]] ({{.*}}: f32[]) -> f32[] {
25// CHECK:   %[[ARG:.*]] = f32[] parameter(0)
26// CHECK:   ROOT %[[RESULT:.*]] = f32[] negate(f32[] %[[ARG]])
27// CHECK: }
28
29// CHECK: %[[COPY_BRANCH:.*]] ({{.*}}: f32[]) -> f32[] {
30// CHECK:   %[[ARG:.*]] = f32[] parameter(0)
31// CHECK:   ROOT %[[RESULT:.*]] = f32[] copy(f32[] %[[ARG]])
32// CHECK: }
33
34// CHECK: %[[FLOOR_BRANCH:.*]] ({{.*}}: f32[]) -> f32[] {
35// CHECK:   %[[ARG:.*]] = f32[] parameter(0)
36// CHECK:   ROOT %[[RESULT:.*]] = f32[] floor(f32[] %[[ARG]])
37// CHECK: }
38
39// CHECK-LABEL: ENTRY
40// CHECK-SAME:  () -> f32[]
41
42// CHECK: %[[INDEX:.*]] = s32[] constant(1)
43// CHECK: %[[OPERAND_1:.*]] = f32[] constant(56)
44// CHECK: %[[OPERAND_2:.*]] = f32[] constant(12)
45// CHECK: %[[OPERAND_3:.*]] = f32[] constant(13)
46// CHECK: ROOT %[[RESULT:.*]] = f32[] conditional(s32[] %[[INDEX]], f32[] %[[OPERAND_1]], f32[] %[[OPERAND_2]], f32[] %[[OPERAND_3]]), branch_computations={%[[NEGATE_BRANCH]], %[[COPY_BRANCH]], %[[FLOOR_BRANCH]]}
47
48// -----
49
50func @main() -> (tensor<f32>, tensor<f32>) {
51  %cst = constant dense<1> : tensor<i32>
52  %cst_0 = constant dense<5.600000e+01> : tensor<f32>
53  %cst_1 = constant dense<1.200000e+01> : tensor<f32>
54  %cst_2 = constant dense<1.300000e+01> : tensor<f32>
55  %0:2 = "mhlo.case"(%cst, %cst_0, %cst_1, %cst_2) ( {
56  ^bb0(%arg0: tensor<f32>):
57    %1 = "mhlo.negate"(%arg0) : (tensor<f32>) -> tensor<f32>
58    "mhlo.return"(%1, %1) : (tensor<f32>, tensor<f32>) -> ()
59  },  {
60  ^bb0(%arg0: tensor<f32>):
61    %1 = "mhlo.copy"(%arg0) : (tensor<f32>) -> tensor<f32>
62    "mhlo.return"(%1, %1) : (tensor<f32>, tensor<f32>) -> ()
63  },  {
64  ^bb0(%arg0: tensor<f32>):
65    %1 = "mhlo.floor"(%arg0) : (tensor<f32>) -> tensor<f32>
66    "mhlo.return"(%1, %1) : (tensor<f32>, tensor<f32>) -> ()
67  }) : (tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
68  return %0#0, %0#1 : tensor<f32>, tensor<f32>
69}
70
71// CHECK: %[[NEGATE_BRANCH:.*]] ({{.*}}: f32[]) -> (f32[], f32[]) {
72// CHECK:   %[[ARG:.*]] = f32[] parameter(0)
73// CHECK:   %[[NEGATE:.*]] = f32[] negate(f32[] %[[ARG]])
74// CHECK:   ROOT %[[TUPLE:.*]] = (f32[], f32[]) tuple(f32[] %[[NEGATE]], f32[] %[[NEGATE]])
75// CHECK: }
76
77// CHECK: %[[COPY_BRANCH:.*]] ({{.*}}: f32[]) -> (f32[], f32[]) {
78// CHECK:   %[[ARG:.*]] = f32[] parameter(0)
79// CHECK:   %[[COPY:.*]] = f32[] copy(f32[] %[[ARG]])
80// CHECK:   ROOT %[[TUPLE:.*]] = (f32[], f32[]) tuple(f32[] %[[COPY]], f32[] %[[COPY]])
81// CHECK: }
82
83// CHECK: %[[FLOOR_BRANCH:.*]] ({{.*}}: f32[]) -> (f32[], f32[]) {
84// CHECK:   %[[ARG:.*]] = f32[] parameter(0)
85// CHECK:   %[[FLOOR:.*]] = f32[] floor(f32[] %[[ARG]])
86// CHECK:   ROOT %[[TUPLE:.*]] = (f32[], f32[]) tuple(f32[] %[[FLOOR]], f32[] %[[FLOOR]])
87// CHECK: }
88
89// CHECK-LABEL: ENTRY
90// CHECK-SAME:  () -> (f32[], f32[])
91
92// CHECK: %[[INDEX:.*]] = s32[] constant(1)
93// CHECK: %[[OPERAND_1:.*]] = f32[] constant(56)
94// CHECK: %[[OPERAND_2:.*]] = f32[] constant(12)
95// CHECK: %[[OPERAND_3:.*]] = f32[] constant(13)
96// CHECK: %[[TUPLE:.*]] = (f32[], f32[]) conditional(s32[] %[[INDEX]], f32[] %[[OPERAND_1]], f32[] %[[OPERAND_2]], f32[] %[[OPERAND_3]]), branch_computations={%[[NEGATE_BRANCH]], %[[COPY_BRANCH]], %[[FLOOR_BRANCH]]}
97// CHECK: %[[RES_1:.*]] = f32[] get-tuple-element((f32[], f32[]) %[[TUPLE]]), index=0
98// CHECK: %[[RES_2:.*]] = f32[] get-tuple-element((f32[], f32[]) %[[TUPLE]]), index=1
99// CHECK: ROOT %[[RESULT:.*]] = (f32[], f32[]) tuple(f32[] %[[RES_1]], f32[] %[[RES_2]])
100