1// RUN: tf-opt -tfl-raise-custom-ops -canonicalize %s --split-input-file | FileCheck %s 2// RUN: tf-opt -tfl-raise-custom-ops="test-raise-tf-targets=tf.FakeQuantWithMinMaxVarsPerChannel" -canonicalize %s --split-input-file | FileCheck --check-prefix=WRAPPED %s 3 4// CHECK-LABEL: custom_op 5func.func @custom_op(%arg0: tensor<4xf32>) -> tensor<4xf32> { 6 %0 = "arith.constant" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> 7 %1 = "tfl.mul"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> 8 // will be preserved since it has uses. 9 %2 = "tf.MyCustomOp"(%1, %0) {fused_activation_function = "RELU", int_attr = 2 : i32} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> 10 // will be preserved since it has side-effect. 11 "tf.MyCustomOp"(%1, %0) {fused_activation_function = "RELU", int_attr = 2 : i32} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> 12 func.return %2 : tensor<4xf32> 13 14// CHECK-NEXT: %[[CST:.*]] = arith.constant dense<1.000000e+00> : tensor<4xf32> 15// CHECK-NEXT: %[[MUL:.*]] = tfl.mul %arg0, %[[CST]] {fused_activation_function = "NONE"} : tensor<4xf32> 16// CHECK-NEXT: %[[CUSTOM_1:.*]] = "tfl.custom_tf"(%[[MUL]], %[[CST]]) ({ 17// CHECK-NEXT: ^bb0(%arg1: tensor<4xf32>, %arg2: tensor<4xf32>): 18// CHECK-NEXT: %[[MY_CUSTOM:.*]] = "tf.MyCustomOp"(%arg1, %arg2) {fused_activation_function = "RELU", int_attr = 2 : i32} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> 19// CHECK-NEXT: "tfl.yield"(%[[MY_CUSTOM]]) : (tensor<4xf32>) -> () 20// CHECK-NEXT: }) {fused_activation_function = "RELU", int_attr = 2 : i32} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> 21// CHECK-NEXT: %[[CUSTOM_2:.*]] = "tfl.custom_tf"(%[[MUL]], %[[CST]]) ({ 22// CHECK-NEXT: ^bb0(%arg1: tensor<4xf32>, %arg2: tensor<4xf32>): 23// CHECK-NEXT: %[[MY_CUSTOM:.*]] = "tf.MyCustomOp"(%arg1, %arg2) {fused_activation_function = "RELU", int_attr = 2 : i32} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> 24// CHECK-NEXT: "tfl.yield"(%[[MY_CUSTOM]]) : (tensor<4xf32>) -> () 25// CHECK-NEXT: }) {fused_activation_function = "RELU", int_attr = 2 : i32} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> 26// CHECK-NEXT: return %[[CUSTOM_1]] : tensor<4xf32> 27} 28 29// ----- 30 31// CHECK-LABEL: tf_executor_wrapper 32// WRAPPED-LABEL: tf_executor_wrapper 33func.func @tf_executor_wrapper(%arg0: tensor<*xf32>) -> tensor<*xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "input", outputs = "output"}} { 34 %0 = tf_executor.graph { 35 %outputs_14, %control_15 = tf_executor.island wraps "tf.Const"() {device = "", value = dense<1.0> : tensor<186xf32>} : () -> tensor<186xf32> 36 %outputs_16, %control_17 = tf_executor.island wraps "tf.Const"() {device = "", value = dense<2.0> : tensor<186xf32>} : () -> tensor<186xf32> 37 %outputs_18, %control_19 = tf_executor.island wraps "tf.FakeQuantWithMinMaxVarsPerChannel"(%arg0, %outputs_16, %outputs_14) {device = "", narrow_range = true, num_bits = 8 : i64} : (tensor<*xf32>, tensor<186xf32>, tensor<186xf32>) -> tensor<*xf32> 38 tf_executor.fetch %outputs_18 : tensor<*xf32> 39 } 40 func.return %0 : tensor<*xf32> 41 42// CHECK: tf_executor.island wraps "tf.FakeQuantWithMinMaxVarsPerChannel" 43 44// WRAPPED-NEXT: tf_executor.graph { 45// WRAPPED-NEXT: tf_executor.island wraps "tf.Const"() {device = "", value = dense<1.000000e+00> : tensor<186xf32>} : () -> tensor<186xf32> 46// WRAPPED-NEXT: tf_executor.island wraps "tf.Const"() {device = "", value = dense<2.000000e+00> : tensor<186xf32>} : () -> tensor<186xf32> 47// WRAPPED-NEXT: tf_executor.island wraps "tfl.custom_tf" 48// WRAPPED-NEXT: ^bb0(%arg1: tensor<*xf32>, %arg2: tensor<186xf32>, %arg3: tensor<186xf32>): 49// WRAPPED-NEXT: %[[fq:.*]] = "tf.FakeQuantWithMinMaxVarsPerChannel"(%arg1, %arg2, %arg3) {device = "", narrow_range = true, num_bits = 8 : i64} : (tensor<*xf32>, tensor<186xf32>, tensor<186xf32>) -> tensor<*xf32> 50// WRAPPED-NEXT: "tfl.yield"(%[[fq]]) : (tensor<*xf32>) -> () 51// WRAPPED-NEXT: }) {device = "", narrow_range = true, num_bits = 8 : i64} : (tensor<*xf32>, tensor<186xf32>, tensor<186xf32>) -> tensor<*xf32> 52} 53