• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_BASE_BINARY_OPS_TEST_H_
17 #define TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_BASE_BINARY_OPS_TEST_H_
18 
19 #include "absl/container/inlined_vector.h"
20 #include "absl/strings/string_view.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "tensorflow/core/common_runtime/device.h"
23 #include "tensorflow/core/common_runtime/device_factory.h"
24 #include "tensorflow/core/framework/fake_input.h"
25 #include "tensorflow/core/framework/node_def_builder.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/framework/tensor_testutil.h"
29 #include "tensorflow/core/kernels/mlir_generated/base_ops_test.h"
30 #include "tensorflow/core/kernels/ops_testutil.h"
31 #include "tensorflow/core/lib/core/status_test_util.h"
32 #include "tensorflow/core/platform/test.h"
33 
34 namespace tensorflow {
35 
36 // Base class for `BinaryOpsTest` fixture that has to be defined with a custom
37 // TF device if you want to use the test macros in this file.
38 class BinaryOpsTestBase : public OpsTestBase {
39  protected:
40   // This method should set the TF device, e.g. DEVICE_CPU, DEVICE_GPU.
41   void SetUp() override = 0;
42 
43   template <typename T, typename OutT>
SetOpKernel(const std::string & op_name,const TensorShape & lhs_shape,const absl::InlinedVector<T,10> & lhs_input,const TensorShape & rhs_shape,const absl::InlinedVector<T,10> & rhs_input,const test::OpsTestConfig & config)44   void SetOpKernel(const std::string& op_name, const TensorShape& lhs_shape,
45                    const absl::InlinedVector<T, 10>& lhs_input,
46                    const TensorShape& rhs_shape,
47                    const absl::InlinedVector<T, 10>& rhs_input,
48                    const test::OpsTestConfig& config) {
49     auto builder = NodeDefBuilder("some_name", op_name)
50                        .Input(FakeInput(DataTypeToEnum<T>::v()))
51                        .Input(FakeInput(DataTypeToEnum<T>::v()));
52     if (config.add_t) {
53       builder.Attr(config.input_attribute, DataTypeToEnum<T>::v());
54     }
55     if (config.add_tout) {
56       builder.Attr(config.output_attribute, DataTypeToEnum<OutT>::v());
57     }
58     TF_ASSERT_OK(builder.Finalize(node_def()));
59 
60     TF_ASSERT_OK(InitOp());
61     AddInputFromArray<T>(lhs_shape, lhs_input);
62     AddInputFromArray<T>(rhs_shape, rhs_input);
63   }
64 
65   // Run fully specified tests.
66 
67   template <typename T, typename OutT>
RunAndExpectResult(const std::string & op_name,const TensorShape & lhs_shape,const absl::InlinedVector<T,10> & lhs_input,const TensorShape & rhs_shape,const absl::InlinedVector<T,10> & rhs_input,const TensorShape & expected_shape,const absl::InlinedVector<OutT,10> & expected_output,const test::OpsTestConfig & config)68   void RunAndExpectResult(const std::string& op_name,
69                           const TensorShape& lhs_shape,
70                           const absl::InlinedVector<T, 10>& lhs_input,
71                           const TensorShape& rhs_shape,
72                           const absl::InlinedVector<T, 10>& rhs_input,
73                           const TensorShape& expected_shape,
74                           const absl::InlinedVector<OutT, 10>& expected_output,
75                           const test::OpsTestConfig& config) {
76     SetOpKernel<T, OutT>(op_name, lhs_shape, lhs_input, rhs_shape, rhs_input,
77                          config);
78     TF_ASSERT_OK(RunOpKernel());
79 
80     // Compare output to expectation.
81     Tensor expected_tensor(allocator(), DataTypeToEnum<OutT>::value,
82                            expected_shape);
83     test::FillValues<OutT>(&expected_tensor, expected_output);
84     if (config.expect_strictly_equal) {
85       test::ExpectEqual(expected_tensor, *GetOutput(0),
86                         config.supress_tolerance ? test::Tolerance::kNone
87                                                  : test::Tolerance::kDefault);
88     } else {
89       test::ExpectClose(expected_tensor, *GetOutput(0), config.atol,
90                         config.rtol);
91     }
92   }
93 
94   template <typename T, typename OutT>
RunAndExpectInvalidArgument(const std::string & op_name,const TensorShape & lhs_shape,const absl::InlinedVector<T,10> & lhs_input,const TensorShape & rhs_shape,const absl::InlinedVector<T,10> & rhs_input,const test::OpsTestConfig & config)95   void RunAndExpectInvalidArgument(const std::string& op_name,
96                                    const TensorShape& lhs_shape,
97                                    const absl::InlinedVector<T, 10>& lhs_input,
98                                    const TensorShape& rhs_shape,
99                                    const absl::InlinedVector<T, 10>& rhs_input,
100                                    const test::OpsTestConfig& config) {
101     SetOpKernel<T, OutT>(op_name, lhs_shape, lhs_input, rhs_shape, rhs_input,
102                          config);
103     auto status = RunOpKernel();
104     EXPECT_FALSE(status.ok());
105     EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
106   }
107 
108   // Run common test cases.
109 
110   template <typename T, typename OutT>
TestIncompatibleShapes(const std::string & op_name,const absl::InlinedVector<T,10> & lhs_input,const absl::InlinedVector<T,10> & rhs_input,const test::OpsTestConfig & config)111   void TestIncompatibleShapes(const std::string& op_name,
112                               const absl::InlinedVector<T, 10>& lhs_input,
113                               const absl::InlinedVector<T, 10>& rhs_input,
114                               const test::OpsTestConfig& config) {
115     // Prepare incompatibly shaped inputs.
116     TensorShape lhs_shape{3};
117     TensorShape rhs_shape{2};
118     auto repeated_lhs_input =
119         test::RepeatInputToMatchShape(lhs_input, lhs_shape.num_elements());
120     auto repeated_rhs_input =
121         test::RepeatInputToMatchShape(rhs_input, rhs_shape.num_elements());
122 
123     RunAndExpectInvalidArgument<T, OutT>(op_name, lhs_shape, repeated_lhs_input,
124                                          rhs_shape, repeated_rhs_input, config);
125   }
126 
127   template <typename T, typename BaselineT, typename OutT,
128             typename BaselineOutT>
TestEqualShapes(const std::string & op_name,const TensorShape & shape,const absl::InlinedVector<T,10> & lhs_input,const absl::InlinedVector<T,10> & rhs_input,BaselineOutT (* baseline_callback)(BaselineT,BaselineT),const test::OpsTestConfig & config)129   void TestEqualShapes(const std::string& op_name, const TensorShape& shape,
130                        const absl::InlinedVector<T, 10>& lhs_input,
131                        const absl::InlinedVector<T, 10>& rhs_input,
132                        BaselineOutT (*baseline_callback)(BaselineT, BaselineT),
133                        const test::OpsTestConfig& config) {
134     // Prepare inputs.
135     int input_size = shape.num_elements();
136     CHECK(lhs_input.size() <= input_size && rhs_input.size() <= input_size &&
137           "expect input shape to hold all input values");
138     auto repeated_lhs_input =
139         test::RepeatInputToMatchShape(lhs_input, input_size);
140     auto repeated_rhs_input =
141         test::RepeatInputToMatchShape(rhs_input, input_size);
142 
143     // Compute expected results.
144     absl::InlinedVector<OutT, 10> expected_output;
145     for (auto it_lhs = repeated_lhs_input.begin(),
146               it_rhs = repeated_rhs_input.begin(),
147               end = repeated_lhs_input.end();
148          it_lhs != end; ++it_lhs, ++it_rhs) {
149       auto lhs = static_cast<BaselineT>(*it_lhs);
150       auto rhs = static_cast<BaselineT>(*it_rhs);
151       auto result = static_cast<OutT>(baseline_callback(lhs, rhs));
152       expected_output.push_back(result);
153     }
154 
155     RunAndExpectResult<T, OutT>(op_name, shape, repeated_lhs_input, shape,
156                                 repeated_rhs_input, shape, expected_output,
157                                 config);
158   }
159 
160   template <typename T, typename BaselineT, typename OutT,
161             typename BaselineOutT>
TestOneScalar(const std::string & op_name,T scalar_input,const TensorShape & other_shape,const absl::InlinedVector<T,10> & other_input,BaselineOutT (* baseline_callback)(BaselineT,BaselineT),const test::OpsTestConfig & config)162   void TestOneScalar(const std::string& op_name, T scalar_input,
163                      const TensorShape& other_shape,
164                      const absl::InlinedVector<T, 10>& other_input,
165                      BaselineOutT (*baseline_callback)(BaselineT, BaselineT),
166                      const test::OpsTestConfig& config) {
167     // Prepare inputs.
168     TensorShape scalar_shape{};
169     CHECK(other_input.size() <= other_shape.num_elements() &&
170           "expect other input shape to hold all input values");
171     auto repeated_other_input =
172         test::RepeatInputToMatchShape(other_input, other_shape.num_elements());
173 
174     // Compute expected results.
175     absl::InlinedVector<OutT, 10> expected_output;
176     for (auto it = repeated_other_input.begin(),
177               end = repeated_other_input.end();
178          it != end; ++it) {
179       auto scalar = static_cast<BaselineT>(scalar_input);
180       auto other_value = static_cast<BaselineT>(*it);
181       auto result = static_cast<OutT>(baseline_callback(scalar, other_value));
182       expected_output.push_back(result);
183     }
184 
185     auto scalar_input_vector = test::InputAsVector<T>({scalar_input});
186     RunAndExpectResult<T, OutT>(op_name, scalar_shape, scalar_input_vector,
187                                 other_shape, repeated_other_input,
188                                 /*expected_shape=*/other_shape, expected_output,
189                                 config);
190   }
191 
192   template <typename T, typename BaselineT, typename OutT,
193             typename BaselineOutT>
TestOneEffectiveScalar(const std::string & op_name,T scalar_input,const TensorShape & other_shape,const absl::InlinedVector<T,10> & other_input,BaselineOutT (* baseline_callback)(BaselineT,BaselineT),const test::OpsTestConfig & config)194   void TestOneEffectiveScalar(const std::string& op_name, T scalar_input,
195                               const TensorShape& other_shape,
196                               const absl::InlinedVector<T, 10>& other_input,
197                               BaselineOutT (*baseline_callback)(BaselineT,
198                                                                 BaselineT),
199                               const test::OpsTestConfig& config) {
200     // Prepare inputs.
201     TensorShape effective_scalar_shape{1, 1, 1, 1, 1, 1, 1};
202     CHECK(other_input.size() <= other_shape.num_elements() &&
203           "expect other input shape to hold all input values");
204     auto repeated_other_input =
205         test::RepeatInputToMatchShape(other_input, other_shape.num_elements());
206 
207     // Compute expected results.
208     absl::InlinedVector<OutT, 10> expected_output;
209     for (auto it = repeated_other_input.begin(),
210               end = repeated_other_input.end();
211          it != end; ++it) {
212       auto scalar = static_cast<BaselineT>(scalar_input);
213       auto other_value = static_cast<BaselineT>(*it);
214       auto result = static_cast<OutT>(baseline_callback(scalar, other_value));
215       expected_output.push_back(result);
216     }
217 
218     auto scalar_input_vector = test::InputAsVector<T>({scalar_input});
219     TensorShape expected_shape = other_shape;
220     while (expected_shape.dims() < effective_scalar_shape.dims()) {
221       expected_shape.InsertDim(0, 1);
222     }
223     RunAndExpectResult<T, OutT>(
224         op_name, effective_scalar_shape, scalar_input_vector, other_shape,
225         repeated_other_input, expected_shape, expected_output, config);
226   }
227 
228   template <typename T, typename BaselineT, typename OutT,
229             typename BaselineOutT>
TestBroadcastingExpand(const std::string & op_name,const absl::InlinedVector<T,10> & lhs_input,const absl::InlinedVector<T,10> & rhs_input,BaselineOutT (* baseline_callback)(BaselineT,BaselineT),const test::OpsTestConfig & config)230   void TestBroadcastingExpand(const std::string& op_name,
231                               const absl::InlinedVector<T, 10>& lhs_input,
232                               const absl::InlinedVector<T, 10>& rhs_input,
233                               BaselineOutT (*baseline_callback)(BaselineT,
234                                                                 BaselineT),
235                               const test::OpsTestConfig& config) {
236     // Prepare inputs.
237     TensorShape lhs_shape{1};
238     TensorShape rhs_shape{6};
239     auto repeated_lhs_input =
240         test::RepeatInputToMatchShape(lhs_input, lhs_shape.num_elements());
241     auto repeated_rhs_input =
242         test::RepeatInputToMatchShape(rhs_input, rhs_shape.num_elements());
243 
244     // Compute expected results.
245     std::vector<int> lhs_indices = {0, 0, 0, 0, 0, 0};
246     std::vector<int> rhs_indices = {0, 1, 2, 3, 4, 5};
247     auto expected_output =
248         ComputeExpectedOutput<T, BaselineT, OutT, BaselineOutT>(
249             lhs_indices, repeated_lhs_input, rhs_indices, repeated_rhs_input,
250             baseline_callback);
251 
252     RunAndExpectResult<T, OutT>(
253         op_name, lhs_shape, repeated_lhs_input, rhs_shape, repeated_rhs_input,
254         /*expected_shape=*/rhs_shape, expected_output, config);
255   }
256 
257   template <typename T, typename BaselineT, typename OutT,
258             typename BaselineOutT>
TestBroadcastingInDim(const std::string & op_name,const absl::InlinedVector<T,10> & lhs_input,const absl::InlinedVector<T,10> & rhs_input,BaselineOutT (* baseline_callback)(BaselineT,BaselineT),const test::OpsTestConfig & config)259   void TestBroadcastingInDim(const std::string& op_name,
260                              const absl::InlinedVector<T, 10>& lhs_input,
261                              const absl::InlinedVector<T, 10>& rhs_input,
262                              BaselineOutT (*baseline_callback)(BaselineT,
263                                                                BaselineT),
264                              const test::OpsTestConfig& config) {
265     // Prepare inputs.
266     TensorShape lhs_shape{3};
267     TensorShape rhs_shape{2, 3};
268     auto repeated_lhs_input =
269         test::RepeatInputToMatchShape(lhs_input, lhs_shape.num_elements());
270     auto repeated_rhs_input =
271         test::RepeatInputToMatchShape(rhs_input, rhs_shape.num_elements());
272 
273     // Compute expected results.
274     std::vector<int> lhs_indices = {0, 1, 2, 0, 1, 2};
275     std::vector<int> rhs_indices = {0, 1, 2, 3, 4, 5};
276     auto expected_output =
277         ComputeExpectedOutput<T, BaselineT, OutT, BaselineOutT>(
278             lhs_indices, repeated_lhs_input, rhs_indices, repeated_rhs_input,
279             baseline_callback);
280 
281     RunAndExpectResult<T, OutT>(
282         op_name, lhs_shape, repeated_lhs_input, rhs_shape, repeated_rhs_input,
283         /*expected_shape=*/rhs_shape, expected_output, config);
284   }
285 
286   template <typename T, typename BaselineT, typename OutT,
287             typename BaselineOutT>
TestBroadcasting(const std::string & op_name,const absl::InlinedVector<T,10> & lhs_input,const absl::InlinedVector<T,10> & rhs_input,BaselineOutT (* baseline_callback)(BaselineT,BaselineT),const test::OpsTestConfig & config)288   void TestBroadcasting(const std::string& op_name,
289                         const absl::InlinedVector<T, 10>& lhs_input,
290                         const absl::InlinedVector<T, 10>& rhs_input,
291                         BaselineOutT (*baseline_callback)(BaselineT, BaselineT),
292                         const test::OpsTestConfig& config) {
293     // Prepare inputs.
294     TensorShape lhs_shape{2, 1};
295     TensorShape rhs_shape{3};
296     auto repeated_lhs_input =
297         test::RepeatInputToMatchShape(lhs_input, lhs_shape.num_elements());
298     auto repeated_rhs_input =
299         test::RepeatInputToMatchShape(rhs_input, rhs_shape.num_elements());
300 
301     // Compute expected results.
302     TensorShape expected_shape{2, 3};
303     std::vector<int> lhs_indices = {0, 0, 0, 1, 1, 1};
304     std::vector<int> rhs_indices = {0, 1, 2, 0, 1, 2};
305     auto expected_output =
306         ComputeExpectedOutput<T, BaselineT, OutT, BaselineOutT>(
307             lhs_indices, repeated_lhs_input, rhs_indices, repeated_rhs_input,
308             baseline_callback);
309 
310     RunAndExpectResult<T, OutT>(op_name, lhs_shape, repeated_lhs_input,
311                                 rhs_shape, repeated_rhs_input, expected_shape,
312                                 expected_output, config);
313   }
314 
315   template <typename T, typename BaselineT, typename OutT,
316             typename BaselineOutT>
TestBroadcastingRank6(const std::string & op_name,const absl::InlinedVector<T,10> & lhs_input,const absl::InlinedVector<T,10> & rhs_input,BaselineOutT (* baseline_callback)(BaselineT,BaselineT),const test::OpsTestConfig & config)317   void TestBroadcastingRank6(const std::string& op_name,
318                              const absl::InlinedVector<T, 10>& lhs_input,
319                              const absl::InlinedVector<T, 10>& rhs_input,
320                              BaselineOutT (*baseline_callback)(BaselineT,
321                                                                BaselineT),
322                              const test::OpsTestConfig& config) {
323     // Prepare inputs.
324     TensorShape lhs_shape{1, 2, 3, 1, 2, 1};
325     TensorShape rhs_shape{1, 1, 1, 2, 3};
326     auto repeated_lhs_input =
327         test::RepeatInputToMatchShape(lhs_input, lhs_shape.num_elements());
328     auto repeated_rhs_input =
329         test::RepeatInputToMatchShape(rhs_input, rhs_shape.num_elements());
330 
331     // Compute expected results.
332     TensorShape expected_shape{1, 2, 3, 1, 2, 3};
333     std::vector<int> lhs_indices = {0, 0, 0, 1, 1, 1, 2,  2,  2,  3,  3,  3,
334                                     4, 4, 4, 5, 5, 5, 6,  6,  6,  7,  7,  7,
335                                     8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11};
336     std::vector<int> rhs_indices = {
337         0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5,
338         0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5,
339     };
340     auto expected_output =
341         ComputeExpectedOutput<T, BaselineT, OutT, BaselineOutT>(
342             lhs_indices, repeated_lhs_input, rhs_indices, repeated_rhs_input,
343             baseline_callback);
344 
345     RunAndExpectResult<T, OutT>(op_name, lhs_shape, repeated_lhs_input,
346                                 rhs_shape, repeated_rhs_input, expected_shape,
347                                 expected_output, config);
348   }
349 
350   template <typename T, typename BaselineT, typename OutT,
351             typename BaselineOutT>
TestEmptyShapeBroadcasting(const std::string & op_name,const absl::InlinedVector<T,10> & lhs_input,const absl::InlinedVector<T,10> & rhs_input,const test::OpsTestConfig & config)352   void TestEmptyShapeBroadcasting(const std::string& op_name,
353                                   const absl::InlinedVector<T, 10>& lhs_input,
354                                   const absl::InlinedVector<T, 10>& rhs_input,
355                                   const test::OpsTestConfig& config) {
356     // Prepare inputs.
357     TensorShape lhs_shape{2, 0, 1};
358     TensorShape rhs_shape{2, 0, 5};
359     absl::InlinedVector<T, 10> empty_input = {};
360 
361     // Define expected result.
362     TensorShape expected_shape{2, 0, 5};
363     absl::InlinedVector<OutT, 10> expected_output = {};
364 
365     RunAndExpectResult<T, OutT>(op_name, lhs_shape, empty_input, rhs_shape,
366                                 empty_input, expected_shape, expected_output,
367                                 config);
368   }
369 
370  private:
371   template <typename T, typename BaselineT, typename OutT,
372             typename BaselineOutT>
ComputeExpectedOutput(std::vector<int> lhs_indices,absl::InlinedVector<T,10> lhs_input,std::vector<int> rhs_indices,absl::InlinedVector<T,10> rhs_input,BaselineOutT (* baseline_callback)(BaselineT,BaselineT))373   absl::InlinedVector<OutT, 10> ComputeExpectedOutput(
374       std::vector<int> lhs_indices, absl::InlinedVector<T, 10> lhs_input,
375       std::vector<int> rhs_indices, absl::InlinedVector<T, 10> rhs_input,
376       BaselineOutT (*baseline_callback)(BaselineT, BaselineT)) {
377     absl::InlinedVector<OutT, 10> expected_output;
378     for (int i = 0; i < lhs_indices.size(); i++) {
379       auto lhs = static_cast<BaselineT>(lhs_input[lhs_indices[i]]);
380       auto rhs = static_cast<BaselineT>(rhs_input[rhs_indices[i]]);
381       auto result = static_cast<OutT>(baseline_callback(lhs, rhs));
382       expected_output.push_back(result);
383     }
384     return expected_output;
385   }
386 };
387 
388 // Macros to easily generate common test cases. The macros use `BinaryOpsTest`
389 // fixture in order to share implementation across GPU and CPU platform tests.
390 // For specific inputs, please define your own test fixtures.
391 #define GENERATE_DEFAULT_NO_BROADCASTING_TESTS_2(                            \
392     op_name, test_name, T, BaselineT, OutT, BaselineOutT, lhs_input,         \
393     rhs_input, baseline_callback, config)                                    \
394   TEST_F(BinaryOpsTest, op_name##EqShapes##test_name) {                      \
395     TestEqualShapes<T, BaselineT, OutT, BaselineOutT>(                       \
396         #op_name, /*shape=*/test::DefaultInputShape(), lhs_input, rhs_input, \
397         baseline_callback, config);                                          \
398   }                                                                          \
399   TEST_F(BinaryOpsTest, op_name##IncompatibleShapes##test_name) {            \
400     TestIncompatibleShapes<T, OutT>(#op_name, lhs_input, rhs_input, config); \
401   }
402 
403 #define GENERATE_DEFAULT_TESTS_2(op_name, test_name, T, BaselineT, OutT,      \
404                                  BaselineOutT, lhs_input, rhs_input,          \
405                                  baseline_callback, config)                   \
406                                                                               \
407   GENERATE_DEFAULT_NO_BROADCASTING_TESTS_2(                                   \
408       op_name, test_name, T, BaselineT, OutT, BaselineOutT, lhs_input,        \
409       rhs_input, baseline_callback, config)                                   \
410                                                                               \
411   TEST_F(BinaryOpsTest, op_name##OneScalar##test_name) {                      \
412     TestOneScalar<T, BaselineT, OutT, BaselineOutT>(                          \
413         #op_name, /*scalar_input=*/lhs_input.front(),                         \
414         /*other_shape=*/test::DefaultInputShape(), /*other_input=*/rhs_input, \
415         baseline_callback, config);                                           \
416   }                                                                           \
417                                                                               \
418   TEST_F(BinaryOpsTest, op_name##TestOneEffectiveScalar##test_name) {         \
419     TestOneEffectiveScalar<T, BaselineT, OutT, BaselineOutT>(                 \
420         #op_name, /*scalar_input=*/lhs_input.front(),                         \
421         /*other_shape=*/test::DefaultInputShape(), /*other_input=*/rhs_input, \
422         baseline_callback, config);                                           \
423   }                                                                           \
424                                                                               \
425   TEST_F(BinaryOpsTest, op_name##BroadcastingExpand##test_name) {             \
426     TestBroadcastingExpand<T, BaselineT, OutT, BaselineOutT>(                 \
427         #op_name, lhs_input, rhs_input, baseline_callback, config);           \
428   }                                                                           \
429                                                                               \
430   TEST_F(BinaryOpsTest, op_name##BroadcastingInDim##test_name) {              \
431     TestBroadcastingInDim<T, BaselineT, OutT, BaselineOutT>(                  \
432         #op_name, lhs_input, rhs_input, baseline_callback, config);           \
433   }                                                                           \
434                                                                               \
435   TEST_F(BinaryOpsTest, op_name##Broadcasting##test_name) {                   \
436     TestBroadcasting<T, BaselineT, OutT, BaselineOutT>(                       \
437         #op_name, lhs_input, rhs_input, baseline_callback, config);           \
438   }                                                                           \
439                                                                               \
440   TEST_F(BinaryOpsTest, op_name##BroadcastingRank6##test_name) {              \
441     TestBroadcastingRank6<T, BaselineT, OutT, BaselineOutT>(                  \
442         #op_name, lhs_input, rhs_input, baseline_callback, config);           \
443   }                                                                           \
444                                                                               \
445   TEST_F(BinaryOpsTest, op_name##EmptyShapeBroadcasting##test_name) {         \
446     TestEmptyShapeBroadcasting<T, BaselineT, OutT, BaselineOutT>(             \
447         #op_name, lhs_input, rhs_input, config);                              \
448   }
449 
450 #define GENERATE_DEFAULT_TESTS(op_name, test_name, T, OutT, baseline_callback, \
451                                config)                                         \
452   GENERATE_DEFAULT_TESTS_2(op_name, test_name, T, T, OutT, OutT,               \
453                            test::DefaultInput<T>(), test::DefaultInput<T>(),   \
454                            baseline_callback, config)
455 
456 #define GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(                  \
457     op_name, test_name, T, OutT, lhs_input, rhs_input, baseline_callback,   \
458     config)                                                                 \
459   GENERATE_DEFAULT_TESTS_2(op_name, test_name, T, T, OutT, OutT, lhs_input, \
460                            rhs_input, baseline_callback, config)
461 
462 #define GENERATE_DEFAULT_NO_BROADCASTING_TESTS(op_name, test_name, T, OutT, \
463                                                baseline_callback)           \
464   GENERATE_DEFAULT_NO_BROADCASTING_TESTS_2(                                 \
465       op_name, test_name, T, T, OutT, OutT, test::DefaultInput<T>(),        \
466       test::DefaultInput<T>(), baseline_callback,                           \
467       test::OpsTestConfig().ExpectStrictlyEqual())
468 
469 }  // namespace tensorflow
470 
471 #endif  // TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_BASE_BINARY_OPS_TEST_H_
472