1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/batchnorm_expander.h"
17
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22
23 #include "absl/types/optional.h"
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/literal_util.h"
27 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
28 #include "tensorflow/compiler/xla/service/hlo_computation.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/status_macros.h"
33 #include "tensorflow/compiler/xla/types.h"
34 #include "tensorflow/compiler/xla/util.h"
35 #include "tensorflow/compiler/xla/xla_data.pb.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/lib/core/status.h"
38 #include "tensorflow/core/platform/logging.h"
39 #include "tensorflow/core/platform/types.h"
40
41 namespace xla {
42
43 namespace {
44
45 using absl::optional;
46
47 // BatchNormExpanderVisitor traverses the HLO computation and rewrites BatchNorm
48 // operations into smaller operations.
49 class BatchNormExpanderVisitor : public DfsHloRewriteVisitor {
50 public:
51 Status HandleBatchNormTraining(HloInstruction* batch_norm) override;
52
53 Status HandleBatchNormInference(HloInstruction* batch_norm) override;
54
55 Status HandleBatchNormGrad(HloInstruction* batch_norm) override;
56
57 // Runs the visitor on a computation.
58 static bool Run(HloComputation* computation, bool rewrite_training_op,
59 bool rewrite_inference_op, bool rewrite_grad_op);
60
61 ~BatchNormExpanderVisitor() override = default;
62
63 private:
BatchNormExpanderVisitor(HloComputation * computation,bool rewrite_training_op,bool rewrite_inference_op,bool rewrite_grad_op)64 explicit BatchNormExpanderVisitor(HloComputation* computation,
65 bool rewrite_training_op,
66 bool rewrite_inference_op,
67 bool rewrite_grad_op)
68 : computation_(computation),
69 rewrite_training_op_(rewrite_training_op),
70 rewrite_inference_op_(rewrite_inference_op),
71 rewrite_grad_op_(rewrite_grad_op) {}
72
GetOrCreateScalarAddComputation(PrimitiveType primitive_type)73 HloComputation* GetOrCreateScalarAddComputation(
74 PrimitiveType primitive_type) {
75 HloComputation::Builder b("scalar_add_computation");
76 Shape shape = ShapeUtil::MakeShape(primitive_type, {});
77 auto scalar_lhs = b.AddInstruction(
78 HloInstruction::CreateParameter(0, shape, "scalar_lhs"));
79 auto scalar_rhs = b.AddInstruction(
80 HloInstruction::CreateParameter(1, shape, "scalar_rhs"));
81 auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary(
82 shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs));
83 return computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op));
84 }
85
Rsqrt(HloInstruction * operand,const std::function<HloInstruction * (std::unique_ptr<HloInstruction>)> & add_instruction)86 std::unique_ptr<HloInstruction> Rsqrt(
87 HloInstruction* operand,
88 const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
89 add_instruction) {
90 return HloInstruction::CreateUnary(operand->shape(), HloOpcode::kRsqrt,
91 operand);
92 }
93
Mean(HloInstruction * element_count,HloInstruction * operand,const std::function<HloInstruction * (std::unique_ptr<HloInstruction>)> & add_instruction)94 std::unique_ptr<HloInstruction> Mean(
95 HloInstruction* element_count, HloInstruction* operand,
96 const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
97 add_instruction) {
98 auto broadcast = add_instruction(
99 HloInstruction::CreateBroadcast(operand->shape(), element_count, {}));
100 return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kDivide,
101 operand, broadcast);
102 }
103
DynamicElementCountPerFeature(HloInstruction * operand,int64 feature_index,const std::function<HloInstruction * (std::unique_ptr<HloInstruction>)> & add_instruction)104 std::unique_ptr<HloInstruction> DynamicElementCountPerFeature(
105 HloInstruction* operand, int64 feature_index,
106 const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
107 add_instruction) {
108 auto elements_per_feature_s32 = add_instruction(
109 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
110
111 for (int64 i = 0; i < operand->shape().rank(); ++i) {
112 if (i == feature_index) {
113 continue;
114 }
115 auto dynamic_dimension_size =
116 add_instruction(HloInstruction::CreateGetDimensionSize(
117 ShapeUtil::MakeShape(S32, {}), operand, i));
118 elements_per_feature_s32 = add_instruction(HloInstruction::CreateBinary(
119 ShapeUtil::MakeShape(S32, {}), HloOpcode::kMultiply,
120 dynamic_dimension_size, elements_per_feature_s32));
121 }
122
123 return HloInstruction::CreateConvert(
124 ShapeUtil::MakeShape(operand->shape().element_type(), {}),
125 elements_per_feature_s32);
126 }
127
128 // Current HloComputation instance the BatchNormExpander is
129 // traversing.
130 HloComputation* computation_;
131
132 bool rewrite_training_op_;
133 bool rewrite_inference_op_;
134 bool rewrite_grad_op_;
135 };
136
137 } // namespace
138
Run(HloComputation * computation,bool rewrite_training_op,bool rewrite_inference_op,bool rewrite_grad_op)139 bool BatchNormExpanderVisitor::Run(HloComputation* computation,
140 bool rewrite_training_op,
141 bool rewrite_inference_op,
142 bool rewrite_grad_op) {
143 BatchNormExpanderVisitor visitor(
144 computation,
145 /*rewrite_training_op=*/rewrite_training_op,
146 /*rewrite_inference_op=*/rewrite_inference_op,
147 /*rewrite_grad_op=*/rewrite_grad_op);
148 TF_CHECK_OK(computation->Accept(&visitor));
149 return visitor.changed();
150 }
151
HandleBatchNormTraining(HloInstruction * batch_norm)152 Status BatchNormExpanderVisitor::HandleBatchNormTraining(
153 HloInstruction* batch_norm) {
154 if (!rewrite_training_op_) {
155 return Status::OK();
156 }
157
158 std::vector<HloInstruction*> added_instructions;
159 auto add = [&](std::unique_ptr<HloInstruction> inst) {
160 HloInstruction* added_inst = computation_->AddInstruction(std::move(inst));
161 added_inst->set_metadata(batch_norm->metadata());
162 added_instructions.push_back(added_inst);
163 return added_inst;
164 };
165 auto add_binary = [&](const Shape& shape, const HloOpcode opcode,
166 HloInstruction* a, HloInstruction* b) {
167 return add(HloInstruction::CreateBinary(shape, opcode, a, b));
168 };
169 int64 instruction_count_before = computation_->instruction_count();
170
171 // Expand batch norm training into smaller HLO ops.
172 HloInstruction* operand = batch_norm->mutable_operand(0);
173 const Shape operand_shape = operand->shape();
174 PrimitiveType ptype = operand_shape.element_type();
175 int64 feature_index = batch_norm->feature_index();
176
177 HloInstruction* scale = batch_norm->mutable_operand(1);
178 HloInstruction* offset = batch_norm->mutable_operand(2);
179 const Shape feature_shape = scale->shape();
180
181 auto zero_literal = LiteralUtil::CreateR0(0.0f);
182 TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype));
183 auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal)));
184
185 auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
186 TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
187 auto epsilon = add(HloInstruction::CreateBroadcast(
188 operand_shape,
189 add(HloInstruction::CreateConstant(std::move(epsilon_literal))), {}));
190 std::vector<int64> dimensions_without_feature;
191
192 for (int64 i = 0; i < operand_shape.rank(); ++i) {
193 if (i != feature_index) {
194 dimensions_without_feature.push_back(i);
195 }
196 }
197
198 auto elements_per_feature =
199 add(DynamicElementCountPerFeature(operand, feature_index, add));
200
201 auto scale_broadcasted = add(
202 HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index}));
203
204 auto offset_broadcasted = add(
205 HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index}));
206
207 HloComputation* add_reduce_computation =
208 GetOrCreateScalarAddComputation(ptype);
209
210 // X^2.
211 auto operand_squared =
212 add_binary(operand_shape, HloOpcode::kMultiply, operand, operand);
213 // Sum[X].
214 auto sum = add(HloInstruction::CreateReduce(feature_shape, operand, zero,
215 dimensions_without_feature,
216 add_reduce_computation));
217
218 // Sum[X^2].
219 auto squared_sum = add(HloInstruction::CreateReduce(
220 feature_shape, operand_squared, zero, dimensions_without_feature,
221 add_reduce_computation));
222
223 // E[X].
224 auto mean = add(Mean(elements_per_feature, sum, add));
225
226 auto mean_broadcasted = add(
227 HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index}));
228
229 // E[X^2].
230 auto square_mean = add(Mean(elements_per_feature, squared_sum, add));
231
232 // E^2[X].
233 auto mean_square =
234 add_binary(feature_shape, HloOpcode::kMultiply, mean, mean);
235
236 // Var[X].
237 auto var =
238 add_binary(feature_shape, HloOpcode::kSubtract, square_mean, mean_square);
239
240 auto var_broadcasted =
241 add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index}));
242
243 // Var[X] + epsilon.
244 auto var_add_epsilon =
245 add_binary(operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon);
246
247 // 1 / Sqrt[Var[X] + epsilon].
248 auto rsqrt_var_add_epsilon = add(Rsqrt(var_add_epsilon, add));
249
250 // X - E[X].
251 auto operand_minus_mean = add_binary(operand_shape, HloOpcode::kSubtract,
252 operand, mean_broadcasted);
253
254 // (X - E[X]) / Sqrt[Var[X] + epsilon].
255 auto normalized = add_binary(operand_shape, HloOpcode::kMultiply,
256 operand_minus_mean, rsqrt_var_add_epsilon);
257
258 // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale.
259 auto scaled_normalized = add_binary(operand_shape, HloOpcode::kMultiply,
260 normalized, scale_broadcasted);
261
262 // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset.
263 auto shifted_normalized = add_binary(operand_shape, HloOpcode::kAdd,
264 scaled_normalized, offset_broadcasted);
265
266 auto tuple = HloInstruction::CreateTuple({shifted_normalized, mean, var});
267
268 if (batch_norm->has_sharding()) {
269 int64 instruction_count_after = computation_->instruction_count();
270 CHECK_EQ(instruction_count_after,
271 instruction_count_before + added_instructions.size());
272 const HloSharding& sharding = batch_norm->sharding();
273 HloSharding operand_sharding =
274 sharding.GetAsShapeTree(batch_norm->shape()).element({0});
275 optional<int64> unique_device = batch_norm->sharding_unique_device();
276 HloSharding default_sharding =
277 unique_device.has_value()
278 ? HloSharding::AssignDevice(unique_device.value())
279 : HloSharding::Replicate();
280 for (HloInstruction* inst : added_instructions) {
281 if (ShapeUtil::Equal(inst->shape(), operand_shape)) {
282 inst->set_sharding(operand_sharding);
283 } else {
284 inst->set_sharding(default_sharding);
285 }
286 }
287 tuple->set_sharding(sharding);
288 }
289 TF_CHECK_OK(ReplaceWithNewInstruction(batch_norm, std::move(tuple)));
290 return Status::OK();
291 }
292
HandleBatchNormInference(HloInstruction * batch_norm)293 Status BatchNormExpanderVisitor::HandleBatchNormInference(
294 HloInstruction* batch_norm) {
295 if (!rewrite_inference_op_) {
296 return Status::OK();
297 }
298 // Expand batch norm inference into smaller HLO ops.
299 HloInstruction* operand = batch_norm->mutable_operand(0);
300 const Shape operand_shape = operand->shape();
301 int64 feature_index = batch_norm->feature_index();
302 PrimitiveType ptype = operand_shape.element_type();
303
304 HloInstruction* scale = batch_norm->mutable_operand(1);
305 HloInstruction* offset = batch_norm->mutable_operand(2);
306 HloInstruction* mean = batch_norm->mutable_operand(3);
307 HloInstruction* var = batch_norm->mutable_operand(4);
308 const Shape feature_shape = scale->shape();
309
310 auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
311 TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
312 auto epsilon = computation_->AddInstruction(HloInstruction::CreateBroadcast(
313 feature_shape,
314 computation_->AddInstruction(
315 HloInstruction::CreateConstant(std::move(epsilon_literal))),
316 {}));
317
318 std::vector<int64> dimensions_without_feature;
319
320 for (int64 i = 0; i < operand_shape.rank(); ++i) {
321 if (i != feature_index) {
322 dimensions_without_feature.push_back(i);
323 }
324 }
325
326 std::vector<HloInstruction*> added_instructions;
327 auto add = [&](std::unique_ptr<HloInstruction> inst) {
328 HloInstruction* added_inst = computation_->AddInstruction(std::move(inst));
329 added_inst->set_metadata(batch_norm->metadata());
330 added_instructions.push_back(added_inst);
331 return added_inst;
332 };
333 auto add_binary = [&](const Shape& shape, const HloOpcode opcode,
334 HloInstruction* a, HloInstruction* b) {
335 return add(HloInstruction::CreateBinary(shape, opcode, a, b));
336 };
337 auto feature_broadcast = [&](HloInstruction* a) {
338 return add(
339 HloInstruction::CreateBroadcast(operand_shape, a, {feature_index}));
340 };
341
342 int64 instruction_count_before = computation_->instruction_count();
343 auto true_scale = add_binary(
344 feature_shape, HloOpcode::kMultiply, scale,
345 add(Rsqrt(add_binary(feature_shape, HloOpcode::kAdd, var, epsilon),
346 add)));
347 auto true_shift = add_binary(
348 feature_shape, HloOpcode::kSubtract, offset,
349 add_binary(feature_shape, HloOpcode::kMultiply, mean, true_scale));
350
351 auto shifted_normalized =
352 add_binary(operand_shape, HloOpcode::kAdd,
353 add_binary(operand_shape, HloOpcode::kMultiply, operand,
354 feature_broadcast(true_scale)),
355 feature_broadcast(true_shift));
356
357 int64 instruction_count_after = computation_->instruction_count();
358 CHECK_EQ(instruction_count_after,
359 instruction_count_before + added_instructions.size());
360 if (batch_norm->has_sharding()) {
361 const HloSharding& sharding = batch_norm->sharding();
362 optional<int64> unique_device = batch_norm->sharding_unique_device();
363 HloSharding default_sharding =
364 unique_device.has_value()
365 ? HloSharding::AssignDevice(unique_device.value())
366 : HloSharding::Replicate();
367 for (HloInstruction* inst : added_instructions) {
368 if (ShapeUtil::Equal(inst->shape(), operand_shape)) {
369 inst->set_sharding(sharding);
370 } else {
371 inst->set_sharding(default_sharding);
372 }
373 }
374 shifted_normalized->set_sharding(sharding);
375 }
376 TF_CHECK_OK(ReplaceInstruction(batch_norm, shifted_normalized));
377 return Status::OK();
378 }
379
HandleBatchNormGrad(HloInstruction * batch_norm)380 Status BatchNormExpanderVisitor::HandleBatchNormGrad(
381 HloInstruction* batch_norm) {
382 // Use the following formulas to calculate gradients:
383 // scale_grad =
384 // sum(output_grad * (activation - mean(activation))) * rsqrt(var + epsilon)
385 //
386 // offset_grad =
387 // sum(output_grad)
388 //
389 // activation_grad =
390 // 1/N * scale * rsqrt(var + epsilon) *
391 // (N * output_grad - sum(output_grad) - (activation - mean(activation)) *
392 // sum(output_grad * (activation - mean(activation))) / (variance +
393 // epsilon))
394 if (!rewrite_grad_op_) {
395 return Status::OK();
396 }
397 std::vector<HloInstruction*> added_instructions;
398 auto add = [&](std::unique_ptr<HloInstruction> inst) {
399 HloInstruction* added_inst = computation_->AddInstruction(std::move(inst));
400 added_inst->set_metadata(batch_norm->metadata());
401 added_instructions.push_back(added_inst);
402 return added_inst;
403 };
404 auto add_binary = [&](const Shape& shape, const HloOpcode opcode,
405 HloInstruction* a, HloInstruction* b) {
406 return add(HloInstruction::CreateBinary(shape, opcode, a, b));
407 };
408 int64 instruction_count_before = computation_->instruction_count();
409
410 HloInstruction* activation = batch_norm->mutable_operand(0);
411 const Shape activation_shape = activation->shape();
412 PrimitiveType ptype = activation_shape.element_type();
413 HloInstruction* scale = batch_norm->mutable_operand(1);
414 const Shape feature_shape = scale->shape();
415 HloInstruction* mean = batch_norm->mutable_operand(2);
416 HloInstruction* variance = batch_norm->mutable_operand(3);
417 HloInstruction* grad_output = batch_norm->mutable_operand(4);
418
419 int64 feature_index = batch_norm->feature_index();
420
421 auto elements_per_feature =
422 add(DynamicElementCountPerFeature(activation, feature_index, add));
423
424 auto zero_literal = LiteralUtil::CreateR0(0.0f);
425 TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype));
426 auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal)));
427
428 auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
429 TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
430 auto epsilon_scalar =
431 add(HloInstruction::CreateConstant(std::move(epsilon_literal)));
432 auto epsilon_activation = add(
433 HloInstruction::CreateBroadcast(activation_shape, epsilon_scalar, {}));
434 auto epsilon_feature =
435 add(HloInstruction::CreateBroadcast(feature_shape, epsilon_scalar, {}));
436
437 std::vector<int64> dimensions_without_feature;
438
439 for (int64 i = 0; i < activation_shape.rank(); ++i) {
440 if (i != feature_index) {
441 dimensions_without_feature.push_back(i);
442 }
443 }
444
445 auto scale_broadcasted = add(HloInstruction::CreateBroadcast(
446 activation_shape, scale, {feature_index}));
447 auto variance_broadcasted = add(HloInstruction::CreateBroadcast(
448 activation_shape, variance, {feature_index}));
449
450 // E[X].
451 auto mean_broadcasted = add(
452 HloInstruction::CreateBroadcast(activation_shape, mean, {feature_index}));
453
454 // rsqrt[Var[X] + epsilon].
455 auto rsqrt_var_add_epsilon_broadcasted =
456 add(Rsqrt(add_binary(activation_shape, HloOpcode::kAdd,
457 variance_broadcasted, epsilon_activation),
458 add));
459
460 auto rsqrt_var_add_epsilon = add(Rsqrt(
461 add_binary(feature_shape, HloOpcode::kAdd, variance, epsilon_feature),
462 add));
463
464 // X - E[X].
465 auto activation_minus_mean = add_binary(
466 activation_shape, HloOpcode::kSubtract, activation, mean_broadcasted);
467
468 // Grad[Y] * (X - E[X]).
469 auto grad_output_times_activation_minus_mean =
470 add_binary(activation_shape, HloOpcode::kMultiply, grad_output,
471 activation_minus_mean);
472
473 HloComputation* add_reduce_computation =
474 GetOrCreateScalarAddComputation(ptype);
475
476 // sum(Grad[Y] * (X - E[X])).
477 auto sum_grad_output_times_activation_minus_mean =
478 add(HloInstruction::CreateReduce(
479 feature_shape, grad_output_times_activation_minus_mean, zero,
480 dimensions_without_feature, add_reduce_computation));
481
482 // Grad[beta] = Sum(Grad[Y]).
483 auto grad_beta = add(HloInstruction::CreateReduce(
484 feature_shape, grad_output, zero, dimensions_without_feature,
485 add_reduce_computation));
486
487 // Grad[scale] = Sum(Grad[Y] * (X - E[X]) * rsqrt[Var[X] + epsilon]).
488 auto grad_scale = add_binary(feature_shape, HloOpcode::kMultiply,
489 sum_grad_output_times_activation_minus_mean,
490 rsqrt_var_add_epsilon);
491
492 // I2 = Sum(Grad[Y])
493 auto i2 = add(HloInstruction::CreateBroadcast(activation_shape, grad_beta,
494 {feature_index}));
495
496 // I3 = Sum(Grad[Y] * (X - E[X]))
497 auto i3 = add(HloInstruction::CreateBroadcast(
498 activation_shape, sum_grad_output_times_activation_minus_mean,
499 {feature_index}));
500
501 // I4 = (X - E[X]) * I3
502 auto i4 = add_binary(activation_shape, HloOpcode::kMultiply, i3,
503 activation_minus_mean);
504
505 // I5 = I4 / (Var[X] + epsilon)
506 auto i5 = add_binary(activation_shape, HloOpcode::kDivide, i4,
507 add_binary(activation_shape, HloOpcode::kAdd,
508 variance_broadcasted, epsilon_activation));
509
510 // scale * rsqrt[Var[X] + epsilon] * 1/N
511 auto scale_times_rsqrt_var_add_epsilon =
512 add_binary(activation_shape, HloOpcode::kMultiply, scale_broadcasted,
513 rsqrt_var_add_epsilon_broadcasted);
514
515 scale_times_rsqrt_var_add_epsilon =
516 add(Mean(elements_per_feature, scale_times_rsqrt_var_add_epsilon, add));
517
518 auto i1 = add_binary(activation_shape, HloOpcode::kMultiply, grad_output,
519 add(HloInstruction::CreateBroadcast(
520 activation_shape, elements_per_feature, {})));
521
522 // I6 = I1 - I2 - I5
523 auto i6 = add_binary(
524 activation_shape, HloOpcode::kSubtract,
525 add_binary(activation_shape, HloOpcode::kSubtract, i1, i2), i5);
526
527 // Grad[X] = scale * rsqrt[Var[X] + epsilon] * 1/N * I6.
528 auto grad_activation = add_binary(activation_shape, HloOpcode::kMultiply,
529 scale_times_rsqrt_var_add_epsilon, i6);
530 auto tuple =
531 HloInstruction::CreateTuple({grad_activation, grad_scale, grad_beta});
532 if (batch_norm->has_sharding()) {
533 const HloSharding& sharding = batch_norm->sharding();
534 int64 instruction_count_after = computation_->instruction_count();
535 CHECK_EQ(instruction_count_after,
536 instruction_count_before + added_instructions.size());
537 HloSharding activation_sharding =
538 sharding.GetAsShapeTree(batch_norm->shape()).element({0});
539 auto unique_device = batch_norm->sharding_unique_device();
540 HloSharding default_sharding =
541 unique_device.has_value()
542 ? HloSharding::AssignDevice(unique_device.value())
543 : HloSharding::Replicate();
544 for (HloInstruction* inst : added_instructions) {
545 if (ShapeUtil::Equal(inst->shape(), activation_shape)) {
546 inst->set_sharding(activation_sharding);
547 } else {
548 inst->set_sharding(default_sharding);
549 }
550 }
551 tuple->set_sharding(sharding);
552 }
553
554 TF_CHECK_OK(ReplaceWithNewInstruction(batch_norm, std::move(tuple)));
555
556 return Status::OK();
557 }
558
Run(HloModule * module)559 StatusOr<bool> BatchNormExpander::Run(HloModule* module) {
560 XLA_VLOG_LINES(2, "BatchNormExpander::Run(), before:\n" + module->ToString());
561 bool changed = false;
562 for (auto* comp : module->MakeNonfusionComputations()) {
563 if (BatchNormExpanderVisitor::Run(comp, rewrite_training_op_,
564 rewrite_inference_op_,
565 rewrite_grad_op_)) {
566 changed = true;
567 }
568 }
569 XLA_VLOG_LINES(2, "BatchNormExpander::Run(), after:\n" + module->ToString());
570 return changed;
571 }
572
573 } // namespace xla
574