1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
17
18 #include <algorithm>
19 #include <iterator>
20 #include <memory>
21 #include <numeric>
22 #include <optional>
23 #include <string>
24 #include <utility>
25 #include <vector>
26
27 #include "absl/algorithm/container.h"
28 #include "absl/strings/str_cat.h"
29 #include "tensorflow/compiler/xla/client/lib/comparators.h"
30 #include "tensorflow/compiler/xla/client/xla_builder.h"
31 #include "tensorflow/compiler/xla/client/xla_computation.h"
32 #include "tensorflow/compiler/xla/comparison_util.h"
33 #include "tensorflow/compiler/xla/literal.h"
34 #include "tensorflow/compiler/xla/literal_util.h"
35 #include "tensorflow/compiler/xla/service/hlo_clone_context.h"
36 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
37 #include "tensorflow/compiler/xla/service/hlo_module.h"
38 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
39 #include "tensorflow/compiler/xla/service/shape_inference.h"
40 #include "tensorflow/compiler/xla/util.h"
41
42 namespace xla {
43 using absl::StrCat;
44
MakeUnaryHlo(HloOpcode opcode,HloInstruction * operand,const OpMetadata * metadata)45 StatusOr<HloInstruction*> MakeUnaryHlo(HloOpcode opcode,
46 HloInstruction* operand,
47 const OpMetadata* metadata) {
48 HloComputation* computation = operand->parent();
49 TF_ASSIGN_OR_RETURN(Shape unary_op_shape,
50 ShapeInference::InferUnaryOpShape(opcode, operand));
51 return computation->AddInstruction(
52 HloInstruction::CreateUnary(unary_op_shape, opcode, operand), metadata);
53 }
54
MakeCopyHlo(HloInstruction * from,const Shape & to)55 HloInstruction* MakeCopyHlo(HloInstruction* from, const Shape& to) {
56 return from->AddInstruction(
57 HloInstruction::CreateUnary(to, HloOpcode::kCopy, from));
58 }
59
MakeBinaryHlo(HloOpcode opcode,HloInstruction * lhs,HloInstruction * rhs,const OpMetadata * metadata)60 StatusOr<HloInstruction*> MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs,
61 HloInstruction* rhs,
62 const OpMetadata* metadata) {
63 HloComputation* computation = lhs->parent();
64 CHECK_EQ(computation, rhs->parent());
65 TF_ASSIGN_OR_RETURN(Shape binary_op_shape,
66 ShapeInference::InferBinaryOpShape(opcode, lhs, rhs));
67 return computation->AddInstruction(
68 HloInstruction::CreateBinary(binary_op_shape, opcode, lhs, rhs),
69 metadata);
70 }
71
MakeCompareHlo(ComparisonDirection direction,HloInstruction * lhs,HloInstruction * rhs,const OpMetadata * metadata)72 StatusOr<HloInstruction*> MakeCompareHlo(ComparisonDirection direction,
73 HloInstruction* lhs,
74 HloInstruction* rhs,
75 const OpMetadata* metadata) {
76 HloComputation* computation = lhs->parent();
77 CHECK_EQ(computation, rhs->parent());
78 TF_ASSIGN_OR_RETURN(
79 Shape binary_op_shape,
80 ShapeInference::InferBinaryOpShape(HloOpcode::kCompare, lhs, rhs));
81 return computation->AddInstruction(
82 HloInstruction::CreateCompare(binary_op_shape, lhs, rhs, direction),
83 metadata);
84 }
85
MakePadHlo(HloInstruction * operand,HloInstruction * padding_value,const PaddingConfig & padding_config,const OpMetadata * metadata)86 StatusOr<HloInstruction*> MakePadHlo(HloInstruction* operand,
87 HloInstruction* padding_value,
88 const PaddingConfig& padding_config,
89 const OpMetadata* metadata) {
90 HloComputation* computation = operand->parent();
91 CHECK_EQ(computation, padding_value->parent());
92 TF_ASSIGN_OR_RETURN(
93 Shape pad_shape,
94 ShapeInference::InferPadShape(operand->shape(), padding_value->shape(),
95 padding_config));
96 return computation->AddInstruction(
97 HloInstruction::CreatePad(pad_shape, operand, padding_value,
98 padding_config),
99 metadata);
100 }
101
MakeSliceHlo(HloInstruction * operand,absl::Span<const int64_t> start_indices,absl::Span<const int64_t> limit_indices,absl::Span<const int64_t> strides,const OpMetadata * metadata)102 StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand,
103 absl::Span<const int64_t> start_indices,
104 absl::Span<const int64_t> limit_indices,
105 absl::Span<const int64_t> strides,
106 const OpMetadata* metadata) {
107 HloComputation* computation = operand->parent();
108 TF_ASSIGN_OR_RETURN(Shape slice_shape, ShapeInference::InferSliceShape(
109 operand->shape(), start_indices,
110 limit_indices, strides));
111 return computation->AddInstruction(
112 HloInstruction::CreateSlice(slice_shape, operand, start_indices,
113 limit_indices, strides),
114 metadata);
115 }
116
MakeConvolveHlo(HloInstruction * lhs,HloInstruction * rhs,int64_t feature_group_count,int64_t batch_group_count,const Window & window,const ConvolutionDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config,std::optional<PrimitiveType> preferred_element_type,const OpMetadata * metadata)117 StatusOr<HloInstruction*> MakeConvolveHlo(
118 HloInstruction* lhs, HloInstruction* rhs, int64_t feature_group_count,
119 int64_t batch_group_count, const Window& window,
120 const ConvolutionDimensionNumbers& dimension_numbers,
121 const PrecisionConfig& precision_config,
122 std::optional<PrimitiveType> preferred_element_type,
123 const OpMetadata* metadata) {
124 HloComputation* computation = lhs->parent();
125 CHECK_EQ(computation, rhs->parent());
126 TF_ASSIGN_OR_RETURN(
127 Shape convolve_shape,
128 ShapeInference::InferConvolveShape(
129 lhs->shape(), rhs->shape(), feature_group_count, batch_group_count,
130 window, dimension_numbers, preferred_element_type));
131 return computation->AddInstruction(
132 HloInstruction::CreateConvolve(
133 convolve_shape, lhs, rhs, feature_group_count, batch_group_count,
134 window, dimension_numbers, precision_config),
135 metadata);
136 }
137
MakeTransposeHlo(HloInstruction * operand,absl::Span<const int64_t> dimensions)138 StatusOr<HloInstruction*> MakeTransposeHlo(
139 HloInstruction* operand, absl::Span<const int64_t> dimensions) {
140 TF_ASSIGN_OR_RETURN(
141 Shape transpose_shape,
142 ShapeInference::InferTransposeShape(operand->shape(), dimensions));
143 return operand->AddInstruction(
144 HloInstruction::CreateTranspose(transpose_shape, operand, dimensions));
145 }
146
MakeReshapeHlo(const Shape & result_shape,HloInstruction * operand)147 StatusOr<HloInstruction*> MakeReshapeHlo(const Shape& result_shape,
148 HloInstruction* operand) {
149 return operand->AddInstruction(
150 HloInstruction::CreateReshape(result_shape, operand));
151 }
152
MakeReshapeHlo(absl::Span<const int64_t> result_shape_dim_bounds,HloInstruction * operand)153 StatusOr<HloInstruction*> MakeReshapeHlo(
154 absl::Span<const int64_t> result_shape_dim_bounds,
155 HloInstruction* operand) {
156 Shape new_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
157 result_shape_dim_bounds);
158 return MakeReshapeHlo(new_shape, operand);
159 }
160
MakeDynamicSliceHlo(HloInstruction * operand,absl::Span<HloInstruction * const> start_indices,absl::Span<const int64_t> slice_sizes,const OpMetadata * metadata)161 StatusOr<HloInstruction*> MakeDynamicSliceHlo(
162 HloInstruction* operand, absl::Span<HloInstruction* const> start_indices,
163 absl::Span<const int64_t> slice_sizes, const OpMetadata* metadata) {
164 HloComputation* computation = operand->parent();
165 std::vector<Shape> scalar_start_indices_shapes(
166 start_indices.size(),
167 ShapeUtil::MakeShape(start_indices[0]->shape().element_type(), {}));
168 TF_ASSIGN_OR_RETURN(
169 Shape dynamic_slice_shape,
170 ShapeInference::InferDynamicSliceShape(
171 operand->shape(), scalar_start_indices_shapes, slice_sizes));
172 return computation->AddInstruction(
173 HloInstruction::CreateDynamicSlice(dynamic_slice_shape, operand,
174 start_indices, slice_sizes),
175 metadata);
176 }
177
MakeDynamicSliceHlo(HloInstruction * operand,HloInstruction * start_indices,absl::Span<const int64_t> slice_sizes,const OpMetadata * metadata)178 StatusOr<HloInstruction*> MakeDynamicSliceHlo(
179 HloInstruction* operand, HloInstruction* start_indices,
180 absl::Span<const int64_t> slice_sizes, const OpMetadata* metadata) {
181 HloComputation* computation = operand->parent();
182 CHECK_EQ(computation, start_indices->parent());
183 int64_t rank = start_indices->shape().dimensions(0);
184 std::vector<HloInstruction*> scalar_start_indices;
185 for (int i = 0; i < rank; ++i) {
186 // TODO(b/118437727): Update callers to provide scalars directly.
187 auto slice = computation->AddInstruction(HloInstruction::CreateSlice(
188 ShapeUtil::MakeShape(start_indices->shape().element_type(), {1}),
189 start_indices, {i}, {i + 1}, {1}));
190 scalar_start_indices.push_back(
191 computation->AddInstruction(HloInstruction::CreateReshape(
192 ShapeUtil::MakeShape(start_indices->shape().element_type(), {}),
193 slice)));
194 }
195 std::vector<Shape> scalar_start_indices_shapes(
196 rank, ShapeUtil::MakeShape(start_indices->shape().element_type(), {}));
197 TF_ASSIGN_OR_RETURN(
198 Shape dynamic_slice_shape,
199 ShapeInference::InferDynamicSliceShape(
200 operand->shape(), scalar_start_indices_shapes, slice_sizes));
201 return computation->AddInstruction(
202 HloInstruction::CreateDynamicSlice(dynamic_slice_shape, operand,
203 scalar_start_indices, slice_sizes),
204 metadata);
205 }
206
MakeDynamicUpdateSliceHlo(HloInstruction * operand,HloInstruction * update,HloInstruction * start_indices,const OpMetadata * metadata)207 StatusOr<HloInstruction*> MakeDynamicUpdateSliceHlo(
208 HloInstruction* operand, HloInstruction* update,
209 HloInstruction* start_indices, const OpMetadata* metadata) {
210 HloComputation* computation = operand->parent();
211 CHECK_EQ(computation, update->parent());
212 CHECK_EQ(computation, start_indices->parent());
213 int64_t rank = start_indices->shape().dimensions(0);
214 std::vector<HloInstruction*> scalar_start_indices;
215 for (int i = 0; i < rank; ++i) {
216 // TODO(b/118437727): Update callers to provide scalars directly.
217 auto slice = computation->AddInstruction(HloInstruction::CreateSlice(
218 ShapeUtil::MakeShape(start_indices->shape().element_type(), {1}),
219 start_indices, {i}, {i + 1}, {1}));
220 scalar_start_indices.push_back(
221 computation->AddInstruction(HloInstruction::CreateReshape(
222 ShapeUtil::MakeShape(start_indices->shape().element_type(), {}),
223 slice)));
224 }
225 std::vector<Shape> scalar_start_indices_shapes(
226 rank, ShapeUtil::MakeShape(start_indices->shape().element_type(), {}));
227 TF_ASSIGN_OR_RETURN(
228 Shape dynamic_update_slice_shape,
229 ShapeInference::InferDynamicUpdateSliceShape(
230 operand->shape(), update->shape(), scalar_start_indices_shapes));
231 return computation->AddInstruction(
232 HloInstruction::CreateDynamicUpdateSlice(
233 dynamic_update_slice_shape, operand, update, scalar_start_indices),
234 metadata);
235 }
236
MakeBroadcastHlo(HloInstruction * operand,absl::Span<const int64_t> broadcast_dimensions,absl::Span<const int64_t> result_shape_bounds,const OpMetadata * metadata)237 HloInstruction* MakeBroadcastHlo(HloInstruction* operand,
238 absl::Span<const int64_t> broadcast_dimensions,
239 absl::Span<const int64_t> result_shape_bounds,
240 const OpMetadata* metadata) {
241 HloComputation* computation = operand->parent();
242 Shape broadcast_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
243 result_shape_bounds);
244
245 return computation->AddInstruction(
246 HloInstruction::CreateBroadcast(broadcast_shape, operand,
247 broadcast_dimensions),
248 metadata);
249 }
250
MakeBroadcastHlo(HloInstruction * operand,absl::Span<const int64_t> broadcast_dimensions,const Shape & shape,const OpMetadata * metadata)251 HloInstruction* MakeBroadcastHlo(HloInstruction* operand,
252 absl::Span<const int64_t> broadcast_dimensions,
253 const Shape& shape,
254 const OpMetadata* metadata) {
255 return MakeBroadcastHlo(operand, broadcast_dimensions, shape.dimensions(),
256 metadata);
257 }
258
MakeGetTupleElementHlo(HloInstruction * operand,int64_t index,const OpMetadata * metadata)259 StatusOr<HloInstruction*> MakeGetTupleElementHlo(HloInstruction* operand,
260 int64_t index,
261 const OpMetadata* metadata) {
262 HloComputation* computation = operand->parent();
263
264 TF_ASSIGN_OR_RETURN(
265 Shape gte_shape,
266 ShapeInference::InferGetTupleElementShape(operand->shape(), index));
267 return computation->AddInstruction(
268 HloInstruction::CreateGetTupleElement(gte_shape, operand, index),
269 metadata);
270 }
271
MakeConcatHlo(absl::Span<HloInstruction * const> operands,int64_t dimension,const OpMetadata * metadata)272 StatusOr<HloInstruction*> MakeConcatHlo(
273 absl::Span<HloInstruction* const> operands, int64_t dimension,
274 const OpMetadata* metadata) {
275 CHECK_GT(operands.size(), 0);
276
277 HloComputation* computation = operands[0]->parent();
278 CHECK(absl::c_all_of(operands, [&](HloInstruction* instr) {
279 return instr->parent() == computation;
280 }));
281
282 std::vector<const Shape*> operand_shapes;
283 absl::c_transform(operands, std::back_inserter(operand_shapes),
284 [](HloInstruction* instr) { return &instr->shape(); });
285
286 TF_ASSIGN_OR_RETURN(Shape concat_shape, ShapeInference::InferConcatOpShape(
287 operand_shapes, dimension));
288 return computation->AddInstruction(
289 HloInstruction::CreateConcatenate(concat_shape, operands, dimension),
290 metadata);
291 }
292
MakeConvertToHlo(HloInstruction * hlo,PrimitiveType type,const OpMetadata * metadata)293 HloInstruction* MakeConvertToHlo(HloInstruction* hlo, PrimitiveType type,
294 const OpMetadata* metadata) {
295 if (hlo->shape().element_type() == type) {
296 return hlo;
297 }
298 Shape shape = ShapeUtil::ChangeElementType(hlo->shape(), type);
299 hlo = hlo->parent()->AddInstruction(HloInstruction::CreateConvert(shape, hlo),
300 metadata);
301 CHECK_EQ(hlo->shape().element_type(), type);
302 return hlo;
303 }
304
MakeBitcastHlo(HloInstruction * hlo,const Shape & shape,const OpMetadata * metadata)305 HloInstruction* MakeBitcastHlo(HloInstruction* hlo, const Shape& shape,
306 const OpMetadata* metadata) {
307 return hlo->parent()->AddInstruction(
308 HloInstruction::CreateBitcast(shape, hlo), metadata);
309 }
310
MakeBitcastConvertToHlo(HloInstruction * hlo,PrimitiveType type,const OpMetadata * metadata)311 HloInstruction* MakeBitcastConvertToHlo(HloInstruction* hlo, PrimitiveType type,
312 const OpMetadata* metadata) {
313 if (hlo->shape().element_type() == type) {
314 return hlo;
315 }
316 Shape shape = ShapeUtil::ChangeElementType(hlo->shape(), type);
317 // PRED are stored as one byte, PRED have a BitWidth of 1, avoid this problem
318 // by using a convert instead of bitcast convert.
319 if (type == PRED || hlo->shape().element_type() == PRED) {
320 return MakeConvertToHlo(hlo, type);
321 }
322 hlo = hlo->parent()->AddInstruction(
323 HloInstruction::CreateBitcastConvert(shape, hlo), metadata);
324 CHECK_EQ(hlo->shape().element_type(), type);
325 return hlo;
326 }
327
MakeIotaHlo(HloComputation * computation,const Shape & shape,int64_t iota_dimension)328 HloInstruction* MakeIotaHlo(HloComputation* computation, const Shape& shape,
329 int64_t iota_dimension) {
330 return computation->AddInstruction(
331 HloInstruction::CreateIota(shape, iota_dimension));
332 }
333
MakeDotHlo(HloInstruction * lhs,HloInstruction * rhs,const DotDimensionNumbers & dim_numbers,const PrecisionConfig & precision_config,std::optional<PrimitiveType> preferred_element_type,const OpMetadata * metadata)334 StatusOr<HloInstruction*> MakeDotHlo(
335 HloInstruction* lhs, HloInstruction* rhs,
336 const DotDimensionNumbers& dim_numbers,
337 const PrecisionConfig& precision_config,
338 std::optional<PrimitiveType> preferred_element_type,
339 const OpMetadata* metadata) {
340 HloComputation* computation = lhs->parent();
341 CHECK_EQ(computation, rhs->parent());
342 TF_ASSIGN_OR_RETURN(
343 Shape dot_shape,
344 ShapeInference::InferDotOpShape(lhs->shape(), rhs->shape(), dim_numbers,
345 preferred_element_type));
346 return computation->AddInstruction(
347 HloInstruction::CreateDot(dot_shape, lhs, rhs, dim_numbers,
348 precision_config),
349 metadata);
350 }
351
MakeMapHlo(absl::Span<HloInstruction * const> operands,HloComputation * map_computation,const OpMetadata * metadata)352 StatusOr<HloInstruction*> MakeMapHlo(absl::Span<HloInstruction* const> operands,
353 HloComputation* map_computation,
354 const OpMetadata* metadata) {
355 CHECK(!operands.empty()) << "Map Hlo requires at least one operand.";
356 HloComputation* computation = operands.front()->parent();
357 std::vector<const Shape*> operand_shapes;
358 int64_t max_operand_rank = 0;
359 for (const HloInstruction* operand : operands) {
360 CHECK_EQ(computation, operand->parent());
361 operand_shapes.push_back(&operand->shape());
362 max_operand_rank = std::max(max_operand_rank, operand->shape().rank());
363 }
364 std::vector<int64_t> map_dims(max_operand_rank);
365 std::iota(map_dims.begin(), map_dims.end(), 0);
366 TF_ASSIGN_OR_RETURN(
367 Shape map_shape,
368 ShapeInference::InferMapShape(
369 operand_shapes, map_computation->ComputeProgramShape(), map_dims));
370 return computation->AddInstruction(
371 HloInstruction::CreateMap(map_shape, operands, map_computation),
372 metadata);
373 }
374
MakeReducePrecisionHlo(HloInstruction * operand,int exponent_bits,int mantissa_bits,const OpMetadata * metadata)375 HloInstruction* MakeReducePrecisionHlo(HloInstruction* operand,
376 int exponent_bits, int mantissa_bits,
377 const OpMetadata* metadata) {
378 return operand->parent()->AddInstruction(
379 HloInstruction::CreateReducePrecision(operand->shape(), operand,
380 exponent_bits, mantissa_bits),
381 metadata);
382 }
383
MakeReduceHlo(HloInstruction * operand,HloInstruction * init_value,absl::Span<const int64_t> dimensions,HloComputation * reduce_computation,const OpMetadata * metadata)384 StatusOr<HloInstruction*> MakeReduceHlo(HloInstruction* operand,
385 HloInstruction* init_value,
386 absl::Span<const int64_t> dimensions,
387 HloComputation* reduce_computation,
388 const OpMetadata* metadata) {
389 auto scalar_shape = ShapeUtil::MakeShape(operand->shape().element_type(), {});
390 auto result_shape = ShapeUtil::DeleteDimensions(dimensions, operand->shape());
391
392 return operand->parent()->AddInstruction(
393 HloInstruction::CreateReduce(result_shape, operand, init_value,
394 dimensions, reduce_computation),
395 metadata);
396 }
397
MakeReduceHlo(HloInstruction * operand,HloInstruction * init_value,absl::Span<const int64_t> dimensions,HloOpcode binary_opcode,const OpMetadata * metadata)398 StatusOr<HloInstruction*> MakeReduceHlo(HloInstruction* operand,
399 HloInstruction* init_value,
400 absl::Span<const int64_t> dimensions,
401 HloOpcode binary_opcode,
402 const OpMetadata* metadata) {
403 auto scalar_shape = ShapeUtil::MakeShape(operand->shape().element_type(), {});
404 HloComputation* reduce_computation;
405 {
406 HloComputation::Builder b(operand->name() + ".reduce_sub_computation");
407 auto lhs = b.AddInstruction(
408 HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
409 auto rhs = b.AddInstruction(
410 HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
411 b.AddInstruction(
412 HloInstruction::CreateBinary(scalar_shape, binary_opcode, lhs, rhs));
413 reduce_computation =
414 operand->parent()->parent()->AddEmbeddedComputation(b.Build());
415 }
416 return MakeReduceHlo(operand, init_value, dimensions, reduce_computation,
417 metadata);
418 }
419
MakeReduceHlo(HloInstruction * operand,HloInstruction * init_value,HloOpcode binary_opcode,HloModule * module,const OpMetadata * metadata)420 StatusOr<HloInstruction*> MakeReduceHlo(HloInstruction* operand,
421 HloInstruction* init_value,
422 HloOpcode binary_opcode,
423 HloModule* module,
424 const OpMetadata* metadata) {
425 DCHECK_NE(nullptr, module);
426 std::vector<int64_t> all_dims(operand->shape().rank());
427 std::iota(all_dims.begin(), all_dims.end(), 0);
428
429 auto scalar_shape = ShapeUtil::MakeShape(operand->shape().element_type(), {});
430 HloComputation* reduce_computation;
431 {
432 HloComputation::Builder b(operand->name() + ".reduce_sub_computation");
433 auto lhs = b.AddInstruction(
434 HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
435 auto rhs = b.AddInstruction(
436 HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
437 b.AddInstruction(
438 HloInstruction::CreateBinary(scalar_shape, binary_opcode, lhs, rhs));
439 reduce_computation = module->AddEmbeddedComputation(b.Build());
440 }
441 return MakeReduceHlo(operand, init_value, all_dims, reduce_computation,
442 metadata);
443 }
444
MakeReduceHlo(absl::Span<HloInstruction * const> operands,absl::Span<HloInstruction * const> init_values,absl::Span<const int64_t> dimensions,HloComputation * reduce_computation,const OpMetadata * metadata)445 StatusOr<HloInstruction*> MakeReduceHlo(
446 absl::Span<HloInstruction* const> operands,
447 absl::Span<HloInstruction* const> init_values,
448 absl::Span<const int64_t> dimensions, HloComputation* reduce_computation,
449 const OpMetadata* metadata) {
450 CHECK(!operands.empty());
451 CHECK_EQ(operands.size(), init_values.size());
452 auto root = reduce_computation->root_instruction();
453 if (root->shape().IsTuple()) {
454 CHECK_EQ(root->shape().tuple_shapes_size(), operands.size());
455 } else {
456 CHECK_EQ(operands.size(), 1);
457 }
458
459 std::vector<Shape> expected_shapes;
460 for (auto operand : operands) {
461 expected_shapes.push_back(ShapeUtil::FilterDimensions(
462 [&](const int64_t dim) {
463 return !absl::c_linear_search(dimensions, dim);
464 },
465 operand->shape()));
466 }
467
468 auto output_shape = ShapeUtil::MakeMaybeTupleShape(expected_shapes);
469
470 return operands[0]->parent()->AddInstruction(
471 HloInstruction::CreateReduce(output_shape, operands, init_values,
472 dimensions, reduce_computation),
473 metadata);
474 }
475
MakeReverseHlo(HloInstruction * operand,absl::Span<const int64_t> dimensions,const OpMetadata * metadata)476 StatusOr<HloInstruction*> MakeReverseHlo(HloInstruction* operand,
477 absl::Span<const int64_t> dimensions,
478 const OpMetadata* metadata) {
479 HloComputation* computation = operand->parent();
480 TF_ASSIGN_OR_RETURN(Shape reverse_shape, ShapeInference::InferReverseShape(
481 operand->shape(), dimensions));
482 return computation->AddInstruction(
483 HloInstruction::CreateReverse(reverse_shape, operand, dimensions),
484 metadata);
485 }
486
MakeSelectHlo(HloInstruction * pred,HloInstruction * on_true,HloInstruction * on_false,HloInstruction * derived_from)487 StatusOr<HloInstruction*> MakeSelectHlo(HloInstruction* pred,
488 HloInstruction* on_true,
489 HloInstruction* on_false,
490 HloInstruction* derived_from) {
491 HloComputation* computation = pred->parent();
492 DCHECK_EQ(computation, on_true->parent());
493 DCHECK_EQ(computation, on_false->parent());
494 Shape op_shape = on_true->shape();
495 if (ShapeUtil::IsScalar(pred->shape())) {
496 if (!ShapeUtil::IsScalar(op_shape) && !op_shape.IsTuple()) {
497 // If the output is not scalar, we need to broadcast the condition
498 // to match the contract of kSelect.
499 pred = computation->AddInstruction(HloInstruction::CreateBroadcast(
500 ShapeUtil::ChangeElementType(op_shape, PrimitiveType::PRED), pred,
501 {}));
502 if (derived_from) {
503 derived_from->SetupDerivedInstruction(pred);
504 }
505 }
506 }
507 TF_RET_CHECK(!op_shape.IsTuple());
508 HloOpcode select_op_code = HloOpcode::kSelect;
509 TF_ASSIGN_OR_RETURN(Shape select_shape,
510 ShapeInference::InferTernaryOpShape(select_op_code, pred,
511 on_true, on_false));
512 HloInstruction* select =
513 computation->AddInstruction(HloInstruction::CreateTernary(
514 select_shape, select_op_code, pred, on_true, on_false));
515 if (derived_from) {
516 derived_from->SetupDerivedInstruction(select);
517 }
518 return select;
519 }
520
MaybeMakeTuple(absl::Span<HloInstruction * const> operands)521 HloInstruction* MaybeMakeTuple(absl::Span<HloInstruction* const> operands) {
522 CHECK(!operands.empty());
523 if (operands.size() == 1) {
524 return operands[0];
525 }
526 return operands[0]->parent()->AddInstruction(
527 HloInstruction::CreateTuple(operands));
528 }
529
MakeSortHlo(const Shape & sort_shape,absl::Span<HloInstruction * const> operands,int64_t dimension_to_sort,bool is_stable,HloComputation::Builder * builder,HloModule * module,const OpMetadata * metadata)530 StatusOr<HloInstruction*> MakeSortHlo(
531 const Shape& sort_shape, absl::Span<HloInstruction* const> operands,
532 int64_t dimension_to_sort, bool is_stable, HloComputation::Builder* builder,
533 HloModule* module, const OpMetadata* metadata) {
534 CHECK(!operands.empty()) << "Sort Hlo requires at least one operand.";
535 HloComputation* compare_computation;
536 XlaBuilder b("Sort.Compare");
537 if (metadata != nullptr) {
538 b.SetOpMetadata(*metadata);
539 }
540 std::vector<PrimitiveType> operand_types(operands.size());
541 for (int64_t i = 0; i < operands.size(); ++i) {
542 operand_types[i] = operands[i]->shape().element_type();
543 }
544 XlaComputation comparator = CreateScalarLtComputation(operand_types, &b);
545 TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comparator.GetProgramShape());
546 HloModuleConfig config(program_shape);
547 TF_ASSIGN_OR_RETURN(auto new_module,
548 HloModule::CreateFromProto(comparator.proto(), config));
549 HloCloneContext context(module);
550 compare_computation =
551 module->DeepCloneComputation(new_module->entry_computation(), &context);
552 return builder->AddInstruction(HloInstruction::CreateSort(
553 sort_shape, dimension_to_sort, operands, compare_computation, is_stable));
554 }
555
CollapseFirstNDims(HloInstruction * operand,int64_t n)556 StatusOr<HloInstruction*> CollapseFirstNDims(HloInstruction* operand,
557 int64_t n) {
558 CHECK_GT(n, 0);
559
560 const Shape& operand_shape = operand->shape();
561 CHECK_GE(operand_shape.dimensions_size(), n);
562 int64_t new_shape_leading_bound = 1;
563 for (int64_t i = 0; i < n; i++) {
564 new_shape_leading_bound *= operand_shape.dimensions(i);
565 }
566
567 std::vector<int64_t> new_shape_dims;
568 new_shape_dims.reserve(operand_shape.dimensions_size() - n + 1);
569 new_shape_dims.push_back(new_shape_leading_bound);
570
571 std::copy(operand_shape.dimensions().begin() + n,
572 operand_shape.dimensions().end(),
573 std::back_inserter(new_shape_dims));
574
575 Shape output_shape =
576 ShapeUtil::MakeShape(operand_shape.element_type(), new_shape_dims);
577
578 return MakeReshapeHlo(output_shape, operand);
579 }
580
PrependDegenerateDims(HloInstruction * operand,int64_t n)581 StatusOr<HloInstruction*> PrependDegenerateDims(HloInstruction* operand,
582 int64_t n) {
583 CHECK_GT(n, 0);
584 std::vector<int64_t> new_shape_dims;
585 const Shape& operand_shape = operand->shape();
586 new_shape_dims.reserve(n + operand_shape.dimensions_size());
587 new_shape_dims.insert(new_shape_dims.begin(), n, 1);
588 absl::c_copy(operand_shape.dimensions(), std::back_inserter(new_shape_dims));
589 return MakeReshapeHlo(new_shape_dims, operand);
590 }
591
ExpandFirstDimIntoNDims(HloInstruction * operand,absl::Span<const int64_t> expanded_dims)592 StatusOr<HloInstruction*> ExpandFirstDimIntoNDims(
593 HloInstruction* operand, absl::Span<const int64_t> expanded_dims) {
594 CHECK_GT(operand->shape().dimensions_size(), 0);
595 CHECK_EQ(operand->shape().dimensions(0), Product(expanded_dims));
596
597 std::vector<int64_t> expanded_shape_dim_bounds;
598 expanded_shape_dim_bounds.reserve(expanded_dims.size() +
599 operand->shape().dimensions_size() - 1);
600 absl::c_copy(expanded_dims, std::back_inserter(expanded_shape_dim_bounds));
601 std::copy(operand->shape().dimensions().begin() + 1,
602 operand->shape().dimensions().end(),
603 std::back_inserter(expanded_shape_dim_bounds));
604 Shape new_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
605 expanded_shape_dim_bounds);
606 return MakeReshapeHlo(new_shape, operand);
607 }
608
ElideDegenerateDims(HloInstruction * operand,absl::Span<const int64_t> dims_to_elide)609 StatusOr<HloInstruction*> ElideDegenerateDims(
610 HloInstruction* operand, absl::Span<const int64_t> dims_to_elide) {
611 return MakeReshapeHlo(ShapeUtil::FilterDimensions(
612 [&](int64_t dim) {
613 return !absl::c_linear_search(dims_to_elide, dim);
614 },
615 operand->shape()),
616 operand);
617 }
618
InsertDegenerateDims(HloInstruction * operand,absl::Span<const int64_t> dims_to_insert)619 StatusOr<HloInstruction*> InsertDegenerateDims(
620 HloInstruction* operand, absl::Span<const int64_t> dims_to_insert) {
621 CHECK(absl::c_is_sorted(dims_to_insert));
622
623 const Shape& operand_shape = operand->shape();
624 int64_t output_shape_rank =
625 operand_shape.dimensions_size() + dims_to_insert.size();
626 for (auto dim_to_insert : dims_to_insert) {
627 CHECK_LT(dim_to_insert, output_shape_rank);
628 }
629
630 std::vector<int64_t> output_shape_dim_bounds;
631 output_shape_dim_bounds.reserve(output_shape_rank);
632 int64_t operand_dims_idx = 0;
633 int64_t dims_to_insert_idx = 0;
634 for (int64_t i = 0; i < output_shape_rank; ++i) {
635 if (dims_to_insert_idx < dims_to_insert.size() &&
636 i == dims_to_insert[dims_to_insert_idx]) {
637 output_shape_dim_bounds.push_back(1);
638 ++dims_to_insert_idx;
639 } else {
640 output_shape_dim_bounds.push_back(
641 operand_shape.dimensions(operand_dims_idx));
642 ++operand_dims_idx;
643 }
644 }
645
646 Shape output_shape = ShapeUtil::MakeShape(operand_shape.element_type(),
647 output_shape_dim_bounds);
648 return MakeReshapeHlo(output_shape, operand);
649 }
650
PadVectorWithZeros(HloInstruction * operand,int64_t zeros_to_prepend,int64_t zeros_to_append)651 StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand,
652 int64_t zeros_to_prepend,
653 int64_t zeros_to_append) {
654 HloComputation* computation = operand->parent();
655 CHECK_EQ(operand->shape().dimensions_size(), 1);
656 PaddingConfig padding_config;
657 PaddingConfig::PaddingConfigDimension padding_config_dim;
658 padding_config_dim.set_edge_padding_low(zeros_to_prepend);
659 padding_config_dim.set_edge_padding_high(zeros_to_append);
660 *padding_config.add_dimensions() = padding_config_dim;
661
662 HloInstruction* zero =
663 computation->AddInstruction(HloInstruction::CreateConstant(
664 LiteralUtil::Zero(operand->shape().element_type())));
665 return MakePadHlo(operand, zero, padding_config);
666 }
667
BroadcastZeros(HloComputation * computation,PrimitiveType element_type,absl::Span<const int64_t> broadcast_dimensions)668 HloInstruction* BroadcastZeros(HloComputation* computation,
669 PrimitiveType element_type,
670 absl::Span<const int64_t> broadcast_dimensions) {
671 HloInstruction* zero = computation->AddInstruction(
672 HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
673 return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{},
674 /*result_shape_bounds=*/broadcast_dimensions);
675 }
676
BroadcastOnes(HloComputation * computation,PrimitiveType element_type,absl::Span<const int64_t> broadcast_dimensions)677 HloInstruction* BroadcastOnes(HloComputation* computation,
678 PrimitiveType element_type,
679 absl::Span<const int64_t> broadcast_dimensions) {
680 HloInstruction* one = computation->AddInstruction(
681 HloInstruction::CreateConstant(LiteralUtil::One(element_type)));
682 return MakeBroadcastHlo(one, /*broadcast_dimensions=*/{},
683 /*result_shape_bounds=*/broadcast_dimensions);
684 }
685
MakeFusionInstruction(HloInstruction * fused,HloInstruction::FusionKind kind)686 StatusOr<HloInstruction*> MakeFusionInstruction(
687 HloInstruction* fused, HloInstruction::FusionKind kind) {
688 HloComputation* comp = fused->parent();
689 HloInstruction* fusion_instruction = comp->AddInstruction(
690 HloInstruction::CreateFusion(fused->shape(), kind, fused));
691 TF_RETURN_IF_ERROR(comp->ReplaceInstruction(fused, fusion_instruction));
692 return fusion_instruction;
693 }
694
695 // Recursively creates a dummy op given a shape. Leaf nodes are broadcasted zero
696 // while internal nodes are tuples.
CreateDummyOp(HloComputation::Builder * b,const Shape & shape)697 HloInstruction* CreateDummyOp(HloComputation::Builder* b, const Shape& shape) {
698 if (shape.IsArray()) {
699 auto zero = b->AddInstruction(HloInstruction::CreateConstant(
700 LiteralUtil::Zero(shape.element_type())));
701 return b->AddInstruction(HloInstruction::CreateBroadcast(shape, zero, {}));
702 }
703 CHECK(shape.IsTuple());
704 std::vector<HloInstruction*> sub_instructions;
705 for (const Shape& subshape : shape.tuple_shapes()) {
706 sub_instructions.push_back(CreateDummyOp(b, subshape));
707 }
708 return b->AddInstruction(HloInstruction::CreateTuple(sub_instructions));
709 }
710
CreateComputationWithSignature(absl::Span<const Shape * const> domain,const Shape & range,absl::string_view name)711 StatusOr<std::unique_ptr<HloComputation>> CreateComputationWithSignature(
712 absl::Span<const Shape* const> domain, const Shape& range,
713 absl::string_view name) {
714 HloComputation::Builder b{std::string(name)};
715 int64_t param_idx = 0;
716 for (const Shape* param_shape : domain) {
717 b.AddInstruction(HloInstruction::CreateParameter(
718 param_idx, *param_shape, StrCat("param.", param_idx)));
719 param_idx++;
720 }
721
722 // We can't change the root type of a computation once it is created so create
723 // a dummy root instruction to give the computation the right root shape. Use
724 // a (recursive) broadcast here to avoid creating large constants.
725 CreateDummyOp(&b, range);
726 return b.Build();
727 }
728
729 } // namespace xla
730