• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
17 
18 #include <algorithm>
19 #include <iterator>
20 #include <memory>
21 #include <numeric>
22 #include <optional>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 #include "absl/algorithm/container.h"
28 #include "absl/strings/str_cat.h"
29 #include "tensorflow/compiler/xla/client/lib/comparators.h"
30 #include "tensorflow/compiler/xla/client/xla_builder.h"
31 #include "tensorflow/compiler/xla/client/xla_computation.h"
32 #include "tensorflow/compiler/xla/comparison_util.h"
33 #include "tensorflow/compiler/xla/literal.h"
34 #include "tensorflow/compiler/xla/literal_util.h"
35 #include "tensorflow/compiler/xla/service/hlo_clone_context.h"
36 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
37 #include "tensorflow/compiler/xla/service/hlo_module.h"
38 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
39 #include "tensorflow/compiler/xla/service/shape_inference.h"
40 #include "tensorflow/compiler/xla/util.h"
41 
42 namespace xla {
43 using absl::StrCat;
44 
MakeUnaryHlo(HloOpcode opcode,HloInstruction * operand,const OpMetadata * metadata)45 StatusOr<HloInstruction*> MakeUnaryHlo(HloOpcode opcode,
46                                        HloInstruction* operand,
47                                        const OpMetadata* metadata) {
48   HloComputation* computation = operand->parent();
49   TF_ASSIGN_OR_RETURN(Shape unary_op_shape,
50                       ShapeInference::InferUnaryOpShape(opcode, operand));
51   return computation->AddInstruction(
52       HloInstruction::CreateUnary(unary_op_shape, opcode, operand), metadata);
53 }
54 
MakeCopyHlo(HloInstruction * from,const Shape & to)55 HloInstruction* MakeCopyHlo(HloInstruction* from, const Shape& to) {
56   return from->AddInstruction(
57       HloInstruction::CreateUnary(to, HloOpcode::kCopy, from));
58 }
59 
MakeBinaryHlo(HloOpcode opcode,HloInstruction * lhs,HloInstruction * rhs,const OpMetadata * metadata)60 StatusOr<HloInstruction*> MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs,
61                                         HloInstruction* rhs,
62                                         const OpMetadata* metadata) {
63   HloComputation* computation = lhs->parent();
64   CHECK_EQ(computation, rhs->parent());
65   TF_ASSIGN_OR_RETURN(Shape binary_op_shape,
66                       ShapeInference::InferBinaryOpShape(opcode, lhs, rhs));
67   return computation->AddInstruction(
68       HloInstruction::CreateBinary(binary_op_shape, opcode, lhs, rhs),
69       metadata);
70 }
71 
MakeCompareHlo(ComparisonDirection direction,HloInstruction * lhs,HloInstruction * rhs,const OpMetadata * metadata)72 StatusOr<HloInstruction*> MakeCompareHlo(ComparisonDirection direction,
73                                          HloInstruction* lhs,
74                                          HloInstruction* rhs,
75                                          const OpMetadata* metadata) {
76   HloComputation* computation = lhs->parent();
77   CHECK_EQ(computation, rhs->parent());
78   TF_ASSIGN_OR_RETURN(
79       Shape binary_op_shape,
80       ShapeInference::InferBinaryOpShape(HloOpcode::kCompare, lhs, rhs));
81   return computation->AddInstruction(
82       HloInstruction::CreateCompare(binary_op_shape, lhs, rhs, direction),
83       metadata);
84 }
85 
MakePadHlo(HloInstruction * operand,HloInstruction * padding_value,const PaddingConfig & padding_config,const OpMetadata * metadata)86 StatusOr<HloInstruction*> MakePadHlo(HloInstruction* operand,
87                                      HloInstruction* padding_value,
88                                      const PaddingConfig& padding_config,
89                                      const OpMetadata* metadata) {
90   HloComputation* computation = operand->parent();
91   CHECK_EQ(computation, padding_value->parent());
92   TF_ASSIGN_OR_RETURN(
93       Shape pad_shape,
94       ShapeInference::InferPadShape(operand->shape(), padding_value->shape(),
95                                     padding_config));
96   return computation->AddInstruction(
97       HloInstruction::CreatePad(pad_shape, operand, padding_value,
98                                 padding_config),
99       metadata);
100 }
101 
MakeSliceHlo(HloInstruction * operand,absl::Span<const int64_t> start_indices,absl::Span<const int64_t> limit_indices,absl::Span<const int64_t> strides,const OpMetadata * metadata)102 StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand,
103                                        absl::Span<const int64_t> start_indices,
104                                        absl::Span<const int64_t> limit_indices,
105                                        absl::Span<const int64_t> strides,
106                                        const OpMetadata* metadata) {
107   HloComputation* computation = operand->parent();
108   TF_ASSIGN_OR_RETURN(Shape slice_shape, ShapeInference::InferSliceShape(
109                                              operand->shape(), start_indices,
110                                              limit_indices, strides));
111   return computation->AddInstruction(
112       HloInstruction::CreateSlice(slice_shape, operand, start_indices,
113                                   limit_indices, strides),
114       metadata);
115 }
116 
MakeConvolveHlo(HloInstruction * lhs,HloInstruction * rhs,int64_t feature_group_count,int64_t batch_group_count,const Window & window,const ConvolutionDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config,std::optional<PrimitiveType> preferred_element_type,const OpMetadata * metadata)117 StatusOr<HloInstruction*> MakeConvolveHlo(
118     HloInstruction* lhs, HloInstruction* rhs, int64_t feature_group_count,
119     int64_t batch_group_count, const Window& window,
120     const ConvolutionDimensionNumbers& dimension_numbers,
121     const PrecisionConfig& precision_config,
122     std::optional<PrimitiveType> preferred_element_type,
123     const OpMetadata* metadata) {
124   HloComputation* computation = lhs->parent();
125   CHECK_EQ(computation, rhs->parent());
126   TF_ASSIGN_OR_RETURN(
127       Shape convolve_shape,
128       ShapeInference::InferConvolveShape(
129           lhs->shape(), rhs->shape(), feature_group_count, batch_group_count,
130           window, dimension_numbers, preferred_element_type));
131   return computation->AddInstruction(
132       HloInstruction::CreateConvolve(
133           convolve_shape, lhs, rhs, feature_group_count, batch_group_count,
134           window, dimension_numbers, precision_config),
135       metadata);
136 }
137 
MakeTransposeHlo(HloInstruction * operand,absl::Span<const int64_t> dimensions)138 StatusOr<HloInstruction*> MakeTransposeHlo(
139     HloInstruction* operand, absl::Span<const int64_t> dimensions) {
140   TF_ASSIGN_OR_RETURN(
141       Shape transpose_shape,
142       ShapeInference::InferTransposeShape(operand->shape(), dimensions));
143   return operand->AddInstruction(
144       HloInstruction::CreateTranspose(transpose_shape, operand, dimensions));
145 }
146 
MakeReshapeHlo(const Shape & result_shape,HloInstruction * operand)147 StatusOr<HloInstruction*> MakeReshapeHlo(const Shape& result_shape,
148                                          HloInstruction* operand) {
149   return operand->AddInstruction(
150       HloInstruction::CreateReshape(result_shape, operand));
151 }
152 
MakeReshapeHlo(absl::Span<const int64_t> result_shape_dim_bounds,HloInstruction * operand)153 StatusOr<HloInstruction*> MakeReshapeHlo(
154     absl::Span<const int64_t> result_shape_dim_bounds,
155     HloInstruction* operand) {
156   Shape new_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
157                                          result_shape_dim_bounds);
158   return MakeReshapeHlo(new_shape, operand);
159 }
160 
MakeDynamicSliceHlo(HloInstruction * operand,absl::Span<HloInstruction * const> start_indices,absl::Span<const int64_t> slice_sizes,const OpMetadata * metadata)161 StatusOr<HloInstruction*> MakeDynamicSliceHlo(
162     HloInstruction* operand, absl::Span<HloInstruction* const> start_indices,
163     absl::Span<const int64_t> slice_sizes, const OpMetadata* metadata) {
164   HloComputation* computation = operand->parent();
165   std::vector<Shape> scalar_start_indices_shapes(
166       start_indices.size(),
167       ShapeUtil::MakeShape(start_indices[0]->shape().element_type(), {}));
168   TF_ASSIGN_OR_RETURN(
169       Shape dynamic_slice_shape,
170       ShapeInference::InferDynamicSliceShape(
171           operand->shape(), scalar_start_indices_shapes, slice_sizes));
172   return computation->AddInstruction(
173       HloInstruction::CreateDynamicSlice(dynamic_slice_shape, operand,
174                                          start_indices, slice_sizes),
175       metadata);
176 }
177 
MakeDynamicSliceHlo(HloInstruction * operand,HloInstruction * start_indices,absl::Span<const int64_t> slice_sizes,const OpMetadata * metadata)178 StatusOr<HloInstruction*> MakeDynamicSliceHlo(
179     HloInstruction* operand, HloInstruction* start_indices,
180     absl::Span<const int64_t> slice_sizes, const OpMetadata* metadata) {
181   HloComputation* computation = operand->parent();
182   CHECK_EQ(computation, start_indices->parent());
183   int64_t rank = start_indices->shape().dimensions(0);
184   std::vector<HloInstruction*> scalar_start_indices;
185   for (int i = 0; i < rank; ++i) {
186     // TODO(b/118437727): Update callers to provide scalars directly.
187     auto slice = computation->AddInstruction(HloInstruction::CreateSlice(
188         ShapeUtil::MakeShape(start_indices->shape().element_type(), {1}),
189         start_indices, {i}, {i + 1}, {1}));
190     scalar_start_indices.push_back(
191         computation->AddInstruction(HloInstruction::CreateReshape(
192             ShapeUtil::MakeShape(start_indices->shape().element_type(), {}),
193             slice)));
194   }
195   std::vector<Shape> scalar_start_indices_shapes(
196       rank, ShapeUtil::MakeShape(start_indices->shape().element_type(), {}));
197   TF_ASSIGN_OR_RETURN(
198       Shape dynamic_slice_shape,
199       ShapeInference::InferDynamicSliceShape(
200           operand->shape(), scalar_start_indices_shapes, slice_sizes));
201   return computation->AddInstruction(
202       HloInstruction::CreateDynamicSlice(dynamic_slice_shape, operand,
203                                          scalar_start_indices, slice_sizes),
204       metadata);
205 }
206 
MakeDynamicUpdateSliceHlo(HloInstruction * operand,HloInstruction * update,HloInstruction * start_indices,const OpMetadata * metadata)207 StatusOr<HloInstruction*> MakeDynamicUpdateSliceHlo(
208     HloInstruction* operand, HloInstruction* update,
209     HloInstruction* start_indices, const OpMetadata* metadata) {
210   HloComputation* computation = operand->parent();
211   CHECK_EQ(computation, update->parent());
212   CHECK_EQ(computation, start_indices->parent());
213   int64_t rank = start_indices->shape().dimensions(0);
214   std::vector<HloInstruction*> scalar_start_indices;
215   for (int i = 0; i < rank; ++i) {
216     // TODO(b/118437727): Update callers to provide scalars directly.
217     auto slice = computation->AddInstruction(HloInstruction::CreateSlice(
218         ShapeUtil::MakeShape(start_indices->shape().element_type(), {1}),
219         start_indices, {i}, {i + 1}, {1}));
220     scalar_start_indices.push_back(
221         computation->AddInstruction(HloInstruction::CreateReshape(
222             ShapeUtil::MakeShape(start_indices->shape().element_type(), {}),
223             slice)));
224   }
225   std::vector<Shape> scalar_start_indices_shapes(
226       rank, ShapeUtil::MakeShape(start_indices->shape().element_type(), {}));
227   TF_ASSIGN_OR_RETURN(
228       Shape dynamic_update_slice_shape,
229       ShapeInference::InferDynamicUpdateSliceShape(
230           operand->shape(), update->shape(), scalar_start_indices_shapes));
231   return computation->AddInstruction(
232       HloInstruction::CreateDynamicUpdateSlice(
233           dynamic_update_slice_shape, operand, update, scalar_start_indices),
234       metadata);
235 }
236 
MakeBroadcastHlo(HloInstruction * operand,absl::Span<const int64_t> broadcast_dimensions,absl::Span<const int64_t> result_shape_bounds,const OpMetadata * metadata)237 HloInstruction* MakeBroadcastHlo(HloInstruction* operand,
238                                  absl::Span<const int64_t> broadcast_dimensions,
239                                  absl::Span<const int64_t> result_shape_bounds,
240                                  const OpMetadata* metadata) {
241   HloComputation* computation = operand->parent();
242   Shape broadcast_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
243                                                result_shape_bounds);
244 
245   return computation->AddInstruction(
246       HloInstruction::CreateBroadcast(broadcast_shape, operand,
247                                       broadcast_dimensions),
248       metadata);
249 }
250 
MakeBroadcastHlo(HloInstruction * operand,absl::Span<const int64_t> broadcast_dimensions,const Shape & shape,const OpMetadata * metadata)251 HloInstruction* MakeBroadcastHlo(HloInstruction* operand,
252                                  absl::Span<const int64_t> broadcast_dimensions,
253                                  const Shape& shape,
254                                  const OpMetadata* metadata) {
255   return MakeBroadcastHlo(operand, broadcast_dimensions, shape.dimensions(),
256                           metadata);
257 }
258 
MakeGetTupleElementHlo(HloInstruction * operand,int64_t index,const OpMetadata * metadata)259 StatusOr<HloInstruction*> MakeGetTupleElementHlo(HloInstruction* operand,
260                                                  int64_t index,
261                                                  const OpMetadata* metadata) {
262   HloComputation* computation = operand->parent();
263 
264   TF_ASSIGN_OR_RETURN(
265       Shape gte_shape,
266       ShapeInference::InferGetTupleElementShape(operand->shape(), index));
267   return computation->AddInstruction(
268       HloInstruction::CreateGetTupleElement(gte_shape, operand, index),
269       metadata);
270 }
271 
MakeConcatHlo(absl::Span<HloInstruction * const> operands,int64_t dimension,const OpMetadata * metadata)272 StatusOr<HloInstruction*> MakeConcatHlo(
273     absl::Span<HloInstruction* const> operands, int64_t dimension,
274     const OpMetadata* metadata) {
275   CHECK_GT(operands.size(), 0);
276 
277   HloComputation* computation = operands[0]->parent();
278   CHECK(absl::c_all_of(operands, [&](HloInstruction* instr) {
279     return instr->parent() == computation;
280   }));
281 
282   std::vector<const Shape*> operand_shapes;
283   absl::c_transform(operands, std::back_inserter(operand_shapes),
284                     [](HloInstruction* instr) { return &instr->shape(); });
285 
286   TF_ASSIGN_OR_RETURN(Shape concat_shape, ShapeInference::InferConcatOpShape(
287                                               operand_shapes, dimension));
288   return computation->AddInstruction(
289       HloInstruction::CreateConcatenate(concat_shape, operands, dimension),
290       metadata);
291 }
292 
MakeConvertToHlo(HloInstruction * hlo,PrimitiveType type,const OpMetadata * metadata)293 HloInstruction* MakeConvertToHlo(HloInstruction* hlo, PrimitiveType type,
294                                  const OpMetadata* metadata) {
295   if (hlo->shape().element_type() == type) {
296     return hlo;
297   }
298   Shape shape = ShapeUtil::ChangeElementType(hlo->shape(), type);
299   hlo = hlo->parent()->AddInstruction(HloInstruction::CreateConvert(shape, hlo),
300                                       metadata);
301   CHECK_EQ(hlo->shape().element_type(), type);
302   return hlo;
303 }
304 
MakeBitcastHlo(HloInstruction * hlo,const Shape & shape,const OpMetadata * metadata)305 HloInstruction* MakeBitcastHlo(HloInstruction* hlo, const Shape& shape,
306                                const OpMetadata* metadata) {
307   return hlo->parent()->AddInstruction(
308       HloInstruction::CreateBitcast(shape, hlo), metadata);
309 }
310 
MakeBitcastConvertToHlo(HloInstruction * hlo,PrimitiveType type,const OpMetadata * metadata)311 HloInstruction* MakeBitcastConvertToHlo(HloInstruction* hlo, PrimitiveType type,
312                                         const OpMetadata* metadata) {
313   if (hlo->shape().element_type() == type) {
314     return hlo;
315   }
316   Shape shape = ShapeUtil::ChangeElementType(hlo->shape(), type);
317   // PRED are stored as one byte, PRED have a BitWidth of 1, avoid this problem
318   // by using a convert instead of bitcast convert.
319   if (type == PRED || hlo->shape().element_type() == PRED) {
320     return MakeConvertToHlo(hlo, type);
321   }
322   hlo = hlo->parent()->AddInstruction(
323       HloInstruction::CreateBitcastConvert(shape, hlo), metadata);
324   CHECK_EQ(hlo->shape().element_type(), type);
325   return hlo;
326 }
327 
MakeIotaHlo(HloComputation * computation,const Shape & shape,int64_t iota_dimension)328 HloInstruction* MakeIotaHlo(HloComputation* computation, const Shape& shape,
329                             int64_t iota_dimension) {
330   return computation->AddInstruction(
331       HloInstruction::CreateIota(shape, iota_dimension));
332 }
333 
MakeDotHlo(HloInstruction * lhs,HloInstruction * rhs,const DotDimensionNumbers & dim_numbers,const PrecisionConfig & precision_config,std::optional<PrimitiveType> preferred_element_type,const OpMetadata * metadata)334 StatusOr<HloInstruction*> MakeDotHlo(
335     HloInstruction* lhs, HloInstruction* rhs,
336     const DotDimensionNumbers& dim_numbers,
337     const PrecisionConfig& precision_config,
338     std::optional<PrimitiveType> preferred_element_type,
339     const OpMetadata* metadata) {
340   HloComputation* computation = lhs->parent();
341   CHECK_EQ(computation, rhs->parent());
342   TF_ASSIGN_OR_RETURN(
343       Shape dot_shape,
344       ShapeInference::InferDotOpShape(lhs->shape(), rhs->shape(), dim_numbers,
345                                       preferred_element_type));
346   return computation->AddInstruction(
347       HloInstruction::CreateDot(dot_shape, lhs, rhs, dim_numbers,
348                                 precision_config),
349       metadata);
350 }
351 
MakeMapHlo(absl::Span<HloInstruction * const> operands,HloComputation * map_computation,const OpMetadata * metadata)352 StatusOr<HloInstruction*> MakeMapHlo(absl::Span<HloInstruction* const> operands,
353                                      HloComputation* map_computation,
354                                      const OpMetadata* metadata) {
355   CHECK(!operands.empty()) << "Map Hlo requires at least one operand.";
356   HloComputation* computation = operands.front()->parent();
357   std::vector<const Shape*> operand_shapes;
358   int64_t max_operand_rank = 0;
359   for (const HloInstruction* operand : operands) {
360     CHECK_EQ(computation, operand->parent());
361     operand_shapes.push_back(&operand->shape());
362     max_operand_rank = std::max(max_operand_rank, operand->shape().rank());
363   }
364   std::vector<int64_t> map_dims(max_operand_rank);
365   std::iota(map_dims.begin(), map_dims.end(), 0);
366   TF_ASSIGN_OR_RETURN(
367       Shape map_shape,
368       ShapeInference::InferMapShape(
369           operand_shapes, map_computation->ComputeProgramShape(), map_dims));
370   return computation->AddInstruction(
371       HloInstruction::CreateMap(map_shape, operands, map_computation),
372       metadata);
373 }
374 
MakeReducePrecisionHlo(HloInstruction * operand,int exponent_bits,int mantissa_bits,const OpMetadata * metadata)375 HloInstruction* MakeReducePrecisionHlo(HloInstruction* operand,
376                                        int exponent_bits, int mantissa_bits,
377                                        const OpMetadata* metadata) {
378   return operand->parent()->AddInstruction(
379       HloInstruction::CreateReducePrecision(operand->shape(), operand,
380                                             exponent_bits, mantissa_bits),
381       metadata);
382 }
383 
MakeReduceHlo(HloInstruction * operand,HloInstruction * init_value,absl::Span<const int64_t> dimensions,HloComputation * reduce_computation,const OpMetadata * metadata)384 StatusOr<HloInstruction*> MakeReduceHlo(HloInstruction* operand,
385                                         HloInstruction* init_value,
386                                         absl::Span<const int64_t> dimensions,
387                                         HloComputation* reduce_computation,
388                                         const OpMetadata* metadata) {
389   auto scalar_shape = ShapeUtil::MakeShape(operand->shape().element_type(), {});
390   auto result_shape = ShapeUtil::DeleteDimensions(dimensions, operand->shape());
391 
392   return operand->parent()->AddInstruction(
393       HloInstruction::CreateReduce(result_shape, operand, init_value,
394                                    dimensions, reduce_computation),
395       metadata);
396 }
397 
MakeReduceHlo(HloInstruction * operand,HloInstruction * init_value,absl::Span<const int64_t> dimensions,HloOpcode binary_opcode,const OpMetadata * metadata)398 StatusOr<HloInstruction*> MakeReduceHlo(HloInstruction* operand,
399                                         HloInstruction* init_value,
400                                         absl::Span<const int64_t> dimensions,
401                                         HloOpcode binary_opcode,
402                                         const OpMetadata* metadata) {
403   auto scalar_shape = ShapeUtil::MakeShape(operand->shape().element_type(), {});
404   HloComputation* reduce_computation;
405   {
406     HloComputation::Builder b(operand->name() + ".reduce_sub_computation");
407     auto lhs = b.AddInstruction(
408         HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
409     auto rhs = b.AddInstruction(
410         HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
411     b.AddInstruction(
412         HloInstruction::CreateBinary(scalar_shape, binary_opcode, lhs, rhs));
413     reduce_computation =
414         operand->parent()->parent()->AddEmbeddedComputation(b.Build());
415   }
416   return MakeReduceHlo(operand, init_value, dimensions, reduce_computation,
417                        metadata);
418 }
419 
MakeReduceHlo(HloInstruction * operand,HloInstruction * init_value,HloOpcode binary_opcode,HloModule * module,const OpMetadata * metadata)420 StatusOr<HloInstruction*> MakeReduceHlo(HloInstruction* operand,
421                                         HloInstruction* init_value,
422                                         HloOpcode binary_opcode,
423                                         HloModule* module,
424                                         const OpMetadata* metadata) {
425   DCHECK_NE(nullptr, module);
426   std::vector<int64_t> all_dims(operand->shape().rank());
427   std::iota(all_dims.begin(), all_dims.end(), 0);
428 
429   auto scalar_shape = ShapeUtil::MakeShape(operand->shape().element_type(), {});
430   HloComputation* reduce_computation;
431   {
432     HloComputation::Builder b(operand->name() + ".reduce_sub_computation");
433     auto lhs = b.AddInstruction(
434         HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
435     auto rhs = b.AddInstruction(
436         HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
437     b.AddInstruction(
438         HloInstruction::CreateBinary(scalar_shape, binary_opcode, lhs, rhs));
439     reduce_computation = module->AddEmbeddedComputation(b.Build());
440   }
441   return MakeReduceHlo(operand, init_value, all_dims, reduce_computation,
442                        metadata);
443 }
444 
MakeReduceHlo(absl::Span<HloInstruction * const> operands,absl::Span<HloInstruction * const> init_values,absl::Span<const int64_t> dimensions,HloComputation * reduce_computation,const OpMetadata * metadata)445 StatusOr<HloInstruction*> MakeReduceHlo(
446     absl::Span<HloInstruction* const> operands,
447     absl::Span<HloInstruction* const> init_values,
448     absl::Span<const int64_t> dimensions, HloComputation* reduce_computation,
449     const OpMetadata* metadata) {
450   CHECK(!operands.empty());
451   CHECK_EQ(operands.size(), init_values.size());
452   auto root = reduce_computation->root_instruction();
453   if (root->shape().IsTuple()) {
454     CHECK_EQ(root->shape().tuple_shapes_size(), operands.size());
455   } else {
456     CHECK_EQ(operands.size(), 1);
457   }
458 
459   std::vector<Shape> expected_shapes;
460   for (auto operand : operands) {
461     expected_shapes.push_back(ShapeUtil::FilterDimensions(
462         [&](const int64_t dim) {
463           return !absl::c_linear_search(dimensions, dim);
464         },
465         operand->shape()));
466   }
467 
468   auto output_shape = ShapeUtil::MakeMaybeTupleShape(expected_shapes);
469 
470   return operands[0]->parent()->AddInstruction(
471       HloInstruction::CreateReduce(output_shape, operands, init_values,
472                                    dimensions, reduce_computation),
473       metadata);
474 }
475 
MakeReverseHlo(HloInstruction * operand,absl::Span<const int64_t> dimensions,const OpMetadata * metadata)476 StatusOr<HloInstruction*> MakeReverseHlo(HloInstruction* operand,
477                                          absl::Span<const int64_t> dimensions,
478                                          const OpMetadata* metadata) {
479   HloComputation* computation = operand->parent();
480   TF_ASSIGN_OR_RETURN(Shape reverse_shape, ShapeInference::InferReverseShape(
481                                                operand->shape(), dimensions));
482   return computation->AddInstruction(
483       HloInstruction::CreateReverse(reverse_shape, operand, dimensions),
484       metadata);
485 }
486 
MakeSelectHlo(HloInstruction * pred,HloInstruction * on_true,HloInstruction * on_false,HloInstruction * derived_from)487 StatusOr<HloInstruction*> MakeSelectHlo(HloInstruction* pred,
488                                         HloInstruction* on_true,
489                                         HloInstruction* on_false,
490                                         HloInstruction* derived_from) {
491   HloComputation* computation = pred->parent();
492   DCHECK_EQ(computation, on_true->parent());
493   DCHECK_EQ(computation, on_false->parent());
494   Shape op_shape = on_true->shape();
495   if (ShapeUtil::IsScalar(pred->shape())) {
496     if (!ShapeUtil::IsScalar(op_shape) && !op_shape.IsTuple()) {
497       // If the output is not scalar, we need to broadcast the condition
498       // to match the contract of kSelect.
499       pred = computation->AddInstruction(HloInstruction::CreateBroadcast(
500           ShapeUtil::ChangeElementType(op_shape, PrimitiveType::PRED), pred,
501           {}));
502       if (derived_from) {
503         derived_from->SetupDerivedInstruction(pred);
504       }
505     }
506   }
507   TF_RET_CHECK(!op_shape.IsTuple());
508   HloOpcode select_op_code = HloOpcode::kSelect;
509   TF_ASSIGN_OR_RETURN(Shape select_shape,
510                       ShapeInference::InferTernaryOpShape(select_op_code, pred,
511                                                           on_true, on_false));
512   HloInstruction* select =
513       computation->AddInstruction(HloInstruction::CreateTernary(
514           select_shape, select_op_code, pred, on_true, on_false));
515   if (derived_from) {
516     derived_from->SetupDerivedInstruction(select);
517   }
518   return select;
519 }
520 
MaybeMakeTuple(absl::Span<HloInstruction * const> operands)521 HloInstruction* MaybeMakeTuple(absl::Span<HloInstruction* const> operands) {
522   CHECK(!operands.empty());
523   if (operands.size() == 1) {
524     return operands[0];
525   }
526   return operands[0]->parent()->AddInstruction(
527       HloInstruction::CreateTuple(operands));
528 }
529 
MakeSortHlo(const Shape & sort_shape,absl::Span<HloInstruction * const> operands,int64_t dimension_to_sort,bool is_stable,HloComputation::Builder * builder,HloModule * module,const OpMetadata * metadata)530 StatusOr<HloInstruction*> MakeSortHlo(
531     const Shape& sort_shape, absl::Span<HloInstruction* const> operands,
532     int64_t dimension_to_sort, bool is_stable, HloComputation::Builder* builder,
533     HloModule* module, const OpMetadata* metadata) {
534   CHECK(!operands.empty()) << "Sort Hlo requires at least one operand.";
535   HloComputation* compare_computation;
536   XlaBuilder b("Sort.Compare");
537   if (metadata != nullptr) {
538     b.SetOpMetadata(*metadata);
539   }
540   std::vector<PrimitiveType> operand_types(operands.size());
541   for (int64_t i = 0; i < operands.size(); ++i) {
542     operand_types[i] = operands[i]->shape().element_type();
543   }
544   XlaComputation comparator = CreateScalarLtComputation(operand_types, &b);
545   TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comparator.GetProgramShape());
546   HloModuleConfig config(program_shape);
547   TF_ASSIGN_OR_RETURN(auto new_module,
548                       HloModule::CreateFromProto(comparator.proto(), config));
549   HloCloneContext context(module);
550   compare_computation =
551       module->DeepCloneComputation(new_module->entry_computation(), &context);
552   return builder->AddInstruction(HloInstruction::CreateSort(
553       sort_shape, dimension_to_sort, operands, compare_computation, is_stable));
554 }
555 
CollapseFirstNDims(HloInstruction * operand,int64_t n)556 StatusOr<HloInstruction*> CollapseFirstNDims(HloInstruction* operand,
557                                              int64_t n) {
558   CHECK_GT(n, 0);
559 
560   const Shape& operand_shape = operand->shape();
561   CHECK_GE(operand_shape.dimensions_size(), n);
562   int64_t new_shape_leading_bound = 1;
563   for (int64_t i = 0; i < n; i++) {
564     new_shape_leading_bound *= operand_shape.dimensions(i);
565   }
566 
567   std::vector<int64_t> new_shape_dims;
568   new_shape_dims.reserve(operand_shape.dimensions_size() - n + 1);
569   new_shape_dims.push_back(new_shape_leading_bound);
570 
571   std::copy(operand_shape.dimensions().begin() + n,
572             operand_shape.dimensions().end(),
573             std::back_inserter(new_shape_dims));
574 
575   Shape output_shape =
576       ShapeUtil::MakeShape(operand_shape.element_type(), new_shape_dims);
577 
578   return MakeReshapeHlo(output_shape, operand);
579 }
580 
PrependDegenerateDims(HloInstruction * operand,int64_t n)581 StatusOr<HloInstruction*> PrependDegenerateDims(HloInstruction* operand,
582                                                 int64_t n) {
583   CHECK_GT(n, 0);
584   std::vector<int64_t> new_shape_dims;
585   const Shape& operand_shape = operand->shape();
586   new_shape_dims.reserve(n + operand_shape.dimensions_size());
587   new_shape_dims.insert(new_shape_dims.begin(), n, 1);
588   absl::c_copy(operand_shape.dimensions(), std::back_inserter(new_shape_dims));
589   return MakeReshapeHlo(new_shape_dims, operand);
590 }
591 
ExpandFirstDimIntoNDims(HloInstruction * operand,absl::Span<const int64_t> expanded_dims)592 StatusOr<HloInstruction*> ExpandFirstDimIntoNDims(
593     HloInstruction* operand, absl::Span<const int64_t> expanded_dims) {
594   CHECK_GT(operand->shape().dimensions_size(), 0);
595   CHECK_EQ(operand->shape().dimensions(0), Product(expanded_dims));
596 
597   std::vector<int64_t> expanded_shape_dim_bounds;
598   expanded_shape_dim_bounds.reserve(expanded_dims.size() +
599                                     operand->shape().dimensions_size() - 1);
600   absl::c_copy(expanded_dims, std::back_inserter(expanded_shape_dim_bounds));
601   std::copy(operand->shape().dimensions().begin() + 1,
602             operand->shape().dimensions().end(),
603             std::back_inserter(expanded_shape_dim_bounds));
604   Shape new_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
605                                          expanded_shape_dim_bounds);
606   return MakeReshapeHlo(new_shape, operand);
607 }
608 
ElideDegenerateDims(HloInstruction * operand,absl::Span<const int64_t> dims_to_elide)609 StatusOr<HloInstruction*> ElideDegenerateDims(
610     HloInstruction* operand, absl::Span<const int64_t> dims_to_elide) {
611   return MakeReshapeHlo(ShapeUtil::FilterDimensions(
612                             [&](int64_t dim) {
613                               return !absl::c_linear_search(dims_to_elide, dim);
614                             },
615                             operand->shape()),
616                         operand);
617 }
618 
InsertDegenerateDims(HloInstruction * operand,absl::Span<const int64_t> dims_to_insert)619 StatusOr<HloInstruction*> InsertDegenerateDims(
620     HloInstruction* operand, absl::Span<const int64_t> dims_to_insert) {
621   CHECK(absl::c_is_sorted(dims_to_insert));
622 
623   const Shape& operand_shape = operand->shape();
624   int64_t output_shape_rank =
625       operand_shape.dimensions_size() + dims_to_insert.size();
626   for (auto dim_to_insert : dims_to_insert) {
627     CHECK_LT(dim_to_insert, output_shape_rank);
628   }
629 
630   std::vector<int64_t> output_shape_dim_bounds;
631   output_shape_dim_bounds.reserve(output_shape_rank);
632   int64_t operand_dims_idx = 0;
633   int64_t dims_to_insert_idx = 0;
634   for (int64_t i = 0; i < output_shape_rank; ++i) {
635     if (dims_to_insert_idx < dims_to_insert.size() &&
636         i == dims_to_insert[dims_to_insert_idx]) {
637       output_shape_dim_bounds.push_back(1);
638       ++dims_to_insert_idx;
639     } else {
640       output_shape_dim_bounds.push_back(
641           operand_shape.dimensions(operand_dims_idx));
642       ++operand_dims_idx;
643     }
644   }
645 
646   Shape output_shape = ShapeUtil::MakeShape(operand_shape.element_type(),
647                                             output_shape_dim_bounds);
648   return MakeReshapeHlo(output_shape, operand);
649 }
650 
PadVectorWithZeros(HloInstruction * operand,int64_t zeros_to_prepend,int64_t zeros_to_append)651 StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand,
652                                              int64_t zeros_to_prepend,
653                                              int64_t zeros_to_append) {
654   HloComputation* computation = operand->parent();
655   CHECK_EQ(operand->shape().dimensions_size(), 1);
656   PaddingConfig padding_config;
657   PaddingConfig::PaddingConfigDimension padding_config_dim;
658   padding_config_dim.set_edge_padding_low(zeros_to_prepend);
659   padding_config_dim.set_edge_padding_high(zeros_to_append);
660   *padding_config.add_dimensions() = padding_config_dim;
661 
662   HloInstruction* zero =
663       computation->AddInstruction(HloInstruction::CreateConstant(
664           LiteralUtil::Zero(operand->shape().element_type())));
665   return MakePadHlo(operand, zero, padding_config);
666 }
667 
BroadcastZeros(HloComputation * computation,PrimitiveType element_type,absl::Span<const int64_t> broadcast_dimensions)668 HloInstruction* BroadcastZeros(HloComputation* computation,
669                                PrimitiveType element_type,
670                                absl::Span<const int64_t> broadcast_dimensions) {
671   HloInstruction* zero = computation->AddInstruction(
672       HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
673   return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{},
674                           /*result_shape_bounds=*/broadcast_dimensions);
675 }
676 
BroadcastOnes(HloComputation * computation,PrimitiveType element_type,absl::Span<const int64_t> broadcast_dimensions)677 HloInstruction* BroadcastOnes(HloComputation* computation,
678                               PrimitiveType element_type,
679                               absl::Span<const int64_t> broadcast_dimensions) {
680   HloInstruction* one = computation->AddInstruction(
681       HloInstruction::CreateConstant(LiteralUtil::One(element_type)));
682   return MakeBroadcastHlo(one, /*broadcast_dimensions=*/{},
683                           /*result_shape_bounds=*/broadcast_dimensions);
684 }
685 
MakeFusionInstruction(HloInstruction * fused,HloInstruction::FusionKind kind)686 StatusOr<HloInstruction*> MakeFusionInstruction(
687     HloInstruction* fused, HloInstruction::FusionKind kind) {
688   HloComputation* comp = fused->parent();
689   HloInstruction* fusion_instruction = comp->AddInstruction(
690       HloInstruction::CreateFusion(fused->shape(), kind, fused));
691   TF_RETURN_IF_ERROR(comp->ReplaceInstruction(fused, fusion_instruction));
692   return fusion_instruction;
693 }
694 
695 // Recursively creates a dummy op given a shape. Leaf nodes are broadcasted zero
696 // while internal nodes are tuples.
CreateDummyOp(HloComputation::Builder * b,const Shape & shape)697 HloInstruction* CreateDummyOp(HloComputation::Builder* b, const Shape& shape) {
698   if (shape.IsArray()) {
699     auto zero = b->AddInstruction(HloInstruction::CreateConstant(
700         LiteralUtil::Zero(shape.element_type())));
701     return b->AddInstruction(HloInstruction::CreateBroadcast(shape, zero, {}));
702   }
703   CHECK(shape.IsTuple());
704   std::vector<HloInstruction*> sub_instructions;
705   for (const Shape& subshape : shape.tuple_shapes()) {
706     sub_instructions.push_back(CreateDummyOp(b, subshape));
707   }
708   return b->AddInstruction(HloInstruction::CreateTuple(sub_instructions));
709 }
710 
CreateComputationWithSignature(absl::Span<const Shape * const> domain,const Shape & range,absl::string_view name)711 StatusOr<std::unique_ptr<HloComputation>> CreateComputationWithSignature(
712     absl::Span<const Shape* const> domain, const Shape& range,
713     absl::string_view name) {
714   HloComputation::Builder b{std::string(name)};
715   int64_t param_idx = 0;
716   for (const Shape* param_shape : domain) {
717     b.AddInstruction(HloInstruction::CreateParameter(
718         param_idx, *param_shape, StrCat("param.", param_idx)));
719     param_idx++;
720   }
721 
722   // We can't change the root type of a computation once it is created so create
723   // a dummy root instruction to give the computation the right root shape.  Use
724   // a (recursive) broadcast here to avoid creating large constants.
725   CreateDummyOp(&b, range);
726   return b.Build();
727 }
728 
729 }  // namespace xla
730