1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
17 #include "absl/algorithm/container.h"
18 #include "absl/memory/memory.h"
19 #include "absl/strings/str_cat.h"
20 #include "tensorflow/compiler/xla/client/lib/comparators.h"
21 #include "tensorflow/compiler/xla/client/xla_builder.h"
22 #include "tensorflow/compiler/xla/client/xla_computation.h"
23 #include "tensorflow/compiler/xla/literal.h"
24 #include "tensorflow/compiler/xla/literal_util.h"
25 #include "tensorflow/compiler/xla/service/hlo_clone_context.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_module.h"
28 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
29 #include "tensorflow/compiler/xla/service/shape_inference.h"
30 #include "tensorflow/compiler/xla/util.h"
31
32 namespace xla {
33 using absl::StrCat;
34
MakeBinaryHlo(HloOpcode opcode,HloInstruction * lhs,HloInstruction * rhs)35 StatusOr<HloInstruction*> MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs,
36 HloInstruction* rhs) {
37 HloComputation* computation = lhs->parent();
38 CHECK_EQ(computation, rhs->parent());
39 TF_ASSIGN_OR_RETURN(Shape binary_op_shape,
40 ShapeInference::InferBinaryOpShape(opcode, lhs, rhs));
41 return computation->AddInstruction(
42 HloInstruction::CreateBinary(binary_op_shape, opcode, lhs, rhs));
43 }
44
MakeCompareHlo(ComparisonDirection direction,HloInstruction * lhs,HloInstruction * rhs)45 StatusOr<HloInstruction*> MakeCompareHlo(ComparisonDirection direction,
46 HloInstruction* lhs,
47 HloInstruction* rhs) {
48 HloComputation* computation = lhs->parent();
49 CHECK_EQ(computation, rhs->parent());
50 TF_ASSIGN_OR_RETURN(
51 Shape binary_op_shape,
52 ShapeInference::InferBinaryOpShape(HloOpcode::kCompare, lhs, rhs));
53 return computation->AddInstruction(
54 HloInstruction::CreateCompare(binary_op_shape, lhs, rhs, direction));
55 }
56
MakePadHlo(HloInstruction * operand,HloInstruction * padding_value,const PaddingConfig & padding_config)57 StatusOr<HloInstruction*> MakePadHlo(HloInstruction* operand,
58 HloInstruction* padding_value,
59 const PaddingConfig& padding_config) {
60 HloComputation* computation = operand->parent();
61 CHECK_EQ(computation, padding_value->parent());
62 TF_ASSIGN_OR_RETURN(
63 Shape pad_shape,
64 ShapeInference::InferPadShape(operand->shape(), padding_value->shape(),
65 padding_config));
66 return computation->AddInstruction(HloInstruction::CreatePad(
67 pad_shape, operand, padding_value, padding_config));
68 }
69
MakeSliceHlo(HloInstruction * operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)70 StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand,
71 absl::Span<const int64> start_indices,
72 absl::Span<const int64> limit_indices,
73 absl::Span<const int64> strides) {
74 HloComputation* computation = operand->parent();
75 TF_ASSIGN_OR_RETURN(Shape slice_shape, ShapeInference::InferSliceShape(
76 operand->shape(), start_indices,
77 limit_indices, strides));
78 return computation->AddInstruction(HloInstruction::CreateSlice(
79 slice_shape, operand, start_indices, limit_indices, strides));
80 }
81
MakeConvolveHlo(HloInstruction * lhs,HloInstruction * rhs,int64 feature_group_count,const Window & window,const ConvolutionDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config)82 StatusOr<HloInstruction*> MakeConvolveHlo(
83 HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count,
84 const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
85 const PrecisionConfig& precision_config) {
86 HloComputation* computation = lhs->parent();
87 CHECK_EQ(computation, rhs->parent());
88 TF_ASSIGN_OR_RETURN(Shape convolve_shape,
89 ShapeInference::InferConvolveShape(
90 lhs->shape(), rhs->shape(), feature_group_count, 1,
91 window, dimension_numbers));
92 return computation->AddInstruction(HloInstruction::CreateConvolve(
93 convolve_shape, lhs, rhs, feature_group_count, 1, window,
94 dimension_numbers, precision_config));
95 }
96
MakeTransposeHlo(HloInstruction * operand,absl::Span<const int64> dimensions)97 StatusOr<HloInstruction*> MakeTransposeHlo(HloInstruction* operand,
98 absl::Span<const int64> dimensions) {
99 HloComputation* computation = operand->parent();
100 TF_ASSIGN_OR_RETURN(
101 Shape transpose_shape,
102 ShapeInference::InferTransposeShape(operand->shape(), dimensions));
103 return computation->AddInstruction(
104 HloInstruction::CreateTranspose(transpose_shape, operand, dimensions));
105 }
106
MakeReshapeHlo(const Shape & result_shape,HloInstruction * operand)107 StatusOr<HloInstruction*> MakeReshapeHlo(const Shape& result_shape,
108 HloInstruction* operand) {
109 HloComputation* computation = operand->parent();
110 return computation->AddInstruction(
111 HloInstruction::CreateReshape(result_shape, operand));
112 }
113
MakeReshapeHlo(absl::Span<const int64> result_shape_dim_bounds,HloInstruction * operand)114 StatusOr<HloInstruction*> MakeReshapeHlo(
115 absl::Span<const int64> result_shape_dim_bounds, HloInstruction* operand) {
116 Shape new_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
117 result_shape_dim_bounds);
118 return MakeReshapeHlo(new_shape, operand);
119 }
120
MakeDynamicSliceHlo(HloInstruction * operand,HloInstruction * start_indices,absl::Span<const int64> slice_sizes)121 StatusOr<HloInstruction*> MakeDynamicSliceHlo(
122 HloInstruction* operand, HloInstruction* start_indices,
123 absl::Span<const int64> slice_sizes) {
124 HloComputation* computation = operand->parent();
125 CHECK_EQ(computation, start_indices->parent());
126 int64 rank = start_indices->shape().dimensions(0);
127 std::vector<HloInstruction*> scalar_start_indices;
128 for (int i = 0; i < rank; ++i) {
129 // TODO(b/118437727): Update callers to provide scalars directly.
130 auto slice = computation->AddInstruction(HloInstruction::CreateSlice(
131 ShapeUtil::MakeShape(start_indices->shape().element_type(), {1}),
132 start_indices, {i}, {i + 1}, {1}));
133 scalar_start_indices.push_back(
134 computation->AddInstruction(HloInstruction::CreateReshape(
135 ShapeUtil::MakeShape(start_indices->shape().element_type(), {}),
136 slice)));
137 }
138 std::vector<Shape> scalar_start_indices_shapes(
139 rank, ShapeUtil::MakeShape(start_indices->shape().element_type(), {}));
140 TF_ASSIGN_OR_RETURN(
141 Shape dynamic_slice_shape,
142 ShapeInference::InferDynamicSliceShape(
143 operand->shape(), scalar_start_indices_shapes, slice_sizes));
144 return computation->AddInstruction(HloInstruction::CreateDynamicSlice(
145 dynamic_slice_shape, operand, scalar_start_indices, slice_sizes));
146 }
147
MakeDynamicUpdateSliceHlo(HloInstruction * operand,HloInstruction * update,HloInstruction * start_indices)148 StatusOr<HloInstruction*> MakeDynamicUpdateSliceHlo(
149 HloInstruction* operand, HloInstruction* update,
150 HloInstruction* start_indices) {
151 HloComputation* computation = operand->parent();
152 CHECK_EQ(computation, update->parent());
153 CHECK_EQ(computation, start_indices->parent());
154 int64 rank = start_indices->shape().dimensions(0);
155 std::vector<HloInstruction*> scalar_start_indices;
156 for (int i = 0; i < rank; ++i) {
157 // TODO(b/118437727): Update callers to provide scalars directly.
158 auto slice = computation->AddInstruction(HloInstruction::CreateSlice(
159 ShapeUtil::MakeShape(start_indices->shape().element_type(), {1}),
160 start_indices, {i}, {i + 1}, {1}));
161 scalar_start_indices.push_back(
162 computation->AddInstruction(HloInstruction::CreateReshape(
163 ShapeUtil::MakeShape(start_indices->shape().element_type(), {}),
164 slice)));
165 }
166 std::vector<Shape> scalar_start_indices_shapes(
167 rank, ShapeUtil::MakeShape(start_indices->shape().element_type(), {}));
168 TF_ASSIGN_OR_RETURN(
169 Shape dynamic_update_slice_shape,
170 ShapeInference::InferDynamicUpdateSliceShape(
171 operand->shape(), update->shape(), scalar_start_indices_shapes));
172 return computation->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
173 dynamic_update_slice_shape, operand, update, scalar_start_indices));
174 }
175
MakeBroadcastHlo(HloInstruction * operand,absl::Span<const int64> broadcast_dimensions,absl::Span<const int64> result_shape_bounds)176 HloInstruction* MakeBroadcastHlo(HloInstruction* operand,
177 absl::Span<const int64> broadcast_dimensions,
178 absl::Span<const int64> result_shape_bounds) {
179 HloComputation* computation = operand->parent();
180 Shape broadcast_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
181 result_shape_bounds);
182
183 return computation->AddInstruction(HloInstruction::CreateBroadcast(
184 broadcast_shape, operand, broadcast_dimensions));
185 }
186
MakeGetTupleElementHlo(HloInstruction * operand,int64 index)187 StatusOr<HloInstruction*> MakeGetTupleElementHlo(HloInstruction* operand,
188 int64 index) {
189 HloComputation* computation = operand->parent();
190
191 TF_ASSIGN_OR_RETURN(
192 Shape gte_shape,
193 ShapeInference::InferGetTupleElementShape(operand->shape(), index));
194 return computation->AddInstruction(
195 HloInstruction::CreateGetTupleElement(gte_shape, operand, index));
196 }
197
MakeConcatHlo(absl::Span<HloInstruction * const> operands,int64 dimension)198 StatusOr<HloInstruction*> MakeConcatHlo(
199 absl::Span<HloInstruction* const> operands, int64 dimension) {
200 CHECK_GT(operands.size(), 0);
201
202 HloComputation* computation = operands[0]->parent();
203 CHECK(absl::c_all_of(operands, [&](HloInstruction* instr) {
204 return instr->parent() == computation;
205 }));
206
207 std::vector<const Shape*> operand_shapes;
208 absl::c_transform(operands, std::back_inserter(operand_shapes),
209 [](HloInstruction* instr) { return &instr->shape(); });
210
211 TF_ASSIGN_OR_RETURN(Shape concat_shape, ShapeInference::InferConcatOpShape(
212 operand_shapes, dimension));
213 return computation->AddInstruction(
214 HloInstruction::CreateConcatenate(concat_shape, operands, dimension));
215 }
216
MakeDotHlo(HloInstruction * lhs,HloInstruction * rhs,const DotDimensionNumbers & dim_numbers,const PrecisionConfig & precision_config)217 StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
218 const DotDimensionNumbers& dim_numbers,
219 const PrecisionConfig& precision_config) {
220 HloComputation* computation = lhs->parent();
221 CHECK_EQ(computation, rhs->parent());
222 TF_ASSIGN_OR_RETURN(
223 Shape dot_shape,
224 ShapeInference::InferDotOpShape(lhs->shape(), rhs->shape(), dim_numbers));
225 return computation->AddInstruction(HloInstruction::CreateDot(
226 dot_shape, lhs, rhs, dim_numbers, precision_config));
227 }
228
MakeMapHlo(absl::Span<HloInstruction * const> operands,HloComputation * map_computation)229 StatusOr<HloInstruction*> MakeMapHlo(absl::Span<HloInstruction* const> operands,
230 HloComputation* map_computation) {
231 CHECK(!operands.empty()) << "Map Hlo requires at least one operand.";
232 HloComputation* computation = operands.front()->parent();
233 std::vector<const Shape*> operand_shapes;
234 int64 max_operand_rank = 0;
235 for (const HloInstruction* operand : operands) {
236 CHECK_EQ(computation, operand->parent());
237 operand_shapes.push_back(&operand->shape());
238 max_operand_rank = std::max(max_operand_rank, operand->shape().rank());
239 }
240 std::vector<int64> map_dims(max_operand_rank);
241 std::iota(map_dims.begin(), map_dims.end(), 0);
242 TF_ASSIGN_OR_RETURN(
243 Shape map_shape,
244 ShapeInference::InferMapShape(
245 operand_shapes, map_computation->ComputeProgramShape(), map_dims));
246 return computation->AddInstruction(
247 HloInstruction::CreateMap(map_shape, operands, map_computation));
248 }
249
MakeReduceHlo(HloInstruction * operand,HloInstruction * init_value,HloOpcode binary_opcode,HloModule * module)250 StatusOr<HloInstruction*> MakeReduceHlo(HloInstruction* operand,
251 HloInstruction* init_value,
252 HloOpcode binary_opcode,
253 HloModule* module) {
254 DCHECK_NE(nullptr, module);
255 std::vector<int64> all_dims(operand->shape().rank());
256 std::iota(all_dims.begin(), all_dims.end(), 0);
257
258 auto scalar_shape = ShapeUtil::MakeShape(operand->shape().element_type(), {});
259 HloComputation* reduce_computation;
260 {
261 HloComputation::Builder b(operand->name() + ".reduce_sub_computation");
262 auto lhs = b.AddInstruction(
263 HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
264 auto rhs = b.AddInstruction(
265 HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
266 b.AddInstruction(
267 HloInstruction::CreateBinary(scalar_shape, binary_opcode, lhs, rhs));
268 reduce_computation = module->AddEmbeddedComputation(b.Build());
269 }
270
271 return operand->parent()->AddInstruction(HloInstruction::CreateReduce(
272 scalar_shape, operand, init_value, all_dims, reduce_computation));
273 }
274
MakeSelectHlo(HloInstruction * pred,HloInstruction * on_true,HloInstruction * on_false)275 StatusOr<HloInstruction*> MakeSelectHlo(HloInstruction* pred,
276 HloInstruction* on_true,
277 HloInstruction* on_false) {
278 HloComputation* computation = pred->parent();
279 DCHECK_EQ(computation, on_true->parent());
280 DCHECK_EQ(computation, on_false->parent());
281 TF_ASSIGN_OR_RETURN(Shape select_shape,
282 ShapeInference::InferTernaryOpShape(
283 HloOpcode::kSelect, pred, on_true, on_false));
284 return computation->AddInstruction(HloInstruction::CreateTernary(
285 select_shape, HloOpcode::kSelect, pred, on_true, on_false));
286 }
287
MakeSortHlo(const Shape & sort_shape,absl::Span<HloInstruction * const> operands,int64 dimension_to_sort,bool is_stable,HloComputation::Builder * builder,HloModule * module)288 StatusOr<HloInstruction*> MakeSortHlo(
289 const Shape& sort_shape, absl::Span<HloInstruction* const> operands,
290 int64 dimension_to_sort, bool is_stable, HloComputation::Builder* builder,
291 HloModule* module) {
292 CHECK(!operands.empty()) << "Sort Hlo requires at least one operand.";
293 HloComputation* compare_computation;
294 XlaBuilder b("Sort.Compare");
295 std::vector<PrimitiveType> operand_types(operands.size());
296 for (int64 i = 0; i < operands.size(); ++i) {
297 operand_types[i] = operands[i]->shape().element_type();
298 }
299 XlaComputation comparator = CreateScalarLtComputation(operand_types, &b);
300 TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comparator.GetProgramShape());
301 HloModuleConfig config(program_shape);
302 TF_ASSIGN_OR_RETURN(auto new_module,
303 HloModule::CreateFromProto(comparator.proto(), config));
304 HloCloneContext context(module);
305 compare_computation =
306 module->DeepCloneComputation(new_module->entry_computation(), &context);
307 return builder->AddInstruction(HloInstruction::CreateSort(
308 sort_shape, dimension_to_sort, operands, compare_computation, is_stable));
309 }
310
CollapseFirstNDims(HloInstruction * operand,int64 n)311 StatusOr<HloInstruction*> CollapseFirstNDims(HloInstruction* operand, int64 n) {
312 CHECK_GT(n, 0);
313
314 const Shape& operand_shape = operand->shape();
315 CHECK_GE(operand_shape.dimensions_size(), n);
316 int64 new_shape_leading_bound = 1;
317 for (int64 i = 0; i < n; i++) {
318 new_shape_leading_bound *= operand_shape.dimensions(i);
319 }
320
321 std::vector<int64> new_shape_dims;
322 new_shape_dims.reserve(operand_shape.dimensions_size() - n + 1);
323 new_shape_dims.push_back(new_shape_leading_bound);
324
325 std::copy(operand_shape.dimensions().begin() + n,
326 operand_shape.dimensions().end(),
327 std::back_inserter(new_shape_dims));
328
329 Shape output_shape =
330 ShapeUtil::MakeShape(operand_shape.element_type(), new_shape_dims);
331
332 return MakeReshapeHlo(output_shape, operand);
333 }
334
PrependDegenerateDims(HloInstruction * operand,int64 n)335 StatusOr<HloInstruction*> PrependDegenerateDims(HloInstruction* operand,
336 int64 n) {
337 CHECK_GT(n, 0);
338 std::vector<int64> new_shape_dims;
339 const Shape& operand_shape = operand->shape();
340 new_shape_dims.reserve(n + operand_shape.dimensions_size());
341 new_shape_dims.insert(new_shape_dims.begin(), n, 1);
342 absl::c_copy(operand_shape.dimensions(), std::back_inserter(new_shape_dims));
343 return MakeReshapeHlo(new_shape_dims, operand);
344 }
345
ExpandFirstDimIntoNDims(HloInstruction * operand,absl::Span<const int64> expanded_dims)346 StatusOr<HloInstruction*> ExpandFirstDimIntoNDims(
347 HloInstruction* operand, absl::Span<const int64> expanded_dims) {
348 CHECK_GT(operand->shape().dimensions_size(), 0);
349 CHECK_EQ(operand->shape().dimensions(0), Product(expanded_dims));
350
351 std::vector<int64> expanded_shape_dim_bounds;
352 expanded_shape_dim_bounds.reserve(expanded_dims.size() +
353 operand->shape().dimensions_size() - 1);
354 absl::c_copy(expanded_dims, std::back_inserter(expanded_shape_dim_bounds));
355 std::copy(operand->shape().dimensions().begin() + 1,
356 operand->shape().dimensions().end(),
357 std::back_inserter(expanded_shape_dim_bounds));
358 Shape new_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
359 expanded_shape_dim_bounds);
360 return MakeReshapeHlo(new_shape, operand);
361 }
362
ElideDegenerateDims(HloInstruction * operand,absl::Span<const int64> dims_to_elide)363 StatusOr<HloInstruction*> ElideDegenerateDims(
364 HloInstruction* operand, absl::Span<const int64> dims_to_elide) {
365 CHECK(absl::c_is_sorted(dims_to_elide));
366
367 const Shape& input_shape = operand->shape();
368 // First accumulate in reverse
369 std::vector<int64> new_shape_dim_bounds;
370 new_shape_dim_bounds.reserve(input_shape.dimensions_size() -
371 dims_to_elide.size());
372 int64 dims_to_elide_idx = dims_to_elide.size() - 1;
373 for (int64 i = input_shape.dimensions_size() - 1; i >= 0; i--) {
374 if (dims_to_elide_idx >= 0 && i == dims_to_elide[dims_to_elide_idx]) {
375 CHECK_EQ(input_shape.dimensions(i), 1);
376 dims_to_elide_idx--;
377 } else {
378 new_shape_dim_bounds.push_back(input_shape.dimensions(i));
379 }
380 }
381
382 absl::c_reverse(new_shape_dim_bounds);
383 Shape output_shape =
384 ShapeUtil::MakeShape(input_shape.element_type(), new_shape_dim_bounds);
385 return MakeReshapeHlo(output_shape, operand);
386 }
387
InsertDegenerateDims(HloInstruction * operand,absl::Span<const int64> dims_to_insert)388 StatusOr<HloInstruction*> InsertDegenerateDims(
389 HloInstruction* operand, absl::Span<const int64> dims_to_insert) {
390 CHECK(absl::c_is_sorted(dims_to_insert));
391
392 const Shape& operand_shape = operand->shape();
393 int64 output_shape_rank =
394 operand_shape.dimensions_size() + dims_to_insert.size();
395 for (auto dim_to_insert : dims_to_insert) {
396 CHECK_LT(dim_to_insert, output_shape_rank);
397 }
398
399 std::vector<int64> output_shape_dim_bounds;
400 output_shape_dim_bounds.reserve(output_shape_rank);
401 int64 operand_dims_idx = 0;
402 int64 dims_to_insert_idx = 0;
403 for (int64 i = 0; i < output_shape_rank; ++i) {
404 if (dims_to_insert_idx < dims_to_insert.size() &&
405 i == dims_to_insert[dims_to_insert_idx]) {
406 output_shape_dim_bounds.push_back(1);
407 ++dims_to_insert_idx;
408 } else {
409 output_shape_dim_bounds.push_back(
410 operand_shape.dimensions(operand_dims_idx));
411 ++operand_dims_idx;
412 }
413 }
414
415 Shape output_shape = ShapeUtil::MakeShape(operand_shape.element_type(),
416 output_shape_dim_bounds);
417 return MakeReshapeHlo(output_shape, operand);
418 }
419
PadVectorWithZeros(HloInstruction * operand,int64 zeros_to_prepend,int64 zeros_to_append)420 StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand,
421 int64 zeros_to_prepend,
422 int64 zeros_to_append) {
423 HloComputation* computation = operand->parent();
424 CHECK_EQ(operand->shape().dimensions_size(), 1);
425 PaddingConfig padding_config;
426 PaddingConfig::PaddingConfigDimension padding_config_dim;
427 padding_config_dim.set_edge_padding_low(zeros_to_prepend);
428 padding_config_dim.set_edge_padding_high(zeros_to_append);
429 *padding_config.add_dimensions() = padding_config_dim;
430
431 HloInstruction* zero =
432 computation->AddInstruction(HloInstruction::CreateConstant(
433 LiteralUtil::Zero(operand->shape().element_type())));
434 return MakePadHlo(operand, zero, padding_config);
435 }
436
BroadcastZeros(HloComputation * computation,PrimitiveType element_type,absl::Span<const int64> broadcast_dimensions)437 HloInstruction* BroadcastZeros(HloComputation* computation,
438 PrimitiveType element_type,
439 absl::Span<const int64> broadcast_dimensions) {
440 HloInstruction* zero = computation->AddInstruction(
441 HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
442 return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{},
443 /*result_shape_bounds=*/broadcast_dimensions);
444 }
445
CreateComputationWithSignature(absl::Span<const Shape * const> domain,const Shape & range,absl::string_view name)446 StatusOr<std::unique_ptr<HloComputation>> CreateComputationWithSignature(
447 absl::Span<const Shape* const> domain, const Shape& range,
448 absl::string_view name) {
449 HloComputation::Builder b{string(name)};
450 int64 param_idx = 0;
451 for (const Shape* param_shape : domain) {
452 b.AddInstruction(HloInstruction::CreateParameter(
453 param_idx, *param_shape, StrCat("param.", param_idx)));
454 param_idx++;
455 }
456
457 // We can't change the root type of a computation once it is created so create
458 // a dummy root instruction to give the computation the right root shape. In
459 // the future we may want to use a (recursive) broadcast here to avoid
460 // creating large constants.
461 b.AddInstruction(
462 HloInstruction::CreateConstant(Literal::CreateFromShape(range)));
463
464 return b.Build();
465 }
466
467 } // namespace xla
468