• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/batchnorm_expander.h"
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/types/optional.h"
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/literal_util.h"
27 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
28 #include "tensorflow/compiler/xla/service/hlo_computation.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/status_macros.h"
33 #include "tensorflow/compiler/xla/types.h"
34 #include "tensorflow/compiler/xla/util.h"
35 #include "tensorflow/compiler/xla/xla_data.pb.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/lib/core/status.h"
38 #include "tensorflow/core/platform/logging.h"
39 #include "tensorflow/core/platform/types.h"
40 
41 namespace xla {
42 
43 namespace {
44 
45 using absl::optional;
46 
47 // BatchNormExpanderVisitor traverses the HLO computation and rewrites BatchNorm
48 // operations into smaller operations.
49 class BatchNormExpanderVisitor : public DfsHloRewriteVisitor {
50  public:
51   Status HandleBatchNormTraining(HloInstruction* batch_norm) override;
52 
53   Status HandleBatchNormInference(HloInstruction* batch_norm) override;
54 
55   Status HandleBatchNormGrad(HloInstruction* batch_norm) override;
56 
57   // Runs the visitor on a computation.
58   static bool Run(HloComputation* computation, bool rewrite_training_op,
59                   bool rewrite_inference_op, bool rewrite_grad_op);
60 
61   ~BatchNormExpanderVisitor() override = default;
62 
63  private:
BatchNormExpanderVisitor(HloComputation * computation,bool rewrite_training_op,bool rewrite_inference_op,bool rewrite_grad_op)64   explicit BatchNormExpanderVisitor(HloComputation* computation,
65                                     bool rewrite_training_op,
66                                     bool rewrite_inference_op,
67                                     bool rewrite_grad_op)
68       : computation_(computation),
69         rewrite_training_op_(rewrite_training_op),
70         rewrite_inference_op_(rewrite_inference_op),
71         rewrite_grad_op_(rewrite_grad_op) {}
72 
GetOrCreateScalarAddComputation(PrimitiveType primitive_type)73   HloComputation* GetOrCreateScalarAddComputation(
74       PrimitiveType primitive_type) {
75     HloComputation::Builder b("scalar_add_computation");
76     Shape shape = ShapeUtil::MakeShape(primitive_type, {});
77     auto scalar_lhs = b.AddInstruction(
78         HloInstruction::CreateParameter(0, shape, "scalar_lhs"));
79     auto scalar_rhs = b.AddInstruction(
80         HloInstruction::CreateParameter(1, shape, "scalar_rhs"));
81     auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary(
82         shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs));
83     return computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op));
84   }
85 
Rsqrt(HloInstruction * operand,const std::function<HloInstruction * (std::unique_ptr<HloInstruction>)> & add_instruction)86   std::unique_ptr<HloInstruction> Rsqrt(
87       HloInstruction* operand,
88       const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
89           add_instruction) {
90     return HloInstruction::CreateUnary(operand->shape(), HloOpcode::kRsqrt,
91                                        operand);
92   }
93 
Mean(HloInstruction * element_count,HloInstruction * operand,const std::function<HloInstruction * (std::unique_ptr<HloInstruction>)> & add_instruction)94   std::unique_ptr<HloInstruction> Mean(
95       HloInstruction* element_count, HloInstruction* operand,
96       const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
97           add_instruction) {
98     auto broadcast = add_instruction(
99         HloInstruction::CreateBroadcast(operand->shape(), element_count, {}));
100     return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kDivide,
101                                         operand, broadcast);
102   }
103 
DynamicElementCountPerFeature(HloInstruction * operand,int64 feature_index,const std::function<HloInstruction * (std::unique_ptr<HloInstruction>)> & add_instruction)104   std::unique_ptr<HloInstruction> DynamicElementCountPerFeature(
105       HloInstruction* operand, int64 feature_index,
106       const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
107           add_instruction) {
108     auto elements_per_feature_s32 = add_instruction(
109         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
110 
111     for (int64 i = 0; i < operand->shape().rank(); ++i) {
112       if (i == feature_index) {
113         continue;
114       }
115       auto dynamic_dimension_size =
116           add_instruction(HloInstruction::CreateGetDimensionSize(
117               ShapeUtil::MakeShape(S32, {}), operand, i));
118       elements_per_feature_s32 = add_instruction(HloInstruction::CreateBinary(
119           ShapeUtil::MakeShape(S32, {}), HloOpcode::kMultiply,
120           dynamic_dimension_size, elements_per_feature_s32));
121     }
122 
123     return HloInstruction::CreateConvert(
124         ShapeUtil::MakeShape(operand->shape().element_type(), {}),
125         elements_per_feature_s32);
126   }
127 
128   // Current HloComputation instance the BatchNormExpander is
129   // traversing.
130   HloComputation* computation_;
131 
132   bool rewrite_training_op_;
133   bool rewrite_inference_op_;
134   bool rewrite_grad_op_;
135 };
136 
137 }  // namespace
138 
Run(HloComputation * computation,bool rewrite_training_op,bool rewrite_inference_op,bool rewrite_grad_op)139 bool BatchNormExpanderVisitor::Run(HloComputation* computation,
140                                    bool rewrite_training_op,
141                                    bool rewrite_inference_op,
142                                    bool rewrite_grad_op) {
143   BatchNormExpanderVisitor visitor(
144       computation,
145       /*rewrite_training_op=*/rewrite_training_op,
146       /*rewrite_inference_op=*/rewrite_inference_op,
147       /*rewrite_grad_op=*/rewrite_grad_op);
148   TF_CHECK_OK(computation->Accept(&visitor));
149   return visitor.changed();
150 }
151 
HandleBatchNormTraining(HloInstruction * batch_norm)152 Status BatchNormExpanderVisitor::HandleBatchNormTraining(
153     HloInstruction* batch_norm) {
154   if (!rewrite_training_op_) {
155     return Status::OK();
156   }
157 
158   std::vector<HloInstruction*> added_instructions;
159   auto add = [&](std::unique_ptr<HloInstruction> inst) {
160     HloInstruction* added_inst = computation_->AddInstruction(std::move(inst));
161     added_inst->set_metadata(batch_norm->metadata());
162     added_instructions.push_back(added_inst);
163     return added_inst;
164   };
165   auto add_binary = [&](const Shape& shape, const HloOpcode opcode,
166                         HloInstruction* a, HloInstruction* b) {
167     return add(HloInstruction::CreateBinary(shape, opcode, a, b));
168   };
169   int64 instruction_count_before = computation_->instruction_count();
170 
171   // Expand batch norm training into smaller HLO ops.
172   HloInstruction* operand = batch_norm->mutable_operand(0);
173   const Shape operand_shape = operand->shape();
174   PrimitiveType ptype = operand_shape.element_type();
175   int64 feature_index = batch_norm->feature_index();
176 
177   HloInstruction* scale = batch_norm->mutable_operand(1);
178   HloInstruction* offset = batch_norm->mutable_operand(2);
179   const Shape feature_shape = scale->shape();
180 
181   auto zero_literal = LiteralUtil::CreateR0(0.0f);
182   TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype));
183   auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal)));
184 
185   auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
186   TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
187   auto epsilon = add(HloInstruction::CreateBroadcast(
188       operand_shape,
189       add(HloInstruction::CreateConstant(std::move(epsilon_literal))), {}));
190   std::vector<int64> dimensions_without_feature;
191 
192   for (int64 i = 0; i < operand_shape.rank(); ++i) {
193     if (i != feature_index) {
194       dimensions_without_feature.push_back(i);
195     }
196   }
197 
198   auto elements_per_feature =
199       add(DynamicElementCountPerFeature(operand, feature_index, add));
200 
201   auto scale_broadcasted = add(
202       HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index}));
203 
204   auto offset_broadcasted = add(
205       HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index}));
206 
207   HloComputation* add_reduce_computation =
208       GetOrCreateScalarAddComputation(ptype);
209 
210   // X^2.
211   auto operand_squared =
212       add_binary(operand_shape, HloOpcode::kMultiply, operand, operand);
213   // Sum[X].
214   auto sum = add(HloInstruction::CreateReduce(feature_shape, operand, zero,
215                                               dimensions_without_feature,
216                                               add_reduce_computation));
217 
218   // Sum[X^2].
219   auto squared_sum = add(HloInstruction::CreateReduce(
220       feature_shape, operand_squared, zero, dimensions_without_feature,
221       add_reduce_computation));
222 
223   // E[X].
224   auto mean = add(Mean(elements_per_feature, sum, add));
225 
226   auto mean_broadcasted = add(
227       HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index}));
228 
229   // E[X^2].
230   auto square_mean = add(Mean(elements_per_feature, squared_sum, add));
231 
232   // E^2[X].
233   auto mean_square =
234       add_binary(feature_shape, HloOpcode::kMultiply, mean, mean);
235 
236   // Var[X].
237   auto var =
238       add_binary(feature_shape, HloOpcode::kSubtract, square_mean, mean_square);
239 
240   auto var_broadcasted =
241       add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index}));
242 
243   // Var[X] + epsilon.
244   auto var_add_epsilon =
245       add_binary(operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon);
246 
247   // 1 / Sqrt[Var[X] + epsilon].
248   auto rsqrt_var_add_epsilon = add(Rsqrt(var_add_epsilon, add));
249 
250   // X - E[X].
251   auto operand_minus_mean = add_binary(operand_shape, HloOpcode::kSubtract,
252                                        operand, mean_broadcasted);
253 
254   // (X - E[X]) / Sqrt[Var[X] + epsilon].
255   auto normalized = add_binary(operand_shape, HloOpcode::kMultiply,
256                                operand_minus_mean, rsqrt_var_add_epsilon);
257 
258   // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale.
259   auto scaled_normalized = add_binary(operand_shape, HloOpcode::kMultiply,
260                                       normalized, scale_broadcasted);
261 
262   // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset.
263   auto shifted_normalized = add_binary(operand_shape, HloOpcode::kAdd,
264                                        scaled_normalized, offset_broadcasted);
265 
266   auto tuple = HloInstruction::CreateTuple({shifted_normalized, mean, var});
267 
268   if (batch_norm->has_sharding()) {
269     int64 instruction_count_after = computation_->instruction_count();
270     CHECK_EQ(instruction_count_after,
271              instruction_count_before + added_instructions.size());
272     const HloSharding& sharding = batch_norm->sharding();
273     HloSharding operand_sharding =
274         sharding.GetAsShapeTree(batch_norm->shape()).element({0});
275     optional<int64> unique_device = batch_norm->sharding_unique_device();
276     HloSharding default_sharding =
277         unique_device.has_value()
278             ? HloSharding::AssignDevice(unique_device.value())
279             : HloSharding::Replicate();
280     for (HloInstruction* inst : added_instructions) {
281       if (ShapeUtil::Equal(inst->shape(), operand_shape)) {
282         inst->set_sharding(operand_sharding);
283       } else {
284         inst->set_sharding(default_sharding);
285       }
286     }
287     tuple->set_sharding(sharding);
288   }
289   TF_CHECK_OK(ReplaceWithNewInstruction(batch_norm, std::move(tuple)));
290   return Status::OK();
291 }
292 
HandleBatchNormInference(HloInstruction * batch_norm)293 Status BatchNormExpanderVisitor::HandleBatchNormInference(
294     HloInstruction* batch_norm) {
295   if (!rewrite_inference_op_) {
296     return Status::OK();
297   }
298   // Expand batch norm inference into smaller HLO ops.
299   HloInstruction* operand = batch_norm->mutable_operand(0);
300   const Shape operand_shape = operand->shape();
301   int64 feature_index = batch_norm->feature_index();
302   PrimitiveType ptype = operand_shape.element_type();
303 
304   HloInstruction* scale = batch_norm->mutable_operand(1);
305   HloInstruction* offset = batch_norm->mutable_operand(2);
306   HloInstruction* mean = batch_norm->mutable_operand(3);
307   HloInstruction* var = batch_norm->mutable_operand(4);
308   const Shape feature_shape = scale->shape();
309 
310   auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
311   TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
312   auto epsilon = computation_->AddInstruction(HloInstruction::CreateBroadcast(
313       feature_shape,
314       computation_->AddInstruction(
315           HloInstruction::CreateConstant(std::move(epsilon_literal))),
316       {}));
317 
318   std::vector<int64> dimensions_without_feature;
319 
320   for (int64 i = 0; i < operand_shape.rank(); ++i) {
321     if (i != feature_index) {
322       dimensions_without_feature.push_back(i);
323     }
324   }
325 
326   std::vector<HloInstruction*> added_instructions;
327   auto add = [&](std::unique_ptr<HloInstruction> inst) {
328     HloInstruction* added_inst = computation_->AddInstruction(std::move(inst));
329     added_inst->set_metadata(batch_norm->metadata());
330     added_instructions.push_back(added_inst);
331     return added_inst;
332   };
333   auto add_binary = [&](const Shape& shape, const HloOpcode opcode,
334                         HloInstruction* a, HloInstruction* b) {
335     return add(HloInstruction::CreateBinary(shape, opcode, a, b));
336   };
337   auto feature_broadcast = [&](HloInstruction* a) {
338     return add(
339         HloInstruction::CreateBroadcast(operand_shape, a, {feature_index}));
340   };
341 
342   int64 instruction_count_before = computation_->instruction_count();
343   auto true_scale = add_binary(
344       feature_shape, HloOpcode::kMultiply, scale,
345       add(Rsqrt(add_binary(feature_shape, HloOpcode::kAdd, var, epsilon),
346                 add)));
347   auto true_shift = add_binary(
348       feature_shape, HloOpcode::kSubtract, offset,
349       add_binary(feature_shape, HloOpcode::kMultiply, mean, true_scale));
350 
351   auto shifted_normalized =
352       add_binary(operand_shape, HloOpcode::kAdd,
353                  add_binary(operand_shape, HloOpcode::kMultiply, operand,
354                             feature_broadcast(true_scale)),
355                  feature_broadcast(true_shift));
356 
357   int64 instruction_count_after = computation_->instruction_count();
358   CHECK_EQ(instruction_count_after,
359            instruction_count_before + added_instructions.size());
360   if (batch_norm->has_sharding()) {
361     const HloSharding& sharding = batch_norm->sharding();
362     optional<int64> unique_device = batch_norm->sharding_unique_device();
363     HloSharding default_sharding =
364         unique_device.has_value()
365             ? HloSharding::AssignDevice(unique_device.value())
366             : HloSharding::Replicate();
367     for (HloInstruction* inst : added_instructions) {
368       if (ShapeUtil::Equal(inst->shape(), operand_shape)) {
369         inst->set_sharding(sharding);
370       } else {
371         inst->set_sharding(default_sharding);
372       }
373     }
374     shifted_normalized->set_sharding(sharding);
375   }
376   TF_CHECK_OK(ReplaceInstruction(batch_norm, shifted_normalized));
377   return Status::OK();
378 }
379 
HandleBatchNormGrad(HloInstruction * batch_norm)380 Status BatchNormExpanderVisitor::HandleBatchNormGrad(
381     HloInstruction* batch_norm) {
382   // Use the following formulas to calculate gradients:
383   // scale_grad =
384   //   sum(output_grad * (activation - mean(activation))) * rsqrt(var + epsilon)
385   //
386   // offset_grad =
387   //   sum(output_grad)
388   //
389   // activation_grad =
390   //   1/N * scale * rsqrt(var + epsilon) *
391   //   (N * output_grad - sum(output_grad) - (activation - mean(activation)) *
392   //   sum(output_grad * (activation - mean(activation))) / (variance +
393   //   epsilon))
394   if (!rewrite_grad_op_) {
395     return Status::OK();
396   }
397   std::vector<HloInstruction*> added_instructions;
398   auto add = [&](std::unique_ptr<HloInstruction> inst) {
399     HloInstruction* added_inst = computation_->AddInstruction(std::move(inst));
400     added_inst->set_metadata(batch_norm->metadata());
401     added_instructions.push_back(added_inst);
402     return added_inst;
403   };
404   auto add_binary = [&](const Shape& shape, const HloOpcode opcode,
405                         HloInstruction* a, HloInstruction* b) {
406     return add(HloInstruction::CreateBinary(shape, opcode, a, b));
407   };
408   int64 instruction_count_before = computation_->instruction_count();
409 
410   HloInstruction* activation = batch_norm->mutable_operand(0);
411   const Shape activation_shape = activation->shape();
412   PrimitiveType ptype = activation_shape.element_type();
413   HloInstruction* scale = batch_norm->mutable_operand(1);
414   const Shape feature_shape = scale->shape();
415   HloInstruction* mean = batch_norm->mutable_operand(2);
416   HloInstruction* variance = batch_norm->mutable_operand(3);
417   HloInstruction* grad_output = batch_norm->mutable_operand(4);
418 
419   int64 feature_index = batch_norm->feature_index();
420 
421   auto elements_per_feature =
422       add(DynamicElementCountPerFeature(activation, feature_index, add));
423 
424   auto zero_literal = LiteralUtil::CreateR0(0.0f);
425   TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype));
426   auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal)));
427 
428   auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
429   TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
430   auto epsilon_scalar =
431       add(HloInstruction::CreateConstant(std::move(epsilon_literal)));
432   auto epsilon_activation = add(
433       HloInstruction::CreateBroadcast(activation_shape, epsilon_scalar, {}));
434   auto epsilon_feature =
435       add(HloInstruction::CreateBroadcast(feature_shape, epsilon_scalar, {}));
436 
437   std::vector<int64> dimensions_without_feature;
438 
439   for (int64 i = 0; i < activation_shape.rank(); ++i) {
440     if (i != feature_index) {
441       dimensions_without_feature.push_back(i);
442     }
443   }
444 
445   auto scale_broadcasted = add(HloInstruction::CreateBroadcast(
446       activation_shape, scale, {feature_index}));
447   auto variance_broadcasted = add(HloInstruction::CreateBroadcast(
448       activation_shape, variance, {feature_index}));
449 
450   // E[X].
451   auto mean_broadcasted = add(
452       HloInstruction::CreateBroadcast(activation_shape, mean, {feature_index}));
453 
454   // rsqrt[Var[X] + epsilon].
455   auto rsqrt_var_add_epsilon_broadcasted =
456       add(Rsqrt(add_binary(activation_shape, HloOpcode::kAdd,
457                            variance_broadcasted, epsilon_activation),
458                 add));
459 
460   auto rsqrt_var_add_epsilon = add(Rsqrt(
461       add_binary(feature_shape, HloOpcode::kAdd, variance, epsilon_feature),
462       add));
463 
464   // X - E[X].
465   auto activation_minus_mean = add_binary(
466       activation_shape, HloOpcode::kSubtract, activation, mean_broadcasted);
467 
468   // Grad[Y] * (X - E[X]).
469   auto grad_output_times_activation_minus_mean =
470       add_binary(activation_shape, HloOpcode::kMultiply, grad_output,
471                  activation_minus_mean);
472 
473   HloComputation* add_reduce_computation =
474       GetOrCreateScalarAddComputation(ptype);
475 
476   // sum(Grad[Y] * (X - E[X])).
477   auto sum_grad_output_times_activation_minus_mean =
478       add(HloInstruction::CreateReduce(
479           feature_shape, grad_output_times_activation_minus_mean, zero,
480           dimensions_without_feature, add_reduce_computation));
481 
482   // Grad[beta] = Sum(Grad[Y]).
483   auto grad_beta = add(HloInstruction::CreateReduce(
484       feature_shape, grad_output, zero, dimensions_without_feature,
485       add_reduce_computation));
486 
487   // Grad[scale] = Sum(Grad[Y] * (X - E[X]) * rsqrt[Var[X] + epsilon]).
488   auto grad_scale = add_binary(feature_shape, HloOpcode::kMultiply,
489                                sum_grad_output_times_activation_minus_mean,
490                                rsqrt_var_add_epsilon);
491 
492   // I2 = Sum(Grad[Y])
493   auto i2 = add(HloInstruction::CreateBroadcast(activation_shape, grad_beta,
494                                                 {feature_index}));
495 
496   // I3 = Sum(Grad[Y] * (X - E[X]))
497   auto i3 = add(HloInstruction::CreateBroadcast(
498       activation_shape, sum_grad_output_times_activation_minus_mean,
499       {feature_index}));
500 
501   // I4 = (X - E[X]) * I3
502   auto i4 = add_binary(activation_shape, HloOpcode::kMultiply, i3,
503                        activation_minus_mean);
504 
505   // I5 = I4 / (Var[X] + epsilon)
506   auto i5 = add_binary(activation_shape, HloOpcode::kDivide, i4,
507                        add_binary(activation_shape, HloOpcode::kAdd,
508                                   variance_broadcasted, epsilon_activation));
509 
510   // scale * rsqrt[Var[X] + epsilon] * 1/N
511   auto scale_times_rsqrt_var_add_epsilon =
512       add_binary(activation_shape, HloOpcode::kMultiply, scale_broadcasted,
513                  rsqrt_var_add_epsilon_broadcasted);
514 
515   scale_times_rsqrt_var_add_epsilon =
516       add(Mean(elements_per_feature, scale_times_rsqrt_var_add_epsilon, add));
517 
518   auto i1 = add_binary(activation_shape, HloOpcode::kMultiply, grad_output,
519                        add(HloInstruction::CreateBroadcast(
520                            activation_shape, elements_per_feature, {})));
521 
522   // I6 = I1 - I2 - I5
523   auto i6 = add_binary(
524       activation_shape, HloOpcode::kSubtract,
525       add_binary(activation_shape, HloOpcode::kSubtract, i1, i2), i5);
526 
527   // Grad[X] = scale * rsqrt[Var[X] + epsilon] * 1/N * I6.
528   auto grad_activation = add_binary(activation_shape, HloOpcode::kMultiply,
529                                     scale_times_rsqrt_var_add_epsilon, i6);
530   auto tuple =
531       HloInstruction::CreateTuple({grad_activation, grad_scale, grad_beta});
532   if (batch_norm->has_sharding()) {
533     const HloSharding& sharding = batch_norm->sharding();
534     int64 instruction_count_after = computation_->instruction_count();
535     CHECK_EQ(instruction_count_after,
536              instruction_count_before + added_instructions.size());
537     HloSharding activation_sharding =
538         sharding.GetAsShapeTree(batch_norm->shape()).element({0});
539     auto unique_device = batch_norm->sharding_unique_device();
540     HloSharding default_sharding =
541         unique_device.has_value()
542             ? HloSharding::AssignDevice(unique_device.value())
543             : HloSharding::Replicate();
544     for (HloInstruction* inst : added_instructions) {
545       if (ShapeUtil::Equal(inst->shape(), activation_shape)) {
546         inst->set_sharding(activation_sharding);
547       } else {
548         inst->set_sharding(default_sharding);
549       }
550     }
551     tuple->set_sharding(sharding);
552   }
553 
554   TF_CHECK_OK(ReplaceWithNewInstruction(batch_norm, std::move(tuple)));
555 
556   return Status::OK();
557 }
558 
Run(HloModule * module)559 StatusOr<bool> BatchNormExpander::Run(HloModule* module) {
560   XLA_VLOG_LINES(2, "BatchNormExpander::Run(), before:\n" + module->ToString());
561   bool changed = false;
562   for (auto* comp : module->MakeNonfusionComputations()) {
563     if (BatchNormExpanderVisitor::Run(comp, rewrite_training_op_,
564                                       rewrite_inference_op_,
565                                       rewrite_grad_op_)) {
566       changed = true;
567     }
568   }
569   XLA_VLOG_LINES(2, "BatchNormExpander::Run(), after:\n" + module->ToString());
570   return changed;
571 }
572 
573 }  // namespace xla
574