1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
17
18 #include "absl/algorithm/container.h"
19 #include "absl/memory/memory.h"
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/xla/client/lib/comparators.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/client/xla_computation.h"
24 #include "tensorflow/compiler/xla/comparison_util.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/literal_util.h"
27 #include "tensorflow/compiler/xla/service/hlo_clone_context.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/service/hlo_module.h"
30 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
31 #include "tensorflow/compiler/xla/service/shape_inference.h"
32 #include "tensorflow/compiler/xla/util.h"
33
34 namespace xla {
35 using absl::StrCat;
36
MakeUnaryHlo(HloOpcode opcode,HloInstruction * operand)37 StatusOr<HloInstruction*> MakeUnaryHlo(HloOpcode opcode,
38 HloInstruction* operand) {
39 HloComputation* computation = operand->parent();
40 TF_ASSIGN_OR_RETURN(Shape unary_op_shape,
41 ShapeInference::InferUnaryOpShape(opcode, operand));
42 return computation->AddInstruction(
43 HloInstruction::CreateUnary(unary_op_shape, opcode, operand));
44 }
45
MakeBinaryHlo(HloOpcode opcode,HloInstruction * lhs,HloInstruction * rhs)46 StatusOr<HloInstruction*> MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs,
47 HloInstruction* rhs) {
48 HloComputation* computation = lhs->parent();
49 CHECK_EQ(computation, rhs->parent());
50 TF_ASSIGN_OR_RETURN(Shape binary_op_shape,
51 ShapeInference::InferBinaryOpShape(opcode, lhs, rhs));
52 return computation->AddInstruction(
53 HloInstruction::CreateBinary(binary_op_shape, opcode, lhs, rhs));
54 }
55
MakeCompareHlo(ComparisonDirection direction,HloInstruction * lhs,HloInstruction * rhs)56 StatusOr<HloInstruction*> MakeCompareHlo(ComparisonDirection direction,
57 HloInstruction* lhs,
58 HloInstruction* rhs) {
59 HloComputation* computation = lhs->parent();
60 CHECK_EQ(computation, rhs->parent());
61 TF_ASSIGN_OR_RETURN(
62 Shape binary_op_shape,
63 ShapeInference::InferBinaryOpShape(HloOpcode::kCompare, lhs, rhs));
64 return computation->AddInstruction(
65 HloInstruction::CreateCompare(binary_op_shape, lhs, rhs, direction));
66 }
67
MakePadHlo(HloInstruction * operand,HloInstruction * padding_value,const PaddingConfig & padding_config)68 StatusOr<HloInstruction*> MakePadHlo(HloInstruction* operand,
69 HloInstruction* padding_value,
70 const PaddingConfig& padding_config) {
71 HloComputation* computation = operand->parent();
72 CHECK_EQ(computation, padding_value->parent());
73 TF_ASSIGN_OR_RETURN(
74 Shape pad_shape,
75 ShapeInference::InferPadShape(operand->shape(), padding_value->shape(),
76 padding_config));
77 return computation->AddInstruction(HloInstruction::CreatePad(
78 pad_shape, operand, padding_value, padding_config));
79 }
80
MakeSliceHlo(HloInstruction * operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)81 StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand,
82 absl::Span<const int64> start_indices,
83 absl::Span<const int64> limit_indices,
84 absl::Span<const int64> strides) {
85 HloComputation* computation = operand->parent();
86 TF_ASSIGN_OR_RETURN(Shape slice_shape, ShapeInference::InferSliceShape(
87 operand->shape(), start_indices,
88 limit_indices, strides));
89 return computation->AddInstruction(HloInstruction::CreateSlice(
90 slice_shape, operand, start_indices, limit_indices, strides));
91 }
92
MakeConvolveHlo(HloInstruction * lhs,HloInstruction * rhs,int64_t feature_group_count,int64_t batch_group_count,const Window & window,const ConvolutionDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config,absl::optional<PrimitiveType> preferred_element_type)93 StatusOr<HloInstruction*> MakeConvolveHlo(
94 HloInstruction* lhs, HloInstruction* rhs, int64_t feature_group_count,
95 int64_t batch_group_count, const Window& window,
96 const ConvolutionDimensionNumbers& dimension_numbers,
97 const PrecisionConfig& precision_config,
98 absl::optional<PrimitiveType> preferred_element_type) {
99 HloComputation* computation = lhs->parent();
100 CHECK_EQ(computation, rhs->parent());
101 TF_ASSIGN_OR_RETURN(
102 Shape convolve_shape,
103 ShapeInference::InferConvolveShape(
104 lhs->shape(), rhs->shape(), feature_group_count, batch_group_count,
105 window, dimension_numbers, preferred_element_type));
106 return computation->AddInstruction(HloInstruction::CreateConvolve(
107 convolve_shape, lhs, rhs, feature_group_count, batch_group_count, window,
108 dimension_numbers, precision_config));
109 }
110
MakeTransposeHlo(HloInstruction * operand,absl::Span<const int64> dimensions)111 StatusOr<HloInstruction*> MakeTransposeHlo(HloInstruction* operand,
112 absl::Span<const int64> dimensions) {
113 HloComputation* computation = operand->parent();
114 TF_ASSIGN_OR_RETURN(
115 Shape transpose_shape,
116 ShapeInference::InferTransposeShape(operand->shape(), dimensions));
117 return computation->AddInstruction(
118 HloInstruction::CreateTranspose(transpose_shape, operand, dimensions));
119 }
120
MakeReshapeHlo(const Shape & result_shape,HloInstruction * operand)121 StatusOr<HloInstruction*> MakeReshapeHlo(const Shape& result_shape,
122 HloInstruction* operand) {
123 HloComputation* computation = operand->parent();
124 return computation->AddInstruction(
125 HloInstruction::CreateReshape(result_shape, operand));
126 }
127
MakeReshapeHlo(absl::Span<const int64> result_shape_dim_bounds,HloInstruction * operand)128 StatusOr<HloInstruction*> MakeReshapeHlo(
129 absl::Span<const int64> result_shape_dim_bounds, HloInstruction* operand) {
130 Shape new_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
131 result_shape_dim_bounds);
132 return MakeReshapeHlo(new_shape, operand);
133 }
134
MakeDynamicSliceHlo(HloInstruction * operand,absl::Span<HloInstruction * const> start_indices,absl::Span<const int64> slice_sizes)135 StatusOr<HloInstruction*> MakeDynamicSliceHlo(
136 HloInstruction* operand, absl::Span<HloInstruction* const> start_indices,
137 absl::Span<const int64> slice_sizes) {
138 HloComputation* computation = operand->parent();
139 std::vector<Shape> scalar_start_indices_shapes(
140 start_indices.size(),
141 ShapeUtil::MakeShape(start_indices[0]->shape().element_type(), {}));
142 TF_ASSIGN_OR_RETURN(
143 Shape dynamic_slice_shape,
144 ShapeInference::InferDynamicSliceShape(
145 operand->shape(), scalar_start_indices_shapes, slice_sizes));
146 return computation->AddInstruction(HloInstruction::CreateDynamicSlice(
147 dynamic_slice_shape, operand, start_indices, slice_sizes));
148 }
149
MakeDynamicSliceHlo(HloInstruction * operand,HloInstruction * start_indices,absl::Span<const int64> slice_sizes)150 StatusOr<HloInstruction*> MakeDynamicSliceHlo(
151 HloInstruction* operand, HloInstruction* start_indices,
152 absl::Span<const int64> slice_sizes) {
153 HloComputation* computation = operand->parent();
154 CHECK_EQ(computation, start_indices->parent());
155 int64_t rank = start_indices->shape().dimensions(0);
156 std::vector<HloInstruction*> scalar_start_indices;
157 for (int i = 0; i < rank; ++i) {
158 // TODO(b/118437727): Update callers to provide scalars directly.
159 auto slice = computation->AddInstruction(HloInstruction::CreateSlice(
160 ShapeUtil::MakeShape(start_indices->shape().element_type(), {1}),
161 start_indices, {i}, {i + 1}, {1}));
162 scalar_start_indices.push_back(
163 computation->AddInstruction(HloInstruction::CreateReshape(
164 ShapeUtil::MakeShape(start_indices->shape().element_type(), {}),
165 slice)));
166 }
167 std::vector<Shape> scalar_start_indices_shapes(
168 rank, ShapeUtil::MakeShape(start_indices->shape().element_type(), {}));
169 TF_ASSIGN_OR_RETURN(
170 Shape dynamic_slice_shape,
171 ShapeInference::InferDynamicSliceShape(
172 operand->shape(), scalar_start_indices_shapes, slice_sizes));
173 return computation->AddInstruction(HloInstruction::CreateDynamicSlice(
174 dynamic_slice_shape, operand, scalar_start_indices, slice_sizes));
175 }
176
MakeDynamicUpdateSliceHlo(HloInstruction * operand,HloInstruction * update,HloInstruction * start_indices)177 StatusOr<HloInstruction*> MakeDynamicUpdateSliceHlo(
178 HloInstruction* operand, HloInstruction* update,
179 HloInstruction* start_indices) {
180 HloComputation* computation = operand->parent();
181 CHECK_EQ(computation, update->parent());
182 CHECK_EQ(computation, start_indices->parent());
183 int64_t rank = start_indices->shape().dimensions(0);
184 std::vector<HloInstruction*> scalar_start_indices;
185 for (int i = 0; i < rank; ++i) {
186 // TODO(b/118437727): Update callers to provide scalars directly.
187 auto slice = computation->AddInstruction(HloInstruction::CreateSlice(
188 ShapeUtil::MakeShape(start_indices->shape().element_type(), {1}),
189 start_indices, {i}, {i + 1}, {1}));
190 scalar_start_indices.push_back(
191 computation->AddInstruction(HloInstruction::CreateReshape(
192 ShapeUtil::MakeShape(start_indices->shape().element_type(), {}),
193 slice)));
194 }
195 std::vector<Shape> scalar_start_indices_shapes(
196 rank, ShapeUtil::MakeShape(start_indices->shape().element_type(), {}));
197 TF_ASSIGN_OR_RETURN(
198 Shape dynamic_update_slice_shape,
199 ShapeInference::InferDynamicUpdateSliceShape(
200 operand->shape(), update->shape(), scalar_start_indices_shapes));
201 return computation->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
202 dynamic_update_slice_shape, operand, update, scalar_start_indices));
203 }
204
MakeBroadcastHlo(HloInstruction * operand,absl::Span<const int64> broadcast_dimensions,absl::Span<const int64> result_shape_bounds)205 HloInstruction* MakeBroadcastHlo(HloInstruction* operand,
206 absl::Span<const int64> broadcast_dimensions,
207 absl::Span<const int64> result_shape_bounds) {
208 HloComputation* computation = operand->parent();
209 Shape broadcast_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
210 result_shape_bounds);
211
212 return computation->AddInstruction(HloInstruction::CreateBroadcast(
213 broadcast_shape, operand, broadcast_dimensions));
214 }
215
MakeBroadcastHlo(HloInstruction * operand,absl::Span<const int64> broadcast_dimensions,const Shape & shape)216 HloInstruction* MakeBroadcastHlo(HloInstruction* operand,
217 absl::Span<const int64> broadcast_dimensions,
218 const Shape& shape) {
219 return MakeBroadcastHlo(operand, broadcast_dimensions, shape.dimensions());
220 }
221
MakeGetTupleElementHlo(HloInstruction * operand,int64_t index)222 StatusOr<HloInstruction*> MakeGetTupleElementHlo(HloInstruction* operand,
223 int64_t index) {
224 HloComputation* computation = operand->parent();
225
226 TF_ASSIGN_OR_RETURN(
227 Shape gte_shape,
228 ShapeInference::InferGetTupleElementShape(operand->shape(), index));
229 return computation->AddInstruction(
230 HloInstruction::CreateGetTupleElement(gte_shape, operand, index));
231 }
232
MakeConcatHlo(absl::Span<HloInstruction * const> operands,int64_t dimension)233 StatusOr<HloInstruction*> MakeConcatHlo(
234 absl::Span<HloInstruction* const> operands, int64_t dimension) {
235 CHECK_GT(operands.size(), 0);
236
237 HloComputation* computation = operands[0]->parent();
238 CHECK(absl::c_all_of(operands, [&](HloInstruction* instr) {
239 return instr->parent() == computation;
240 }));
241
242 std::vector<const Shape*> operand_shapes;
243 absl::c_transform(operands, std::back_inserter(operand_shapes),
244 [](HloInstruction* instr) { return &instr->shape(); });
245
246 TF_ASSIGN_OR_RETURN(Shape concat_shape, ShapeInference::InferConcatOpShape(
247 operand_shapes, dimension));
248 return computation->AddInstruction(
249 HloInstruction::CreateConcatenate(concat_shape, operands, dimension));
250 }
251
MakeConvertToHlo(HloInstruction * hlo,PrimitiveType type)252 HloInstruction* MakeConvertToHlo(HloInstruction* hlo, PrimitiveType type) {
253 if (hlo->shape().element_type() == type) {
254 return hlo;
255 }
256 Shape shape = ShapeUtil::ChangeElementType(hlo->shape(), type);
257 hlo =
258 hlo->parent()->AddInstruction(HloInstruction::CreateConvert(shape, hlo));
259 CHECK_EQ(hlo->shape().element_type(), type);
260 return hlo;
261 }
262
MakeBitcastConvertToHlo(HloInstruction * hlo,PrimitiveType type)263 HloInstruction* MakeBitcastConvertToHlo(HloInstruction* hlo,
264 PrimitiveType type) {
265 if (hlo->shape().element_type() == type) {
266 return hlo;
267 }
268 Shape shape = ShapeUtil::ChangeElementType(hlo->shape(), type);
269 // PRED are stored as one byte, PRED have a BitWidth of 1, avoid this problem
270 // by using a convert instead of bitcast convert.
271 if (type == PRED || hlo->shape().element_type() == PRED) {
272 return MakeConvertToHlo(hlo, type);
273 }
274 hlo = hlo->parent()->AddInstruction(
275 HloInstruction::CreateBitcastConvert(shape, hlo));
276 CHECK_EQ(hlo->shape().element_type(), type);
277 return hlo;
278 }
279
MakeIotaHlo(HloComputation * computation,const Shape & shape,int64_t iota_dimension)280 HloInstruction* MakeIotaHlo(HloComputation* computation, const Shape& shape,
281 int64_t iota_dimension) {
282 return computation->AddInstruction(
283 HloInstruction::CreateIota(shape, iota_dimension));
284 }
285
MakeDotHlo(HloInstruction * lhs,HloInstruction * rhs,const DotDimensionNumbers & dim_numbers,const PrecisionConfig & precision_config,absl::optional<PrimitiveType> preferred_element_type)286 StatusOr<HloInstruction*> MakeDotHlo(
287 HloInstruction* lhs, HloInstruction* rhs,
288 const DotDimensionNumbers& dim_numbers,
289 const PrecisionConfig& precision_config,
290 absl::optional<PrimitiveType> preferred_element_type) {
291 HloComputation* computation = lhs->parent();
292 CHECK_EQ(computation, rhs->parent());
293 TF_ASSIGN_OR_RETURN(
294 Shape dot_shape,
295 ShapeInference::InferDotOpShape(lhs->shape(), rhs->shape(), dim_numbers,
296 preferred_element_type));
297 return computation->AddInstruction(HloInstruction::CreateDot(
298 dot_shape, lhs, rhs, dim_numbers, precision_config));
299 }
300
MakeMapHlo(absl::Span<HloInstruction * const> operands,HloComputation * map_computation)301 StatusOr<HloInstruction*> MakeMapHlo(absl::Span<HloInstruction* const> operands,
302 HloComputation* map_computation) {
303 CHECK(!operands.empty()) << "Map Hlo requires at least one operand.";
304 HloComputation* computation = operands.front()->parent();
305 std::vector<const Shape*> operand_shapes;
306 int64_t max_operand_rank = 0;
307 for (const HloInstruction* operand : operands) {
308 CHECK_EQ(computation, operand->parent());
309 operand_shapes.push_back(&operand->shape());
310 max_operand_rank = std::max(max_operand_rank, operand->shape().rank());
311 }
312 std::vector<int64> map_dims(max_operand_rank);
313 std::iota(map_dims.begin(), map_dims.end(), 0);
314 TF_ASSIGN_OR_RETURN(
315 Shape map_shape,
316 ShapeInference::InferMapShape(
317 operand_shapes, map_computation->ComputeProgramShape(), map_dims));
318 return computation->AddInstruction(
319 HloInstruction::CreateMap(map_shape, operands, map_computation));
320 }
321
MakeReduceHlo(HloInstruction * operand,HloInstruction * init_value,absl::Span<const int64> dimensions,HloOpcode binary_opcode)322 StatusOr<HloInstruction*> MakeReduceHlo(HloInstruction* operand,
323 HloInstruction* init_value,
324 absl::Span<const int64> dimensions,
325 HloOpcode binary_opcode) {
326 auto scalar_shape = ShapeUtil::MakeShape(operand->shape().element_type(), {});
327 auto result_shape = ShapeUtil::FilterDimensions(
328 [&](const int64_t dim) {
329 return !absl::c_linear_search(dimensions, dim);
330 },
331 operand->shape());
332 HloComputation* reduce_computation;
333 {
334 HloComputation::Builder b(operand->name() + ".reduce_sub_computation");
335 auto lhs = b.AddInstruction(
336 HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
337 auto rhs = b.AddInstruction(
338 HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
339 b.AddInstruction(
340 HloInstruction::CreateBinary(scalar_shape, binary_opcode, lhs, rhs));
341 reduce_computation =
342 operand->parent()->parent()->AddEmbeddedComputation(b.Build());
343 }
344
345 return operand->parent()->AddInstruction(HloInstruction::CreateReduce(
346 result_shape, operand, init_value, dimensions, reduce_computation));
347 }
348
MakeReduceHlo(HloInstruction * operand,HloInstruction * init_value,HloOpcode binary_opcode,HloModule * module)349 StatusOr<HloInstruction*> MakeReduceHlo(HloInstruction* operand,
350 HloInstruction* init_value,
351 HloOpcode binary_opcode,
352 HloModule* module) {
353 DCHECK_NE(nullptr, module);
354 std::vector<int64> all_dims(operand->shape().rank());
355 std::iota(all_dims.begin(), all_dims.end(), 0);
356
357 auto scalar_shape = ShapeUtil::MakeShape(operand->shape().element_type(), {});
358 HloComputation* reduce_computation;
359 {
360 HloComputation::Builder b(operand->name() + ".reduce_sub_computation");
361 auto lhs = b.AddInstruction(
362 HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
363 auto rhs = b.AddInstruction(
364 HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
365 b.AddInstruction(
366 HloInstruction::CreateBinary(scalar_shape, binary_opcode, lhs, rhs));
367 reduce_computation = module->AddEmbeddedComputation(b.Build());
368 }
369
370 return operand->parent()->AddInstruction(HloInstruction::CreateReduce(
371 scalar_shape, operand, init_value, all_dims, reduce_computation));
372 }
373
MakeReverseHlo(HloInstruction * operand,absl::Span<const int64> dimensions)374 StatusOr<HloInstruction*> MakeReverseHlo(HloInstruction* operand,
375 absl::Span<const int64> dimensions) {
376 HloComputation* computation = operand->parent();
377 TF_ASSIGN_OR_RETURN(Shape reverse_shape, ShapeInference::InferReverseShape(
378 operand->shape(), dimensions));
379 return computation->AddInstruction(
380 HloInstruction::CreateReverse(reverse_shape, operand, dimensions));
381 }
382
MakeSelectHlo(HloInstruction * pred,HloInstruction * on_true,HloInstruction * on_false,HloInstruction * derived_from)383 StatusOr<HloInstruction*> MakeSelectHlo(HloInstruction* pred,
384 HloInstruction* on_true,
385 HloInstruction* on_false,
386 HloInstruction* derived_from) {
387 HloComputation* computation = pred->parent();
388 DCHECK_EQ(computation, on_true->parent());
389 DCHECK_EQ(computation, on_false->parent());
390 Shape op_shape = on_true->shape();
391 if (ShapeUtil::IsScalar(pred->shape())) {
392 if (!ShapeUtil::IsScalar(op_shape) && !op_shape.IsTuple()) {
393 // If the output is not scalar, we need to broadcast the condition
394 // to match the contract of kSelect. For tuples, we use kTupleSelect
395 // which expects the condition to be a scalar.
396 pred = computation->AddInstruction(HloInstruction::CreateBroadcast(
397 ShapeUtil::ChangeElementType(op_shape, PrimitiveType::PRED), pred,
398 {}));
399 if (derived_from) {
400 derived_from->SetupDerivedInstruction(pred);
401 }
402 }
403 }
404 HloOpcode select_op_code =
405 op_shape.IsTuple() ? HloOpcode::kTupleSelect : HloOpcode::kSelect;
406 TF_ASSIGN_OR_RETURN(Shape select_shape,
407 ShapeInference::InferTernaryOpShape(select_op_code, pred,
408 on_true, on_false));
409 HloInstruction* select =
410 computation->AddInstruction(HloInstruction::CreateTernary(
411 select_shape, select_op_code, pred, on_true, on_false));
412 if (derived_from) {
413 derived_from->SetupDerivedInstruction(select);
414 }
415 return select;
416 }
417
MakeSortHlo(const Shape & sort_shape,absl::Span<HloInstruction * const> operands,int64_t dimension_to_sort,bool is_stable,HloComputation::Builder * builder,HloModule * module)418 StatusOr<HloInstruction*> MakeSortHlo(
419 const Shape& sort_shape, absl::Span<HloInstruction* const> operands,
420 int64_t dimension_to_sort, bool is_stable, HloComputation::Builder* builder,
421 HloModule* module) {
422 CHECK(!operands.empty()) << "Sort Hlo requires at least one operand.";
423 HloComputation* compare_computation;
424 XlaBuilder b("Sort.Compare");
425 std::vector<PrimitiveType> operand_types(operands.size());
426 for (int64_t i = 0; i < operands.size(); ++i) {
427 operand_types[i] = operands[i]->shape().element_type();
428 }
429 XlaComputation comparator = CreateScalarLtComputation(operand_types, &b);
430 TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comparator.GetProgramShape());
431 HloModuleConfig config(program_shape);
432 TF_ASSIGN_OR_RETURN(auto new_module,
433 HloModule::CreateFromProto(comparator.proto(), config));
434 HloCloneContext context(module);
435 compare_computation =
436 module->DeepCloneComputation(new_module->entry_computation(), &context);
437 return builder->AddInstruction(HloInstruction::CreateSort(
438 sort_shape, dimension_to_sort, operands, compare_computation, is_stable));
439 }
440
CollapseFirstNDims(HloInstruction * operand,int64_t n)441 StatusOr<HloInstruction*> CollapseFirstNDims(HloInstruction* operand,
442 int64_t n) {
443 CHECK_GT(n, 0);
444
445 const Shape& operand_shape = operand->shape();
446 CHECK_GE(operand_shape.dimensions_size(), n);
447 int64_t new_shape_leading_bound = 1;
448 for (int64_t i = 0; i < n; i++) {
449 new_shape_leading_bound *= operand_shape.dimensions(i);
450 }
451
452 std::vector<int64> new_shape_dims;
453 new_shape_dims.reserve(operand_shape.dimensions_size() - n + 1);
454 new_shape_dims.push_back(new_shape_leading_bound);
455
456 std::copy(operand_shape.dimensions().begin() + n,
457 operand_shape.dimensions().end(),
458 std::back_inserter(new_shape_dims));
459
460 Shape output_shape =
461 ShapeUtil::MakeShape(operand_shape.element_type(), new_shape_dims);
462
463 return MakeReshapeHlo(output_shape, operand);
464 }
465
PrependDegenerateDims(HloInstruction * operand,int64_t n)466 StatusOr<HloInstruction*> PrependDegenerateDims(HloInstruction* operand,
467 int64_t n) {
468 CHECK_GT(n, 0);
469 std::vector<int64> new_shape_dims;
470 const Shape& operand_shape = operand->shape();
471 new_shape_dims.reserve(n + operand_shape.dimensions_size());
472 new_shape_dims.insert(new_shape_dims.begin(), n, 1);
473 absl::c_copy(operand_shape.dimensions(), std::back_inserter(new_shape_dims));
474 return MakeReshapeHlo(new_shape_dims, operand);
475 }
476
ExpandFirstDimIntoNDims(HloInstruction * operand,absl::Span<const int64> expanded_dims)477 StatusOr<HloInstruction*> ExpandFirstDimIntoNDims(
478 HloInstruction* operand, absl::Span<const int64> expanded_dims) {
479 CHECK_GT(operand->shape().dimensions_size(), 0);
480 CHECK_EQ(operand->shape().dimensions(0), Product(expanded_dims));
481
482 std::vector<int64> expanded_shape_dim_bounds;
483 expanded_shape_dim_bounds.reserve(expanded_dims.size() +
484 operand->shape().dimensions_size() - 1);
485 absl::c_copy(expanded_dims, std::back_inserter(expanded_shape_dim_bounds));
486 std::copy(operand->shape().dimensions().begin() + 1,
487 operand->shape().dimensions().end(),
488 std::back_inserter(expanded_shape_dim_bounds));
489 Shape new_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
490 expanded_shape_dim_bounds);
491 return MakeReshapeHlo(new_shape, operand);
492 }
493
ElideDegenerateDims(HloInstruction * operand,absl::Span<const int64> dims_to_elide)494 StatusOr<HloInstruction*> ElideDegenerateDims(
495 HloInstruction* operand, absl::Span<const int64> dims_to_elide) {
496 return MakeReshapeHlo(ShapeUtil::FilterDimensions(
497 [&](int64_t dim) {
498 return !absl::c_linear_search(dims_to_elide, dim);
499 },
500 operand->shape()),
501 operand);
502 }
503
InsertDegenerateDims(HloInstruction * operand,absl::Span<const int64> dims_to_insert)504 StatusOr<HloInstruction*> InsertDegenerateDims(
505 HloInstruction* operand, absl::Span<const int64> dims_to_insert) {
506 CHECK(absl::c_is_sorted(dims_to_insert));
507
508 const Shape& operand_shape = operand->shape();
509 int64_t output_shape_rank =
510 operand_shape.dimensions_size() + dims_to_insert.size();
511 for (auto dim_to_insert : dims_to_insert) {
512 CHECK_LT(dim_to_insert, output_shape_rank);
513 }
514
515 std::vector<int64> output_shape_dim_bounds;
516 output_shape_dim_bounds.reserve(output_shape_rank);
517 int64_t operand_dims_idx = 0;
518 int64_t dims_to_insert_idx = 0;
519 for (int64_t i = 0; i < output_shape_rank; ++i) {
520 if (dims_to_insert_idx < dims_to_insert.size() &&
521 i == dims_to_insert[dims_to_insert_idx]) {
522 output_shape_dim_bounds.push_back(1);
523 ++dims_to_insert_idx;
524 } else {
525 output_shape_dim_bounds.push_back(
526 operand_shape.dimensions(operand_dims_idx));
527 ++operand_dims_idx;
528 }
529 }
530
531 Shape output_shape = ShapeUtil::MakeShape(operand_shape.element_type(),
532 output_shape_dim_bounds);
533 return MakeReshapeHlo(output_shape, operand);
534 }
535
PadVectorWithZeros(HloInstruction * operand,int64_t zeros_to_prepend,int64_t zeros_to_append)536 StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand,
537 int64_t zeros_to_prepend,
538 int64_t zeros_to_append) {
539 HloComputation* computation = operand->parent();
540 CHECK_EQ(operand->shape().dimensions_size(), 1);
541 PaddingConfig padding_config;
542 PaddingConfig::PaddingConfigDimension padding_config_dim;
543 padding_config_dim.set_edge_padding_low(zeros_to_prepend);
544 padding_config_dim.set_edge_padding_high(zeros_to_append);
545 *padding_config.add_dimensions() = padding_config_dim;
546
547 HloInstruction* zero =
548 computation->AddInstruction(HloInstruction::CreateConstant(
549 LiteralUtil::Zero(operand->shape().element_type())));
550 return MakePadHlo(operand, zero, padding_config);
551 }
552
BroadcastZeros(HloComputation * computation,PrimitiveType element_type,absl::Span<const int64> broadcast_dimensions)553 HloInstruction* BroadcastZeros(HloComputation* computation,
554 PrimitiveType element_type,
555 absl::Span<const int64> broadcast_dimensions) {
556 HloInstruction* zero = computation->AddInstruction(
557 HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
558 return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{},
559 /*result_shape_bounds=*/broadcast_dimensions);
560 }
561
BroadcastOnes(HloComputation * computation,PrimitiveType element_type,absl::Span<const int64> broadcast_dimensions)562 HloInstruction* BroadcastOnes(HloComputation* computation,
563 PrimitiveType element_type,
564 absl::Span<const int64> broadcast_dimensions) {
565 HloInstruction* one = computation->AddInstruction(
566 HloInstruction::CreateConstant(LiteralUtil::One(element_type)));
567 return MakeBroadcastHlo(one, /*broadcast_dimensions=*/{},
568 /*result_shape_bounds=*/broadcast_dimensions);
569 }
570
MakeFusionInstruction(HloInstruction * fused,HloInstruction::FusionKind kind)571 StatusOr<HloInstruction*> MakeFusionInstruction(
572 HloInstruction* fused, HloInstruction::FusionKind kind) {
573 HloComputation* comp = fused->parent();
574 HloInstruction* fusion_instruction = comp->AddInstruction(
575 HloInstruction::CreateFusion(fused->shape(), kind, fused));
576 TF_RETURN_IF_ERROR(comp->ReplaceInstruction(fused, fusion_instruction));
577 return fusion_instruction;
578 }
579
580 // Recursively creates a dummy op given a shape. Leaf nodes are broadcasted zero
581 // while internal nodes are tuples.
CreateDummyOp(HloComputation::Builder * b,const Shape & shape)582 HloInstruction* CreateDummyOp(HloComputation::Builder* b, const Shape& shape) {
583 if (shape.IsArray()) {
584 auto zero = b->AddInstruction(HloInstruction::CreateConstant(
585 LiteralUtil::Zero(shape.element_type())));
586 return b->AddInstruction(HloInstruction::CreateBroadcast(shape, zero, {}));
587 }
588 CHECK(shape.IsTuple());
589 std::vector<HloInstruction*> sub_instructions;
590 for (const Shape& subshape : shape.tuple_shapes()) {
591 sub_instructions.push_back(CreateDummyOp(b, subshape));
592 }
593 return b->AddInstruction(HloInstruction::CreateTuple(sub_instructions));
594 }
595
CreateComputationWithSignature(absl::Span<const Shape * const> domain,const Shape & range,absl::string_view name)596 StatusOr<std::unique_ptr<HloComputation>> CreateComputationWithSignature(
597 absl::Span<const Shape* const> domain, const Shape& range,
598 absl::string_view name) {
599 HloComputation::Builder b{string(name)};
600 int64_t param_idx = 0;
601 for (const Shape* param_shape : domain) {
602 b.AddInstruction(HloInstruction::CreateParameter(
603 param_idx, *param_shape, StrCat("param.", param_idx)));
604 param_idx++;
605 }
606
607 // We can't change the root type of a computation once it is created so create
608 // a dummy root instruction to give the computation the right root shape. Use
609 // a (recursive) broadcast here to avoid creating large constants.
610 CreateDummyOp(&b, range);
611 return b.Build();
612 }
613
614 } // namespace xla
615