1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
17
18 #include "absl/algorithm/container.h"
19 #include "absl/memory/memory.h"
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/xla/client/lib/comparators.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/client/xla_computation.h"
24 #include "tensorflow/compiler/xla/comparison_util.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/literal_util.h"
27 #include "tensorflow/compiler/xla/service/hlo_clone_context.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/service/hlo_module.h"
30 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
31 #include "tensorflow/compiler/xla/service/shape_inference.h"
32 #include "tensorflow/compiler/xla/util.h"
33
34 namespace xla {
35 using absl::StrCat;
36
MakeUnaryHlo(HloOpcode opcode,HloInstruction * operand)37 StatusOr<HloInstruction*> MakeUnaryHlo(HloOpcode opcode,
38 HloInstruction* operand) {
39 HloComputation* computation = operand->parent();
40 TF_ASSIGN_OR_RETURN(Shape unary_op_shape,
41 ShapeInference::InferUnaryOpShape(opcode, operand));
42 return computation->AddInstruction(
43 HloInstruction::CreateUnary(unary_op_shape, opcode, operand));
44 }
45
MakeBinaryHlo(HloOpcode opcode,HloInstruction * lhs,HloInstruction * rhs)46 StatusOr<HloInstruction*> MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs,
47 HloInstruction* rhs) {
48 HloComputation* computation = lhs->parent();
49 CHECK_EQ(computation, rhs->parent());
50 TF_ASSIGN_OR_RETURN(Shape binary_op_shape,
51 ShapeInference::InferBinaryOpShape(opcode, lhs, rhs));
52 return computation->AddInstruction(
53 HloInstruction::CreateBinary(binary_op_shape, opcode, lhs, rhs));
54 }
55
MakeCompareHlo(ComparisonDirection direction,HloInstruction * lhs,HloInstruction * rhs)56 StatusOr<HloInstruction*> MakeCompareHlo(ComparisonDirection direction,
57 HloInstruction* lhs,
58 HloInstruction* rhs) {
59 HloComputation* computation = lhs->parent();
60 CHECK_EQ(computation, rhs->parent());
61 TF_ASSIGN_OR_RETURN(
62 Shape binary_op_shape,
63 ShapeInference::InferBinaryOpShape(HloOpcode::kCompare, lhs, rhs));
64 return computation->AddInstruction(
65 HloInstruction::CreateCompare(binary_op_shape, lhs, rhs, direction));
66 }
67
MakePadHlo(HloInstruction * operand,HloInstruction * padding_value,const PaddingConfig & padding_config)68 StatusOr<HloInstruction*> MakePadHlo(HloInstruction* operand,
69 HloInstruction* padding_value,
70 const PaddingConfig& padding_config) {
71 HloComputation* computation = operand->parent();
72 CHECK_EQ(computation, padding_value->parent());
73 TF_ASSIGN_OR_RETURN(
74 Shape pad_shape,
75 ShapeInference::InferPadShape(operand->shape(), padding_value->shape(),
76 padding_config));
77 return computation->AddInstruction(HloInstruction::CreatePad(
78 pad_shape, operand, padding_value, padding_config));
79 }
80
MakeSliceHlo(HloInstruction * operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)81 StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand,
82 absl::Span<const int64> start_indices,
83 absl::Span<const int64> limit_indices,
84 absl::Span<const int64> strides) {
85 HloComputation* computation = operand->parent();
86 TF_ASSIGN_OR_RETURN(Shape slice_shape, ShapeInference::InferSliceShape(
87 operand->shape(), start_indices,
88 limit_indices, strides));
89 return computation->AddInstruction(HloInstruction::CreateSlice(
90 slice_shape, operand, start_indices, limit_indices, strides));
91 }
92
MakeConvolveHlo(HloInstruction * lhs,HloInstruction * rhs,int64 feature_group_count,int64 batch_group_count,const Window & window,const ConvolutionDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config,absl::optional<PrimitiveType> preferred_element_type)93 StatusOr<HloInstruction*> MakeConvolveHlo(
94 HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count,
95 int64 batch_group_count, const Window& window,
96 const ConvolutionDimensionNumbers& dimension_numbers,
97 const PrecisionConfig& precision_config,
98 absl::optional<PrimitiveType> preferred_element_type) {
99 HloComputation* computation = lhs->parent();
100 CHECK_EQ(computation, rhs->parent());
101 TF_ASSIGN_OR_RETURN(
102 Shape convolve_shape,
103 ShapeInference::InferConvolveShape(
104 lhs->shape(), rhs->shape(), feature_group_count, batch_group_count,
105 window, dimension_numbers, preferred_element_type));
106 return computation->AddInstruction(HloInstruction::CreateConvolve(
107 convolve_shape, lhs, rhs, feature_group_count, batch_group_count, window,
108 dimension_numbers, precision_config));
109 }
110
MakeTransposeHlo(HloInstruction * operand,absl::Span<const int64> dimensions)111 StatusOr<HloInstruction*> MakeTransposeHlo(HloInstruction* operand,
112 absl::Span<const int64> dimensions) {
113 HloComputation* computation = operand->parent();
114 TF_ASSIGN_OR_RETURN(
115 Shape transpose_shape,
116 ShapeInference::InferTransposeShape(operand->shape(), dimensions));
117 return computation->AddInstruction(
118 HloInstruction::CreateTranspose(transpose_shape, operand, dimensions));
119 }
120
MakeReshapeHlo(const Shape & result_shape,HloInstruction * operand)121 StatusOr<HloInstruction*> MakeReshapeHlo(const Shape& result_shape,
122 HloInstruction* operand) {
123 HloComputation* computation = operand->parent();
124 return computation->AddInstruction(
125 HloInstruction::CreateReshape(result_shape, operand));
126 }
127
MakeReshapeHlo(absl::Span<const int64> result_shape_dim_bounds,HloInstruction * operand)128 StatusOr<HloInstruction*> MakeReshapeHlo(
129 absl::Span<const int64> result_shape_dim_bounds, HloInstruction* operand) {
130 Shape new_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
131 result_shape_dim_bounds);
132 return MakeReshapeHlo(new_shape, operand);
133 }
134
MakeDynamicSliceHlo(HloInstruction * operand,absl::Span<HloInstruction * const> start_indices,absl::Span<const int64> slice_sizes)135 StatusOr<HloInstruction*> MakeDynamicSliceHlo(
136 HloInstruction* operand, absl::Span<HloInstruction* const> start_indices,
137 absl::Span<const int64> slice_sizes) {
138 HloComputation* computation = operand->parent();
139 std::vector<Shape> scalar_start_indices_shapes(
140 start_indices.size(),
141 ShapeUtil::MakeShape(start_indices[0]->shape().element_type(), {}));
142 TF_ASSIGN_OR_RETURN(
143 Shape dynamic_slice_shape,
144 ShapeInference::InferDynamicSliceShape(
145 operand->shape(), scalar_start_indices_shapes, slice_sizes));
146 return computation->AddInstruction(HloInstruction::CreateDynamicSlice(
147 dynamic_slice_shape, operand, start_indices, slice_sizes));
148 }
149
MakeDynamicSliceHlo(HloInstruction * operand,HloInstruction * start_indices,absl::Span<const int64> slice_sizes)150 StatusOr<HloInstruction*> MakeDynamicSliceHlo(
151 HloInstruction* operand, HloInstruction* start_indices,
152 absl::Span<const int64> slice_sizes) {
153 HloComputation* computation = operand->parent();
154 CHECK_EQ(computation, start_indices->parent());
155 int64 rank = start_indices->shape().dimensions(0);
156 std::vector<HloInstruction*> scalar_start_indices;
157 for (int i = 0; i < rank; ++i) {
158 // TODO(b/118437727): Update callers to provide scalars directly.
159 auto slice = computation->AddInstruction(HloInstruction::CreateSlice(
160 ShapeUtil::MakeShape(start_indices->shape().element_type(), {1}),
161 start_indices, {i}, {i + 1}, {1}));
162 scalar_start_indices.push_back(
163 computation->AddInstruction(HloInstruction::CreateReshape(
164 ShapeUtil::MakeShape(start_indices->shape().element_type(), {}),
165 slice)));
166 }
167 std::vector<Shape> scalar_start_indices_shapes(
168 rank, ShapeUtil::MakeShape(start_indices->shape().element_type(), {}));
169 TF_ASSIGN_OR_RETURN(
170 Shape dynamic_slice_shape,
171 ShapeInference::InferDynamicSliceShape(
172 operand->shape(), scalar_start_indices_shapes, slice_sizes));
173 return computation->AddInstruction(HloInstruction::CreateDynamicSlice(
174 dynamic_slice_shape, operand, scalar_start_indices, slice_sizes));
175 }
176
MakeDynamicUpdateSliceHlo(HloInstruction * operand,HloInstruction * update,HloInstruction * start_indices)177 StatusOr<HloInstruction*> MakeDynamicUpdateSliceHlo(
178 HloInstruction* operand, HloInstruction* update,
179 HloInstruction* start_indices) {
180 HloComputation* computation = operand->parent();
181 CHECK_EQ(computation, update->parent());
182 CHECK_EQ(computation, start_indices->parent());
183 int64 rank = start_indices->shape().dimensions(0);
184 std::vector<HloInstruction*> scalar_start_indices;
185 for (int i = 0; i < rank; ++i) {
186 // TODO(b/118437727): Update callers to provide scalars directly.
187 auto slice = computation->AddInstruction(HloInstruction::CreateSlice(
188 ShapeUtil::MakeShape(start_indices->shape().element_type(), {1}),
189 start_indices, {i}, {i + 1}, {1}));
190 scalar_start_indices.push_back(
191 computation->AddInstruction(HloInstruction::CreateReshape(
192 ShapeUtil::MakeShape(start_indices->shape().element_type(), {}),
193 slice)));
194 }
195 std::vector<Shape> scalar_start_indices_shapes(
196 rank, ShapeUtil::MakeShape(start_indices->shape().element_type(), {}));
197 TF_ASSIGN_OR_RETURN(
198 Shape dynamic_update_slice_shape,
199 ShapeInference::InferDynamicUpdateSliceShape(
200 operand->shape(), update->shape(), scalar_start_indices_shapes));
201 return computation->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
202 dynamic_update_slice_shape, operand, update, scalar_start_indices));
203 }
204
MakeBroadcastHlo(HloInstruction * operand,absl::Span<const int64> broadcast_dimensions,absl::Span<const int64> result_shape_bounds)205 HloInstruction* MakeBroadcastHlo(HloInstruction* operand,
206 absl::Span<const int64> broadcast_dimensions,
207 absl::Span<const int64> result_shape_bounds) {
208 HloComputation* computation = operand->parent();
209 Shape broadcast_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
210 result_shape_bounds);
211
212 return computation->AddInstruction(HloInstruction::CreateBroadcast(
213 broadcast_shape, operand, broadcast_dimensions));
214 }
215
MakeBroadcastHlo(HloInstruction * operand,absl::Span<const int64> broadcast_dimensions,const Shape & shape)216 HloInstruction* MakeBroadcastHlo(HloInstruction* operand,
217 absl::Span<const int64> broadcast_dimensions,
218 const Shape& shape) {
219 return MakeBroadcastHlo(operand, broadcast_dimensions, shape.dimensions());
220 }
221
MakeGetTupleElementHlo(HloInstruction * operand,int64 index)222 StatusOr<HloInstruction*> MakeGetTupleElementHlo(HloInstruction* operand,
223 int64 index) {
224 HloComputation* computation = operand->parent();
225
226 TF_ASSIGN_OR_RETURN(
227 Shape gte_shape,
228 ShapeInference::InferGetTupleElementShape(operand->shape(), index));
229 return computation->AddInstruction(
230 HloInstruction::CreateGetTupleElement(gte_shape, operand, index));
231 }
232
MakeConcatHlo(absl::Span<HloInstruction * const> operands,int64 dimension)233 StatusOr<HloInstruction*> MakeConcatHlo(
234 absl::Span<HloInstruction* const> operands, int64 dimension) {
235 CHECK_GT(operands.size(), 0);
236
237 HloComputation* computation = operands[0]->parent();
238 CHECK(absl::c_all_of(operands, [&](HloInstruction* instr) {
239 return instr->parent() == computation;
240 }));
241
242 std::vector<const Shape*> operand_shapes;
243 absl::c_transform(operands, std::back_inserter(operand_shapes),
244 [](HloInstruction* instr) { return &instr->shape(); });
245
246 TF_ASSIGN_OR_RETURN(Shape concat_shape, ShapeInference::InferConcatOpShape(
247 operand_shapes, dimension));
248 return computation->AddInstruction(
249 HloInstruction::CreateConcatenate(concat_shape, operands, dimension));
250 }
251
MakeConvertToHlo(HloInstruction * hlo,PrimitiveType type)252 HloInstruction* MakeConvertToHlo(HloInstruction* hlo, PrimitiveType type) {
253 if (hlo->shape().element_type() == type) {
254 return hlo;
255 }
256 Shape shape = ShapeUtil::ChangeElementType(hlo->shape(), type);
257 hlo =
258 hlo->parent()->AddInstruction(HloInstruction::CreateConvert(shape, hlo));
259 CHECK_EQ(hlo->shape().element_type(), type);
260 return hlo;
261 }
262
MakeBitcastConvertToHlo(HloInstruction * hlo,PrimitiveType type)263 HloInstruction* MakeBitcastConvertToHlo(HloInstruction* hlo,
264 PrimitiveType type) {
265 if (hlo->shape().element_type() == type) {
266 return hlo;
267 }
268 Shape shape = ShapeUtil::ChangeElementType(hlo->shape(), type);
269 // PRED are stored as one byte, PRED have a BitWidth of 1, avoid this problem
270 // by using a convert instead of bitcast convert.
271 if (type == PRED || hlo->shape().element_type() == PRED) {
272 return MakeConvertToHlo(hlo, type);
273 }
274 hlo = hlo->parent()->AddInstruction(
275 HloInstruction::CreateBitcastConvert(shape, hlo));
276 CHECK_EQ(hlo->shape().element_type(), type);
277 return hlo;
278 }
279
MakeIotaHlo(HloComputation * computation,const Shape & shape,int64 iota_dimension)280 HloInstruction* MakeIotaHlo(HloComputation* computation, const Shape& shape,
281 int64 iota_dimension) {
282 return computation->AddInstruction(
283 HloInstruction::CreateIota(shape, iota_dimension));
284 }
285
MakeDotHlo(HloInstruction * lhs,HloInstruction * rhs,const DotDimensionNumbers & dim_numbers,const PrecisionConfig & precision_config,absl::optional<PrimitiveType> preferred_element_type)286 StatusOr<HloInstruction*> MakeDotHlo(
287 HloInstruction* lhs, HloInstruction* rhs,
288 const DotDimensionNumbers& dim_numbers,
289 const PrecisionConfig& precision_config,
290 absl::optional<PrimitiveType> preferred_element_type) {
291 HloComputation* computation = lhs->parent();
292 CHECK_EQ(computation, rhs->parent());
293 TF_ASSIGN_OR_RETURN(
294 Shape dot_shape,
295 ShapeInference::InferDotOpShape(lhs->shape(), rhs->shape(), dim_numbers,
296 preferred_element_type));
297 return computation->AddInstruction(HloInstruction::CreateDot(
298 dot_shape, lhs, rhs, dim_numbers, precision_config));
299 }
300
MakeMapHlo(absl::Span<HloInstruction * const> operands,HloComputation * map_computation)301 StatusOr<HloInstruction*> MakeMapHlo(absl::Span<HloInstruction* const> operands,
302 HloComputation* map_computation) {
303 CHECK(!operands.empty()) << "Map Hlo requires at least one operand.";
304 HloComputation* computation = operands.front()->parent();
305 std::vector<const Shape*> operand_shapes;
306 int64 max_operand_rank = 0;
307 for (const HloInstruction* operand : operands) {
308 CHECK_EQ(computation, operand->parent());
309 operand_shapes.push_back(&operand->shape());
310 max_operand_rank = std::max(max_operand_rank, operand->shape().rank());
311 }
312 std::vector<int64> map_dims(max_operand_rank);
313 std::iota(map_dims.begin(), map_dims.end(), 0);
314 TF_ASSIGN_OR_RETURN(
315 Shape map_shape,
316 ShapeInference::InferMapShape(
317 operand_shapes, map_computation->ComputeProgramShape(), map_dims));
318 return computation->AddInstruction(
319 HloInstruction::CreateMap(map_shape, operands, map_computation));
320 }
321
MakeReduceHlo(HloInstruction * operand,HloInstruction * init_value,absl::Span<const int64> dimensions,HloOpcode binary_opcode)322 StatusOr<HloInstruction*> MakeReduceHlo(HloInstruction* operand,
323 HloInstruction* init_value,
324 absl::Span<const int64> dimensions,
325 HloOpcode binary_opcode) {
326 auto scalar_shape = ShapeUtil::MakeShape(operand->shape().element_type(), {});
327 auto result_shape = ShapeUtil::FilterDimensions(
328 [&](const int64 dim) { return !absl::c_linear_search(dimensions, dim); },
329 operand->shape());
330 HloComputation* reduce_computation;
331 {
332 HloComputation::Builder b(operand->name() + ".reduce_sub_computation");
333 auto lhs = b.AddInstruction(
334 HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
335 auto rhs = b.AddInstruction(
336 HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
337 b.AddInstruction(
338 HloInstruction::CreateBinary(scalar_shape, binary_opcode, lhs, rhs));
339 reduce_computation =
340 operand->parent()->parent()->AddEmbeddedComputation(b.Build());
341 }
342
343 return operand->parent()->AddInstruction(HloInstruction::CreateReduce(
344 result_shape, operand, init_value, dimensions, reduce_computation));
345 }
346
MakeReduceHlo(HloInstruction * operand,HloInstruction * init_value,HloOpcode binary_opcode,HloModule * module)347 StatusOr<HloInstruction*> MakeReduceHlo(HloInstruction* operand,
348 HloInstruction* init_value,
349 HloOpcode binary_opcode,
350 HloModule* module) {
351 DCHECK_NE(nullptr, module);
352 std::vector<int64> all_dims(operand->shape().rank());
353 std::iota(all_dims.begin(), all_dims.end(), 0);
354
355 auto scalar_shape = ShapeUtil::MakeShape(operand->shape().element_type(), {});
356 HloComputation* reduce_computation;
357 {
358 HloComputation::Builder b(operand->name() + ".reduce_sub_computation");
359 auto lhs = b.AddInstruction(
360 HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
361 auto rhs = b.AddInstruction(
362 HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
363 b.AddInstruction(
364 HloInstruction::CreateBinary(scalar_shape, binary_opcode, lhs, rhs));
365 reduce_computation = module->AddEmbeddedComputation(b.Build());
366 }
367
368 return operand->parent()->AddInstruction(HloInstruction::CreateReduce(
369 scalar_shape, operand, init_value, all_dims, reduce_computation));
370 }
371
MakeReverseHlo(HloInstruction * operand,absl::Span<const int64> dimensions)372 StatusOr<HloInstruction*> MakeReverseHlo(HloInstruction* operand,
373 absl::Span<const int64> dimensions) {
374 HloComputation* computation = operand->parent();
375 TF_ASSIGN_OR_RETURN(Shape reverse_shape, ShapeInference::InferReverseShape(
376 operand->shape(), dimensions));
377 return computation->AddInstruction(
378 HloInstruction::CreateReverse(reverse_shape, operand, dimensions));
379 }
380
MakeSelectHlo(HloInstruction * pred,HloInstruction * on_true,HloInstruction * on_false,HloInstruction * derived_from)381 StatusOr<HloInstruction*> MakeSelectHlo(HloInstruction* pred,
382 HloInstruction* on_true,
383 HloInstruction* on_false,
384 HloInstruction* derived_from) {
385 HloComputation* computation = pred->parent();
386 DCHECK_EQ(computation, on_true->parent());
387 DCHECK_EQ(computation, on_false->parent());
388 Shape op_shape = on_true->shape();
389 if (ShapeUtil::IsScalar(pred->shape())) {
390 if (!ShapeUtil::IsScalar(op_shape) && !op_shape.IsTuple()) {
391 // If the output is not scalar, we need to broadcast the condition
392 // to match the contract of kSelect. For tuples, we use kTupleSelect
393 // which expects the condition to be a scalar.
394 pred = computation->AddInstruction(HloInstruction::CreateBroadcast(
395 ShapeUtil::ChangeElementType(op_shape, PrimitiveType::PRED), pred,
396 {}));
397 if (derived_from) {
398 derived_from->SetupDerivedInstruction(pred);
399 }
400 }
401 }
402 HloOpcode select_op_code =
403 op_shape.IsTuple() ? HloOpcode::kTupleSelect : HloOpcode::kSelect;
404 TF_ASSIGN_OR_RETURN(Shape select_shape,
405 ShapeInference::InferTernaryOpShape(select_op_code, pred,
406 on_true, on_false));
407 HloInstruction* select =
408 computation->AddInstruction(HloInstruction::CreateTernary(
409 select_shape, select_op_code, pred, on_true, on_false));
410 if (derived_from) {
411 derived_from->SetupDerivedInstruction(select);
412 }
413 return select;
414 }
415
MakeSortHlo(const Shape & sort_shape,absl::Span<HloInstruction * const> operands,int64 dimension_to_sort,bool is_stable,HloComputation::Builder * builder,HloModule * module)416 StatusOr<HloInstruction*> MakeSortHlo(
417 const Shape& sort_shape, absl::Span<HloInstruction* const> operands,
418 int64 dimension_to_sort, bool is_stable, HloComputation::Builder* builder,
419 HloModule* module) {
420 CHECK(!operands.empty()) << "Sort Hlo requires at least one operand.";
421 HloComputation* compare_computation;
422 XlaBuilder b("Sort.Compare");
423 std::vector<PrimitiveType> operand_types(operands.size());
424 for (int64 i = 0; i < operands.size(); ++i) {
425 operand_types[i] = operands[i]->shape().element_type();
426 }
427 XlaComputation comparator = CreateScalarLtComputation(operand_types, &b);
428 TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comparator.GetProgramShape());
429 HloModuleConfig config(program_shape);
430 TF_ASSIGN_OR_RETURN(auto new_module,
431 HloModule::CreateFromProto(comparator.proto(), config));
432 HloCloneContext context(module);
433 compare_computation =
434 module->DeepCloneComputation(new_module->entry_computation(), &context);
435 return builder->AddInstruction(HloInstruction::CreateSort(
436 sort_shape, dimension_to_sort, operands, compare_computation, is_stable));
437 }
438
CollapseFirstNDims(HloInstruction * operand,int64 n)439 StatusOr<HloInstruction*> CollapseFirstNDims(HloInstruction* operand, int64 n) {
440 CHECK_GT(n, 0);
441
442 const Shape& operand_shape = operand->shape();
443 CHECK_GE(operand_shape.dimensions_size(), n);
444 int64 new_shape_leading_bound = 1;
445 for (int64 i = 0; i < n; i++) {
446 new_shape_leading_bound *= operand_shape.dimensions(i);
447 }
448
449 std::vector<int64> new_shape_dims;
450 new_shape_dims.reserve(operand_shape.dimensions_size() - n + 1);
451 new_shape_dims.push_back(new_shape_leading_bound);
452
453 std::copy(operand_shape.dimensions().begin() + n,
454 operand_shape.dimensions().end(),
455 std::back_inserter(new_shape_dims));
456
457 Shape output_shape =
458 ShapeUtil::MakeShape(operand_shape.element_type(), new_shape_dims);
459
460 return MakeReshapeHlo(output_shape, operand);
461 }
462
PrependDegenerateDims(HloInstruction * operand,int64 n)463 StatusOr<HloInstruction*> PrependDegenerateDims(HloInstruction* operand,
464 int64 n) {
465 CHECK_GT(n, 0);
466 std::vector<int64> new_shape_dims;
467 const Shape& operand_shape = operand->shape();
468 new_shape_dims.reserve(n + operand_shape.dimensions_size());
469 new_shape_dims.insert(new_shape_dims.begin(), n, 1);
470 absl::c_copy(operand_shape.dimensions(), std::back_inserter(new_shape_dims));
471 return MakeReshapeHlo(new_shape_dims, operand);
472 }
473
ExpandFirstDimIntoNDims(HloInstruction * operand,absl::Span<const int64> expanded_dims)474 StatusOr<HloInstruction*> ExpandFirstDimIntoNDims(
475 HloInstruction* operand, absl::Span<const int64> expanded_dims) {
476 CHECK_GT(operand->shape().dimensions_size(), 0);
477 CHECK_EQ(operand->shape().dimensions(0), Product(expanded_dims));
478
479 std::vector<int64> expanded_shape_dim_bounds;
480 expanded_shape_dim_bounds.reserve(expanded_dims.size() +
481 operand->shape().dimensions_size() - 1);
482 absl::c_copy(expanded_dims, std::back_inserter(expanded_shape_dim_bounds));
483 std::copy(operand->shape().dimensions().begin() + 1,
484 operand->shape().dimensions().end(),
485 std::back_inserter(expanded_shape_dim_bounds));
486 Shape new_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
487 expanded_shape_dim_bounds);
488 return MakeReshapeHlo(new_shape, operand);
489 }
490
ElideDegenerateDims(HloInstruction * operand,absl::Span<const int64> dims_to_elide)491 StatusOr<HloInstruction*> ElideDegenerateDims(
492 HloInstruction* operand, absl::Span<const int64> dims_to_elide) {
493 return MakeReshapeHlo(
494 ShapeUtil::FilterDimensions(
495 [&](int64 dim) { return !absl::c_linear_search(dims_to_elide, dim); },
496 operand->shape()),
497 operand);
498 }
499
InsertDegenerateDims(HloInstruction * operand,absl::Span<const int64> dims_to_insert)500 StatusOr<HloInstruction*> InsertDegenerateDims(
501 HloInstruction* operand, absl::Span<const int64> dims_to_insert) {
502 CHECK(absl::c_is_sorted(dims_to_insert));
503
504 const Shape& operand_shape = operand->shape();
505 int64 output_shape_rank =
506 operand_shape.dimensions_size() + dims_to_insert.size();
507 for (auto dim_to_insert : dims_to_insert) {
508 CHECK_LT(dim_to_insert, output_shape_rank);
509 }
510
511 std::vector<int64> output_shape_dim_bounds;
512 output_shape_dim_bounds.reserve(output_shape_rank);
513 int64 operand_dims_idx = 0;
514 int64 dims_to_insert_idx = 0;
515 for (int64 i = 0; i < output_shape_rank; ++i) {
516 if (dims_to_insert_idx < dims_to_insert.size() &&
517 i == dims_to_insert[dims_to_insert_idx]) {
518 output_shape_dim_bounds.push_back(1);
519 ++dims_to_insert_idx;
520 } else {
521 output_shape_dim_bounds.push_back(
522 operand_shape.dimensions(operand_dims_idx));
523 ++operand_dims_idx;
524 }
525 }
526
527 Shape output_shape = ShapeUtil::MakeShape(operand_shape.element_type(),
528 output_shape_dim_bounds);
529 return MakeReshapeHlo(output_shape, operand);
530 }
531
PadVectorWithZeros(HloInstruction * operand,int64 zeros_to_prepend,int64 zeros_to_append)532 StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand,
533 int64 zeros_to_prepend,
534 int64 zeros_to_append) {
535 HloComputation* computation = operand->parent();
536 CHECK_EQ(operand->shape().dimensions_size(), 1);
537 PaddingConfig padding_config;
538 PaddingConfig::PaddingConfigDimension padding_config_dim;
539 padding_config_dim.set_edge_padding_low(zeros_to_prepend);
540 padding_config_dim.set_edge_padding_high(zeros_to_append);
541 *padding_config.add_dimensions() = padding_config_dim;
542
543 HloInstruction* zero =
544 computation->AddInstruction(HloInstruction::CreateConstant(
545 LiteralUtil::Zero(operand->shape().element_type())));
546 return MakePadHlo(operand, zero, padding_config);
547 }
548
BroadcastZeros(HloComputation * computation,PrimitiveType element_type,absl::Span<const int64> broadcast_dimensions)549 HloInstruction* BroadcastZeros(HloComputation* computation,
550 PrimitiveType element_type,
551 absl::Span<const int64> broadcast_dimensions) {
552 HloInstruction* zero = computation->AddInstruction(
553 HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
554 return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{},
555 /*result_shape_bounds=*/broadcast_dimensions);
556 }
557
BroadcastOnes(HloComputation * computation,PrimitiveType element_type,absl::Span<const int64> broadcast_dimensions)558 HloInstruction* BroadcastOnes(HloComputation* computation,
559 PrimitiveType element_type,
560 absl::Span<const int64> broadcast_dimensions) {
561 HloInstruction* one = computation->AddInstruction(
562 HloInstruction::CreateConstant(LiteralUtil::One(element_type)));
563 return MakeBroadcastHlo(one, /*broadcast_dimensions=*/{},
564 /*result_shape_bounds=*/broadcast_dimensions);
565 }
566
567 // Recursively creates a dummy op given a shape. Leaf nodes are broadcasted zero
568 // while internal nodes are tuples.
CreateDummyOp(HloComputation::Builder * b,const Shape & shape)569 HloInstruction* CreateDummyOp(HloComputation::Builder* b, const Shape& shape) {
570 if (shape.IsArray()) {
571 auto zero = b->AddInstruction(HloInstruction::CreateConstant(
572 LiteralUtil::Zero(shape.element_type())));
573 return b->AddInstruction(HloInstruction::CreateBroadcast(shape, zero, {}));
574 }
575 CHECK(shape.IsTuple());
576 std::vector<HloInstruction*> sub_instructions;
577 for (const Shape& subshape : shape.tuple_shapes()) {
578 sub_instructions.push_back(CreateDummyOp(b, subshape));
579 }
580 return b->AddInstruction(HloInstruction::CreateTuple(sub_instructions));
581 }
582
CreateComputationWithSignature(absl::Span<const Shape * const> domain,const Shape & range,absl::string_view name)583 StatusOr<std::unique_ptr<HloComputation>> CreateComputationWithSignature(
584 absl::Span<const Shape* const> domain, const Shape& range,
585 absl::string_view name) {
586 HloComputation::Builder b{string(name)};
587 int64 param_idx = 0;
588 for (const Shape* param_shape : domain) {
589 b.AddInstruction(HloInstruction::CreateParameter(
590 param_idx, *param_shape, StrCat("param.", param_idx)));
591 param_idx++;
592 }
593
594 // We can't change the root type of a computation once it is created so create
595 // a dummy root instruction to give the computation the right root shape. Use
596 // a (recursive) broadcast here to avoid creating large constants.
597 CreateDummyOp(&b, range);
598 return b.Build();
599 }
600
601 } // namespace xla
602