• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// RUN: tf-opt %s -split-input-file -tfl-analyze-variables-pass --cse | FileCheck %s
2
3// CHECK: module attributes {tfl._legalize_tfl_variables = true}
4module {
5  func.func @f() -> tensor<*xi32> {
6    %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf_type.resource<tensor<*xi32>>>
7    %2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf_type.resource<tensor<*xi32>>>) -> tensor<*xi32>
8    func.return %2 : tensor<*xi32>
9  }
10}
11
12// -----
13
14// CHECK: module attributes {tfl._legalize_tfl_variables = true}
15module {
16  func.func @main() -> tensor<*xi32> {
17    %0 = "tf.PartitionedCall"() {f = @f, config = "", config_proto = "", executor_type = ""}
18      : () -> tensor<*xi32>
19    func.return %0 : tensor<*xi32>
20  }
21  func.func @f() -> tensor<*xi32> {
22    %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf_type.resource<tensor<*xi32>>>
23    %1 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf_type.resource<tensor<*xi32>>>) -> tensor<*xi32>
24    func.return %1 : tensor<*xi32>
25  }
26}
27
28
29// -----
30
31// CHECK: module attributes {tfl._legalize_tfl_variables = false}
32module {
33  func.func @main() -> tensor<*xi32> {
34    %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf_type.resource<tensor<*xi32>>>
35    %1 = "tf.PartitionedCall"(%0) {f = @f, config = "", config_proto = "", executor_type = ""}
36      : (tensor<*x!tf_type.resource<tensor<*xi32>>>) -> tensor<*xi32>
37    func.return %1 : tensor<*xi32>
38  }
39  func.func @f(%arg0 : tensor<*x!tf_type.resource<tensor<*xi32>>>) -> tensor<*xi32> {
40    %0 = "tf.ReadVariableOp"(%arg0) {dtype = i32} : (tensor<*x!tf_type.resource<tensor<*xi32>>>) -> tensor<*xi32>
41    func.return %0 : tensor<*xi32>
42  }
43}
44
45// -----
46
47// CHECK: module attributes {tfl._legalize_tfl_variables = false}
48module {
49  func.func @main() -> tensor<*xi32> {
50    %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf_type.resource<tensor<*xi32>>>
51    %cst = arith.constant dense<2> : tensor<4xi32>
52    "tf.AssignAddVariableOp"(%0, %cst) {} : (tensor<*x!tf_type.resource<tensor<*xi32>>>, tensor<4xi32>) -> ()
53    %1 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf_type.resource<tensor<*xi32>>>) -> tensor<*xi32>
54    func.return %1 : tensor<*xi32>
55  }
56}
57
58// -----
59
60// CHECK: module attributes {tfl._legalize_tfl_variables = true}
61module {
62  func.func @main() -> tensor<i32> {
63    %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf_type.resource<tensor<*xi32>>>
64    %cst = arith.constant dense<1> : tensor<i32>
65    %1:2 = "tfl.while"(%cst, %0) ({
66    ^bb0(%arg1: tensor<*xi32>, %arg2: tensor<*x!tf_type.resource<tensor<*xi32>>>):
67      %2 = "tf.ReadVariableOp"(%arg2) {dtype = i32} : (tensor<*x!tf_type.resource<tensor<*xi32>>>) -> tensor<*xi32>
68      %3 = "tfl.greater"(%arg1, %2) : (tensor<*xi32>, tensor<*xi32>) -> tensor<i1>
69      "tfl.yield"(%3) : (tensor<i1>) -> ()
70  },  {
71    ^bb0(%arg3: tensor<*xi32>, %arg4: tensor<i32>):
72      %4 = "tfl.sub"(%arg3, %arg4) {fused_activation_function = "NONE"} :
73        (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
74      "tfl.yield"(%4) : (tensor<*xi32>) -> ()
75  }) : (tensor<i32>, tensor<*x!tf_type.resource<tensor<*xi32>>>) -> (tensor<i32>, tensor<*x!tf_type.resource<tensor<*xi32>>>)
76    func.return %1#0 : tensor<i32>
77  }
78}
79
80// -----
81
82// CHECK: module attributes {tfl._legalize_tfl_variables = false}
83module {
84  func.func @main() -> tensor<i32> {
85    %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf_type.resource<tensor<*xi32>>>
86    %cst = arith.constant dense<1> : tensor<i32>
87    %1:2 = "tfl.while"(%cst, %0) ({
88    ^bb0(%arg1: tensor<*xi32>, %arg2: tensor<*x!tf_type.resource<tensor<*xi32>>>):
89      %2 = "tf.ReadVariableOp"(%arg2) {dtype = i32} : (tensor<*x!tf_type.resource<tensor<*xi32>>>) -> tensor<*xi32>
90      %3 = "tfl.greater"(%arg1, %2) : (tensor<*xi32>, tensor<*xi32>) -> tensor<i1>
91      "tfl.yield"(%3) : (tensor<i1>) -> ()
92  },  {
93    ^bb0(%arg3: tensor<*xi32>, %arg4: tensor<*x!tf_type.resource<tensor<*xi32>>>):
94      %cst1 = arith.constant dense<2> : tensor<4xi32>
95      "tf.AssignAddVariableOp"(%arg4, %cst1) {} : (tensor<*x!tf_type.resource<tensor<*xi32>>>, tensor<4xi32>) -> ()
96      %4 = "tf.ReadVariableOp"(%arg4) {dtype = i32} : (tensor<*x!tf_type.resource<tensor<*xi32>>>) -> tensor<*xi32>
97      "tfl.yield"(%4) : (tensor<*xi32>) -> ()
98  }) : (tensor<i32>, tensor<*x!tf_type.resource<tensor<*xi32>>>) -> (tensor<i32>, tensor<*x!tf_type.resource<tensor<*xi32>>>)
99    func.return %1#0 : tensor<i32>
100  }
101}
102
103// -----
104
105// CHECK: module attributes {tfl._legalize_tfl_variables = true}
106module {
107  func.func @main(%arg0 : tensor<!tf_type.resource<tensor<4096xf32>>>,
108      %arg1 : tensor<*x!tf_type.variant>) {
109    %cst_0 = arith.constant dense<2> : tensor<i64>
110    %cst_1 = arith.constant dense<0> : tensor<i32>
111    %0 = "tf.RepeatDataset"(%arg1, %cst_0) {device = "",
112      output_shapes = [#tf_type.shape<?>],
113      output_types = [!tf_type.string]} : (tensor<*x!tf_type.variant>, tensor<i64>) -> tensor<!tf_type.variant>
114
115    %1 = "tf.ReduceDataset"(%0, %cst_1, %arg0) {
116      Targuments = [!tf_type.resource],
117      Tstate = [i32], device = "",
118      f = @__reduce_func, f._tf_data_function = true,
119      output_shapes = [#tf_type.shape<>],
120      output_types = [i32], use_inter_op_parallelism = true} : (tensor<!tf_type.variant>, tensor<i32>, tensor<!tf_type.resource<tensor<4096xf32>>>) -> (tensor<*xi32>)
121    func.return
122  }
123
124  func.func private @__reduce_func(%arg0: tensor<i32> {tf._user_specified_name = "args_0"}) -> (tensor<i32>) attributes {tf._tf_data_function = true, tf.signature.is_stateful} {
125    %0 = "tf.JustPretend"() : () -> (tensor<i32>)
126    func.return %0: tensor<i32>
127  }
128}
129