1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
17
18 #include <unordered_map>
19
20 #include "absl/algorithm/container.h"
21 #include "absl/types/optional.h"
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/Support/raw_ostream.h"
26 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
27 #include "mlir/IR/Attributes.h" // from @llvm-project
28 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
29 #include "mlir/IR/Builders.h" // from @llvm-project
30 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
31 #include "mlir/IR/Identifier.h" // from @llvm-project
32 #include "mlir/IR/Location.h" // from @llvm-project
33 #include "mlir/IR/Region.h" // from @llvm-project
34 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
35 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
36 #include "tensorflow/compiler/mlir/xla/attribute_importer.h"
37 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
38 #include "tensorflow/compiler/xla/comparison_util.h"
39 #include "tensorflow/compiler/xla/protobuf_util.h"
40 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
41 #include "tensorflow/compiler/xla/service/hlo_computation.h"
42 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
43 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
44 #include "tensorflow/compiler/xla/service/hlo_module.h"
45 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
46 #include "tensorflow/compiler/xla/status_macros.h"
47 #include "tensorflow/compiler/xla/xla_data.pb.h"
48
49 using llvm::APInt;
50 using llvm::makeArrayRef;
51 using mlir::DenseIntElementsAttr;
52 using mlir::FuncOp;
53 using mlir::NamedAttribute;
54 using mlir::Operation;
55 using mlir::RankedTensorType;
56 using mlir::Type;
57 using mlir::Value;
58
59 namespace xla {
60
61 namespace {
62
63 // Note: This sanitization function causes an irreversible many-to-one mapping
64 // and any solution to mitigate this would cause issues with the reverse
65 // direction. Longterm solution is to add a function attribute to maintain the
66 // original HLO naming.
SanitizeFunctionName(llvm::StringRef name)67 string SanitizeFunctionName(llvm::StringRef name) {
68 string output(name);
69 llvm::for_each(output, [](char& x) { x = x == '-' ? '_' : x; });
70 return output;
71 }
72
73 // Returns whether the instruction is a default dot operation.
DotIsDefault(const HloInstruction * instruction)74 bool DotIsDefault(const HloInstruction* instruction) {
75 auto dnums = instruction->dot_dimension_numbers();
76 DotDimensionNumbers default_dimension_numbers;
77 default_dimension_numbers.add_lhs_contracting_dimensions(
78 instruction->operand(0)->shape().dimensions_size() == 1 ? 0 : 1);
79 default_dimension_numbers.add_rhs_contracting_dimensions(0);
80 return xla::protobuf_util::ProtobufEquals(dnums, default_dimension_numbers);
81 }
82
83 // Returns an MLIR Location generated from HLO Instruction. Uses instruction
84 // metadata if present or instruction name.
GenerateInstructionLocation(const HloInstruction * instruction,mlir::OpBuilder * func_builder)85 mlir::Location GenerateInstructionLocation(const HloInstruction* instruction,
86 mlir::OpBuilder* func_builder) {
87 const std::string& op_name = instruction->metadata().op_name();
88 if (op_name.empty()) {
89 return mlir::NameLoc::get(func_builder->getIdentifier(instruction->name()));
90 }
91
92 mlir::Location op_name_loc =
93 mlir::NameLoc::get(func_builder->getIdentifier(op_name));
94 const std::string& source_file = instruction->metadata().source_file();
95 if (source_file.empty()) {
96 return op_name_loc;
97 }
98
99 return func_builder->getFusedLoc(
100 {op_name_loc,
101 mlir::FileLineColLoc::get(func_builder->getContext(), source_file,
102 instruction->metadata().source_line(), 0)});
103 }
104 } // namespace
105
ImportAsFunc(const HloComputation & computation,mlir::ModuleOp module,std::unordered_map<const HloComputation *,FuncOp> * function_map,mlir::Builder * builder)106 Status HloFunctionImporter::ImportAsFunc(
107 const HloComputation& computation, mlir::ModuleOp module,
108 std::unordered_map<const HloComputation*, FuncOp>* function_map,
109 mlir::Builder* builder) {
110 HloFunctionImporter importer(module, function_map, builder);
111 return importer.ImportAsFunc(computation).status();
112 }
113
ImportAsRegion(const xla::HloComputation & computation,mlir::Region * region,mlir::Builder * builder)114 Status HloFunctionImporter::ImportAsRegion(
115 const xla::HloComputation& computation, mlir::Region* region,
116 mlir::Builder* builder) {
117 HloFunctionImporter importer(region->getParentOfType<mlir::ModuleOp>(), {},
118 builder);
119 return importer.ImportAsRegion(computation, region);
120 }
121
ImportAsFunc(const HloComputation & computation)122 StatusOr<mlir::FuncOp> HloFunctionImporter::ImportAsFunc(
123 const HloComputation& computation) {
124 auto& imported = (*function_map_)[&computation];
125 if (imported) return imported;
126 llvm::SmallVector<Type, 4> args, rets;
127 TF_RETURN_IF_ERROR(GetMlirTypes(computation.parameter_instructions(), &args));
128 TF_RETURN_IF_ERROR(GetMlirTypes({computation.root_instruction()}, &rets));
129 auto func_type = mlir::FunctionType::get(context_, args, rets);
130
131 string computation_name =
132 computation.parent()->entry_computation() == &computation
133 ? "main"
134 : SanitizeFunctionName(computation.name());
135
136 // Construct the MLIR function and map arguments.
137 llvm::ArrayRef<mlir::NamedAttribute> attrs;
138 auto function = mlir::FuncOp::create(mlir::UnknownLoc::get(context_),
139 computation_name, func_type, attrs);
140 auto visibility = computation_name == "main" ? FuncOp::Visibility::Public
141 : FuncOp::Visibility::Private;
142 function.setVisibility(visibility);
143 module_.push_back(function);
144
145 // Add to the map right away for function calls.
146 imported = function;
147
148 mlir::Block* block = function.addEntryBlock();
149 TF_RETURN_IF_ERROR(ImportInstructions(computation, block));
150
151 return function;
152 }
153
ImportAsRegion(const HloComputation & computation,mlir::Region * region)154 tensorflow::Status HloFunctionImporter::ImportAsRegion(
155 const HloComputation& computation, mlir::Region* region) {
156 // TODO(hinsu): Store computation name as an attribute for round-trip.
157 auto* block = new mlir::Block;
158 region->push_back(block);
159
160 llvm::SmallVector<Type, 4> args;
161 TF_RETURN_IF_ERROR(GetMlirTypes(computation.parameter_instructions(), &args));
162 block->addArguments(args);
163
164 return ImportInstructions(computation, block);
165 }
166
ImportInstructionsImpl(const xla::HloComputation & computation,const llvm::SmallVectorImpl<Value> & arguments,mlir::OpBuilder * builder)167 StatusOr<Value> HloFunctionImporter::ImportInstructionsImpl(
168 const xla::HloComputation& computation,
169 const llvm::SmallVectorImpl<Value>& arguments, mlir::OpBuilder* builder) {
170 // Setup the input parameters.
171 const int num_parameters = computation.num_parameters();
172
173 if (arguments.size() != num_parameters)
174 return InvalidArgument("Caller vs callee argument sizes do not match");
175
176 for (int i = 0; i < num_parameters; i++) {
177 auto hlo_parameter = computation.parameter_instruction(i);
178 instruction_value_map_[hlo_parameter] = arguments[i];
179 }
180
181 for (auto instruction : computation.MakeInstructionPostOrder()) {
182 TF_ASSIGN_OR_RETURN(auto operands, GetOperands(instruction));
183 TF_ASSIGN_OR_RETURN(
184 auto new_operation,
185 ImportInstructionWithLayout(instruction, operands, builder));
186 if (new_operation) {
187 instruction_value_map_[instruction] = new_operation->getResult(0);
188 }
189 }
190
191 // Setup the return type (HLO only supports a single return value).
192 return GetMlirValue(computation.root_instruction());
193 }
194
ImportInstructions(const HloComputation & computation,mlir::Block * block)195 Status HloFunctionImporter::ImportInstructions(
196 const HloComputation& computation, mlir::Block* block) {
197 llvm::SmallVector<Value, 4> arguments(block->args_begin(), block->args_end());
198 mlir::OpBuilder builder = mlir::OpBuilder::atBlockEnd(block);
199 TF_ASSIGN_OR_RETURN(Value result,
200 ImportInstructionsImpl(computation, arguments, &builder));
201
202 // TODO(suderman): Add location tracking details.
203 mlir::Location loc = builder.getUnknownLoc();
204
205 // Create terminator op depending on the parent op of this region.
206 if (llvm::isa<FuncOp>(block->getParentOp())) {
207 builder.create<mlir::ReturnOp>(loc, result);
208 } else {
209 builder.create<mlir::mhlo::ReturnOp>(loc, result);
210 }
211 return tensorflow::Status::OK();
212 }
213
ImportInstructions(const xla::HloComputation & computation,const llvm::SmallVectorImpl<Value> & arguments,mlir::OpBuilder * builder)214 StatusOr<Value> HloFunctionImporter::ImportInstructions(
215 const xla::HloComputation& computation,
216 const llvm::SmallVectorImpl<Value>& arguments, mlir::OpBuilder* builder) {
217 mlir::Block* block = builder->getBlock();
218 if (block == nullptr)
219 return InvalidArgument(
220 "ImportInstructions requires a valid block in the builder");
221
222 HloFunctionImporter importer(
223 block->getParent()->getParentOfType<mlir::ModuleOp>(), {}, builder);
224 return importer.ImportInstructionsImpl(computation, arguments, builder);
225 }
226
ImportInstruction(const xla::HloInstruction * instr,const llvm::SmallVectorImpl<mlir::Value> & operands,mlir::OpBuilder * builder)227 StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
228 const xla::HloInstruction* instr,
229 const llvm::SmallVectorImpl<mlir::Value>& operands,
230 mlir::OpBuilder* builder) {
231 mlir::Block* block = builder->getBlock();
232 if (block == nullptr)
233 return InvalidArgument(
234 "ImportInstructions requires a valid block in the builder");
235
236 HloFunctionImporter importer(
237 block->getParent()->getParentOfType<mlir::ModuleOp>(), {}, builder);
238
239 return importer.ImportInstructionWithLayout(instr, operands, builder);
240 }
241
ImportInstructionImpl(const HloInstruction * instruction,const llvm::SmallVectorImpl<mlir::Value> & operands,mlir::OpBuilder * func_builder)242 StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
243 const HloInstruction* instruction,
244 const llvm::SmallVectorImpl<mlir::Value>& operands,
245 mlir::OpBuilder* func_builder) {
246 TF_ASSIGN_OR_RETURN(auto result_type, ConvertShapeToType<RankedTensorType>(
247 instruction->shape(), *builder_));
248 mlir::Location loc = GenerateInstructionLocation(instruction, func_builder);
249
250 llvm::SmallVector<NamedAttribute, 10> attributes;
251 switch (instruction->opcode()) {
252 case HloOpcode::kParameter: {
253 return nullptr;
254 }
255 case HloOpcode::kConstant: {
256 const Literal& literal = instruction->literal();
257 auto attr = CreateDenseElementsAttrFromLiteral(literal, *builder_);
258 if (!attr.ok()) return attr.status();
259 mlir::Operation* new_operation =
260 func_builder->create<mlir::mhlo::ConstOp>(loc, attr.ValueOrDie());
261 for (auto attr : attributes) {
262 new_operation->setAttr(attr.first, attr.second);
263 }
264 return new_operation;
265 }
266 case HloOpcode::kIota: {
267 return func_builder
268 ->create<mlir::mhlo::IotaOp>(
269 loc, result_type,
270 func_builder->getI64IntegerAttr(
271 Cast<HloIotaInstruction>(instruction)->iota_dimension()))
272 .getOperation();
273 }
274 #define MakeAndReturn(mlir_op) \
275 { \
276 mlir::Operation* new_operation = \
277 func_builder->create<mlir::mhlo::mlir_op>(loc, result_type, operands, \
278 attributes); \
279 return new_operation; \
280 }
281 case HloOpcode::kBroadcast: {
282 // Note that the HLO broadcast is more powerful than the XLA broadcast
283 // op. BroadcastInDim offers a superset of the HLO op's functionality.
284 attributes.push_back(
285 builder_->getNamedAttr("broadcast_dimensions",
286 ConvertDimensions(instruction->dimensions())));
287 MakeAndReturn(BroadcastInDimOp);
288 }
289 #define MakeAndReturnBatchNormOp(batch_norm_op) \
290 { \
291 attributes.push_back(builder_->getNamedAttr( \
292 "epsilon", builder_->getF32FloatAttr(instruction->epsilon()))); \
293 attributes.push_back(builder_->getNamedAttr( \
294 "feature_index", \
295 builder_->getI64IntegerAttr(instruction->feature_index()))); \
296 MakeAndReturn(batch_norm_op); \
297 }
298 case HloOpcode::kBatchNormGrad:
299 MakeAndReturnBatchNormOp(BatchNormGradOp);
300 case HloOpcode::kBatchNormInference:
301 MakeAndReturnBatchNormOp(BatchNormInferenceOp);
302 case HloOpcode::kBatchNormTraining:
303 MakeAndReturnBatchNormOp(BatchNormTrainingOp);
304 #undef MakeAndReturnBatchNormOp
305
306 case HloOpcode::kDot: {
307 attributes.push_back(builder_->getNamedAttr(
308 "precision_config",
309 ConvertPrecisionConfig(&instruction->precision_config(), builder_)));
310
311 // Consider consolidating DotOps together.
312 if (DotIsDefault(instruction)) {
313 MakeAndReturn(DotOp);
314 }
315
316 attributes.push_back(builder_->getNamedAttr(
317 "dot_dimension_numbers",
318 ConvertDotDimensionNumbers(instruction->dot_dimension_numbers(),
319 builder_)));
320 MakeAndReturn(DotGeneralOp);
321 }
322 case HloOpcode::kCall: {
323 TF_ASSIGN_OR_RETURN(FuncOp function,
324 ImportAsFunc(*instruction->to_apply()));
325 mlir::Operation* new_operation =
326 func_builder->create<mlir::CallOp>(loc, function, operands);
327 return new_operation;
328 }
329 case HloOpcode::kCollectivePermute: {
330 attributes.push_back(ConvertSourceTargetPairs(
331 instruction->source_target_pairs(), builder_));
332 MakeAndReturn(CollectivePermuteOp);
333 }
334 case HloOpcode::kCustomCall: {
335 auto custom_call = Cast<HloCustomCallInstruction>(instruction);
336 TF_ASSIGN_OR_RETURN(
337 auto mlir_api_version,
338 ConvertCustomCallApiVersion(custom_call->api_version()));
339 attributes.push_back(builder_->getNamedAttr(
340 "call_target_name",
341 builder_->getStringAttr(custom_call->custom_call_target())));
342 attributes.push_back(builder_->getNamedAttr(
343 "has_side_effect",
344 builder_->getBoolAttr(custom_call->custom_call_has_side_effect())));
345 attributes.push_back(builder_->getNamedAttr(
346 "backend_config",
347 builder_->getStringAttr(custom_call->raw_backend_config_string())));
348 attributes.push_back(builder_->getNamedAttr(
349 "api_version", mlir::mhlo::CustomCallApiVersionAttr::get(
350 builder_->getContext(), mlir_api_version)));
351 MakeAndReturn(CustomCallOp);
352 }
353 case HloOpcode::kCompare: {
354 auto compare = Cast<HloCompareInstruction>(instruction);
355 attributes.push_back(ConvertComparisonDirection(compare->direction()));
356 auto default_type = Comparison::DefaultComparisonType(
357 compare->operand(0)->shape().element_type());
358 if (compare->type() != default_type)
359 attributes.push_back(ConvertComparisonType(compare->type()));
360 MakeAndReturn(CompareOp);
361 }
362 case HloOpcode::kCholesky: {
363 attributes.push_back(builder_->getNamedAttr(
364 "lower",
365 builder_->getBoolAttr(instruction->cholesky_options().lower())));
366 MakeAndReturn(CholeskyOp);
367 }
368 case HloOpcode::kGather: {
369 auto gather_instruction = Cast<HloGatherInstruction>(instruction);
370 attributes.push_back(builder_->getNamedAttr(
371 "dimension_numbers",
372 ConvertGatherDimensionNumbers(
373 gather_instruction->gather_dimension_numbers(), builder_)));
374
375 std::vector<int64_t> slice_sizes(
376 gather_instruction->gather_slice_sizes().begin(),
377 gather_instruction->gather_slice_sizes().end());
378 attributes.push_back(
379 builder_->getNamedAttr("slice_sizes", Convert(slice_sizes)));
380 attributes.push_back(builder_->getNamedAttr(
381 "indices_are_sorted",
382 builder_->getBoolAttr(gather_instruction->indices_are_sorted())));
383
384 MakeAndReturn(GatherOp);
385 }
386 case HloOpcode::kDynamicSlice: {
387 std::vector<int64_t> slice_sizes(
388 instruction->dynamic_slice_sizes().begin(),
389 instruction->dynamic_slice_sizes().end());
390 return func_builder
391 ->create<mlir::mhlo::DynamicSliceOp>(
392 loc, result_type, operands[0],
393 makeArrayRef(operands).drop_front(), Convert(slice_sizes))
394 .getOperation();
395 }
396 case HloOpcode::kDynamicUpdateSlice: {
397 return func_builder
398 ->create<mlir::mhlo::DynamicUpdateSliceOp>(
399 loc, result_type, operands[0], operands[1],
400 llvm::ArrayRef<Value>(operands.begin() + 2, operands.end()))
401 .getOperation();
402 }
403 case HloOpcode::kInfeed: {
404 attributes.push_back(builder_->getNamedAttr(
405 "infeed_config",
406 mlir::StringAttr::get(builder_->getContext(),
407 instruction->infeed_config())));
408 TF_ASSIGN_OR_RETURN(mlir::Attribute layout,
409 ConvertShapeToMlirLayout(instruction->shape()));
410 attributes.push_back(builder_->getNamedAttr("layout", layout));
411 MakeAndReturn(InfeedOp);
412 }
413 case HloOpcode::kOutfeed: {
414 attributes.push_back(builder_->getNamedAttr(
415 "outfeed_config",
416 mlir::StringAttr::get(builder_->getContext(),
417 instruction->outfeed_config())));
418 MakeAndReturn(OutfeedOp);
419 }
420 case HloOpcode::kPad: {
421 const auto& padding_config = instruction->padding_config();
422 llvm::SmallVector<int64_t, 4> edge_padding_low;
423 llvm::SmallVector<int64_t, 4> edge_padding_high;
424 llvm::SmallVector<int64_t, 4> interior_padding;
425 edge_padding_low.reserve(padding_config.dimensions_size());
426 edge_padding_high.reserve(padding_config.dimensions_size());
427 interior_padding.reserve(padding_config.dimensions_size());
428
429 for (const auto& dimension : padding_config.dimensions()) {
430 edge_padding_low.push_back(dimension.edge_padding_low());
431 edge_padding_high.push_back(dimension.edge_padding_high());
432 interior_padding.push_back(dimension.interior_padding());
433 }
434
435 return func_builder
436 ->create<mlir::mhlo::PadOp>(loc, result_type, operands[0],
437 operands[1], Convert(edge_padding_low),
438 Convert(edge_padding_high),
439 Convert(interior_padding))
440 .getOperation();
441 }
442 case HloOpcode::kScatter: {
443 auto scatter = Cast<HloScatterInstruction>(instruction);
444 attributes.push_back(builder_->getNamedAttr(
445 "scatter_dimension_numbers",
446 ConvertScatterDimensionNumbers(scatter->scatter_dimension_numbers(),
447 builder_)));
448 attributes.push_back(builder_->getNamedAttr(
449 "indices_are_sorted",
450 builder_->getBoolAttr(scatter->indices_are_sorted())));
451 attributes.push_back(builder_->getNamedAttr(
452 "unique_indices", builder_->getBoolAttr(scatter->unique_indices())));
453
454 auto scatter_op = func_builder->create<mlir::mhlo::ScatterOp>(
455 loc, result_type, operands, attributes);
456 TF_RETURN_IF_ERROR(ImportAsRegion(*scatter->to_apply(),
457 &scatter_op.update_computation()));
458 return scatter_op.getOperation();
459 }
460 case HloOpcode::kSelectAndScatter: {
461 auto select_scatter = Cast<HloSelectAndScatterInstruction>(instruction);
462 llvm::SmallVector<int64_t, 4> window_strides, window_dimensions;
463 llvm::SmallVector<int64_t, 8> padding;
464 for (const auto& dim : select_scatter->window().dimensions()) {
465 window_strides.push_back(dim.stride());
466 window_dimensions.push_back(dim.size());
467 padding.push_back(dim.padding_low());
468 padding.push_back(dim.padding_high());
469 }
470 attributes.push_back(
471 builder_->getNamedAttr("window_strides", Convert(window_strides)));
472 attributes.push_back(builder_->getNamedAttr("window_dimensions",
473 Convert(window_dimensions)));
474 attributes.push_back(ConvertPadding(padding));
475 auto select_scatter_op =
476 func_builder->create<mlir::mhlo::SelectAndScatterOp>(
477 loc, result_type, operands, attributes);
478 TF_RETURN_IF_ERROR(ImportAsRegion(*select_scatter->select(),
479 &select_scatter_op.select()));
480 TF_RETURN_IF_ERROR(ImportAsRegion(*select_scatter->scatter(),
481 &select_scatter_op.scatter()));
482 return select_scatter_op.getOperation();
483 }
484 case HloOpcode::kSetDimensionSize: {
485 attributes.push_back(builder_->getNamedAttr(
486 "dimension", builder_->getI64IntegerAttr(instruction->dimension())));
487 MakeAndReturn(SetDimensionSizeOp);
488 }
489 case HloOpcode::kSlice: {
490 return func_builder
491 ->create<mlir::mhlo::SliceOp>(
492 loc, result_type, operands[0],
493 ConvertDimensions(instruction->slice_starts()),
494 ConvertDimensions(instruction->slice_limits()),
495 ConvertDimensions(instruction->slice_strides()))
496 .getOperation();
497 }
498 case HloOpcode::kSort: {
499 auto sort_instruction = Cast<HloSortInstruction>(instruction);
500
501 llvm::SmallVector<Type, 4> return_types = {result_type};
502 if (mlir::TupleType tuple_ty = result_type.dyn_cast<mlir::TupleType>()) {
503 return_types = llvm::to_vector<6>(tuple_ty.getTypes());
504 }
505
506 auto sort_op = func_builder->create<mlir::mhlo::SortOp>(
507 loc, return_types, operands,
508 builder_->getI64IntegerAttr(sort_instruction->sort_dimension()),
509 builder_->getBoolAttr(sort_instruction->is_stable()));
510 TF_RETURN_IF_ERROR(
511 ImportAsRegion(*sort_instruction->to_apply(), &sort_op.comparator()));
512
513 // Check if the output needs to be tupled.
514 if (return_types.size() == 1 && return_types.front() == result_type) {
515 return sort_op.getOperation();
516 }
517
518 return func_builder
519 ->create<mlir::mhlo::TupleOp>(loc, result_type, sort_op.getResults())
520 .getOperation();
521 }
522 case HloOpcode::kConditional: {
523 llvm::SmallVector<Type, 4> rets;
524 mlir::Type pred_or_index_type =
525 operands[0].getType().cast<mlir::TensorType>().getElementType();
526 // It is a predicated conditional if first argument is a boolean and
527 // should be mapped to If op.
528 if (pred_or_index_type.isInteger(1)) {
529 TF_RETURN_IF_ERROR(GetMlirTypes(
530 {instruction->true_computation()->root_instruction()}, &rets));
531
532 auto op = func_builder->create<mlir::mhlo::IfOp>(loc, rets, operands,
533 attributes);
534 TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->true_computation(),
535 &op.true_branch()));
536 TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->false_computation(),
537 &op.false_branch()));
538 return op.getOperation();
539 }
540
541 // Otherwise, it is a indexed conditional and should be mapped to Case
542 // op.
543 TF_RETURN_IF_ERROR(GetMlirTypes(
544 {instruction->branch_computation(0)->root_instruction()}, &rets));
545
546 int num_branches = instruction->branch_count();
547 auto op = func_builder->create<mlir::mhlo::CaseOp>(
548 loc, rets, operands, attributes, num_branches);
549 for (auto index_and_computation :
550 llvm::enumerate(instruction->branch_computations())) {
551 auto index = index_and_computation.index();
552 HloComputation* computation = index_and_computation.value();
553 TF_RETURN_IF_ERROR(ImportAsRegion(*computation, &op.branches()[index]));
554 }
555 return op.getOperation();
556 }
557 case HloOpcode::kConcatenate: {
558 // TODO(b/132057942): Support taking an uint64_t instead of an
559 // IntegerAttr for concatenate dimension.
560 return func_builder
561 ->create<mlir::mhlo::ConcatenateOp>(
562 loc, result_type, operands,
563 builder_->getI64IntegerAttr(instruction->concatenate_dimension()))
564 .getOperation();
565 }
566 case HloOpcode::kAllGather: {
567 auto all_gather = Cast<HloAllGatherInstruction>(instruction);
568 attributes.push_back(builder_->getNamedAttr(
569 "all_gather_dim",
570 builder_->getI64IntegerAttr(all_gather->all_gather_dimension())));
571 attributes.push_back(
572 ConvertReplicaGroups(all_gather->replica_groups(), builder_));
573 attributes.push_back(ConvertChannelHandle(all_gather->channel_id()));
574 MakeAndReturn(AllGatherOp);
575 }
576 case HloOpcode::kAllReduce: {
577 auto all_reduce = Cast<HloAllReduceInstruction>(instruction);
578 attributes.push_back(
579 ConvertReplicaGroups(all_reduce->replica_groups(), builder_));
580 attributes.push_back(ConvertChannelHandle(all_reduce->channel_id()));
581 auto all_reduce_op = func_builder->create<mlir::mhlo::AllReduceOp>(
582 loc, result_type, operands, attributes);
583 TF_RETURN_IF_ERROR(ImportAsRegion(*all_reduce->to_apply(),
584 &all_reduce_op.computation()));
585 return all_reduce_op.getOperation();
586 }
587 case HloOpcode::kReduce: {
588 // Operands in the first half are reduction inputs and the remaining
589 // operands are corresponding initial values.
590 size_t num_inputs = operands.size() / 2;
591 auto reduce = func_builder->create<mlir::mhlo::ReduceOp>(
592 loc, result_type, llvm::makeArrayRef(operands).take_front(num_inputs),
593 llvm::makeArrayRef(operands).drop_front(num_inputs),
594 ConvertDimensions(instruction->dimensions()));
595 TF_RETURN_IF_ERROR(
596 ImportAsRegion(*instruction->to_apply(), &reduce.body()));
597 return reduce.getOperation();
598 }
599 case HloOpcode::kReverse: {
600 return func_builder
601 ->create<mlir::mhlo::ReverseOp>(
602 loc, result_type, operands[0],
603 ConvertDimensions(instruction->dimensions()))
604 .getOperation();
605 }
606 case HloOpcode::kRng: {
607 auto shape = func_builder->create<mlir::ConstantOp>(
608 loc, Convert(result_type.cast<RankedTensorType>().getShape()));
609 switch (instruction->random_distribution()) {
610 case xla::RNG_UNIFORM:
611 return func_builder
612 ->create<mlir::mhlo::RngUniformOp>(loc, result_type, operands[0],
613 operands[1], shape)
614 .getOperation();
615
616 case xla::RNG_NORMAL:
617 return func_builder
618 ->create<mlir::mhlo::RngNormalOp>(loc, result_type, operands[0],
619 operands[1], shape)
620 .getOperation();
621
622 default:
623 return tensorflow::errors::InvalidArgument(absl::StrCat(
624 "Unsupported distribution: ",
625 RandomDistributionToString(instruction->random_distribution())));
626 }
627 }
628 case HloOpcode::kRngBitGenerator: {
629 auto rng_op = Cast<HloRngBitGeneratorInstruction>(instruction);
630 auto op = func_builder->create<mlir::mhlo::RngBitGeneratorOp>(
631 loc, result_type,
632 func_builder->getI32IntegerAttr(rng_op->algorithm()), operands[0]);
633 return op.getOperation();
634 }
635 case HloOpcode::kWhile: {
636 auto op = func_builder->create<mlir::mhlo::WhileOp>(
637 loc, operands[0].getType(), operands[0]);
638 TF_RETURN_IF_ERROR(
639 ImportAsRegion(*instruction->while_condition(), &op.cond()));
640 TF_RETURN_IF_ERROR(
641 ImportAsRegion(*instruction->while_body(), &op.body()));
642 return op.getOperation();
643 }
644 case HloOpcode::kGetTupleElement: {
645 attributes.push_back(builder_->getNamedAttr(
646 "index", builder_->getIntegerAttr(builder_->getIntegerType(32),
647 instruction->tuple_index())));
648 MakeAndReturn(GetTupleElementOp);
649 };
650 case HloOpcode::kGetDimensionSize: {
651 attributes.push_back(builder_->getNamedAttr(
652 "dimension", builder_->getI64IntegerAttr(instruction->dimension())));
653 MakeAndReturn(GetDimensionSizeOp);
654 };
655 case HloOpcode::kTranspose: {
656 attributes.push_back(builder_->getNamedAttr(
657 "permutation", ConvertDimensions(instruction->dimensions())));
658 MakeAndReturn(TransposeOp);
659 }
660 case HloOpcode::kTriangularSolve: {
661 attributes.push_back(builder_->getNamedAttr(
662 "left_side",
663 builder_->getBoolAttr(
664 instruction->triangular_solve_options().left_side())));
665 attributes.push_back(builder_->getNamedAttr(
666 "lower", builder_->getBoolAttr(
667 instruction->triangular_solve_options().lower())));
668 attributes.push_back(builder_->getNamedAttr(
669 "unit_diagonal",
670 builder_->getBoolAttr(
671 instruction->triangular_solve_options().unit_diagonal())));
672 auto transpose_a =
673 builder_->getStringAttr(TriangularSolveOptions::Transpose_Name(
674 instruction->triangular_solve_options().transpose_a()));
675 attributes.push_back(builder_->getNamedAttr("transpose_a", transpose_a));
676 MakeAndReturn(TriangularSolveOp);
677 }
678 case HloOpcode::kReduceWindow: {
679 llvm::SmallVector<int64, 4> sizes, strides, base_dilations, win_dilations;
680 llvm::SmallVector<int64_t, 8> padding;
681 for (const auto& dim : instruction->window().dimensions()) {
682 sizes.push_back(dim.size());
683 strides.push_back(dim.stride());
684 base_dilations.push_back(dim.base_dilation());
685 win_dilations.push_back(dim.window_dilation());
686 padding.push_back(dim.padding_low());
687 padding.push_back(dim.padding_high());
688 }
689 attributes.push_back(builder_->getNamedAttr("window_dimensions",
690 ConvertDimensions(sizes)));
691 attributes.push_back(
692 builder_->getNamedAttr("window_strides", ConvertDimensions(strides)));
693 attributes.push_back(builder_->getNamedAttr(
694 "base_dilations", ConvertDimensions(base_dilations)));
695 attributes.push_back(builder_->getNamedAttr(
696 "window_dilations", ConvertDimensions(win_dilations)));
697 attributes.push_back(ConvertPadding(padding));
698 auto reduce = func_builder->create<mlir::mhlo::ReduceWindowOp>(
699 loc, result_type, operands, attributes);
700 TF_RETURN_IF_ERROR(
701 ImportAsRegion(*instruction->to_apply(), &reduce.body()));
702 return reduce.getOperation();
703 }
704 case HloOpcode::kMap: {
705 auto op = func_builder->create<mlir::mhlo::MapOp>(
706 loc, result_type, operands,
707 ConvertDimensions(instruction->dimensions()));
708 TF_RETURN_IF_ERROR(
709 ImportAsRegion(*instruction->to_apply(), &op.computation()));
710 return op.getOperation();
711 }
712 case HloOpcode::kConvolution: {
713 llvm::SmallVector<int64_t, 4> strides, lhs_dilations, rhs_dilations;
714 llvm::SmallVector<int64_t, 8> paddings;
715 for (const auto& dim : instruction->window().dimensions()) {
716 strides.push_back(dim.stride());
717 lhs_dilations.push_back(dim.base_dilation());
718 rhs_dilations.push_back(dim.window_dilation());
719 paddings.push_back(dim.padding_low());
720 paddings.push_back(dim.padding_high());
721 }
722
723 attributes.push_back(
724 builder_->getNamedAttr("window_strides", Convert(strides)));
725 attributes.push_back(ConvertPadding(paddings));
726 attributes.push_back(
727 builder_->getNamedAttr("lhs_dilation", Convert(lhs_dilations)));
728 attributes.push_back(
729 builder_->getNamedAttr("rhs_dilation", Convert(rhs_dilations)));
730 attributes.push_back(builder_->getNamedAttr(
731 "dimension_numbers",
732 ConvertConvDimensionNumbers(
733 instruction->convolution_dimension_numbers(), builder_)));
734 attributes.push_back(builder_->getNamedAttr(
735 "feature_group_count",
736 builder_->getI64IntegerAttr(instruction->feature_group_count())));
737 attributes.push_back(builder_->getNamedAttr(
738 "batch_group_count",
739 builder_->getI64IntegerAttr(instruction->batch_group_count())));
740 attributes.push_back(builder_->getNamedAttr(
741 "precision_config",
742 ConvertPrecisionConfig(&instruction->precision_config(), builder_)));
743
744 MakeAndReturn(ConvOp);
745 }
746
747 case HloOpcode::kFft: {
748 auto fft_type =
749 builder_->getStringAttr(FftType_Name(instruction->fft_type()));
750
751 std::vector<int64_t> fft_length(instruction->fft_length().begin(),
752 instruction->fft_length().end());
753
754 attributes.push_back(builder_->getNamedAttr("fft_type", fft_type));
755 attributes.push_back(
756 builder_->getNamedAttr("fft_length", Convert(fft_length)));
757 MakeAndReturn(FftOp);
758 }
759
760 case HloOpcode::kAdd: {
761 // HLO add ops on PRED elements are actually boolean or, but MHLO dialect
762 // AddOps on i1 are just addition with overflow; so, we have to implement
763 // the special behavior of HLO add ops on PRED here by creating an OrOp
764 // instead.
765 if (instruction->shape().element_type() == PRED) {
766 MakeAndReturn(OrOp);
767 } else {
768 MakeAndReturn(AddOp);
769 }
770 }
771 #define NoAttributeCase(hlo_op_code, mlir_op) \
772 case HloOpcode::hlo_op_code: { \
773 MakeAndReturn(mlir_op); \
774 }
775
776 // broadcast dimensions are never added here because they don't exist as
777 // part of the HLO instruction. They are only a convenience in the XLA
778 // builder API.
779 NoAttributeCase(kAbs, AbsOp);
780 NoAttributeCase(kAfterAll, AfterAllOp);
781 NoAttributeCase(kAnd, AndOp);
782 NoAttributeCase(kAtan2, Atan2Op);
783 NoAttributeCase(kBitcastConvert, BitcastConvertOp);
784 NoAttributeCase(kCbrt, CbrtOp);
785 NoAttributeCase(kClz, ClzOp);
786 NoAttributeCase(kConvert, ConvertOp);
787 NoAttributeCase(kCeil, CeilOp);
788 NoAttributeCase(kClamp, ClampOp);
789 NoAttributeCase(kComplex, ComplexOp);
790 NoAttributeCase(kCos, CosOp);
791 NoAttributeCase(kDivide, DivOp);
792 NoAttributeCase(kExp, ExpOp);
793 NoAttributeCase(kExpm1, Expm1Op);
794 NoAttributeCase(kFloor, FloorOp);
795 NoAttributeCase(kIsFinite, IsFiniteOp);
796 NoAttributeCase(kImag, ImagOp);
797 NoAttributeCase(kLog, LogOp);
798 NoAttributeCase(kLog1p, Log1pOp);
799 NoAttributeCase(kMaximum, MaxOp);
800 NoAttributeCase(kMinimum, MinOp);
801 NoAttributeCase(kMultiply, MulOp);
802 NoAttributeCase(kNegate, NegOp);
803 NoAttributeCase(kNot, NotOp);
804 NoAttributeCase(kOr, OrOp);
805 NoAttributeCase(kPopulationCount, PopulationCountOp);
806 NoAttributeCase(kPower, PowOp);
807 NoAttributeCase(kReal, RealOp);
808 NoAttributeCase(kRemainder, RemOp);
809 NoAttributeCase(kReplicaId, ReplicaIdOp);
810 NoAttributeCase(kLogistic, LogisticOp);
811 // The dimensions attribute is not present on the HLO Reshape
812 // instruction. If dimensions are non-default, the XLA builder
813 // implements it as a separate transpose.
814 NoAttributeCase(kReshape, ReshapeOp);
815 NoAttributeCase(kRoundNearestAfz, RoundOp);
816 NoAttributeCase(kRsqrt, RsqrtOp);
817 NoAttributeCase(kSelect, SelectOp);
818 NoAttributeCase(kShiftLeft, ShiftLeftOp);
819 NoAttributeCase(kShiftRightArithmetic, ShiftRightArithmeticOp);
820 NoAttributeCase(kShiftRightLogical, ShiftRightLogicalOp);
821 NoAttributeCase(kSign, SignOp);
822 NoAttributeCase(kSin, SinOp);
823 NoAttributeCase(kSqrt, SqrtOp);
824 NoAttributeCase(kSubtract, SubOp);
825 NoAttributeCase(kTanh, TanhOp);
826 NoAttributeCase(kTuple, TupleOp);
827 NoAttributeCase(kXor, XorOp);
828 // TODO(b/129422361) Copy needs special handling because it is not
829 // defined in tensorflow/compiler/xla/client/xla_builder.h. See
830 // operation semantics in
831 // g3doc/platforms/xla/g3doc/internal/hlo_semantics#copy
832 NoAttributeCase(kCopy, CopyOp);
833 #undef NoAttributeCase
834 #undef MakeAndReturn
835 case HloOpcode::kFusion: {
836 auto fusion = func_builder->create<mlir::mhlo::FusionOp>(
837 loc, result_type, operands,
838 builder_->getStringAttr(xla::ToString(instruction->fusion_kind())));
839 TF_RETURN_IF_ERROR(
840 ImportAsRegion(*instruction->fused_instructions_computation(),
841 &fusion.fused_computation()));
842 return fusion.getOperation();
843 }
844 case HloOpcode::kBitcast: {
845 auto bitcast = func_builder->create<mlir::mhlo::BitcastOp>(
846 loc, result_type, operands, attributes);
847 // Store the source and result layout as attributes. Although the MHLO
848 // Bitcast operates on tensors, these layouts are relevant as they define
849 // the mapping between the elements of the source and result.
850 SetLayoutForMlir(bitcast, instruction->shape(), "result_layout");
851 SetLayoutForMlir(bitcast, instruction->operand(0)->shape(),
852 "source_layout");
853 return bitcast.getOperation();
854 }
855 case HloOpcode::kReducePrecision: {
856 auto op = func_builder->create<mlir::mhlo::ReducePrecisionOp>(
857 loc, result_type, operands[0], attributes);
858 op.exponent_bitsAttr(func_builder->getIntegerAttr(
859 func_builder->getI32Type(), instruction->exponent_bits()));
860 op.mantissa_bitsAttr(func_builder->getIntegerAttr(
861 func_builder->getI32Type(), instruction->mantissa_bits()));
862 return op.getOperation();
863 }
864 case HloOpcode::kAddDependency:
865 // Arbitrary op code that I suspect we will not implement for quite a
866 // while and allows testing handling of unknown ops. Selected because it
867 // is not mentioned in xla client anywhere or in the hlo of our sample
868 // models.
869 default: {
870 mlir::OperationState result(loc, "mhlo.unknown");
871 result.addOperands(operands);
872 result.addTypes(result_type);
873 for (auto attr : attributes) {
874 result.attributes.push_back(attr);
875 }
876
877 return func_builder->createOperation(result);
878 }
879 }
880 }
881
ImportInstructionWithLayout(const HloInstruction * instruction,const llvm::SmallVectorImpl<mlir::Value> & operands,mlir::OpBuilder * func_builder)882 StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionWithLayout(
883 const HloInstruction* instruction,
884 const llvm::SmallVectorImpl<mlir::Value>& operands,
885 mlir::OpBuilder* func_builder) {
886 TF_ASSIGN_OR_RETURN(
887 mlir::Operation * op,
888 ImportInstructionImpl(instruction, operands, func_builder));
889 if (op == nullptr) return op;
890
891 // See MlirToHloConversionOptions for more about layouts.
892 //
893 // Minor-to-major is a permutation of [0, rank), presenting tensor dimensions
894 // in physical minor-to-major order.
895 if (instruction->shape().IsArray() &&
896 !instruction->shape().layout().minor_to_major().empty() &&
897 instruction->shape().layout() !=
898 LayoutUtil::MakeDescendingLayout(
899 instruction->shape().dimensions().size())) {
900 SetLayoutForMlir(op, instruction->shape());
901 }
902 return op;
903 }
904
GetOperands(const HloInstruction * instruction)905 StatusOr<llvm::SmallVector<mlir::Value, 4>> HloFunctionImporter::GetOperands(
906 const HloInstruction* instruction) {
907 llvm::SmallVector<mlir::Value, 4> operands;
908 for (const auto& operand : instruction->operands()) {
909 auto input_it = instruction_value_map_.find(operand);
910 if (input_it == instruction_value_map_.end()) {
911 return tensorflow::errors::Internal(
912 absl::StrCat("Could not find input value: ", operand->name(),
913 " for instruction ", instruction->name()));
914 }
915 operands.push_back(input_it->second);
916 }
917 return operands;
918 }
919
GetMlirTypes(const std::vector<HloInstruction * > & instructions,llvm::SmallVectorImpl<mlir::Type> * types)920 tensorflow::Status HloFunctionImporter::GetMlirTypes(
921 const std::vector<HloInstruction*>& instructions,
922 llvm::SmallVectorImpl<mlir::Type>* types) {
923 for (auto instruction : instructions) {
924 TF_ASSIGN_OR_RETURN(auto ret_type, ConvertShapeToType<RankedTensorType>(
925 instruction->shape(), *builder_));
926 types->push_back(ret_type);
927 }
928 return tensorflow::Status::OK();
929 }
930
GetMlirValue(const HloInstruction * instruction)931 StatusOr<Value> HloFunctionImporter::GetMlirValue(
932 const HloInstruction* instruction) {
933 auto lookup = instruction_value_map_.find(instruction);
934 if (lookup != instruction_value_map_.end()) {
935 return lookup->second;
936 }
937
938 return tensorflow::errors::Internal(absl::StrCat(
939 "Unable to find value for input: ", instruction->ToString()));
940 }
941
ConvertComparisonDirection(ComparisonDirection direction)942 mlir::NamedAttribute HloFunctionImporter::ConvertComparisonDirection(
943 ComparisonDirection direction) {
944 return builder_->getNamedAttr(
945 "comparison_direction",
946 builder_->getStringAttr(ComparisonDirectionToString(direction)));
947 }
948
ConvertComparisonType(Comparison::Type type)949 mlir::NamedAttribute HloFunctionImporter::ConvertComparisonType(
950 Comparison::Type type) {
951 return builder_->getNamedAttr(
952 "compare_type", builder_->getStringAttr(ComparisonTypeToString(type)));
953 }
954
ConvertDimensions(llvm::ArrayRef<int64> op_dimensions)955 mlir::DenseIntElementsAttr HloFunctionImporter::ConvertDimensions(
956 llvm::ArrayRef<int64> op_dimensions) {
957 llvm::SmallVector<APInt, 8> dimensions;
958 dimensions.reserve(op_dimensions.size());
959 for (auto value : op_dimensions) dimensions.emplace_back(APInt(64, value));
960
961 return DenseIntElementsAttr::get(
962 RankedTensorType::get(dimensions.size(), builder_->getIntegerType(64)),
963 dimensions);
964 }
965
Convert(llvm::ArrayRef<int64_t> elements)966 mlir::DenseIntElementsAttr HloFunctionImporter::Convert(
967 llvm::ArrayRef<int64_t> elements) {
968 return DenseIntElementsAttr::get(
969 RankedTensorType::get(elements.size(), builder_->getIntegerType(64)),
970 elements);
971 }
972
ConvertPadding(llvm::ArrayRef<int64_t> padding)973 mlir::NamedAttribute HloFunctionImporter::ConvertPadding(
974 llvm::ArrayRef<int64_t> padding) {
975 auto ty =
976 mlir::RankedTensorType::get({static_cast<int64_t>(padding.size()) / 2, 2},
977 builder_->getIntegerType(64));
978 auto attr = DenseIntElementsAttr::get(ty, padding);
979 return builder_->getNamedAttr("padding", attr);
980 }
981
ConvertSourceTargetPairs(const std::vector<std::pair<tensorflow::int64,tensorflow::int64>> & source_target_pairs,mlir::Builder * builder)982 mlir::NamedAttribute HloFunctionImporter::ConvertSourceTargetPairs(
983 const std::vector<std::pair<tensorflow::int64, tensorflow::int64>>&
984 source_target_pairs,
985 mlir::Builder* builder) {
986 std::vector<int64_t> attr(source_target_pairs.size() * 2);
987 for (auto p : llvm::enumerate(source_target_pairs)) {
988 attr[2 * p.index()] = p.value().first;
989 attr[2 * p.index() + 1] = p.value().second;
990 }
991 auto type = mlir::RankedTensorType::get(
992 {static_cast<int64_t>(attr.size() / 2), 2}, builder->getIntegerType(64));
993 return builder->getNamedAttr("source_target_pairs",
994 DenseIntElementsAttr::get(type, attr));
995 }
996
ConvertReplicaGroups(absl::Span<const ReplicaGroup> replica_groups,mlir::Builder * builder)997 mlir::NamedAttribute HloFunctionImporter::ConvertReplicaGroups(
998 absl::Span<const ReplicaGroup> replica_groups, mlir::Builder* builder) {
999 const int64_t num_groups = replica_groups.size();
1000 // Replica groups in HLO can be non-uniform in size, for example:
1001 // replica_groups={{0},{1,2},{3}}. Since we are representing them as a 2D
1002 // tensor, pad the smaller sized replica groups with -1.
1003 const int64_t group_size = absl::c_accumulate(
1004 replica_groups, int64_t(0), [](int64_t current, const ReplicaGroup& g) {
1005 return std::max<int64_t>(current, g.replica_ids_size());
1006 });
1007 // Initialize all elements to -1 to support non-uniform replica groups.
1008 std::vector<int64_t> attr(num_groups * group_size, -1);
1009 for (int i = 0; i < num_groups; ++i) {
1010 int index = i * group_size;
1011 for (const int64& id : replica_groups[i].replica_ids()) attr[index++] = id;
1012 }
1013 auto type = mlir::RankedTensorType::get({num_groups, group_size},
1014 builder->getIntegerType(64));
1015 return builder->getNamedAttr("replica_groups",
1016 DenseIntElementsAttr::get(type, attr));
1017 }
1018
ConvertChannelHandle(absl::optional<tensorflow::int64> channel_id)1019 mlir::NamedAttribute HloFunctionImporter::ConvertChannelHandle(
1020 absl::optional<tensorflow::int64> channel_id) {
1021 xla::ChannelHandle channel_handle;
1022 if (channel_id) channel_handle.set_handle(*channel_id);
1023 return ConvertChannelHandle(channel_handle);
1024 }
1025
ConvertChannelHandle(const xla::ChannelHandle & channel)1026 mlir::NamedAttribute HloFunctionImporter::ConvertChannelHandle(
1027 const xla::ChannelHandle& channel) {
1028 return builder_->getNamedAttr(
1029 "channel_handle",
1030 mlir::mhlo::ChannelHandle::get(
1031 builder_->getI64IntegerAttr(channel.handle()),
1032 builder_->getI64IntegerAttr(channel.type()), context_));
1033 }
1034
SetLayoutForMlir(mlir::Operation * op,const Shape & shape,llvm::StringRef attr_name)1035 void HloFunctionImporter::SetLayoutForMlir(mlir::Operation* op,
1036 const Shape& shape,
1037 llvm::StringRef attr_name) {
1038 llvm::SmallVector<int64_t, 4> minor_to_major(
1039 shape.layout().minor_to_major().begin(),
1040 shape.layout().minor_to_major().end());
1041 op->setAttr(
1042 attr_name,
1043 mlir::Builder(op->getContext()).getIndexTensorAttr(minor_to_major));
1044 }
1045
ConvertShapeToMlirLayout(const xla::Shape & shape)1046 StatusOr<mlir::Attribute> HloFunctionImporter::ConvertShapeToMlirLayout(
1047 const xla::Shape& shape) {
1048 if (shape.IsToken()) return builder_->getUnitAttr();
1049 if (shape.IsTuple()) {
1050 std::vector<mlir::Attribute> tuple_layouts;
1051 for (int64_t i = 0; i < shape.tuple_shapes_size(); i++) {
1052 TF_ASSIGN_OR_RETURN(mlir::Attribute layout,
1053 ConvertShapeToMlirLayout(shape.tuple_shapes(i)));
1054 tuple_layouts.push_back(layout);
1055 }
1056 llvm::ArrayRef<mlir::Attribute> array_ref(tuple_layouts);
1057 return builder_->getArrayAttr(array_ref);
1058 }
1059 if (shape.IsArray()) {
1060 const xla::Layout l = shape.layout();
1061 std::vector<mlir::Attribute> minor_to_major;
1062 for (int64_t i : l.minor_to_major()) {
1063 minor_to_major.push_back(builder_->getI64IntegerAttr(i));
1064 }
1065 llvm::ArrayRef<mlir::Attribute> array_ref(minor_to_major);
1066 return builder_->getArrayAttr(array_ref);
1067 }
1068 return tensorflow::errors::Internal("Couldn't convert layout.");
1069 }
1070
1071 } // namespace xla
1072