• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 Licensed under the Apache License, Version 2.0 (the "License");
3 you may not use this file except in compliance with the License.
4 You may obtain a copy of the License at
5     http://www.apache.org/licenses/LICENSE-2.0
6 Unless required by applicable law or agreed to in writing, software
7 distributed under the License is distributed on an "AS IS" BASIS,
8 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 See the License for the specific language governing permissions and
10 limitations under the License.
11 ==============================================================================*/
12 
13 #include <vector>
14 
15 #include "tensorflow/compiler/xla/error_spec.h"
16 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
17 
18 namespace xla {
19 namespace gpu {
20 namespace {
21 
22 class ElementWiseRowVectorizationTest : public HloTestBase {};
23 
TEST_F(ElementWiseRowVectorizationTest,SimpleAddSmallRowBroadcastingTest)24 TEST_F(ElementWiseRowVectorizationTest, SimpleAddSmallRowBroadcastingTest) {
25   const char* hlo_text = R"(
26 HloModule SimpleAddSmallRowBroadcasting
27 
28 %fused_computation.0 {
29   %param_0 = f32[48]{0} parameter(0)
30   %broadcast = f32[256,14,14,48]{3,2,1,0} broadcast(%param_0), dimensions={3}
31   %param_1 = f32[256,14,14,48]{3,2,1,0} parameter(1)
32   ROOT %add = f32[256,14,14,48]{3,2,1,0} add(%broadcast, %param_1)
33 }
34 
35 ENTRY main {
36   %param_0 = f32[48]{0} parameter(0)
37   %param_1 = f32[256,14,14,48]{3,2,1,0} parameter(1)
38 
39   ROOT %fusion.0_small = f32[256,14,14,48]{3,2,1,0} fusion(%param_0, %param_1), kind=kLoop, calls=%fused_computation.0
40 }
41 )";
42   auto module = ParseAndReturnVerifiedModule(hlo_text).ValueOrDie();
43   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
44 }
45 
46 }  // namespace
47 }  // namespace gpu
48 }  // namespace xla
49