1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/batchnorm_expander.h"
17
18 #include <memory>
19 #include <utility>
20
21 #include "tensorflow/compiler/xla/layout_util.h"
22 #include "tensorflow/compiler/xla/literal.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/service/hlo_parser.h"
28 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/test.h"
31 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34
35 namespace xla {
36 namespace {
37
38 class BatchNormExpanderTest : public HloTestBase {
39 protected:
40 // BatchNorm should have a dynamic sized divider for mean operations.
CountGetDimensionSize(const HloModule & module)41 int64_t CountGetDimensionSize(const HloModule& module) {
42 int64_t count = 0;
43 for (HloComputation* comp : module.computations()) {
44 for (HloInstruction* inst : comp->instructions()) {
45 if (inst->opcode() == HloOpcode::kGetDimensionSize) {
46 count++;
47 }
48 }
49 }
50 return count;
51 }
52 };
53
54 // Test that we expand BatchNormTraining.
TEST_F(BatchNormExpanderTest,BatchNormTraining)55 TEST_F(BatchNormExpanderTest, BatchNormTraining) {
56 Shape input_shape = ShapeUtil::MakeShape(F32, {2, 2, 2, 2});
57 Shape scale_shape = ShapeUtil::MakeShape(F32, {2});
58 Shape offset_shape = ShapeUtil::MakeShape(F32, {2});
59
60 HloComputation::Builder builder(TestName());
61 HloInstruction* param0 = builder.AddInstruction(
62 HloInstruction::CreateParameter(0, input_shape, "activation"));
63
64 HloInstruction* param1 = builder.AddInstruction(
65 HloInstruction::CreateParameter(1, scale_shape, "scale"));
66
67 HloInstruction* param2 = builder.AddInstruction(
68 HloInstruction::CreateParameter(2, offset_shape, "offset"));
69
70 builder.AddInstruction(HloInstruction::CreateBatchNormTraining(
71 ShapeUtil::MakeTupleShape({input_shape, scale_shape, offset_shape}),
72 param0, param1, param2,
73 /*epsilon=*/0.001, /*feature_index=*/3));
74
75 auto module = CreateNewVerifiedModule();
76 auto computation = module->AddEntryComputation(builder.Build());
77 HloInstruction* root = computation->root_instruction();
78 EXPECT_EQ(root->opcode(), HloOpcode::kBatchNormTraining);
79 BatchNormExpander rewriter(/*rewrite_training_op=*/true,
80 /*rewrite_inference_op=*/true,
81 /*rewrite_grad_op=*/true);
82 ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie());
83 root = computation->root_instruction();
84 EXPECT_EQ(CountGetDimensionSize(*module), 3);
85 // Make sure this operation is expanded.
86 EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
87 }
88
89 // Test that we expand BatchNormGrad.
TEST_F(BatchNormExpanderTest,BatchNormGrad)90 TEST_F(BatchNormExpanderTest, BatchNormGrad) {
91 Shape input_shape = ShapeUtil::MakeShape(F32, {2, 2, 2, 2});
92 Shape scale_shape = ShapeUtil::MakeShape(F32, {2});
93 Shape mean_shape = ShapeUtil::MakeShape(F32, {2});
94 Shape var_shape = ShapeUtil::MakeShape(F32, {2});
95 Shape grad_output_shape = ShapeUtil::MakeShape(F32, {2, 2, 2, 2});
96
97 HloComputation::Builder builder(TestName());
98 HloInstruction* param0 = builder.AddInstruction(
99 HloInstruction::CreateParameter(0, input_shape, "activation"));
100
101 HloInstruction* param1 = builder.AddInstruction(
102 HloInstruction::CreateParameter(1, scale_shape, "scale"));
103
104 HloInstruction* param2 = builder.AddInstruction(
105 HloInstruction::CreateParameter(2, mean_shape, "mean"));
106
107 HloInstruction* param3 = builder.AddInstruction(
108 HloInstruction::CreateParameter(3, var_shape, "var"));
109
110 HloInstruction* param4 = builder.AddInstruction(
111 HloInstruction::CreateParameter(4, grad_output_shape, "grad_output"));
112
113 builder.AddInstruction(HloInstruction::CreateBatchNormGrad(
114 ShapeUtil::MakeTupleShape({input_shape, scale_shape, mean_shape}), param0,
115 param1, param2, param3, param4,
116 /*epsilon=*/0.001, /*feature_index=*/3));
117
118 auto module = CreateNewVerifiedModule();
119 auto computation = module->AddEntryComputation(builder.Build());
120 HloInstruction* root = computation->root_instruction();
121 EXPECT_EQ(root->opcode(), HloOpcode::kBatchNormGrad);
122 BatchNormExpander rewriter(/*rewrite_training_op=*/true,
123 /*rewrite_inference_op=*/true,
124 /*rewrite_grad_op=*/true);
125 ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie());
126 root = computation->root_instruction();
127 EXPECT_EQ(CountGetDimensionSize(*module), 3);
128 // Make sure this operation is expanded.
129 EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
130 }
131
TEST_F(BatchNormExpanderTest,BatchNormTrainingSharding)132 TEST_F(BatchNormExpanderTest, BatchNormTrainingSharding) {
133 const char* module_str = R"(
134 HloModule module
135 ENTRY entry {
136 %param.0 = f32[8,4] parameter(0)
137 %param.1 = f32[4] parameter(1)
138 %param.2 = f32[4] parameter(2)
139 ROOT %batch-norm-training = (f32[8,4], f32[4], f32[4])
140 batch-norm-training(f32[8,4] %param.0, f32[4] %param.1, f32[4] %param.2),
141 epsilon=0.001, feature_index=1, sharding={maximal device=1}
142 })";
143
144 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
145 BatchNormExpander rewriter(/*rewrite_training_op=*/true,
146 /*rewrite_inference_op=*/true,
147 /*rewrite_grad_op=*/true);
148 ASSERT_TRUE(rewriter.Run(m.get()).ValueOrDie());
149
150 for (auto* instruction : m->entry_computation()->instructions()) {
151 if (instruction->opcode() == HloOpcode::kParameter) {
152 continue;
153 }
154 auto device = instruction->sharding_unique_device();
155 ASSERT_TRUE(device);
156 EXPECT_EQ(*device, 1);
157 }
158 }
159
TEST_F(BatchNormExpanderTest,Execution)160 TEST_F(BatchNormExpanderTest, Execution) {
161 const char* module_str = R"(
162 HloModule module
163 ENTRY entry {
164 %param.0 = f32[8,4] parameter(0)
165 %param.1 = f32[4] parameter(1)
166 %param.2 = f32[4] parameter(2)
167 ROOT %batch-norm-training = (f32[8,4], f32[4], f32[4])
168 batch-norm-training(f32[8,4] %param.0, f32[4] %param.1, f32[4] %param.2),
169 epsilon=0.001, feature_index=1, sharding={maximal device=1}
170 })";
171 EXPECT_TRUE(RunAndCompare(module_str, ErrorSpec{1e-4, 1e-4}));
172 }
173
174 } // namespace
175 } // namespace xla
176