1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 Licensed under the Apache License, Version 2.0 (the "License");
3 you may not use this file except in compliance with the License.
4 You may obtain a copy of the License at
5 http://www.apache.org/licenses/LICENSE-2.0
6 Unless required by applicable law or agreed to in writing, software
7 distributed under the License is distributed on an "AS IS" BASIS,
8 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 See the License for the specific language governing permissions and
10 limitations under the License.
11 ==============================================================================*/
12
13 #include <vector>
14
15 #include "tensorflow/compiler/xla/error_spec.h"
16 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
17
18 namespace xla {
19 namespace gpu {
20 namespace {
21
22 class ElementWiseRowVectorizationTest : public HloTestBase {};
23
TEST_F(ElementWiseRowVectorizationTest,SimpleAddSmallRowBroadcastingTest)24 TEST_F(ElementWiseRowVectorizationTest, SimpleAddSmallRowBroadcastingTest) {
25 const char* hlo_text = R"(
26 HloModule SimpleAddSmallRowBroadcasting
27
28 %fused_computation.0 {
29 %param_0 = f32[48]{0} parameter(0)
30 %broadcast = f32[256,14,14,48]{3,2,1,0} broadcast(%param_0), dimensions={3}
31 %param_1 = f32[256,14,14,48]{3,2,1,0} parameter(1)
32 ROOT %add = f32[256,14,14,48]{3,2,1,0} add(%broadcast, %param_1)
33 }
34
35 ENTRY main {
36 %param_0 = f32[48]{0} parameter(0)
37 %param_1 = f32[256,14,14,48]{3,2,1,0} parameter(1)
38
39 ROOT %fusion.0_small = f32[256,14,14,48]{3,2,1,0} fusion(%param_0, %param_1), kind=kLoop, calls=%fused_computation.0
40 }
41 )";
42 auto module = ParseAndReturnVerifiedModule(hlo_text).ValueOrDie();
43 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
44 }
45
46 } // namespace
47 } // namespace gpu
48 } // namespace xla
49