• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion use-tf2xla-fallback=false" -verify-diagnostics %s | FileCheck --check-prefix NO_FALLBACK %s
2// RUN: tf-opt "-xla-legalize-tf=use-tf2xla-fallback=true device-type=XLA_CPU_JIT" -verify-diagnostics %s | FileCheck --check-prefix SUPPORTED_FALLBACK_DEVICE %s
3// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion use-tf2xla-fallback=true" %s | FileCheck --check-prefix UNSPECIFIED_FALLBACK_DEVICE %s
4// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion use-tf2xla-fallback=true device-type=INVALID_DEVICE_TYPE" %s | FileCheck --check-prefix UNSUPPORTED_FALLBACK_DEVICE %s
5
6// We run this test four times:
7// 1) Legalize without using TF2XLA fallback (ops cannot be legalized).
8// 2) Use fallback with a device that supports all ops (ops can be legalized).
9// 3) Use fallback with unspecified device (ops cannot be legalized).
10// 4) Use fallback with specified but unsupported device (ops cannot be legalized).
11//
12// Note: For 3) and 4) we do not use `-verify-diagnostics` because these cases
13// produce remarks that don't occur for 1) and 2) and there is no way to check
14// the remarks only for 3) and 4) (except using two files).
15
16module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} {
17
18// CHECK-LABEL: non_max_suppression_v4
19func @non_max_suppression_v4(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor<f32>, %arg3: tensor<f32>) -> tensor<2xi32> {
20  %max_size = mhlo.constant dense<2> : tensor<i32>
21  // NO_FALLBACK: tf.NonMaxSuppressionV4
22  // SUPPORTED_FALLBACK_DEVICE-NOT: tf.NonMaxSuppressionV4
23  // UNSPECIFIED_FALLBACK_DEVICE: tf.NonMaxSuppressionV4
24  // UNSUPPORTED_FALLBACK_DEVICE:  tf.NonMaxSuppressionV4
25  %0:2 = "tf.NonMaxSuppressionV4"(%arg0, %arg1, %max_size, %arg2, %arg3) {pad_to_max_output_size = true}: (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<i32>)
26  return %0#0 : tensor<2xi32>
27}
28
29// CHECK-LABEL: mirror_pad
30func @mirror_pad(%arg0: tensor<2x3xcomplex<f64>>) -> tensor<4x7xcomplex<f64>> {
31  %0 = mhlo.constant dense<[[1, 1], [2, 2]]> : tensor<2x2xi32>
32  // NO_FALLBACK: tf.MirrorPad
33  // SUPPORTED_FALLBACK_DEVICE-NOT: tf.MirrorPad
34  // UNSPECIFIED_FALLBACK_DEVICE: tf.MirrorPad
35  // UNSUPPORTED_FALLBACK_DEVICE: tf.MirrorPad
36  %1 = "tf.MirrorPad"(%arg0, %0) {mode = "SYMMETRIC"} : (tensor<2x3xcomplex<f64>>, tensor<2x2xi32>) -> tensor<4x7xcomplex<f64>>
37  return %1 : tensor<4x7xcomplex<f64>>
38}
39
40// BatchMatMulV2 has native as well as fallback lowering patterns available.
41// The fallback pattern uses dot_general without broadcast on operands and then
42// transposes the output which is faster. However, the fallback pattern doesn't
43// support dynamic shaped operands like the native lowering. Verify that
44// fallback lowering is preferred for static shaped operands when available.
45
46// CHECK-LABEL: batchmatmulv2
47func @batchmatmulv2(%arg0: tensor<1x4x2xf32>, %arg1: tensor<3x2x4xf32>) -> tensor<3x4x4xf32> {
48  // NO_FALLBACK: mhlo.dynamic_broadcast_in_dim
49  // NO_FALLBACK: mhlo.dot_general
50
51  // SUPPORTED_FALLBACK_DEVICE: mhlo.reduce
52  // SUPPORTED_FALLBACK_DEVICE: mhlo.dot_general
53  // SUPPORTED_FALLBACK_DEVICE: mhlo.transpose
54
55  %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor<1x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32>
56  return %0 : tensor<3x4x4xf32>
57}
58
59}
60