1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
17
18 #include <unordered_map>
19
20 #include "absl/algorithm/container.h"
21 #include "absl/types/optional.h"
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/Support/raw_ostream.h"
26 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
27 #include "mlir/IR/Attributes.h" // from @llvm-project
28 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
29 #include "mlir/IR/Builders.h" // from @llvm-project
30 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
31 #include "mlir/IR/Identifier.h" // from @llvm-project
32 #include "mlir/IR/Location.h" // from @llvm-project
33 #include "mlir/IR/Region.h" // from @llvm-project
34 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
35 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
36 #include "tensorflow/compiler/mlir/xla/attribute_importer.h"
37 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
38 #include "tensorflow/compiler/xla/comparison_util.h"
39 #include "tensorflow/compiler/xla/protobuf_util.h"
40 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
41 #include "tensorflow/compiler/xla/service/hlo_computation.h"
42 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
43 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
44 #include "tensorflow/compiler/xla/service/hlo_module.h"
45 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
46 #include "tensorflow/compiler/xla/status_macros.h"
47 #include "tensorflow/compiler/xla/xla_data.pb.h"
48
49 using llvm::APInt;
50 using llvm::makeArrayRef;
51 using mlir::DenseIntElementsAttr;
52 using mlir::FuncOp;
53 using mlir::NamedAttribute;
54 using mlir::Operation;
55 using mlir::RankedTensorType;
56 using mlir::Type;
57 using mlir::Value;
58
59 namespace xla {
60
61 namespace {
62
63 // Note: This sanitization function causes an irreversible many-to-one mapping
64 // and any solution to mitigate this would cause issues with the reverse
65 // direction. Longterm solution is to add a function attribute to maintain the
66 // original HLO naming.
SanitizeFunctionName(llvm::StringRef name)67 string SanitizeFunctionName(llvm::StringRef name) {
68 string output(name);
69 llvm::for_each(output, [](char& x) { x = x == '-' ? '_' : x; });
70 return output;
71 }
72
73 // Returns whether the instruction is a default dot operation.
DotIsDefault(const HloInstruction * instruction)74 bool DotIsDefault(const HloInstruction* instruction) {
75 auto dnums = instruction->dot_dimension_numbers();
76 DotDimensionNumbers default_dimension_numbers;
77 default_dimension_numbers.add_lhs_contracting_dimensions(
78 instruction->operand(0)->shape().dimensions_size() == 1 ? 0 : 1);
79 default_dimension_numbers.add_rhs_contracting_dimensions(0);
80 return xla::protobuf_util::ProtobufEquals(dnums, default_dimension_numbers);
81 }
82
83 // Returns an MLIR Location generated from HLO Instruction. Uses instruction
84 // metadata if present or instruction name.
GenerateInstructionLocation(HloInstruction * instruction,mlir::OpBuilder * func_builder)85 mlir::Location GenerateInstructionLocation(HloInstruction* instruction,
86 mlir::OpBuilder* func_builder) {
87 const std::string& op_name = instruction->metadata().op_name();
88 if (op_name.empty()) {
89 return mlir::NameLoc::get(func_builder->getIdentifier(instruction->name()),
90 func_builder->getContext());
91 }
92
93 mlir::Location op_name_loc = mlir::NameLoc::get(
94 func_builder->getIdentifier(op_name), func_builder->getContext());
95 const std::string& source_file = instruction->metadata().source_file();
96 if (source_file.empty()) {
97 return op_name_loc;
98 }
99
100 return mlir::FusedLoc::get(
101 {op_name_loc, mlir::FileLineColLoc::get(
102 source_file, instruction->metadata().source_line(), 0,
103 func_builder->getContext())},
104 func_builder->getContext());
105 }
106 } // namespace
107
ImportAsFunc(const HloComputation & computation,mlir::ModuleOp module,std::unordered_map<const HloComputation *,FuncOp> * function_map,mlir::Builder * builder)108 Status HloFunctionImporter::ImportAsFunc(
109 const HloComputation& computation, mlir::ModuleOp module,
110 std::unordered_map<const HloComputation*, FuncOp>* function_map,
111 mlir::Builder* builder) {
112 HloFunctionImporter importer(module, function_map, builder);
113 return importer.ImportAsFunc(computation).status();
114 }
115
ImportAsRegion(const xla::HloComputation & computation,mlir::Region * region,mlir::Builder * builder)116 Status HloFunctionImporter::ImportAsRegion(
117 const xla::HloComputation& computation, mlir::Region* region,
118 mlir::Builder* builder) {
119 HloFunctionImporter importer(region->getParentOfType<mlir::ModuleOp>(), {},
120 builder);
121 return importer.ImportAsRegion(computation, region);
122 }
123
ImportAsFunc(const HloComputation & computation)124 StatusOr<mlir::FuncOp> HloFunctionImporter::ImportAsFunc(
125 const HloComputation& computation) {
126 auto& imported = (*function_map_)[&computation];
127 if (imported) return imported;
128 llvm::SmallVector<Type, 4> args, rets;
129 TF_RETURN_IF_ERROR(GetMlirTypes(computation.parameter_instructions(), &args));
130 TF_RETURN_IF_ERROR(GetMlirTypes({computation.root_instruction()}, &rets));
131 auto func_type = mlir::FunctionType::get(context_, args, rets);
132
133 string computation_name =
134 computation.parent()->entry_computation() == &computation
135 ? "main"
136 : SanitizeFunctionName(computation.name());
137
138 // Construct the MLIR function and map arguments.
139 llvm::ArrayRef<mlir::NamedAttribute> attrs;
140 auto function = mlir::FuncOp::create(mlir::UnknownLoc::get(context_),
141 computation_name, func_type, attrs);
142 auto visibility = computation_name == "main" ? FuncOp::Visibility::Public
143 : FuncOp::Visibility::Private;
144 function.setVisibility(visibility);
145 module_.push_back(function);
146
147 // Add to the map right away for function calls.
148 imported = function;
149
150 mlir::Block* block = function.addEntryBlock();
151 TF_RETURN_IF_ERROR(ImportInstructions(computation, block));
152
153 return function;
154 }
155
ImportAsRegion(const HloComputation & computation,mlir::Region * region)156 tensorflow::Status HloFunctionImporter::ImportAsRegion(
157 const HloComputation& computation, mlir::Region* region) {
158 // TODO(hinsu): Store computation name as an attribute for round-trip.
159 auto* block = new mlir::Block;
160 region->push_back(block);
161
162 llvm::SmallVector<Type, 4> args;
163 TF_RETURN_IF_ERROR(GetMlirTypes(computation.parameter_instructions(), &args));
164 block->addArguments(args);
165
166 return ImportInstructions(computation, block);
167 }
168
ImportInstructionsImpl(const xla::HloComputation & computation,const llvm::SmallVectorImpl<Value> & arguments,mlir::OpBuilder * builder)169 StatusOr<Value> HloFunctionImporter::ImportInstructionsImpl(
170 const xla::HloComputation& computation,
171 const llvm::SmallVectorImpl<Value>& arguments, mlir::OpBuilder* builder) {
172 // Setup the input parameters.
173 const int num_parameters = computation.num_parameters();
174
175 if (arguments.size() != num_parameters)
176 return InvalidArgument("Caller vs callee argument sizes do not match");
177
178 for (int i = 0; i < num_parameters; i++) {
179 auto hlo_parameter = computation.parameter_instruction(i);
180 instruction_value_map_[hlo_parameter] = arguments[i];
181 }
182
183 for (auto instruction : computation.MakeInstructionPostOrder()) {
184 TF_ASSIGN_OR_RETURN(auto new_operation,
185 ImportInstruction(instruction, builder));
186 if (new_operation) {
187 instruction_value_map_[instruction] = new_operation->getResult(0);
188 }
189 }
190
191 // Setup the return type (HLO only supports a single return value).
192 return GetMlirValue(computation.root_instruction());
193 }
194
ImportInstructions(const HloComputation & computation,mlir::Block * block)195 Status HloFunctionImporter::ImportInstructions(
196 const HloComputation& computation, mlir::Block* block) {
197 llvm::SmallVector<Value, 4> arguments(block->args_begin(), block->args_end());
198 mlir::OpBuilder builder = mlir::OpBuilder::atBlockEnd(block);
199 TF_ASSIGN_OR_RETURN(Value result,
200 ImportInstructionsImpl(computation, arguments, &builder));
201
202 // TODO(suderman): Add location tracking details.
203 mlir::Location loc = builder.getUnknownLoc();
204
205 // Create terminator op depending on the parent op of this region.
206 if (llvm::isa<FuncOp>(block->getParentOp())) {
207 builder.create<mlir::ReturnOp>(loc, result);
208 } else {
209 builder.create<mlir::mhlo::ReturnOp>(loc, result);
210 }
211 return tensorflow::Status::OK();
212 }
213
ImportInstructions(const xla::HloComputation & computation,const llvm::SmallVectorImpl<Value> & arguments,mlir::OpBuilder * builder)214 StatusOr<Value> HloFunctionImporter::ImportInstructions(
215 const xla::HloComputation& computation,
216 const llvm::SmallVectorImpl<Value>& arguments, mlir::OpBuilder* builder) {
217 mlir::Block* block = builder->getBlock();
218 if (block == nullptr)
219 return InvalidArgument(
220 "ImportInstructions requires a valid block in the builder");
221
222 HloFunctionImporter importer(
223 block->getParent()->getParentOfType<mlir::ModuleOp>(), {}, builder);
224 return importer.ImportInstructionsImpl(computation, arguments, builder);
225 }
226
ImportInstructionImpl(HloInstruction * instruction,mlir::OpBuilder * func_builder)227 StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
228 HloInstruction* instruction, mlir::OpBuilder* func_builder) {
229 TF_ASSIGN_OR_RETURN(auto operands, GetOperands(instruction));
230 TF_ASSIGN_OR_RETURN(auto result_type, ConvertShapeToType<RankedTensorType>(
231 instruction->shape(), *builder_));
232 mlir::Location loc = GenerateInstructionLocation(instruction, func_builder);
233
234 llvm::SmallVector<NamedAttribute, 10> attributes;
235 switch (instruction->opcode()) {
236 case HloOpcode::kParameter: {
237 return nullptr;
238 }
239 case HloOpcode::kConstant: {
240 const Literal& literal = instruction->literal();
241 auto attr = CreateDenseElementsAttrFromLiteral(literal, *builder_);
242 if (!attr.ok()) return attr.status();
243 mlir::Operation* new_operation =
244 func_builder->create<mlir::mhlo::ConstOp>(loc, attr.ValueOrDie());
245 for (auto attr : attributes) {
246 new_operation->setAttr(attr.first, attr.second);
247 }
248 return new_operation;
249 }
250 case HloOpcode::kIota: {
251 return func_builder
252 ->create<mlir::mhlo::IotaOp>(
253 loc, result_type,
254 func_builder->getI64IntegerAttr(
255 Cast<HloIotaInstruction>(instruction)->iota_dimension()))
256 .getOperation();
257 }
258 #define MakeAndReturn(mlir_op) \
259 { \
260 mlir::Operation* new_operation = \
261 func_builder->create<mlir::mhlo::mlir_op>(loc, result_type, operands, \
262 attributes); \
263 return new_operation; \
264 }
265 case HloOpcode::kBroadcast: {
266 // Note that the HLO broadcast is more powerful than the XLA broadcast
267 // op. BroadcastInDim offers a superset of the HLO op's functionality.
268 attributes.push_back(
269 builder_->getNamedAttr("broadcast_dimensions",
270 ConvertDimensions(instruction->dimensions())));
271 MakeAndReturn(BroadcastInDimOp);
272 }
273 #define MakeAndReturnBatchNormOp(batch_norm_op) \
274 { \
275 attributes.push_back(builder_->getNamedAttr( \
276 "epsilon", builder_->getF32FloatAttr(instruction->epsilon()))); \
277 attributes.push_back(builder_->getNamedAttr( \
278 "feature_index", \
279 builder_->getI64IntegerAttr(instruction->feature_index()))); \
280 MakeAndReturn(batch_norm_op); \
281 }
282 case HloOpcode::kBatchNormGrad:
283 MakeAndReturnBatchNormOp(BatchNormGradOp);
284 case HloOpcode::kBatchNormInference:
285 MakeAndReturnBatchNormOp(BatchNormInferenceOp);
286 case HloOpcode::kBatchNormTraining:
287 MakeAndReturnBatchNormOp(BatchNormTrainingOp);
288 #undef MakeAndReturnBatchNormOp
289
290 case HloOpcode::kDot: {
291 attributes.push_back(builder_->getNamedAttr(
292 "precision_config",
293 ConvertPrecisionConfig(&instruction->precision_config(), builder_)));
294
295 // Consider consolidating DotOps together.
296 if (DotIsDefault(instruction)) {
297 MakeAndReturn(DotOp);
298 }
299
300 attributes.push_back(builder_->getNamedAttr(
301 "dot_dimension_numbers",
302 ConvertDotDimensionNumbers(instruction->dot_dimension_numbers(),
303 builder_)));
304 MakeAndReturn(DotGeneralOp);
305 }
306 case HloOpcode::kCall: {
307 TF_ASSIGN_OR_RETURN(FuncOp function,
308 ImportAsFunc(*instruction->to_apply()));
309 mlir::Operation* new_operation =
310 func_builder->create<mlir::CallOp>(loc, function, operands);
311 return new_operation;
312 }
313 case HloOpcode::kCollectivePermute: {
314 attributes.push_back(ConvertSourceTargetPairs(
315 instruction->source_target_pairs(), builder_));
316 MakeAndReturn(CollectivePermuteOp);
317 }
318 case HloOpcode::kCustomCall: {
319 auto custom_call = Cast<HloCustomCallInstruction>(instruction);
320 attributes.push_back(builder_->getNamedAttr(
321 "call_target_name",
322 builder_->getStringAttr(custom_call->custom_call_target())));
323 attributes.push_back(builder_->getNamedAttr(
324 "has_side_effect",
325 builder_->getBoolAttr(custom_call->custom_call_has_side_effect())));
326 attributes.push_back(builder_->getNamedAttr(
327 "backend_config",
328 builder_->getStringAttr(custom_call->raw_backend_config_string())));
329 MakeAndReturn(CustomCallOp);
330 }
331 case HloOpcode::kCompare: {
332 auto compare = Cast<HloCompareInstruction>(instruction);
333 attributes.push_back(ConvertComparisonDirection(compare->direction()));
334 auto default_type = Comparison::DefaultComparisonType(
335 compare->operand(0)->shape().element_type());
336 if (compare->type() != default_type)
337 attributes.push_back(ConvertComparisonType(compare->type()));
338 MakeAndReturn(CompareOp);
339 }
340 case HloOpcode::kCholesky: {
341 attributes.push_back(builder_->getNamedAttr(
342 "lower",
343 builder_->getBoolAttr(instruction->cholesky_options().lower())));
344 MakeAndReturn(CholeskyOp);
345 }
346 case HloOpcode::kGather: {
347 auto gather_instruction = Cast<HloGatherInstruction>(instruction);
348 attributes.push_back(builder_->getNamedAttr(
349 "dimension_numbers",
350 ConvertGatherDimensionNumbers(
351 gather_instruction->gather_dimension_numbers(), builder_)));
352
353 std::vector<int64_t> slice_sizes(
354 gather_instruction->gather_slice_sizes().begin(),
355 gather_instruction->gather_slice_sizes().end());
356 attributes.push_back(
357 builder_->getNamedAttr("slice_sizes", Convert(slice_sizes)));
358 attributes.push_back(builder_->getNamedAttr(
359 "indices_are_sorted",
360 builder_->getBoolAttr(gather_instruction->indices_are_sorted())));
361
362 MakeAndReturn(GatherOp);
363 }
364 case HloOpcode::kDynamicSlice: {
365 std::vector<int64_t> slice_sizes(
366 instruction->dynamic_slice_sizes().begin(),
367 instruction->dynamic_slice_sizes().end());
368 return func_builder
369 ->create<mlir::mhlo::DynamicSliceOp>(
370 loc, result_type, operands[0],
371 makeArrayRef(operands).drop_front(), Convert(slice_sizes))
372 .getOperation();
373 }
374 case HloOpcode::kDynamicUpdateSlice: {
375 return func_builder
376 ->create<mlir::mhlo::DynamicUpdateSliceOp>(
377 loc, result_type, operands[0], operands[1],
378 llvm::ArrayRef<Value>(operands.begin() + 2, operands.end()))
379 .getOperation();
380 }
381 case HloOpcode::kInfeed: {
382 attributes.push_back(builder_->getNamedAttr(
383 "infeed_config",
384 mlir::StringAttr::get(builder_->getContext(),
385 instruction->infeed_config())));
386 // TODO(kramm): Support tuples and tokens.
387 if (instruction->shape().IsArray()) {
388 const xla::Layout l = instruction->shape().layout();
389 absl::Span<const int64> minor_to_major = l.minor_to_major();
390 std::vector<mlir::Attribute> v;
391 for (int64 i : minor_to_major) {
392 v.push_back(builder_->getI32IntegerAttr(i));
393 }
394 llvm::ArrayRef<mlir::Attribute> array_ref(v);
395 mlir::ArrayAttr layout = builder_->getArrayAttr(array_ref);
396 attributes.push_back(builder_->getNamedAttr("layout", layout));
397 }
398
399 MakeAndReturn(InfeedOp);
400 }
401 case HloOpcode::kOutfeed: {
402 attributes.push_back(builder_->getNamedAttr(
403 "outfeed_config",
404 mlir::StringAttr::get(builder_->getContext(),
405 instruction->outfeed_config())));
406 MakeAndReturn(OutfeedOp);
407 }
408 case HloOpcode::kPad: {
409 const auto& padding_config = instruction->padding_config();
410 llvm::SmallVector<int64_t, 4> edge_padding_low;
411 llvm::SmallVector<int64_t, 4> edge_padding_high;
412 llvm::SmallVector<int64_t, 4> interior_padding;
413 edge_padding_low.reserve(padding_config.dimensions_size());
414 edge_padding_high.reserve(padding_config.dimensions_size());
415 interior_padding.reserve(padding_config.dimensions_size());
416
417 for (const auto& dimension : padding_config.dimensions()) {
418 edge_padding_low.push_back(dimension.edge_padding_low());
419 edge_padding_high.push_back(dimension.edge_padding_high());
420 interior_padding.push_back(dimension.interior_padding());
421 }
422
423 return func_builder
424 ->create<mlir::mhlo::PadOp>(loc, result_type, operands[0],
425 operands[1], Convert(edge_padding_low),
426 Convert(edge_padding_high),
427 Convert(interior_padding))
428 .getOperation();
429 }
430 case HloOpcode::kScatter: {
431 auto scatter = Cast<HloScatterInstruction>(instruction);
432 attributes.push_back(builder_->getNamedAttr(
433 "scatter_dimension_numbers",
434 ConvertScatterDimensionNumbers(scatter->scatter_dimension_numbers(),
435 builder_)));
436 attributes.push_back(builder_->getNamedAttr(
437 "indices_are_sorted",
438 builder_->getBoolAttr(scatter->indices_are_sorted())));
439 attributes.push_back(builder_->getNamedAttr(
440 "unique_indices", builder_->getBoolAttr(scatter->unique_indices())));
441
442 auto scatter_op = func_builder->create<mlir::mhlo::ScatterOp>(
443 loc, result_type, operands, attributes);
444 TF_RETURN_IF_ERROR(ImportAsRegion(*scatter->to_apply(),
445 &scatter_op.update_computation()));
446 return scatter_op.getOperation();
447 }
448 case HloOpcode::kSelectAndScatter: {
449 auto select_scatter = Cast<HloSelectAndScatterInstruction>(instruction);
450 llvm::SmallVector<int64_t, 4> window_strides, window_dimensions;
451 llvm::SmallVector<int64_t, 8> padding;
452 for (const auto& dim : select_scatter->window().dimensions()) {
453 window_strides.push_back(dim.stride());
454 window_dimensions.push_back(dim.size());
455 padding.push_back(dim.padding_low());
456 padding.push_back(dim.padding_high());
457 }
458 attributes.push_back(
459 builder_->getNamedAttr("window_strides", Convert(window_strides)));
460 attributes.push_back(builder_->getNamedAttr("window_dimensions",
461 Convert(window_dimensions)));
462 attributes.push_back(ConvertPadding(padding));
463 auto select_scatter_op =
464 func_builder->create<mlir::mhlo::SelectAndScatterOp>(
465 loc, result_type, operands, attributes);
466 TF_RETURN_IF_ERROR(ImportAsRegion(*select_scatter->select(),
467 &select_scatter_op.select()));
468 TF_RETURN_IF_ERROR(ImportAsRegion(*select_scatter->scatter(),
469 &select_scatter_op.scatter()));
470 return select_scatter_op.getOperation();
471 }
472 case HloOpcode::kSetDimensionSize: {
473 attributes.push_back(builder_->getNamedAttr(
474 "dimension", builder_->getI64IntegerAttr(instruction->dimension())));
475 MakeAndReturn(SetDimensionSizeOp);
476 }
477 case HloOpcode::kSlice: {
478 return func_builder
479 ->create<mlir::mhlo::SliceOp>(
480 loc, result_type, operands[0],
481 ConvertDimensions(instruction->slice_starts()),
482 ConvertDimensions(instruction->slice_limits()),
483 ConvertDimensions(instruction->slice_strides()))
484 .getOperation();
485 }
486 case HloOpcode::kSort: {
487 auto sort_instruction = Cast<HloSortInstruction>(instruction);
488
489 llvm::SmallVector<Type, 4> return_types = {result_type};
490 if (mlir::TupleType tuple_ty = result_type.dyn_cast<mlir::TupleType>()) {
491 return_types = llvm::to_vector<6>(tuple_ty.getTypes());
492 }
493
494 auto sort_op = func_builder->create<mlir::mhlo::SortOp>(
495 loc, return_types, operands,
496 builder_->getI64IntegerAttr(sort_instruction->sort_dimension()),
497 builder_->getBoolAttr(sort_instruction->is_stable()));
498 TF_RETURN_IF_ERROR(
499 ImportAsRegion(*sort_instruction->to_apply(), &sort_op.comparator()));
500
501 // Check if the output needs to be tupled.
502 if (return_types.size() == 1 && return_types.front() == result_type) {
503 return sort_op.getOperation();
504 }
505
506 return func_builder
507 ->create<mlir::mhlo::TupleOp>(loc, result_type, sort_op.getResults())
508 .getOperation();
509 }
510 case HloOpcode::kConditional: {
511 llvm::SmallVector<Type, 4> rets;
512 mlir::Type pred_or_index_type =
513 operands[0].getType().cast<mlir::TensorType>().getElementType();
514 // It is a predicated conditional if first argument is a boolean and
515 // should be mapped to If op.
516 if (pred_or_index_type.isInteger(1)) {
517 TF_RETURN_IF_ERROR(GetMlirTypes(
518 {instruction->true_computation()->root_instruction()}, &rets));
519
520 auto op = func_builder->create<mlir::mhlo::IfOp>(loc, rets, operands,
521 attributes);
522 TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->true_computation(),
523 &op.true_branch()));
524 TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->false_computation(),
525 &op.false_branch()));
526 return op.getOperation();
527 }
528
529 // Otherwise, it is a indexed conditional and should be mapped to Case
530 // op.
531 TF_RETURN_IF_ERROR(GetMlirTypes(
532 {instruction->branch_computation(0)->root_instruction()}, &rets));
533
534 int num_branches = instruction->branch_count();
535 auto op = func_builder->create<mlir::mhlo::CaseOp>(
536 loc, rets, operands, attributes, num_branches);
537 for (auto index_and_computation :
538 llvm::enumerate(instruction->branch_computations())) {
539 auto index = index_and_computation.index();
540 HloComputation* computation = index_and_computation.value();
541 TF_RETURN_IF_ERROR(ImportAsRegion(*computation, &op.branches()[index]));
542 }
543 return op.getOperation();
544 }
545 case HloOpcode::kConcatenate: {
546 // TODO(b/132057942): Support taking an uint64_t instead of an
547 // IntegerAttr for concatenate dimension.
548 return func_builder
549 ->create<mlir::mhlo::ConcatenateOp>(
550 loc, result_type, operands,
551 builder_->getI64IntegerAttr(instruction->concatenate_dimension()))
552 .getOperation();
553 }
554 case HloOpcode::kAllReduce: {
555 auto all_reduce = Cast<HloAllReduceInstruction>(instruction);
556 attributes.push_back(
557 ConvertReplicaGroups(all_reduce->replica_groups(), builder_));
558 attributes.push_back(ConvertChannelHandle(all_reduce->channel_id()));
559 auto all_reduce_op = func_builder->create<mlir::mhlo::AllReduceOp>(
560 loc, result_type, operands, attributes);
561 TF_RETURN_IF_ERROR(ImportAsRegion(*all_reduce->to_apply(),
562 &all_reduce_op.computation()));
563 return all_reduce_op.getOperation();
564 }
565 case HloOpcode::kReduce: {
566 // Operands in the first half are reduction inputs and the remaining
567 // operands are corresponding initial values.
568 size_t num_inputs = operands.size() / 2;
569 auto reduce = func_builder->create<mlir::mhlo::ReduceOp>(
570 loc, result_type, llvm::makeArrayRef(operands).take_front(num_inputs),
571 llvm::makeArrayRef(operands).drop_front(num_inputs),
572 ConvertDimensions(instruction->dimensions()));
573 TF_RETURN_IF_ERROR(
574 ImportAsRegion(*instruction->to_apply(), &reduce.body()));
575 return reduce.getOperation();
576 }
577 case HloOpcode::kReverse: {
578 return func_builder
579 ->create<mlir::mhlo::ReverseOp>(
580 loc, result_type, operands[0],
581 ConvertDimensions(instruction->dimensions()))
582 .getOperation();
583 }
584 case HloOpcode::kRng: {
585 auto shape = func_builder->create<mlir::ConstantOp>(
586 loc, Convert(result_type.cast<RankedTensorType>().getShape()));
587 switch (instruction->random_distribution()) {
588 case xla::RNG_UNIFORM:
589 return func_builder
590 ->create<mlir::mhlo::RngUniformOp>(loc, result_type, operands[0],
591 operands[1], shape)
592 .getOperation();
593
594 case xla::RNG_NORMAL:
595 return func_builder
596 ->create<mlir::mhlo::RngNormalOp>(loc, result_type, operands[0],
597 operands[1], shape)
598 .getOperation();
599
600 default:
601 return tensorflow::errors::InvalidArgument(absl::StrCat(
602 "Unsupported distribution: ",
603 RandomDistributionToString(instruction->random_distribution())));
604 }
605 }
606 case HloOpcode::kRngBitGenerator: {
607 auto rng_op = Cast<HloRngBitGeneratorInstruction>(instruction);
608 auto op = func_builder->create<mlir::mhlo::RngBitGeneratorOp>(
609 loc, result_type,
610 func_builder->getI32IntegerAttr(rng_op->algorithm()), operands[0]);
611 return op.getOperation();
612 }
613 case HloOpcode::kWhile: {
614 auto op = func_builder->create<mlir::mhlo::WhileOp>(
615 loc, operands[0].getType(), operands[0]);
616 TF_RETURN_IF_ERROR(
617 ImportAsRegion(*instruction->while_condition(), &op.cond()));
618 TF_RETURN_IF_ERROR(
619 ImportAsRegion(*instruction->while_body(), &op.body()));
620 return op.getOperation();
621 }
622 case HloOpcode::kGetTupleElement: {
623 attributes.push_back(builder_->getNamedAttr(
624 "index", builder_->getIntegerAttr(builder_->getIntegerType(32),
625 instruction->tuple_index())));
626 MakeAndReturn(GetTupleElementOp);
627 };
628 case HloOpcode::kGetDimensionSize: {
629 attributes.push_back(builder_->getNamedAttr(
630 "dimension", builder_->getI64IntegerAttr(instruction->dimension())));
631 MakeAndReturn(GetDimensionSizeOp);
632 };
633 case HloOpcode::kTranspose: {
634 attributes.push_back(builder_->getNamedAttr(
635 "permutation", ConvertDimensions(instruction->dimensions())));
636 MakeAndReturn(TransposeOp);
637 }
638 case HloOpcode::kTriangularSolve: {
639 attributes.push_back(builder_->getNamedAttr(
640 "left_side",
641 builder_->getBoolAttr(
642 instruction->triangular_solve_options().left_side())));
643 attributes.push_back(builder_->getNamedAttr(
644 "lower", builder_->getBoolAttr(
645 instruction->triangular_solve_options().lower())));
646 attributes.push_back(builder_->getNamedAttr(
647 "unit_diagonal",
648 builder_->getBoolAttr(
649 instruction->triangular_solve_options().unit_diagonal())));
650 auto transpose_a =
651 builder_->getStringAttr(TriangularSolveOptions::Transpose_Name(
652 instruction->triangular_solve_options().transpose_a()));
653 attributes.push_back(builder_->getNamedAttr("transpose_a", transpose_a));
654 MakeAndReturn(TriangularSolveOp);
655 }
656 case HloOpcode::kReduceWindow: {
657 llvm::SmallVector<int64, 4> sizes, strides, base_dilations, win_dilations;
658 llvm::SmallVector<int64_t, 8> padding;
659 for (const auto& dim : instruction->window().dimensions()) {
660 sizes.push_back(dim.size());
661 strides.push_back(dim.stride());
662 base_dilations.push_back(dim.base_dilation());
663 win_dilations.push_back(dim.window_dilation());
664 padding.push_back(dim.padding_low());
665 padding.push_back(dim.padding_high());
666 }
667 attributes.push_back(builder_->getNamedAttr("window_dimensions",
668 ConvertDimensions(sizes)));
669 attributes.push_back(
670 builder_->getNamedAttr("window_strides", ConvertDimensions(strides)));
671 attributes.push_back(builder_->getNamedAttr(
672 "base_dilations", ConvertDimensions(base_dilations)));
673 attributes.push_back(builder_->getNamedAttr(
674 "window_dilations", ConvertDimensions(win_dilations)));
675 attributes.push_back(ConvertPadding(padding));
676 auto reduce = func_builder->create<mlir::mhlo::ReduceWindowOp>(
677 loc, result_type, operands, attributes);
678 TF_RETURN_IF_ERROR(
679 ImportAsRegion(*instruction->to_apply(), &reduce.body()));
680 return reduce.getOperation();
681 }
682 case HloOpcode::kMap: {
683 auto op = func_builder->create<mlir::mhlo::MapOp>(
684 loc, result_type, operands,
685 ConvertDimensions(instruction->dimensions()));
686 TF_RETURN_IF_ERROR(
687 ImportAsRegion(*instruction->to_apply(), &op.computation()));
688 return op.getOperation();
689 }
690 case HloOpcode::kConvolution: {
691 llvm::SmallVector<int64_t, 4> strides, lhs_dilations, rhs_dilations;
692 llvm::SmallVector<int64_t, 8> paddings;
693 for (const auto& dim : instruction->window().dimensions()) {
694 strides.push_back(dim.stride());
695 lhs_dilations.push_back(dim.base_dilation());
696 rhs_dilations.push_back(dim.window_dilation());
697 paddings.push_back(dim.padding_low());
698 paddings.push_back(dim.padding_high());
699 }
700
701 attributes.push_back(
702 builder_->getNamedAttr("window_strides", Convert(strides)));
703 attributes.push_back(ConvertPadding(paddings));
704 attributes.push_back(
705 builder_->getNamedAttr("lhs_dilation", Convert(lhs_dilations)));
706 attributes.push_back(
707 builder_->getNamedAttr("rhs_dilation", Convert(rhs_dilations)));
708 attributes.push_back(builder_->getNamedAttr(
709 "dimension_numbers",
710 ConvertConvDimensionNumbers(
711 instruction->convolution_dimension_numbers(), builder_)));
712 attributes.push_back(builder_->getNamedAttr(
713 "feature_group_count",
714 builder_->getI64IntegerAttr(instruction->feature_group_count())));
715 attributes.push_back(builder_->getNamedAttr(
716 "batch_group_count",
717 builder_->getI64IntegerAttr(instruction->batch_group_count())));
718 attributes.push_back(builder_->getNamedAttr(
719 "precision_config",
720 ConvertPrecisionConfig(&instruction->precision_config(), builder_)));
721
722 MakeAndReturn(ConvOp);
723 }
724
725 case HloOpcode::kFft: {
726 auto fft_type =
727 builder_->getStringAttr(FftType_Name(instruction->fft_type()));
728
729 std::vector<int64_t> fft_length(instruction->fft_length().begin(),
730 instruction->fft_length().end());
731
732 attributes.push_back(builder_->getNamedAttr("fft_type", fft_type));
733 attributes.push_back(
734 builder_->getNamedAttr("fft_length", Convert(fft_length)));
735 MakeAndReturn(FftOp);
736 }
737 #define NoAttributeCase(hlo_op_code, mlir_op) \
738 case HloOpcode::hlo_op_code: { \
739 MakeAndReturn(mlir_op); \
740 }
741
742 // broadcast dimensions are never added here because they don't exist as
743 // part of the HLO instruction. They are only a convenience in the XLA
744 // builder API.
745 NoAttributeCase(kAbs, AbsOp);
746 NoAttributeCase(kAdd, AddOp);
747 NoAttributeCase(kAfterAll, AfterAllOp);
748 NoAttributeCase(kAnd, AndOp);
749 NoAttributeCase(kAtan2, Atan2Op);
750 NoAttributeCase(kBitcastConvert, BitcastConvertOp);
751 NoAttributeCase(kCbrt, CbrtOp);
752 NoAttributeCase(kConvert, ConvertOp);
753 NoAttributeCase(kCeil, CeilOp);
754 NoAttributeCase(kClamp, ClampOp);
755 NoAttributeCase(kComplex, ComplexOp);
756 NoAttributeCase(kCos, CosOp);
757 NoAttributeCase(kDivide, DivOp);
758 NoAttributeCase(kExp, ExpOp);
759 NoAttributeCase(kExpm1, Expm1Op);
760 NoAttributeCase(kFloor, FloorOp);
761 NoAttributeCase(kIsFinite, IsFiniteOp);
762 NoAttributeCase(kImag, ImagOp);
763 NoAttributeCase(kLog, LogOp);
764 NoAttributeCase(kLog1p, Log1pOp);
765 NoAttributeCase(kMaximum, MaxOp);
766 NoAttributeCase(kMinimum, MinOp);
767 NoAttributeCase(kMultiply, MulOp);
768 NoAttributeCase(kNegate, NegOp);
769 NoAttributeCase(kNot, NotOp);
770 NoAttributeCase(kOr, OrOp);
771 NoAttributeCase(kPopulationCount, PopulationCountOp);
772 NoAttributeCase(kPower, PowOp);
773 NoAttributeCase(kReal, RealOp);
774 NoAttributeCase(kRemainder, RemOp);
775 NoAttributeCase(kReplicaId, ReplicaIdOp);
776 // The dimensions attribute is not present on the HLO Reshape
777 // instruction. If dimensions are non-default, the XLA builder
778 // implements it as a separate transpose.
779 NoAttributeCase(kReshape, ReshapeOp);
780 NoAttributeCase(kRoundNearestAfz, RoundOp);
781 NoAttributeCase(kRsqrt, RsqrtOp);
782 NoAttributeCase(kSelect, SelectOp);
783 NoAttributeCase(kShiftLeft, ShiftLeftOp);
784 NoAttributeCase(kShiftRightArithmetic, ShiftRightArithmeticOp);
785 NoAttributeCase(kShiftRightLogical, ShiftRightLogicalOp);
786 NoAttributeCase(kSign, SignOp);
787 NoAttributeCase(kSin, SinOp);
788 NoAttributeCase(kSqrt, SqrtOp);
789 NoAttributeCase(kSubtract, SubOp);
790 NoAttributeCase(kTanh, TanhOp);
791 NoAttributeCase(kTuple, TupleOp);
792 NoAttributeCase(kXor, XorOp);
793 // TODO(b/129422361) Copy needs special handling because it is not
794 // defined in tensorflow/compiler/xla/client/xla_builder.h. See
795 // operation semantics in
796 // g3doc/platforms/xla/g3doc/internal/hlo_semantics#copy
797 NoAttributeCase(kCopy, CopyOp);
798 #undef NoAttributeCase
799 #undef MakeAndReturn
800 case HloOpcode::kFusion: {
801 auto fusion = func_builder->create<mlir::mhlo::FusionOp>(
802 loc, result_type, operands,
803 builder_->getStringAttr(xla::ToString(instruction->fusion_kind())));
804 TF_RETURN_IF_ERROR(
805 ImportAsRegion(*instruction->fused_instructions_computation(),
806 &fusion.fused_computation()));
807 return fusion.getOperation();
808 }
809 case HloOpcode::kBitcast:
810 return func_builder
811 ->create<mlir::mhlo::BitcastOp>(loc, result_type, operands,
812 attributes)
813 .getOperation();
814 case HloOpcode::kReducePrecision: {
815 auto op = func_builder->create<mlir::mhlo::ReducePrecisionOp>(
816 loc, result_type, operands[0], attributes);
817 op.exponent_bitsAttr(func_builder->getIntegerAttr(
818 func_builder->getI32Type(), instruction->exponent_bits()));
819 op.mantissa_bitsAttr(func_builder->getIntegerAttr(
820 func_builder->getI32Type(), instruction->mantissa_bits()));
821 return op.getOperation();
822 }
823 case HloOpcode::kAddDependency:
824 // Arbitrary op code that I suspect we will not implement for quite a
825 // while and allows testing handling of unknown ops. Selected because it
826 // is not mentioned in xla client anywhere or in the hlo of our sample
827 // models.
828 default: {
829 mlir::OperationState result(loc, "mhlo.unknown");
830 result.addOperands(operands);
831 result.addTypes(result_type);
832 for (auto attr : attributes) {
833 result.attributes.push_back(attr);
834 }
835
836 return func_builder->createOperation(result);
837 }
838 }
839 }
840
ImportInstruction(HloInstruction * instruction,mlir::OpBuilder * func_builder)841 StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
842 HloInstruction* instruction, mlir::OpBuilder* func_builder) {
843 TF_ASSIGN_OR_RETURN(mlir::Operation * op,
844 ImportInstructionImpl(instruction, func_builder));
845 if (op == nullptr) return op;
846
847 // See MlirToHloConversionOptions for more about layouts.
848 //
849 // Minor-to-major is a permutation of [0, rank), presenting tensor dimensions
850 // in physical minor-to-major order.
851 if (instruction->shape().IsArray() &&
852 !instruction->shape().layout().minor_to_major().empty() &&
853 instruction->shape().layout() !=
854 LayoutUtil::MakeDescendingLayout(
855 instruction->shape().dimensions().size())) {
856 SetLayoutForMlir(op, instruction->shape());
857 }
858 return op;
859 }
860
GetOperands(HloInstruction * instruction)861 StatusOr<llvm::SmallVector<mlir::Value, 4>> HloFunctionImporter::GetOperands(
862 HloInstruction* instruction) {
863 llvm::SmallVector<mlir::Value, 4> operands;
864 for (const auto& operand : instruction->operands()) {
865 auto input_it = instruction_value_map_.find(operand);
866 if (input_it == instruction_value_map_.end()) {
867 return tensorflow::errors::Internal(
868 absl::StrCat("Could not find input value: ", operand->name(),
869 " for instruction ", instruction->name()));
870 }
871 operands.push_back(input_it->second);
872 }
873 return operands;
874 }
875
GetMlirTypes(const std::vector<HloInstruction * > & instructions,llvm::SmallVectorImpl<mlir::Type> * types)876 tensorflow::Status HloFunctionImporter::GetMlirTypes(
877 const std::vector<HloInstruction*>& instructions,
878 llvm::SmallVectorImpl<mlir::Type>* types) {
879 for (auto instruction : instructions) {
880 TF_ASSIGN_OR_RETURN(auto ret_type, ConvertShapeToType<RankedTensorType>(
881 instruction->shape(), *builder_));
882 types->push_back(ret_type);
883 }
884 return tensorflow::Status::OK();
885 }
886
GetMlirValue(HloInstruction * instruction)887 StatusOr<Value> HloFunctionImporter::GetMlirValue(HloInstruction* instruction) {
888 auto lookup = instruction_value_map_.find(instruction);
889 if (lookup != instruction_value_map_.end()) {
890 return lookup->second;
891 }
892
893 return tensorflow::errors::Internal(absl::StrCat(
894 "Unable to find value for input: ", instruction->ToString()));
895 }
896
ConvertComparisonDirection(ComparisonDirection direction)897 mlir::NamedAttribute HloFunctionImporter::ConvertComparisonDirection(
898 ComparisonDirection direction) {
899 return builder_->getNamedAttr(
900 "comparison_direction",
901 builder_->getStringAttr(ComparisonDirectionToString(direction)));
902 }
903
ConvertComparisonType(Comparison::Type type)904 mlir::NamedAttribute HloFunctionImporter::ConvertComparisonType(
905 Comparison::Type type) {
906 return builder_->getNamedAttr(
907 "compare_type", builder_->getStringAttr(ComparisonTypeToString(type)));
908 }
909
ConvertDimensions(llvm::ArrayRef<int64> op_dimensions)910 mlir::DenseIntElementsAttr HloFunctionImporter::ConvertDimensions(
911 llvm::ArrayRef<int64> op_dimensions) {
912 llvm::SmallVector<APInt, 8> dimensions;
913 dimensions.reserve(op_dimensions.size());
914 for (auto value : op_dimensions) dimensions.emplace_back(APInt(64, value));
915
916 return DenseIntElementsAttr::get(
917 RankedTensorType::get(dimensions.size(), builder_->getIntegerType(64)),
918 dimensions);
919 }
920
Convert(llvm::ArrayRef<int64_t> elements)921 mlir::DenseIntElementsAttr HloFunctionImporter::Convert(
922 llvm::ArrayRef<int64_t> elements) {
923 return DenseIntElementsAttr::get(
924 RankedTensorType::get(elements.size(), builder_->getIntegerType(64)),
925 elements);
926 }
927
ConvertPadding(llvm::ArrayRef<int64_t> padding)928 mlir::NamedAttribute HloFunctionImporter::ConvertPadding(
929 llvm::ArrayRef<int64_t> padding) {
930 auto ty =
931 mlir::RankedTensorType::get({static_cast<int64_t>(padding.size()) / 2, 2},
932 builder_->getIntegerType(64));
933 auto attr = DenseIntElementsAttr::get(ty, padding);
934 return builder_->getNamedAttr("padding", attr);
935 }
936
ConvertSourceTargetPairs(const std::vector<std::pair<tensorflow::int64,tensorflow::int64>> & source_target_pairs,mlir::Builder * builder)937 mlir::NamedAttribute HloFunctionImporter::ConvertSourceTargetPairs(
938 const std::vector<std::pair<tensorflow::int64, tensorflow::int64>>&
939 source_target_pairs,
940 mlir::Builder* builder) {
941 std::vector<int64_t> attr(source_target_pairs.size() * 2);
942 for (auto p : llvm::enumerate(source_target_pairs)) {
943 attr[2 * p.index()] = p.value().first;
944 attr[2 * p.index() + 1] = p.value().second;
945 }
946 auto type = mlir::RankedTensorType::get(
947 {static_cast<int64_t>(attr.size() / 2), 2}, builder->getIntegerType(64));
948 return builder->getNamedAttr("source_target_pairs",
949 DenseIntElementsAttr::get(type, attr));
950 }
951
ConvertReplicaGroups(const std::vector<ReplicaGroup> & replica_groups,mlir::Builder * builder)952 mlir::NamedAttribute HloFunctionImporter::ConvertReplicaGroups(
953 const std::vector<ReplicaGroup>& replica_groups, mlir::Builder* builder) {
954 const int64_t num_groups = replica_groups.size();
955 // Replica groups in HLO can be non-uniform in size, for example:
956 // replica_groups={{0},{1,2},{3}}. Since we are representing them as a 2D
957 // tensor, pad the smaller sized replica groups with -1.
958 const int64_t group_size = absl::c_accumulate(
959 replica_groups, int64_t(0), [](int64_t current, const ReplicaGroup& g) {
960 return std::max<int64_t>(current, g.replica_ids_size());
961 });
962 // Initialize all elements to -1 to support non-uniform replica groups.
963 std::vector<int64_t> attr(num_groups * group_size, -1);
964 for (int i = 0; i < num_groups; ++i) {
965 int index = i * group_size;
966 for (const int64& id : replica_groups[i].replica_ids()) attr[index++] = id;
967 }
968 auto type = mlir::RankedTensorType::get({num_groups, group_size},
969 builder->getIntegerType(64));
970 return builder->getNamedAttr("replica_groups",
971 DenseIntElementsAttr::get(type, attr));
972 }
973
ConvertChannelHandle(absl::optional<tensorflow::int64> channel_id)974 mlir::NamedAttribute HloFunctionImporter::ConvertChannelHandle(
975 absl::optional<tensorflow::int64> channel_id) {
976 xla::ChannelHandle channel_handle;
977 if (channel_id) channel_handle.set_handle(*channel_id);
978 return ConvertChannelHandle(channel_handle);
979 }
980
ConvertChannelHandle(const xla::ChannelHandle & channel)981 mlir::NamedAttribute HloFunctionImporter::ConvertChannelHandle(
982 const xla::ChannelHandle& channel) {
983 return builder_->getNamedAttr(
984 "channel_handle",
985 mlir::mhlo::ChannelHandle::get(
986 builder_->getI64IntegerAttr(channel.handle()),
987 builder_->getI64IntegerAttr(channel.type()), context_));
988 }
989
SetLayoutForMlir(mlir::Operation * op,const Shape & shape)990 void HloFunctionImporter::SetLayoutForMlir(mlir::Operation* op,
991 const Shape& shape) {
992 llvm::SmallVector<int64_t, 4> minor_to_major(
993 shape.layout().minor_to_major().begin(),
994 shape.layout().minor_to_major().end());
995 op->setAttr(
996 "minor_to_major",
997 mlir::Builder(op->getContext()).getIndexTensorAttr(minor_to_major));
998 }
999
1000 } // namespace xla
1001