• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.h"
17 
18 #include <algorithm>
19 
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/types/span.h"
22 #include "tensorflow/compiler/xla/layout_util.h"
23 #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
24 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
25 #include "tensorflow/core/platform/errors.h"
26 #include "tensorflow/core/util/env_var.h"
27 
28 namespace xla {
29 namespace gpu {
30 
31 namespace {
32 
GetUniqueOutputTypeOfFusible(const HloInstruction & fusible)33 PrimitiveType GetUniqueOutputTypeOfFusible(const HloInstruction& fusible) {
34   auto outputs = GetOutputsOfFusible(fusible);
35   CHECK(!outputs.empty());
36   PrimitiveType first_output_type = outputs[0]->shape().element_type();
37   for (size_t i = 1; i < outputs.size(); ++i) {
38     PrimitiveType cur_output_type = outputs[i]->shape().element_type();
39     CHECK(first_output_type == cur_output_type)
40         << "Output types are expected to be unique, but see "
41         << PrimitiveType_Name(first_output_type) << " and "
42         << PrimitiveType_Name(cur_output_type);
43   }
44 
45   return first_output_type;
46 }
47 
48 class HorizontalLoopFusionImpl {
49  public:
HorizontalLoopFusionImpl(HloComputation * computation)50   explicit HorizontalLoopFusionImpl(HloComputation* computation)
51       : computation_(computation) {}
52 
~HorizontalLoopFusionImpl()53   ~HorizontalLoopFusionImpl() {}
54 
55   StatusOr<bool> Run();
56 
57  private:
58   Status Fuse(absl::Span<HloInstruction*> fused_fusion_instrs);
59 
60   // Horizontally fuses `fused_fusion_instrs`. It is required that each of
61   // `fused_fusion_instrs` is a kLoop fusion. Also, we require their numbers of
62   // outputs to be the same, so that each output will be fused/concatenated with
63   // the same number of outputs from other fused fusion instrs. Then, all the
64   // fused outputs still have the same shapes for kernel generation.
65   //
66   // Returns the fused computation in `uniq_computation` and the operands that
67   // are used by `uniq_computation`.
68   Status CreateFusedComputation(
69       absl::Span<HloInstruction*> fused_fusion_instrs,
70       std::unique_ptr<HloComputation>* uniq_computation,
71       std::vector<HloInstruction*>* bound_operands);
72 
73   // FusionCandidates collects profitable candidates for a given consumer
74   // instruction. GetNextSpanOfFusions() can then be iteratively invoked to
75   // acquire the next set of fusion candidates based on some heuristics.
76   class FusionCandidates {
77    public:
FusionCandidates(HloInstruction * consumer)78     explicit FusionCandidates(HloInstruction* consumer)
79         : fusible_instrs_(), pos_(0) {
80       Initialize(consumer);
81     }
82 
83     // Gets a span of fusions to be fused.
84     absl::Span<HloInstruction*> GetNextSpanOfFusions();
85 
86    private:
87     void Initialize(HloInstruction*);
88 
89     std::vector<HloInstruction*> fusible_instrs_;
90     // `pos_` points to the start position of the next span.
91     size_t pos_;
92   };
93 
94   HloComputation* computation_;
95 };  // HorizontalLoopFusionImpl
96 
IsFusibleCandidate(const HloInstruction & instr)97 bool IsFusibleCandidate(const HloInstruction& instr) {
98   // Require no further check for element-wise instructions.
99   if (instr.IsElementwise() && instr.operand_count() > 0) {
100     return true;
101   }
102 
103   // Exclude fusions other than kLoop.
104   if (!instr.IsLoopFusion()) {
105     return false;
106   }
107 
108   // Cannot support fusion who has multiple output types, because the
109   // concatenate (inserted for horizontal fusion) requires the same type
110   // for all of its operands.
111   auto outputs = GetOutputsOfFusible(instr);
112   CHECK(!outputs.empty());
113   const HloInstruction* first_output = outputs[0];
114   for (size_t i = 1; i < outputs.size(); ++i) {
115     if (first_output->shape().element_type() !=
116         outputs[i]->shape().element_type()) {
117       return false;
118     }
119   }
120 
121   return true;
122 }
123 
124 // Returns whether `instr` is a profitable candidate to be horizontally fused.
125 // Since the primary benefit of horizontal fusion comes from reducing the
126 // kernel launch overhead, we want to exclude the instructions with
127 // insignificant kernel launch overhead. In other words, we exclude instructions
128 // if their computation latencies are longer than launch latencies. We estimate
129 // the computation latency of a given instruction by its shapes and the
130 // instruction count in its fused computation. We roughly observe that if a
131 // fusion instruction has shapes smaller than `kShapeThreshold` and has fewer
132 // instructions than `kInstrCountThreshold`, it is launch-latency-bound and
133 // profitable by horizontal fusion.
IsProfitableFusionCandidate(const HloInstruction & instr)134 bool IsProfitableFusionCandidate(const HloInstruction& instr) {
135   constexpr int64_t kShapeThreshold = 128 * 2048;
136   constexpr int64_t kInstrCountThreshold = 30;
137   const HloInstruction* root = (instr.opcode() == HloOpcode::kFusion)
138                                    ? instr.fused_expression_root()
139                                    : &instr;
140 
141   // Too large shapes are not easily profitable.
142   if (root->opcode() == HloOpcode::kTuple) {
143     // Since all output shapes are the same, use the first shape as the
144     // representative.
145     Shape shape = root->operand(0)->shape();
146     if (ShapeUtil::ElementsIn(shape) > kShapeThreshold) {
147       return false;
148     }
149   } else {
150     Shape shape = root->shape();
151     if (ShapeUtil::ElementsIn(shape) > kShapeThreshold) {
152       return false;
153     }
154   }
155 
156   // Having too many instructions is not easily profitable.
157   if (instr.opcode() == HloOpcode::kFusion &&
158       instr.fused_instruction_count() > kInstrCountThreshold) {
159     return false;
160   }
161 
162   return true;
163 }
164 
165 // Returns whether `fusion_instr` has only row-major layouts.
166 // The horizontal fusion excludes computations with non-row-major layouts,
167 // because fusing computations with different layouts can result in uncoalesced
168 // memory accesses and cause great performance overhead.
HasOnlyRowMajorLayout(const HloInstruction & instr)169 bool HasOnlyRowMajorLayout(const HloInstruction& instr) {
170   if (instr.opcode() != HloOpcode::kFusion) {
171     return LayoutUtil::IsMonotonicWithDim0Major(instr.shape().layout());
172   }
173 
174   auto fused_instrs = instr.fused_instructions_computation()->instructions();
175   for (HloInstruction* i : fused_instrs) {
176     if (i->shape().layout().format() != DENSE) {
177       continue;
178     }
179     if (!LayoutUtil::IsMonotonicWithDim0Major(i->shape().layout())) {
180       return false;
181     }
182   }
183   return true;
184 }
185 
186 // Returns whether any operand of `instr` is a parameter instruction that
187 // is shared with `fusion_instrs`.
AnyOpndIsParamSharedAmongFusions(const HloInstruction * instr,const absl::flat_hash_set<HloInstruction * > & fusion_instrs)188 bool AnyOpndIsParamSharedAmongFusions(
189     const HloInstruction* instr,
190     const absl::flat_hash_set<HloInstruction*>& fusion_instrs) {
191   return absl::c_any_of(instr->operands(), [&](const HloInstruction* opnd) {
192     return opnd->opcode() == HloOpcode::kParameter &&
193            absl::c_any_of(opnd->users(), [&](const HloInstruction* user) {
194              return user != instr && fusion_instrs.contains(user);
195            });
196   });
197 }
198 
Initialize(HloInstruction * consumer)199 void HorizontalLoopFusionImpl::FusionCandidates::Initialize(
200     HloInstruction* consumer) {
201   // First, find out all potential target candidates. We will filter out
202   // unsupported/non-profitable cases below.
203   absl::flat_hash_set<HloInstruction*> fusible_candidates;
204   std::vector<HloInstruction*> ordered_fusible_candidates;
205   for (HloInstruction* opnd : consumer->operands()) {
206     HloInstruction* predecessor = opnd->LatestNonGteAncestor();
207     // We support kLoop fusion and element-wise HLOs now. We may extend the
208     // support list if needs arise.
209     if (IsFusibleCandidate(*predecessor)) {
210       if (fusible_candidates.insert(predecessor).second) {
211         // Add unseen fusion to ordered list.
212         ordered_fusible_candidates.push_back(predecessor);
213       }
214     }
215   }
216 
217   for (HloInstruction* instr : ordered_fusible_candidates) {
218     if (!IsConsumerTheOnlyNonRootUser(*instr, *consumer)) {
219       VLOG(2) << "Reject maybe illegal instr " << instr->ToString()
220               << "; including it may create cycles in HLO.";
221       continue;
222     } else if (!IsProfitableFusionCandidate(*instr)) {
223       VLOG(2) << "Reject may-not-be profitable fusion instr "
224               << instr->ToString();
225       continue;
226     } else if (!HasOnlyRowMajorLayout(*instr)) {
227       VLOG(2) << "Reject non-row-major fusion instr " << instr->ToString();
228       continue;
229     } else if (AnyOpndIsParamSharedAmongFusions(instr, fusible_candidates)) {
230       // Don't fuse fusions whose operands are parameter instructions that are
231       // shared among fusions because we cannot i/o alias the produced
232       // horizontal fusion due to the concat insertion.
233       VLOG(2) << "Reject the fusion instr because it shares parameter with"
234               << " other fusion candidates, instr: " << instr->ToString();
235       continue;
236     } else {
237       VLOG(2) << "Find a fusion candidate " << instr->ToString();
238       // Encapsulate it into a fusion computation for unified representation
239       // for later processing.
240       fusible_instrs_.push_back(instr);
241     }
242   }
243 
244   // Sort `fusible_instrs_` according to output types, the number of outputs,
245   // and instruction counts, because we only fuse instructions with the same
246   // number/type of outputs and whose computations have the same instruction
247   // count.
248   std::sort(
249       fusible_instrs_.begin(), fusible_instrs_.end(),
250       [&](const HloInstruction* a, const HloInstruction* b) {
251         if (GetUniqueOutputTypeOfFusible(*a) !=
252             GetUniqueOutputTypeOfFusible(*b)) {
253           return GetUniqueOutputTypeOfFusible(*a) <
254                  GetUniqueOutputTypeOfFusible(*b);
255         } else if (GetOutputSizeOfFusible(*a) != GetOutputSizeOfFusible(*b)) {
256           return GetOutputSizeOfFusible(*a) < GetOutputSizeOfFusible(*b);
257         } else {
258           return GetInstrCountOfFusible(*a) < GetInstrCountOfFusible(*b);
259         }
260       });
261 }
262 
263 // Gets a next span of fusion instructions to be fused.
264 absl::Span<HloInstruction*>
GetNextSpanOfFusions()265 HorizontalLoopFusionImpl::FusionCandidates::GetNextSpanOfFusions() {
266   if (pos_ >= fusible_instrs_.size()) {
267     return absl::Span<HloInstruction*>();
268   }
269 
270   // Fusing too many computations at a time may not be easily profitable and
271   // may increase compile time due to large kernels. Set a limit to it.
272   constexpr int64_t kMaxFusionBatchSize = 32;
273   // CUDA has a parameter size limit of ~4k bytes.
274   constexpr int64_t kMaxCudaParamSize = 4000;
275   size_t accum_io_size = 0;
276   auto reach_max_fusion_batch_size = [&](size_t left, size_t right) -> bool {
277     if (right - left >= kMaxFusionBatchSize) {
278       return true;
279     }
280 
281     accum_io_size += fusible_instrs_.at(right)->operand_count() +
282                      GetOutputSizeOfFusible(*fusible_instrs_.at(right));
283 
284     if (accum_io_size * 8 >= kMaxCudaParamSize) {
285       return true;
286     }
287 
288     return false;
289   };
290 
291   size_t left = pos_;
292   size_t right = pos_ + 1;
293   size_t first_output_size = GetOutputSizeOfFusible(*fusible_instrs_[left]);
294   PrimitiveType first_output_type =
295       GetUniqueOutputTypeOfFusible(*fusible_instrs_[left]);
296   for (; right < fusible_instrs_.size(); ++right) {
297     PrimitiveType cur_output_type =
298         GetUniqueOutputTypeOfFusible(*fusible_instrs_[right]);
299     if (first_output_type != cur_output_type) {
300       // Cannot fuse computations who have multiple output types.
301       break;
302     } else if (first_output_size !=
303                GetOutputSizeOfFusible(*fusible_instrs_[right])) {
304       // Cannot fuse computations who have different numbers of outputs.
305       break;
306     } else if (GetInstrCountOfFusible(*fusible_instrs_[left]) !=
307                GetInstrCountOfFusible(*fusible_instrs_[right])) {
308       // Do not fuse computations of different instruction counts as it may
309       // introduce control divergence. This is a very simple heuristic to avoid
310       // fusing computations with too much discrepancy and we may improve it
311       // when the needs arise.
312       break;
313     } else if (reach_max_fusion_batch_size(left, right)) {
314       // Hit max fusion batch size.
315       break;
316     }
317   }
318 
319   pos_ = right;
320   return absl::MakeSpan(fusible_instrs_).subspan(left, right - left);
321 }
322 
CreateFusedComputation(absl::Span<HloInstruction * > fused_fusion_instrs,std::unique_ptr<HloComputation> * uniq_computation,std::vector<HloInstruction * > * bound_operands)323 Status HorizontalLoopFusionImpl::CreateFusedComputation(
324     absl::Span<HloInstruction*> fused_fusion_instrs,
325     std::unique_ptr<HloComputation>* uniq_computation,
326     std::vector<HloInstruction*>* bound_operands) {
327   // First, build a computation with only params.
328   HloComputation::Builder b("horizontally_fused_computation");
329   size_t fused_comp_param_id = 0;
330   for (size_t i = 0; i < fused_fusion_instrs.size(); ++i) {
331     auto old_params = fused_fusion_instrs[i]->fused_parameters();
332     for (size_t j = 0; j < old_params.size(); ++j) {
333       HloInstruction* bound_opnd = fused_fusion_instrs[i]->mutable_operand(j);
334       // in a form of param_i_j
335       b.AddInstruction(HloInstruction::CreateParameter(
336           fused_comp_param_id++, bound_opnd->shape(),
337           absl::StrCat("param_", i, "_", j)));
338       bound_operands->push_back(bound_opnd);
339     }
340   }
341   // Always create a dummy tuple instruction to serve as the root of the
342   // computation, as the existence of a root instruction is required by the
343   // HloComputation. The real root instruction will replace it below.
344   HloInstruction* dummy_root = b.AddInstruction(
345       HloInstruction::CreateTuple(std::vector<HloInstruction*>{}));
346   *uniq_computation = b.Build(dummy_root);
347   HloComputation* comp = uniq_computation->get();
348 
349   // Preparing clone_map, which maps old operand to new operand.
350   absl::flat_hash_map<const HloInstruction*, HloInstruction*> clone_map;
351   size_t new_param_id = 0;
352   for (size_t i = 0; i < fused_fusion_instrs.size(); ++i) {
353     auto old_params = fused_fusion_instrs[i]->fused_parameters();
354     for (size_t j = 0; j < old_params.size(); ++j) {
355       HloInstruction* old_param = old_params[j];
356       HloInstruction* new_param = comp->parameter_instruction(new_param_id++);
357       clone_map.insert({old_param, new_param});
358     }
359   }
360 
361   // Clone every fused computation.
362   for (size_t i = 0; i < fused_fusion_instrs.size(); ++i) {
363     auto def_to_use_order = fused_fusion_instrs[i]
364                                 ->fused_instructions_computation()
365                                 ->MakeInstructionPostOrder();
366     for (HloInstruction* old_instr : def_to_use_order) {
367       if (old_instr->opcode() == HloOpcode::kParameter) {
368         // Parameters have been created.
369         continue;
370       }
371       std::vector<HloInstruction*> new_opnds;
372       for (HloInstruction* old_opnd : old_instr->operands()) {
373         CHECK(clone_map.find(old_opnd) != clone_map.end());
374         new_opnds.push_back(clone_map[old_opnd]);
375       }
376       HloInstruction* new_instr = comp->AddInstruction(
377           old_instr->CloneWithNewOperands(old_instr->shape(), new_opnds));
378       clone_map.insert({old_instr, new_instr});
379     }
380   }
381 
382   std::vector<HloInstruction*> concated_outputs;
383   // Since we require each fusion to have the same number of outputs, we can
384   // simply use the first fusion as the representative for output size.
385   size_t fused_instr_output_size =
386       GetOutputSizeOfFusible(*fused_fusion_instrs[0]);
387   for (size_t i = 0; i < fused_instr_output_size; ++i) {
388     std::vector<HloInstruction*> reshapes(fused_fusion_instrs.size());
389     for (size_t j = 0; j < fused_fusion_instrs.size(); ++j) {
390       const HloInstruction* old_output =
391           GetOutputsOfFusible(*fused_fusion_instrs[j])[i];
392       HloInstruction* new_output = clone_map[old_output];
393       TF_ASSIGN_OR_RETURN(
394           reshapes[j],
395           MakeReshapeHlo(ShapeUtil::MakeShapeWithLayout(
396                              new_output->shape().element_type(),
397                              {ShapeUtil::ElementsIn(new_output->shape())},
398                              /*minor_to_major=*/std::vector<int64>(1, 0)),
399                          new_output));
400     }
401     TF_ASSIGN_OR_RETURN(HloInstruction * concated_output,
402                         MakeConcatHlo(reshapes, 0));
403     concated_outputs.push_back(concated_output);
404   }
405 
406   // Make slices of outputs.
407   std::vector<HloInstruction*> output_slices(concated_outputs.size() *
408                                              fused_fusion_instrs.size());
409   for (size_t i = 0; i < concated_outputs.size(); ++i) {
410     HloInstruction* concated_output = concated_outputs[i];
411     int64_t slice_start = 0;
412     // Create a slice per fused computation.
413     for (size_t j = 0; j < fused_fusion_instrs.size(); ++j) {
414       const HloInstruction* old_output =
415           GetOutputsOfFusible(*fused_fusion_instrs[j])[i];
416       Shape shape = old_output->shape();
417       int64_t slice_limit = slice_start + ShapeUtil::ElementsIn(shape);
418       TF_ASSIGN_OR_RETURN(
419           output_slices[concated_outputs.size() * j + i],
420           MakeSliceHlo(concated_output, {slice_start}, {slice_limit},
421                        /*strides=*/{1}));
422       slice_start = slice_limit;
423     }
424   }
425 
426   // Make a tuple of output_slices.
427   HloInstruction* tuple =
428       comp->AddInstruction(HloInstruction::CreateTuple(output_slices));
429   comp->set_root_instruction(tuple, /*accept_different_shape=*/true);
430   TF_RETURN_IF_ERROR(comp->RemoveInstruction(dummy_root));
431 
432   return Status::OK();
433 }
434 
Fuse(absl::Span<HloInstruction * > fused_fusion_instrs)435 Status HorizontalLoopFusionImpl::Fuse(
436     absl::Span<HloInstruction*> fused_fusion_instrs) {
437   // Fuse fused_fusion_instrs and replace them with the new fused computation.
438   std::unique_ptr<HloComputation> uniq_computation;
439   std::vector<HloInstruction*> bound_operands;
440   TF_RETURN_IF_ERROR(CreateFusedComputation(
441       fused_fusion_instrs, &uniq_computation, &bound_operands));
442   HloComputation* fused_comp = computation_->parent()->AddEmbeddedComputation(
443       std::move(uniq_computation));
444   HloInstruction* hori_fusion_instr =
445       computation_->AddInstruction(HloInstruction::CreateFusion(
446           fused_comp->root_instruction()->shape(),
447           HloInstruction::FusionKind::kInput, bound_operands, fused_comp));
448   fused_comp->SetFusionInstruction(hori_fusion_instr);
449 
450   // Insert bitcasts and replace corresponding users. Note that we do not insert
451   // the bitcasts in the fused computation as it does not fit into the slice
452   // input fusion pattern. However, inserting bitcasts outside the fused
453   // computation creates no performance cost.
454   size_t total_output_id = 0;
455   for (size_t i = 0; i < fused_fusion_instrs.size(); ++i) {
456     std::vector<HloInstruction*> bitcasts;
457     HloInstruction* fused_instr = fused_fusion_instrs[i];
458     size_t num_outputs = GetOutputSizeOfFusible(*fused_instr);
459     for (size_t j = 0; j < num_outputs; ++j) {
460       const HloInstruction* output = GetOutputsOfFusible(*fused_instr)[j];
461       TF_ASSIGN_OR_RETURN(
462           HloInstruction * gep,
463           MakeGetTupleElementHlo(hori_fusion_instr, total_output_id++));
464       bitcasts.push_back(computation_->AddInstruction(
465           HloInstruction::CreateBitcast(output->shape(), gep)));
466     }
467     HloInstruction* bitcast_or_tuple =
468         (bitcasts.size() == 1) ? bitcasts.at(0)
469                                : computation_->AddInstruction(
470                                      HloInstruction::CreateTuple(bitcasts));
471     TF_RETURN_IF_ERROR(
472         computation_->ReplaceInstruction(fused_instr, bitcast_or_tuple));
473   }
474 
475   return Status::OK();
476 }
477 
Run()478 StatusOr<bool> HorizontalLoopFusionImpl::Run() {
479   bool changed = false;
480   XLA_VLOG_LINES(3, computation_->ToString());
481 
482   // Traverse from use to def. Bitcasts are placed after h-fusions to resolve
483   // shape mismatch but bitcasts could prevent future h-fusion from happening.
484   // So, a bottom-up, use-to-def order should be more favorable. It also helps
485   // to save compiler iterations to reach the fixed point.
486   std::vector<HloInstruction*> use_to_def_order =
487       computation_->MakeInstructionPostOrder();
488   absl::c_reverse(use_to_def_order);
489   for (size_t i = 0; i < use_to_def_order.size(); ++i) {
490     HloInstruction* consumer = use_to_def_order[i];
491     HorizontalLoopFusionImpl::FusionCandidates fusion_candidates(consumer);
492     while (true) {
493       auto fusibles = fusion_candidates.GetNextSpanOfFusions();
494       if (fusibles.empty()) {
495         break;
496       } else if (fusibles.size() == 1) {
497         // Skip; there is just one fused_instr.
498         continue;
499       }
500 
501       changed = true;
502       // Convert fusible into fusion_instrs to simplify the implementation of
503       // `Fuse()`.
504       std::vector<HloInstruction*> fusion_instrs;
505       for (HloInstruction* instr : fusibles) {
506         if (instr->opcode() == HloOpcode::kFusion) {
507           fusion_instrs.push_back(instr);
508         } else {
509           TF_ASSIGN_OR_RETURN(
510               HloInstruction * fusion_instr,
511               MakeFusionInstruction(instr, HloInstruction::FusionKind::kLoop));
512           fusion_instrs.push_back(fusion_instr);
513         }
514       }
515       TF_RETURN_IF_ERROR(Fuse(absl::MakeSpan(fusion_instrs)));
516     }
517   }
518 
519   return changed;
520 }
521 
522 }  // namespace
523 
RunOnComputation(HloComputation * computation)524 StatusOr<bool> GpuHorizontalLoopFusion::RunOnComputation(
525     HloComputation* computation) {
526   HorizontalLoopFusionImpl horizontal_fusion_impl(computation);
527   return horizontal_fusion_impl.Run();
528 }
529 
Run(HloModule * module)530 StatusOr<bool> GpuHorizontalLoopFusion::Run(HloModule* module) {
531   bool changed = false;
532   VLOG(2) << "Run horizontal fusion.";
533 
534   // Run on the entry computation is actually enough.
535   TF_ASSIGN_OR_RETURN(changed, RunOnComputation(module->entry_computation()));
536 
537   return changed;
538 }
539 
540 }  // namespace gpu
541 }  // namespace xla
542