• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
17 
18 #include <unordered_map>
19 
20 #include "absl/algorithm/container.h"
21 #include "absl/types/optional.h"
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/Support/raw_ostream.h"
26 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
27 #include "mlir/IR/Attributes.h"  // from @llvm-project
28 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
29 #include "mlir/IR/Builders.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
31 #include "mlir/IR/Identifier.h"  // from @llvm-project
32 #include "mlir/IR/Location.h"  // from @llvm-project
33 #include "mlir/IR/Region.h"  // from @llvm-project
34 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
35 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
36 #include "tensorflow/compiler/mlir/xla/attribute_importer.h"
37 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
38 #include "tensorflow/compiler/xla/comparison_util.h"
39 #include "tensorflow/compiler/xla/protobuf_util.h"
40 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
41 #include "tensorflow/compiler/xla/service/hlo_computation.h"
42 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
43 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
44 #include "tensorflow/compiler/xla/service/hlo_module.h"
45 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
46 #include "tensorflow/compiler/xla/status_macros.h"
47 #include "tensorflow/compiler/xla/xla_data.pb.h"
48 
49 using llvm::APInt;
50 using llvm::makeArrayRef;
51 using mlir::DenseIntElementsAttr;
52 using mlir::FuncOp;
53 using mlir::NamedAttribute;
54 using mlir::Operation;
55 using mlir::RankedTensorType;
56 using mlir::Type;
57 using mlir::Value;
58 
59 namespace xla {
60 
61 namespace {
62 
63 // Note: This sanitization function causes an irreversible many-to-one mapping
64 // and any solution to mitigate this would cause issues with the reverse
65 // direction. Longterm solution is to add a function attribute to maintain the
66 // original HLO naming.
SanitizeFunctionName(llvm::StringRef name)67 string SanitizeFunctionName(llvm::StringRef name) {
68   string output(name);
69   llvm::for_each(output, [](char& x) { x = x == '-' ? '_' : x; });
70   return output;
71 }
72 
73 // Returns whether the instruction is a default dot operation.
DotIsDefault(const HloInstruction * instruction)74 bool DotIsDefault(const HloInstruction* instruction) {
75   auto dnums = instruction->dot_dimension_numbers();
76   DotDimensionNumbers default_dimension_numbers;
77   default_dimension_numbers.add_lhs_contracting_dimensions(
78       instruction->operand(0)->shape().dimensions_size() == 1 ? 0 : 1);
79   default_dimension_numbers.add_rhs_contracting_dimensions(0);
80   return xla::protobuf_util::ProtobufEquals(dnums, default_dimension_numbers);
81 }
82 
83 // Returns an MLIR Location generated from HLO Instruction. Uses instruction
84 // metadata if present or instruction name.
GenerateInstructionLocation(HloInstruction * instruction,mlir::OpBuilder * func_builder)85 mlir::Location GenerateInstructionLocation(HloInstruction* instruction,
86                                            mlir::OpBuilder* func_builder) {
87   const std::string& op_name = instruction->metadata().op_name();
88   if (op_name.empty()) {
89     return mlir::NameLoc::get(func_builder->getIdentifier(instruction->name()),
90                               func_builder->getContext());
91   }
92 
93   mlir::Location op_name_loc = mlir::NameLoc::get(
94       func_builder->getIdentifier(op_name), func_builder->getContext());
95   const std::string& source_file = instruction->metadata().source_file();
96   if (source_file.empty()) {
97     return op_name_loc;
98   }
99 
100   return mlir::FusedLoc::get(
101       {op_name_loc, mlir::FileLineColLoc::get(
102                         source_file, instruction->metadata().source_line(), 0,
103                         func_builder->getContext())},
104       func_builder->getContext());
105 }
106 }  // namespace
107 
ImportAsFunc(const HloComputation & computation,mlir::ModuleOp module,std::unordered_map<const HloComputation *,FuncOp> * function_map,mlir::Builder * builder)108 Status HloFunctionImporter::ImportAsFunc(
109     const HloComputation& computation, mlir::ModuleOp module,
110     std::unordered_map<const HloComputation*, FuncOp>* function_map,
111     mlir::Builder* builder) {
112   HloFunctionImporter importer(module, function_map, builder);
113   return importer.ImportAsFunc(computation).status();
114 }
115 
ImportAsRegion(const xla::HloComputation & computation,mlir::Region * region,mlir::Builder * builder)116 Status HloFunctionImporter::ImportAsRegion(
117     const xla::HloComputation& computation, mlir::Region* region,
118     mlir::Builder* builder) {
119   HloFunctionImporter importer(region->getParentOfType<mlir::ModuleOp>(), {},
120                                builder);
121   return importer.ImportAsRegion(computation, region);
122 }
123 
ImportAsFunc(const HloComputation & computation)124 StatusOr<mlir::FuncOp> HloFunctionImporter::ImportAsFunc(
125     const HloComputation& computation) {
126   auto& imported = (*function_map_)[&computation];
127   if (imported) return imported;
128   llvm::SmallVector<Type, 4> args, rets;
129   TF_RETURN_IF_ERROR(GetMlirTypes(computation.parameter_instructions(), &args));
130   TF_RETURN_IF_ERROR(GetMlirTypes({computation.root_instruction()}, &rets));
131   auto func_type = mlir::FunctionType::get(context_, args, rets);
132 
133   string computation_name =
134       computation.parent()->entry_computation() == &computation
135           ? "main"
136           : SanitizeFunctionName(computation.name());
137 
138   // Construct the MLIR function and map arguments.
139   llvm::ArrayRef<mlir::NamedAttribute> attrs;
140   auto function = mlir::FuncOp::create(mlir::UnknownLoc::get(context_),
141                                        computation_name, func_type, attrs);
142   auto visibility = computation_name == "main" ? FuncOp::Visibility::Public
143                                                : FuncOp::Visibility::Private;
144   function.setVisibility(visibility);
145   module_.push_back(function);
146 
147   // Add to the map right away for function calls.
148   imported = function;
149 
150   mlir::Block* block = function.addEntryBlock();
151   TF_RETURN_IF_ERROR(ImportInstructions(computation, block));
152 
153   return function;
154 }
155 
ImportAsRegion(const HloComputation & computation,mlir::Region * region)156 tensorflow::Status HloFunctionImporter::ImportAsRegion(
157     const HloComputation& computation, mlir::Region* region) {
158   // TODO(hinsu): Store computation name as an attribute for round-trip.
159   auto* block = new mlir::Block;
160   region->push_back(block);
161 
162   llvm::SmallVector<Type, 4> args;
163   TF_RETURN_IF_ERROR(GetMlirTypes(computation.parameter_instructions(), &args));
164   block->addArguments(args);
165 
166   return ImportInstructions(computation, block);
167 }
168 
ImportInstructionsImpl(const xla::HloComputation & computation,const llvm::SmallVectorImpl<Value> & arguments,mlir::OpBuilder * builder)169 StatusOr<Value> HloFunctionImporter::ImportInstructionsImpl(
170     const xla::HloComputation& computation,
171     const llvm::SmallVectorImpl<Value>& arguments, mlir::OpBuilder* builder) {
172   // Setup the input parameters.
173   const int num_parameters = computation.num_parameters();
174 
175   if (arguments.size() != num_parameters)
176     return InvalidArgument("Caller vs callee argument sizes do not match");
177 
178   for (int i = 0; i < num_parameters; i++) {
179     auto hlo_parameter = computation.parameter_instruction(i);
180     instruction_value_map_[hlo_parameter] = arguments[i];
181   }
182 
183   for (auto instruction : computation.MakeInstructionPostOrder()) {
184     TF_ASSIGN_OR_RETURN(auto new_operation,
185                         ImportInstruction(instruction, builder));
186     if (new_operation) {
187       instruction_value_map_[instruction] = new_operation->getResult(0);
188     }
189   }
190 
191   // Setup the return type (HLO only supports a single return value).
192   return GetMlirValue(computation.root_instruction());
193 }
194 
ImportInstructions(const HloComputation & computation,mlir::Block * block)195 Status HloFunctionImporter::ImportInstructions(
196     const HloComputation& computation, mlir::Block* block) {
197   llvm::SmallVector<Value, 4> arguments(block->args_begin(), block->args_end());
198   mlir::OpBuilder builder = mlir::OpBuilder::atBlockEnd(block);
199   TF_ASSIGN_OR_RETURN(Value result,
200                       ImportInstructionsImpl(computation, arguments, &builder));
201 
202   // TODO(suderman): Add location tracking details.
203   mlir::Location loc = builder.getUnknownLoc();
204 
205   // Create terminator op depending on the parent op of this region.
206   if (llvm::isa<FuncOp>(block->getParentOp())) {
207     builder.create<mlir::ReturnOp>(loc, result);
208   } else {
209     builder.create<mlir::mhlo::ReturnOp>(loc, result);
210   }
211   return tensorflow::Status::OK();
212 }
213 
ImportInstructions(const xla::HloComputation & computation,const llvm::SmallVectorImpl<Value> & arguments,mlir::OpBuilder * builder)214 StatusOr<Value> HloFunctionImporter::ImportInstructions(
215     const xla::HloComputation& computation,
216     const llvm::SmallVectorImpl<Value>& arguments, mlir::OpBuilder* builder) {
217   mlir::Block* block = builder->getBlock();
218   if (block == nullptr)
219     return InvalidArgument(
220         "ImportInstructions requires a valid block in the builder");
221 
222   HloFunctionImporter importer(
223       block->getParent()->getParentOfType<mlir::ModuleOp>(), {}, builder);
224   return importer.ImportInstructionsImpl(computation, arguments, builder);
225 }
226 
ImportInstructionImpl(HloInstruction * instruction,mlir::OpBuilder * func_builder)227 StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
228     HloInstruction* instruction, mlir::OpBuilder* func_builder) {
229   TF_ASSIGN_OR_RETURN(auto operands, GetOperands(instruction));
230   TF_ASSIGN_OR_RETURN(auto result_type, ConvertShapeToType<RankedTensorType>(
231                                             instruction->shape(), *builder_));
232   mlir::Location loc = GenerateInstructionLocation(instruction, func_builder);
233 
234   llvm::SmallVector<NamedAttribute, 10> attributes;
235   switch (instruction->opcode()) {
236     case HloOpcode::kParameter: {
237       return nullptr;
238     }
239     case HloOpcode::kConstant: {
240       const Literal& literal = instruction->literal();
241       auto attr = CreateDenseElementsAttrFromLiteral(literal, *builder_);
242       if (!attr.ok()) return attr.status();
243       mlir::Operation* new_operation =
244           func_builder->create<mlir::mhlo::ConstOp>(loc, attr.ValueOrDie());
245       for (auto attr : attributes) {
246         new_operation->setAttr(attr.first, attr.second);
247       }
248       return new_operation;
249     }
250     case HloOpcode::kIota: {
251       return func_builder
252           ->create<mlir::mhlo::IotaOp>(
253               loc, result_type,
254               func_builder->getI64IntegerAttr(
255                   Cast<HloIotaInstruction>(instruction)->iota_dimension()))
256           .getOperation();
257     }
258 #define MakeAndReturn(mlir_op)                                                \
259   {                                                                           \
260     mlir::Operation* new_operation =                                          \
261         func_builder->create<mlir::mhlo::mlir_op>(loc, result_type, operands, \
262                                                   attributes);                \
263     return new_operation;                                                     \
264   }
265     case HloOpcode::kBroadcast: {
266       // Note that the HLO broadcast is more powerful than the XLA broadcast
267       // op. BroadcastInDim offers a superset of the HLO op's functionality.
268       attributes.push_back(
269           builder_->getNamedAttr("broadcast_dimensions",
270                                  ConvertDimensions(instruction->dimensions())));
271       MakeAndReturn(BroadcastInDimOp);
272     }
273 #define MakeAndReturnBatchNormOp(batch_norm_op)                         \
274   {                                                                     \
275     attributes.push_back(builder_->getNamedAttr(                        \
276         "epsilon", builder_->getF32FloatAttr(instruction->epsilon()))); \
277     attributes.push_back(builder_->getNamedAttr(                        \
278         "feature_index",                                                \
279         builder_->getI64IntegerAttr(instruction->feature_index())));    \
280     MakeAndReturn(batch_norm_op);                                       \
281   }
282     case HloOpcode::kBatchNormGrad:
283       MakeAndReturnBatchNormOp(BatchNormGradOp);
284     case HloOpcode::kBatchNormInference:
285       MakeAndReturnBatchNormOp(BatchNormInferenceOp);
286     case HloOpcode::kBatchNormTraining:
287       MakeAndReturnBatchNormOp(BatchNormTrainingOp);
288 #undef MakeAndReturnBatchNormOp
289 
290     case HloOpcode::kDot: {
291       attributes.push_back(builder_->getNamedAttr(
292           "precision_config",
293           ConvertPrecisionConfig(&instruction->precision_config(), builder_)));
294 
295       // Consider consolidating DotOps together.
296       if (DotIsDefault(instruction)) {
297         MakeAndReturn(DotOp);
298       }
299 
300       attributes.push_back(builder_->getNamedAttr(
301           "dot_dimension_numbers",
302           ConvertDotDimensionNumbers(instruction->dot_dimension_numbers(),
303                                      builder_)));
304       MakeAndReturn(DotGeneralOp);
305     }
306     case HloOpcode::kCall: {
307       TF_ASSIGN_OR_RETURN(FuncOp function,
308                           ImportAsFunc(*instruction->to_apply()));
309       mlir::Operation* new_operation =
310           func_builder->create<mlir::CallOp>(loc, function, operands);
311       return new_operation;
312     }
313     case HloOpcode::kCollectivePermute: {
314       attributes.push_back(ConvertSourceTargetPairs(
315           instruction->source_target_pairs(), builder_));
316       MakeAndReturn(CollectivePermuteOp);
317     }
318     case HloOpcode::kCustomCall: {
319       auto custom_call = Cast<HloCustomCallInstruction>(instruction);
320       attributes.push_back(builder_->getNamedAttr(
321           "call_target_name",
322           builder_->getStringAttr(custom_call->custom_call_target())));
323       attributes.push_back(builder_->getNamedAttr(
324           "has_side_effect",
325           builder_->getBoolAttr(custom_call->custom_call_has_side_effect())));
326       attributes.push_back(builder_->getNamedAttr(
327           "backend_config",
328           builder_->getStringAttr(custom_call->raw_backend_config_string())));
329       MakeAndReturn(CustomCallOp);
330     }
331     case HloOpcode::kCompare: {
332       auto compare = Cast<HloCompareInstruction>(instruction);
333       attributes.push_back(ConvertComparisonDirection(compare->direction()));
334       auto default_type = Comparison::DefaultComparisonType(
335           compare->operand(0)->shape().element_type());
336       if (compare->type() != default_type)
337         attributes.push_back(ConvertComparisonType(compare->type()));
338       MakeAndReturn(CompareOp);
339     }
340     case HloOpcode::kCholesky: {
341       attributes.push_back(builder_->getNamedAttr(
342           "lower",
343           builder_->getBoolAttr(instruction->cholesky_options().lower())));
344       MakeAndReturn(CholeskyOp);
345     }
346     case HloOpcode::kGather: {
347       auto gather_instruction = Cast<HloGatherInstruction>(instruction);
348       attributes.push_back(builder_->getNamedAttr(
349           "dimension_numbers",
350           ConvertGatherDimensionNumbers(
351               gather_instruction->gather_dimension_numbers(), builder_)));
352 
353       std::vector<int64_t> slice_sizes(
354           gather_instruction->gather_slice_sizes().begin(),
355           gather_instruction->gather_slice_sizes().end());
356       attributes.push_back(
357           builder_->getNamedAttr("slice_sizes", Convert(slice_sizes)));
358       attributes.push_back(builder_->getNamedAttr(
359           "indices_are_sorted",
360           builder_->getBoolAttr(gather_instruction->indices_are_sorted())));
361 
362       MakeAndReturn(GatherOp);
363     }
364     case HloOpcode::kDynamicSlice: {
365       std::vector<int64_t> slice_sizes(
366           instruction->dynamic_slice_sizes().begin(),
367           instruction->dynamic_slice_sizes().end());
368       return func_builder
369           ->create<mlir::mhlo::DynamicSliceOp>(
370               loc, result_type, operands[0],
371               makeArrayRef(operands).drop_front(), Convert(slice_sizes))
372           .getOperation();
373     }
374     case HloOpcode::kDynamicUpdateSlice: {
375       return func_builder
376           ->create<mlir::mhlo::DynamicUpdateSliceOp>(
377               loc, result_type, operands[0], operands[1],
378               llvm::ArrayRef<Value>(operands.begin() + 2, operands.end()))
379           .getOperation();
380     }
381     case HloOpcode::kInfeed: {
382       attributes.push_back(builder_->getNamedAttr(
383           "infeed_config",
384           mlir::StringAttr::get(builder_->getContext(),
385                                 instruction->infeed_config())));
386       // TODO(kramm): Support tuples and tokens.
387       if (instruction->shape().IsArray()) {
388         const xla::Layout l = instruction->shape().layout();
389         absl::Span<const int64> minor_to_major = l.minor_to_major();
390         std::vector<mlir::Attribute> v;
391         for (int64 i : minor_to_major) {
392           v.push_back(builder_->getI32IntegerAttr(i));
393         }
394         llvm::ArrayRef<mlir::Attribute> array_ref(v);
395         mlir::ArrayAttr layout = builder_->getArrayAttr(array_ref);
396         attributes.push_back(builder_->getNamedAttr("layout", layout));
397       }
398 
399       MakeAndReturn(InfeedOp);
400     }
401     case HloOpcode::kOutfeed: {
402       attributes.push_back(builder_->getNamedAttr(
403           "outfeed_config",
404           mlir::StringAttr::get(builder_->getContext(),
405                                 instruction->outfeed_config())));
406       MakeAndReturn(OutfeedOp);
407     }
408     case HloOpcode::kPad: {
409       const auto& padding_config = instruction->padding_config();
410       llvm::SmallVector<int64_t, 4> edge_padding_low;
411       llvm::SmallVector<int64_t, 4> edge_padding_high;
412       llvm::SmallVector<int64_t, 4> interior_padding;
413       edge_padding_low.reserve(padding_config.dimensions_size());
414       edge_padding_high.reserve(padding_config.dimensions_size());
415       interior_padding.reserve(padding_config.dimensions_size());
416 
417       for (const auto& dimension : padding_config.dimensions()) {
418         edge_padding_low.push_back(dimension.edge_padding_low());
419         edge_padding_high.push_back(dimension.edge_padding_high());
420         interior_padding.push_back(dimension.interior_padding());
421       }
422 
423       return func_builder
424           ->create<mlir::mhlo::PadOp>(loc, result_type, operands[0],
425                                       operands[1], Convert(edge_padding_low),
426                                       Convert(edge_padding_high),
427                                       Convert(interior_padding))
428           .getOperation();
429     }
430     case HloOpcode::kScatter: {
431       auto scatter = Cast<HloScatterInstruction>(instruction);
432       attributes.push_back(builder_->getNamedAttr(
433           "scatter_dimension_numbers",
434           ConvertScatterDimensionNumbers(scatter->scatter_dimension_numbers(),
435                                          builder_)));
436       attributes.push_back(builder_->getNamedAttr(
437           "indices_are_sorted",
438           builder_->getBoolAttr(scatter->indices_are_sorted())));
439       attributes.push_back(builder_->getNamedAttr(
440           "unique_indices", builder_->getBoolAttr(scatter->unique_indices())));
441 
442       auto scatter_op = func_builder->create<mlir::mhlo::ScatterOp>(
443           loc, result_type, operands, attributes);
444       TF_RETURN_IF_ERROR(ImportAsRegion(*scatter->to_apply(),
445                                         &scatter_op.update_computation()));
446       return scatter_op.getOperation();
447     }
448     case HloOpcode::kSelectAndScatter: {
449       auto select_scatter = Cast<HloSelectAndScatterInstruction>(instruction);
450       llvm::SmallVector<int64_t, 4> window_strides, window_dimensions;
451       llvm::SmallVector<int64_t, 8> padding;
452       for (const auto& dim : select_scatter->window().dimensions()) {
453         window_strides.push_back(dim.stride());
454         window_dimensions.push_back(dim.size());
455         padding.push_back(dim.padding_low());
456         padding.push_back(dim.padding_high());
457       }
458       attributes.push_back(
459           builder_->getNamedAttr("window_strides", Convert(window_strides)));
460       attributes.push_back(builder_->getNamedAttr("window_dimensions",
461                                                   Convert(window_dimensions)));
462       attributes.push_back(ConvertPadding(padding));
463       auto select_scatter_op =
464           func_builder->create<mlir::mhlo::SelectAndScatterOp>(
465               loc, result_type, operands, attributes);
466       TF_RETURN_IF_ERROR(ImportAsRegion(*select_scatter->select(),
467                                         &select_scatter_op.select()));
468       TF_RETURN_IF_ERROR(ImportAsRegion(*select_scatter->scatter(),
469                                         &select_scatter_op.scatter()));
470       return select_scatter_op.getOperation();
471     }
472     case HloOpcode::kSetDimensionSize: {
473       attributes.push_back(builder_->getNamedAttr(
474           "dimension", builder_->getI64IntegerAttr(instruction->dimension())));
475       MakeAndReturn(SetDimensionSizeOp);
476     }
477     case HloOpcode::kSlice: {
478       return func_builder
479           ->create<mlir::mhlo::SliceOp>(
480               loc, result_type, operands[0],
481               ConvertDimensions(instruction->slice_starts()),
482               ConvertDimensions(instruction->slice_limits()),
483               ConvertDimensions(instruction->slice_strides()))
484           .getOperation();
485     }
486     case HloOpcode::kSort: {
487       auto sort_instruction = Cast<HloSortInstruction>(instruction);
488 
489       llvm::SmallVector<Type, 4> return_types = {result_type};
490       if (mlir::TupleType tuple_ty = result_type.dyn_cast<mlir::TupleType>()) {
491         return_types = llvm::to_vector<6>(tuple_ty.getTypes());
492       }
493 
494       auto sort_op = func_builder->create<mlir::mhlo::SortOp>(
495           loc, return_types, operands,
496           builder_->getI64IntegerAttr(sort_instruction->sort_dimension()),
497           builder_->getBoolAttr(sort_instruction->is_stable()));
498       TF_RETURN_IF_ERROR(
499           ImportAsRegion(*sort_instruction->to_apply(), &sort_op.comparator()));
500 
501       // Check if the output needs to be tupled.
502       if (return_types.size() == 1 && return_types.front() == result_type) {
503         return sort_op.getOperation();
504       }
505 
506       return func_builder
507           ->create<mlir::mhlo::TupleOp>(loc, result_type, sort_op.getResults())
508           .getOperation();
509     }
510     case HloOpcode::kConditional: {
511       llvm::SmallVector<Type, 4> rets;
512       mlir::Type pred_or_index_type =
513           operands[0].getType().cast<mlir::TensorType>().getElementType();
514       // It is a predicated conditional if first argument is a boolean and
515       // should be mapped to If op.
516       if (pred_or_index_type.isInteger(1)) {
517         TF_RETURN_IF_ERROR(GetMlirTypes(
518             {instruction->true_computation()->root_instruction()}, &rets));
519 
520         auto op = func_builder->create<mlir::mhlo::IfOp>(loc, rets, operands,
521                                                          attributes);
522         TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->true_computation(),
523                                           &op.true_branch()));
524         TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->false_computation(),
525                                           &op.false_branch()));
526         return op.getOperation();
527       }
528 
529       // Otherwise, it is a indexed conditional and should be mapped to Case
530       // op.
531       TF_RETURN_IF_ERROR(GetMlirTypes(
532           {instruction->branch_computation(0)->root_instruction()}, &rets));
533 
534       int num_branches = instruction->branch_count();
535       auto op = func_builder->create<mlir::mhlo::CaseOp>(
536           loc, rets, operands, attributes, num_branches);
537       for (auto index_and_computation :
538            llvm::enumerate(instruction->branch_computations())) {
539         auto index = index_and_computation.index();
540         HloComputation* computation = index_and_computation.value();
541         TF_RETURN_IF_ERROR(ImportAsRegion(*computation, &op.branches()[index]));
542       }
543       return op.getOperation();
544     }
545     case HloOpcode::kConcatenate: {
546       // TODO(b/132057942): Support taking an uint64_t instead of an
547       // IntegerAttr for concatenate dimension.
548       return func_builder
549           ->create<mlir::mhlo::ConcatenateOp>(
550               loc, result_type, operands,
551               builder_->getI64IntegerAttr(instruction->concatenate_dimension()))
552           .getOperation();
553     }
554     case HloOpcode::kAllReduce: {
555       auto all_reduce = Cast<HloAllReduceInstruction>(instruction);
556       attributes.push_back(
557           ConvertReplicaGroups(all_reduce->replica_groups(), builder_));
558       attributes.push_back(ConvertChannelHandle(all_reduce->channel_id()));
559       auto all_reduce_op = func_builder->create<mlir::mhlo::AllReduceOp>(
560           loc, result_type, operands, attributes);
561       TF_RETURN_IF_ERROR(ImportAsRegion(*all_reduce->to_apply(),
562                                         &all_reduce_op.computation()));
563       return all_reduce_op.getOperation();
564     }
565     case HloOpcode::kReduce: {
566       // Operands in the first half are reduction inputs and the remaining
567       // operands are corresponding initial values.
568       size_t num_inputs = operands.size() / 2;
569       auto reduce = func_builder->create<mlir::mhlo::ReduceOp>(
570           loc, result_type, llvm::makeArrayRef(operands).take_front(num_inputs),
571           llvm::makeArrayRef(operands).drop_front(num_inputs),
572           ConvertDimensions(instruction->dimensions()));
573       TF_RETURN_IF_ERROR(
574           ImportAsRegion(*instruction->to_apply(), &reduce.body()));
575       return reduce.getOperation();
576     }
577     case HloOpcode::kReverse: {
578       return func_builder
579           ->create<mlir::mhlo::ReverseOp>(
580               loc, result_type, operands[0],
581               ConvertDimensions(instruction->dimensions()))
582           .getOperation();
583     }
584     case HloOpcode::kRng: {
585       auto shape = func_builder->create<mlir::ConstantOp>(
586           loc, Convert(result_type.cast<RankedTensorType>().getShape()));
587       switch (instruction->random_distribution()) {
588         case xla::RNG_UNIFORM:
589           return func_builder
590               ->create<mlir::mhlo::RngUniformOp>(loc, result_type, operands[0],
591                                                  operands[1], shape)
592               .getOperation();
593 
594         case xla::RNG_NORMAL:
595           return func_builder
596               ->create<mlir::mhlo::RngNormalOp>(loc, result_type, operands[0],
597                                                 operands[1], shape)
598               .getOperation();
599 
600         default:
601           return tensorflow::errors::InvalidArgument(absl::StrCat(
602               "Unsupported distribution: ",
603               RandomDistributionToString(instruction->random_distribution())));
604       }
605     }
606     case HloOpcode::kRngBitGenerator: {
607       auto rng_op = Cast<HloRngBitGeneratorInstruction>(instruction);
608       auto op = func_builder->create<mlir::mhlo::RngBitGeneratorOp>(
609           loc, result_type,
610           func_builder->getI32IntegerAttr(rng_op->algorithm()), operands[0]);
611       return op.getOperation();
612     }
613     case HloOpcode::kWhile: {
614       auto op = func_builder->create<mlir::mhlo::WhileOp>(
615           loc, operands[0].getType(), operands[0]);
616       TF_RETURN_IF_ERROR(
617           ImportAsRegion(*instruction->while_condition(), &op.cond()));
618       TF_RETURN_IF_ERROR(
619           ImportAsRegion(*instruction->while_body(), &op.body()));
620       return op.getOperation();
621     }
622     case HloOpcode::kGetTupleElement: {
623       attributes.push_back(builder_->getNamedAttr(
624           "index", builder_->getIntegerAttr(builder_->getIntegerType(32),
625                                             instruction->tuple_index())));
626       MakeAndReturn(GetTupleElementOp);
627     };
628     case HloOpcode::kGetDimensionSize: {
629       attributes.push_back(builder_->getNamedAttr(
630           "dimension", builder_->getI64IntegerAttr(instruction->dimension())));
631       MakeAndReturn(GetDimensionSizeOp);
632     };
633     case HloOpcode::kTranspose: {
634       attributes.push_back(builder_->getNamedAttr(
635           "permutation", ConvertDimensions(instruction->dimensions())));
636       MakeAndReturn(TransposeOp);
637     }
638     case HloOpcode::kTriangularSolve: {
639       attributes.push_back(builder_->getNamedAttr(
640           "left_side",
641           builder_->getBoolAttr(
642               instruction->triangular_solve_options().left_side())));
643       attributes.push_back(builder_->getNamedAttr(
644           "lower", builder_->getBoolAttr(
645                        instruction->triangular_solve_options().lower())));
646       attributes.push_back(builder_->getNamedAttr(
647           "unit_diagonal",
648           builder_->getBoolAttr(
649               instruction->triangular_solve_options().unit_diagonal())));
650       auto transpose_a =
651           builder_->getStringAttr(TriangularSolveOptions::Transpose_Name(
652               instruction->triangular_solve_options().transpose_a()));
653       attributes.push_back(builder_->getNamedAttr("transpose_a", transpose_a));
654       MakeAndReturn(TriangularSolveOp);
655     }
656     case HloOpcode::kReduceWindow: {
657       llvm::SmallVector<int64, 4> sizes, strides, base_dilations, win_dilations;
658       llvm::SmallVector<int64_t, 8> padding;
659       for (const auto& dim : instruction->window().dimensions()) {
660         sizes.push_back(dim.size());
661         strides.push_back(dim.stride());
662         base_dilations.push_back(dim.base_dilation());
663         win_dilations.push_back(dim.window_dilation());
664         padding.push_back(dim.padding_low());
665         padding.push_back(dim.padding_high());
666       }
667       attributes.push_back(builder_->getNamedAttr("window_dimensions",
668                                                   ConvertDimensions(sizes)));
669       attributes.push_back(
670           builder_->getNamedAttr("window_strides", ConvertDimensions(strides)));
671       attributes.push_back(builder_->getNamedAttr(
672           "base_dilations", ConvertDimensions(base_dilations)));
673       attributes.push_back(builder_->getNamedAttr(
674           "window_dilations", ConvertDimensions(win_dilations)));
675       attributes.push_back(ConvertPadding(padding));
676       auto reduce = func_builder->create<mlir::mhlo::ReduceWindowOp>(
677           loc, result_type, operands, attributes);
678       TF_RETURN_IF_ERROR(
679           ImportAsRegion(*instruction->to_apply(), &reduce.body()));
680       return reduce.getOperation();
681     }
682     case HloOpcode::kMap: {
683       auto op = func_builder->create<mlir::mhlo::MapOp>(
684           loc, result_type, operands,
685           ConvertDimensions(instruction->dimensions()));
686       TF_RETURN_IF_ERROR(
687           ImportAsRegion(*instruction->to_apply(), &op.computation()));
688       return op.getOperation();
689     }
690     case HloOpcode::kConvolution: {
691       llvm::SmallVector<int64_t, 4> strides, lhs_dilations, rhs_dilations;
692       llvm::SmallVector<int64_t, 8> paddings;
693       for (const auto& dim : instruction->window().dimensions()) {
694         strides.push_back(dim.stride());
695         lhs_dilations.push_back(dim.base_dilation());
696         rhs_dilations.push_back(dim.window_dilation());
697         paddings.push_back(dim.padding_low());
698         paddings.push_back(dim.padding_high());
699       }
700 
701       attributes.push_back(
702           builder_->getNamedAttr("window_strides", Convert(strides)));
703       attributes.push_back(ConvertPadding(paddings));
704       attributes.push_back(
705           builder_->getNamedAttr("lhs_dilation", Convert(lhs_dilations)));
706       attributes.push_back(
707           builder_->getNamedAttr("rhs_dilation", Convert(rhs_dilations)));
708       attributes.push_back(builder_->getNamedAttr(
709           "dimension_numbers",
710           ConvertConvDimensionNumbers(
711               instruction->convolution_dimension_numbers(), builder_)));
712       attributes.push_back(builder_->getNamedAttr(
713           "feature_group_count",
714           builder_->getI64IntegerAttr(instruction->feature_group_count())));
715       attributes.push_back(builder_->getNamedAttr(
716           "batch_group_count",
717           builder_->getI64IntegerAttr(instruction->batch_group_count())));
718       attributes.push_back(builder_->getNamedAttr(
719           "precision_config",
720           ConvertPrecisionConfig(&instruction->precision_config(), builder_)));
721 
722       MakeAndReturn(ConvOp);
723     }
724 
725     case HloOpcode::kFft: {
726       auto fft_type =
727           builder_->getStringAttr(FftType_Name(instruction->fft_type()));
728 
729       std::vector<int64_t> fft_length(instruction->fft_length().begin(),
730                                       instruction->fft_length().end());
731 
732       attributes.push_back(builder_->getNamedAttr("fft_type", fft_type));
733       attributes.push_back(
734           builder_->getNamedAttr("fft_length", Convert(fft_length)));
735       MakeAndReturn(FftOp);
736     }
737 #define NoAttributeCase(hlo_op_code, mlir_op) \
738   case HloOpcode::hlo_op_code: {              \
739     MakeAndReturn(mlir_op);                   \
740   }
741 
742       // broadcast dimensions are never added here because they don't exist as
743       // part of the HLO instruction. They are only a convenience in the XLA
744       // builder API.
745       NoAttributeCase(kAbs, AbsOp);
746       NoAttributeCase(kAdd, AddOp);
747       NoAttributeCase(kAfterAll, AfterAllOp);
748       NoAttributeCase(kAnd, AndOp);
749       NoAttributeCase(kAtan2, Atan2Op);
750       NoAttributeCase(kBitcastConvert, BitcastConvertOp);
751       NoAttributeCase(kCbrt, CbrtOp);
752       NoAttributeCase(kConvert, ConvertOp);
753       NoAttributeCase(kCeil, CeilOp);
754       NoAttributeCase(kClamp, ClampOp);
755       NoAttributeCase(kComplex, ComplexOp);
756       NoAttributeCase(kCos, CosOp);
757       NoAttributeCase(kDivide, DivOp);
758       NoAttributeCase(kExp, ExpOp);
759       NoAttributeCase(kExpm1, Expm1Op);
760       NoAttributeCase(kFloor, FloorOp);
761       NoAttributeCase(kIsFinite, IsFiniteOp);
762       NoAttributeCase(kImag, ImagOp);
763       NoAttributeCase(kLog, LogOp);
764       NoAttributeCase(kLog1p, Log1pOp);
765       NoAttributeCase(kMaximum, MaxOp);
766       NoAttributeCase(kMinimum, MinOp);
767       NoAttributeCase(kMultiply, MulOp);
768       NoAttributeCase(kNegate, NegOp);
769       NoAttributeCase(kNot, NotOp);
770       NoAttributeCase(kOr, OrOp);
771       NoAttributeCase(kPopulationCount, PopulationCountOp);
772       NoAttributeCase(kPower, PowOp);
773       NoAttributeCase(kReal, RealOp);
774       NoAttributeCase(kRemainder, RemOp);
775       NoAttributeCase(kReplicaId, ReplicaIdOp);
776       // The dimensions attribute is not present on the HLO Reshape
777       // instruction. If dimensions are non-default, the XLA builder
778       // implements it as a separate transpose.
779       NoAttributeCase(kReshape, ReshapeOp);
780       NoAttributeCase(kRoundNearestAfz, RoundOp);
781       NoAttributeCase(kRsqrt, RsqrtOp);
782       NoAttributeCase(kSelect, SelectOp);
783       NoAttributeCase(kShiftLeft, ShiftLeftOp);
784       NoAttributeCase(kShiftRightArithmetic, ShiftRightArithmeticOp);
785       NoAttributeCase(kShiftRightLogical, ShiftRightLogicalOp);
786       NoAttributeCase(kSign, SignOp);
787       NoAttributeCase(kSin, SinOp);
788       NoAttributeCase(kSqrt, SqrtOp);
789       NoAttributeCase(kSubtract, SubOp);
790       NoAttributeCase(kTanh, TanhOp);
791       NoAttributeCase(kTuple, TupleOp);
792       NoAttributeCase(kXor, XorOp);
793       // TODO(b/129422361) Copy needs special handling because it is not
794       // defined in tensorflow/compiler/xla/client/xla_builder.h. See
795       // operation semantics in
796       // g3doc/platforms/xla/g3doc/internal/hlo_semantics#copy
797       NoAttributeCase(kCopy, CopyOp);
798 #undef NoAttributeCase
799 #undef MakeAndReturn
800     case HloOpcode::kFusion: {
801       auto fusion = func_builder->create<mlir::mhlo::FusionOp>(
802           loc, result_type, operands,
803           builder_->getStringAttr(xla::ToString(instruction->fusion_kind())));
804       TF_RETURN_IF_ERROR(
805           ImportAsRegion(*instruction->fused_instructions_computation(),
806                          &fusion.fused_computation()));
807       return fusion.getOperation();
808     }
809     case HloOpcode::kBitcast:
810       return func_builder
811           ->create<mlir::mhlo::BitcastOp>(loc, result_type, operands,
812                                           attributes)
813           .getOperation();
814     case HloOpcode::kReducePrecision: {
815       auto op = func_builder->create<mlir::mhlo::ReducePrecisionOp>(
816           loc, result_type, operands[0], attributes);
817       op.exponent_bitsAttr(func_builder->getIntegerAttr(
818           func_builder->getI32Type(), instruction->exponent_bits()));
819       op.mantissa_bitsAttr(func_builder->getIntegerAttr(
820           func_builder->getI32Type(), instruction->mantissa_bits()));
821       return op.getOperation();
822     }
823     case HloOpcode::kAddDependency:
824       // Arbitrary op code that I suspect we will not implement for quite a
825       // while and allows testing handling of unknown ops. Selected because it
826       // is not mentioned in xla client anywhere or in the hlo of our sample
827       // models.
828     default: {
829       mlir::OperationState result(loc, "mhlo.unknown");
830       result.addOperands(operands);
831       result.addTypes(result_type);
832       for (auto attr : attributes) {
833         result.attributes.push_back(attr);
834       }
835 
836       return func_builder->createOperation(result);
837     }
838   }
839 }
840 
ImportInstruction(HloInstruction * instruction,mlir::OpBuilder * func_builder)841 StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
842     HloInstruction* instruction, mlir::OpBuilder* func_builder) {
843   TF_ASSIGN_OR_RETURN(mlir::Operation * op,
844                       ImportInstructionImpl(instruction, func_builder));
845   if (op == nullptr) return op;
846 
847   // See MlirToHloConversionOptions for more about layouts.
848   //
849   // Minor-to-major is a permutation of [0, rank), presenting tensor dimensions
850   // in physical minor-to-major order.
851   if (instruction->shape().IsArray() &&
852       !instruction->shape().layout().minor_to_major().empty() &&
853       instruction->shape().layout() !=
854           LayoutUtil::MakeDescendingLayout(
855               instruction->shape().dimensions().size())) {
856     SetLayoutForMlir(op, instruction->shape());
857   }
858   return op;
859 }
860 
GetOperands(HloInstruction * instruction)861 StatusOr<llvm::SmallVector<mlir::Value, 4>> HloFunctionImporter::GetOperands(
862     HloInstruction* instruction) {
863   llvm::SmallVector<mlir::Value, 4> operands;
864   for (const auto& operand : instruction->operands()) {
865     auto input_it = instruction_value_map_.find(operand);
866     if (input_it == instruction_value_map_.end()) {
867       return tensorflow::errors::Internal(
868           absl::StrCat("Could not find input value: ", operand->name(),
869                        " for instruction ", instruction->name()));
870     }
871     operands.push_back(input_it->second);
872   }
873   return operands;
874 }
875 
GetMlirTypes(const std::vector<HloInstruction * > & instructions,llvm::SmallVectorImpl<mlir::Type> * types)876 tensorflow::Status HloFunctionImporter::GetMlirTypes(
877     const std::vector<HloInstruction*>& instructions,
878     llvm::SmallVectorImpl<mlir::Type>* types) {
879   for (auto instruction : instructions) {
880     TF_ASSIGN_OR_RETURN(auto ret_type, ConvertShapeToType<RankedTensorType>(
881                                            instruction->shape(), *builder_));
882     types->push_back(ret_type);
883   }
884   return tensorflow::Status::OK();
885 }
886 
GetMlirValue(HloInstruction * instruction)887 StatusOr<Value> HloFunctionImporter::GetMlirValue(HloInstruction* instruction) {
888   auto lookup = instruction_value_map_.find(instruction);
889   if (lookup != instruction_value_map_.end()) {
890     return lookup->second;
891   }
892 
893   return tensorflow::errors::Internal(absl::StrCat(
894       "Unable to find value for input: ", instruction->ToString()));
895 }
896 
ConvertComparisonDirection(ComparisonDirection direction)897 mlir::NamedAttribute HloFunctionImporter::ConvertComparisonDirection(
898     ComparisonDirection direction) {
899   return builder_->getNamedAttr(
900       "comparison_direction",
901       builder_->getStringAttr(ComparisonDirectionToString(direction)));
902 }
903 
ConvertComparisonType(Comparison::Type type)904 mlir::NamedAttribute HloFunctionImporter::ConvertComparisonType(
905     Comparison::Type type) {
906   return builder_->getNamedAttr(
907       "compare_type", builder_->getStringAttr(ComparisonTypeToString(type)));
908 }
909 
ConvertDimensions(llvm::ArrayRef<int64> op_dimensions)910 mlir::DenseIntElementsAttr HloFunctionImporter::ConvertDimensions(
911     llvm::ArrayRef<int64> op_dimensions) {
912   llvm::SmallVector<APInt, 8> dimensions;
913   dimensions.reserve(op_dimensions.size());
914   for (auto value : op_dimensions) dimensions.emplace_back(APInt(64, value));
915 
916   return DenseIntElementsAttr::get(
917       RankedTensorType::get(dimensions.size(), builder_->getIntegerType(64)),
918       dimensions);
919 }
920 
Convert(llvm::ArrayRef<int64_t> elements)921 mlir::DenseIntElementsAttr HloFunctionImporter::Convert(
922     llvm::ArrayRef<int64_t> elements) {
923   return DenseIntElementsAttr::get(
924       RankedTensorType::get(elements.size(), builder_->getIntegerType(64)),
925       elements);
926 }
927 
ConvertPadding(llvm::ArrayRef<int64_t> padding)928 mlir::NamedAttribute HloFunctionImporter::ConvertPadding(
929     llvm::ArrayRef<int64_t> padding) {
930   auto ty =
931       mlir::RankedTensorType::get({static_cast<int64_t>(padding.size()) / 2, 2},
932                                   builder_->getIntegerType(64));
933   auto attr = DenseIntElementsAttr::get(ty, padding);
934   return builder_->getNamedAttr("padding", attr);
935 }
936 
ConvertSourceTargetPairs(const std::vector<std::pair<tensorflow::int64,tensorflow::int64>> & source_target_pairs,mlir::Builder * builder)937 mlir::NamedAttribute HloFunctionImporter::ConvertSourceTargetPairs(
938     const std::vector<std::pair<tensorflow::int64, tensorflow::int64>>&
939         source_target_pairs,
940     mlir::Builder* builder) {
941   std::vector<int64_t> attr(source_target_pairs.size() * 2);
942   for (auto p : llvm::enumerate(source_target_pairs)) {
943     attr[2 * p.index()] = p.value().first;
944     attr[2 * p.index() + 1] = p.value().second;
945   }
946   auto type = mlir::RankedTensorType::get(
947       {static_cast<int64_t>(attr.size() / 2), 2}, builder->getIntegerType(64));
948   return builder->getNamedAttr("source_target_pairs",
949                                DenseIntElementsAttr::get(type, attr));
950 }
951 
ConvertReplicaGroups(const std::vector<ReplicaGroup> & replica_groups,mlir::Builder * builder)952 mlir::NamedAttribute HloFunctionImporter::ConvertReplicaGroups(
953     const std::vector<ReplicaGroup>& replica_groups, mlir::Builder* builder) {
954   const int64_t num_groups = replica_groups.size();
955   // Replica groups in HLO can be non-uniform in size, for example:
956   // replica_groups={{0},{1,2},{3}}. Since we are representing them as a 2D
957   // tensor, pad the smaller sized replica groups with -1.
958   const int64_t group_size = absl::c_accumulate(
959       replica_groups, int64_t(0), [](int64_t current, const ReplicaGroup& g) {
960         return std::max<int64_t>(current, g.replica_ids_size());
961       });
962   // Initialize all elements to -1 to support non-uniform replica groups.
963   std::vector<int64_t> attr(num_groups * group_size, -1);
964   for (int i = 0; i < num_groups; ++i) {
965     int index = i * group_size;
966     for (const int64& id : replica_groups[i].replica_ids()) attr[index++] = id;
967   }
968   auto type = mlir::RankedTensorType::get({num_groups, group_size},
969                                           builder->getIntegerType(64));
970   return builder->getNamedAttr("replica_groups",
971                                DenseIntElementsAttr::get(type, attr));
972 }
973 
ConvertChannelHandle(absl::optional<tensorflow::int64> channel_id)974 mlir::NamedAttribute HloFunctionImporter::ConvertChannelHandle(
975     absl::optional<tensorflow::int64> channel_id) {
976   xla::ChannelHandle channel_handle;
977   if (channel_id) channel_handle.set_handle(*channel_id);
978   return ConvertChannelHandle(channel_handle);
979 }
980 
ConvertChannelHandle(const xla::ChannelHandle & channel)981 mlir::NamedAttribute HloFunctionImporter::ConvertChannelHandle(
982     const xla::ChannelHandle& channel) {
983   return builder_->getNamedAttr(
984       "channel_handle",
985       mlir::mhlo::ChannelHandle::get(
986           builder_->getI64IntegerAttr(channel.handle()),
987           builder_->getI64IntegerAttr(channel.type()), context_));
988 }
989 
SetLayoutForMlir(mlir::Operation * op,const Shape & shape)990 void HloFunctionImporter::SetLayoutForMlir(mlir::Operation* op,
991                                            const Shape& shape) {
992   llvm::SmallVector<int64_t, 4> minor_to_major(
993       shape.layout().minor_to_major().begin(),
994       shape.layout().minor_to_major().end());
995   op->setAttr(
996       "minor_to_major",
997       mlir::Builder(op->getContext()).getIndexTensorAttr(minor_to_major));
998 }
999 
1000 }  // namespace xla
1001