• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
17 
18 #include <unordered_map>
19 
20 #include "absl/algorithm/container.h"
21 #include "absl/types/optional.h"
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/Support/raw_ostream.h"
26 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
27 #include "mlir/IR/Attributes.h"  // from @llvm-project
28 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
29 #include "mlir/IR/Builders.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
31 #include "mlir/IR/Identifier.h"  // from @llvm-project
32 #include "mlir/IR/Location.h"  // from @llvm-project
33 #include "mlir/IR/Region.h"  // from @llvm-project
34 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
35 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
36 #include "tensorflow/compiler/mlir/xla/attribute_importer.h"
37 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
38 #include "tensorflow/compiler/xla/comparison_util.h"
39 #include "tensorflow/compiler/xla/protobuf_util.h"
40 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
41 #include "tensorflow/compiler/xla/service/hlo_computation.h"
42 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
43 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
44 #include "tensorflow/compiler/xla/service/hlo_module.h"
45 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
46 #include "tensorflow/compiler/xla/status_macros.h"
47 #include "tensorflow/compiler/xla/xla_data.pb.h"
48 
49 using llvm::APInt;
50 using llvm::makeArrayRef;
51 using mlir::DenseIntElementsAttr;
52 using mlir::FuncOp;
53 using mlir::NamedAttribute;
54 using mlir::Operation;
55 using mlir::RankedTensorType;
56 using mlir::Type;
57 using mlir::Value;
58 
59 namespace xla {
60 
61 namespace {
62 
63 // Note: This sanitization function causes an irreversible many-to-one mapping
64 // and any solution to mitigate this would cause issues with the reverse
65 // direction. Longterm solution is to add a function attribute to maintain the
66 // original HLO naming.
SanitizeFunctionName(llvm::StringRef name)67 string SanitizeFunctionName(llvm::StringRef name) {
68   string output(name);
69   llvm::for_each(output, [](char& x) { x = x == '-' ? '_' : x; });
70   return output;
71 }
72 
73 // Returns whether the instruction is a default dot operation.
DotIsDefault(const HloInstruction * instruction)74 bool DotIsDefault(const HloInstruction* instruction) {
75   auto dnums = instruction->dot_dimension_numbers();
76   DotDimensionNumbers default_dimension_numbers;
77   default_dimension_numbers.add_lhs_contracting_dimensions(
78       instruction->operand(0)->shape().dimensions_size() == 1 ? 0 : 1);
79   default_dimension_numbers.add_rhs_contracting_dimensions(0);
80   return xla::protobuf_util::ProtobufEquals(dnums, default_dimension_numbers);
81 }
82 
83 // Returns an MLIR Location generated from HLO Instruction. Uses instruction
84 // metadata if present or instruction name.
GenerateInstructionLocation(const HloInstruction * instruction,mlir::OpBuilder * func_builder)85 mlir::Location GenerateInstructionLocation(const HloInstruction* instruction,
86                                            mlir::OpBuilder* func_builder) {
87   const std::string& op_name = instruction->metadata().op_name();
88   if (op_name.empty()) {
89     return mlir::NameLoc::get(func_builder->getIdentifier(instruction->name()));
90   }
91 
92   mlir::Location op_name_loc =
93       mlir::NameLoc::get(func_builder->getIdentifier(op_name));
94   const std::string& source_file = instruction->metadata().source_file();
95   if (source_file.empty()) {
96     return op_name_loc;
97   }
98 
99   return func_builder->getFusedLoc(
100       {op_name_loc,
101        mlir::FileLineColLoc::get(func_builder->getContext(), source_file,
102                                  instruction->metadata().source_line(), 0)});
103 }
104 }  // namespace
105 
ImportAsFunc(const HloComputation & computation,mlir::ModuleOp module,std::unordered_map<const HloComputation *,FuncOp> * function_map,mlir::Builder * builder)106 Status HloFunctionImporter::ImportAsFunc(
107     const HloComputation& computation, mlir::ModuleOp module,
108     std::unordered_map<const HloComputation*, FuncOp>* function_map,
109     mlir::Builder* builder) {
110   HloFunctionImporter importer(module, function_map, builder);
111   return importer.ImportAsFunc(computation).status();
112 }
113 
ImportAsRegion(const xla::HloComputation & computation,mlir::Region * region,mlir::Builder * builder)114 Status HloFunctionImporter::ImportAsRegion(
115     const xla::HloComputation& computation, mlir::Region* region,
116     mlir::Builder* builder) {
117   HloFunctionImporter importer(region->getParentOfType<mlir::ModuleOp>(), {},
118                                builder);
119   return importer.ImportAsRegion(computation, region);
120 }
121 
ImportAsFunc(const HloComputation & computation)122 StatusOr<mlir::FuncOp> HloFunctionImporter::ImportAsFunc(
123     const HloComputation& computation) {
124   auto& imported = (*function_map_)[&computation];
125   if (imported) return imported;
126   llvm::SmallVector<Type, 4> args, rets;
127   TF_RETURN_IF_ERROR(GetMlirTypes(computation.parameter_instructions(), &args));
128   TF_RETURN_IF_ERROR(GetMlirTypes({computation.root_instruction()}, &rets));
129   auto func_type = mlir::FunctionType::get(context_, args, rets);
130 
131   string computation_name =
132       computation.parent()->entry_computation() == &computation
133           ? "main"
134           : SanitizeFunctionName(computation.name());
135 
136   // Construct the MLIR function and map arguments.
137   llvm::ArrayRef<mlir::NamedAttribute> attrs;
138   auto function = mlir::FuncOp::create(mlir::UnknownLoc::get(context_),
139                                        computation_name, func_type, attrs);
140   auto visibility = computation_name == "main" ? FuncOp::Visibility::Public
141                                                : FuncOp::Visibility::Private;
142   function.setVisibility(visibility);
143   module_.push_back(function);
144 
145   // Add to the map right away for function calls.
146   imported = function;
147 
148   mlir::Block* block = function.addEntryBlock();
149   TF_RETURN_IF_ERROR(ImportInstructions(computation, block));
150 
151   return function;
152 }
153 
ImportAsRegion(const HloComputation & computation,mlir::Region * region)154 tensorflow::Status HloFunctionImporter::ImportAsRegion(
155     const HloComputation& computation, mlir::Region* region) {
156   // TODO(hinsu): Store computation name as an attribute for round-trip.
157   auto* block = new mlir::Block;
158   region->push_back(block);
159 
160   llvm::SmallVector<Type, 4> args;
161   TF_RETURN_IF_ERROR(GetMlirTypes(computation.parameter_instructions(), &args));
162   block->addArguments(args);
163 
164   return ImportInstructions(computation, block);
165 }
166 
ImportInstructionsImpl(const xla::HloComputation & computation,const llvm::SmallVectorImpl<Value> & arguments,mlir::OpBuilder * builder)167 StatusOr<Value> HloFunctionImporter::ImportInstructionsImpl(
168     const xla::HloComputation& computation,
169     const llvm::SmallVectorImpl<Value>& arguments, mlir::OpBuilder* builder) {
170   // Setup the input parameters.
171   const int num_parameters = computation.num_parameters();
172 
173   if (arguments.size() != num_parameters)
174     return InvalidArgument("Caller vs callee argument sizes do not match");
175 
176   for (int i = 0; i < num_parameters; i++) {
177     auto hlo_parameter = computation.parameter_instruction(i);
178     instruction_value_map_[hlo_parameter] = arguments[i];
179   }
180 
181   for (auto instruction : computation.MakeInstructionPostOrder()) {
182     TF_ASSIGN_OR_RETURN(auto operands, GetOperands(instruction));
183     TF_ASSIGN_OR_RETURN(
184         auto new_operation,
185         ImportInstructionWithLayout(instruction, operands, builder));
186     if (new_operation) {
187       instruction_value_map_[instruction] = new_operation->getResult(0);
188     }
189   }
190 
191   // Setup the return type (HLO only supports a single return value).
192   return GetMlirValue(computation.root_instruction());
193 }
194 
ImportInstructions(const HloComputation & computation,mlir::Block * block)195 Status HloFunctionImporter::ImportInstructions(
196     const HloComputation& computation, mlir::Block* block) {
197   llvm::SmallVector<Value, 4> arguments(block->args_begin(), block->args_end());
198   mlir::OpBuilder builder = mlir::OpBuilder::atBlockEnd(block);
199   TF_ASSIGN_OR_RETURN(Value result,
200                       ImportInstructionsImpl(computation, arguments, &builder));
201 
202   // TODO(suderman): Add location tracking details.
203   mlir::Location loc = builder.getUnknownLoc();
204 
205   // Create terminator op depending on the parent op of this region.
206   if (llvm::isa<FuncOp>(block->getParentOp())) {
207     builder.create<mlir::ReturnOp>(loc, result);
208   } else {
209     builder.create<mlir::mhlo::ReturnOp>(loc, result);
210   }
211   return tensorflow::Status::OK();
212 }
213 
ImportInstructions(const xla::HloComputation & computation,const llvm::SmallVectorImpl<Value> & arguments,mlir::OpBuilder * builder)214 StatusOr<Value> HloFunctionImporter::ImportInstructions(
215     const xla::HloComputation& computation,
216     const llvm::SmallVectorImpl<Value>& arguments, mlir::OpBuilder* builder) {
217   mlir::Block* block = builder->getBlock();
218   if (block == nullptr)
219     return InvalidArgument(
220         "ImportInstructions requires a valid block in the builder");
221 
222   HloFunctionImporter importer(
223       block->getParent()->getParentOfType<mlir::ModuleOp>(), {}, builder);
224   return importer.ImportInstructionsImpl(computation, arguments, builder);
225 }
226 
ImportInstruction(const xla::HloInstruction * instr,const llvm::SmallVectorImpl<mlir::Value> & operands,mlir::OpBuilder * builder)227 StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
228     const xla::HloInstruction* instr,
229     const llvm::SmallVectorImpl<mlir::Value>& operands,
230     mlir::OpBuilder* builder) {
231   mlir::Block* block = builder->getBlock();
232   if (block == nullptr)
233     return InvalidArgument(
234         "ImportInstructions requires a valid block in the builder");
235 
236   HloFunctionImporter importer(
237       block->getParent()->getParentOfType<mlir::ModuleOp>(), {}, builder);
238 
239   return importer.ImportInstructionWithLayout(instr, operands, builder);
240 }
241 
ImportInstructionImpl(const HloInstruction * instruction,const llvm::SmallVectorImpl<mlir::Value> & operands,mlir::OpBuilder * func_builder)242 StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
243     const HloInstruction* instruction,
244     const llvm::SmallVectorImpl<mlir::Value>& operands,
245     mlir::OpBuilder* func_builder) {
246   TF_ASSIGN_OR_RETURN(auto result_type, ConvertShapeToType<RankedTensorType>(
247                                             instruction->shape(), *builder_));
248   mlir::Location loc = GenerateInstructionLocation(instruction, func_builder);
249 
250   llvm::SmallVector<NamedAttribute, 10> attributes;
251   switch (instruction->opcode()) {
252     case HloOpcode::kParameter: {
253       return nullptr;
254     }
255     case HloOpcode::kConstant: {
256       const Literal& literal = instruction->literal();
257       auto attr = CreateDenseElementsAttrFromLiteral(literal, *builder_);
258       if (!attr.ok()) return attr.status();
259       mlir::Operation* new_operation =
260           func_builder->create<mlir::mhlo::ConstOp>(loc, attr.ValueOrDie());
261       for (auto attr : attributes) {
262         new_operation->setAttr(attr.first, attr.second);
263       }
264       return new_operation;
265     }
266     case HloOpcode::kIota: {
267       return func_builder
268           ->create<mlir::mhlo::IotaOp>(
269               loc, result_type,
270               func_builder->getI64IntegerAttr(
271                   Cast<HloIotaInstruction>(instruction)->iota_dimension()))
272           .getOperation();
273     }
274 #define MakeAndReturn(mlir_op)                                                \
275   {                                                                           \
276     mlir::Operation* new_operation =                                          \
277         func_builder->create<mlir::mhlo::mlir_op>(loc, result_type, operands, \
278                                                   attributes);                \
279     return new_operation;                                                     \
280   }
281     case HloOpcode::kBroadcast: {
282       // Note that the HLO broadcast is more powerful than the XLA broadcast
283       // op. BroadcastInDim offers a superset of the HLO op's functionality.
284       attributes.push_back(
285           builder_->getNamedAttr("broadcast_dimensions",
286                                  ConvertDimensions(instruction->dimensions())));
287       MakeAndReturn(BroadcastInDimOp);
288     }
289 #define MakeAndReturnBatchNormOp(batch_norm_op)                         \
290   {                                                                     \
291     attributes.push_back(builder_->getNamedAttr(                        \
292         "epsilon", builder_->getF32FloatAttr(instruction->epsilon()))); \
293     attributes.push_back(builder_->getNamedAttr(                        \
294         "feature_index",                                                \
295         builder_->getI64IntegerAttr(instruction->feature_index())));    \
296     MakeAndReturn(batch_norm_op);                                       \
297   }
298     case HloOpcode::kBatchNormGrad:
299       MakeAndReturnBatchNormOp(BatchNormGradOp);
300     case HloOpcode::kBatchNormInference:
301       MakeAndReturnBatchNormOp(BatchNormInferenceOp);
302     case HloOpcode::kBatchNormTraining:
303       MakeAndReturnBatchNormOp(BatchNormTrainingOp);
304 #undef MakeAndReturnBatchNormOp
305 
306     case HloOpcode::kDot: {
307       attributes.push_back(builder_->getNamedAttr(
308           "precision_config",
309           ConvertPrecisionConfig(&instruction->precision_config(), builder_)));
310 
311       // Consider consolidating DotOps together.
312       if (DotIsDefault(instruction)) {
313         MakeAndReturn(DotOp);
314       }
315 
316       attributes.push_back(builder_->getNamedAttr(
317           "dot_dimension_numbers",
318           ConvertDotDimensionNumbers(instruction->dot_dimension_numbers(),
319                                      builder_)));
320       MakeAndReturn(DotGeneralOp);
321     }
322     case HloOpcode::kCall: {
323       TF_ASSIGN_OR_RETURN(FuncOp function,
324                           ImportAsFunc(*instruction->to_apply()));
325       mlir::Operation* new_operation =
326           func_builder->create<mlir::CallOp>(loc, function, operands);
327       return new_operation;
328     }
329     case HloOpcode::kCollectivePermute: {
330       attributes.push_back(ConvertSourceTargetPairs(
331           instruction->source_target_pairs(), builder_));
332       MakeAndReturn(CollectivePermuteOp);
333     }
334     case HloOpcode::kCustomCall: {
335       auto custom_call = Cast<HloCustomCallInstruction>(instruction);
336       TF_ASSIGN_OR_RETURN(
337           auto mlir_api_version,
338           ConvertCustomCallApiVersion(custom_call->api_version()));
339       attributes.push_back(builder_->getNamedAttr(
340           "call_target_name",
341           builder_->getStringAttr(custom_call->custom_call_target())));
342       attributes.push_back(builder_->getNamedAttr(
343           "has_side_effect",
344           builder_->getBoolAttr(custom_call->custom_call_has_side_effect())));
345       attributes.push_back(builder_->getNamedAttr(
346           "backend_config",
347           builder_->getStringAttr(custom_call->raw_backend_config_string())));
348       attributes.push_back(builder_->getNamedAttr(
349           "api_version", mlir::mhlo::CustomCallApiVersionAttr::get(
350                              builder_->getContext(), mlir_api_version)));
351       MakeAndReturn(CustomCallOp);
352     }
353     case HloOpcode::kCompare: {
354       auto compare = Cast<HloCompareInstruction>(instruction);
355       attributes.push_back(ConvertComparisonDirection(compare->direction()));
356       auto default_type = Comparison::DefaultComparisonType(
357           compare->operand(0)->shape().element_type());
358       if (compare->type() != default_type)
359         attributes.push_back(ConvertComparisonType(compare->type()));
360       MakeAndReturn(CompareOp);
361     }
362     case HloOpcode::kCholesky: {
363       attributes.push_back(builder_->getNamedAttr(
364           "lower",
365           builder_->getBoolAttr(instruction->cholesky_options().lower())));
366       MakeAndReturn(CholeskyOp);
367     }
368     case HloOpcode::kGather: {
369       auto gather_instruction = Cast<HloGatherInstruction>(instruction);
370       attributes.push_back(builder_->getNamedAttr(
371           "dimension_numbers",
372           ConvertGatherDimensionNumbers(
373               gather_instruction->gather_dimension_numbers(), builder_)));
374 
375       std::vector<int64_t> slice_sizes(
376           gather_instruction->gather_slice_sizes().begin(),
377           gather_instruction->gather_slice_sizes().end());
378       attributes.push_back(
379           builder_->getNamedAttr("slice_sizes", Convert(slice_sizes)));
380       attributes.push_back(builder_->getNamedAttr(
381           "indices_are_sorted",
382           builder_->getBoolAttr(gather_instruction->indices_are_sorted())));
383 
384       MakeAndReturn(GatherOp);
385     }
386     case HloOpcode::kDynamicSlice: {
387       std::vector<int64_t> slice_sizes(
388           instruction->dynamic_slice_sizes().begin(),
389           instruction->dynamic_slice_sizes().end());
390       return func_builder
391           ->create<mlir::mhlo::DynamicSliceOp>(
392               loc, result_type, operands[0],
393               makeArrayRef(operands).drop_front(), Convert(slice_sizes))
394           .getOperation();
395     }
396     case HloOpcode::kDynamicUpdateSlice: {
397       return func_builder
398           ->create<mlir::mhlo::DynamicUpdateSliceOp>(
399               loc, result_type, operands[0], operands[1],
400               llvm::ArrayRef<Value>(operands.begin() + 2, operands.end()))
401           .getOperation();
402     }
403     case HloOpcode::kInfeed: {
404       attributes.push_back(builder_->getNamedAttr(
405           "infeed_config",
406           mlir::StringAttr::get(builder_->getContext(),
407                                 instruction->infeed_config())));
408       TF_ASSIGN_OR_RETURN(mlir::Attribute layout,
409                           ConvertShapeToMlirLayout(instruction->shape()));
410       attributes.push_back(builder_->getNamedAttr("layout", layout));
411       MakeAndReturn(InfeedOp);
412     }
413     case HloOpcode::kOutfeed: {
414       attributes.push_back(builder_->getNamedAttr(
415           "outfeed_config",
416           mlir::StringAttr::get(builder_->getContext(),
417                                 instruction->outfeed_config())));
418       MakeAndReturn(OutfeedOp);
419     }
420     case HloOpcode::kPad: {
421       const auto& padding_config = instruction->padding_config();
422       llvm::SmallVector<int64_t, 4> edge_padding_low;
423       llvm::SmallVector<int64_t, 4> edge_padding_high;
424       llvm::SmallVector<int64_t, 4> interior_padding;
425       edge_padding_low.reserve(padding_config.dimensions_size());
426       edge_padding_high.reserve(padding_config.dimensions_size());
427       interior_padding.reserve(padding_config.dimensions_size());
428 
429       for (const auto& dimension : padding_config.dimensions()) {
430         edge_padding_low.push_back(dimension.edge_padding_low());
431         edge_padding_high.push_back(dimension.edge_padding_high());
432         interior_padding.push_back(dimension.interior_padding());
433       }
434 
435       return func_builder
436           ->create<mlir::mhlo::PadOp>(loc, result_type, operands[0],
437                                       operands[1], Convert(edge_padding_low),
438                                       Convert(edge_padding_high),
439                                       Convert(interior_padding))
440           .getOperation();
441     }
442     case HloOpcode::kScatter: {
443       auto scatter = Cast<HloScatterInstruction>(instruction);
444       attributes.push_back(builder_->getNamedAttr(
445           "scatter_dimension_numbers",
446           ConvertScatterDimensionNumbers(scatter->scatter_dimension_numbers(),
447                                          builder_)));
448       attributes.push_back(builder_->getNamedAttr(
449           "indices_are_sorted",
450           builder_->getBoolAttr(scatter->indices_are_sorted())));
451       attributes.push_back(builder_->getNamedAttr(
452           "unique_indices", builder_->getBoolAttr(scatter->unique_indices())));
453 
454       auto scatter_op = func_builder->create<mlir::mhlo::ScatterOp>(
455           loc, result_type, operands, attributes);
456       TF_RETURN_IF_ERROR(ImportAsRegion(*scatter->to_apply(),
457                                         &scatter_op.update_computation()));
458       return scatter_op.getOperation();
459     }
460     case HloOpcode::kSelectAndScatter: {
461       auto select_scatter = Cast<HloSelectAndScatterInstruction>(instruction);
462       llvm::SmallVector<int64_t, 4> window_strides, window_dimensions;
463       llvm::SmallVector<int64_t, 8> padding;
464       for (const auto& dim : select_scatter->window().dimensions()) {
465         window_strides.push_back(dim.stride());
466         window_dimensions.push_back(dim.size());
467         padding.push_back(dim.padding_low());
468         padding.push_back(dim.padding_high());
469       }
470       attributes.push_back(
471           builder_->getNamedAttr("window_strides", Convert(window_strides)));
472       attributes.push_back(builder_->getNamedAttr("window_dimensions",
473                                                   Convert(window_dimensions)));
474       attributes.push_back(ConvertPadding(padding));
475       auto select_scatter_op =
476           func_builder->create<mlir::mhlo::SelectAndScatterOp>(
477               loc, result_type, operands, attributes);
478       TF_RETURN_IF_ERROR(ImportAsRegion(*select_scatter->select(),
479                                         &select_scatter_op.select()));
480       TF_RETURN_IF_ERROR(ImportAsRegion(*select_scatter->scatter(),
481                                         &select_scatter_op.scatter()));
482       return select_scatter_op.getOperation();
483     }
484     case HloOpcode::kSetDimensionSize: {
485       attributes.push_back(builder_->getNamedAttr(
486           "dimension", builder_->getI64IntegerAttr(instruction->dimension())));
487       MakeAndReturn(SetDimensionSizeOp);
488     }
489     case HloOpcode::kSlice: {
490       return func_builder
491           ->create<mlir::mhlo::SliceOp>(
492               loc, result_type, operands[0],
493               ConvertDimensions(instruction->slice_starts()),
494               ConvertDimensions(instruction->slice_limits()),
495               ConvertDimensions(instruction->slice_strides()))
496           .getOperation();
497     }
498     case HloOpcode::kSort: {
499       auto sort_instruction = Cast<HloSortInstruction>(instruction);
500 
501       llvm::SmallVector<Type, 4> return_types = {result_type};
502       if (mlir::TupleType tuple_ty = result_type.dyn_cast<mlir::TupleType>()) {
503         return_types = llvm::to_vector<6>(tuple_ty.getTypes());
504       }
505 
506       auto sort_op = func_builder->create<mlir::mhlo::SortOp>(
507           loc, return_types, operands,
508           builder_->getI64IntegerAttr(sort_instruction->sort_dimension()),
509           builder_->getBoolAttr(sort_instruction->is_stable()));
510       TF_RETURN_IF_ERROR(
511           ImportAsRegion(*sort_instruction->to_apply(), &sort_op.comparator()));
512 
513       // Check if the output needs to be tupled.
514       if (return_types.size() == 1 && return_types.front() == result_type) {
515         return sort_op.getOperation();
516       }
517 
518       return func_builder
519           ->create<mlir::mhlo::TupleOp>(loc, result_type, sort_op.getResults())
520           .getOperation();
521     }
522     case HloOpcode::kConditional: {
523       llvm::SmallVector<Type, 4> rets;
524       mlir::Type pred_or_index_type =
525           operands[0].getType().cast<mlir::TensorType>().getElementType();
526       // It is a predicated conditional if first argument is a boolean and
527       // should be mapped to If op.
528       if (pred_or_index_type.isInteger(1)) {
529         TF_RETURN_IF_ERROR(GetMlirTypes(
530             {instruction->true_computation()->root_instruction()}, &rets));
531 
532         auto op = func_builder->create<mlir::mhlo::IfOp>(loc, rets, operands,
533                                                          attributes);
534         TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->true_computation(),
535                                           &op.true_branch()));
536         TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->false_computation(),
537                                           &op.false_branch()));
538         return op.getOperation();
539       }
540 
541       // Otherwise, it is a indexed conditional and should be mapped to Case
542       // op.
543       TF_RETURN_IF_ERROR(GetMlirTypes(
544           {instruction->branch_computation(0)->root_instruction()}, &rets));
545 
546       int num_branches = instruction->branch_count();
547       auto op = func_builder->create<mlir::mhlo::CaseOp>(
548           loc, rets, operands, attributes, num_branches);
549       for (auto index_and_computation :
550            llvm::enumerate(instruction->branch_computations())) {
551         auto index = index_and_computation.index();
552         HloComputation* computation = index_and_computation.value();
553         TF_RETURN_IF_ERROR(ImportAsRegion(*computation, &op.branches()[index]));
554       }
555       return op.getOperation();
556     }
557     case HloOpcode::kConcatenate: {
558       // TODO(b/132057942): Support taking an uint64_t instead of an
559       // IntegerAttr for concatenate dimension.
560       return func_builder
561           ->create<mlir::mhlo::ConcatenateOp>(
562               loc, result_type, operands,
563               builder_->getI64IntegerAttr(instruction->concatenate_dimension()))
564           .getOperation();
565     }
566     case HloOpcode::kAllGather: {
567       auto all_gather = Cast<HloAllGatherInstruction>(instruction);
568       attributes.push_back(builder_->getNamedAttr(
569           "all_gather_dim",
570           builder_->getI64IntegerAttr(all_gather->all_gather_dimension())));
571       attributes.push_back(
572           ConvertReplicaGroups(all_gather->replica_groups(), builder_));
573       attributes.push_back(ConvertChannelHandle(all_gather->channel_id()));
574       MakeAndReturn(AllGatherOp);
575     }
576     case HloOpcode::kAllReduce: {
577       auto all_reduce = Cast<HloAllReduceInstruction>(instruction);
578       attributes.push_back(
579           ConvertReplicaGroups(all_reduce->replica_groups(), builder_));
580       attributes.push_back(ConvertChannelHandle(all_reduce->channel_id()));
581       auto all_reduce_op = func_builder->create<mlir::mhlo::AllReduceOp>(
582           loc, result_type, operands, attributes);
583       TF_RETURN_IF_ERROR(ImportAsRegion(*all_reduce->to_apply(),
584                                         &all_reduce_op.computation()));
585       return all_reduce_op.getOperation();
586     }
587     case HloOpcode::kReduce: {
588       // Operands in the first half are reduction inputs and the remaining
589       // operands are corresponding initial values.
590       size_t num_inputs = operands.size() / 2;
591       auto reduce = func_builder->create<mlir::mhlo::ReduceOp>(
592           loc, result_type, llvm::makeArrayRef(operands).take_front(num_inputs),
593           llvm::makeArrayRef(operands).drop_front(num_inputs),
594           ConvertDimensions(instruction->dimensions()));
595       TF_RETURN_IF_ERROR(
596           ImportAsRegion(*instruction->to_apply(), &reduce.body()));
597       return reduce.getOperation();
598     }
599     case HloOpcode::kReverse: {
600       return func_builder
601           ->create<mlir::mhlo::ReverseOp>(
602               loc, result_type, operands[0],
603               ConvertDimensions(instruction->dimensions()))
604           .getOperation();
605     }
606     case HloOpcode::kRng: {
607       auto shape = func_builder->create<mlir::ConstantOp>(
608           loc, Convert(result_type.cast<RankedTensorType>().getShape()));
609       switch (instruction->random_distribution()) {
610         case xla::RNG_UNIFORM:
611           return func_builder
612               ->create<mlir::mhlo::RngUniformOp>(loc, result_type, operands[0],
613                                                  operands[1], shape)
614               .getOperation();
615 
616         case xla::RNG_NORMAL:
617           return func_builder
618               ->create<mlir::mhlo::RngNormalOp>(loc, result_type, operands[0],
619                                                 operands[1], shape)
620               .getOperation();
621 
622         default:
623           return tensorflow::errors::InvalidArgument(absl::StrCat(
624               "Unsupported distribution: ",
625               RandomDistributionToString(instruction->random_distribution())));
626       }
627     }
628     case HloOpcode::kRngBitGenerator: {
629       auto rng_op = Cast<HloRngBitGeneratorInstruction>(instruction);
630       auto op = func_builder->create<mlir::mhlo::RngBitGeneratorOp>(
631           loc, result_type,
632           func_builder->getI32IntegerAttr(rng_op->algorithm()), operands[0]);
633       return op.getOperation();
634     }
635     case HloOpcode::kWhile: {
636       auto op = func_builder->create<mlir::mhlo::WhileOp>(
637           loc, operands[0].getType(), operands[0]);
638       TF_RETURN_IF_ERROR(
639           ImportAsRegion(*instruction->while_condition(), &op.cond()));
640       TF_RETURN_IF_ERROR(
641           ImportAsRegion(*instruction->while_body(), &op.body()));
642       return op.getOperation();
643     }
644     case HloOpcode::kGetTupleElement: {
645       attributes.push_back(builder_->getNamedAttr(
646           "index", builder_->getIntegerAttr(builder_->getIntegerType(32),
647                                             instruction->tuple_index())));
648       MakeAndReturn(GetTupleElementOp);
649     };
650     case HloOpcode::kGetDimensionSize: {
651       attributes.push_back(builder_->getNamedAttr(
652           "dimension", builder_->getI64IntegerAttr(instruction->dimension())));
653       MakeAndReturn(GetDimensionSizeOp);
654     };
655     case HloOpcode::kTranspose: {
656       attributes.push_back(builder_->getNamedAttr(
657           "permutation", ConvertDimensions(instruction->dimensions())));
658       MakeAndReturn(TransposeOp);
659     }
660     case HloOpcode::kTriangularSolve: {
661       attributes.push_back(builder_->getNamedAttr(
662           "left_side",
663           builder_->getBoolAttr(
664               instruction->triangular_solve_options().left_side())));
665       attributes.push_back(builder_->getNamedAttr(
666           "lower", builder_->getBoolAttr(
667                        instruction->triangular_solve_options().lower())));
668       attributes.push_back(builder_->getNamedAttr(
669           "unit_diagonal",
670           builder_->getBoolAttr(
671               instruction->triangular_solve_options().unit_diagonal())));
672       auto transpose_a =
673           builder_->getStringAttr(TriangularSolveOptions::Transpose_Name(
674               instruction->triangular_solve_options().transpose_a()));
675       attributes.push_back(builder_->getNamedAttr("transpose_a", transpose_a));
676       MakeAndReturn(TriangularSolveOp);
677     }
678     case HloOpcode::kReduceWindow: {
679       llvm::SmallVector<int64, 4> sizes, strides, base_dilations, win_dilations;
680       llvm::SmallVector<int64_t, 8> padding;
681       for (const auto& dim : instruction->window().dimensions()) {
682         sizes.push_back(dim.size());
683         strides.push_back(dim.stride());
684         base_dilations.push_back(dim.base_dilation());
685         win_dilations.push_back(dim.window_dilation());
686         padding.push_back(dim.padding_low());
687         padding.push_back(dim.padding_high());
688       }
689       attributes.push_back(builder_->getNamedAttr("window_dimensions",
690                                                   ConvertDimensions(sizes)));
691       attributes.push_back(
692           builder_->getNamedAttr("window_strides", ConvertDimensions(strides)));
693       attributes.push_back(builder_->getNamedAttr(
694           "base_dilations", ConvertDimensions(base_dilations)));
695       attributes.push_back(builder_->getNamedAttr(
696           "window_dilations", ConvertDimensions(win_dilations)));
697       attributes.push_back(ConvertPadding(padding));
698       auto reduce = func_builder->create<mlir::mhlo::ReduceWindowOp>(
699           loc, result_type, operands, attributes);
700       TF_RETURN_IF_ERROR(
701           ImportAsRegion(*instruction->to_apply(), &reduce.body()));
702       return reduce.getOperation();
703     }
704     case HloOpcode::kMap: {
705       auto op = func_builder->create<mlir::mhlo::MapOp>(
706           loc, result_type, operands,
707           ConvertDimensions(instruction->dimensions()));
708       TF_RETURN_IF_ERROR(
709           ImportAsRegion(*instruction->to_apply(), &op.computation()));
710       return op.getOperation();
711     }
712     case HloOpcode::kConvolution: {
713       llvm::SmallVector<int64_t, 4> strides, lhs_dilations, rhs_dilations;
714       llvm::SmallVector<int64_t, 8> paddings;
715       for (const auto& dim : instruction->window().dimensions()) {
716         strides.push_back(dim.stride());
717         lhs_dilations.push_back(dim.base_dilation());
718         rhs_dilations.push_back(dim.window_dilation());
719         paddings.push_back(dim.padding_low());
720         paddings.push_back(dim.padding_high());
721       }
722 
723       attributes.push_back(
724           builder_->getNamedAttr("window_strides", Convert(strides)));
725       attributes.push_back(ConvertPadding(paddings));
726       attributes.push_back(
727           builder_->getNamedAttr("lhs_dilation", Convert(lhs_dilations)));
728       attributes.push_back(
729           builder_->getNamedAttr("rhs_dilation", Convert(rhs_dilations)));
730       attributes.push_back(builder_->getNamedAttr(
731           "dimension_numbers",
732           ConvertConvDimensionNumbers(
733               instruction->convolution_dimension_numbers(), builder_)));
734       attributes.push_back(builder_->getNamedAttr(
735           "feature_group_count",
736           builder_->getI64IntegerAttr(instruction->feature_group_count())));
737       attributes.push_back(builder_->getNamedAttr(
738           "batch_group_count",
739           builder_->getI64IntegerAttr(instruction->batch_group_count())));
740       attributes.push_back(builder_->getNamedAttr(
741           "precision_config",
742           ConvertPrecisionConfig(&instruction->precision_config(), builder_)));
743 
744       MakeAndReturn(ConvOp);
745     }
746 
747     case HloOpcode::kFft: {
748       auto fft_type =
749           builder_->getStringAttr(FftType_Name(instruction->fft_type()));
750 
751       std::vector<int64_t> fft_length(instruction->fft_length().begin(),
752                                       instruction->fft_length().end());
753 
754       attributes.push_back(builder_->getNamedAttr("fft_type", fft_type));
755       attributes.push_back(
756           builder_->getNamedAttr("fft_length", Convert(fft_length)));
757       MakeAndReturn(FftOp);
758     }
759 
760     case HloOpcode::kAdd: {
761       // HLO add ops on PRED elements are actually boolean or, but MHLO dialect
762       // AddOps on i1 are just addition with overflow; so, we have to implement
763       // the special behavior of HLO add ops on PRED here by creating an OrOp
764       // instead.
765       if (instruction->shape().element_type() == PRED) {
766         MakeAndReturn(OrOp);
767       } else {
768         MakeAndReturn(AddOp);
769       }
770     }
771 #define NoAttributeCase(hlo_op_code, mlir_op) \
772   case HloOpcode::hlo_op_code: {              \
773     MakeAndReturn(mlir_op);                   \
774   }
775 
776       // broadcast dimensions are never added here because they don't exist as
777       // part of the HLO instruction. They are only a convenience in the XLA
778       // builder API.
779       NoAttributeCase(kAbs, AbsOp);
780       NoAttributeCase(kAfterAll, AfterAllOp);
781       NoAttributeCase(kAnd, AndOp);
782       NoAttributeCase(kAtan2, Atan2Op);
783       NoAttributeCase(kBitcastConvert, BitcastConvertOp);
784       NoAttributeCase(kCbrt, CbrtOp);
785       NoAttributeCase(kClz, ClzOp);
786       NoAttributeCase(kConvert, ConvertOp);
787       NoAttributeCase(kCeil, CeilOp);
788       NoAttributeCase(kClamp, ClampOp);
789       NoAttributeCase(kComplex, ComplexOp);
790       NoAttributeCase(kCos, CosOp);
791       NoAttributeCase(kDivide, DivOp);
792       NoAttributeCase(kExp, ExpOp);
793       NoAttributeCase(kExpm1, Expm1Op);
794       NoAttributeCase(kFloor, FloorOp);
795       NoAttributeCase(kIsFinite, IsFiniteOp);
796       NoAttributeCase(kImag, ImagOp);
797       NoAttributeCase(kLog, LogOp);
798       NoAttributeCase(kLog1p, Log1pOp);
799       NoAttributeCase(kMaximum, MaxOp);
800       NoAttributeCase(kMinimum, MinOp);
801       NoAttributeCase(kMultiply, MulOp);
802       NoAttributeCase(kNegate, NegOp);
803       NoAttributeCase(kNot, NotOp);
804       NoAttributeCase(kOr, OrOp);
805       NoAttributeCase(kPopulationCount, PopulationCountOp);
806       NoAttributeCase(kPower, PowOp);
807       NoAttributeCase(kReal, RealOp);
808       NoAttributeCase(kRemainder, RemOp);
809       NoAttributeCase(kReplicaId, ReplicaIdOp);
810       NoAttributeCase(kLogistic, LogisticOp);
811       // The dimensions attribute is not present on the HLO Reshape
812       // instruction. If dimensions are non-default, the XLA builder
813       // implements it as a separate transpose.
814       NoAttributeCase(kReshape, ReshapeOp);
815       NoAttributeCase(kRoundNearestAfz, RoundOp);
816       NoAttributeCase(kRsqrt, RsqrtOp);
817       NoAttributeCase(kSelect, SelectOp);
818       NoAttributeCase(kShiftLeft, ShiftLeftOp);
819       NoAttributeCase(kShiftRightArithmetic, ShiftRightArithmeticOp);
820       NoAttributeCase(kShiftRightLogical, ShiftRightLogicalOp);
821       NoAttributeCase(kSign, SignOp);
822       NoAttributeCase(kSin, SinOp);
823       NoAttributeCase(kSqrt, SqrtOp);
824       NoAttributeCase(kSubtract, SubOp);
825       NoAttributeCase(kTanh, TanhOp);
826       NoAttributeCase(kTuple, TupleOp);
827       NoAttributeCase(kXor, XorOp);
828       // TODO(b/129422361) Copy needs special handling because it is not
829       // defined in tensorflow/compiler/xla/client/xla_builder.h. See
830       // operation semantics in
831       // g3doc/platforms/xla/g3doc/internal/hlo_semantics#copy
832       NoAttributeCase(kCopy, CopyOp);
833 #undef NoAttributeCase
834 #undef MakeAndReturn
835     case HloOpcode::kFusion: {
836       auto fusion = func_builder->create<mlir::mhlo::FusionOp>(
837           loc, result_type, operands,
838           builder_->getStringAttr(xla::ToString(instruction->fusion_kind())));
839       TF_RETURN_IF_ERROR(
840           ImportAsRegion(*instruction->fused_instructions_computation(),
841                          &fusion.fused_computation()));
842       return fusion.getOperation();
843     }
844     case HloOpcode::kBitcast: {
845       auto bitcast = func_builder->create<mlir::mhlo::BitcastOp>(
846           loc, result_type, operands, attributes);
847       // Store the source and result layout as attributes. Although the MHLO
848       // Bitcast operates on tensors, these layouts are relevant as they define
849       // the mapping between the elements of the source and result.
850       SetLayoutForMlir(bitcast, instruction->shape(), "result_layout");
851       SetLayoutForMlir(bitcast, instruction->operand(0)->shape(),
852                        "source_layout");
853       return bitcast.getOperation();
854     }
855     case HloOpcode::kReducePrecision: {
856       auto op = func_builder->create<mlir::mhlo::ReducePrecisionOp>(
857           loc, result_type, operands[0], attributes);
858       op.exponent_bitsAttr(func_builder->getIntegerAttr(
859           func_builder->getI32Type(), instruction->exponent_bits()));
860       op.mantissa_bitsAttr(func_builder->getIntegerAttr(
861           func_builder->getI32Type(), instruction->mantissa_bits()));
862       return op.getOperation();
863     }
864     case HloOpcode::kAddDependency:
865       // Arbitrary op code that I suspect we will not implement for quite a
866       // while and allows testing handling of unknown ops. Selected because it
867       // is not mentioned in xla client anywhere or in the hlo of our sample
868       // models.
869     default: {
870       mlir::OperationState result(loc, "mhlo.unknown");
871       result.addOperands(operands);
872       result.addTypes(result_type);
873       for (auto attr : attributes) {
874         result.attributes.push_back(attr);
875       }
876 
877       return func_builder->createOperation(result);
878     }
879   }
880 }
881 
ImportInstructionWithLayout(const HloInstruction * instruction,const llvm::SmallVectorImpl<mlir::Value> & operands,mlir::OpBuilder * func_builder)882 StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionWithLayout(
883     const HloInstruction* instruction,
884     const llvm::SmallVectorImpl<mlir::Value>& operands,
885     mlir::OpBuilder* func_builder) {
886   TF_ASSIGN_OR_RETURN(
887       mlir::Operation * op,
888       ImportInstructionImpl(instruction, operands, func_builder));
889   if (op == nullptr) return op;
890 
891   // See MlirToHloConversionOptions for more about layouts.
892   //
893   // Minor-to-major is a permutation of [0, rank), presenting tensor dimensions
894   // in physical minor-to-major order.
895   if (instruction->shape().IsArray() &&
896       !instruction->shape().layout().minor_to_major().empty() &&
897       instruction->shape().layout() !=
898           LayoutUtil::MakeDescendingLayout(
899               instruction->shape().dimensions().size())) {
900     SetLayoutForMlir(op, instruction->shape());
901   }
902   return op;
903 }
904 
GetOperands(const HloInstruction * instruction)905 StatusOr<llvm::SmallVector<mlir::Value, 4>> HloFunctionImporter::GetOperands(
906     const HloInstruction* instruction) {
907   llvm::SmallVector<mlir::Value, 4> operands;
908   for (const auto& operand : instruction->operands()) {
909     auto input_it = instruction_value_map_.find(operand);
910     if (input_it == instruction_value_map_.end()) {
911       return tensorflow::errors::Internal(
912           absl::StrCat("Could not find input value: ", operand->name(),
913                        " for instruction ", instruction->name()));
914     }
915     operands.push_back(input_it->second);
916   }
917   return operands;
918 }
919 
GetMlirTypes(const std::vector<HloInstruction * > & instructions,llvm::SmallVectorImpl<mlir::Type> * types)920 tensorflow::Status HloFunctionImporter::GetMlirTypes(
921     const std::vector<HloInstruction*>& instructions,
922     llvm::SmallVectorImpl<mlir::Type>* types) {
923   for (auto instruction : instructions) {
924     TF_ASSIGN_OR_RETURN(auto ret_type, ConvertShapeToType<RankedTensorType>(
925                                            instruction->shape(), *builder_));
926     types->push_back(ret_type);
927   }
928   return tensorflow::Status::OK();
929 }
930 
GetMlirValue(const HloInstruction * instruction)931 StatusOr<Value> HloFunctionImporter::GetMlirValue(
932     const HloInstruction* instruction) {
933   auto lookup = instruction_value_map_.find(instruction);
934   if (lookup != instruction_value_map_.end()) {
935     return lookup->second;
936   }
937 
938   return tensorflow::errors::Internal(absl::StrCat(
939       "Unable to find value for input: ", instruction->ToString()));
940 }
941 
ConvertComparisonDirection(ComparisonDirection direction)942 mlir::NamedAttribute HloFunctionImporter::ConvertComparisonDirection(
943     ComparisonDirection direction) {
944   return builder_->getNamedAttr(
945       "comparison_direction",
946       builder_->getStringAttr(ComparisonDirectionToString(direction)));
947 }
948 
ConvertComparisonType(Comparison::Type type)949 mlir::NamedAttribute HloFunctionImporter::ConvertComparisonType(
950     Comparison::Type type) {
951   return builder_->getNamedAttr(
952       "compare_type", builder_->getStringAttr(ComparisonTypeToString(type)));
953 }
954 
ConvertDimensions(llvm::ArrayRef<int64> op_dimensions)955 mlir::DenseIntElementsAttr HloFunctionImporter::ConvertDimensions(
956     llvm::ArrayRef<int64> op_dimensions) {
957   llvm::SmallVector<APInt, 8> dimensions;
958   dimensions.reserve(op_dimensions.size());
959   for (auto value : op_dimensions) dimensions.emplace_back(APInt(64, value));
960 
961   return DenseIntElementsAttr::get(
962       RankedTensorType::get(dimensions.size(), builder_->getIntegerType(64)),
963       dimensions);
964 }
965 
Convert(llvm::ArrayRef<int64_t> elements)966 mlir::DenseIntElementsAttr HloFunctionImporter::Convert(
967     llvm::ArrayRef<int64_t> elements) {
968   return DenseIntElementsAttr::get(
969       RankedTensorType::get(elements.size(), builder_->getIntegerType(64)),
970       elements);
971 }
972 
ConvertPadding(llvm::ArrayRef<int64_t> padding)973 mlir::NamedAttribute HloFunctionImporter::ConvertPadding(
974     llvm::ArrayRef<int64_t> padding) {
975   auto ty =
976       mlir::RankedTensorType::get({static_cast<int64_t>(padding.size()) / 2, 2},
977                                   builder_->getIntegerType(64));
978   auto attr = DenseIntElementsAttr::get(ty, padding);
979   return builder_->getNamedAttr("padding", attr);
980 }
981 
ConvertSourceTargetPairs(const std::vector<std::pair<tensorflow::int64,tensorflow::int64>> & source_target_pairs,mlir::Builder * builder)982 mlir::NamedAttribute HloFunctionImporter::ConvertSourceTargetPairs(
983     const std::vector<std::pair<tensorflow::int64, tensorflow::int64>>&
984         source_target_pairs,
985     mlir::Builder* builder) {
986   std::vector<int64_t> attr(source_target_pairs.size() * 2);
987   for (auto p : llvm::enumerate(source_target_pairs)) {
988     attr[2 * p.index()] = p.value().first;
989     attr[2 * p.index() + 1] = p.value().second;
990   }
991   auto type = mlir::RankedTensorType::get(
992       {static_cast<int64_t>(attr.size() / 2), 2}, builder->getIntegerType(64));
993   return builder->getNamedAttr("source_target_pairs",
994                                DenseIntElementsAttr::get(type, attr));
995 }
996 
ConvertReplicaGroups(absl::Span<const ReplicaGroup> replica_groups,mlir::Builder * builder)997 mlir::NamedAttribute HloFunctionImporter::ConvertReplicaGroups(
998     absl::Span<const ReplicaGroup> replica_groups, mlir::Builder* builder) {
999   const int64_t num_groups = replica_groups.size();
1000   // Replica groups in HLO can be non-uniform in size, for example:
1001   // replica_groups={{0},{1,2},{3}}. Since we are representing them as a 2D
1002   // tensor, pad the smaller sized replica groups with -1.
1003   const int64_t group_size = absl::c_accumulate(
1004       replica_groups, int64_t(0), [](int64_t current, const ReplicaGroup& g) {
1005         return std::max<int64_t>(current, g.replica_ids_size());
1006       });
1007   // Initialize all elements to -1 to support non-uniform replica groups.
1008   std::vector<int64_t> attr(num_groups * group_size, -1);
1009   for (int i = 0; i < num_groups; ++i) {
1010     int index = i * group_size;
1011     for (const int64& id : replica_groups[i].replica_ids()) attr[index++] = id;
1012   }
1013   auto type = mlir::RankedTensorType::get({num_groups, group_size},
1014                                           builder->getIntegerType(64));
1015   return builder->getNamedAttr("replica_groups",
1016                                DenseIntElementsAttr::get(type, attr));
1017 }
1018 
ConvertChannelHandle(absl::optional<tensorflow::int64> channel_id)1019 mlir::NamedAttribute HloFunctionImporter::ConvertChannelHandle(
1020     absl::optional<tensorflow::int64> channel_id) {
1021   xla::ChannelHandle channel_handle;
1022   if (channel_id) channel_handle.set_handle(*channel_id);
1023   return ConvertChannelHandle(channel_handle);
1024 }
1025 
ConvertChannelHandle(const xla::ChannelHandle & channel)1026 mlir::NamedAttribute HloFunctionImporter::ConvertChannelHandle(
1027     const xla::ChannelHandle& channel) {
1028   return builder_->getNamedAttr(
1029       "channel_handle",
1030       mlir::mhlo::ChannelHandle::get(
1031           builder_->getI64IntegerAttr(channel.handle()),
1032           builder_->getI64IntegerAttr(channel.type()), context_));
1033 }
1034 
SetLayoutForMlir(mlir::Operation * op,const Shape & shape,llvm::StringRef attr_name)1035 void HloFunctionImporter::SetLayoutForMlir(mlir::Operation* op,
1036                                            const Shape& shape,
1037                                            llvm::StringRef attr_name) {
1038   llvm::SmallVector<int64_t, 4> minor_to_major(
1039       shape.layout().minor_to_major().begin(),
1040       shape.layout().minor_to_major().end());
1041   op->setAttr(
1042       attr_name,
1043       mlir::Builder(op->getContext()).getIndexTensorAttr(minor_to_major));
1044 }
1045 
ConvertShapeToMlirLayout(const xla::Shape & shape)1046 StatusOr<mlir::Attribute> HloFunctionImporter::ConvertShapeToMlirLayout(
1047     const xla::Shape& shape) {
1048   if (shape.IsToken()) return builder_->getUnitAttr();
1049   if (shape.IsTuple()) {
1050     std::vector<mlir::Attribute> tuple_layouts;
1051     for (int64_t i = 0; i < shape.tuple_shapes_size(); i++) {
1052       TF_ASSIGN_OR_RETURN(mlir::Attribute layout,
1053                           ConvertShapeToMlirLayout(shape.tuple_shapes(i)));
1054       tuple_layouts.push_back(layout);
1055     }
1056     llvm::ArrayRef<mlir::Attribute> array_ref(tuple_layouts);
1057     return builder_->getArrayAttr(array_ref);
1058   }
1059   if (shape.IsArray()) {
1060     const xla::Layout l = shape.layout();
1061     std::vector<mlir::Attribute> minor_to_major;
1062     for (int64_t i : l.minor_to_major()) {
1063       minor_to_major.push_back(builder_->getI64IntegerAttr(i));
1064     }
1065     llvm::ArrayRef<mlir::Attribute> array_ref(minor_to_major);
1066     return builder_->getArrayAttr(array_ref);
1067   }
1068   return tensorflow::errors::Internal("Couldn't convert layout.");
1069 }
1070 
1071 }  // namespace xla
1072