1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
16
17 #include <initializer_list>
18 #include <memory>
19 #include <string>
20 #include <tuple>
21 #include <utility>
22 #include <vector>
23
24 #include "absl/memory/memory.h"
25 #include "absl/strings/str_format.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/literal.h"
28 #include "tensorflow/compiler/xla/permutation_util.h"
29 #include "tensorflow/compiler/xla/reference_util.h"
30 #include "tensorflow/compiler/xla/service/hlo_computation.h"
31 #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
32 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/status.h"
35 #include "tensorflow/compiler/xla/status_macros.h"
36 #include "tensorflow/compiler/xla/statusor.h"
37 #include "tensorflow/compiler/xla/test.h"
38 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
39 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
40 #include "tensorflow/compiler/xla/tests/test_utils.h"
41 #include "tensorflow/compiler/xla/types.h"
42 #include "tensorflow/compiler/xla/util.h"
43 #include "tensorflow/compiler/xla/xla_data.pb.h"
44 #include "tensorflow/core/lib/core/status.h"
45 #include "tensorflow/core/lib/core/status_test_util.h"
46 #include "tensorflow/core/platform/test.h"
47 #include "tensorflow/core/platform/test_benchmark.h"
48 #include "tensorflow/core/platform/types.h"
49
50 namespace xla {
51 namespace {
52
53 static std::array<bool, 2> use_bf16_params{true, false};
54
55 // Test fixture for the HloEvaluator.
56 //
57 // In bf16 mode, all f32 shapes are converted to bf16 before running.
58 class HloEvaluatorTest : public HloTestBase {
59 public:
HloEvaluatorTest()60 HloEvaluatorTest() : use_bfloat16_(false) { InitializeFftData(); }
61
Evaluate(absl::Span<const Literal * const> arg_literals={})62 StatusOr<Literal> Evaluate(
63 absl::Span<const Literal* const> arg_literals = {}) {
64 if (use_bfloat16_) {
65 HloElementTypeConverter(F32, BF16).Run(m_.get()).ValueOrDie();
66 }
67 return evaluator_.Evaluate(*m_->entry_computation(), arg_literals);
68 }
69
70 // Evaluate function that takes in a local module instead of using m_
71 // that is in HloTestBase. Once m_ in HloTestBase is
72 // removed, this should be the default Evaluate function.
EvaluateWithModule(HloModule * module,absl::Span<const Literal * const> arg_literals={})73 Literal EvaluateWithModule(
74 HloModule* module, absl::Span<const Literal* const> arg_literals = {}) {
75 if (use_bfloat16_) {
76 HloElementTypeConverter(F32, BF16).Run(m_.get()).ValueOrDie();
77 }
78 return evaluator_.Evaluate(*module->entry_computation(), arg_literals)
79 .ConsumeValueOrDie();
80 }
81
TestUnaryOp(HloOpcode opcode,Literal expected,Literal input,float aabs=0)82 void TestUnaryOp(HloOpcode opcode, Literal expected, Literal input,
83 float aabs = 0) {
84 HloComputation::Builder b(TestName());
85 auto c1 =
86 b.AddInstruction(HloInstruction::CreateConstant(std::move(input)));
87 b.AddInstruction(HloInstruction::CreateUnary(expected.shape(), opcode, c1));
88 m_->AddEntryComputation(b.Build());
89
90 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
91
92 auto element_type = expected.shape().element_type();
93 if (element_type == F32 || element_type == F64) {
94 ErrorSpec error(aabs);
95 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, error));
96 } else {
97 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
98 }
99 }
100
TestBinaryOp(HloOpcode opcode,Literal expected,Literal lhs,Literal rhs)101 void TestBinaryOp(HloOpcode opcode, Literal expected, Literal lhs,
102 Literal rhs) {
103 HloComputation::Builder b(TestName());
104 auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs)));
105 auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs)));
106 b.AddInstruction(
107 HloInstruction::CreateBinary(expected.shape(), opcode, c1, c2));
108 m_->AddEntryComputation(b.Build());
109
110 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
111
112 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
113 }
114
TestTernaryOp(HloOpcode opcode,Literal expected,Literal src0,Literal src1,Literal src2)115 void TestTernaryOp(HloOpcode opcode, Literal expected, Literal src0,
116 Literal src1, Literal src2) {
117 HloComputation::Builder b(TestName());
118 auto operand0 =
119 b.AddInstruction(HloInstruction::CreateConstant(std::move(src0)));
120 auto operand1 =
121 b.AddInstruction(HloInstruction::CreateConstant(std::move(src1)));
122 auto operand2 =
123 b.AddInstruction(HloInstruction::CreateConstant(std::move(src2)));
124 b.AddInstruction(HloInstruction::CreateTernary(
125 expected.shape(), opcode, operand0, operand1, operand2));
126 m_->AddEntryComputation(b.Build());
127
128 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
129
130 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
131 }
132
MaxComputationScalarF32()133 std::unique_ptr<HloComputation> MaxComputationScalarF32() {
134 HloComputation::Builder max_computation("max");
135 Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
136 auto param_lhs = max_computation.AddInstruction(
137 HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
138 auto param_rhs = max_computation.AddInstruction(
139 HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
140 max_computation.AddInstruction(HloInstruction::CreateBinary(
141 scalar_shape, HloOpcode::kMaximum, param_lhs, param_rhs));
142 return max_computation.Build();
143 }
144
ReduceWindowMaxIotaTest(int window_size,int padding,int stride,int window_dilation,int base_dilation,const Literal & expected)145 void ReduceWindowMaxIotaTest(int window_size, int padding, int stride,
146 int window_dilation, int base_dilation,
147 const Literal& expected) {
148 HloComputation::Builder b(TestName());
149
150 // arg:
151 // f32[4,4] {
152 // { 0, 1, 2, 3 },
153 // { 4, 5, 6, 7 },
154 // { 8, 9, 10, 11 },
155 // { 12, 13, 14, 15 }
156 // }
157 auto arg_array = absl::make_unique<Array2D<float>>(4, 4);
158 arg_array->FillIota(0);
159 auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
160
161 HloInstruction* arg_instruction = b.AddInstruction(
162 HloInstruction::CreateConstant(std::move(arg_literal)));
163 auto init_value = b.AddInstruction(
164 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
165 auto max_func = m_->AddEmbeddedComputation(MaxComputationScalarF32());
166
167 Window window;
168 WindowDimension dim;
169 dim.set_size(window_size);
170 dim.set_stride(stride);
171 dim.set_padding_low(padding);
172 dim.set_padding_high(padding);
173 dim.set_window_dilation(window_dilation);
174 dim.set_base_dilation(base_dilation);
175 *window.add_dimensions() = dim;
176 *window.add_dimensions() = dim;
177
178 int dim0 = expected.shape().dimensions(0);
179 int dim1 = expected.shape().dimensions(1);
180 Shape shape = ShapeUtil::MakeShape(F32, {dim0, dim1});
181 b.AddInstruction(HloInstruction::CreateReduceWindow(
182 shape, arg_instruction, init_value, window, max_func));
183
184 m_->AddEntryComputation(b.Build());
185 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
186 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
187 }
188
189 protected:
HloEvaluatorTest(bool use_bfloat16)190 explicit HloEvaluatorTest(bool use_bfloat16) : use_bfloat16_(use_bfloat16) {
191 InitializeFftData();
192 }
193
194 // Initializes data sets used in FFT tests below.
195 void InitializeFftData();
196
197 HloEvaluator evaluator_;
198
199 const bool use_bfloat16_;
200 std::unique_ptr<HloModule> m_ = CreateNewVerifiedModule();
201
202 // Data sets used in FFT tests below.
203 ErrorSpec fft_error_ = ErrorSpec(1e-4, 1e-5);
204 Literal fft_c64x2x4x8_;
205 Literal fft_c64x2x4x8_1d_;
206 Literal fft_c64x2x4x8_2d_;
207 Literal fft_c64x2x4x8_3d_;
208 };
209
210 // Lets you write TEST_Ps that run twice, once with and once without bf16.
211 class HloEvaluatorBf16Test : public ::testing::WithParamInterface<bool>,
212 public HloEvaluatorTest {
213 protected:
HloEvaluatorBf16Test()214 HloEvaluatorBf16Test() : HloEvaluatorTest(/*use_bfloat16=*/GetParam()) {}
215 };
216
217 INSTANTIATE_TEST_SUITE_P(HloEvaluatorTest_Instantiation, HloEvaluatorBf16Test,
218 ::testing::ValuesIn(use_bf16_params));
219
220 // Verifies that HloEvaluator evaluates a HLO instruction that performs clamp
221 // with 3 operands.
TEST_P(HloEvaluatorBf16Test,DoesClamp)222 TEST_P(HloEvaluatorBf16Test, DoesClamp) {
223 auto low = LiteralUtil::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}});
224 auto value = LiteralUtil::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}});
225 auto high = LiteralUtil::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}});
226
227 Shape shape = low.shape();
228 HloComputation::Builder b(TestName());
229 auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low)));
230 auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
231 auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high)));
232 b.AddInstruction(
233 HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3));
234 m_->AddEntryComputation(b.Build());
235
236 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
237
238 auto expected = LiteralUtil::CreateR2<float>({{0, 4}, {2, 4}});
239
240 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
241 }
242
243 // Verifies that clamping of int64 does not cause loss of precision
TEST_P(HloEvaluatorBf16Test,DoesClampInt64)244 TEST_P(HloEvaluatorBf16Test, DoesClampInt64) {
245 auto ones = [](int bits) { return (int64{1} << bits) - 1; };
246
247 auto low =
248 LiteralUtil::CreateR2<int64>({{0, ones(54)}, {ones(54), ones(58)}});
249 auto value = LiteralUtil::CreateR2<int64>({{0, ones(56)}, {0, ones(58)}});
250 auto high = LiteralUtil::CreateR2<int64>(
251 {{ones(54), ones(55)}, {ones(56), ones(58)}});
252
253 Shape shape = low.shape();
254 HloComputation::Builder b(TestName());
255 auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low)));
256 auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
257 auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high)));
258 b.AddInstruction(
259 HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3));
260 m_->AddEntryComputation(b.Build());
261
262 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
263
264 auto expected =
265 LiteralUtil::CreateR2<int64>({{0, ones(55)}, {ones(54), ones(58)}});
266
267 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
268 }
269
TEST_P(HloEvaluatorBf16Test,DISABLED_DoesClampSpecialBroadcast)270 TEST_P(HloEvaluatorBf16Test, DISABLED_DoesClampSpecialBroadcast) {
271 auto low = LiteralUtil::CreateR0<float>(0.f);
272 auto value = LiteralUtil::CreateR2<float>({{-1.f, 0.f}, {1.f, 2.f}});
273 auto high = LiteralUtil::CreateR0<float>(1.f);
274
275 Shape shape = value.shape();
276 HloComputation::Builder b(TestName());
277 auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low)));
278 auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
279 auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high)));
280 b.AddInstruction(
281 HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3));
282 m_->AddEntryComputation(b.Build());
283
284 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
285
286 auto expected = LiteralUtil::CreateR2<float>({{0, 0}, {1, 1}});
287
288 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
289 }
290
291 // Verifies that HloEvaluator evaluates a HLO instruction that performs select
292 // with 3 operands.
TEST_P(HloEvaluatorBf16Test,DoesSelect)293 TEST_P(HloEvaluatorBf16Test, DoesSelect) {
294 auto pred = LiteralUtil::CreateR2<bool>({{true, false}, {false, true}});
295 auto on_true = LiteralUtil::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}});
296 auto on_false = LiteralUtil::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}});
297
298 Shape shape = on_true.shape();
299 HloComputation::Builder b(TestName());
300 auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(pred)));
301 auto c2 =
302 b.AddInstruction(HloInstruction::CreateConstant(std::move(on_true)));
303 auto c3 =
304 b.AddInstruction(HloInstruction::CreateConstant(std::move(on_false)));
305 b.AddInstruction(
306 HloInstruction::CreateTernary(shape, HloOpcode::kSelect, c1, c2, c3));
307 m_->AddEntryComputation(b.Build());
308
309 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({}));
310
311 auto expected = LiteralUtil::CreateR2<float>({{2, 5}, {0, 4}});
312
313 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
314 }
315
316 // Verifies that HloEvaluator evaluates a HLO instruction that performs
317 // element-wise addition with 2 operands.
TEST_F(HloEvaluatorTest,DoesAdd)318 TEST_F(HloEvaluatorTest, DoesAdd) {
319 auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
320 auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
321 auto expected = LiteralUtil::CreateR2<int64>({{3, 4}, {-96, 8}});
322 TestBinaryOp(HloOpcode::kAdd, std::move(expected), std::move(lhs),
323 std::move(rhs));
324 }
325 // Verifies that HloEvaluator evaluates a HLO instruction that performs
326 // element-wise and with 2 operands.
TEST_P(HloEvaluatorBf16Test,DoesAnd)327 TEST_P(HloEvaluatorBf16Test, DoesAnd) {
328 auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
329 auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
330 auto expected = LiteralUtil::CreateR2<int64>({{0, 0}, {4, 4}});
331 TestBinaryOp(HloOpcode::kAnd, std::move(expected), std::move(lhs),
332 std::move(rhs));
333 }
334 // Verifies that HloEvaluator evaluates a HLO instruction that performs
335 // element-wise or with 2 operands.
TEST_F(HloEvaluatorTest,DoesOr)336 TEST_F(HloEvaluatorTest, DoesOr) {
337 auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
338 auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
339 auto expected = LiteralUtil::CreateR2<int64>({{3, 4}, {-100, 4}});
340 TestBinaryOp(HloOpcode::kOr, std::move(expected), std::move(lhs),
341 std::move(rhs));
342 }
343 // Verifies that HloEvaluator evaluates a HLO instruction that performs
344 // element-wise or with 2 operands.
TEST_F(HloEvaluatorTest,DoesXor)345 TEST_F(HloEvaluatorTest, DoesXor) {
346 auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
347 auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
348 auto expected = LiteralUtil::CreateR2<int64>({{3, 4}, {-104, 0}});
349 TestBinaryOp(HloOpcode::kXor, std::move(expected), std::move(lhs),
350 std::move(rhs));
351 }
352 // Verifies that HloEvaluator evaluates a HLO instruction that performs
353 // element-wise multiply with 2 operands.
TEST_F(HloEvaluatorTest,DoesMultiply)354 TEST_F(HloEvaluatorTest, DoesMultiply) {
355 auto lhs = LiteralUtil::CreateR2<int32>({{-1, 0}, {-100, 4}});
356 auto rhs = LiteralUtil::CreateR2<int32>(
357 {{std::numeric_limits<int32>::min(), 4}, {4, 4}});
358 auto expected = LiteralUtil::CreateR2<int32>(
359 {{std::numeric_limits<int32>::min(), 0}, {-400, 16}});
360 TestBinaryOp(HloOpcode::kMultiply, std::move(expected), std::move(lhs),
361 std::move(rhs));
362 }
363 // Verifies that HloEvaluator evaluates a HLO instruction that performs
364 // element-wise divide with 2 operands.
TEST_F(HloEvaluatorTest,DoesDivideInt64)365 TEST_F(HloEvaluatorTest, DoesDivideInt64) {
366 auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
367 auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
368 auto expected = LiteralUtil::CreateR2<int64>({{0, 0}, {-25, 1}});
369 TestBinaryOp(HloOpcode::kDivide, std::move(expected), std::move(lhs),
370 std::move(rhs));
371 }
372
TEST_F(HloEvaluatorTest,DoesClampS64)373 TEST_F(HloEvaluatorTest, DoesClampS64) {
374 auto low = LiteralUtil::CreateR1<int64>(
375 {-8616761059752331528LL, 6780561065411491190LL, -8616761059752331528LL});
376 auto value = LiteralUtil::CreateR1<int64>(
377 {-6780561065411491190LL, 6780561065411491180LL, 4241131823772864090LL});
378 auto high = LiteralUtil::CreateR1<int64>(
379 {-6780561065411491180LL, 8616761059752331528LL, 3832151243857508051LL});
380 auto expected = LiteralUtil::CreateR1<int64>(
381 {-6780561065411491190LL, 6780561065411491190LL, 3832151243857508051LL});
382 TestTernaryOp(HloOpcode::kClamp, std::move(expected), std::move(low),
383 std::move(value), std::move(high));
384 }
385
TEST_P(HloEvaluatorBf16Test,DoesDivideDouble)386 TEST_P(HloEvaluatorBf16Test, DoesDivideDouble) {
387 auto lhs = LiteralUtil::CreateR2<double>({{1.0, 0.0}, {-100.0, 4.0}});
388 auto rhs = LiteralUtil::CreateR2<double>({{2.2, 4.0}, {4.0, 4.0}});
389 auto expected =
390 LiteralUtil::CreateR2<double>({{0.45454545454545453, 0}, {-25, 1}});
391 TestBinaryOp(HloOpcode::kDivide, std::move(expected), std::move(lhs),
392 std::move(rhs));
393 }
394
395 // Verifies that HloEvaluator evaluates a HLO instruction that performs
396 // element-wise abs op with 1 operand.
TEST_F(HloEvaluatorTest,DoesAbsR2)397 TEST_F(HloEvaluatorTest, DoesAbsR2) {
398 auto operand = LiteralUtil::CreateR2<int64>({{1, -20}, {-100, 4}});
399 auto expected = LiteralUtil::CreateR2<int64>({{1, 20}, {100, 4}});
400 TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand));
401 }
TEST_P(HloEvaluatorBf16Test,DoesAbsR0)402 TEST_P(HloEvaluatorBf16Test, DoesAbsR0) {
403 auto operand = LiteralUtil::CreateR0<float>(-1.0f);
404 auto expected = LiteralUtil::CreateR0<float>(1.0f);
405 TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand));
406 }
TEST_P(HloEvaluatorBf16Test,DoesAbsR1WithZeroSize)407 TEST_P(HloEvaluatorBf16Test, DoesAbsR1WithZeroSize) {
408 auto operand = LiteralUtil::CreateR1<float>({});
409 auto expected = LiteralUtil::CreateR1<float>({});
410 TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand));
411 }
412
TEST_F(HloEvaluatorTest,DoesAbsC128)413 TEST_F(HloEvaluatorTest, DoesAbsC128) {
414 auto x = LiteralUtil::CreateR0<complex128>({1, 2});
415 auto expected_real = LiteralUtil::CreateR0<double>(2.23607);
416 TestUnaryOp(HloOpcode::kAbs, std::move(expected_real), std::move(x), 3e-06);
417 }
418
TEST_F(HloEvaluatorTest,DoesNegateR2)419 TEST_F(HloEvaluatorTest, DoesNegateR2) {
420 auto operand = LiteralUtil::CreateR2<int32>(
421 {{0, std::numeric_limits<int32>::min()}, {-1, 4}});
422 auto expected = LiteralUtil::CreateR2<int32>(
423 {{0, std::numeric_limits<int>::min()}, {1, -4}});
424 TestUnaryOp(HloOpcode::kNegate, std::move(expected), std::move(operand));
425 }
TEST_P(HloEvaluatorBf16Test,DoesCosR2)426 TEST_P(HloEvaluatorBf16Test, DoesCosR2) {
427 auto operand = LiteralUtil::CreateR2<float>({{0, M_PI}, {-M_PI, 2 * M_PI}});
428 auto expected = LiteralUtil::CreateR2<float>({{1, -1}, {-1, 1}});
429 TestUnaryOp(HloOpcode::kCos, std::move(expected), std::move(operand),
430 use_bfloat16_ ? 0.031250 : 9.5367431640625E-7);
431 }
TEST_P(HloEvaluatorBf16Test,DoesSinR2)432 TEST_P(HloEvaluatorBf16Test, DoesSinR2) {
433 auto operand = LiteralUtil::CreateR2<float>({{0, M_PI}, {-M_PI, 2 * M_PI}});
434 auto expected = LiteralUtil::CreateR2<float>({{0, 0}, {0, 0}});
435 TestUnaryOp(HloOpcode::kSin, std::move(expected), std::move(operand),
436 use_bfloat16_ ? 0.031250 : 9.5367431640625E-7);
437 }
TEST_F(HloEvaluatorTest,DoesNotR2)438 TEST_F(HloEvaluatorTest, DoesNotR2) {
439 auto operand =
440 LiteralUtil::CreateR2<int32>({{0, std::numeric_limits<int>::min()},
441 {-1, std::numeric_limits<int>::max()}});
442 auto expected =
443 LiteralUtil::CreateR2<int32>({{-1, std::numeric_limits<int>::max()},
444 {0, std::numeric_limits<int>::min()}});
445 TestUnaryOp(HloOpcode::kNot, std::move(expected), std::move(operand));
446 }
447
TEST_F(HloEvaluatorTest,DoesRealC128)448 TEST_F(HloEvaluatorTest, DoesRealC128) {
449 auto x = LiteralUtil::CreateR1<complex128>({{1, 0}, {-100, 4}});
450 auto expected_real = LiteralUtil::CreateR1<double>({1, -100});
451 TestUnaryOp(HloOpcode::kReal, std::move(expected_real), std::move(x));
452 }
453
TEST_F(HloEvaluatorTest,DoesImagC128)454 TEST_F(HloEvaluatorTest, DoesImagC128) {
455 auto x = LiteralUtil::CreateR1<complex128>({{1, 0}, {-100, 4}});
456 auto expected_imag = LiteralUtil::CreateR1<double>({0, 4});
457 TestUnaryOp(HloOpcode::kImag, std::move(expected_imag), std::move(x));
458 }
459
460 // Verifies that HloEvaluator evaluates a HLO Computation with non-parameter nor
461 // constant operands.
TEST_F(HloEvaluatorTest,DoesTraverseInstructions)462 TEST_F(HloEvaluatorTest, DoesTraverseInstructions) {
463 auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
464 auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
465 auto rhs2 = LiteralUtil::CreateR2<int64>({{1, -20}, {-100, 4}});
466 std::vector<const Literal*> args = {&lhs, &rhs, &rhs2};
467
468 Shape shape = ShapeUtil::MakeShape(S64, {2, 2});
469
470 HloComputation::Builder b(TestName());
471 auto param_lhs =
472 b.AddInstruction(HloInstruction::CreateParameter(0, shape, "lhs"));
473 auto param_rhs =
474 b.AddInstruction(HloInstruction::CreateParameter(1, shape, "rhs"));
475 auto lhs_instruction = b.AddInstruction(HloInstruction::CreateBinary(
476 shape, HloOpcode::kAdd, param_lhs, param_rhs));
477
478 auto param_rhs2 =
479 b.AddInstruction(HloInstruction::CreateParameter(2, shape, "rhs2"));
480 b.AddInstruction(HloInstruction::CreateBinary(shape, HloOpcode::kAdd,
481 lhs_instruction, param_rhs2));
482 m_->AddEntryComputation(b.Build());
483
484 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate(args));
485
486 auto expected = LiteralUtil::CreateR2<int64>({{4, -16}, {-196, 12}});
487
488 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
489 }
490
491 // Verifies Reshape operation is correctly evaluated.
TEST_F(HloEvaluatorTest,DoesReshape)492 TEST_F(HloEvaluatorTest, DoesReshape) {
493 HloComputation::Builder b(TestName());
494 const int64 dimensions[] = {11, 8, 7, 5, 9};
495 TF_ASSERT_OK_AND_ASSIGN(auto literal,
496 LiteralUtil::CreateRandomLiteral<F32>(
497 ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
498 auto literal_clone = literal.Clone();
499 HloInstruction* literal_instruction =
500 b.AddInstruction(HloInstruction::CreateConstant(std::move(literal)));
501
502 Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5});
503 const int64 permutation[] = {1, 2, 0, 4, 3};
504 b.AddInstruction(
505 HloInstruction::CreateTranspose(shape, literal_instruction, permutation));
506 m_->AddEntryComputation(b.Build());
507
508 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({}));
509
510 using NativeT = typename primitive_util::PrimitiveTypeToNative<F32>::type;
511 result.EachCell<NativeT>([&](absl::Span<const int64> indices, NativeT value) {
512 std::vector<int64> rindexes = PermuteInverse(indices, permutation);
513 EXPECT_NEAR(value, literal_clone.Get<NativeT>(rindexes), 0.031250);
514 });
515 }
516
517 // Verifies Broadcast operation is correctly evaluated.
TEST_F(HloEvaluatorTest,DoesBroadcast)518 TEST_F(HloEvaluatorTest, DoesBroadcast) {
519 HloComputation::Builder b(TestName());
520 auto input_literal = LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}});
521 auto output_literal = LiteralUtil::CreateR3<int32>(
522 {{{1, 2}, {3, 4}, {5, 6}}, {{1, 2}, {3, 4}, {5, 6}}});
523 HloInstruction* literal_instruction = b.AddInstruction(
524 HloInstruction::CreateConstant(std::move(input_literal)));
525 b.AddInstruction(HloInstruction::CreateBroadcast(
526 output_literal.shape(), literal_instruction, {1, 2}));
527 m_->AddEntryComputation(b.Build());
528
529 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({}));
530
531 EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal));
532 }
533
TEST_F(HloEvaluatorTest,DoesBroadcastScalar)534 TEST_F(HloEvaluatorTest, DoesBroadcastScalar) {
535 HloComputation::Builder b(TestName());
536 auto input_literal = LiteralUtil::CreateR0<int32>(111);
537 auto output_literal = LiteralUtil::CreateR2<int32>(
538 {{111, 111}, {111, 111}, {111, 111}, {111, 111}, {111, 111}, {111, 111}});
539
540 HloInstruction* literal_instruction = b.AddInstruction(
541 HloInstruction::CreateConstant(std::move(input_literal)));
542 // Broadcast dimension should be empty in the case of scalars.
543 b.AddInstruction(HloInstruction::CreateBroadcast(
544 output_literal.shape(), literal_instruction,
545 /*broadcast_dimensions=*/{}));
546 m_->AddEntryComputation(b.Build());
547
548 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({}));
549
550 EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal));
551 }
552
TEST_F(HloEvaluatorTest,DoesConcatenateSimple)553 TEST_F(HloEvaluatorTest, DoesConcatenateSimple) {
554 HloComputation::Builder b(TestName());
555
556 HloInstruction* operand1 = b.AddInstruction(HloInstruction::CreateConstant(
557 LiteralUtil::CreateR2<int64>({{-1, -2}, {100, 200}})));
558 HloInstruction* operand2 = b.AddInstruction(HloInstruction::CreateConstant(
559 LiteralUtil::CreateR2<int64>({{-2, -3}, {-100, -200}})));
560
561 std::vector<HloInstruction*> operands = {operand1, operand2};
562
563 Shape shape = ShapeUtil::MakeShape(S64, {4, 2});
564 b.AddInstruction(HloInstruction::CreateConcatenate(shape, operands, 0));
565
566 m_->AddEntryComputation(b.Build());
567
568 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
569
570 auto expected = LiteralUtil::CreateR2<int64>(
571 {{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}});
572 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
573 }
574
TEST_F(HloEvaluatorTest,ConcatenateHandlesShapeWithZeroElement)575 TEST_F(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) {
576 HloComputation::Builder b(TestName());
577
578 HloInstruction* operand1 = b.AddInstruction(
579 HloInstruction::CreateConstant(LiteralUtil::CreateR1<int64>({100, 200})));
580 HloInstruction* operand2 = b.AddInstruction(
581 HloInstruction::CreateConstant(LiteralUtil::CreateR1<int64>({})));
582
583 std::vector<HloInstruction*> operands = {operand1, operand2};
584
585 Shape shape = ShapeUtil::MakeShape(S64, {2});
586 b.AddInstruction(HloInstruction::CreateConcatenate(shape, operands, 0));
587
588 m_->AddEntryComputation(b.Build());
589
590 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
591
592 auto expected = LiteralUtil::CreateR1<int64>({100, 200});
593 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
594 }
595
TEST_P(HloEvaluatorBf16Test,ConvertWithSameLayout)596 TEST_P(HloEvaluatorBf16Test, ConvertWithSameLayout) {
597 HloComputation::Builder b(TestName());
598
599 auto input_literal = LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}});
600 auto expected =
601 LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}});
602 ASSERT_TRUE(LayoutUtil::LayoutsInShapesEqual(input_literal.shape(),
603 expected.shape()));
604
605 HloInstruction* constant = b.AddInstruction(
606 HloInstruction::CreateConstant(std::move(input_literal)));
607 b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant));
608 m_->AddEntryComputation(b.Build());
609
610 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
611
612 EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
613 }
614
TEST_P(HloEvaluatorBf16Test,ConvertWithDifferentLayout)615 TEST_P(HloEvaluatorBf16Test, ConvertWithDifferentLayout) {
616 HloComputation::Builder b(TestName());
617
618 auto input_literal = LiteralUtil::CreateR2WithLayout<int32>(
619 {{1, 2}, {3, 4}, {5, 6}}, LayoutUtil::MakeLayout({0, 1}));
620 auto expected = LiteralUtil::CreateR2WithLayout<float>(
621 {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}, LayoutUtil::MakeLayout({1, 0}));
622 ASSERT_FALSE(LayoutUtil::LayoutsInShapesEqual(input_literal.shape(),
623 expected.shape()));
624
625 HloInstruction* constant = b.AddInstruction(
626 HloInstruction::CreateConstant(std::move(input_literal)));
627 b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant));
628 m_->AddEntryComputation(b.Build());
629
630 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
631
632 EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
633 }
634
CreatePaddingConfig(std::initializer_list<std::array<int64,3>> padding_dimensions)635 PaddingConfig CreatePaddingConfig(
636 std::initializer_list<std::array<int64, 3>> padding_dimensions) {
637 PaddingConfig padding_config;
638
639 for (auto& paddings_per_dim : padding_dimensions) {
640 auto dimension = padding_config.add_dimensions();
641 dimension->set_edge_padding_low(paddings_per_dim[0]);
642 dimension->set_edge_padding_high(paddings_per_dim[1]);
643 dimension->set_interior_padding(paddings_per_dim[2]);
644 }
645 return padding_config;
646 }
647
TEST_F(HloEvaluatorTest,Pad2DIntegerArrayWithZeroDimension)648 TEST_F(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) {
649 auto operand = LiteralUtil::CreateR2<int32>({{}, {}});
650 HloComputation::Builder b(TestName());
651 auto operand_instruction =
652 b.AddInstruction(HloInstruction::CreateConstant(std::move(operand)));
653
654 constexpr int32 kPadValue = 10;
655 auto pad_value = LiteralUtil::CreateR0<int32>(kPadValue);
656 auto padding_value_instruction =
657 b.AddInstruction(HloInstruction::CreateConstant(std::move(pad_value)));
658
659 auto padding_config = CreatePaddingConfig({{{1, 0, 2}}, {{0, 2, 1}}});
660 Shape shape = ShapeUtil::MakeShape(S32, {5, 2});
661 b.AddInstruction(HloInstruction::CreatePad(
662 shape, operand_instruction, padding_value_instruction, padding_config));
663 m_->AddEntryComputation(b.Build());
664
665 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
666
667 auto expected = LiteralUtil::CreateR2<int32>(
668 {{10, 10}, {10, 10}, {10, 10}, {10, 10}, {10, 10}});
669
670 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
671 }
672
TEST_P(HloEvaluatorBf16Test,Pad4DFloatArrayWithInteriorPadding)673 TEST_P(HloEvaluatorBf16Test, Pad4DFloatArrayWithInteriorPadding) {
674 HloComputation::Builder b(TestName());
675
676 Array4D<float> input_array(3, 2, 1, 1, {1, 2, 3, 4, 5, 6});
677 auto input = LiteralUtil::CreateR4FromArray4D<float>(input_array);
678 HloInstruction* input_instruction =
679 b.AddInstruction(HloInstruction::CreateConstant(std::move(input)));
680 constexpr float kPadValue = 1.5;
681 auto pad_value = LiteralUtil::CreateR0<float>(kPadValue);
682 HloInstruction* pad_instruction =
683 b.AddInstruction(HloInstruction::CreateConstant(std::move(pad_value)));
684
685 Shape shape = ShapeUtil::MakeShape(F32, {8, 5, 1, 1});
686 auto r4_padding_on_dim0_dim1 =
687 CreatePaddingConfig({{{1, 0, 2}}, {{0, 2, 1}}, {{0, 0, 0}}, {{0, 0, 0}}});
688 b.AddInstruction(HloInstruction::CreatePad(
689 shape, input_instruction, pad_instruction, r4_padding_on_dim0_dim1));
690 m_->AddEntryComputation(b.Build());
691
692 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
693
694 auto expected_array = absl::make_unique<Array4D<float>>(8, 5, 1, 1);
695 expected_array->Fill(kPadValue);
696 (*expected_array)(1, 0, 0, 0) = 1.0f;
697 (*expected_array)(1, 2, 0, 0) = 2.0f;
698 (*expected_array)(4, 0, 0, 0) = 3.0f;
699 (*expected_array)(4, 2, 0, 0) = 4.0f;
700 (*expected_array)(7, 0, 0, 0) = 5.0f;
701 (*expected_array)(7, 2, 0, 0) = 6.0f;
702
703 auto expected = LiteralUtil::CreateR4FromArray4D<float>(*expected_array);
704
705 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
706 }
707
TEST_P(HloEvaluatorBf16Test,NegativePadding2D)708 TEST_P(HloEvaluatorBf16Test, NegativePadding2D) {
709 HloComputation::Builder b(TestName());
710
711 // input_array:
712 // f32[4,3] {
713 // { 1, 2, 3 },
714 // { 5, 6, 7 },
715 // { 9, 10, 11 },
716 // { 13, 14, 15 },
717 // }
718 auto input_array = absl::make_unique<Array2D<float>>(4, 3);
719 input_array->FillUnique(1.0f);
720 auto input = LiteralUtil::CreateR2FromArray2D<float>(*input_array);
721 HloInstruction* input_instruction =
722 b.AddInstruction(HloInstruction::CreateConstant(std::move(input)));
723
724 auto pad_value_instruction = b.AddInstruction(
725 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.718f)));
726
727 auto r2_padding_on_dim0_dim1 =
728 CreatePaddingConfig({{{-1, -2, 0}}, {{-2, 4, 0}}});
729 Shape shape = ShapeUtil::MakeShape(F32, {1, 5});
730 b.AddInstruction(HloInstruction::CreatePad(shape, input_instruction,
731 pad_value_instruction,
732 r2_padding_on_dim0_dim1));
733
734 m_->AddEntryComputation(b.Build());
735
736 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
737
738 // f32[1,5] { 7.0, 2.718, 2.718, 2.718, 2.718 }
739 auto expected_array = absl::make_unique<Array2D<float>>(1, 5);
740 (*expected_array)(0, 0) = 7.0f;
741 (*expected_array)(0, 1) = 2.718f;
742 (*expected_array)(0, 2) = 2.718f;
743 (*expected_array)(0, 3) = 2.718f;
744 (*expected_array)(0, 4) = 2.718f;
745 auto expected = LiteralUtil::CreateR2FromArray2D<float>(*expected_array);
746
747 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(0.031250)));
748 }
749
TEST_P(HloEvaluatorBf16Test,NegativeAndInteriorPadding2D)750 TEST_P(HloEvaluatorBf16Test, NegativeAndInteriorPadding2D) {
751 HloComputation::Builder b(TestName());
752
753 // f32[4,3] {
754 // { 1, 2, 3 },
755 // { 5, 6, 7 },
756 // { 9, 10, 11 },
757 // { 13, 14, 15 },
758 // }
759 auto input_array = absl::make_unique<Array2D<float>>(4, 3);
760 input_array->FillUnique(1.0f);
761 auto input = LiteralUtil::CreateR2FromArray2D<float>(*input_array);
762 HloInstruction* input_instruction =
763 b.AddInstruction(HloInstruction::CreateConstant(std::move(input)));
764
765 auto pad_value_instruction = b.AddInstruction(
766 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.718f)));
767
768 PaddingConfig padding_config = MakeNoPaddingConfig(2);
769
770 // Negative padding that results in zero dimensions.
771 auto r2_padding_on_dim0_dim1 =
772 CreatePaddingConfig({{{-2, -5, 1}}, {{-2, 4, 2}}});
773
774 Shape shape = ShapeUtil::MakeShape(F32, {0, 9});
775 b.AddInstruction(HloInstruction::CreatePad(shape, input_instruction,
776 pad_value_instruction,
777 r2_padding_on_dim0_dim1));
778
779 m_->AddEntryComputation(b.Build());
780
781 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
782
783 auto expected_array = absl::make_unique<Array2D<float>>(0, 9);
784 auto expected = LiteralUtil::CreateR2FromArray2D<float>(*expected_array);
785
786 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
787 }
788
TEST_P(HloEvaluatorBf16Test,DotRank2AndRank1)789 TEST_P(HloEvaluatorBf16Test, DotRank2AndRank1) {
790 HloComputation::Builder b(TestName());
791
792 // lhs:
793 // f32[4,1] {
794 // { 1 },
795 // { 2 },
796 // { 3 },
797 // { 4 },
798 // }
799 auto lhs_array = absl::make_unique<Array2D<float>>(4, 1);
800 lhs_array->FillUnique(1.0f);
801 auto lhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*lhs_array);
802 HloInstruction* lhs_instruction =
803 b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
804
805 // rhs:
806 // f32[2] { 1, 2 },
807 auto rhs_literal = LiteralUtil::CreateR2<float>({{1, 2}});
808 HloInstruction* rhs_instruction =
809 b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
810
811 Shape shape = ShapeUtil::MakeShape(F32, {4, 2});
812 DotDimensionNumbers dot_dnums;
813 dot_dnums.add_lhs_contracting_dimensions(1);
814 dot_dnums.add_rhs_contracting_dimensions(0);
815 b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
816 rhs_instruction, dot_dnums,
817 DefaultPrecisionConfig(2)));
818 m_->AddEntryComputation(b.Build());
819
820 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
821
822 // clang-format off
823 auto expected_array = Array2D<float>({
824 {1.f, 2.f},
825 {2.f, 4.f},
826 {3.f, 6.f},
827 {4.f, 8.f},
828 });
829 // clang-format on
830 auto expected = LiteralUtil::CreateR2FromArray2D<float>(expected_array);
831
832 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
833 }
834
TEST_P(HloEvaluatorBf16Test,DotRank1AndRank2)835 TEST_P(HloEvaluatorBf16Test, DotRank1AndRank2) {
836 HloComputation::Builder b(TestName());
837
838 // lhs:
839 // f32[3]
840 // { 1, 2, 3 },
841 auto lhs_literal = LiteralUtil::CreateR1<float>({1, 2, 3});
842 HloInstruction* lhs_instruction =
843 b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
844
845 // rhs:
846 // f32[3,2] {
847 // { 1, 2 },
848 // { 3, 4 },
849 // { 5, 6 },
850 // }
851 auto rhs_array = absl::make_unique<Array2D<float>>(3, 2);
852 rhs_array->FillUnique(1.0f);
853 auto rhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*rhs_array);
854 HloInstruction* rhs_instruction =
855 b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
856
857 Shape shape = ShapeUtil::MakeShape(F32, {2});
858 DotDimensionNumbers dot_dnums;
859 dot_dnums.add_lhs_contracting_dimensions(0);
860 dot_dnums.add_rhs_contracting_dimensions(0);
861 b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
862 rhs_instruction, dot_dnums,
863 DefaultPrecisionConfig(2)));
864 m_->AddEntryComputation(b.Build());
865
866 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
867
868 auto expected = LiteralUtil::CreateR1<float>({22.f, 28.f});
869
870 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
871 }
872
TEST_P(HloEvaluatorBf16Test,DotRank2AndRank2)873 TEST_P(HloEvaluatorBf16Test, DotRank2AndRank2) {
874 HloComputation::Builder b(TestName());
875
876 // lhs:
877 // f32[4,3] {
878 // { 1, 2, 3 },
879 // { 5, 6, 7 },
880 // { 9, 10, 11 },
881 // { 13, 14, 15 },
882 // }
883 auto lhs_array = absl::make_unique<Array2D<float>>(4, 3);
884 lhs_array->FillUnique(1.0f);
885 auto lhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*lhs_array);
886 HloInstruction* lhs_instruction =
887 b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
888
889 // rhs:
890 // f32[3,2] {
891 // { 1, 2 },
892 // { 3, 4 },
893 // { 5, 6 },
894 // }
895 auto rhs_array = absl::make_unique<Array2D<float>>(3, 2);
896 rhs_array->FillUnique(1.0f);
897 auto rhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*rhs_array);
898 HloInstruction* rhs_instruction =
899 b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
900
901 Shape shape = ShapeUtil::MakeShape(F32, {4, 2});
902 DotDimensionNumbers dot_dnums;
903 dot_dnums.add_lhs_contracting_dimensions(1);
904 dot_dnums.add_rhs_contracting_dimensions(0);
905 b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
906 rhs_instruction, dot_dnums,
907 DefaultPrecisionConfig(2)));
908 m_->AddEntryComputation(b.Build());
909
910 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
911
912 auto expected_array = Array2D<float>({
913 {22.f, 28.f},
914 {58.f, 76.f},
915 {94.f, 124.f},
916 {130.f, 172.f},
917 });
918 auto expected = LiteralUtil::CreateR2FromArray2D<float>(expected_array);
919
920 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
921 }
922
TEST_P(HloEvaluatorBf16Test,DotRank4AndRank4)923 TEST_P(HloEvaluatorBf16Test, DotRank4AndRank4) {
924 HloComputation::Builder b(TestName());
925
926 auto lhs_array = absl::make_unique<Array4D<float>>(2, 2, 3, 1);
927 lhs_array->FillIota(1.0f);
928 auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(*lhs_array);
929 HloInstruction* lhs_instruction =
930 b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
931
932 auto rhs_array = absl::make_unique<Array4D<float>>(2, 2, 3, 1);
933 rhs_array->FillIota(2.0f);
934 auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(*rhs_array);
935 HloInstruction* rhs_instruction =
936 b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
937
938 Shape shape = ShapeUtil::MakeShape(F32, {2, 1, 1});
939 DotDimensionNumbers dot_dnums;
940
941 dot_dnums.add_lhs_batch_dimensions(0);
942 dot_dnums.add_rhs_batch_dimensions(0);
943 dot_dnums.add_lhs_contracting_dimensions(1);
944 dot_dnums.add_lhs_contracting_dimensions(2);
945 dot_dnums.add_rhs_contracting_dimensions(1);
946 dot_dnums.add_rhs_contracting_dimensions(2);
947 b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
948 rhs_instruction, dot_dnums,
949 DefaultPrecisionConfig(2)));
950 m_->AddEntryComputation(b.Build());
951
952 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
953
954 float expected_1 = 0;
955 for (float i = 1.0f; i < 7.0f; ++i) {
956 expected_1 += i * i + i;
957 }
958 float expected_2 = 0;
959 for (float i = 7.0f; i < 13.0f; ++i) {
960 expected_2 += i * i + i;
961 }
962 auto expected_array = Array3D<float>({{{expected_1}}, {{expected_2}}});
963 auto expected = LiteralUtil::CreateR3FromArray3D<float>(expected_array);
964
965 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
966 }
967
TEST_P(HloEvaluatorBf16Test,SimpleConv1D)968 TEST_P(HloEvaluatorBf16Test, SimpleConv1D) {
969 HloComputation::Builder b(TestName());
970
971 Array3D<float> lhs_array = {{{1, 2, 3}}};
972 auto lhs_literal = LiteralUtil::CreateR3FromArray3D<float>(lhs_array);
973 HloInstruction* lhs_instruction =
974 b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
975
976 Array3D<float> rhs_array = {{{3.f, 4.f}}};
977 auto rhs_literal = LiteralUtil::CreateR3FromArray3D<float>(rhs_array);
978 HloInstruction* rhs_instruction =
979 b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
980
981 Window window;
982 WindowDimension dim;
983 dim.set_size(2);
984 dim.set_stride(1);
985 dim.set_padding_low(0);
986 dim.set_padding_high(1);
987 dim.set_window_dilation(1);
988 dim.set_base_dilation(1);
989 *window.add_dimensions() = dim;
990
991 ConvolutionDimensionNumbers dnums;
992 dnums.set_input_batch_dimension(0);
993 dnums.set_output_batch_dimension(0);
994 dnums.set_input_feature_dimension(1);
995 dnums.set_output_feature_dimension(1);
996 dnums.add_input_spatial_dimensions(2);
997 dnums.add_output_spatial_dimensions(2);
998
999 dnums.set_kernel_output_feature_dimension(0);
1000 dnums.set_kernel_input_feature_dimension(1);
1001 dnums.add_kernel_spatial_dimensions(2);
1002
1003 Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 3});
1004 b.AddInstruction(HloInstruction::CreateConvolve(
1005 shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
1006 /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1007 m_->AddEntryComputation(b.Build());
1008
1009 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1010
1011 Array3D<float> expected_array = {{{11.f, 18.f, 9.f}}};
1012 auto expected = LiteralUtil::CreateR3FromArray3D<float>(expected_array);
1013
1014 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1015 }
1016
TEST_P(HloEvaluatorBf16Test,Simple4x4Conv2DWith2x2Kernel)1017 TEST_P(HloEvaluatorBf16Test, Simple4x4Conv2DWith2x2Kernel) {
1018 HloComputation::Builder b(TestName());
1019
1020 Array4D<float> lhs_array(1, 1, 4, 4);
1021 // clang-format off
1022 lhs_array.FillWithYX(Array2D<float>({
1023 {1, 2, 3, 4 },
1024 {5, 6, 7, 8 },
1025 {9, 10, 11, 12},
1026 {13, 14, 15, 16},
1027 }));
1028 // clang-format on
1029 auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs_array);
1030 HloInstruction* lhs_instruction =
1031 b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
1032
1033 Array4D<float> rhs_array(1, 1, 2, 2);
1034 // clang-format off
1035 rhs_array.FillWithYX(Array2D<float>({
1036 {5, 6},
1037 {7, 8},
1038 }));
1039 // clang-format on
1040 auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs_array);
1041 HloInstruction* rhs_instruction =
1042 b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
1043
1044 Window window;
1045 WindowDimension dim;
1046 dim.set_size(2);
1047 dim.set_stride(1);
1048 dim.set_padding_low(0);
1049 dim.set_padding_high(1);
1050 dim.set_window_dilation(1);
1051 dim.set_base_dilation(1);
1052 *window.add_dimensions() = dim;
1053 *window.add_dimensions() = dim;
1054
1055 ConvolutionDimensionNumbers dnums =
1056 XlaBuilder::CreateDefaultConvDimensionNumbers(2);
1057
1058 Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
1059 b.AddInstruction(HloInstruction::CreateConvolve(
1060 shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
1061 /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1062 m_->AddEntryComputation(b.Build());
1063
1064 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1065
1066 Array4D<float> expected_array(1, 1, 4, 4);
1067 // clang-format off
1068 expected_array.FillWithYX(Array2D<float>({
1069 {100, 126, 152, 76},
1070 {204, 230, 256, 124},
1071 {308, 334, 360, 172},
1072 {149, 160, 171, 80},
1073 }));
1074 // clang-format on
1075 auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
1076
1077 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1078 }
1079
TEST_P(HloEvaluatorBf16Test,Conv2DGeneralDimensionsReversed)1080 TEST_P(HloEvaluatorBf16Test, Conv2DGeneralDimensionsReversed) {
1081 HloComputation::Builder b(TestName());
1082
1083 // clang-format off
1084 // Input dimensions: [feature=2, height=3, batch=1, width=4]
1085 Array4D<float> input({
1086 {{{1, 2, 3, 4}},
1087 {{5, 6, 7, 8}},
1088 {{9, 10, 11, 12}}},
1089 {{{13, 14, 15, 16}},
1090 {{17, 18, 19, 20}},
1091 {{21, 22, 23, 24}}}
1092 });
1093 // Weight dimensions:
1094 // [kernel_output_feature=1, width=3, kernel_input_feature=2, height=3]
1095 Array4D<float> weight({{
1096 {{1, 7, 13},
1097 {4, 10, 16}},
1098 {{2, 8, 14},
1099 {5, 11, 17}},
1100 {{3, 9, 15},
1101 {6, 12, 18}}
1102 }});
1103 // clang-format on
1104
1105 auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(input);
1106 HloInstruction* lhs_instruction =
1107 b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
1108
1109 auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(weight);
1110 HloInstruction* rhs_instruction =
1111 b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
1112 rhs_instruction = b.AddInstruction(HloInstruction::CreateReverse(
1113 rhs_instruction->shape(), rhs_instruction, {3, 1}));
1114
1115 Window window;
1116 WindowDimension dim;
1117 dim.set_size(3);
1118 dim.set_stride(1);
1119 dim.set_padding_low(0);
1120 dim.set_padding_high(0);
1121 dim.set_window_dilation(1);
1122 dim.set_base_dilation(1);
1123 dim.set_window_reversal(true);
1124 *window.add_dimensions() = dim;
1125 *window.add_dimensions() = dim;
1126
1127 ConvolutionDimensionNumbers dnums;
1128 dnums.set_input_batch_dimension(2);
1129 dnums.set_output_batch_dimension(2);
1130 dnums.set_input_feature_dimension(0);
1131 dnums.set_output_feature_dimension(0);
1132 dnums.add_input_spatial_dimensions(1);
1133 dnums.add_output_spatial_dimensions(1);
1134 dnums.add_input_spatial_dimensions(3);
1135 dnums.add_output_spatial_dimensions(3);
1136
1137 dnums.set_kernel_output_feature_dimension(0);
1138 dnums.set_kernel_input_feature_dimension(2);
1139 dnums.add_kernel_spatial_dimensions(3);
1140 dnums.add_kernel_spatial_dimensions(1);
1141
1142 Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
1143 b.AddInstruction(HloInstruction::CreateConvolve(
1144 shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
1145 /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1146 m_->AddEntryComputation(b.Build());
1147
1148 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1149
1150 // clang-format off
1151 // Result dimensions: [feature=1, height=1, batch=1, width=2]
1152 Array4D<float> expected_array({{{{2514, 2685}}}});
1153 Array4D<float> expected_array_bf16({{{{2512, 2688}}}});
1154 // clang-format on
1155 auto expected = LiteralUtil::CreateR4FromArray4D<float>(
1156 use_bfloat16_ ? expected_array_bf16 : expected_array);
1157
1158 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1159 }
1160
TEST_P(HloEvaluatorBf16Test,Conv2DGeneralDimensions)1161 TEST_P(HloEvaluatorBf16Test, Conv2DGeneralDimensions) {
1162 HloComputation::Builder b(TestName());
1163
1164 // clang-format off
1165 // Input dimensions: [feature=2, height=3, batch=1, width=4]
1166 Array4D<float> input({
1167 {{{1, 2, 3, 4}},
1168 {{5, 6, 7, 8}},
1169 {{9, 10, 11, 12}}},
1170 {{{13, 14, 15, 16}},
1171 {{17, 18, 19, 20}},
1172 {{21, 22, 23, 24}}}
1173 });
1174 // Weight dimensions:
1175 // [kernel_output_feature=1, width=3, kernel_input_feature=2, height=3]
1176 Array4D<float> weight({{
1177 {{1, 7, 13},
1178 {4, 10, 16}},
1179 {{2, 8, 14},
1180 {5, 11, 17}},
1181 {{3, 9, 15},
1182 {6, 12, 18}}
1183 }});
1184 // clang-format on
1185
1186 auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(input);
1187 HloInstruction* lhs_instruction =
1188 b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
1189
1190 auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(weight);
1191 HloInstruction* rhs_instruction =
1192 b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
1193
1194 Window window;
1195 WindowDimension dim;
1196 dim.set_size(3);
1197 dim.set_stride(1);
1198 dim.set_padding_low(0);
1199 dim.set_padding_high(0);
1200 dim.set_window_dilation(1);
1201 dim.set_base_dilation(1);
1202 *window.add_dimensions() = dim;
1203 *window.add_dimensions() = dim;
1204
1205 ConvolutionDimensionNumbers dnums;
1206 dnums.set_input_batch_dimension(2);
1207 dnums.set_output_batch_dimension(2);
1208 dnums.set_input_feature_dimension(0);
1209 dnums.set_output_feature_dimension(0);
1210 dnums.add_input_spatial_dimensions(1);
1211 dnums.add_output_spatial_dimensions(1);
1212 dnums.add_input_spatial_dimensions(3);
1213 dnums.add_output_spatial_dimensions(3);
1214
1215 dnums.set_kernel_output_feature_dimension(0);
1216 dnums.set_kernel_input_feature_dimension(2);
1217 dnums.add_kernel_spatial_dimensions(3);
1218 dnums.add_kernel_spatial_dimensions(1);
1219
1220 Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
1221 b.AddInstruction(HloInstruction::CreateConvolve(
1222 shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
1223 /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1224 m_->AddEntryComputation(b.Build());
1225
1226 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1227
1228 // clang-format off
1229 // Result dimensions: [feature=1, height=1, batch=1, width=2]
1230 Array4D<float> expected_array({{{{2514, 2685}}}});
1231 Array4D<float> expected_array_bf16({{{{2512, 2688}}}});
1232 // clang-format on
1233 auto expected = LiteralUtil::CreateR4FromArray4D<float>(
1234 use_bfloat16_ ? expected_array_bf16 : expected_array);
1235
1236 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1237 }
1238
TEST_P(HloEvaluatorBf16Test,DilatedBaseConv2DWithHighPadding)1239 TEST_P(HloEvaluatorBf16Test, DilatedBaseConv2DWithHighPadding) {
1240 HloComputation::Builder b(TestName());
1241
1242 Array4D<float> lhs_array(1, 1, 4, 4);
1243 // clang-format off
1244 lhs_array.FillWithYX(Array2D<float>({
1245 {1, 2, 3, 4 },
1246 {5, 6, 7, 8 },
1247 {9, 10, 11, 12},
1248 {13, 14, 15, 16},
1249 }));
1250 // clang-format on
1251 auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs_array);
1252 HloInstruction* lhs_instruction =
1253 b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
1254
1255 Array4D<float> rhs_array(1, 1, 2, 2);
1256 // clang-format off
1257 rhs_array.FillWithYX(Array2D<float>({
1258 {5, 6},
1259 {7, 8},
1260 }));
1261 // clang-format on
1262 auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs_array);
1263 HloInstruction* rhs_instruction =
1264 b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
1265
1266 Window window;
1267 WindowDimension dim;
1268 dim.set_size(2);
1269 dim.set_stride(1);
1270 dim.set_padding_low(0);
1271 dim.set_padding_high(1);
1272 dim.set_window_dilation(1);
1273 dim.set_base_dilation(2);
1274 *window.add_dimensions() = dim;
1275 *window.add_dimensions() = dim;
1276
1277 ConvolutionDimensionNumbers dnums =
1278 XlaBuilder::CreateDefaultConvDimensionNumbers(2);
1279
1280 Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 7, 7});
1281 b.AddInstruction(HloInstruction::CreateConvolve(
1282 shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
1283 /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1284 m_->AddEntryComputation(b.Build());
1285
1286 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1287
1288 Array4D<float> expected_array(1, 1, 7, 7);
1289 expected_array.FillWithYX(Array2D<float>({
1290 {5, 12, 10, 18, 15, 24, 20},
1291 {35, 48, 42, 56, 49, 64, 56},
1292 {25, 36, 30, 42, 35, 48, 40},
1293 {63, 80, 70, 88, 77, 96, 84},
1294 {45, 60, 50, 66, 55, 72, 60},
1295 {91, 112, 98, 120, 105, 128, 112},
1296 {65, 84, 70, 90, 75, 96, 80},
1297 }));
1298 auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
1299
1300 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1301 }
1302
TEST_P(HloEvaluatorBf16Test,DilatedBaseConv2DWithLowAndHighPadding)1303 TEST_P(HloEvaluatorBf16Test, DilatedBaseConv2DWithLowAndHighPadding) {
1304 HloComputation::Builder b(TestName());
1305
1306 Array4D<float> lhs_array(1, 1, 4, 4);
1307 // clang-format off
1308 lhs_array.FillWithYX(Array2D<float>({
1309 {1, 2, 3, 4 },
1310 {5, 6, 7, 8 },
1311 {9, 10, 11, 12},
1312 {13, 14, 15, 16},
1313 }));
1314 // clang-format on
1315 auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs_array);
1316 HloInstruction* lhs_instruction =
1317 b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
1318
1319 Array4D<float> rhs_array(1, 1, 2, 2);
1320 // clang-format off
1321 rhs_array.FillWithYX(Array2D<float>({
1322 {5, 6},
1323 {7, 8},
1324 }));
1325 // clang-format on
1326 auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs_array);
1327 HloInstruction* rhs_instruction =
1328 b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
1329
1330 Window window;
1331 WindowDimension dim;
1332 dim.set_size(2);
1333 dim.set_stride(1);
1334 dim.set_padding_low(1);
1335 dim.set_padding_high(1);
1336 dim.set_window_dilation(1);
1337 dim.set_base_dilation(2);
1338 *window.add_dimensions() = dim;
1339 *window.add_dimensions() = dim;
1340
1341 ConvolutionDimensionNumbers dnums =
1342 XlaBuilder::CreateDefaultConvDimensionNumbers(2);
1343
1344 Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 8, 8});
1345 b.AddInstruction(HloInstruction::CreateConvolve(
1346 shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
1347 /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1348 m_->AddEntryComputation(b.Build());
1349
1350 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1351
1352 Array4D<float> expected_array(1, 1, 8, 8);
1353 expected_array.FillWithYX(Array2D<float>({
1354 {8, 7, 16, 14, 24, 21, 32, 28},
1355 {6, 5, 12, 10, 18, 15, 24, 20},
1356 {40, 35, 48, 42, 56, 49, 64, 56},
1357 {30, 25, 36, 30, 42, 35, 48, 40},
1358 {72, 63, 80, 70, 88, 77, 96, 84},
1359 {54, 45, 60, 50, 66, 55, 72, 60},
1360 {104, 91, 112, 98, 120, 105, 128, 112},
1361 {78, 65, 84, 70, 90, 75, 96, 80},
1362 }));
1363 auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
1364
1365 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1366 }
1367
TEST_P(HloEvaluatorBf16Test,DilatedWindowAndBaseConv2DWithDifferentLowAndHighPaddingAndStrides)1368 TEST_P(HloEvaluatorBf16Test,
1369 DilatedWindowAndBaseConv2DWithDifferentLowAndHighPaddingAndStrides) {
1370 HloComputation::Builder b(TestName());
1371
1372 Array4D<float> lhs_array(1, 1, 4, 4);
1373 // clang-format off
1374 lhs_array.FillWithYX(Array2D<float>({
1375 {1, 2, 3, 4 },
1376 {5, 6, 7, 8 },
1377 {9, 10, 11, 12},
1378 {13, 14, 15, 16},
1379 }));
1380 // clang-format on
1381 auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs_array);
1382 HloInstruction* lhs_instruction =
1383 b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
1384
1385 Array4D<float> rhs_array(1, 1, 2, 3);
1386 // clang-format off
1387 rhs_array.FillWithYX(Array2D<float>({
1388 {5, 6, 7},
1389 {8, 9, 10},
1390 }));
1391 // clang-format on
1392 auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs_array);
1393 HloInstruction* rhs_instruction =
1394 b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
1395
1396 Window window;
1397 WindowDimension dim;
1398 dim.set_size(2);
1399 dim.set_stride(1);
1400 dim.set_padding_low(2);
1401 dim.set_padding_high(2);
1402 dim.set_window_dilation(2);
1403 dim.set_base_dilation(2);
1404 *window.add_dimensions() = dim;
1405 dim.set_size(3);
1406 dim.set_stride(3);
1407 dim.set_padding_low(2);
1408 dim.set_padding_high(-1);
1409 dim.set_window_dilation(1);
1410 dim.set_base_dilation(3);
1411 *window.add_dimensions() = dim;
1412
1413 ConvolutionDimensionNumbers dnums =
1414 XlaBuilder::CreateDefaultConvDimensionNumbers(2);
1415
1416 Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 9, 3});
1417 b.AddInstruction(HloInstruction::CreateConvolve(
1418 shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
1419 /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1420 m_->AddEntryComputation(b.Build());
1421
1422 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1423
1424 Array4D<float> expected_array(1, 1, 9, 3);
1425 expected_array.FillWithYX(Array2D<float>({
1426 {10, 20, 30},
1427 {0, 0, 0},
1428 {57, 74, 91},
1429 {0, 0, 0},
1430 {125, 142, 159},
1431 {0, 0, 0},
1432 {193, 210, 227},
1433 {0, 0, 0},
1434 {91, 98, 105},
1435 }));
1436 auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
1437
1438 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1439 }
1440
TEST_P(HloEvaluatorBf16Test,Conv2DGroupedConvolution)1441 TEST_P(HloEvaluatorBf16Test, Conv2DGroupedConvolution) {
1442 HloComputation::Builder b(TestName());
1443 std::vector<int64> input_dims = {1, 2, 2, 4};
1444 std::vector<int64> filter_dims = {2, 2, 2, 8};
1445 Shape input_shape = ShapeUtil::MakeShapeWithType<float>(input_dims);
1446 Shape filter_shape = ShapeUtil::MakeShapeWithType<float>(filter_dims);
1447 // Tensorflow dimension numbers for 2D convolution.
1448 ConvolutionDimensionNumbers dnums;
1449 dnums.set_input_batch_dimension(0);
1450 dnums.set_output_batch_dimension(0);
1451 dnums.add_input_spatial_dimensions(1);
1452 dnums.add_output_spatial_dimensions(1);
1453 dnums.add_input_spatial_dimensions(2);
1454 dnums.add_output_spatial_dimensions(2);
1455 dnums.set_input_feature_dimension(3);
1456 dnums.set_output_feature_dimension(3);
1457 dnums.add_kernel_spatial_dimensions(0);
1458 dnums.add_kernel_spatial_dimensions(1);
1459 dnums.set_kernel_input_feature_dimension(2);
1460 dnums.set_kernel_output_feature_dimension(3);
1461
1462 Window window;
1463 WindowDimension dim;
1464 dim.set_size(2);
1465 dim.set_stride(1);
1466 dim.set_padding_low(0);
1467 dim.set_padding_high(0);
1468 dim.set_window_dilation(1);
1469 dim.set_base_dilation(1);
1470 *window.add_dimensions() = dim;
1471 *window.add_dimensions() = dim;
1472
1473 std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape));
1474 std::iota(input_elems.begin(), input_elems.end(), -7);
1475 auto input_r1 = LiteralUtil::CreateR1<float>(input_elems);
1476 auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
1477 HloInstruction* lhs_instruction =
1478 b.AddInstruction(HloInstruction::CreateConstant(std::move(input_r4)));
1479
1480 std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape));
1481 std::iota(filter_elems.begin(), filter_elems.end(), -31);
1482 auto filter_r1 = LiteralUtil::CreateR1<float>(filter_elems);
1483 auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
1484 HloInstruction* rhs_instruction =
1485 b.AddInstruction(HloInstruction::CreateConstant(std::move(filter_r4)));
1486
1487 Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 8});
1488 b.AddInstruction(HloInstruction::CreateConvolve(
1489 shape, lhs_instruction, rhs_instruction,
1490 /*feature_group_count=*/2, /*batch_group_count=*/1, window, dnums,
1491 DefaultPrecisionConfig(2)));
1492 m_->AddEntryComputation(b.Build());
1493
1494 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1495
1496 Array4D<float> expected_array(1, 1, 1, 8);
1497 expected_array.FillWithYX(
1498 Array2D<float>({{668, 664, 660, 656, 668, 680, 692, 704}}));
1499 auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
1500 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1501 }
1502
1503 // Initialization of data sets for FFT tests:
1504
InitializeFftData()1505 void HloEvaluatorTest::InitializeFftData() {
1506 // clang-format off
1507 fft_c64x2x4x8_ = LiteralUtil::CreateR3<complex64>({
1508 {{{0.0, 0.0}, {1.0, 0.0}, {2.0, 0.0}, {3.0, 0.0},
1509 {4.0, 0.0}, {5.0, 0.0}, {6.0, 0.0}, {7.0, 0.0}},
1510 {{0.0, 0.0}, {0.0, 1.0}, {0.0, 2.0}, {0.0, 3.0},
1511 {0.0, 4.0}, {0.0, 5.0}, {0.0, 6.0}, {0.0, 7.0}},
1512 {{0.0, 7.0}, {1.0, 6.0}, {2.0, 5.0}, {3.0, 4.0},
1513 {4.0, 3.0}, {5.0, 2.0}, {6.0, 1.0}, {7.0, 0.0}},
1514 {{7.0, 0.0}, {6.0, 1.0}, {5.0, 2.0}, {4.0, 3.0},
1515 {3.0, 4.0}, {2.0, 5.0}, {1.0, 6.0}, {0.0, 7.0}}},
1516 {{{-4.0, 0.0}, {-3.0, 0.0}, {-2.0, 0.0}, {-1.0, 0.0},
1517 {1.0, 0.0}, {2.0, 0.0}, {3.0, 0.0}, {4.0, 0.0}},
1518 {{0.0, -4.0}, {0.0, -3.0}, {0.0, -2.0}, {0.0, -1.0},
1519 {0.0, 1.0}, {0.0, 2.0}, {0.0, 3.0}, {0.0, 4.0}},
1520 {{3.5, 3.5}, {-1.707107, -0.707107}, {-1.0, -0.0}, {-0.707107, 0.292893},
1521 {-0.5, 0.5}, {-0.292893, 0.707107}, {0.0, 1.0}, {0.707107, 1.707107}},
1522 {{3.5, 3.5}, {1.707107, 0.707107}, {1.0, 0.0}, {0.707107, -0.292893},
1523 {0.5, -0.5}, {0.292893, -0.707107}, {-0.0, -1.0}, {-0.707107, -1.707107}}}
1524 });
1525 fft_c64x2x4x8_1d_ = LiteralUtil::CreateR3<complex64>({
1526 {{{28.0, 0.0}, {-4.0, 9.656854}, {-4.0, 4.0}, {-4.0, 1.656854},
1527 {-4.0, 0.0}, {-4.0, -1.656854}, {-4.0, -4.0}, {-4.0, -9.656854}},
1528 {{0.0, 28.0}, {-9.656854, -4.0}, {-4.0, -4.0}, {-1.656854, -4.0},
1529 {0.0, -4.0}, {1.656854, -4.0}, {4.0, -4.0}, {9.656854, -4.0}},
1530 {{28.0, 28.0}, {5.656854, 13.656854}, {0.0, 8.0}, {-2.343146, 5.656854},
1531 {-4.0, 4.0}, {-5.656854, 2.343146}, {-8.0, -0.0}, {-13.656854, -5.656854}}, // NOLINT
1532 {{28.0, 28.0}, {-5.656854, -13.656854}, {-0.0, -8.0}, {2.343146, -5.656854}, // NOLINT
1533 {4.0, -4.0}, {5.656854, -2.343146}, {8.0, 0.0}, {13.656854, 5.656854}}},
1534 {{{0.0, 0.0}, {-5.0, 12.071068}, {-4.0, 4.0}, {-5.0, 2.071068},
1535 {-4.0, 0.0}, {-5.0, -2.071068}, {-4.0, -4.0}, {-5.0, -12.071068}},
1536 {{0.0, 0.0}, {-12.071068, -5.0}, {-4.0, -4.0}, {-2.071068, -5.0},
1537 {0.0, -4.0}, {2.071068, -5.0}, {4.0, -4.0}, {12.071068, -5.0}},
1538 {{0.0, 7.0}, {1.0, 6.0}, {2.0, 5.0}, {3.0, 4.0},
1539 {4.0, 3.0}, {5.0, 2.0}, {6.0, 1.0}, {7.0, 0.0}},
1540 {{7.0, 0.0}, {6.0, 1.0}, {5.0, 2.0}, {4.0, 3.0},
1541 {3.0, 4.0}, {2.0, 5.0}, {1.0, 6.0}, {0.0, 7.0}}}
1542 });
1543 fft_c64x2x4x8_2d_ = LiteralUtil::CreateR3<complex64>({
1544 {{{84.0, 84.0}, {-13.656854, 5.656854}, {-8.0, 0.0}, {-5.656854, -2.343146},
1545 {-4.0, -4.0}, {-2.343146, -5.656854}, {0.0, -8.0}, {5.656854, -13.656854}}, // NOLINT
1546 {{0.0, 0.0}, {0.0, -0.0}, {0.0, 0.0}, {0.0, 0.0},
1547 {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
1548 {{28.0, -28.0}, {16.970562, 40.970562}, {0.0, 24.0}, {-7.029438, 16.970562}, // NOLINT
1549 {-12.0, 12.0}, {-16.970562, 7.029438}, {-24.0, 0.0}, {-40.970562, -16.970562}}, // NOLINT
1550 {{0.0, -56.0}, {-19.313708, -8.0}, {-8.0, -8.0}, {-3.313708, -8.0},
1551 {0.0, -8.0}, {3.313708, -8.0}, {8.0, -8.0}, {19.313708, -8.0}}},
1552 {{{7.0, 7.0}, {-10.071068, 14.071068}, {-1.0, 7.0}, {-0.071068, 4.071068},
1553 {3.0, 3.0}, {4.071068, -0.071068}, {7.0, -1.0}, {14.071068, -10.071068}},
1554 {{0.0, 0.0}, {-12.0, 24.142136}, {-12.0, 8.0}, {-16.0, 4.142136},
1555 {-16.0, 0.0}, {-20.0, -4.142136}, {-20.0, -8.0}, {-24.0, -24.142136}},
1556 {{-7.0, 7.0}, {2.071068, 22.071068}, {-3.0, 11.0}, {-3.928932, 8.071068},
1557 {-3.0, 3.0}, {-4.071068, -0.071068}, {-3.0, -5.0}, {-10.071068, -14.071068}}, // NOLINT
1558 {{0.0, -14.0}, {0.0, -12.0}, {0.0, -10.0}, {0.0, -8.0},
1559 {0.0, -6.0}, {0.0, -4.0}, {0.0, -2.0}, {0.0, 0.0}}}
1560 });
1561 fft_c64x2x4x8_3d_ = LiteralUtil::CreateR3<complex64>({
1562 {{{91.0, 91.0}, {-23.727922, 19.727922}, {-9.0, 7.0}, {-5.727922, 1.727922},
1563 {-1.0, -1.0}, {1.727922, -5.727922}, {7.0, -9}, {19.727922, -23.727922}},
1564 {{0.0, 0.0}, {-12.0, 24.142136}, {-12.0, 8.0}, {-16.0, 4.142136},
1565 {-16.0, 0.0}, {-20.0, -4.142136}, {-20.0, -8.0}, {-24.0, -24.142136}},
1566 {{21.0, -21.0}, {19.041630, 63.041630}, {-3.0, 35.0}, {-10.958370, 25.041630}, // NOLINT
1567 {-15.0, 15.0}, {-21.041630, 6.958370}, {-27.0, -5.0}, {-51.041630, -31.041630}}, // NOLINT
1568 {{0.0, -70.0}, {-19.313708, -20.0}, {-8.0, -18.0}, {-3.313708, -16.0},
1569 {0.0, -14.0}, {3.313708, -12.0}, {8.0, -10.0}, {19.313708, -8.0}}},
1570 {{{77.0, 77.0}, {-3.585786, -8.414214}, {-7.0, -7.0}, {-5.585786, -6.414214}, // NOLINT
1571 {-7.0, -7.0}, {-6.414214, -5.585786}, {-7.0, -7.0}, {-8.414214, -3.585786}}, // NOLINT
1572 {{0.0, 0.0}, {12.0, -24.142136}, {12.0, -8.0}, {16.0, -4.142136},
1573 {16.0, 0.0}, {20.0, 4.142136}, {20.0, 8.0}, {24.0, 24.142136}},
1574 {{35.0, -35.0}, {14.899494, 18.899494}, {3.0, 13.0}, {-3.100506, 8.899494},
1575 {-9.0, 9.0}, {-12.899494, 7.100506}, {-21.0, 5.0}, {-30.899494, -2.899494}}, // NOLINT
1576 {{0.0, -42.0}, {-19.313708, 4.0}, {-8.0, 2.0}, {-3.313708, 0.0},
1577 {0.0, -2.0}, {3.313708, -4.0}, {8.0, -6.0}, {19.313708, -8.0}}}
1578 });
1579 // clang-format on
1580 }
1581
1582 // Simple FFT tests:
1583
1584 TEST_F(HloEvaluatorTest, 1D_FFT_4_on_c64x4) {
1585 const char* hlo_text = R"(
1586 HloModule Fft
1587
1588 ENTRY main {
1589 operand = c64[4] parameter(0)
1590 ROOT fft = c64[4] fft(operand), fft_type=FFT, fft_length={4}
1591 }
1592 )";
1593 auto input = LiteralUtil::CreateR1<complex64>(
1594 {{1.0, 0.0}, {2.0, 0.0}, {3.0, 0.0}, {4.0, 0.0}});
1595 auto expected = LiteralUtil::CreateR1<complex64>(
1596 {{10.0, 0.0}, {-2.0, 2.0}, {-2.0, 0.0}, {-2.0, -2.0}});
1597 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1598 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1599 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1600 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1601 }
1602
1603 TEST_F(HloEvaluatorTest, 1D_IFFT_4_on_c64x4) {
1604 const char* hlo_text = R"(
1605 HloModule Fft
1606
1607 ENTRY main {
1608 operand = c64[4] parameter(0)
1609 ROOT ifft = c64[4] fft(operand), fft_type=IFFT, fft_length={4}
1610 }
1611 )";
1612 auto input = LiteralUtil::CreateR1<complex64>(
1613 {{10.0, 0.0}, {-2.0, 2.0}, {-2.0, 0.0}, {-2.0, -2.0}});
1614 auto expected = LiteralUtil::CreateR1<complex64>(
1615 {{1.0, 0.0}, {2.0, 0.0}, {3.0, 0.0}, {4.0, 0.0}});
1616 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1617 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1618 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1619 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1620 }
1621
1622 TEST_F(HloEvaluatorTest, 1D_RFFT_4_on_f32x4) {
1623 const char* hlo_text = R"(
1624 HloModule Fft
1625
1626 ENTRY main {
1627 operand = f32[4] parameter(0)
1628 ROOT rfft = c64[3] fft(operand), fft_type=RFFT, fft_length={4}
1629 }
1630 )";
1631 auto input = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, 4.0});
1632 auto expected =
1633 LiteralUtil::CreateR1<complex64>({{10.0, 0.0}, {-2.0, 2.0}, {-2.0, 0.0}});
1634 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1635 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1636 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1637 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1638 }
1639
1640 TEST_F(HloEvaluatorTest, 1D_IRFFT_4_on_c64x3) {
1641 const char* hlo_text = R"(
1642 HloModule Fft
1643
1644 ENTRY main {
1645 operand = c64[3] parameter(0)
1646 ROOT irfft = f32[4] fft(operand), fft_type=IRFFT, fft_length={4}
1647 }
1648 )";
1649 auto input =
1650 LiteralUtil::CreateR1<complex64>({{10.0, 0.0}, {-2.0, 2.0}, {-2.0, 0.0}});
1651 auto expected = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, 4.0});
1652 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1653 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1654 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1655 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1656 }
1657
1658 // 1D FFT tests:
1659
1660 TEST_F(HloEvaluatorTest, 1D_FFT_8_on_c64x2x4x8) {
1661 const char* hlo_text = R"(
1662 HloModule Fft
1663
1664 ENTRY main {
1665 operand = c64[2, 4, 8] parameter(0)
1666 ROOT fft = c64[2, 4, 8] fft(operand), fft_type=FFT, fft_length={8}
1667 }
1668 )";
1669 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1670 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&fft_c64x2x4x8_}));
1671 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_1d_.shape()));
1672 EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_1d_, result, fft_error_));
1673 }
1674
1675 TEST_F(HloEvaluatorTest, 1D_IFFT_8_on_c64x2x4x8) {
1676 const char* hlo_text = R"(
1677 HloModule Fft
1678
1679 ENTRY main {
1680 operand = c64[2, 4, 8] parameter(0)
1681 ROOT ifft = c64[2, 4, 8] fft(operand), fft_type=IFFT, fft_length={8}
1682 }
1683 )";
1684 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1685 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&fft_c64x2x4x8_1d_}));
1686 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_.shape()));
1687 EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_, result, fft_error_));
1688 }
1689
1690 TEST_F(HloEvaluatorTest, 1D_RFFT_8_on_f32x8) {
1691 const char* hlo_text = R"(
1692 HloModule Fft
1693
1694 ENTRY main {
1695 operand = f32[8] parameter(0)
1696 ROOT rfft = c64[5] fft(operand), fft_type=RFFT, fft_length={8}
1697 }
1698 )";
1699 auto input =
1700 LiteralUtil::CreateR1<float>({1.8, 2.7, 3.6, 4.5, 5.4, 6.3, 7.2, 8.1});
1701 auto expected = LiteralUtil::CreateR1<complex64>({{39.6, 0.0},
1702 {-3.6, 8.691169},
1703 {-3.6, 3.6},
1704 {-3.6, 1.491169},
1705 {-3.6, 0.0}});
1706 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1707 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1708 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1709 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1710 }
1711
1712 TEST_F(HloEvaluatorTest, 1D_IRFFT_8_on_c64x5) {
1713 const char* hlo_text = R"(
1714 HloModule Fft
1715
1716 ENTRY main {
1717 operand = c64[5] parameter(0)
1718 ROOT irfft = f32[8] fft(operand), fft_type=IRFFT, fft_length={8}
1719 }
1720 )";
1721 auto input = LiteralUtil::CreateR1<complex64>({{39.6, 0.0},
1722 {-3.6, 8.691169},
1723 {-3.6, 3.6},
1724 {-3.6, 1.491169},
1725 {-3.6, 0.0}});
1726 auto expected =
1727 LiteralUtil::CreateR1<float>({1.8, 2.7, 3.6, 4.5, 5.4, 6.3, 7.2, 8.1});
1728 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1729 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1730 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1731 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1732 }
1733
1734 TEST_F(HloEvaluatorTest, 1D_RFFT_9_on_f32x9) {
1735 const char* hlo_text = R"(
1736 HloModule Fft
1737
1738 ENTRY main {
1739 operand = f32[9] parameter(0)
1740 ROOT rfft = c64[5] fft(operand), fft_type=RFFT, fft_length={9}
1741 }
1742 )";
1743 auto input = LiteralUtil::CreateR1<float>(
1744 {1.8, 2.7, 3.6, 4.5, 5.4, 6.3, 7.2, 8.1, 9.9});
1745 auto expected = LiteralUtil::CreateR1<complex64>({{49.5, 0.0},
1746 {-3.360560, 11.705792},
1747 {-3.893717, 5.712929},
1748 {-4.5, 3.117691},
1749 {-4.895723, 1.021942}});
1750 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1751 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1752 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1753 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1754 }
1755
1756 TEST_F(HloEvaluatorTest, 1D_IRFFT_9_on_c64x5) {
1757 const char* hlo_text = R"(
1758 HloModule Fft
1759
1760 ENTRY main {
1761 operand = c64[5] parameter(0)
1762 ROOT irfft = f32[9] fft(operand), fft_type=IRFFT, fft_length={9}
1763 }
1764 )";
1765 auto input = LiteralUtil::CreateR1<complex64>({{49.5, 0.0},
1766 {-3.360560, 11.705792},
1767 {-3.893717, 5.712929},
1768 {-4.5, 3.117691},
1769 {-4.895723, 1.021942}});
1770 auto expected = LiteralUtil::CreateR1<float>(
1771 {1.8, 2.7, 3.6, 4.5, 5.4, 6.3, 7.2, 8.1, 9.9});
1772 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1773 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1774 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1775 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1776 }
1777
1778 // 2D FFT tests:
1779
1780 TEST_F(HloEvaluatorTest, 2D_FFT_4x8_on_c64x2x4x8) {
1781 const char* hlo_text = R"(
1782 HloModule Fft
1783
1784 ENTRY main {
1785 operand = c64[2, 4, 8] parameter(0)
1786 ROOT fft = c64[2, 4, 8] fft(operand), fft_type=FFT, fft_length={4, 8}
1787 }
1788 )";
1789 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1790 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&fft_c64x2x4x8_}));
1791 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_2d_.shape()));
1792 EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_2d_, result, fft_error_));
1793 }
1794
1795 TEST_F(HloEvaluatorTest, 2D_IFFT_4x8_on_c64x2x4x8) {
1796 const char* hlo_text = R"(
1797 HloModule Fft
1798
1799 ENTRY main {
1800 operand = c64[2, 4, 8] parameter(0)
1801 ROOT ifft = c64[2, 4, 8] fft(operand), fft_type=IFFT, fft_length={4, 8}
1802 }
1803 )";
1804 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1805 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&fft_c64x2x4x8_2d_}));
1806 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_.shape()));
1807 EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_, result, fft_error_));
1808 }
1809
1810 TEST_F(HloEvaluatorTest, 2D_RFFT_3x8_on_f32x3x8) {
1811 const char* hlo_text = R"(
1812 HloModule Fft
1813
1814 ENTRY main {
1815 operand = f32[3, 8] parameter(0)
1816 ROOT rfft = c64[3, 5] fft(operand), fft_type=RFFT, fft_length={3, 8}
1817 }
1818 )";
1819 auto input =
1820 LiteralUtil::CreateR2<float>({{1.8, 2.7, 3.6, 4.5, 5.4, 6.3, 7.2, 8.1},
1821 {8.1, 7.2, 6.3, 5.4, 4.5, 3.6, 2.7, 1.8},
1822 {1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8}});
1823 auto expected = LiteralUtil::CreateR2<complex64>({{{118.8, 0.0},
1824 {-4.4, 10.622540},
1825 {-4.4, 4.4},
1826 {-4.4, 1.822540},
1827 {-4.4, 0.0}},
1828 {{0.0, 0.0},
1829 {-19.926162, 0.797280},
1830 {-10.128203, -3.728203},
1831 {-6.069756, -5.602720},
1832 {-3.2, -6.928203}},
1833 {{0.0, 0.0},
1834 {13.526162, 14.653687},
1835 {3.728203, 10.128203},
1836 {-0.330244, 8.253687},
1837 {-3.2, 6.928203}}});
1838 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1839 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1840 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1841 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1842 }
1843
1844 TEST_F(HloEvaluatorTest, 2D_IRFFT_3x8_on_c64x3x5) {
1845 const char* hlo_text = R"(
1846 HloModule Fft
1847
1848 ENTRY main {
1849 operand = c64[3, 5] parameter(0)
1850 ROOT irfft = f32[3, 8] fft(operand), fft_type=IRFFT, fft_length={3, 8}
1851 }
1852 )";
1853 auto input = LiteralUtil::CreateR2<complex64>({{{118.8, 0.0},
1854 {-4.4, 10.622540},
1855 {-4.4, 4.4},
1856 {-4.4, 1.822540},
1857 {-4.4, 0.0}},
1858 {{0.0, 0.0},
1859 {-19.926162, 0.797280},
1860 {-10.128203, -3.728203},
1861 {-6.069756, -5.602720},
1862 {-3.2, -6.928203}},
1863 {{0.0, 0.0},
1864 {13.526162, 14.653687},
1865 {3.728203, 10.128203},
1866 {-0.330244, 8.253687},
1867 {-3.2, 6.928203}}});
1868 auto expected =
1869 LiteralUtil::CreateR2<float>({{1.8, 2.7, 3.6, 4.5, 5.4, 6.3, 7.2, 8.1},
1870 {8.1, 7.2, 6.3, 5.4, 4.5, 3.6, 2.7, 1.8},
1871 {1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8}});
1872 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1873 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1874 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1875 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1876 }
1877
1878 TEST_F(HloEvaluatorTest, 2D_RFFT_3x9_on_f32x3x9) {
1879 const char* hlo_text = R"(
1880 HloModule Fft
1881
1882 ENTRY main {
1883 operand = f32[3, 9] parameter(0)
1884 ROOT rfft = c64[3, 5] fft(operand), fft_type=RFFT, fft_length={3, 9}
1885 }
1886 )";
1887 auto input = LiteralUtil::CreateR2<float>(
1888 {{1.9, 2.8, 3.7, 4.6, 5.5, 6.4, 7.3, 8.2, 9.1},
1889 {9.1, 8.2, 7.3, 6.4, 5.5, 4.6, 3.7, 2.8, 1.9},
1890 {1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9}});
1891 auto expected = LiteralUtil::CreateR2<complex64>({{{148.5, 0.0},
1892 {-4.95, 13.600013},
1893 {-4.95, 5.899180},
1894 {-4.95, 2.857884},
1895 {-4.95, 0.872819}},
1896 {{0.0, 0.0},
1897 {-25.014467, 2.096690},
1898 {-12.888800, -3.503916},
1899 {-8.1, -5.715768},
1900 {-4.974333, -7.159452}},
1901 {{0.0, 0.0},
1902 {17.814467, 17.685147},
1903 {5.688800, 12.084542},
1904 {0.9, 9.872690},
1905 {-2.225667, 8.429006}}});
1906 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1907 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1908 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1909 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1910 }
1911
1912 TEST_F(HloEvaluatorTest, 2D_IRFFT_3x9_on_c64x3x5) {
1913 const char* hlo_text = R"(
1914 HloModule Fft
1915
1916 ENTRY main {
1917 operand = c64[3, 5] parameter(0)
1918 ROOT irfft = f32[3, 9] fft(operand), fft_type=IRFFT, fft_length={3, 9}
1919 }
1920 )";
1921 auto input = LiteralUtil::CreateR2<complex64>({{{148.5, 0.0},
1922 {-4.95, 13.600013},
1923 {-4.95, 5.899180},
1924 {-4.95, 2.857884},
1925 {-4.95, 0.872819}},
1926 {{0.0, 0.0},
1927 {-25.014467, 2.096690},
1928 {-12.888800, -3.503916},
1929 {-8.1, -5.715768},
1930 {-4.974333, -7.159452}},
1931 {{0.0, 0.0},
1932 {17.814467, 17.685147},
1933 {5.688800, 12.084542},
1934 {0.9, 9.872690},
1935 {-2.225667, 8.429006}}});
1936 auto expected = LiteralUtil::CreateR2<float>(
1937 {{1.9, 2.8, 3.7, 4.6, 5.5, 6.4, 7.3, 8.2, 9.1},
1938 {9.1, 8.2, 7.3, 6.4, 5.5, 4.6, 3.7, 2.8, 1.9},
1939 {1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9}});
1940 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1941 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1942 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1943 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1944 }
1945
1946 // 3D FFT tests:
1947
1948 TEST_F(HloEvaluatorTest, 3D_FFT_2x4x8_on_c64x2x4x8) {
1949 const char* hlo_text = R"(
1950 HloModule Fft
1951
1952 ENTRY main {
1953 operand = c64[2, 4, 8] parameter(0)
1954 ROOT fft = c64[2, 4, 8] fft(operand), fft_type=FFT, fft_length={2, 4, 8}
1955 }
1956 )";
1957 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1958 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&fft_c64x2x4x8_}));
1959 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_3d_.shape()));
1960 EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_3d_, result, fft_error_));
1961 }
1962
1963 TEST_F(HloEvaluatorTest, 3D_IFFT_2x4x8_on_c64x2x4x8) {
1964 const char* hlo_text = R"(
1965 HloModule Fft
1966
1967 ENTRY main {
1968 operand = c64[2, 4, 8] parameter(0)
1969 ROOT ifft = c64[2, 4, 8] fft(operand), fft_type=IFFT, fft_length={2, 4, 8}
1970 }
1971 )";
1972 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1973 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&fft_c64x2x4x8_3d_}));
1974 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_.shape()));
1975 EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_, result, fft_error_));
1976 }
1977
1978 TEST_F(HloEvaluatorTest, 3D_RFFT_3x3x4_on_f32x3x3x4) {
1979 const char* hlo_text = R"(
1980 HloModule Fft
1981
1982 ENTRY main {
1983 operand = f32[3, 3, 4] parameter(0)
1984 ROOT rfft = c64[3, 3, 3] fft(operand), fft_type=RFFT, fft_length={3, 3, 4}
1985 }
1986 )";
1987 auto input = LiteralUtil::CreateR3<float>(
1988 {{{1.8, 2.7, 3.6, 4.5}, {8.1, 7.2, 6.3, 5.4}, {1.1, 2.2, 3.3, 4.4}},
1989 {{5.4, 6.3, 7.2, 8.1}, {4.5, 3.6, 2.7, 1.8}, {5.5, 6.6, 7.7, 8.8}},
1990 {{-1.8, -2.7, -3.6, -4.5},
1991 {-5.4, -6.3, -7.2, -8.1},
1992 {1.9, 2.9, 3.9, 4.9}}});
1993 auto expected = LiteralUtil::CreateR3<complex64>(
1994 {{{{92.8, 0.0}, {-2.8, 2.8}, {-2.8, 0.0}},
1995 {{-5.9, 35.160631}, {-11.519100, -8.919100}, {-1.3, -10.219100}},
1996 {{-5.9, -35.160631}, {8.919100, 11.519100}, {-1.3, 10.219100}}},
1997 {{{29.5, -81.579593}, {1.390897, 5.190897}, {-1.9, 3.290897}},
1998 {{-25.1, -49.017038}, {1.044486, 4.844486}, {-1.9, 2.944486}},
1999 {{11.8, 27.712813}, {1.517691, 4.717691}, {-1.6, 3.117691}}},
2000 {{{29.5, 81.579593}, {-5.190897, -1.390897}, {-1.9, -3.290897}},
2001 {{11.8, -27.712813}, {-4.717691, -1.517691}, {-1.6, -3.117691}},
2002 {{-25.1, 49.017038}, {-4.844486, -1.044486}, {-1.9, -2.944486}}}});
2003 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2004 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2005 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2006 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2007 }
2008
2009 TEST_F(HloEvaluatorTest, 3D_IRFFT_3x3x4_on_c64x3x3x3) {
2010 const char* hlo_text = R"(
2011 HloModule Fft
2012
2013 ENTRY main {
2014 operand = c64[3, 3, 3] parameter(0)
2015 ROOT irfft = f32[3, 3, 4] fft(operand), fft_type=IRFFT, fft_length={3, 3, 4}
2016 }
2017 )";
2018 auto input = LiteralUtil::CreateR3<complex64>(
2019 {{{{92.8, 0.0}, {-2.8, 2.8}, {-2.8, 0.0}},
2020 {{-5.9, 35.160631}, {-11.519100, -8.919100}, {-1.3, -10.219100}},
2021 {{-5.9, -35.160631}, {8.919100, 11.519100}, {-1.3, 10.219100}}},
2022 {{{29.5, -81.579593}, {1.390897, 5.190897}, {-1.9, 3.290897}},
2023 {{-25.1, -49.017038}, {1.044486, 4.844486}, {-1.9, 2.944486}},
2024 {{11.8, 27.712813}, {1.517691, 4.717691}, {-1.6, 3.117691}}},
2025 {{{29.5, 81.579593}, {-5.190897, -1.390897}, {-1.9, -3.290897}},
2026 {{11.8, -27.712813}, {-4.717691, -1.517691}, {-1.6, -3.117691}},
2027 {{-25.1, 49.017038}, {-4.844486, -1.044486}, {-1.9, -2.944486}}}});
2028 auto expected = LiteralUtil::CreateR3<float>(
2029 {{{1.8, 2.7, 3.6, 4.5}, {8.1, 7.2, 6.3, 5.4}, {1.1, 2.2, 3.3, 4.4}},
2030 {{5.4, 6.3, 7.2, 8.1}, {4.5, 3.6, 2.7, 1.8}, {5.5, 6.6, 7.7, 8.8}},
2031 {{-1.8, -2.7, -3.6, -4.5},
2032 {-5.4, -6.3, -7.2, -8.1},
2033 {1.9, 2.9, 3.9, 4.9}}});
2034 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2035 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2036 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2037 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2038 }
2039
2040 TEST_F(HloEvaluatorTest, 3D_RFFT_3x3x5_on_f32x3x3x5) {
2041 const char* hlo_text = R"(
2042 HloModule Fft
2043
2044 ENTRY main {
2045 operand = f32[3, 3, 5] parameter(0)
2046 ROOT rfft = c64[3, 3, 3] fft(operand), fft_type=RFFT, fft_length={3, 3, 5}
2047 }
2048 )";
2049 auto input = LiteralUtil::CreateR3<float>({{{1.8, 2.7, 3.6, 4.5, 5.4},
2050 {8.1, 7.2, 6.3, 5.4, 4.5},
2051 {1.1, 2.2, 3.3, 4.4, 5.5}},
2052 {{5.4, 6.3, 7.2, 8.1, 9.0},
2053 {4.5, 3.6, 2.7, 1.8, 0.9},
2054 {5.5, 6.6, 7.7, 8.8, 9.9}},
2055 {{-1.8, -2.7, -3.6, -4.5, -5.4},
2056 {-5.4, -6.3, -7.2, -8.1, -9.0},
2057 {1.9, 2.9, 3.9, 4.9, 5.9}}});
2058 auto expected = LiteralUtil::CreateR3<complex64>(
2059 {{{{119.5, 0.0}, {-3.5, 4.817337}, {-3.5, 1.137219}},
2060 {{-5.75, 56.724664}, {-19.206730, -10.537254}, {-5.775483, -12.245880}},
2061 {{-5.75, -56.724664}, {15.956730, 15.010495}, {2.525483, 13.301869}}},
2062 {{{39.25, -106.088112}, {3.286913, 7.382528}, {-1.038404, 4.885305}},
2063 {{-29.0, -64.951905}, {2.690922, 6.949515}, {-1.179098, 4.452292}},
2064 {{16.75, 30.743902}, {3.363918, 6.649878}, {-0.733751, 4.546954}}},
2065 {{{39.25, 106.088112}, {-8.036913, -0.844714}, {-3.711596, -3.341936}},
2066 {{16.75, -30.743902}, {-7.363918, -1.144350}, {-3.266249, -3.247275}},
2067 {{-29.0, 64.951905}, {-7.440922, -0.411701}, {-3.570902, -2.908924}}}});
2068 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2069 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2070 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2071 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2072 }
2073
2074 TEST_F(HloEvaluatorTest, 3D_IRFFT_3x3x5_on_c64x3x3x3) {
2075 const char* hlo_text = R"(
2076 HloModule Fft
2077
2078 ENTRY main {
2079 operand = c64[3, 3, 3] parameter(0)
2080 ROOT irfft = f32[3, 3, 5] fft(operand), fft_type=IRFFT, fft_length={3, 3, 5}
2081 }
2082 )";
2083 auto input = LiteralUtil::CreateR3<complex64>(
2084 {{{{119.5, 0.0}, {-3.5, 4.817337}, {-3.5, 1.137219}},
2085 {{-5.75, 56.724664}, {-19.206730, -10.537254}, {-5.775483, -12.245880}},
2086 {{-5.75, -56.724664}, {15.956730, 15.010495}, {2.525483, 13.301869}}},
2087 {{{39.25, -106.088112}, {3.286913, 7.382528}, {-1.038404, 4.885305}},
2088 {{-29.0, -64.951905}, {2.690922, 6.949515}, {-1.179098, 4.452292}},
2089 {{16.75, 30.743902}, {3.363918, 6.649878}, {-0.733751, 4.546954}}},
2090 {{{39.25, 106.088112}, {-8.036913, -0.844714}, {-3.711596, -3.341936}},
2091 {{16.75, -30.743902}, {-7.363918, -1.144350}, {-3.266249, -3.247275}},
2092 {{-29.0, 64.951905}, {-7.440922, -0.411701}, {-3.570902, -2.908924}}}});
2093 auto expected = LiteralUtil::CreateR3<float>({{{1.8, 2.7, 3.6, 4.5, 5.4},
2094 {8.1, 7.2, 6.3, 5.4, 4.5},
2095 {1.1, 2.2, 3.3, 4.4, 5.5}},
2096 {{5.4, 6.3, 7.2, 8.1, 9.0},
2097 {4.5, 3.6, 2.7, 1.8, 0.9},
2098 {5.5, 6.6, 7.7, 8.8, 9.9}},
2099 {{-1.8, -2.7, -3.6, -4.5, -5.4},
2100 {-5.4, -6.3, -7.2, -8.1, -9.0},
2101 {1.9, 2.9, 3.9, 4.9, 5.9}}});
2102 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2103 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2104 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2105 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2106 }
2107
2108 // FFT tests with non-default data layout:
2109
2110 TEST_F(HloEvaluatorTest, 1D_FFT_8_on_c64x2x4x8_with_layout) {
2111 const char* hlo_text = R"(
2112 HloModule Fft
2113
2114 ENTRY main {
2115 operand = c64[2, 4, 8]{0, 2, 1} parameter(0)
2116 ROOT fft = c64[2, 4, 8]{1, 2, 0} fft(operand), fft_type=FFT, fft_length={8}
2117 }
2118 )";
2119 auto input = fft_c64x2x4x8_.Relayout(LayoutUtil::MakeLayout({0, 2, 1}));
2120 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2121 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2122 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_1d_.shape()));
2123 EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_1d_, result, fft_error_));
2124 }
2125
2126 TEST_F(HloEvaluatorTest, 2D_FFT_4x8_on_c64x2x4x8_with_layout) {
2127 const char* hlo_text = R"(
2128 HloModule Fft
2129
2130 ENTRY main {
2131 operand = c64[2, 4, 8]{2, 0, 1} parameter(0)
2132 ROOT fft = c64[2, 4, 8]{1, 0, 2} fft(operand), fft_type=FFT, fft_length={4, 8}
2133 }
2134 )";
2135 auto input = fft_c64x2x4x8_.Relayout(LayoutUtil::MakeLayout({2, 0, 1}));
2136 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2137 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2138 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_2d_.shape()));
2139 EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_2d_, result, fft_error_));
2140 }
2141
2142 TEST_F(HloEvaluatorTest, 3D_FFT_2x4x8_on_c64x2x4x8_with_layout) {
2143 const char* hlo_text = R"(
2144 HloModule Fft
2145
2146 ENTRY main {
2147 operand = c64[2, 4, 8]{1, 2, 0} parameter(0)
2148 ROOT fft =
2149 c64[2, 4, 8]{0, 2, 1} fft(operand), fft_type=FFT, fft_length={2, 4, 8}
2150 }
2151 )";
2152 auto input = fft_c64x2x4x8_.Relayout(LayoutUtil::MakeLayout({1, 2, 0}));
2153 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2154 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2155 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_3d_.shape()));
2156 EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_3d_, result, fft_error_));
2157 }
2158
2159 // FFT tests with unusual parameters:
2160
2161 // Zero-length transform.
2162 TEST_F(HloEvaluatorTest, 1D_FFT_0_on_c64x1x1x1x1) {
2163 const char* hlo_text = R"(
2164 HloModule Fft
2165
2166 ENTRY main {
2167 operand = c64[1, 1, 1, 1] parameter(0)
2168 ROOT fft = c64[1, 1, 1, 1] fft(operand), fft_type=FFT, fft_length={0}
2169 }
2170 )";
2171 auto input = LiteralUtil::CreateR4<complex64>({{{{{42.24, 24.42}}}}});
2172 auto expected = LiteralUtil::CreateR4<complex64>({{{{{0.0, 0.0}}}}});
2173 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2174 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2175 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2176 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2177 }
2178
2179 // Zero-length axis.
2180 TEST_F(HloEvaluatorTest, 1D_FFT_1_on_c64x1x1x1x0) {
2181 const char* hlo_text = R"(
2182 HloModule Fft
2183
2184 ENTRY main {
2185 operand = c64[1, 1, 1, 0] parameter(0)
2186 ROOT fft = c64[1, 1, 1, 0] fft(operand), fft_type=FFT, fft_length={1}
2187 }
2188 )";
2189 TF_ASSERT_OK_AND_ASSIGN(
2190 auto input,
2191 LiteralUtil::CreateR4<complex64>({{{{}}}}).Reshape({1, 1, 1, 0}));
2192 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2193 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2194 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), input.shape()));
2195 EXPECT_TRUE(LiteralTestUtil::Near(input, result, fft_error_));
2196 }
2197
2198 // Some/all dimensions have length 1.
2199 TEST_F(HloEvaluatorTest, 1D_FFT_1_on_c64x1x1x1x1) {
2200 const char* hlo_text = R"(
2201 HloModule Fft
2202
2203 ENTRY main {
2204 operand = c64[1, 1, 1, 1] parameter(0)
2205 ROOT fft = c64[1, 1, 1, 1] fft(operand), fft_type=FFT, fft_length={1}
2206 }
2207 )";
2208 auto input = LiteralUtil::CreateR4<complex64>({{{{{42.24, 24.42}}}}});
2209 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2210 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2211 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), input.shape()));
2212 EXPECT_TRUE(LiteralTestUtil::Near(input, result, fft_error_));
2213 }
2214
2215 // Zero-length transform.
2216 TEST_F(HloEvaluatorTest, 3D_FFT_1x0x1_on_c64x1x1x1x1) {
2217 const char* hlo_text = R"(
2218 HloModule Fft
2219
2220 ENTRY main {
2221 operand = c64[1, 1, 1, 1] parameter(0)
2222 ROOT fft = c64[1, 1, 1, 1] fft(operand), fft_type=FFT, fft_length={1, 0, 1}
2223 }
2224 )";
2225 auto input = LiteralUtil::CreateR4<complex64>({{{{{42.24, 24.42}}}}});
2226 auto expected = LiteralUtil::CreateR4<complex64>({{{{{0.0, 0.0}}}}});
2227 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2228 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2229 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2230 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2231 }
2232
2233 // Zero-length axis.
2234 TEST_F(HloEvaluatorTest, 3D_FFT_1x1x1_on_c64x0x1x0x1) {
2235 const char* hlo_text = R"(
2236 HloModule Fft
2237
2238 ENTRY main {
2239 operand = c64[0, 1, 0, 1] parameter(0)
2240 ROOT fft = c64[0, 1, 0, 1] fft(operand), fft_type=FFT, fft_length={1, 1, 1}
2241 }
2242 )";
2243 TF_ASSERT_OK_AND_ASSIGN(
2244 auto input,
2245 LiteralUtil::CreateR4<complex64>({{{{}}}}).Reshape({0, 1, 0, 1}));
2246 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2247 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2248 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), input.shape()));
2249 EXPECT_TRUE(LiteralTestUtil::Near(input, result, fft_error_));
2250 }
2251
2252 // Some/all dimensions have length 1.
2253 TEST_F(HloEvaluatorTest, 3D_FFT_1x1x1_on_c64x1x1x1x1) {
2254 const char* hlo_text = R"(
2255 HloModule Fft
2256
2257 ENTRY main {
2258 operand = c64[1, 1, 1, 1] parameter(0)
2259 ROOT fft = c64[1, 1, 1, 1] fft(operand), fft_type=FFT, fft_length={1, 1, 1}
2260 }
2261 )";
2262 auto input = LiteralUtil::CreateR4<complex64>({{{{{42.24, 24.42}}}}});
2263 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2264 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2265 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), input.shape()));
2266 EXPECT_TRUE(LiteralTestUtil::Near(input, result, fft_error_));
2267 }
2268
2269 // Some/all dimensions have length 1.
2270 TEST_F(HloEvaluatorTest, 3D_FFT_3x1x1_on_c64x1x3x1x1) {
2271 const char* hlo_text = R"(
2272 HloModule Fft
2273
2274 ENTRY main {
2275 operand = c64[1, 3, 1, 1] parameter(0)
2276 ROOT fft = c64[1, 3, 1, 1] fft(operand), fft_type=FFT, fft_length={3, 1, 1}
2277 }
2278 )";
2279 auto input = LiteralUtil::CreateR4<complex64>(
2280 {{{{{42.24, 24.42}}}, {{{-42.24, 24.42}}}, {{{42.24, -24.42}}}}});
2281 auto expected =
2282 LiteralUtil::CreateR4<complex64>({{{{{42.24, 24.42}}},
2283 {{{84.5367, 97.5818}}},
2284 {{{-0.0566792, -48.7418}}}}});
2285 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2286 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2287 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2288 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2289 }
2290
2291 // Some/all dimensions have length 1.
2292 TEST_F(HloEvaluatorTest, 3D_IFFT_3x1x1_on_c64x1x3x1x1) {
2293 const char* hlo_text = R"(
2294 HloModule Fft
2295
2296 ENTRY main {
2297 operand = c64[1, 3, 1, 1] parameter(0)
2298 ROOT ifft = c64[1, 3, 1, 1] fft(operand), fft_type=IFFT, fft_length={3, 1, 1}
2299 }
2300 )";
2301 auto input = LiteralUtil::CreateR4<complex64>({{{{{42.24, 24.42}}},
2302 {{{84.5367, 97.5818}}},
2303 {{{-0.0566792, -48.7418}}}}});
2304 auto expected = LiteralUtil::CreateR4<complex64>(
2305 {{{{{42.24, 24.42}}}, {{{-42.24, 24.42}}}, {{{42.24, -24.42}}}}});
2306 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2307 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2308 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2309 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2310 }
2311
2312 // Odd transform length.
2313 TEST_F(HloEvaluatorTest, 1D_FFT_5_on_c64x5) {
2314 const char* hlo_text = R"(
2315 HloModule Fft
2316
2317 ENTRY main {
2318 operand = c64[5] parameter(0)
2319 ROOT fft = c64[5] fft(operand), fft_type=FFT, fft_length={5}
2320 }
2321 )";
2322 auto input = LiteralUtil::CreateR1<complex64>(
2323 {{1.0, 5.0}, {2.0, 4.0}, {3.0, 3.0}, {4.0, 2.0}, {5.0, 1.0}});
2324 auto expected = LiteralUtil::CreateR1<complex64>({{15.0, 15.0},
2325 {0.940955, 5.94095},
2326 {-1.6877, 3.3123},
2327 {-3.3123, 1.6877},
2328 {-5.94095, -0.940955}});
2329 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2330 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2331 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2332 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2333 }
2334
2335 // Odd transform length.
2336 TEST_F(HloEvaluatorTest, 1D_IFFT_5_on_c64x5) {
2337 const char* hlo_text = R"(
2338 HloModule Fft
2339
2340 ENTRY main {
2341 operand = c64[5] parameter(0)
2342 ROOT ifft = c64[5] fft(operand), fft_type=IFFT, fft_length={5}
2343 }
2344 )";
2345 auto input = LiteralUtil::CreateR1<complex64>({{15.0, 15.0},
2346 {0.940955, 5.94095},
2347 {-1.6877, 3.3123},
2348 {-3.3123, 1.6877},
2349 {-5.94095, -0.940955}});
2350 auto expected = LiteralUtil::CreateR1<complex64>(
2351 {{1.0, 5.0}, {2.0, 4.0}, {3.0, 3.0}, {4.0, 2.0}, {5.0, 1.0}});
2352 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2353 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2354 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2355 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2356 }
2357
2358 // All input values are zero.
2359 TEST_F(HloEvaluatorTest, 1D_FFT_4_on_zero_c64x4) {
2360 const char* hlo_text = R"(
2361 HloModule Fft
2362
2363 ENTRY main {
2364 operand = c64[4] parameter(0)
2365 ROOT fft = c64[4] fft(operand), fft_type=FFT, fft_length={4}
2366 }
2367 )";
2368 auto input = LiteralUtil::CreateR1<complex64>(
2369 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}});
2370 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2371 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2372 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), input.shape()));
2373 EXPECT_TRUE(LiteralTestUtil::Near(input, result, fft_error_));
2374 }
2375
2376 // All input values are zero.
2377 TEST_F(HloEvaluatorTest, 3D_FFT_3x3x4_on_zero_c64x3x3x4) {
2378 const char* hlo_text = R"(
2379 HloModule Fft
2380
2381 ENTRY main {
2382 operand = c64[3, 3, 4] parameter(0)
2383 ROOT fft = c64[3, 3, 4] fft(operand), fft_type=FFT, fft_length={3, 3, 4}
2384 }
2385 )";
2386 auto input = LiteralUtil::CreateR3<complex64>(
2387 {{{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2388 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2389 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}},
2390 {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2391 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2392 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}},
2393 {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2394 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2395 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}}});
2396 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2397 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2398 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), input.shape()));
2399 EXPECT_TRUE(LiteralTestUtil::Near(input, result, fft_error_));
2400 }
2401
2402 // All input values are zero.
2403 TEST_F(HloEvaluatorTest, 3D_IFFT_3x3x4_on_zero_c64x3x3x4) {
2404 const char* hlo_text = R"(
2405 HloModule Fft
2406
2407 ENTRY main {
2408 operand = c64[3, 3, 4] parameter(0)
2409 ROOT ifft = c64[3, 3, 4] fft(operand), fft_type=IFFT, fft_length={3, 3, 4}
2410 }
2411 )";
2412 auto input = LiteralUtil::CreateR3<complex64>(
2413 {{{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2414 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2415 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}},
2416 {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2417 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2418 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}},
2419 {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2420 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2421 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}}});
2422 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2423 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2424 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), input.shape()));
2425 EXPECT_TRUE(LiteralTestUtil::Near(input, result, fft_error_));
2426 }
2427
2428 // All input values are zero.
2429 TEST_F(HloEvaluatorTest, 3D_RFFT_3x3x4_on_zero_f32x3x3x4) {
2430 const char* hlo_text = R"(
2431 HloModule Fft
2432
2433 ENTRY main {
2434 operand = f32[3, 3, 4] parameter(0)
2435 ROOT rfft = c64[3, 3, 3] fft(operand), fft_type=RFFT, fft_length={3, 3, 4}
2436 }
2437 )";
2438 auto input = LiteralUtil::CreateR3<float>(
2439 {{{0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}},
2440 {{0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}},
2441 {{0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}}});
2442 auto expected = LiteralUtil::CreateR3<complex64>(
2443 {{{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2444 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2445 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}},
2446 {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2447 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2448 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}},
2449 {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2450 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2451 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}}});
2452 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2453 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2454 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2455 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2456 }
2457
2458 // All input values are zero.
2459 TEST_F(HloEvaluatorTest, 3D_IRFFT_3x3x4_on_zero_c64x3x3x3) {
2460 const char* hlo_text = R"(
2461 HloModule Fft
2462
2463 ENTRY main {
2464 operand = c64[3, 3, 3] parameter(0)
2465 ROOT irfft = f32[3, 3, 4] fft(operand), fft_type=IRFFT, fft_length={3, 3, 4}
2466 }
2467 )";
2468 auto input = LiteralUtil::CreateR3<complex64>(
2469 {{{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2470 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2471 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}},
2472 {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2473 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2474 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}},
2475 {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2476 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2477 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}}});
2478 auto expected = LiteralUtil::CreateR3<float>(
2479 {{{0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}},
2480 {{0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}},
2481 {{0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}}});
2482 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2483 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2484 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2485 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2486 }
2487
2488 // Input values, for which IRFFT discards non-zero imaginary parts.
2489 TEST_F(HloEvaluatorTest, 2D_IRFFT_3x4_on_c64x3x3) {
2490 const char* hlo_text = R"(
2491 HloModule Fft
2492
2493 ENTRY main {
2494 operand = c64[3, 3] parameter(0)
2495 ROOT irfft = f32[3, 4] fft(operand), fft_type=IRFFT, fft_length={3, 4}
2496 }
2497 )";
2498 auto input =
2499 LiteralUtil::CreateR2<complex64>({{{0.0, 0.0}, {1.0, 0.0}, {2.0, 0.0}},
2500 {{3.0, 0.0}, {4.0, 0.0}, {5.0, 0.0}},
2501 {{6.0, 0.0}, {7.0, 0.0}, {8.0, 0.0}}});
2502 auto expected =
2503 LiteralUtil::CreateR2<float>({{4.0, -0.5, 0.0, -0.5},
2504 {-1.5, 0.433013, 0.0, -0.433013},
2505 {-1.5, -0.433013, 0.0, 0.433013}});
2506 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2507 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2508 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2509 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2510 }
2511
2512 class HloEvaluatorPreciseReduceTest : public HloTestBase {};
2513
2514 // Tests that Reduce doesn't lose precision when adding many numbers (because
2515 // it accumulates its result in a double).
TEST_F(HloEvaluatorPreciseReduceTest,AddReductionPrecisionTest)2516 TEST_F(HloEvaluatorPreciseReduceTest, AddReductionPrecisionTest) {
2517 auto m = CreateNewVerifiedModule();
2518 HloComputation::Builder b(TestName());
2519
2520 constexpr int kNumElements = 1 << 25; // float += 1 saturates at 1<<24
2521 std::vector<float> v(kNumElements, 1.0f);
2522 HloInstruction* arg_instruction = b.AddInstruction(
2523 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(v)));
2524 HloInstruction* init_value = b.AddInstruction(
2525 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
2526
2527 HloComputation::Builder add_computation("add");
2528 Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
2529 auto param_lhs = add_computation.AddInstruction(
2530 HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
2531 auto param_rhs = add_computation.AddInstruction(
2532 HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
2533 add_computation.AddInstruction(HloInstruction::CreateBinary(
2534 scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs));
2535 auto add_func = m->AddEmbeddedComputation(add_computation.Build());
2536
2537 HloInstruction* reduce_instruction = b.AddInstruction(
2538 HloInstruction::CreateReduce(scalar_shape, arg_instruction, init_value,
2539 /*dimensions_to_reduce=*/{0}, add_func));
2540 m->AddEntryComputation(b.Build());
2541
2542 HloEvaluator hlo_eval;
2543 Literal result = hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie();
2544 LiteralTestUtil::ExpectR0Equal<float>(kNumElements, result);
2545 }
2546
2547 // Reducing many numbers should be fast because it doesn't create
2548 // intermediate Literals; the microbenchmark should finish in < 1 msec.
BM_ReducePrecisely(::testing::benchmark::State & state)2549 void BM_ReducePrecisely(::testing::benchmark::State& state) {
2550 HloComputation::Builder b("BM_ReducePrecisely");
2551 HloModuleConfig config;
2552 config.set_debug_options(GetDebugOptionsFromFlags());
2553 HloModule module("BM_ReducePrecisely", config);
2554
2555 constexpr int kNumElements = 1 << 25; // float += 1 saturates at 1<<24
2556 std::vector<float> v(kNumElements, 1.0f);
2557 HloInstruction* arg_instruction = b.AddInstruction(
2558 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(v)));
2559 auto init_value = b.AddInstruction(
2560 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
2561
2562 HloComputation::Builder add_computation("add");
2563 Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
2564 auto param_lhs = add_computation.AddInstruction(
2565 HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
2566 auto param_rhs = add_computation.AddInstruction(
2567 HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
2568 add_computation.AddInstruction(HloInstruction::CreateBinary(
2569 scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs));
2570 auto add_func = module.AddEmbeddedComputation(add_computation.Build());
2571
2572 HloInstruction* reduce_instruction = b.AddInstruction(
2573 HloInstruction::CreateReduce(scalar_shape, arg_instruction, init_value,
2574 /*dimensions_to_reduce=*/{0}, add_func));
2575 module.AddEntryComputation(b.Build());
2576
2577 // Benchmark loop
2578 for (auto s : state) {
2579 HloEvaluator hlo_eval;
2580 hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie();
2581 }
2582 }
2583
2584 BENCHMARK(BM_ReducePrecisely);
2585
TEST_P(HloEvaluatorBf16Test,ReduceAdd)2586 TEST_P(HloEvaluatorBf16Test, ReduceAdd) {
2587 HloComputation::Builder b(TestName());
2588
2589 // arg:
2590 // f32[2,3] {
2591 // { 1, 2, 3 },
2592 // { 5, 6, 7 },
2593 // }
2594 auto arg_array = absl::make_unique<Array2D<float>>(2, 3);
2595 arg_array->FillUnique(1.0f);
2596 auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
2597
2598 HloInstruction* arg_instruction =
2599 b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
2600
2601 auto init_value = b.AddInstruction(
2602 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
2603
2604 HloComputation::Builder add_computation("add");
2605 Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
2606 auto param_lhs = add_computation.AddInstruction(
2607 HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
2608 auto param_rhs = add_computation.AddInstruction(
2609 HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
2610 add_computation.AddInstruction(HloInstruction::CreateBinary(
2611 scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs));
2612 auto add_func = m_->AddEmbeddedComputation(add_computation.Build());
2613
2614 Shape shape = ShapeUtil::MakeShape(F32, {2});
2615 b.AddInstruction(
2616 HloInstruction::CreateReduce(shape, arg_instruction, init_value,
2617 /*dimensions_to_reduce=*/{1}, add_func));
2618
2619 m_->AddEntryComputation(b.Build());
2620
2621 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
2622
2623 auto expected = LiteralUtil::CreateR1<float>({6, 18});
2624
2625 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
2626 }
2627
TEST_P(HloEvaluatorBf16Test,ReduceWindowMax)2628 TEST_P(HloEvaluatorBf16Test, ReduceWindowMax) {
2629 HloComputation::Builder b(TestName());
2630
2631 // arg:
2632 // f32[2,3] {
2633 // { 1, 2, 3 },
2634 // { 5, 6, 7 },
2635 // }
2636 auto arg_array = absl::make_unique<Array2D<float>>(2, 3);
2637 arg_array->FillUnique(1.0f);
2638 auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
2639
2640 HloInstruction* arg_instruction =
2641 b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
2642
2643 auto init_value = b.AddInstruction(
2644 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
2645 auto max_func = m_->AddEmbeddedComputation(MaxComputationScalarF32());
2646
2647 Window window;
2648 WindowDimension dim;
2649 dim.set_size(2);
2650 dim.set_stride(1);
2651 dim.set_padding_low(0);
2652 dim.set_padding_high(0);
2653 dim.set_window_dilation(1);
2654 dim.set_base_dilation(1);
2655 *window.add_dimensions() = dim;
2656 *window.add_dimensions() = dim;
2657
2658 Shape shape = ShapeUtil::MakeShape(F32, {1, 2});
2659 b.AddInstruction(HloInstruction::CreateReduceWindow(
2660 shape, arg_instruction, init_value, window, max_func));
2661
2662 m_->AddEntryComputation(b.Build());
2663
2664 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
2665
2666 auto expected = LiteralUtil::CreateR2<float>({{6, 7}});
2667 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
2668 }
2669
TEST_P(HloEvaluatorBf16Test,ReduceWindowMaxIotaWindowDilation)2670 TEST_P(HloEvaluatorBf16Test, ReduceWindowMaxIotaWindowDilation) {
2671 auto expected = LiteralUtil::CreateR2<float>({{10, 11}, {14, 15}});
2672 ReduceWindowMaxIotaTest(
2673 /*window_size=*/2,
2674 /*padding=*/0,
2675 /*stride=*/1,
2676 /*window_dilation=*/2,
2677 /*base_dilation=*/1,
2678 /*expected=*/expected);
2679 }
2680
TEST_P(HloEvaluatorBf16Test,ReduceWindowMaxIotaStrideWindowDilation)2681 TEST_P(HloEvaluatorBf16Test, ReduceWindowMaxIotaStrideWindowDilation) {
2682 auto expected = LiteralUtil::CreateR2<float>({{10}});
2683 ReduceWindowMaxIotaTest(
2684 /*window_size=*/2,
2685 /*padding=*/0,
2686 /*stride=*/2,
2687 /*window_dilation=*/2,
2688 /*base_dilation=*/1,
2689 /*expected=*/expected);
2690 }
2691
TEST_P(HloEvaluatorBf16Test,ReduceWindowMaxIotaBaseDilation)2692 TEST_P(HloEvaluatorBf16Test, ReduceWindowMaxIotaBaseDilation) {
2693 auto expected = LiteralUtil::CreateR2<float>({{0, 1, 1, 2, 2, 3},
2694 {4, 5, 5, 6, 6, 7},
2695 {4, 5, 5, 6, 6, 7},
2696 {8, 9, 9, 10, 10, 11},
2697 {8, 9, 9, 10, 10, 11},
2698 {12, 13, 13, 14, 14, 15}});
2699 ReduceWindowMaxIotaTest(
2700 /*window_size=*/2,
2701 /*padding=*/0,
2702 /*stride=*/1,
2703 /*window_dilation=*/1,
2704 /*base_dilation=*/2,
2705 /*expected=*/expected);
2706 }
2707
TEST_P(HloEvaluatorBf16Test,ReduceWindowMaxIotaStrideBaseDilation)2708 TEST_P(HloEvaluatorBf16Test, ReduceWindowMaxIotaStrideBaseDilation) {
2709 auto expected =
2710 LiteralUtil::CreateR2<float>({{0, 1, 2}, {4, 5, 6}, {8, 9, 10}});
2711 ReduceWindowMaxIotaTest(
2712 /*window_size=*/2,
2713 /*padding=*/0,
2714 /*stride=*/2,
2715 /*window_dilation=*/1,
2716 /*base_dilation=*/2,
2717 /*expected=*/expected);
2718 }
2719
TEST_P(HloEvaluatorBf16Test,ReduceWindowMaxIotaStrideBothDilation)2720 TEST_P(HloEvaluatorBf16Test, ReduceWindowMaxIotaStrideBothDilation) {
2721 auto expected =
2722 LiteralUtil::CreateR2<float>({{5, 6, 7}, {9, 10, 11}, {13, 14, 15}});
2723 ReduceWindowMaxIotaTest(
2724 /*window_size=*/2,
2725 /*padding=*/0,
2726 /*stride=*/2,
2727 /*window_dilation=*/2,
2728 /*base_dilation=*/2,
2729 /*expected=*/expected);
2730 }
2731
TEST_P(HloEvaluatorBf16Test,ReduceWindowMaxIotaPaddingStrideBaseDilation)2732 TEST_P(HloEvaluatorBf16Test, ReduceWindowMaxIotaPaddingStrideBaseDilation) {
2733 // The base is dilated first, and then padding is applied, hence this result.
2734 auto expected =
2735 LiteralUtil::CreateR2<float>({{0, 2, 3}, {8, 10, 11}, {12, 14, 15}});
2736 ReduceWindowMaxIotaTest(
2737 /*window_size=*/3,
2738 /*padding=*/1,
2739 /*stride=*/3,
2740 /*window_dilation=*/1,
2741 /*base_dilation=*/2,
2742 /*expected=*/expected);
2743 }
2744
TEST_P(HloEvaluatorBf16Test,ReduceWindowAdd)2745 TEST_P(HloEvaluatorBf16Test, ReduceWindowAdd) {
2746 HloComputation::Builder b(TestName());
2747
2748 // arg:
2749 // f32[2,3] {
2750 // { 1, 2, 3 },
2751 // { 5, 6, 7 },
2752 // }
2753 auto arg_array = absl::make_unique<Array2D<float>>(2, 3);
2754 arg_array->FillUnique(1.0f);
2755 auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
2756
2757 HloInstruction* arg_instruction =
2758 b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
2759
2760 auto init_value = b.AddInstruction(
2761 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
2762
2763 HloComputation::Builder add_computation("add");
2764 Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
2765 auto param_lhs = add_computation.AddInstruction(
2766 HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
2767 auto param_rhs = add_computation.AddInstruction(
2768 HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
2769 add_computation.AddInstruction(HloInstruction::CreateBinary(
2770 scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs));
2771 auto add_func = m_->AddEmbeddedComputation(add_computation.Build());
2772
2773 Window window;
2774 WindowDimension dim;
2775 dim.set_size(1);
2776 dim.set_stride(1);
2777 dim.set_padding_low(0);
2778 dim.set_padding_high(0);
2779 dim.set_window_dilation(1);
2780 dim.set_base_dilation(1);
2781 *window.add_dimensions() = dim;
2782 dim.set_size(2);
2783 dim.set_stride(1);
2784 dim.set_padding_low(1);
2785 dim.set_padding_high(0);
2786 dim.set_window_dilation(1);
2787 dim.set_base_dilation(1);
2788 *window.add_dimensions() = dim;
2789
2790 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
2791 b.AddInstruction(HloInstruction::CreateReduceWindow(
2792 shape, arg_instruction, init_value, window, add_func));
2793
2794 m_->AddEntryComputation(b.Build());
2795
2796 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
2797
2798 auto expected = LiteralUtil::CreateR2<float>({{1, 3, 5}, {5, 11, 13}});
2799 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
2800 }
2801
TEST_P(HloEvaluatorBf16Test,ReduceWindowAdd6D)2802 TEST_P(HloEvaluatorBf16Test, ReduceWindowAdd6D) {
2803 HloComputation::Builder b(TestName());
2804
2805 // arg: f32[4,4,4,4,4,4] full of ones. Using small dims to limit run-time.
2806 std::vector<int64> input_dims(6, 4);
2807 Literal arg_literal =
2808 LiteralUtil::CreateFullWithDescendingLayout<float>(input_dims, 1.0f);
2809
2810 HloInstruction* arg_instruction =
2811 b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
2812
2813 auto init_value = b.AddInstruction(
2814 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
2815
2816 HloComputation::Builder add_computation("add");
2817 Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
2818 auto param_lhs = add_computation.AddInstruction(
2819 HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
2820 auto param_rhs = add_computation.AddInstruction(
2821 HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
2822 add_computation.AddInstruction(HloInstruction::CreateBinary(
2823 scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs));
2824 auto add_func = m_->AddEmbeddedComputation(add_computation.Build());
2825
2826 Window window;
2827
2828 WindowDimension trivial_dim;
2829 trivial_dim.set_size(1);
2830 trivial_dim.set_stride(1);
2831 trivial_dim.set_padding_low(0);
2832 trivial_dim.set_padding_high(0);
2833 trivial_dim.set_window_dilation(1);
2834 trivial_dim.set_base_dilation(1);
2835
2836 WindowDimension active_dim;
2837 active_dim.set_size(2);
2838 active_dim.set_stride(1);
2839 active_dim.set_padding_low(0);
2840 active_dim.set_padding_high(0);
2841 active_dim.set_window_dilation(1);
2842 active_dim.set_base_dilation(1);
2843
2844 *window.add_dimensions() = trivial_dim;
2845 *window.add_dimensions() = active_dim;
2846 *window.add_dimensions() = active_dim;
2847 *window.add_dimensions() = active_dim;
2848 *window.add_dimensions() = trivial_dim;
2849 *window.add_dimensions() = trivial_dim;
2850
2851 Shape shape = ShapeUtil::MakeShape(F32, {4, 3, 3, 3, 4, 4});
2852 b.AddInstruction(HloInstruction::CreateReduceWindow(
2853 shape, arg_instruction, init_value, window, add_func));
2854
2855 m_->AddEntryComputation(b.Build());
2856
2857 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
2858
2859 std::vector<int64> output_dims = {4, 3, 3, 3, 4, 4};
2860 Literal result_literal =
2861 LiteralUtil::CreateFullWithDescendingLayout<float>(output_dims, 8.0f);
2862 EXPECT_TRUE(LiteralTestUtil::Equal(result_literal, result));
2863 }
2864
TEST_P(HloEvaluatorBf16Test,Min3In5Stride2Tuple)2865 TEST_P(HloEvaluatorBf16Test, Min3In5Stride2Tuple) {
2866 HloComputation::Builder builder("main");
2867 auto input1 = builder.AddInstruction(HloInstruction::CreateConstant(
2868 LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1})));
2869 auto input2 = builder.AddInstruction(HloInstruction::CreateConstant(
2870 LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1})));
2871 HloComputation::Builder bcompute("ComputeFunction");
2872 auto shape1 = ShapeUtil::MakeShape(F32, {});
2873 auto shape2 = ShapeUtil::MakeShape(F32, {});
2874 auto p2 =
2875 bcompute.AddInstruction(HloInstruction::CreateParameter(0, shape1, "x0"));
2876 auto p3 =
2877 bcompute.AddInstruction(HloInstruction::CreateParameter(1, shape2, "x1"));
2878 auto p4 =
2879 bcompute.AddInstruction(HloInstruction::CreateParameter(2, shape1, "y0"));
2880 auto p5 =
2881 bcompute.AddInstruction(HloInstruction::CreateParameter(3, shape2, "y1"));
2882 std::vector<HloInstruction*> compute_vec = {
2883 bcompute.AddInstruction(
2884 HloInstruction::CreateBinary(shape1, HloOpcode::kMinimum, p2, p4)),
2885 bcompute.AddInstruction(
2886 HloInstruction::CreateBinary(shape2, HloOpcode::kMinimum, p3, p5))};
2887 bcompute.AddInstruction(HloInstruction::CreateTuple(compute_vec));
2888 auto compute_tuple = m_->AddEmbeddedComputation(bcompute.Build());
2889 std::vector<HloInstruction*> input_vec = {input1, input2};
2890 auto init1 = builder.AddInstruction(
2891 HloInstruction::CreateConstant(LiteralUtil::MaxValue(F32)));
2892 auto init2 = builder.AddInstruction(
2893 HloInstruction::CreateConstant(LiteralUtil::MaxValue(F32)));
2894 std::vector<HloInstruction*> init_vec = {init1, init2};
2895 auto padding = std::pair<int64, int64>(0, 0);
2896 TF_ASSERT_OK_AND_ASSIGN(auto window,
2897 ShapeInference::InferWindowFromDimensions(
2898 {3}, {2}, absl::MakeSpan(&padding, 1),
2899 /*lhs_dilation=*/{},
2900 /*rhs_dilation=*/{}));
2901 std::vector<const Shape*> input_shapes = {&input1->shape(), &input2->shape()};
2902 std::vector<const Shape*> init_shapes = {&init1->shape(), &init2->shape()};
2903 TF_ASSERT_OK_AND_ASSIGN(Shape shape,
2904 ShapeInference::InferReduceWindowShape(
2905 input_shapes, init_shapes, window,
2906 compute_tuple->ComputeProgramShape()));
2907 builder.AddInstruction(HloInstruction::CreateReduceWindow(
2908 shape, input_vec, init_vec, window, compute_tuple));
2909 auto r1 = LiteralUtil::CreateR1<float>({100, 1});
2910 auto expected = LiteralUtil::MakeTuple({&r1, &r1});
2911 m_->AddEntryComputation(builder.Build());
2912 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
2913 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
2914 }
2915
TEST_P(HloEvaluatorBf16Test,Min3In5Stride2TupleDiffInput)2916 TEST_P(HloEvaluatorBf16Test, Min3In5Stride2TupleDiffInput) {
2917 HloComputation::Builder builder("main");
2918 auto input1 = builder.AddInstruction(HloInstruction::CreateConstant(
2919 LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1})));
2920 auto input2 = builder.AddInstruction(HloInstruction::CreateConstant(
2921 LiteralUtil::CreateR1<int>({15, 28, 300, 107, 12})));
2922 HloComputation::Builder bcompute("ComputeFunction");
2923 auto shape1 = ShapeUtil::MakeShape(F32, {});
2924 auto shape2 = ShapeUtil::MakeShape(S32, {});
2925 auto p2 =
2926 bcompute.AddInstruction(HloInstruction::CreateParameter(0, shape1, "x0"));
2927 auto p3 =
2928 bcompute.AddInstruction(HloInstruction::CreateParameter(1, shape2, "x1"));
2929 auto p4 =
2930 bcompute.AddInstruction(HloInstruction::CreateParameter(2, shape1, "y0"));
2931 auto p5 =
2932 bcompute.AddInstruction(HloInstruction::CreateParameter(3, shape2, "y1"));
2933 std::vector<HloInstruction*> compute_vec = {
2934 bcompute.AddInstruction(
2935 HloInstruction::CreateBinary(shape1, HloOpcode::kMinimum, p2, p4)),
2936 bcompute.AddInstruction(
2937 HloInstruction::CreateBinary(shape2, HloOpcode::kMinimum, p3, p5))};
2938 bcompute.AddInstruction(HloInstruction::CreateTuple(compute_vec));
2939 auto compute_tuple = m_->AddEmbeddedComputation(bcompute.Build());
2940 std::vector<HloInstruction*> input_vec = {input1, input2};
2941 auto init1 = builder.AddInstruction(
2942 HloInstruction::CreateConstant(LiteralUtil::MaxValue(F32)));
2943 auto init2 = builder.AddInstruction(
2944 HloInstruction::CreateConstant(LiteralUtil::MaxValue(S32)));
2945 std::vector<HloInstruction*> init_vec = {init1, init2};
2946 auto padding = std::pair<int64, int64>(0, 0);
2947 TF_ASSERT_OK_AND_ASSIGN(auto window,
2948 ShapeInference::InferWindowFromDimensions(
2949 {3}, {2}, absl::MakeSpan(&padding, 1),
2950 /*lhs_dilation=*/{},
2951 /*rhs_dilation=*/{}));
2952 std::vector<const Shape*> input_shapes = {&input1->shape(), &input2->shape()};
2953 std::vector<const Shape*> init_shapes = {&init1->shape(), &init2->shape()};
2954 TF_ASSERT_OK_AND_ASSIGN(Shape shape,
2955 ShapeInference::InferReduceWindowShape(
2956 input_shapes, init_shapes, window,
2957 compute_tuple->ComputeProgramShape()));
2958 builder.AddInstruction(HloInstruction::CreateReduceWindow(
2959 shape, input_vec, init_vec, window, compute_tuple));
2960 auto r1 = LiteralUtil::CreateR1<float>({100, 1});
2961 auto r2 = LiteralUtil::CreateR1<int>({15, 12});
2962 auto expected = LiteralUtil::MakeTuple({&r1, &r2});
2963 m_->AddEntryComputation(builder.Build());
2964 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
2965 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
2966 }
2967
TEST_P(HloEvaluatorBf16Test,StridedSlice)2968 TEST_P(HloEvaluatorBf16Test, StridedSlice) {
2969 HloComputation::Builder b(TestName());
2970
2971 // arg:
2972 // f32[3,5] {
2973 // { 1, 2, 3, 4, 5 },
2974 // { 9, 10, 11, 12, 13 },
2975 // { 17, 18, 19, 20, 21 },
2976 // }
2977 auto operand_array = absl::make_unique<Array2D<float>>(3, 5);
2978 operand_array->FillUnique(1.0f);
2979 auto operand_literal =
2980 LiteralUtil::CreateR2FromArray2D<float>(*operand_array);
2981
2982 HloInstruction* operand = b.AddInstruction(
2983 HloInstruction::CreateConstant(std::move(operand_literal)));
2984
2985 Shape shape = ShapeUtil::MakeShape(F32, {2, 1});
2986 b.AddInstruction(HloInstruction::CreateSlice(shape, operand,
2987 /*start_indices=*/{0, 2},
2988 /*limit_indices=*/{3, 5},
2989 /*strides=*/{2, 3}));
2990 m_->AddEntryComputation(b.Build());
2991
2992 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
2993
2994 auto expected = LiteralUtil::CreateR2<float>({
2995 {3},
2996 {19},
2997 });
2998
2999 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3000 }
3001
TEST_P(HloEvaluatorBf16Test,DynamicSlice)3002 TEST_P(HloEvaluatorBf16Test, DynamicSlice) {
3003 HloComputation::Builder b(TestName());
3004
3005 // arg:
3006 // f32[2,4] {
3007 // { 1, 2, 3, 4 },
3008 // { 5, 6, 7, 8 },
3009 // }
3010 auto operand_array = absl::make_unique<Array2D<float>>(2, 4);
3011 operand_array->FillUnique(1.0f);
3012 auto operand_literal =
3013 LiteralUtil::CreateR2FromArray2D<float>(*operand_array);
3014
3015 HloInstruction* operand = b.AddInstruction(
3016 HloInstruction::CreateConstant(std::move(operand_literal)));
3017
3018 auto zero = b.AddInstruction(
3019 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
3020 auto one = b.AddInstruction(
3021 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
3022
3023 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3024 b.AddInstruction(
3025 HloInstruction::CreateDynamicSlice(shape, operand, {zero, one}, {2, 3}));
3026 m_->AddEntryComputation(b.Build());
3027
3028 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
3029
3030 auto expected = LiteralUtil::CreateR2<float>({
3031 {2, 3, 4},
3032 {6, 7, 8},
3033 });
3034
3035 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3036 }
3037
3038 // Verifies that the HloEvaluator's implementation goes along with existing
3039 // backends' behavior, although this is not required by the spec.
TEST_P(HloEvaluatorBf16Test,DynamicSliceModSlice)3040 TEST_P(HloEvaluatorBf16Test, DynamicSliceModSlice) {
3041 HloComputation::Builder b(TestName());
3042
3043 // arg:
3044 // f32[2,4] {
3045 // { 1, 2, 3, 4 },
3046 // { 5, 6, 7, 8 },
3047 // }
3048 auto operand_array = absl::make_unique<Array2D<float>>(2, 4);
3049 operand_array->FillUnique(1.0f);
3050 auto operand_literal =
3051 LiteralUtil::CreateR2FromArray2D<float>(*operand_array);
3052
3053 HloInstruction* operand = b.AddInstruction(
3054 HloInstruction::CreateConstant(std::move(operand_literal)));
3055
3056 auto two = b.AddInstruction(
3057 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(2)));
3058 auto one = b.AddInstruction(
3059 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
3060
3061 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3062 b.AddInstruction(
3063 HloInstruction::CreateDynamicSlice(shape, operand, {two, one}, {2, 3}));
3064 m_->AddEntryComputation(b.Build());
3065
3066 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
3067
3068 auto expected = LiteralUtil::CreateR2<float>({
3069 {2, 3, 4},
3070 {6, 7, 8},
3071 });
3072
3073 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3074 }
3075
TEST_P(HloEvaluatorBf16Test,DynamicSliceUpdate)3076 TEST_P(HloEvaluatorBf16Test, DynamicSliceUpdate) {
3077 HloComputation::Builder b(TestName());
3078
3079 // arg:
3080 // f32[2,3] {
3081 // { 1, 2, 3 },
3082 // { 5, 6, 7 },
3083 // }
3084 auto operand_array = absl::make_unique<Array2D<double>>(2, 3);
3085 operand_array->FillUnique(1.0);
3086 auto operand_literal =
3087 LiteralUtil::CreateR2FromArray2D<double>(*operand_array);
3088
3089 HloInstruction* operand = b.AddInstruction(
3090 HloInstruction::CreateConstant(std::move(operand_literal)));
3091
3092 auto zero = b.AddInstruction(
3093 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
3094 auto one = b.AddInstruction(
3095 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
3096
3097 auto update = b.AddInstruction(HloInstruction::CreateConstant(
3098 LiteralUtil::CreateR2<double>({{-2.0, -3.0}, {-6.0, -7.0}})));
3099
3100 Shape shape = ShapeUtil::MakeShape(F64, {2, 3});
3101 b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
3102 shape, operand, update, {zero, one}));
3103 m_->AddEntryComputation(b.Build());
3104
3105 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
3106
3107 auto expected = LiteralUtil::CreateR2<double>({
3108 {1, -2, -3},
3109 {5, -6, -7},
3110 });
3111
3112 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3113 }
3114
TEST_P(HloEvaluatorBf16Test,SetAndGetTuples)3115 TEST_P(HloEvaluatorBf16Test, SetAndGetTuples) {
3116 HloComputation::Builder b(TestName());
3117
3118 // arg:
3119 // f32[2,3] {
3120 // { 1, 2, 3 },
3121 // { 5, 6, 7 },
3122 // }
3123 auto operand_array = absl::make_unique<Array2D<double>>(2, 3);
3124 operand_array->FillUnique(1.0);
3125 auto operand_literal2 =
3126 LiteralUtil::CreateR2FromArray2D<double>(*operand_array);
3127
3128 HloInstruction* operand2 = b.AddInstruction(
3129 HloInstruction::CreateConstant(std::move(operand_literal2)));
3130 HloInstruction* operand1 = b.AddInstruction(
3131 HloInstruction::CreateConstant(LiteralUtil::CreateR1<int64>({0, 1})));
3132
3133 auto tuple =
3134 b.AddInstruction(HloInstruction::CreateTuple({operand1, operand2}));
3135
3136 Shape shape = ShapeUtil::MakeShape(F64, {2, 3});
3137 b.AddInstruction(HloInstruction::CreateGetTupleElement(shape, tuple, 1));
3138
3139 m_->AddEntryComputation(b.Build());
3140
3141 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
3142
3143 auto expected = LiteralUtil::CreateR2<double>({
3144 {1, 2, 3},
3145 {5, 6, 7},
3146 });
3147
3148 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3149 }
3150
TEST_P(HloEvaluatorBf16Test,SetAndGetNestedTuples)3151 TEST_P(HloEvaluatorBf16Test, SetAndGetNestedTuples) {
3152 HloComputation::Builder b(TestName());
3153
3154 // arg:
3155 // f32[2,3] {
3156 // { 1, 2, 3 },
3157 // { 5, 6, 7 },
3158 // }
3159 auto operand_array = absl::make_unique<Array2D<double>>(2, 3);
3160 operand_array->FillUnique(1.0);
3161
3162 HloInstruction* operand2 = b.AddInstruction(HloInstruction::CreateConstant(
3163 LiteralUtil::CreateR2FromArray2D<double>(*operand_array)));
3164 HloInstruction* operand1 = b.AddInstruction(
3165 HloInstruction::CreateConstant(LiteralUtil::CreateR1<int64>({0, 1})));
3166
3167 auto tuple1 =
3168 b.AddInstruction(HloInstruction::CreateTuple({operand1, operand2}));
3169 auto tuple2 =
3170 b.AddInstruction(HloInstruction::CreateTuple({operand2, operand2}));
3171
3172 auto outer_tuple =
3173 b.AddInstruction(HloInstruction::CreateTuple({tuple1, tuple2}));
3174
3175 b.AddInstruction(
3176 HloInstruction::CreateGetTupleElement(tuple2->shape(), outer_tuple, 1));
3177
3178 m_->AddEntryComputation(b.Build());
3179
3180 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
3181
3182 auto result_inner_literal =
3183 LiteralUtil::CreateR2FromArray2D<double>(*operand_array);
3184 auto expected =
3185 LiteralUtil::MakeTuple({&result_inner_literal, &result_inner_literal});
3186
3187 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3188 }
3189
TEST_P(HloEvaluatorBf16Test,Reverse)3190 TEST_P(HloEvaluatorBf16Test, Reverse) {
3191 HloComputation::Builder b(TestName());
3192
3193 // Input shape is float[4x3x2x1].
3194 // clang-format off
3195 Array4D<float> input({
3196 {{{1.0f}, {2.0f}},
3197 {{3.0f}, {4.0f}},
3198 {{5.0f}, {6.0f}}},
3199 {{{7.0f}, {8.0f}},
3200 {{9.0f}, {10.0f}},
3201 {{11.0f}, {12.0f}}},
3202 {{{13.0f}, {14.0f}},
3203 {{15.0f}, {16.0f}},
3204 {{17.0f}, {18.0f}}},
3205 {{{19.0f}, {20.0f}},
3206 {{21.0f}, {22.0f}},
3207 {{23.0f}, {24.0f}}},
3208 });
3209 // clang-format on
3210 auto operand_literal = LiteralUtil::CreateR4FromArray4D<float>(input);
3211 HloInstruction* operand = b.AddInstruction(
3212 HloInstruction::CreateConstant(std::move(operand_literal)));
3213
3214 const Shape shape = ShapeUtil::MakeShape(F32, {4, 3, 2, 1});
3215 b.AddInstruction(HloInstruction::CreateReverse(shape, operand, {0, 1}));
3216 m_->AddEntryComputation(b.Build());
3217
3218 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
3219
3220 // clang-format off
3221 auto expected = LiteralUtil::CreateR4FromArray4D<float>({
3222 {{{23.0f}, {24.0f}},
3223 {{21.0f}, {22.0f}},
3224 {{19.0f}, {20.0f}}},
3225
3226 {{{17.0f}, {18.0f}},
3227 {{15.0f}, {16.0f}},
3228 {{13.0f}, {14.0f}}},
3229
3230 {{{11.0f}, {12.0f}},
3231 {{9.0f}, {10.0f}},
3232 {{7.0f}, {8.0f}}},
3233
3234 {{{5.0f}, {6.0f}},
3235 {{3.0f}, {4.0f}},
3236 {{1.0f}, {2.0f}}},
3237 });
3238 // clang-format on
3239
3240 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3241 }
3242
TEST_P(HloEvaluatorBf16Test,EvaluateWithSubstitutions)3243 TEST_P(HloEvaluatorBf16Test, EvaluateWithSubstitutions) {
3244 HloComputation::Builder b(TestName());
3245 Shape shape = ShapeUtil::MakeShape(F32, {4});
3246
3247 HloInstruction* param0 =
3248 b.AddInstruction(HloInstruction::CreateParameter(0, shape, "param0"));
3249 HloInstruction* square = b.AddInstruction(HloInstruction::CreateBinary(
3250 shape, HloOpcode::kMultiply, param0, param0));
3251 HloInstruction* add = b.AddInstruction(
3252 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, square));
3253
3254 // Evaluate add with param0 = {1, 2, 3, 4}, square = {10, 20, 30, 40}.
3255 HloEvaluator evaluator;
3256 Literal param0_literal = LiteralUtil::CreateR1<float>({1, 2, 3, 4});
3257 Literal square_literal = LiteralUtil::CreateR1<float>({10, 20, 30, 40});
3258 TF_ASSERT_OK_AND_ASSIGN(
3259 Literal result,
3260 evaluator.EvaluateWithSubstitutions(
3261 add, {{param0, ¶m0_literal}, {square, &square_literal}}));
3262 EXPECT_TRUE(LiteralTestUtil::Equal(
3263 LiteralUtil::CreateR1<float>({11, 22, 33, 44}), result));
3264 }
3265
3266 // Check that EvaluateWithSubstitutions works if one of the operands to the op
3267 // we're evaluating is a constant.
TEST_P(HloEvaluatorBf16Test,EvaluateWithSubstitutionsWithConstantOperand)3268 TEST_P(HloEvaluatorBf16Test, EvaluateWithSubstitutionsWithConstantOperand) {
3269 HloComputation::Builder b(TestName());
3270 Shape shape = ShapeUtil::MakeShape(F32, {4});
3271
3272 HloInstruction* param0 =
3273 b.AddInstruction(HloInstruction::CreateParameter(0, shape, "param0"));
3274 HloInstruction* square = b.AddInstruction(HloInstruction::CreateBinary(
3275 shape, HloOpcode::kMultiply, param0, param0));
3276 HloInstruction* constant = b.AddInstruction(HloInstruction::CreateConstant(
3277 LiteralUtil::CreateR1<float>({1, 2, 3, 4})));
3278 HloInstruction* add = b.AddInstruction(
3279 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, constant, square));
3280
3281 // Evaluate add with square = {10, 20, 30, 40}.
3282 HloEvaluator evaluator;
3283 Literal square_literal = LiteralUtil::CreateR1<float>({10, 20, 30, 40});
3284 TF_ASSERT_OK_AND_ASSIGN(
3285 Literal result,
3286 evaluator.EvaluateWithSubstitutions(add, {{square, &square_literal}}));
3287 EXPECT_TRUE(LiteralTestUtil::Equal(
3288 LiteralUtil::CreateR1<float>({11, 22, 33, 44}), result));
3289 }
3290
TEST_F(HloEvaluatorTest,EvaluateGather_TensorFlowGatherV1)3291 TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) {
3292 const char* hlo_text = R"(
3293 HloModule TensorFlowGatherV1
3294
3295 ENTRY main {
3296 operand = s32[3,3] parameter(0)
3297 indices = s32[2] parameter(1)
3298 ROOT gather = s32[2,3] gather(operand, indices),
3299 offset_dims={1},
3300 collapsed_slice_dims={0},
3301 start_index_map={0},
3302 index_vector_dim=1,
3303 slice_sizes={1, 3}
3304 }
3305 )";
3306 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3307 Literal operand =
3308 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3309 Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
3310 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
3311 EXPECT_TRUE(LiteralTestUtil::Equal(
3312 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}), result));
3313 }
3314
TEST_F(HloEvaluatorTest,EvaluateGather_TensorFlowGatherV2)3315 TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) {
3316 const char* hlo_text = R"(
3317 HloModule TensorFlowGatherV2
3318
3319 ENTRY main {
3320 operand = s32[3,3] parameter(0)
3321 indices = s32[2] parameter(1)
3322 ROOT gather = s32[3,2] gather(operand, indices),
3323 offset_dims={0},
3324 collapsed_slice_dims={1},
3325 start_index_map={1},
3326 index_vector_dim=1,
3327 slice_sizes={3, 1}
3328 }
3329 )";
3330 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3331 Literal operand =
3332 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3333 Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
3334 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
3335 EXPECT_TRUE(LiteralTestUtil::Equal(
3336 LiteralUtil::CreateR2<int32>({{1, 3}, {4, 6}, {7, 9}}), result));
3337 }
3338
TEST_F(HloEvaluatorTest,EvaluateGather_TensorFlowGatherMultipleBatchDims)3339 TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) {
3340 const char* hlo_text = R"(
3341 HloModule TensorFlowGatherMultipleBatchDims
3342
3343 ENTRY main {
3344 operand = s32[3,3] parameter(0)
3345 indices = s32[2,2] parameter(1)
3346 ROOT gather = s32[2,3,2] gather(operand, indices),
3347 offset_dims={1},
3348 collapsed_slice_dims={1},
3349 start_index_map={1},
3350 index_vector_dim=2,
3351 slice_sizes={3, 1}
3352 }
3353 )";
3354 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3355 Literal operand =
3356 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3357 Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
3358 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
3359 EXPECT_TRUE(LiteralTestUtil::Equal(
3360 LiteralUtil::CreateR3<int32>(
3361 {{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}),
3362 result));
3363 }
3364
TEST_F(HloEvaluatorTest,EvaluateGather_TensorFlowGatherNd)3365 TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) {
3366 const char* hlo_text = R"(
3367 HloModule TensorFlowGatherNd
3368
3369 ENTRY main {
3370 operand = s32[3,3,2] parameter(0)
3371 indices = s32[2,2] parameter(1)
3372 ROOT gather = s32[2,2] gather(operand, indices),
3373 offset_dims={1},
3374 collapsed_slice_dims={0,1},
3375 start_index_map={0,1},
3376 index_vector_dim=1,
3377 slice_sizes={1,1,2}
3378 }
3379 )";
3380 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3381 Literal operand =
3382 LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
3383 {{-4, 4}, {-5, 5}, {-6, 6}}, //
3384 {{-7, 7}, {-8, 8}, {-9, 9}}});
3385 Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
3386 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
3387 EXPECT_TRUE(LiteralTestUtil::Equal(
3388 LiteralUtil::CreateR2<int32>({{-1, 1}, {-4, 4}}), result));
3389 }
3390
TEST_F(HloEvaluatorTest,EvaluateGather_TensorFlowGatherNdNonDefaultIndexVectorDim)3391 TEST_F(HloEvaluatorTest,
3392 EvaluateGather_TensorFlowGatherNdNonDefaultIndexVectorDim) {
3393 const char* hlo_text = R"(
3394 HloModule TensorFlowGatherNd
3395
3396 ENTRY main {
3397 operand = s32[3,3,2] parameter(0)
3398 indices = s32[2,2] parameter(1)
3399 ROOT gather = s32[2,2] gather(operand, indices),
3400 offset_dims={1},
3401 collapsed_slice_dims={0,1},
3402 start_index_map={0,1},
3403 index_vector_dim=0,
3404 slice_sizes={1,1,2}
3405 }
3406 )";
3407 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3408 Literal operand =
3409 LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
3410 {{-4, 4}, {-5, 5}, {-6, 6}}, //
3411 {{-7, 7}, {-8, 8}, {-9, 9}}});
3412 Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
3413 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
3414 EXPECT_TRUE(LiteralTestUtil::Equal(
3415 LiteralUtil::CreateR2<int32>({{-2, 2}, {-1, 1}}), result));
3416 }
3417
TEST_F(HloEvaluatorTest,EvaluateGather_DynamicSlice)3418 TEST_F(HloEvaluatorTest, EvaluateGather_DynamicSlice) {
3419 const char* hlo_text = R"(
3420 HloModule DynamicSlice
3421
3422 ENTRY main {
3423 operand = s32[3,3] parameter(0)
3424 indices = s32[2] parameter(1)
3425 ROOT gather = s32[1,1] gather(operand, indices),
3426 offset_dims={0,1},
3427 collapsed_slice_dims={},
3428 start_index_map={0,1},
3429 index_vector_dim=0,
3430 slice_sizes={1,1}
3431 }
3432 )";
3433 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3434 Literal operand =
3435 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3436 Literal start_indices = LiteralUtil::CreateR1<int32>({1, 1});
3437 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
3438 EXPECT_TRUE(
3439 LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{5}}), result));
3440 }
3441
TEST_F(HloEvaluatorTest,EvaluateGather_BatchDynamicSlice)3442 TEST_F(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) {
3443 const char* hlo_text = R"(
3444 HloModule BatchDynamicSlice
3445
3446 ENTRY main {
3447 operand = s32[3,3] parameter(0)
3448 indices = s32[2,2] parameter(1)
3449 ROOT gather = s32[2,1,1] gather(operand, indices),
3450 offset_dims={1,2},
3451 collapsed_slice_dims={},
3452 start_index_map={0,1},
3453 index_vector_dim=0,
3454 slice_sizes={1,1}
3455 }
3456 )";
3457 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3458 Literal operand =
3459 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3460 Literal start_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
3461 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
3462 EXPECT_TRUE(LiteralTestUtil::Equal(
3463 LiteralUtil::CreateR3<int32>({{{8}}, {{5}}}), result));
3464 }
3465
TEST_F(HloEvaluatorTest,EvaluateGather_ZeroDimBounds)3466 TEST_F(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) {
3467 const char* hlo_text = R"(
3468 HloModule TensorFlowGatherV1
3469
3470 ENTRY main {
3471 operand = s32[3,0] parameter(0)
3472 indices = s32[2] parameter(1)
3473 ROOT gather = s32[2,0] gather(operand, indices),
3474 offset_dims={1},
3475 collapsed_slice_dims={0},
3476 start_index_map={0},
3477 index_vector_dim=1,
3478 slice_sizes={1, 0}
3479 }
3480 )";
3481 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3482 Literal operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
3483 Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
3484 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
3485 EXPECT_TRUE(
3486 LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{}, {}}), result));
3487 }
3488
TEST_F(HloEvaluatorTest,EvaluateGather_NoOutputWindowDims)3489 TEST_F(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) {
3490 const string hlo_text = R"(
3491 HloModule GatherXd
3492
3493 ENTRY main {
3494 operand = s32[3] parameter(0)
3495 indices = s32[2,2,1] parameter(1)
3496 ROOT gather = s32[2,2] gather(operand, indices),
3497 offset_dims={},
3498 collapsed_slice_dims={0},
3499 start_index_map={0},
3500 index_vector_dim=2,
3501 slice_sizes={1}
3502 }
3503 )";
3504 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3505
3506 Literal operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
3507 Literal start_indices =
3508 LiteralUtil::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
3509 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
3510 EXPECT_TRUE(LiteralTestUtil::Equal(
3511 LiteralUtil::CreateR2<int32>({{0, 1}, {2, 1}}), result));
3512 }
3513
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatterV1_Update)3514 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV1_Update) {
3515 const char* hlo_text = R"(
3516 HloModule TensorFlowScatterV1
3517
3518 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3519 lhs = s32[] parameter(0)
3520 ROOT rhs = s32[] parameter(1)
3521 }
3522
3523 ENTRY main {
3524 operand = s32[3,3] parameter(0)
3525 indices = s32[2] parameter(1)
3526 updates = s32[2,3] parameter(2)
3527 ROOT scatter = s32[3,3] scatter(operand, indices, updates),
3528 to_apply=update_s32,
3529 update_window_dims={1},
3530 inserted_window_dims={0},
3531 scatter_dims_to_operand_dims={0},
3532 index_vector_dim=1
3533 }
3534 )";
3535 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3536 Literal operand =
3537 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3538 Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
3539 Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
3540 TF_ASSERT_OK_AND_ASSIGN(Literal result,
3541 Evaluate({&operand, &scatter_indices, &updates}));
3542 EXPECT_TRUE(LiteralTestUtil::Equal(
3543 LiteralUtil::CreateR2<int32>({{10, 20, 30}, {4, 5, 6}, {70, 80, 90}}),
3544 result));
3545 }
3546
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatterV2_Update)3547 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV2_Update) {
3548 const char* hlo_text = R"(
3549 HloModule TensorFlowScatterV2
3550
3551 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3552 lhs = s32[] parameter(0)
3553 ROOT rhs = s32[] parameter(1)
3554 }
3555
3556 ENTRY main {
3557 operand = s32[3,3] parameter(0)
3558 indices = s32[2] parameter(1)
3559 updates = s32[3,2] parameter(2)
3560 ROOT scatter = s32[3,3] scatter(operand, indices, updates),
3561 to_apply=update_s32,
3562 update_window_dims={0},
3563 inserted_window_dims={1},
3564 scatter_dims_to_operand_dims={1},
3565 index_vector_dim=1
3566 }
3567 )";
3568 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3569 Literal operand =
3570 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3571 Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
3572 Literal updates =
3573 LiteralUtil::CreateR2<int32>({{10, 30}, {40, 60}, {70, 90}});
3574 TF_ASSERT_OK_AND_ASSIGN(Literal result,
3575 Evaluate({&operand, &scatter_indices, &updates}));
3576 EXPECT_TRUE(LiteralTestUtil::Equal(
3577 LiteralUtil::CreateR2<int32>({{10, 2, 30}, {40, 5, 60}, {70, 8, 90}}),
3578 result));
3579 }
3580
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatter_Add)3581 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Add) {
3582 const char* hlo_text = R"(
3583 HloModule TensorFlowScatter
3584
3585 add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3586 lhs = s32[] parameter(0)
3587 rhs = s32[] parameter(1)
3588 ROOT add = s32[] add(s32[] lhs, s32[] rhs)
3589 }
3590
3591 ENTRY main {
3592 operand = s32[3,3] parameter(0)
3593 indices = s32[2] parameter(1)
3594 updates = s32[2,3] parameter(2)
3595 ROOT scatter = s32[3,3] scatter(operand, indices, updates),
3596 to_apply=add_s32,
3597 update_window_dims={1},
3598 inserted_window_dims={0},
3599 scatter_dims_to_operand_dims={0},
3600 index_vector_dim=1
3601 }
3602 )";
3603 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3604 Literal operand =
3605 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3606 Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
3607 Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
3608 TF_ASSERT_OK_AND_ASSIGN(Literal result,
3609 Evaluate({&operand, &scatter_indices, &updates}));
3610 EXPECT_TRUE(LiteralTestUtil::Equal(
3611 LiteralUtil::CreateR2<int32>({{11, 22, 33}, {4, 5, 6}, {77, 88, 99}}),
3612 result));
3613 }
3614
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatter_Mul)3615 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Mul) {
3616 const char* hlo_text = R"(
3617 HloModule TensorFlowScatter
3618
3619 mul_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3620 lhs = s32[] parameter(0)
3621 rhs = s32[] parameter(1)
3622 ROOT mul = s32[] multiply(s32[] lhs, s32[] rhs)
3623 }
3624
3625 ENTRY main {
3626 operand = s32[3,3] parameter(0)
3627 indices = s32[2] parameter(1)
3628 updates = s32[2,3] parameter(2)
3629 ROOT scatter = s32[3,3] scatter(operand, indices, updates),
3630 to_apply=mul_s32,
3631 update_window_dims={1},
3632 inserted_window_dims={0},
3633 scatter_dims_to_operand_dims={0},
3634 index_vector_dim=1
3635 }
3636 )";
3637 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3638 Literal operand =
3639 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3640 Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
3641 Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
3642 TF_ASSERT_OK_AND_ASSIGN(Literal result,
3643 Evaluate({&operand, &scatter_indices, &updates}));
3644 EXPECT_TRUE(LiteralTestUtil::Equal(
3645 LiteralUtil::CreateR2<int32>({{10, 40, 90}, {4, 5, 6}, {490, 640, 810}}),
3646 result));
3647 }
3648
TEST_P(HloEvaluatorBf16Test,EvaluateScatter_TensorFlowScatter_F32)3649 TEST_P(HloEvaluatorBf16Test, EvaluateScatter_TensorFlowScatter_F32) {
3650 const char* hlo_text = R"(
3651 HloModule TensorFlowScatter
3652
3653 add_f32 (lhs: f32[], rhs: f32[]) -> f32[] {
3654 lhs = f32[] parameter(0)
3655 rhs = f32[] parameter(1)
3656 ROOT add = f32[] add(f32[] lhs, f32[] rhs)
3657 }
3658
3659 ENTRY main {
3660 operand = f32[3,3] parameter(0)
3661 indices = s32[2] parameter(1)
3662 updates = f32[2,3] parameter(2)
3663 ROOT scatter = f32[3,3] scatter(operand, indices, updates),
3664 to_apply=add_f32,
3665 update_window_dims={1},
3666 inserted_window_dims={0},
3667 scatter_dims_to_operand_dims={0},
3668 index_vector_dim=1
3669 }
3670 )";
3671 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3672 Literal operand = LiteralUtil::CreateR2<float>(
3673 {{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}});
3674 Literal scatter_indices = LiteralUtil::CreateR1<int32>({2, 1});
3675 Literal updates =
3676 LiteralUtil::CreateR2<float>({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}});
3677 TF_ASSERT_OK_AND_ASSIGN(Literal result,
3678 Evaluate({&operand, &scatter_indices, &updates}));
3679 EXPECT_TRUE(LiteralTestUtil::Near(
3680 LiteralUtil::CreateR2<float>(
3681 {{1.1, 2.2, 3.3}, {6.7, 8.6, 8.2}, {8.1, 9.9, 10.6}}),
3682 result, ErrorSpec{0.1, 0.01}));
3683 }
3684
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatter_RepeatedIndices)3685 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_RepeatedIndices) {
3686 const char* hlo_text = R"(
3687 HloModule TensorFlowScatter
3688
3689 add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3690 lhs = s32[] parameter(0)
3691 rhs = s32[] parameter(1)
3692 ROOT add = s32[] add(s32[] lhs, s32[] rhs)
3693 }
3694
3695 ENTRY main {
3696 operand = s32[3,3] parameter(0)
3697 indices = s32[2] parameter(1)
3698 updates = s32[2,3] parameter(2)
3699 ROOT scatter = s32[3,3] scatter(operand, indices, updates),
3700 to_apply=add_s32,
3701 update_window_dims={1},
3702 inserted_window_dims={0},
3703 scatter_dims_to_operand_dims={0},
3704 index_vector_dim=1
3705 }
3706 )";
3707 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3708 Literal operand =
3709 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3710 Literal scatter_indices = LiteralUtil::CreateR1<int32>({1, 1});
3711 Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
3712 TF_ASSERT_OK_AND_ASSIGN(Literal result,
3713 Evaluate({&operand, &scatter_indices, &updates}));
3714 EXPECT_TRUE(LiteralTestUtil::Equal(
3715 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {84, 105, 126}, {7, 8, 9}}),
3716 result));
3717 }
3718
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatter_MultipleBatchDims)3719 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_MultipleBatchDims) {
3720 const char* hlo_text = R"(
3721 HloModule TensorFlowScatterMultipleBatchDims
3722
3723 add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3724 lhs = s32[] parameter(0)
3725 rhs = s32[] parameter(1)
3726 ROOT add = s32[] add(s32[] lhs, s32[] rhs)
3727 }
3728
3729 ENTRY main {
3730 operand = s32[3,3] parameter(0)
3731 indices = s32[2,2] parameter(1)
3732 updates = s32[2,3,2] parameter(2)
3733 ROOT scatter = s32[3,3] scatter(operand, indices, updates),
3734 to_apply=add_s32,
3735 update_window_dims={1},
3736 inserted_window_dims={1},
3737 scatter_dims_to_operand_dims={1},
3738 index_vector_dim=2
3739 }
3740 )";
3741 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3742 Literal operand =
3743 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3744 Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
3745 Literal updates = LiteralUtil::CreateR3<int32>(
3746 {{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}});
3747 TF_ASSERT_OK_AND_ASSIGN(Literal result,
3748 Evaluate({&operand, &scatter_indices, &updates}));
3749 EXPECT_TRUE(LiteralTestUtil::Equal(
3750 LiteralUtil::CreateR2<int32>({{11, 7, 38}, {44, 10, 71}, {77, 13, 104}}),
3751 result));
3752 }
3753
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatterNd)3754 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterNd) {
3755 const char* hlo_text = R"(
3756 HloModule TensorFlowScatterNd
3757
3758 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3759 lhs = s32[] parameter(0)
3760 ROOT rhs = s32[] parameter(1)
3761 }
3762
3763 ENTRY main {
3764 operand = s32[3,3,2] parameter(0)
3765 indices = s32[2,2] parameter(1)
3766 updates = s32[2,2] parameter(2)
3767 ROOT scatter = s32[3,3,2] scatter(operand, indices, updates),
3768 to_apply=update_s32,
3769 update_window_dims={1},
3770 inserted_window_dims={0,1},
3771 scatter_dims_to_operand_dims={0,1},
3772 index_vector_dim=1
3773 }
3774 )";
3775 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3776 Literal operand =
3777 LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
3778 {{-4, 4}, {-5, 5}, {-6, 6}}, //
3779 {{-7, 7}, {-8, 8}, {-9, 9}}});
3780 Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
3781 Literal updates = LiteralUtil::CreateR2<int32>({{-10, 10}, {-40, 40}});
3782 Literal expected =
3783 LiteralUtil::CreateR3<int32>({{{-10, 10}, {-2, 2}, {-3, 3}}, //
3784 {{-40, 40}, {-5, 5}, {-6, 6}}, //
3785 {{-7, 7}, {-8, 8}, {-9, 9}}});
3786 TF_ASSERT_OK_AND_ASSIGN(Literal result,
3787 Evaluate({&operand, &scatter_indices, &updates}));
3788 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3789 }
3790
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatterNd_NonDefaultIndexVectorDim)3791 TEST_F(HloEvaluatorTest,
3792 EvaluateScatter_TensorFlowScatterNd_NonDefaultIndexVectorDim) {
3793 const char* hlo_text = R"(
3794 HloModule TensorFlowScatterNdNonDefaultIndexVectorDim
3795
3796 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3797 lhs = s32[] parameter(0)
3798 ROOT rhs = s32[] parameter(1)
3799 }
3800
3801 ENTRY main {
3802 operand = s32[3,3,2] parameter(0)
3803 indices = s32[2,2] parameter(1)
3804 updates = s32[2,2] parameter(2)
3805 ROOT scatter = s32[3,3,2] scatter(operand, indices, updates),
3806 to_apply=update_s32,
3807 update_window_dims={1},
3808 inserted_window_dims={0,1},
3809 scatter_dims_to_operand_dims={0,1},
3810 index_vector_dim=0
3811 }
3812 )";
3813 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3814 Literal operand =
3815 LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
3816 {{-4, 4}, {-5, 5}, {-6, 6}}, //
3817 {{-7, 7}, {-8, 8}, {-9, 9}}});
3818 Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
3819 Literal updates = LiteralUtil::CreateR2<int32>({{-10, 10}, {-20, 20}});
3820 Literal expected =
3821 LiteralUtil::CreateR3<int32>({{{-20, 20}, {-10, 10}, {-3, 3}}, //
3822 {{-4, 4}, {-5, 5}, {-6, 6}}, //
3823 {{-7, 7}, {-8, 8}, {-9, 9}}});
3824 TF_ASSERT_OK_AND_ASSIGN(Literal result,
3825 Evaluate({&operand, &scatter_indices, &updates}));
3826 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3827 }
3828
TEST_F(HloEvaluatorTest,EvaluateScatter_DynamicUpdateSlice)3829 TEST_F(HloEvaluatorTest, EvaluateScatter_DynamicUpdateSlice) {
3830 const char* hlo_text = R"(
3831 HloModule DynamicUpdateSlice
3832
3833 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3834 lhs = s32[] parameter(0)
3835 ROOT rhs = s32[] parameter(1)
3836 }
3837
3838 ENTRY main {
3839 operand = s32[3,3] parameter(0)
3840 indices = s32[2] parameter(1)
3841 updates = s32[1,1] parameter(2)
3842 ROOT scatter = s32[3,3] scatter(operand, indices, updates),
3843 to_apply=update_s32,
3844 update_window_dims={0,1},
3845 inserted_window_dims={},
3846 scatter_dims_to_operand_dims={0,1},
3847 index_vector_dim=0
3848 }
3849 )";
3850 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3851 Literal operand =
3852 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3853 Literal scatter_indices = LiteralUtil::CreateR1<int32>({1, 1});
3854 Literal updates = LiteralUtil::CreateR2<int32>({{10}});
3855 Literal expected =
3856 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 10, 6}, {7, 8, 9}});
3857 TF_ASSERT_OK_AND_ASSIGN(Literal result,
3858 Evaluate({&operand, &scatter_indices, &updates}));
3859 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3860 }
3861
TEST_F(HloEvaluatorTest,EvaluateScatter_BatchDynamicUpdateSlice)3862 TEST_F(HloEvaluatorTest, EvaluateScatter_BatchDynamicUpdateSlice) {
3863 const char* hlo_text = R"(
3864 HloModule BatchDynamicUpdateSlice
3865
3866 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3867 lhs = s32[] parameter(0)
3868 ROOT rhs = s32[] parameter(1)
3869 }
3870
3871 ENTRY main {
3872 operand = s32[3,3] parameter(0)
3873 indices = s32[2,2] parameter(1)
3874 updates = s32[2,1,1] parameter(2)
3875 ROOT scatter = s32[3,3] scatter(operand, indices, updates),
3876 to_apply=update_s32,
3877 update_window_dims={1,2},
3878 inserted_window_dims={},
3879 scatter_dims_to_operand_dims={0,1},
3880 index_vector_dim=0
3881 }
3882 )";
3883 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3884 Literal operand =
3885 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3886 Literal scatter_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
3887 Literal updates = LiteralUtil::CreateR3<int32>({{{10}}, {{20}}});
3888 Literal expected =
3889 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 20, 6}, {7, 10, 9}});
3890 TF_ASSERT_OK_AND_ASSIGN(Literal result,
3891 Evaluate({&operand, &scatter_indices, &updates}));
3892 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3893 }
3894
TEST_F(HloEvaluatorTest,EvaluateScatter_ZeroDimBounds)3895 TEST_F(HloEvaluatorTest, EvaluateScatter_ZeroDimBounds) {
3896 const char* hlo_text = R"(
3897 HloModule TensorFlowScatter_ZeroDimBounds
3898
3899 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3900 lhs = s32[] parameter(0)
3901 ROOT rhs = s32[] parameter(1)
3902 }
3903
3904 ENTRY main {
3905 operand = s32[3,0] parameter(0)
3906 indices = s32[2] parameter(1)
3907 updates = s32[2,0] parameter(2)
3908 ROOT scatter = s32[3,0] scatter(operand, indices, updates),
3909 to_apply=update_s32,
3910 update_window_dims={1},
3911 inserted_window_dims={0},
3912 scatter_dims_to_operand_dims={0},
3913 index_vector_dim=1
3914 }
3915 )";
3916 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3917 Literal operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
3918 Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
3919 Literal updates = LiteralUtil::CreateR2<int32>({{}, {}});
3920 TF_ASSERT_OK_AND_ASSIGN(Literal result,
3921 Evaluate({&operand, &scatter_indices, &updates}));
3922 EXPECT_TRUE(LiteralTestUtil::Equal(operand, result));
3923 }
3924
TEST_F(HloEvaluatorTest,EvaluateScatter_NoUpdateWindowDims)3925 TEST_F(HloEvaluatorTest, EvaluateScatter_NoUpdateWindowDims) {
3926 const string hlo_text = R"(
3927 HloModule Scatter_NoUpdateWindowDims
3928
3929 add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3930 lhs = s32[] parameter(0)
3931 rhs = s32[] parameter(1)
3932 ROOT add = s32[] add(s32[] lhs, s32[] rhs)
3933 }
3934
3935 ENTRY main {
3936 operand = s32[3] parameter(0)
3937 indices = s32[2,2,1] parameter(1)
3938 updates = s32[2,2] parameter(2)
3939 ROOT scatter = s32[3] scatter(operand, indices, updates),
3940 to_apply=add_s32,
3941 update_window_dims={},
3942 inserted_window_dims={0},
3943 scatter_dims_to_operand_dims={0},
3944 index_vector_dim=2
3945 }
3946 )";
3947 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3948
3949 Literal operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
3950 Literal scatter_indices =
3951 LiteralUtil::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
3952 Literal updates = LiteralUtil::CreateR2<int32>({{10, 20}, {30, 40}});
3953 Literal expected = LiteralUtil::CreateR1<int32>({10, 61, 32});
3954 TF_ASSERT_OK_AND_ASSIGN(Literal result,
3955 Evaluate({&operand, &scatter_indices, &updates}));
3956 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3957 }
3958
TEST_F(HloEvaluatorTest,EvaluateScatter_NegativeIndices)3959 TEST_F(HloEvaluatorTest, EvaluateScatter_NegativeIndices) {
3960 const char* hlo_text = R"(
3961 HloModule TensorFlowScatter_NegativeIndices
3962
3963 add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3964 lhs = s32[] parameter(0)
3965 rhs = s32[] parameter(1)
3966 ROOT add = s32[] add(s32[] lhs, s32[] rhs)
3967 }
3968
3969 ENTRY main {
3970 operand = s32[3,3] parameter(0)
3971 indices = s32[2] parameter(1)
3972 updates = s32[2,3] parameter(2)
3973 ROOT scatter = s32[3,3] scatter(operand, indices, updates),
3974 to_apply=add_s32,
3975 update_window_dims={1},
3976 inserted_window_dims={0},
3977 scatter_dims_to_operand_dims={0},
3978 index_vector_dim=1
3979 }
3980 )";
3981 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
3982 ParseAndReturnVerifiedModule(hlo_text));
3983 Literal operand =
3984 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3985 // No updates should happen for the negative indices.
3986 Literal scatter_indices = LiteralUtil::CreateR1<int32>({-1, 2});
3987 Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
3988 EXPECT_TRUE(LiteralTestUtil::Equal(
3989 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {77, 88, 99}}),
3990 EvaluateWithModule(module.get(),
3991 {&operand, &scatter_indices, &updates})));
3992 }
3993
TEST_F(HloEvaluatorTest,EvaluateScatter_OobIndices)3994 TEST_F(HloEvaluatorTest, EvaluateScatter_OobIndices) {
3995 const string hlo_text = R"(
3996 HloModule BatchDynamicUpdateSlice
3997
3998 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3999 lhs = s32[] parameter(0)
4000 ROOT rhs = s32[] parameter(1)
4001 }
4002
4003 ENTRY main {
4004 operand = s32[3,3]{1,0} parameter(0)
4005 indices = s32[6,2]{1,0} parameter(1)
4006 updates = s32[6,1,1]{2,1,0} parameter(2)
4007 ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates),
4008 to_apply=update_s32,
4009 update_window_dims={1,2},
4010 inserted_window_dims={},
4011 scatter_dims_to_operand_dims={0,1},
4012 index_vector_dim=1
4013 }
4014 )";
4015 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
4016 ParseAndReturnVerifiedModule(hlo_text));
4017 Literal operand =
4018 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
4019 // No updates should happen for the OOB indices.
4020 Literal scatter_indices = LiteralUtil::CreateR2<int32>(
4021 {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}});
4022 Literal updates = LiteralUtil::CreateR3<int32>(
4023 {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}});
4024 EXPECT_TRUE(LiteralTestUtil::Equal(
4025 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 30, 60}, {7, 20, 9}}),
4026 EvaluateWithModule(module.get(),
4027 {&operand, &scatter_indices, &updates})));
4028 }
4029
TEST_F(HloEvaluatorTest,EvaluateScatter_OobUpdateWindow)4030 TEST_F(HloEvaluatorTest, EvaluateScatter_OobUpdateWindow) {
4031 const char* hlo_text = R"(
4032 HloModule TensorFlowScatterNd_OobUpdateWindow
4033
4034 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
4035 lhs = s32[] parameter(0)
4036 ROOT rhs = s32[] parameter(1)
4037 }
4038
4039 ENTRY main {
4040 operand = s32[3,3,2] parameter(0)
4041 indices = s32[1,2] parameter(1)
4042 updates = s32[1,2,2] parameter(2)
4043 ROOT scatter = s32[3,3,2] scatter(operand, indices, updates),
4044 to_apply=update_s32,
4045 update_window_dims={1,2},
4046 inserted_window_dims={0},
4047 scatter_dims_to_operand_dims={0,1},
4048 index_vector_dim=1
4049 }
4050 )";
4051 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
4052 ParseAndReturnVerifiedModule(hlo_text));
4053 Literal operand =
4054 LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
4055 {{-4, 4}, {-5, 5}, {-6, 6}}, //
4056 {{-7, 7}, {-8, 8}, {-9, 9}}});
4057 Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 2}});
4058 Literal updates = LiteralUtil::CreateR3<int32>({{{-10, 10}, {-40, 40}}});
4059 // Given the update window size of 2,2 and the index of 0,2, the update window
4060 // will be OOB. So, nothing should be updated.
4061 Literal expected = operand.Clone();
4062 EXPECT_TRUE(LiteralTestUtil::Equal(
4063 expected, EvaluateWithModule(module.get(),
4064 {&operand, &scatter_indices, &updates})));
4065 }
4066
4067 // Verifies that HloEvaluator evaluates a HLO instruction that performs
4068 // element-wise comparison with 2 bfloat16 operands.
TEST_F(HloEvaluatorTest,DoesCompareBF16)4069 TEST_F(HloEvaluatorTest, DoesCompareBF16) {
4070 // lhs >= rhs
4071 auto lhs = LiteralUtil::CreateR2<bfloat16>(
4072 {{bfloat16(0.25), bfloat16(0.35), bfloat16(0.125)},
4073 {bfloat16(-0.25), bfloat16(-0.35), bfloat16(-0.125)}});
4074 auto rhs = LiteralUtil::CreateR2<bfloat16>(
4075 {{bfloat16(0.5), bfloat16(0.125), bfloat16(0.125)},
4076 {bfloat16(0.25), bfloat16(-0.375), bfloat16(-0.127)}});
4077 auto expected =
4078 LiteralUtil::CreateR2<bool>({{false, true, true}, {false, true, true}});
4079
4080 HloComputation::Builder b(TestName());
4081 auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs)));
4082 auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs)));
4083 b.AddInstruction(HloInstruction::CreateCompare(expected.shape(), c1, c2,
4084 ComparisonDirection::kGe));
4085 m_->AddEntryComputation(b.Build());
4086
4087 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
4088 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4089 }
4090
TEST_P(HloEvaluatorBf16Test,Bf16Reduction)4091 TEST_P(HloEvaluatorBf16Test, Bf16Reduction) {
4092 const string hlo_text = R"(
4093 HloModule Bf16Reduction
4094
4095 add_bf16 (lhs: bf16[], rhs: bf16[]) -> bf16[] {
4096 lhs = bf16[] parameter(0)
4097 rhs = bf16[] parameter(1)
4098 ROOT add = bf16[] add(bf16[] lhs, bf16[] rhs)
4099 }
4100
4101 ENTRY main {
4102 arg0 = bf16[4]{0} parameter(0)
4103 init = bf16[] constant(0)
4104 ROOT %reduce = bf16[] reduce(arg0, init), dimensions={0}, to_apply=add_bf16
4105 }
4106 )";
4107 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4108
4109 Literal arg = LiteralUtil::CreateR1<bfloat16>(
4110 {bfloat16(1.0f), bfloat16(3.0f), bfloat16(-2.0f), bfloat16(42.0f)});
4111 Literal expected = LiteralUtil::CreateR0<bfloat16>(bfloat16(44.0f));
4112 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&arg}));
4113 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4114 }
4115
TEST_F(HloEvaluatorTest,MixedPrecisionReduction)4116 TEST_F(HloEvaluatorTest, MixedPrecisionReduction) {
4117 const string hlo_text = R"(
4118 HloModule MixedPrecisionReduction
4119
4120 add_f32 {
4121 lhs = f32[] parameter(0)
4122 rhs = f32[] parameter(1)
4123 ROOT add = f32[] add(lhs, rhs)
4124 }
4125
4126 ENTRY main {
4127 arg0 = f32[4]{0} parameter(0)
4128 init = f32[] constant(0)
4129 ROOT %reduce = bf16[] reduce(arg0, init), dimensions={0}, to_apply=add_f32
4130 }
4131 )";
4132 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4133
4134 Literal arg = LiteralUtil::CreateR1<float>({1.0f, 3.0f, -2.0f, 42.0f});
4135 Literal expected = LiteralUtil::CreateR0<bfloat16>(bfloat16(44.0f));
4136 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&arg}));
4137 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4138 }
4139
TEST_F(HloEvaluatorTest,DontFailOnCallUnimplementedOps)4140 TEST_F(HloEvaluatorTest, DontFailOnCallUnimplementedOps) {
4141 // Infeed triggers unimplemented error within HandleCall, and we verify that
4142 // the Evaluator does fail in such case.
4143 const string hlo_text = R"(
4144 HloModule DontFailOnCall
4145
4146 call {
4147 token0 = token[] after-all()
4148 ROOT infeed = ((u32[3]{0}, pred[]), token[]) infeed(token0)
4149 }
4150
4151 ENTRY main {
4152 ROOT result = ((u32[3]{0}, pred[]), token[]) call(), to_apply=call
4153 }
4154 )";
4155 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4156 auto statusor = Evaluate();
4157 EXPECT_FALSE(statusor.status().ok());
4158 }
4159
TEST_F(HloEvaluatorTest,DontFailOnFusionWithUnimplementedOps)4160 TEST_F(HloEvaluatorTest, DontFailOnFusionWithUnimplementedOps) {
4161 // Infeed triggers unimplemented error within HandleFusion, and we verify that
4162 // the Evaluator does fail in such case.
4163 const string hlo_text = R"(
4164 HloModule DontFailOnFusion
4165
4166 fused_computation {
4167 token0 = token[] after-all()
4168 ROOT infeed = ((u32[3]{0}, pred[]), token[]) infeed(token0)
4169 }
4170
4171 ENTRY main {
4172 ROOT result = ((u32[3]{0}, pred[]), token[]) fusion(), kind=kLoop, calls=fused_computation
4173 }
4174 )";
4175 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4176 auto statusor = Evaluate();
4177 EXPECT_FALSE(statusor.status().ok());
4178 }
4179
TEST_P(HloEvaluatorBf16Test,SliceWithDifferentLayout)4180 TEST_P(HloEvaluatorBf16Test, SliceWithDifferentLayout) {
4181 // Regression test for b/114735354.
4182 const string hlo_text = R"(
4183 HloModule SliceWithDifferentLayout
4184
4185 ENTRY main {
4186 arg = f32[2,2,2]{0,1,2} parameter(0)
4187 ROOT %slice = f32[2,2,2]{1,0,2} slice(f32[2,2,2]{0,1,2} %arg), slice={[0:2], [0:2], [0:2]}
4188 }
4189 )";
4190 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4191
4192 Literal arg = LiteralUtil::CreateR3WithLayout<float>(
4193 {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
4194 LayoutUtil::MakeLayout({0, 1, 2}));
4195 TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&arg}));
4196 EXPECT_TRUE(LiteralTestUtil::Equal(arg, actual));
4197 }
4198
TEST_P(HloEvaluatorBf16Test,Bitcast)4199 TEST_P(HloEvaluatorBf16Test, Bitcast) {
4200 // Regression test for b/114735354.
4201 const absl::string_view hlo_text_base = R"(
4202 HloModule Bitcast
4203
4204 ENTRY main {
4205 param = %s[32,121]{1,0} parameter(0)
4206 ROOT bitcast = %s[121,32,1]{0,1,2} bitcast(%s[32,121]{1,0} param)
4207 }
4208 )";
4209 string hlo_text;
4210 if (use_bfloat16_) {
4211 hlo_text = absl::StrFormat(hlo_text_base, "bf16", "bf16", "bf16");
4212 } else {
4213 hlo_text = absl::StrFormat(hlo_text_base, "f32", "f32", "f32");
4214 }
4215 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4216 auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie();
4217 TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&args[0]}));
4218 if (use_bfloat16_) {
4219 EXPECT_TRUE(
4220 absl::c_equal(args[0].data<bfloat16>(), actual.data<bfloat16>()));
4221 } else {
4222 EXPECT_TRUE(absl::c_equal(args[0].data<float>(), actual.data<float>()));
4223 }
4224 }
4225
4226 // Check that s32 under/overflow doesn't trigger a ubsan failure.
TEST_F(HloEvaluatorTest,Int32Overflow)4227 TEST_F(HloEvaluatorTest, Int32Overflow) {
4228 const absl::string_view hlo_text = R"(
4229 HloModule Test
4230
4231 ENTRY main {
4232 c1 = s32[] constant(1073741824) // 2^30
4233 sum = s32[] add(c1, c1) // 2^31, i.e. INT_MIN
4234
4235 c2 = s32[] constant(-2147483648) // -2^31
4236 sub = s32[] subtract(c2, c1) // -2^31 - 2^30, underflows
4237
4238 mul = s32[] multiply(c1, c1)
4239 ROOT tuple = (s32[], s32[], s32[]) tuple(sum, sub, mul)
4240 }
4241 )";
4242 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4243 TF_ASSERT_OK_AND_ASSIGN(auto literal, Evaluate({}));
4244 std::vector<Literal> actual = literal.DecomposeTuple();
4245 ASSERT_EQ(actual.size(), 3);
4246
4247 uint32 pow30 = uint32{1} << 30;
4248 uint32 pow31 = uint32{1} << 31;
4249 EXPECT_EQ(actual[0].GetFirstElement<int32>(), static_cast<int32>(pow31));
4250 EXPECT_EQ(actual[1].GetFirstElement<int32>(),
4251 static_cast<int32>(-(pow31 + pow30)));
4252 EXPECT_EQ(actual[2].GetFirstElement<int32>(),
4253 static_cast<int32>(pow31 * pow31));
4254 }
4255
TEST_F(HloEvaluatorTest,GetDimensionSize)4256 TEST_F(HloEvaluatorTest, GetDimensionSize) {
4257 const absl::string_view hlo_text = R"(
4258 HloModule Test
4259
4260 ENTRY main {
4261 size = s32[] parameter(0)
4262
4263 data = s32[4] parameter(1)
4264
4265 sum = s32[4] add(data, data)
4266
4267 ROOT dynamic_size = s32[] get-dimension-size(sum), dimensions={0}
4268 }
4269 )";
4270 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4271
4272 // Set up dynamic parameter binding.
4273 TF_CHECK_OK(m_->dynamic_parameter_binding().Bind(
4274 DynamicParameterBinding::DynamicParameter{0, {}},
4275 DynamicParameterBinding::DynamicDimension{1, {}, 0}));
4276
4277 TF_ASSERT_OK_AND_ASSIGN(DynamicDimensionInference dynamic_dimension_inference,
4278 DynamicDimensionInference::Run(m_.get()));
4279
4280 evaluator_.set_dynamic_dimension_inference(&dynamic_dimension_inference);
4281 Literal size_arg = LiteralUtil::CreateR0<int32>(3);
4282 Literal data_arg = LiteralUtil::CreateR1<int32>({1, 2, 3, 4});
4283
4284 TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&size_arg, &data_arg}));
4285
4286 EXPECT_EQ(actual.GetFirstElement<int32>(), static_cast<int32>(3));
4287 }
4288
4289 // Check that we get a useful error if we pass inputs of the wrong shape.
TEST_F(HloEvaluatorTest,EvaluateWithWrongInputShapes)4290 TEST_F(HloEvaluatorTest, EvaluateWithWrongInputShapes) {
4291 const absl::string_view hlo_text = R"(
4292 HloModule Test
4293
4294 ENTRY main {
4295 p0 = s32[1] parameter(0)
4296 ROOT sum = s32[1] add(p0, p0)
4297 }
4298 )";
4299 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4300 Literal input_wrong_shape = LiteralUtil::CreateR1<int32>({0, 1});
4301
4302 EXPECT_EQ(HloEvaluator()
4303 .Evaluate(*m_, {&input_wrong_shape})
4304 .status()
4305 .error_message(),
4306 "Shape mismatch at parameter 0. Computation expected s32[1]{0}, "
4307 "but arg was s32[2]{0}.");
4308 EXPECT_EQ(HloEvaluator()
4309 .Evaluate(*m_->entry_computation(), {&input_wrong_shape})
4310 .status()
4311 .error_message(),
4312 "Shape mismatch at parameter 0. Computation expected s32[1]{0}, "
4313 "but arg was s32[2]{0}.");
4314 }
4315
4316 // Check that we get a useful error if we pass too many or too few inputs.
TEST_F(HloEvaluatorTest,EvaluateWithWrongNumberOfInputs)4317 TEST_F(HloEvaluatorTest, EvaluateWithWrongNumberOfInputs) {
4318 const absl::string_view hlo_text = R"(
4319 HloModule Test
4320
4321 ENTRY main {
4322 p0 = s32[1] parameter(0)
4323 ROOT sum = s32[1] add(p0, p0)
4324 }
4325 )";
4326 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4327 Literal input = LiteralUtil::CreateR1<int32>({0});
4328
4329 EXPECT_EQ(
4330 HloEvaluator().Evaluate(*m_, {&input, &input}).status().error_message(),
4331 "Expected 1 argument, but got 2.");
4332 EXPECT_EQ(HloEvaluator()
4333 .Evaluate(*m_->entry_computation(), {&input, &input})
4334 .status()
4335 .error_message(),
4336 "Expected 1 argument, but got 2.");
4337 }
4338
TEST_F(HloEvaluatorTest,PreserveFusionInputLayout)4339 TEST_F(HloEvaluatorTest, PreserveFusionInputLayout) {
4340 const absl::string_view hlo_text = R"(
4341 HloModule FusionInputLayout
4342
4343 fused_computation {
4344 param_0 = f32[20,20]{0,1} parameter(0)
4345 ROOT bitcast = f32[20,20]{1,0} bitcast(param_0)
4346 }
4347
4348 ENTRY kernel_entry {
4349 parameter.0 = f32[20,20]{0,1} parameter(0)
4350 ROOT fusion = f32[20,20]{1,0} fusion(parameter.0),
4351 kind=kLoop, calls=fused_computation
4352 })";
4353
4354 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4355 auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie();
4356
4357 TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&args[0]}));
4358 EXPECT_TRUE(absl::c_equal(args[0].data<float>(), actual.data<float>()));
4359 }
4360
TEST_F(HloEvaluatorTest,PreserveFusionOutputLayout)4361 TEST_F(HloEvaluatorTest, PreserveFusionOutputLayout) {
4362 const absl::string_view hlo_text = R"(
4363 HloModule FusionOutputLayout
4364
4365 fused_computation {
4366 param_0 = f32[20,20]{1,0} parameter(0)
4367 ROOT bitcast = f32[20,20]{0,1} bitcast(param_0)
4368 }
4369
4370 ENTRY kernel_entry {
4371 parameter.0 = f32[20,20]{1,0} parameter(0)
4372 ROOT fusion = f32[20,20]{0,1} fusion(parameter.0),
4373 kind=kLoop, calls=fused_computation
4374 })";
4375
4376 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4377 auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie();
4378 TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&args[0]}));
4379 EXPECT_TRUE(absl::c_equal(args[0].data<float>(), actual.data<float>()));
4380 }
4381
TEST_F(HloEvaluatorTest,PreserveMOFusionOutputLayout)4382 TEST_F(HloEvaluatorTest, PreserveMOFusionOutputLayout) {
4383 const absl::string_view hlo_text = R"(
4384 HloModule MOFusionOutputLayout
4385
4386 fused_computation {
4387 param_0 = f32[20,20]{1,0} parameter(0)
4388 bitcast = f32[20,20]{0,1} bitcast(param_0)
4389 ROOT tuple = (f32[20,20]{0,1}) tuple(bitcast)
4390 }
4391
4392 ENTRY kernel_entry {
4393 parameter.0 = f32[20,20]{1,0} parameter(0)
4394 ROOT fusion = (f32[20,20]{0,1}) fusion(parameter.0),
4395 kind=kLoop, calls=fused_computation
4396 })";
4397
4398 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4399 auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie();
4400 TF_ASSERT_OK_AND_ASSIGN(Literal actual_tuple, Evaluate({&args[0]}));
4401 std::vector<Literal> actual_literals = actual_tuple.DecomposeTuple();
4402 EXPECT_TRUE(
4403 absl::c_equal(args[0].data<float>(), actual_literals[0].data<float>()));
4404 }
4405
4406 // Tests that custom_calls fail to evaluate when no handler is specified.
TEST_F(HloEvaluatorTest,EvaluateCustomCall_NoHandler)4407 TEST_F(HloEvaluatorTest, EvaluateCustomCall_NoHandler) {
4408 const absl::string_view hlo_text = R"(
4409 HloModule EvaluateCustomCall_NoHandler
4410 ENTRY kernel_entry {
4411 parameter.0 = u32[2,2]{1,0} parameter(0)
4412 ROOT test_root = (u32[2,2]{1,0}) custom-call(parameter.0),
4413 custom_call_target="_my_custom_call"
4414 }
4415 )";
4416
4417 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4418 auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie();
4419 EXPECT_EQ(HloEvaluator().Evaluate(*m_, {&args[0]}).status().code(),
4420 ::tensorflow::error::UNIMPLEMENTED);
4421 }
4422
4423 // Tests when a custom_call handler returns an error.
TEST_F(HloEvaluatorTest,EvaluateCustomCall_HandlerError)4424 TEST_F(HloEvaluatorTest, EvaluateCustomCall_HandlerError) {
4425 const absl::string_view hlo_text = R"(
4426 HloModule EvaluateCustomCall_HandlerError
4427 ENTRY kernel_entry {
4428 parameter.0 = u32[2,2]{1,0} parameter(0)
4429 ROOT test_root = (u32[2,2]{1,0}) custom-call(parameter.0),
4430 custom_call_target="_my_custom_call"
4431 }
4432 )";
4433
4434 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4435 auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie();
4436 HloEvaluator evaluator;
4437 evaluator.set_custom_call_handler(
4438 [](HloInstruction* custom_call, absl::Span<const Literal*> operands) {
4439 return InternalError("Test error");
4440 });
4441 EXPECT_EQ(evaluator.Evaluate(*m_, {&args[0]}).status().code(),
4442 ::tensorflow::error::INTERNAL);
4443 }
4444
4445 // Tests the custom_call handler on calls with many inputs.
4446 // We sum the operands so that we can verify the operand and output literals
4447 // are properly mapped for access.
TEST_F(HloEvaluatorTest,EvaluateCustomCall_ManyInputs)4448 TEST_F(HloEvaluatorTest, EvaluateCustomCall_ManyInputs) {
4449 const absl::string_view hlo_text = R"(
4450 HloModule EvaluateCustomCall_ManyInputs
4451 ENTRY kernel_entry {
4452 parameter.0 = u32[1]{0} parameter(0)
4453 parameter.1 = u32[1]{0} parameter(1)
4454 ROOT test_root = u32[1]{0} custom-call(parameter.0, parameter.1),
4455 custom_call_target="_my_custom_call"
4456 }
4457 )";
4458
4459 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4460 auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie();
4461 HloEvaluator evaluator;
4462 evaluator.set_custom_call_handler(
4463 [](HloInstruction* custom_call, absl::Span<const Literal*> operands) {
4464 EXPECT_EQ(HloOpcode::kCustomCall, custom_call->opcode());
4465 EXPECT_EQ("_my_custom_call", custom_call->custom_call_target());
4466 EXPECT_EQ(2, custom_call->operand_count());
4467 EXPECT_EQ(2, operands.size());
4468 auto output = Literal::CreateFromShape(custom_call->shape());
4469 auto operand0_data = operands[0]->data<uint32>();
4470 auto operand1_data = operands[1]->data<uint32>();
4471 auto output_data = output.data<uint32>();
4472 output_data[0] = operand0_data[0] + operand1_data[0];
4473 return output;
4474 });
4475 TF_ASSERT_OK_AND_ASSIGN(
4476 Literal actual_literal,
4477 evaluator.Evaluate(*m_->entry_computation(), {&args[0], &args[1]}));
4478 auto arg0_data = args[0].data<uint32>();
4479 auto arg1_data = args[1].data<uint32>();
4480 std::vector<uint32> expected_data = {arg0_data[0] + arg1_data[0]};
4481 EXPECT_TRUE(absl::c_equal(expected_data, actual_literal.data<uint32>()));
4482 }
4483
TEST_F(HloEvaluatorTest,IsFiniteF16)4484 TEST_F(HloEvaluatorTest, IsFiniteF16) {
4485 const absl::string_view hlo_text = R"(
4486 HloModule test
4487
4488 ENTRY IsFiniteTest {
4489 c = f16[6] constant({nan, 7, nan, -1, inf, -inf})
4490 ROOT is-finite = pred[6] is-finite(c)
4491 })";
4492
4493 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4494 TF_ASSERT_OK_AND_ASSIGN(
4495 Literal actual_literal,
4496 HloEvaluator().Evaluate(*m_->entry_computation(), {}));
4497 EXPECT_THAT(actual_literal.data<bool>(),
4498 ::testing::ElementsAre(false, true, false, true, false, false));
4499 }
4500
TEST_F(HloEvaluatorTest,IsFiniteBf16)4501 TEST_F(HloEvaluatorTest, IsFiniteBf16) {
4502 const absl::string_view hlo_text = R"(
4503 HloModule test
4504
4505 ENTRY IsFiniteTest {
4506 c = bf16[6] constant({nan, 7, nan, -1, inf, -inf})
4507 ROOT is-finite = pred[6] is-finite(c)
4508 })";
4509
4510 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4511 TF_ASSERT_OK_AND_ASSIGN(
4512 Literal actual_literal,
4513 HloEvaluator().Evaluate(*m_->entry_computation(), {}));
4514 EXPECT_THAT(actual_literal.data<bool>(),
4515 ::testing::ElementsAre(false, true, false, true, false, false));
4516 }
4517
4518 // Check that evaluating `f32[<huge>, 0] iota` doesn't oom (it's an empty
4519 // array!).
TEST_F(HloEvaluatorTest,ZeroSizedIotaWithHugeDimension)4520 TEST_F(HloEvaluatorTest, ZeroSizedIotaWithHugeDimension) {
4521 const absl::string_view hlo_text = R"(
4522 HloModule test
4523 ENTRY t {
4524 ROOT i = f32[1000000000000, 0] iota(), iota_dimension=0
4525 })";
4526 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4527 TF_ASSERT_OK_AND_ASSIGN(
4528 Literal actual_literal,
4529 HloEvaluator().Evaluate(*m_->entry_computation(), {}));
4530 EXPECT_THAT(actual_literal.data<float>(), ::testing::IsEmpty());
4531 }
4532
TEST_F(HloEvaluatorTest,CopyStartCopyDone)4533 TEST_F(HloEvaluatorTest, CopyStartCopyDone) {
4534 const absl::string_view hlo_text = R"(
4535 HloModule test
4536 ENTRY CopyStartCopyDone {
4537 init = f32[] constant(42.0)
4538 copy-start = (f32[]{:S(1)}, f32[], u32[]) copy-start(init)
4539 ROOT copy-done = f32[] copy-done(copy-start)
4540 }
4541 )";
4542 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4543 Literal expected = LiteralUtil::CreateR0<float>(42.0f);
4544 TF_ASSERT_OK_AND_ASSIGN(
4545 Literal result, HloEvaluator().Evaluate(*m_->entry_computation(), {}));
4546 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4547 }
4548
TEST_F(HloEvaluatorTest,MapBF16)4549 TEST_F(HloEvaluatorTest, MapBF16) {
4550 const absl::string_view hlo_text = R"(
4551 HloModule test
4552
4553 map_computation {
4554 p = bf16[] parameter(0)
4555 add = bf16[] add(p, p)
4556 ROOT conv = f32[] convert(add)
4557 }
4558
4559 ENTRY CopyStartCopyDone {
4560 c = bf16[3] constant({1, 2, 3})
4561 ROOT map = f32[3] map(c), to_apply=map_computation
4562 }
4563 )";
4564 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4565 Literal expected = LiteralUtil::CreateR1<float>({2.f, 4.f, 6.f});
4566 TF_ASSERT_OK_AND_ASSIGN(
4567 Literal result, HloEvaluator().Evaluate(*m_->entry_computation(), {}));
4568 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4569 }
4570
TEST_F(HloEvaluatorTest,MapS16)4571 TEST_F(HloEvaluatorTest, MapS16) {
4572 const absl::string_view hlo_text = R"(
4573 HloModule test
4574
4575 map_computation {
4576 p = s16[] parameter(0)
4577 add = s16[] add(p, p)
4578 ROOT conv = f32[] convert(add)
4579 }
4580
4581 ENTRY CopyStartCopyDone {
4582 c = s16[3] constant({1, 2, 3})
4583 ROOT map = f32[3] map(c), to_apply=map_computation
4584 }
4585 )";
4586 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4587 Literal expected = LiteralUtil::CreateR1<float>({2.f, 4.f, 6.f});
4588 TF_ASSERT_OK_AND_ASSIGN(
4589 Literal result, HloEvaluator().Evaluate(*m_->entry_computation(), {}));
4590 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4591 }
4592
TEST_F(HloEvaluatorTest,MapU16)4593 TEST_F(HloEvaluatorTest, MapU16) {
4594 const absl::string_view hlo_text = R"(
4595 HloModule test
4596
4597 map_computation {
4598 p = u16[] parameter(0)
4599 add = u16[] add(p, p)
4600 ROOT conv = f32[] convert(add)
4601 }
4602
4603 ENTRY CopyStartCopyDone {
4604 c = u16[3] constant({1, 2, 3})
4605 ROOT map = f32[3] map(c), to_apply=map_computation
4606 }
4607 )";
4608 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4609 Literal expected = LiteralUtil::CreateR1<float>({2.f, 4.f, 6.f});
4610 TF_ASSERT_OK_AND_ASSIGN(
4611 Literal result, HloEvaluator().Evaluate(*m_->entry_computation(), {}));
4612 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4613 }
4614
TEST_F(HloEvaluatorTest,DotUpcast)4615 TEST_F(HloEvaluatorTest, DotUpcast) {
4616 const absl::string_view hlo_text = R"(
4617 HloModule test
4618 ENTRY DotUpcast {
4619 l = s16[4,3]{1,0} parameter(0)
4620 r = s8[3,2]{1,0} parameter(1)
4621 ROOT result = s32[4,2] dot(l, r), lhs_contracting_dims={1},
4622 rhs_contracting_dims={0}
4623 }
4624 )";
4625 // lhs:
4626 // s16[4,3] {
4627 // { 1, 2, 3 },
4628 // { 5, 6, 7 },
4629 // { 9, 10, 11 },
4630 // { 13, 14, 15 },
4631 // }
4632 auto lhs_array = absl::make_unique<Array2D<int16>>(4, 3);
4633 lhs_array->FillUnique(1);
4634 auto lhs_literal = LiteralUtil::CreateR2FromArray2D<int16>(*lhs_array);
4635
4636 // rhs:
4637 // s8[3,2] {
4638 // { 1, 2 },
4639 // { 3, 4 },
4640 // { 5, 6 },
4641 // }
4642 auto rhs_array = absl::make_unique<Array2D<int8>>(3, 2);
4643 rhs_array->FillUnique(1);
4644 auto rhs_literal = LiteralUtil::CreateR2FromArray2D<int8>(*rhs_array);
4645 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4646 TF_ASSERT_OK_AND_ASSIGN(Literal result,
4647 Evaluate({&lhs_literal, &rhs_literal}));
4648
4649 auto expected_array =
4650 Array2D<int32>({{22, 28}, {58, 76}, {94, 124}, {130, 172}});
4651 auto expected = LiteralUtil::CreateR2FromArray2D<int32>(expected_array);
4652
4653 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4654 }
4655
TEST_F(HloEvaluatorTest,SortC64)4656 TEST_F(HloEvaluatorTest, SortC64) {
4657 const absl::string_view hlo_text = R"(
4658 HloModule m
4659
4660 sort_lt_comparator {
4661 parameter.0 = c64[] parameter(0)
4662 real.0 = f32[] real(parameter.0)
4663 parameter.1 = c64[] parameter(1)
4664 real.1 = f32[] real(parameter.1)
4665 ROOT compare = pred[] compare(real.0, real.1), direction=LT
4666 }
4667
4668 ENTRY main {
4669 c = c64[3] constant({(2, 0), (4, 0), (6, 0)})
4670 ROOT sort = c64[3]{0} sort(c), dimensions={0}, to_apply=sort_lt_comparator
4671 }
4672 )";
4673 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4674 Literal expected =
4675 LiteralUtil::CreateR1<std::complex<float>>({2.f, 4.f, 6.f});
4676 TF_ASSERT_OK_AND_ASSIGN(
4677 Literal result, HloEvaluator().Evaluate(*m_->entry_computation(), {}));
4678 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4679 }
4680
4681 } // namespace
4682 } // namespace xla
4683