• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
17 
18 #include "absl/algorithm/container.h"
19 #include "absl/memory/memory.h"
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/xla/client/lib/comparators.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/client/xla_computation.h"
24 #include "tensorflow/compiler/xla/comparison_util.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/literal_util.h"
27 #include "tensorflow/compiler/xla/service/hlo_clone_context.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/service/hlo_module.h"
30 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
31 #include "tensorflow/compiler/xla/service/shape_inference.h"
32 #include "tensorflow/compiler/xla/util.h"
33 
34 namespace xla {
35 using absl::StrCat;
36 
MakeUnaryHlo(HloOpcode opcode,HloInstruction * operand)37 StatusOr<HloInstruction*> MakeUnaryHlo(HloOpcode opcode,
38                                        HloInstruction* operand) {
39   HloComputation* computation = operand->parent();
40   TF_ASSIGN_OR_RETURN(Shape unary_op_shape,
41                       ShapeInference::InferUnaryOpShape(opcode, operand));
42   return computation->AddInstruction(
43       HloInstruction::CreateUnary(unary_op_shape, opcode, operand));
44 }
45 
MakeBinaryHlo(HloOpcode opcode,HloInstruction * lhs,HloInstruction * rhs)46 StatusOr<HloInstruction*> MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs,
47                                         HloInstruction* rhs) {
48   HloComputation* computation = lhs->parent();
49   CHECK_EQ(computation, rhs->parent());
50   TF_ASSIGN_OR_RETURN(Shape binary_op_shape,
51                       ShapeInference::InferBinaryOpShape(opcode, lhs, rhs));
52   return computation->AddInstruction(
53       HloInstruction::CreateBinary(binary_op_shape, opcode, lhs, rhs));
54 }
55 
MakeCompareHlo(ComparisonDirection direction,HloInstruction * lhs,HloInstruction * rhs)56 StatusOr<HloInstruction*> MakeCompareHlo(ComparisonDirection direction,
57                                          HloInstruction* lhs,
58                                          HloInstruction* rhs) {
59   HloComputation* computation = lhs->parent();
60   CHECK_EQ(computation, rhs->parent());
61   TF_ASSIGN_OR_RETURN(
62       Shape binary_op_shape,
63       ShapeInference::InferBinaryOpShape(HloOpcode::kCompare, lhs, rhs));
64   return computation->AddInstruction(
65       HloInstruction::CreateCompare(binary_op_shape, lhs, rhs, direction));
66 }
67 
MakePadHlo(HloInstruction * operand,HloInstruction * padding_value,const PaddingConfig & padding_config)68 StatusOr<HloInstruction*> MakePadHlo(HloInstruction* operand,
69                                      HloInstruction* padding_value,
70                                      const PaddingConfig& padding_config) {
71   HloComputation* computation = operand->parent();
72   CHECK_EQ(computation, padding_value->parent());
73   TF_ASSIGN_OR_RETURN(
74       Shape pad_shape,
75       ShapeInference::InferPadShape(operand->shape(), padding_value->shape(),
76                                     padding_config));
77   return computation->AddInstruction(HloInstruction::CreatePad(
78       pad_shape, operand, padding_value, padding_config));
79 }
80 
MakeSliceHlo(HloInstruction * operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)81 StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand,
82                                        absl::Span<const int64> start_indices,
83                                        absl::Span<const int64> limit_indices,
84                                        absl::Span<const int64> strides) {
85   HloComputation* computation = operand->parent();
86   TF_ASSIGN_OR_RETURN(Shape slice_shape, ShapeInference::InferSliceShape(
87                                              operand->shape(), start_indices,
88                                              limit_indices, strides));
89   return computation->AddInstruction(HloInstruction::CreateSlice(
90       slice_shape, operand, start_indices, limit_indices, strides));
91 }
92 
MakeConvolveHlo(HloInstruction * lhs,HloInstruction * rhs,int64_t feature_group_count,int64_t batch_group_count,const Window & window,const ConvolutionDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config,absl::optional<PrimitiveType> preferred_element_type)93 StatusOr<HloInstruction*> MakeConvolveHlo(
94     HloInstruction* lhs, HloInstruction* rhs, int64_t feature_group_count,
95     int64_t batch_group_count, const Window& window,
96     const ConvolutionDimensionNumbers& dimension_numbers,
97     const PrecisionConfig& precision_config,
98     absl::optional<PrimitiveType> preferred_element_type) {
99   HloComputation* computation = lhs->parent();
100   CHECK_EQ(computation, rhs->parent());
101   TF_ASSIGN_OR_RETURN(
102       Shape convolve_shape,
103       ShapeInference::InferConvolveShape(
104           lhs->shape(), rhs->shape(), feature_group_count, batch_group_count,
105           window, dimension_numbers, preferred_element_type));
106   return computation->AddInstruction(HloInstruction::CreateConvolve(
107       convolve_shape, lhs, rhs, feature_group_count, batch_group_count, window,
108       dimension_numbers, precision_config));
109 }
110 
MakeTransposeHlo(HloInstruction * operand,absl::Span<const int64> dimensions)111 StatusOr<HloInstruction*> MakeTransposeHlo(HloInstruction* operand,
112                                            absl::Span<const int64> dimensions) {
113   HloComputation* computation = operand->parent();
114   TF_ASSIGN_OR_RETURN(
115       Shape transpose_shape,
116       ShapeInference::InferTransposeShape(operand->shape(), dimensions));
117   return computation->AddInstruction(
118       HloInstruction::CreateTranspose(transpose_shape, operand, dimensions));
119 }
120 
MakeReshapeHlo(const Shape & result_shape,HloInstruction * operand)121 StatusOr<HloInstruction*> MakeReshapeHlo(const Shape& result_shape,
122                                          HloInstruction* operand) {
123   HloComputation* computation = operand->parent();
124   return computation->AddInstruction(
125       HloInstruction::CreateReshape(result_shape, operand));
126 }
127 
MakeReshapeHlo(absl::Span<const int64> result_shape_dim_bounds,HloInstruction * operand)128 StatusOr<HloInstruction*> MakeReshapeHlo(
129     absl::Span<const int64> result_shape_dim_bounds, HloInstruction* operand) {
130   Shape new_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
131                                          result_shape_dim_bounds);
132   return MakeReshapeHlo(new_shape, operand);
133 }
134 
MakeDynamicSliceHlo(HloInstruction * operand,absl::Span<HloInstruction * const> start_indices,absl::Span<const int64> slice_sizes)135 StatusOr<HloInstruction*> MakeDynamicSliceHlo(
136     HloInstruction* operand, absl::Span<HloInstruction* const> start_indices,
137     absl::Span<const int64> slice_sizes) {
138   HloComputation* computation = operand->parent();
139   std::vector<Shape> scalar_start_indices_shapes(
140       start_indices.size(),
141       ShapeUtil::MakeShape(start_indices[0]->shape().element_type(), {}));
142   TF_ASSIGN_OR_RETURN(
143       Shape dynamic_slice_shape,
144       ShapeInference::InferDynamicSliceShape(
145           operand->shape(), scalar_start_indices_shapes, slice_sizes));
146   return computation->AddInstruction(HloInstruction::CreateDynamicSlice(
147       dynamic_slice_shape, operand, start_indices, slice_sizes));
148 }
149 
MakeDynamicSliceHlo(HloInstruction * operand,HloInstruction * start_indices,absl::Span<const int64> slice_sizes)150 StatusOr<HloInstruction*> MakeDynamicSliceHlo(
151     HloInstruction* operand, HloInstruction* start_indices,
152     absl::Span<const int64> slice_sizes) {
153   HloComputation* computation = operand->parent();
154   CHECK_EQ(computation, start_indices->parent());
155   int64_t rank = start_indices->shape().dimensions(0);
156   std::vector<HloInstruction*> scalar_start_indices;
157   for (int i = 0; i < rank; ++i) {
158     // TODO(b/118437727): Update callers to provide scalars directly.
159     auto slice = computation->AddInstruction(HloInstruction::CreateSlice(
160         ShapeUtil::MakeShape(start_indices->shape().element_type(), {1}),
161         start_indices, {i}, {i + 1}, {1}));
162     scalar_start_indices.push_back(
163         computation->AddInstruction(HloInstruction::CreateReshape(
164             ShapeUtil::MakeShape(start_indices->shape().element_type(), {}),
165             slice)));
166   }
167   std::vector<Shape> scalar_start_indices_shapes(
168       rank, ShapeUtil::MakeShape(start_indices->shape().element_type(), {}));
169   TF_ASSIGN_OR_RETURN(
170       Shape dynamic_slice_shape,
171       ShapeInference::InferDynamicSliceShape(
172           operand->shape(), scalar_start_indices_shapes, slice_sizes));
173   return computation->AddInstruction(HloInstruction::CreateDynamicSlice(
174       dynamic_slice_shape, operand, scalar_start_indices, slice_sizes));
175 }
176 
MakeDynamicUpdateSliceHlo(HloInstruction * operand,HloInstruction * update,HloInstruction * start_indices)177 StatusOr<HloInstruction*> MakeDynamicUpdateSliceHlo(
178     HloInstruction* operand, HloInstruction* update,
179     HloInstruction* start_indices) {
180   HloComputation* computation = operand->parent();
181   CHECK_EQ(computation, update->parent());
182   CHECK_EQ(computation, start_indices->parent());
183   int64_t rank = start_indices->shape().dimensions(0);
184   std::vector<HloInstruction*> scalar_start_indices;
185   for (int i = 0; i < rank; ++i) {
186     // TODO(b/118437727): Update callers to provide scalars directly.
187     auto slice = computation->AddInstruction(HloInstruction::CreateSlice(
188         ShapeUtil::MakeShape(start_indices->shape().element_type(), {1}),
189         start_indices, {i}, {i + 1}, {1}));
190     scalar_start_indices.push_back(
191         computation->AddInstruction(HloInstruction::CreateReshape(
192             ShapeUtil::MakeShape(start_indices->shape().element_type(), {}),
193             slice)));
194   }
195   std::vector<Shape> scalar_start_indices_shapes(
196       rank, ShapeUtil::MakeShape(start_indices->shape().element_type(), {}));
197   TF_ASSIGN_OR_RETURN(
198       Shape dynamic_update_slice_shape,
199       ShapeInference::InferDynamicUpdateSliceShape(
200           operand->shape(), update->shape(), scalar_start_indices_shapes));
201   return computation->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
202       dynamic_update_slice_shape, operand, update, scalar_start_indices));
203 }
204 
MakeBroadcastHlo(HloInstruction * operand,absl::Span<const int64> broadcast_dimensions,absl::Span<const int64> result_shape_bounds)205 HloInstruction* MakeBroadcastHlo(HloInstruction* operand,
206                                  absl::Span<const int64> broadcast_dimensions,
207                                  absl::Span<const int64> result_shape_bounds) {
208   HloComputation* computation = operand->parent();
209   Shape broadcast_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
210                                                result_shape_bounds);
211 
212   return computation->AddInstruction(HloInstruction::CreateBroadcast(
213       broadcast_shape, operand, broadcast_dimensions));
214 }
215 
MakeBroadcastHlo(HloInstruction * operand,absl::Span<const int64> broadcast_dimensions,const Shape & shape)216 HloInstruction* MakeBroadcastHlo(HloInstruction* operand,
217                                  absl::Span<const int64> broadcast_dimensions,
218                                  const Shape& shape) {
219   return MakeBroadcastHlo(operand, broadcast_dimensions, shape.dimensions());
220 }
221 
MakeGetTupleElementHlo(HloInstruction * operand,int64_t index)222 StatusOr<HloInstruction*> MakeGetTupleElementHlo(HloInstruction* operand,
223                                                  int64_t index) {
224   HloComputation* computation = operand->parent();
225 
226   TF_ASSIGN_OR_RETURN(
227       Shape gte_shape,
228       ShapeInference::InferGetTupleElementShape(operand->shape(), index));
229   return computation->AddInstruction(
230       HloInstruction::CreateGetTupleElement(gte_shape, operand, index));
231 }
232 
MakeConcatHlo(absl::Span<HloInstruction * const> operands,int64_t dimension)233 StatusOr<HloInstruction*> MakeConcatHlo(
234     absl::Span<HloInstruction* const> operands, int64_t dimension) {
235   CHECK_GT(operands.size(), 0);
236 
237   HloComputation* computation = operands[0]->parent();
238   CHECK(absl::c_all_of(operands, [&](HloInstruction* instr) {
239     return instr->parent() == computation;
240   }));
241 
242   std::vector<const Shape*> operand_shapes;
243   absl::c_transform(operands, std::back_inserter(operand_shapes),
244                     [](HloInstruction* instr) { return &instr->shape(); });
245 
246   TF_ASSIGN_OR_RETURN(Shape concat_shape, ShapeInference::InferConcatOpShape(
247                                               operand_shapes, dimension));
248   return computation->AddInstruction(
249       HloInstruction::CreateConcatenate(concat_shape, operands, dimension));
250 }
251 
MakeConvertToHlo(HloInstruction * hlo,PrimitiveType type)252 HloInstruction* MakeConvertToHlo(HloInstruction* hlo, PrimitiveType type) {
253   if (hlo->shape().element_type() == type) {
254     return hlo;
255   }
256   Shape shape = ShapeUtil::ChangeElementType(hlo->shape(), type);
257   hlo =
258       hlo->parent()->AddInstruction(HloInstruction::CreateConvert(shape, hlo));
259   CHECK_EQ(hlo->shape().element_type(), type);
260   return hlo;
261 }
262 
MakeBitcastConvertToHlo(HloInstruction * hlo,PrimitiveType type)263 HloInstruction* MakeBitcastConvertToHlo(HloInstruction* hlo,
264                                         PrimitiveType type) {
265   if (hlo->shape().element_type() == type) {
266     return hlo;
267   }
268   Shape shape = ShapeUtil::ChangeElementType(hlo->shape(), type);
269   // PRED are stored as one byte, PRED have a BitWidth of 1, avoid this problem
270   // by using a convert instead of bitcast convert.
271   if (type == PRED || hlo->shape().element_type() == PRED) {
272     return MakeConvertToHlo(hlo, type);
273   }
274   hlo = hlo->parent()->AddInstruction(
275       HloInstruction::CreateBitcastConvert(shape, hlo));
276   CHECK_EQ(hlo->shape().element_type(), type);
277   return hlo;
278 }
279 
MakeIotaHlo(HloComputation * computation,const Shape & shape,int64_t iota_dimension)280 HloInstruction* MakeIotaHlo(HloComputation* computation, const Shape& shape,
281                             int64_t iota_dimension) {
282   return computation->AddInstruction(
283       HloInstruction::CreateIota(shape, iota_dimension));
284 }
285 
MakeDotHlo(HloInstruction * lhs,HloInstruction * rhs,const DotDimensionNumbers & dim_numbers,const PrecisionConfig & precision_config,absl::optional<PrimitiveType> preferred_element_type)286 StatusOr<HloInstruction*> MakeDotHlo(
287     HloInstruction* lhs, HloInstruction* rhs,
288     const DotDimensionNumbers& dim_numbers,
289     const PrecisionConfig& precision_config,
290     absl::optional<PrimitiveType> preferred_element_type) {
291   HloComputation* computation = lhs->parent();
292   CHECK_EQ(computation, rhs->parent());
293   TF_ASSIGN_OR_RETURN(
294       Shape dot_shape,
295       ShapeInference::InferDotOpShape(lhs->shape(), rhs->shape(), dim_numbers,
296                                       preferred_element_type));
297   return computation->AddInstruction(HloInstruction::CreateDot(
298       dot_shape, lhs, rhs, dim_numbers, precision_config));
299 }
300 
MakeMapHlo(absl::Span<HloInstruction * const> operands,HloComputation * map_computation)301 StatusOr<HloInstruction*> MakeMapHlo(absl::Span<HloInstruction* const> operands,
302                                      HloComputation* map_computation) {
303   CHECK(!operands.empty()) << "Map Hlo requires at least one operand.";
304   HloComputation* computation = operands.front()->parent();
305   std::vector<const Shape*> operand_shapes;
306   int64_t max_operand_rank = 0;
307   for (const HloInstruction* operand : operands) {
308     CHECK_EQ(computation, operand->parent());
309     operand_shapes.push_back(&operand->shape());
310     max_operand_rank = std::max(max_operand_rank, operand->shape().rank());
311   }
312   std::vector<int64> map_dims(max_operand_rank);
313   std::iota(map_dims.begin(), map_dims.end(), 0);
314   TF_ASSIGN_OR_RETURN(
315       Shape map_shape,
316       ShapeInference::InferMapShape(
317           operand_shapes, map_computation->ComputeProgramShape(), map_dims));
318   return computation->AddInstruction(
319       HloInstruction::CreateMap(map_shape, operands, map_computation));
320 }
321 
MakeReduceHlo(HloInstruction * operand,HloInstruction * init_value,absl::Span<const int64> dimensions,HloOpcode binary_opcode)322 StatusOr<HloInstruction*> MakeReduceHlo(HloInstruction* operand,
323                                         HloInstruction* init_value,
324                                         absl::Span<const int64> dimensions,
325                                         HloOpcode binary_opcode) {
326   auto scalar_shape = ShapeUtil::MakeShape(operand->shape().element_type(), {});
327   auto result_shape = ShapeUtil::FilterDimensions(
328       [&](const int64_t dim) {
329         return !absl::c_linear_search(dimensions, dim);
330       },
331       operand->shape());
332   HloComputation* reduce_computation;
333   {
334     HloComputation::Builder b(operand->name() + ".reduce_sub_computation");
335     auto lhs = b.AddInstruction(
336         HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
337     auto rhs = b.AddInstruction(
338         HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
339     b.AddInstruction(
340         HloInstruction::CreateBinary(scalar_shape, binary_opcode, lhs, rhs));
341     reduce_computation =
342         operand->parent()->parent()->AddEmbeddedComputation(b.Build());
343   }
344 
345   return operand->parent()->AddInstruction(HloInstruction::CreateReduce(
346       result_shape, operand, init_value, dimensions, reduce_computation));
347 }
348 
MakeReduceHlo(HloInstruction * operand,HloInstruction * init_value,HloOpcode binary_opcode,HloModule * module)349 StatusOr<HloInstruction*> MakeReduceHlo(HloInstruction* operand,
350                                         HloInstruction* init_value,
351                                         HloOpcode binary_opcode,
352                                         HloModule* module) {
353   DCHECK_NE(nullptr, module);
354   std::vector<int64> all_dims(operand->shape().rank());
355   std::iota(all_dims.begin(), all_dims.end(), 0);
356 
357   auto scalar_shape = ShapeUtil::MakeShape(operand->shape().element_type(), {});
358   HloComputation* reduce_computation;
359   {
360     HloComputation::Builder b(operand->name() + ".reduce_sub_computation");
361     auto lhs = b.AddInstruction(
362         HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
363     auto rhs = b.AddInstruction(
364         HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
365     b.AddInstruction(
366         HloInstruction::CreateBinary(scalar_shape, binary_opcode, lhs, rhs));
367     reduce_computation = module->AddEmbeddedComputation(b.Build());
368   }
369 
370   return operand->parent()->AddInstruction(HloInstruction::CreateReduce(
371       scalar_shape, operand, init_value, all_dims, reduce_computation));
372 }
373 
MakeReverseHlo(HloInstruction * operand,absl::Span<const int64> dimensions)374 StatusOr<HloInstruction*> MakeReverseHlo(HloInstruction* operand,
375                                          absl::Span<const int64> dimensions) {
376   HloComputation* computation = operand->parent();
377   TF_ASSIGN_OR_RETURN(Shape reverse_shape, ShapeInference::InferReverseShape(
378                                                operand->shape(), dimensions));
379   return computation->AddInstruction(
380       HloInstruction::CreateReverse(reverse_shape, operand, dimensions));
381 }
382 
MakeSelectHlo(HloInstruction * pred,HloInstruction * on_true,HloInstruction * on_false,HloInstruction * derived_from)383 StatusOr<HloInstruction*> MakeSelectHlo(HloInstruction* pred,
384                                         HloInstruction* on_true,
385                                         HloInstruction* on_false,
386                                         HloInstruction* derived_from) {
387   HloComputation* computation = pred->parent();
388   DCHECK_EQ(computation, on_true->parent());
389   DCHECK_EQ(computation, on_false->parent());
390   Shape op_shape = on_true->shape();
391   if (ShapeUtil::IsScalar(pred->shape())) {
392     if (!ShapeUtil::IsScalar(op_shape) && !op_shape.IsTuple()) {
393       // If the output is not scalar, we need to broadcast the condition
394       // to match the contract of kSelect. For tuples, we use kTupleSelect
395       // which expects the condition to be a scalar.
396       pred = computation->AddInstruction(HloInstruction::CreateBroadcast(
397           ShapeUtil::ChangeElementType(op_shape, PrimitiveType::PRED), pred,
398           {}));
399       if (derived_from) {
400         derived_from->SetupDerivedInstruction(pred);
401       }
402     }
403   }
404   HloOpcode select_op_code =
405       op_shape.IsTuple() ? HloOpcode::kTupleSelect : HloOpcode::kSelect;
406   TF_ASSIGN_OR_RETURN(Shape select_shape,
407                       ShapeInference::InferTernaryOpShape(select_op_code, pred,
408                                                           on_true, on_false));
409   HloInstruction* select =
410       computation->AddInstruction(HloInstruction::CreateTernary(
411           select_shape, select_op_code, pred, on_true, on_false));
412   if (derived_from) {
413     derived_from->SetupDerivedInstruction(select);
414   }
415   return select;
416 }
417 
MakeSortHlo(const Shape & sort_shape,absl::Span<HloInstruction * const> operands,int64_t dimension_to_sort,bool is_stable,HloComputation::Builder * builder,HloModule * module)418 StatusOr<HloInstruction*> MakeSortHlo(
419     const Shape& sort_shape, absl::Span<HloInstruction* const> operands,
420     int64_t dimension_to_sort, bool is_stable, HloComputation::Builder* builder,
421     HloModule* module) {
422   CHECK(!operands.empty()) << "Sort Hlo requires at least one operand.";
423   HloComputation* compare_computation;
424   XlaBuilder b("Sort.Compare");
425   std::vector<PrimitiveType> operand_types(operands.size());
426   for (int64_t i = 0; i < operands.size(); ++i) {
427     operand_types[i] = operands[i]->shape().element_type();
428   }
429   XlaComputation comparator = CreateScalarLtComputation(operand_types, &b);
430   TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comparator.GetProgramShape());
431   HloModuleConfig config(program_shape);
432   TF_ASSIGN_OR_RETURN(auto new_module,
433                       HloModule::CreateFromProto(comparator.proto(), config));
434   HloCloneContext context(module);
435   compare_computation =
436       module->DeepCloneComputation(new_module->entry_computation(), &context);
437   return builder->AddInstruction(HloInstruction::CreateSort(
438       sort_shape, dimension_to_sort, operands, compare_computation, is_stable));
439 }
440 
CollapseFirstNDims(HloInstruction * operand,int64_t n)441 StatusOr<HloInstruction*> CollapseFirstNDims(HloInstruction* operand,
442                                              int64_t n) {
443   CHECK_GT(n, 0);
444 
445   const Shape& operand_shape = operand->shape();
446   CHECK_GE(operand_shape.dimensions_size(), n);
447   int64_t new_shape_leading_bound = 1;
448   for (int64_t i = 0; i < n; i++) {
449     new_shape_leading_bound *= operand_shape.dimensions(i);
450   }
451 
452   std::vector<int64> new_shape_dims;
453   new_shape_dims.reserve(operand_shape.dimensions_size() - n + 1);
454   new_shape_dims.push_back(new_shape_leading_bound);
455 
456   std::copy(operand_shape.dimensions().begin() + n,
457             operand_shape.dimensions().end(),
458             std::back_inserter(new_shape_dims));
459 
460   Shape output_shape =
461       ShapeUtil::MakeShape(operand_shape.element_type(), new_shape_dims);
462 
463   return MakeReshapeHlo(output_shape, operand);
464 }
465 
PrependDegenerateDims(HloInstruction * operand,int64_t n)466 StatusOr<HloInstruction*> PrependDegenerateDims(HloInstruction* operand,
467                                                 int64_t n) {
468   CHECK_GT(n, 0);
469   std::vector<int64> new_shape_dims;
470   const Shape& operand_shape = operand->shape();
471   new_shape_dims.reserve(n + operand_shape.dimensions_size());
472   new_shape_dims.insert(new_shape_dims.begin(), n, 1);
473   absl::c_copy(operand_shape.dimensions(), std::back_inserter(new_shape_dims));
474   return MakeReshapeHlo(new_shape_dims, operand);
475 }
476 
ExpandFirstDimIntoNDims(HloInstruction * operand,absl::Span<const int64> expanded_dims)477 StatusOr<HloInstruction*> ExpandFirstDimIntoNDims(
478     HloInstruction* operand, absl::Span<const int64> expanded_dims) {
479   CHECK_GT(operand->shape().dimensions_size(), 0);
480   CHECK_EQ(operand->shape().dimensions(0), Product(expanded_dims));
481 
482   std::vector<int64> expanded_shape_dim_bounds;
483   expanded_shape_dim_bounds.reserve(expanded_dims.size() +
484                                     operand->shape().dimensions_size() - 1);
485   absl::c_copy(expanded_dims, std::back_inserter(expanded_shape_dim_bounds));
486   std::copy(operand->shape().dimensions().begin() + 1,
487             operand->shape().dimensions().end(),
488             std::back_inserter(expanded_shape_dim_bounds));
489   Shape new_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
490                                          expanded_shape_dim_bounds);
491   return MakeReshapeHlo(new_shape, operand);
492 }
493 
ElideDegenerateDims(HloInstruction * operand,absl::Span<const int64> dims_to_elide)494 StatusOr<HloInstruction*> ElideDegenerateDims(
495     HloInstruction* operand, absl::Span<const int64> dims_to_elide) {
496   return MakeReshapeHlo(ShapeUtil::FilterDimensions(
497                             [&](int64_t dim) {
498                               return !absl::c_linear_search(dims_to_elide, dim);
499                             },
500                             operand->shape()),
501                         operand);
502 }
503 
InsertDegenerateDims(HloInstruction * operand,absl::Span<const int64> dims_to_insert)504 StatusOr<HloInstruction*> InsertDegenerateDims(
505     HloInstruction* operand, absl::Span<const int64> dims_to_insert) {
506   CHECK(absl::c_is_sorted(dims_to_insert));
507 
508   const Shape& operand_shape = operand->shape();
509   int64_t output_shape_rank =
510       operand_shape.dimensions_size() + dims_to_insert.size();
511   for (auto dim_to_insert : dims_to_insert) {
512     CHECK_LT(dim_to_insert, output_shape_rank);
513   }
514 
515   std::vector<int64> output_shape_dim_bounds;
516   output_shape_dim_bounds.reserve(output_shape_rank);
517   int64_t operand_dims_idx = 0;
518   int64_t dims_to_insert_idx = 0;
519   for (int64_t i = 0; i < output_shape_rank; ++i) {
520     if (dims_to_insert_idx < dims_to_insert.size() &&
521         i == dims_to_insert[dims_to_insert_idx]) {
522       output_shape_dim_bounds.push_back(1);
523       ++dims_to_insert_idx;
524     } else {
525       output_shape_dim_bounds.push_back(
526           operand_shape.dimensions(operand_dims_idx));
527       ++operand_dims_idx;
528     }
529   }
530 
531   Shape output_shape = ShapeUtil::MakeShape(operand_shape.element_type(),
532                                             output_shape_dim_bounds);
533   return MakeReshapeHlo(output_shape, operand);
534 }
535 
PadVectorWithZeros(HloInstruction * operand,int64_t zeros_to_prepend,int64_t zeros_to_append)536 StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand,
537                                              int64_t zeros_to_prepend,
538                                              int64_t zeros_to_append) {
539   HloComputation* computation = operand->parent();
540   CHECK_EQ(operand->shape().dimensions_size(), 1);
541   PaddingConfig padding_config;
542   PaddingConfig::PaddingConfigDimension padding_config_dim;
543   padding_config_dim.set_edge_padding_low(zeros_to_prepend);
544   padding_config_dim.set_edge_padding_high(zeros_to_append);
545   *padding_config.add_dimensions() = padding_config_dim;
546 
547   HloInstruction* zero =
548       computation->AddInstruction(HloInstruction::CreateConstant(
549           LiteralUtil::Zero(operand->shape().element_type())));
550   return MakePadHlo(operand, zero, padding_config);
551 }
552 
BroadcastZeros(HloComputation * computation,PrimitiveType element_type,absl::Span<const int64> broadcast_dimensions)553 HloInstruction* BroadcastZeros(HloComputation* computation,
554                                PrimitiveType element_type,
555                                absl::Span<const int64> broadcast_dimensions) {
556   HloInstruction* zero = computation->AddInstruction(
557       HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
558   return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{},
559                           /*result_shape_bounds=*/broadcast_dimensions);
560 }
561 
BroadcastOnes(HloComputation * computation,PrimitiveType element_type,absl::Span<const int64> broadcast_dimensions)562 HloInstruction* BroadcastOnes(HloComputation* computation,
563                               PrimitiveType element_type,
564                               absl::Span<const int64> broadcast_dimensions) {
565   HloInstruction* one = computation->AddInstruction(
566       HloInstruction::CreateConstant(LiteralUtil::One(element_type)));
567   return MakeBroadcastHlo(one, /*broadcast_dimensions=*/{},
568                           /*result_shape_bounds=*/broadcast_dimensions);
569 }
570 
MakeFusionInstruction(HloInstruction * fused,HloInstruction::FusionKind kind)571 StatusOr<HloInstruction*> MakeFusionInstruction(
572     HloInstruction* fused, HloInstruction::FusionKind kind) {
573   HloComputation* comp = fused->parent();
574   HloInstruction* fusion_instruction = comp->AddInstruction(
575       HloInstruction::CreateFusion(fused->shape(), kind, fused));
576   TF_RETURN_IF_ERROR(comp->ReplaceInstruction(fused, fusion_instruction));
577   return fusion_instruction;
578 }
579 
580 // Recursively creates a dummy op given a shape. Leaf nodes are broadcasted zero
581 // while internal nodes are tuples.
CreateDummyOp(HloComputation::Builder * b,const Shape & shape)582 HloInstruction* CreateDummyOp(HloComputation::Builder* b, const Shape& shape) {
583   if (shape.IsArray()) {
584     auto zero = b->AddInstruction(HloInstruction::CreateConstant(
585         LiteralUtil::Zero(shape.element_type())));
586     return b->AddInstruction(HloInstruction::CreateBroadcast(shape, zero, {}));
587   }
588   CHECK(shape.IsTuple());
589   std::vector<HloInstruction*> sub_instructions;
590   for (const Shape& subshape : shape.tuple_shapes()) {
591     sub_instructions.push_back(CreateDummyOp(b, subshape));
592   }
593   return b->AddInstruction(HloInstruction::CreateTuple(sub_instructions));
594 }
595 
CreateComputationWithSignature(absl::Span<const Shape * const> domain,const Shape & range,absl::string_view name)596 StatusOr<std::unique_ptr<HloComputation>> CreateComputationWithSignature(
597     absl::Span<const Shape* const> domain, const Shape& range,
598     absl::string_view name) {
599   HloComputation::Builder b{string(name)};
600   int64_t param_idx = 0;
601   for (const Shape* param_shape : domain) {
602     b.AddInstruction(HloInstruction::CreateParameter(
603         param_idx, *param_shape, StrCat("param.", param_idx)));
604     param_idx++;
605   }
606 
607   // We can't change the root type of a computation once it is created so create
608   // a dummy root instruction to give the computation the right root shape.  Use
609   // a (recursive) broadcast here to avoid creating large constants.
610   CreateDummyOp(&b, range);
611   return b.Build();
612 }
613 
614 }  // namespace xla
615