• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/batchnorm_expander.h"
17 
18 #include <memory>
19 #include <utility>
20 
21 #include "tensorflow/compiler/xla/layout_util.h"
22 #include "tensorflow/compiler/xla/literal.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/service/hlo_parser.h"
28 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/test.h"
31 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 
35 namespace xla {
36 namespace {
37 
38 class BatchNormExpanderTest : public HloTestBase {
39  protected:
40   // BatchNorm should have a dynamic sized divider for mean operations.
CountGetDimensionSize(const HloModule & module)41   int64_t CountGetDimensionSize(const HloModule& module) {
42     int64_t count = 0;
43     for (HloComputation* comp : module.computations()) {
44       for (HloInstruction* inst : comp->instructions()) {
45         if (inst->opcode() == HloOpcode::kGetDimensionSize) {
46           count++;
47         }
48       }
49     }
50     return count;
51   }
52 };
53 
54 // Test that we expand BatchNormTraining.
TEST_F(BatchNormExpanderTest,BatchNormTraining)55 TEST_F(BatchNormExpanderTest, BatchNormTraining) {
56   Shape input_shape = ShapeUtil::MakeShape(F32, {2, 2, 2, 2});
57   Shape scale_shape = ShapeUtil::MakeShape(F32, {2});
58   Shape offset_shape = ShapeUtil::MakeShape(F32, {2});
59 
60   HloComputation::Builder builder(TestName());
61   HloInstruction* param0 = builder.AddInstruction(
62       HloInstruction::CreateParameter(0, input_shape, "activation"));
63 
64   HloInstruction* param1 = builder.AddInstruction(
65       HloInstruction::CreateParameter(1, scale_shape, "scale"));
66 
67   HloInstruction* param2 = builder.AddInstruction(
68       HloInstruction::CreateParameter(2, offset_shape, "offset"));
69 
70   builder.AddInstruction(HloInstruction::CreateBatchNormTraining(
71       ShapeUtil::MakeTupleShape({input_shape, scale_shape, offset_shape}),
72       param0, param1, param2,
73       /*epsilon=*/0.001, /*feature_index=*/3));
74 
75   auto module = CreateNewVerifiedModule();
76   auto computation = module->AddEntryComputation(builder.Build());
77   HloInstruction* root = computation->root_instruction();
78   EXPECT_EQ(root->opcode(), HloOpcode::kBatchNormTraining);
79   BatchNormExpander rewriter(/*rewrite_training_op=*/true,
80                              /*rewrite_inference_op=*/true,
81                              /*rewrite_grad_op=*/true);
82   ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie());
83   root = computation->root_instruction();
84   EXPECT_EQ(CountGetDimensionSize(*module), 3);
85   // Make sure this operation is expanded.
86   EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
87 }
88 
89 // Test that we expand BatchNormGrad.
TEST_F(BatchNormExpanderTest,BatchNormGrad)90 TEST_F(BatchNormExpanderTest, BatchNormGrad) {
91   Shape input_shape = ShapeUtil::MakeShape(F32, {2, 2, 2, 2});
92   Shape scale_shape = ShapeUtil::MakeShape(F32, {2});
93   Shape mean_shape = ShapeUtil::MakeShape(F32, {2});
94   Shape var_shape = ShapeUtil::MakeShape(F32, {2});
95   Shape grad_output_shape = ShapeUtil::MakeShape(F32, {2, 2, 2, 2});
96 
97   HloComputation::Builder builder(TestName());
98   HloInstruction* param0 = builder.AddInstruction(
99       HloInstruction::CreateParameter(0, input_shape, "activation"));
100 
101   HloInstruction* param1 = builder.AddInstruction(
102       HloInstruction::CreateParameter(1, scale_shape, "scale"));
103 
104   HloInstruction* param2 = builder.AddInstruction(
105       HloInstruction::CreateParameter(2, mean_shape, "mean"));
106 
107   HloInstruction* param3 = builder.AddInstruction(
108       HloInstruction::CreateParameter(3, var_shape, "var"));
109 
110   HloInstruction* param4 = builder.AddInstruction(
111       HloInstruction::CreateParameter(4, grad_output_shape, "grad_output"));
112 
113   builder.AddInstruction(HloInstruction::CreateBatchNormGrad(
114       ShapeUtil::MakeTupleShape({input_shape, scale_shape, mean_shape}), param0,
115       param1, param2, param3, param4,
116       /*epsilon=*/0.001, /*feature_index=*/3));
117 
118   auto module = CreateNewVerifiedModule();
119   auto computation = module->AddEntryComputation(builder.Build());
120   HloInstruction* root = computation->root_instruction();
121   EXPECT_EQ(root->opcode(), HloOpcode::kBatchNormGrad);
122   BatchNormExpander rewriter(/*rewrite_training_op=*/true,
123                              /*rewrite_inference_op=*/true,
124                              /*rewrite_grad_op=*/true);
125   ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie());
126   root = computation->root_instruction();
127   EXPECT_EQ(CountGetDimensionSize(*module), 3);
128   // Make sure this operation is expanded.
129   EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
130 }
131 
TEST_F(BatchNormExpanderTest,BatchNormTrainingSharding)132 TEST_F(BatchNormExpanderTest, BatchNormTrainingSharding) {
133   const char* module_str = R"(
134 HloModule module
135 ENTRY entry {
136   %param.0 = f32[8,4] parameter(0)
137   %param.1 = f32[4] parameter(1)
138   %param.2 = f32[4] parameter(2)
139   ROOT %batch-norm-training = (f32[8,4], f32[4], f32[4])
140     batch-norm-training(f32[8,4] %param.0, f32[4] %param.1, f32[4] %param.2),
141     epsilon=0.001, feature_index=1, sharding={maximal device=1}
142 })";
143 
144   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
145   BatchNormExpander rewriter(/*rewrite_training_op=*/true,
146                              /*rewrite_inference_op=*/true,
147                              /*rewrite_grad_op=*/true);
148   ASSERT_TRUE(rewriter.Run(m.get()).ValueOrDie());
149 
150   for (auto* instruction : m->entry_computation()->instructions()) {
151     if (instruction->opcode() == HloOpcode::kParameter) {
152       continue;
153     }
154     auto device = instruction->sharding_unique_device();
155     ASSERT_TRUE(device);
156     EXPECT_EQ(*device, 1);
157   }
158 }
159 
TEST_F(BatchNormExpanderTest,Execution)160 TEST_F(BatchNormExpanderTest, Execution) {
161   const char* module_str = R"(
162 HloModule module
163 ENTRY entry {
164   %param.0 = f32[8,4] parameter(0)
165   %param.1 = f32[4] parameter(1)
166   %param.2 = f32[4] parameter(2)
167   ROOT %batch-norm-training = (f32[8,4], f32[4], f32[4])
168     batch-norm-training(f32[8,4] %param.0, f32[4] %param.1, f32[4] %param.2),
169     epsilon=0.001, feature_index=1, sharding={maximal device=1}
170 })";
171   EXPECT_TRUE(RunAndCompare(module_str, ErrorSpec{1e-4, 1e-4}));
172 }
173 
174 }  // namespace
175 }  // namespace xla
176