1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h"
17
18 #include <climits>
19 #include <memory>
20 #include <tuple>
21
22 #include "absl/algorithm/container.h"
23 #include "absl/types/optional.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
27 #include "mlir/IR/AffineExpr.h" // from @llvm-project
28 #include "mlir/IR/AffineMap.h" // from @llvm-project
29 #include "mlir/IR/Attributes.h" // from @llvm-project
30 #include "mlir/IR/Builders.h" // from @llvm-project
31 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
32 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
33 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
34 #include "mlir/IR/Dialect.h" // from @llvm-project
35 #include "mlir/IR/Location.h" // from @llvm-project
36 #include "mlir/IR/MLIRContext.h" // from @llvm-project
37 #include "mlir/IR/OpDefinition.h" // from @llvm-project
38 #include "mlir/IR/Operation.h" // from @llvm-project
39 #include "mlir/IR/PatternMatch.h" // from @llvm-project
40 #include "mlir/IR/SymbolTable.h" // from @llvm-project
41 #include "mlir/IR/Verifier.h" // from @llvm-project
42 #include "mlir/Pass/Pass.h" // from @llvm-project
43 #include "mlir/Pass/PassOptions.h" // from @llvm-project
44 #include "mlir/Translation.h" // from @llvm-project
45 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
46 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.h"
47 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
48 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
49 #include "tensorflow/compiler/mlir/xla/attribute_importer.h"
50 #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
51 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
52 #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
53 #include "tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.h"
54 #include "tensorflow/compiler/xla/debug_options_flags.h"
55 #include "tensorflow/compiler/xla/service/backend.h"
56 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
57 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
58 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
59 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
60 #include "tensorflow/compiler/xla/service/hlo_computation.h"
61 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
62 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
63 #include "tensorflow/compiler/xla/service/hlo_module.h"
64 #include "tensorflow/compiler/xla/service/hlo_parser.h"
65 #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
66 #include "tensorflow/compiler/xla/shape_util.h"
67 #include "tensorflow/compiler/xla/statusor.h"
68 #include "tensorflow/compiler/xla/util.h"
69 #include "tensorflow/compiler/xla/window_util.h"
70 #include "tensorflow/compiler/xla/xla_data.pb.h"
71
72 using xla::BufferAllocation;
73 using xla::BufferAssignment;
74 using xla::HloComputation;
75 using xla::HloCustomCallInstruction;
76 using xla::HloInfeedInstruction;
77 using xla::HloInstruction;
78 using xla::HloModule;
79 using xla::HloModuleProto;
80 using xla::HloOutfeedInstruction;
81 using xla::HloProto;
82 using xla::Shape;
83 using xla::StatusOr;
84
85 namespace mlir {
86 namespace {
87
StringRefToView(llvm::StringRef ref)88 absl::string_view StringRefToView(llvm::StringRef ref) {
89 return {ref.data(), ref.size()};
90 }
91
HloModuleFromProto(const HloProto & hlo_proto)92 StatusOr<std::unique_ptr<HloModule>> HloModuleFromProto(
93 const HloProto& hlo_proto) {
94 const HloModuleProto& module_proto = hlo_proto.hlo_module();
95 TF_ASSIGN_OR_RETURN(const xla::HloModuleConfig module_config,
96 HloModule::CreateModuleConfigFromProto(
97 module_proto, xla::GetDebugOptionsFromFlags()));
98 return HloModule::CreateFromProto(module_proto, module_config);
99 }
100
AllocationShouldLowerToTypedArg(const BufferAllocation * alloc)101 bool AllocationShouldLowerToTypedArg(const BufferAllocation* alloc) {
102 return alloc->is_entry_computation_parameter() && !alloc->maybe_live_out();
103 }
104
105 } // namespace
106
107 // Convert the MLIR `module` from HLO dialect to LHLO dialect using XLA for the
108 // given platform.
OptimizeAndConvertHloToLmhlo(std::unique_ptr<HloModule> hlo_module,ModuleOp module,StringRef platform_name)109 Status OptimizeAndConvertHloToLmhlo(std::unique_ptr<HloModule> hlo_module,
110 ModuleOp module, StringRef platform_name) {
111 auto platform = xla::se::MultiPlatformManager::PlatformWithName(
112 StringRefToView(platform_name));
113 if (!platform.ok()) {
114 std::string error_msg;
115 llvm::raw_string_ostream os(error_msg);
116 os << "failed to get platform: " << platform.status().ToString()
117 << " (available Platform: ";
118 std::vector<std::string> available_platforms;
119 (void)xla::se::MultiPlatformManager::PlatformsWithFilter(
120 [&](const stream_executor::Platform* p) {
121 available_platforms.push_back(p->Name());
122 return false;
123 });
124 llvm::interleaveComma(available_platforms, os);
125 os << ")";
126 return xla::InvalidArgument("%s", os.str().c_str());
127 }
128
129 xla::BackendOptions backend_options;
130 backend_options.set_platform(platform.ValueOrDie());
131 auto backend_or_err = xla::Backend::CreateBackend(backend_options);
132 TF_RETURN_WITH_CONTEXT_IF_ERROR(backend_or_err.status(),
133 "failed to create XLA Backend ");
134 auto backend = std::move(backend_or_err.ValueOrDie());
135
136 // Run all HLO passes to produce an optimized module.
137 auto result_or = backend->compiler()->RunHloPassesAndBufferAssignement(
138 std::move(hlo_module), backend->default_stream_executor(),
139 optimize_xla_hlo, {backend->memory_allocator()});
140 TF_RETURN_WITH_CONTEXT_IF_ERROR(result_or.status(),
141 "running XLA pass pipeline");
142 std::unique_ptr<HloModule> optimized_hlo_module =
143 std::move(std::get<0>(result_or.ValueOrDie()));
144 std::unique_ptr<BufferAssignment> assignment =
145 std::move(std::get<1>(result_or.ValueOrDie()));
146
147 // Clear the module before populating it back with the result of the
148 // conversion.
149 module.getBody()->clear();
150 OpBuilder builder(module);
151
152 TF_RETURN_WITH_CONTEXT_IF_ERROR(
153 HloToLhloModule(*assignment, *optimized_hlo_module, module),
154 "converting HLO to LHLO");
155
156 return Status::OK();
157 }
158
159 namespace {
160 // This pass takes an MLIR HLO module, converts it to XLA to perform the HLO
161 // optimization pipeline for the required platform, and then converts it back to
162 // MLIR LHLO.
163 class XlaHloToLhloPass
164 : public PassWrapper<XlaHloToLhloPass, OperationPass<ModuleOp>> {
getDependentDialects(DialectRegistry & registry) const165 void getDependentDialects(DialectRegistry& registry) const override {
166 registry
167 .insert<StandardOpsDialect, memref::MemRefDialect, mhlo::MhloDialect,
168 lmhlo::LmhloDialect, lmhlo_gpu::LmhloGpuDialect>();
169 }
170
171 public:
172 XlaHloToLhloPass() = default;
XlaHloToLhloPass(const XlaHloToLhloPass &)173 XlaHloToLhloPass(const XlaHloToLhloPass&) {}
getArgument() const174 StringRef getArgument() const final { return "xla-hlo-to-lhlo-with-xla"; }
getDescription() const175 StringRef getDescription() const final {
176 return "Emit LHLO from HLO using the existing XLA implementation";
177 }
178
179 private:
runOnOperation()180 void runOnOperation() final {
181 ModuleOp module = getOperation();
182
183 auto status = [&module, this]() -> Status {
184 SymbolTable symbol_table(module);
185 if (!symbol_table.lookup("main")) {
186 return xla::InvalidArgument(
187 "conversion to HLO module failed: missing main()");
188 }
189 HloProto hlo_proto;
190 TF_RETURN_WITH_CONTEXT_IF_ERROR(
191 ConvertMlirHloToHlo(module, &hlo_proto,
192 /*use_tuple_args=*/false,
193 /*return_tuple=*/false,
194 /*shape_representation_fn=*/nullptr),
195 "conversion to XLA HLO proto failed");
196
197 auto statusOrHloModule = HloModuleFromProto(hlo_proto);
198 TF_RETURN_WITH_CONTEXT_IF_ERROR(statusOrHloModule.status(),
199 "parsing HLO proto to HLO module failed");
200 std::unique_ptr<HloModule> hlo_module =
201 std::move(statusOrHloModule.ValueOrDie());
202
203 return OptimizeAndConvertHloToLmhlo(std::move(hlo_module), module,
204 platform_);
205 }();
206 if (!status.ok()) {
207 module.emitError() << status.ToString();
208 return signalPassFailure();
209 }
210 }
211
212 Option<std::string> platform_{
213 *this, "platform",
214 llvm::cl::desc("The platform to use for the XLA optimization pipeline."),
215 llvm::cl::init("Host")};
216 };
217
218 } // namespace
219
220 // Creates MLIR operands corresponding to operands and results of the XLA HLO
221 // instruction. If `num_operands` is valid, then only the first `num_operands`
222 // operands of the HLO instruction will be considered.
CreateOperands(const HloInstruction * instr,absl::optional<xla::int64> num_operands,TokenLoweringMode token_mode,llvm::SmallVectorImpl<Value> & operands,size_t & num_arguments,size_t & num_results)223 Status LhloDialectEmitter::CreateOperands(
224 const HloInstruction* instr, absl::optional<xla::int64> num_operands,
225 TokenLoweringMode token_mode, llvm::SmallVectorImpl<Value>& operands,
226 size_t& num_arguments, size_t& num_results) {
227 if (num_operands.value_or(0) > instr->operand_count())
228 return xla::InvalidArgument("num_operands must be <= operand count");
229 for (xla::int64 i = 0; i < num_operands.value_or(instr->operand_count());
230 ++i) {
231 TF_RETURN_IF_ERROR(GetOrCreateView(instr->operand(i), &operands,
232 /*result_subset=*/{}, token_mode));
233 }
234 num_arguments = operands.size();
235 TF_RETURN_IF_ERROR(
236 GetOrCreateView(instr, &operands, /*result_subset=*/{}, token_mode));
237 num_results = operands.size() - num_arguments;
238 return Status::OK();
239 }
240
241 template <typename OpType>
CreateOpWithoutAttrs(const HloInstruction * instr,ValueRange operands)242 OpType LhloDialectEmitter::CreateOpWithoutAttrs(const HloInstruction* instr,
243 ValueRange operands) {
244 Location loc = getLocation(instr);
245 return builder_.create<OpType>(loc, llvm::None, operands,
246 llvm::ArrayRef<NamedAttribute>{});
247 }
248
249 template <typename OpType>
CreateOpWithoutAttrs(const HloInstruction * instr,size_t & num_arguments,size_t & num_results,absl::optional<xla::int64> num_operands)250 StatusOr<OpType> LhloDialectEmitter::CreateOpWithoutAttrs(
251 const HloInstruction* instr, size_t& num_arguments, size_t& num_results,
252 absl::optional<xla::int64> num_operands) {
253 llvm::SmallVector<Value, 4> operands;
254 TF_RETURN_IF_ERROR(CreateOperands(instr, num_operands,
255 TokenLoweringMode::kFailToLower, operands,
256 num_arguments, num_results));
257 return CreateOpWithoutAttrs<OpType>(instr, operands);
258 }
259
CreateOpInFusion(const HloInstruction * instr,ValueRange buffer_operands,size_t num_arguments,size_t num_results)260 StatusOr<mlir::Operation*> LhloDialectEmitter::CreateOpInFusion(
261 const HloInstruction* instr, ValueRange buffer_operands,
262 size_t num_arguments, size_t num_results) {
263 Location loc = getLocation(instr);
264 std::vector<Value> buffers(buffer_operands.begin(), buffer_operands.end());
265 absl::Span<Value> arguments =
266 absl::MakeSpan(buffers).subspan(0, num_arguments);
267 absl::Span<Value> results =
268 absl::MakeSpan(buffers).subspan(num_arguments, num_results);
269
270 mlir::lmhlo::FusionOp fusion = builder_.create<mlir::lmhlo::FusionOp>(loc);
271 mlir::OpBuilder b(&fusion.region());
272
273 llvm::SmallVector<mlir::Value, 4> loads;
274 for (Value arg : arguments) {
275 auto load = b.create<mlir::memref::TensorLoadOp>(loc, arg);
276 Shape shape = xla::TypeToShape(arg.getType());
277 TF_RET_CHECK(shape.IsArray());
278 if (shape.layout() !=
279 xla::LayoutUtil::MakeDescendingLayout(shape.dimensions().size())) {
280 load->setAttr("minor_to_major", GetLayoutAttribute(shape.layout(), &b));
281 }
282 loads.push_back(load);
283 }
284 mlir::Operation* op = nullptr;
285 if (instr->opcode() == xla::HloOpcode::kReduce) {
286 TF_RET_CHECK(loads.size() % 2 == 0);
287 std::vector<int64_t> dimensions(instr->dimensions().begin(),
288 instr->dimensions().end());
289 auto reduce_op = b.create<mhlo::ReduceOp>(
290 loc, llvm::makeArrayRef(loads).take_front(loads.size() / 2),
291 llvm::makeArrayRef(loads).drop_front(loads.size() / 2),
292 GetI64DenseElementsAttr(dimensions));
293
294 TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
295 *instr->called_computations()[0], &reduce_op.body(), &builder_));
296 op = reduce_op;
297 } else {
298 TF_ASSIGN_OR_RETURN(
299 op, xla::HloFunctionImporter::ImportInstruction(instr, loads, &b));
300 }
301 TF_RET_CHECK(op->getNumResults() == num_results);
302 for (int i = 0; i < results.size(); i++) {
303 b.create<mlir::memref::TensorStoreOp>(loc, op->getResult(i), results[i]);
304 }
305 return op;
306 }
307
CreateOpInFusion(const HloInstruction * instr)308 StatusOr<mlir::Operation*> LhloDialectEmitter::CreateOpInFusion(
309 const HloInstruction* instr) {
310 llvm::SmallVector<Value, 4> operands;
311 size_t num_arguments, num_results;
312 TF_RETURN_IF_ERROR(CreateOperands(instr, absl::nullopt,
313 TokenLoweringMode::kFailToLower, operands,
314 num_arguments, num_results));
315 TF_ASSIGN_OR_RETURN(
316 auto op, CreateOpInFusion(instr, operands, num_arguments, num_results));
317 return op->getParentOp();
318 }
319
EmitOp(const HloInstruction * instr)320 StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(
321 const HloInstruction* instr) {
322 using xla::HloOpcode;
323 switch (instr->opcode()) {
324 case HloOpcode::kAddDependency:
325 return nullptr;
326 case HloOpcode::kAfterAll:
327 // LMHLO is already ordered. This assumption may be broken after
328 // introducing async regions and partial orders.
329 return nullptr;
330 case HloOpcode::kAllToAll:
331 return EmitAllToAllOp(instr);
332 case HloOpcode::kAllGather:
333 return EmitAllGatherOp(instr);
334 case HloOpcode::kAllReduce:
335 return EmitAllReduceOp(instr);
336 case HloOpcode::kAllReduceStart:
337 return EmitAllReduceStartOp(instr);
338 case HloOpcode::kAllReduceDone:
339 return EmitAllReduceDoneOp(instr);
340 case HloOpcode::kReduceScatter:
341 return EmitReduceScatterOp(instr);
342 case HloOpcode::kBitcast:
343 return EmitBitcast(instr);
344 case HloOpcode::kCollectivePermute:
345 return EmitCollectivePermuteOp(instr);
346 case HloOpcode::kConditional:
347 return EmitCaseOp(instr);
348 case HloOpcode::kFft:
349 return EmitFftOp(instr);
350 case HloOpcode::kGetTupleElement:
351 return nullptr;
352 case HloOpcode::kInfeed:
353 return EmitInfeedOp(instr);
354 case HloOpcode::kOutfeed:
355 return EmitOutfeedOp(instr);
356 case HloOpcode::kPartitionId:
357 return CreateOpWithoutAttrs<lmhlo::PartitionIdOp>(instr);
358 case HloOpcode::kReplicaId:
359 return CreateOpWithoutAttrs<lmhlo::ReplicaIdOp>(instr);
360 case HloOpcode::kTriangularSolve:
361 return EmitTriangularSolveOp(instr);
362 case HloOpcode::kTuple:
363 return nullptr;
364 case HloOpcode::kSort:
365 return EmitSortOp(instr);
366 case HloOpcode::kFusion:
367 return EmitFusionOp(instr);
368 case HloOpcode::kScatter:
369 return EmitScatterOp(instr);
370 case HloOpcode::kSelectAndScatter:
371 return EmitSelectAndScatterOp(instr);
372 case HloOpcode::kCustomCall:
373 return EmitCustomCallOp(instr);
374 case HloOpcode::kConstant:
375 return EmitConstant(instr);
376 case HloOpcode::kRngGetAndUpdateState:
377 return EmitRngGetAndUpdateStateOp(instr);
378 case HloOpcode::kWhile:
379 return EmitWhileOp(instr);
380
381 case HloOpcode::kAbs:
382 case HloOpcode::kAdd:
383 case HloOpcode::kAnd:
384 case HloOpcode::kAtan2:
385 case HloOpcode::kBitcastConvert:
386 case HloOpcode::kBroadcast:
387 case HloOpcode::kCeil:
388 case HloOpcode::kCbrt:
389 case HloOpcode::kClamp:
390 case HloOpcode::kClz:
391 case HloOpcode::kCompare:
392 case HloOpcode::kComplex:
393 case HloOpcode::kConcatenate:
394 case HloOpcode::kConvert:
395 case HloOpcode::kCos:
396 case HloOpcode::kDivide:
397 case HloOpcode::kDot:
398 case HloOpcode::kDynamicSlice:
399 case HloOpcode::kDynamicUpdateSlice:
400 case HloOpcode::kExp:
401 case HloOpcode::kExpm1:
402 case HloOpcode::kFloor:
403 case HloOpcode::kGather:
404 case HloOpcode::kImag:
405 case HloOpcode::kIota:
406 case HloOpcode::kIsFinite:
407 case HloOpcode::kLog:
408 case HloOpcode::kLog1p:
409 case HloOpcode::kMap:
410 case HloOpcode::kMaximum:
411 case HloOpcode::kMinimum:
412 case HloOpcode::kMultiply:
413 case HloOpcode::kNegate:
414 case HloOpcode::kNot:
415 case HloOpcode::kOr:
416 case HloOpcode::kPad:
417 case HloOpcode::kPopulationCount:
418 case HloOpcode::kPower:
419 case HloOpcode::kReal:
420 case HloOpcode::kReshape:
421 case HloOpcode::kReducePrecision:
422 case HloOpcode::kReduceWindow:
423 case HloOpcode::kRemainder:
424 case HloOpcode::kReverse:
425 case HloOpcode::kRoundNearestAfz:
426 case HloOpcode::kRsqrt:
427 case HloOpcode::kSelect:
428 case HloOpcode::kShiftLeft:
429 case HloOpcode::kShiftRightLogical:
430 case HloOpcode::kShiftRightArithmetic:
431 case HloOpcode::kSign:
432 case HloOpcode::kSin:
433 case HloOpcode::kSlice:
434 case HloOpcode::kSqrt:
435 case HloOpcode::kSubtract:
436 case HloOpcode::kTanh:
437 case HloOpcode::kTranspose:
438 case HloOpcode::kXor:
439 case HloOpcode::kCopy:
440 case HloOpcode::kReduce:
441 return CreateOpInFusion(instr);
442 default:
443 llvm::errs() << instr->ToString();
444 return tensorflow::errors::Internal(
445 absl::StrCat("LHLO opcode ", xla::HloOpcodeString(instr->opcode()),
446 " is not supported."));
447 }
448 }
449
DefaultAction(const HloInstruction * instr)450 Status LhloDialectEmitter::DefaultAction(const HloInstruction* instr) {
451 return EmitOp(instr).status();
452 }
453
EmitSortOp(const HloInstruction * instr)454 StatusOr<lmhlo::SortOp> LhloDialectEmitter::EmitSortOp(
455 const HloInstruction* instr) {
456 TF_ASSIGN_OR_RETURN(auto sort, CreateOpWithoutAttrs<lmhlo::SortOp>(instr));
457 auto* sort_instr = xla::Cast<xla::HloSortInstruction>(instr);
458 sort.dimensionAttr(builder_.getI64IntegerAttr(sort_instr->sort_dimension()));
459 sort.is_stableAttr(builder_.getBoolAttr(sort_instr->is_stable()));
460 TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
461 *sort_instr->called_computations()[0], &sort.comparator(), &builder_));
462 return sort;
463 }
464
465 // Walks MHLO::TupleOp recursively.
WalkTuplePostOrder(Value v,const std::function<Status (Value)> & visitor)466 Status WalkTuplePostOrder(Value v,
467 const std::function<Status(Value)>& visitor) {
468 if (auto* op = v.getDefiningOp()) {
469 if (auto tuple = dyn_cast<mhlo::TupleOp>(op)) {
470 for (Value sub_v : tuple.val()) {
471 TF_RETURN_IF_ERROR(WalkTuplePostOrder(sub_v, visitor));
472 }
473 return Status::OK();
474 }
475 }
476 return visitor(v);
477 }
478
RewriteFusionOperand(const HloInstruction * root,const Shape & shape,xla::ShapeIndex * shape_index,OpBuilder * b,Location loc)479 StatusOr<Value> LhloDialectEmitter::RewriteFusionOperand(
480 const HloInstruction* root, const Shape& shape,
481 xla::ShapeIndex* shape_index, OpBuilder* b, Location loc) {
482 if (shape.IsTuple()) {
483 llvm::SmallVector<Value, 4> values;
484 for (int i = 0; i < shape.tuple_shapes_size(); ++i) {
485 shape_index->push_back(i);
486 TF_ASSIGN_OR_RETURN(
487 auto v, RewriteFusionOperand(root, shape.tuple_shapes(i), shape_index,
488 b, loc));
489 values.push_back(v);
490 shape_index->pop_back();
491 }
492 return Value(b->create<mhlo::TupleOp>(loc, values));
493 }
494 TF_ASSIGN_OR_RETURN(Value memref,
495 GetOrCreateArrayView(root, shape, *shape_index));
496 auto load = b->create<memref::TensorLoadOp>(loc, memref);
497 if (shape.layout() !=
498 xla::LayoutUtil::MakeDescendingLayout(shape.dimensions().size())) {
499 llvm::SmallVector<int64_t, 4> minor_to_major(
500 shape.layout().minor_to_major().begin(),
501 shape.layout().minor_to_major().end());
502 load->setAttr("minor_to_major", GetLayoutAttribute(shape.layout(), b));
503 }
504 return load.getResult();
505 }
506
507 // Emit a lmhlo.fusion based on XLA HLO fusion. Structurally they are not neatly
508 // equivalent. Specifically, XLA HLO fusion:
509 // fused_computation {
510 // %p0 = parameter(0)
511 // %p1 = parameter(1)
512 // ...
513 // ROOT %ret = ...
514 // }
515 // will be converted to
516 // lmhlo.fusion() { // no explicit operands
517 // // capturing outside buffers
518 // %p0 = tensor_load(%arg0) : memref<...> -> tensor<...>
519 // %p1 = tensor_load(%arg1) : memref<...> -> tensor<...>
520 // ...
521 // tensor_store ..., %ret // store a tensor to a memref
522 // }
EmitFusionOp(const HloInstruction * instr)523 StatusOr<lmhlo::FusionOp> LhloDialectEmitter::EmitFusionOp(
524 const HloInstruction* instr) {
525 Location loc = getLocation(instr);
526
527 auto* fusion_instr = xla::Cast<xla::HloFusionInstruction>(instr);
528
529 auto fusion = builder_.create<lmhlo::FusionOp>(getLocation(instr));
530 auto after_fusion = builder_.saveInsertionPoint();
531 auto reverter = xla::MakeCleanup(
532 [this, after_fusion] { builder_.restoreInsertionPoint(after_fusion); });
533 builder_ = mlir::OpBuilder(fusion);
534
535 auto region_builder = OpBuilder::atBlockBegin(&fusion.region().front());
536
537 llvm::SmallVector<Value, 8> arguments;
538 for (int i = 0; i < instr->operands().size(); ++i) {
539 const HloInstruction* operand = instr->operand(i);
540 xla::ShapeIndex shape_index;
541 TF_ASSIGN_OR_RETURN(
542 auto arg, RewriteFusionOperand(operand, operand->shape(), &shape_index,
543 ®ion_builder, loc));
544 arguments.push_back(arg);
545 }
546
547 TF_ASSIGN_OR_RETURN(Value result,
548 xla::HloFunctionImporter::ImportInstructions(
549 *fusion_instr->fused_instructions_computation(),
550 arguments, ®ion_builder));
551 {
552 int i = 0;
553 llvm::SmallVector<Value, 4> output;
554 TF_RETURN_IF_ERROR(GetOrCreateView(instr, &output));
555 TF_RETURN_IF_ERROR(WalkTuplePostOrder(result, [&](Value v) mutable {
556 region_builder.create<memref::TensorStoreOp>(loc, v, output[i++]);
557 return Status::OK();
558 }));
559 if (i != output.size()) {
560 return xla::InternalError("output sizes don't match");
561 }
562 }
563
564 // Fold GTE/Tuple pairs.
565 //
566 // Since the fused region refers to values in its parent region, we can't
567 // call applyPatternAndFoldGreedily. We optimize it manually.
568 //
569 // Only walk once, because post-ordering is exactly what we need for GTE
570 // optimizations.
571 fusion.region().walk([](mhlo::GetTupleElementOp gte) {
572 SmallVector<Value, 4> folded_values;
573 if (succeeded(OpBuilder(gte).tryFold(gte, folded_values))) {
574 gte.replaceAllUsesWith(folded_values[0]);
575 }
576 });
577
578 // Effectively a DCE on the region.
579 {
580 llvm::SmallVector<mlir::Operation*, 4> ops;
581 fusion.region().walk([&](mlir::Operation* op) { ops.push_back(op); });
582 // Visit the user first.
583 std::reverse(ops.begin(), ops.end());
584 for (auto op : ops) {
585 if (isOpTriviallyDead(op)) op->erase();
586 }
587 }
588
589 return fusion;
590 }
591
592 StatusOr<mhlo::ScatterDimensionNumbers>
GetScatterDimensionNumbers(const HloInstruction * instr,mlir::MLIRContext * context)593 LhloDialectEmitter::GetScatterDimensionNumbers(const HloInstruction* instr,
594 mlir::MLIRContext* context) {
595 auto* scatter_instr = xla::Cast<xla::HloScatterInstruction>(instr);
596
597 const xla::ScatterDimensionNumbers& xla_scatter_dim =
598 scatter_instr->scatter_dimension_numbers();
599
600 mlir::Builder builder(context);
601 auto get_i64_array_attr =
602 [builder](absl::Span<const xla::int64> container) mutable {
603 return builder.getI64TensorAttr(
604 {container.data(), static_cast<size_t>(container.size())});
605 };
606 auto scatter_dimension_numbers = mhlo::ScatterDimensionNumbers::get(
607 get_i64_array_attr(xla_scatter_dim.update_window_dims()),
608 get_i64_array_attr(xla_scatter_dim.inserted_window_dims()),
609 get_i64_array_attr(xla_scatter_dim.scatter_dims_to_operand_dims()),
610 builder.getI64IntegerAttr(xla_scatter_dim.index_vector_dim()), context);
611 return scatter_dimension_numbers;
612 }
613
EmitScatterOp(const HloInstruction * instr)614 StatusOr<lmhlo::ScatterOp> LhloDialectEmitter::EmitScatterOp(
615 const HloInstruction* instr) {
616 TF_ASSIGN_OR_RETURN(auto scatter,
617 CreateOpWithoutAttrs<lmhlo::ScatterOp>(instr));
618
619 // copy attributes
620 auto* scatter_instr = xla::Cast<xla::HloScatterInstruction>(instr);
621
622 TF_ASSIGN_OR_RETURN(auto scatter_dimension_numbers,
623 GetScatterDimensionNumbers(instr, builder_.getContext()));
624 scatter.scatter_dimension_numbersAttr(scatter_dimension_numbers);
625 scatter.indices_are_sortedAttr(
626 builder_.getBoolAttr(scatter_instr->indices_are_sorted()));
627 scatter.unique_indicesAttr(
628 builder_.getBoolAttr(scatter_instr->unique_indices()));
629
630 // import update computation as region
631 TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
632 *scatter_instr->called_computations()[0], &scatter.update_computation(),
633 &builder_));
634
635 return scatter;
636 }
637
EmitSelectAndScatterOp(const HloInstruction * instr)638 StatusOr<lmhlo::SelectAndScatterOp> LhloDialectEmitter::EmitSelectAndScatterOp(
639 const HloInstruction* instr) {
640 TF_ASSIGN_OR_RETURN(auto select_and_scatter,
641 CreateOpWithoutAttrs<lmhlo::SelectAndScatterOp>(instr));
642
643 // copy attributes
644 auto* select_and_scatter_instr =
645 xla::Cast<xla::HloSelectAndScatterInstruction>(instr);
646 const xla::Window& window = select_and_scatter_instr->window();
647
648 if (xla::window_util::HasDilation(window)) {
649 return xla::Unimplemented("Dilation for SelectAndScatter is not supported");
650 }
651
652 select_and_scatter.window_dimensionsAttr(
653 GetWindowElements(window, [](const xla::WindowDimension& dim) {
654 return static_cast<int64_t>(dim.size());
655 }));
656 select_and_scatter.window_stridesAttr(
657 GetWindowElements(window, [](const xla::WindowDimension& dim) {
658 return static_cast<int64_t>(dim.stride());
659 }));
660 select_and_scatter.paddingAttr(
661 GetWindowElements(window, [](const xla::WindowDimension& dim) {
662 return static_cast<int64_t>(dim.padding_low());
663 }));
664
665 // import select and scatter computation as region
666 TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
667 *select_and_scatter_instr->select(), &select_and_scatter.select(),
668 &builder_));
669 TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
670 *select_and_scatter_instr->scatter(), &select_and_scatter.scatter(),
671 &builder_));
672 return select_and_scatter;
673 }
674
EmitCustomCallOp(const HloInstruction * instr)675 StatusOr<mlir::Operation*> LhloDialectEmitter::EmitCustomCallOp(
676 const HloInstruction* instr) {
677 auto* custom_call_instr = xla::Cast<xla::HloCustomCallInstruction>(instr);
678
679 if (xla::gpu::IsCustomCallToCusolver(*instr)) {
680 return EmitCholesky(custom_call_instr);
681 }
682
683 if (xla::gpu::IsCublasGemm(*instr)) {
684 return EmitGemm(custom_call_instr);
685 }
686
687 if (xla::gpu::IsCustomCallToDnnConvolution(*instr)) {
688 return EmitDnnConvolution(custom_call_instr);
689 }
690
691 if (xla::gpu::IsCustomCallToDnnBatchNorm(*instr)) {
692 return EmitDnnBatchNorm(custom_call_instr);
693 }
694
695 // For custom call, if there are any token operands or results, they will not
696 // be represented in LHLO so we need to remember the mapping. First create
697 // operands where each token is replaced with a null Value.
698 llvm::SmallVector<Value, 4> operands;
699 size_t num_arguments, num_results;
700 TF_RETURN_IF_ERROR(CreateOperands(instr, /*num_operands=*/absl::nullopt,
701 TokenLoweringMode::kUseNull, operands,
702 num_arguments, num_results));
703
704 // Now check if any of the operands is Null, which would indicate the presence
705 // of a token in the input or output.
706 bool has_token = llvm::any_of(operands, [](Value v) { return !v; });
707
708 lmhlo::CustomCallTargetArgMapping target_mapping;
709 if (has_token) {
710 // If there was a token, squeeze all the non-token arguments and results
711 // (in-place) and remember the mapping.
712 int next_index = 0;
713 llvm::SmallVector<int64_t> arg_to_target_arg_mapping;
714 for (int i = 0; i < num_arguments; ++i) {
715 if (operands[i]) {
716 arg_to_target_arg_mapping.push_back(i);
717 operands[next_index++] = operands[i];
718 }
719 }
720 // Size of arg_to_target_arg_mapping is the number of arguments in LHLO.
721 llvm::SmallVector<int64_t> result_to_target_result_mapping;
722 for (int i = num_arguments; i < operands.size(); ++i) {
723 if (operands[i]) {
724 result_to_target_result_mapping.push_back(i - num_arguments);
725 operands[next_index++] = operands[i];
726 }
727 }
728
729 // Build the mapping attribute.
730 target_mapping = lmhlo::CustomCallTargetArgMapping::get(
731 builder_.getI64IntegerAttr(num_arguments),
732 builder_.getI64IntegerAttr(num_results),
733 builder_.getI64ArrayAttr(arg_to_target_arg_mapping),
734 builder_.getI64ArrayAttr(result_to_target_result_mapping),
735 builder_.getContext());
736
737 // Drop the remaining operands and adjust num_arguments and num_results
738 // for LMHLO creation.
739 operands.resize(next_index);
740 num_arguments = arg_to_target_arg_mapping.size();
741 num_results = result_to_target_result_mapping.size();
742 }
743
744 auto custom_call = CreateOpWithoutAttrs<lmhlo::CustomCallOp>(instr, operands);
745 TF_ASSIGN_OR_RETURN(
746 auto mlir_api_version,
747 ConvertCustomCallApiVersion(custom_call_instr->api_version()));
748 custom_call.call_target_nameAttr(
749 builder_.getStringAttr(custom_call_instr->custom_call_target()));
750 custom_call.backend_configAttr(
751 builder_.getStringAttr(custom_call_instr->opaque()));
752 custom_call.api_versionAttr(mhlo::CustomCallApiVersionAttr::get(
753 builder_.getContext(), mlir_api_version));
754 const int32_t segments[2] = {static_cast<int32_t>(num_arguments),
755 static_cast<int32_t>(num_results)};
756 custom_call->setAttr(lmhlo::CustomCallOp::getOperandSegmentSizeAttr(),
757 builder_.getI32VectorAttr(segments));
758 if (target_mapping) custom_call.target_arg_mappingAttr(target_mapping);
759 return custom_call.getOperation();
760 }
761
EmitCholesky(const HloCustomCallInstruction * custom_call)762 StatusOr<lmhlo_gpu::CholeskyOp> LhloDialectEmitter::EmitCholesky(
763 const HloCustomCallInstruction* custom_call) {
764 TF_ASSIGN_OR_RETURN(auto cholesky_op,
765 CreateOpWithoutAttrs<lmhlo_gpu::CholeskyOp>(custom_call));
766 TF_ASSIGN_OR_RETURN(xla::CholeskyOptions options,
767 custom_call->backend_config<xla::CholeskyOptions>());
768 cholesky_op.is_lowerAttr(builder_.getBoolAttr(options.lower()));
769 return cholesky_op;
770 }
771
EmitGemm(const HloCustomCallInstruction * custom_call)772 StatusOr<Operation*> LhloDialectEmitter::EmitGemm(
773 const HloCustomCallInstruction* custom_call) {
774 TF_ASSIGN_OR_RETURN(
775 auto const config,
776 custom_call->backend_config<xla::gpu::GemmBackendConfig>());
777
778 auto set_common_attributes = [&](auto op) -> Operation* {
779 auto hlo_dims = config.dot_dimension_numbers();
780 auto mlir_dims = mhlo::DotDimensionNumbers::get(
781 GetI64DenseElementsAttr(hlo_dims.lhs_batch_dimensions()),
782 GetI64DenseElementsAttr(hlo_dims.rhs_batch_dimensions()),
783 GetI64DenseElementsAttr(hlo_dims.lhs_contracting_dimensions()),
784 GetI64DenseElementsAttr(hlo_dims.rhs_contracting_dimensions()),
785 builder_.getContext());
786 op.dot_dimension_numbersAttr(mlir_dims);
787 op.alpha_realAttr(builder_.getF64FloatAttr(config.alpha_real()));
788 op.alpha_imagAttr(builder_.getF64FloatAttr(config.alpha_imag()));
789 op.batch_sizeAttr(builder_.getI64IntegerAttr(config.batch_size()));
790 op.lhs_strideAttr(builder_.getI64IntegerAttr(config.lhs_stride()));
791 op.rhs_strideAttr(builder_.getI64IntegerAttr(config.rhs_stride()));
792 if (config.algorithm_case() ==
793 xla::gpu::GemmBackendConfig::kSelectedAlgorithm) {
794 op.algorithmAttr(builder_.getI64IntegerAttr(config.selected_algorithm()));
795 }
796 return op.getOperation();
797 };
798
799 if (custom_call->operand_count() == 2) {
800 TF_ASSIGN_OR_RETURN(auto gemm,
801 CreateOpWithoutAttrs<lmhlo_gpu::GEMMOp>(custom_call));
802 return set_common_attributes(gemm);
803 }
804
805 if (custom_call->operand_count() == 3) {
806 TF_ASSIGN_OR_RETURN(
807 auto gemm_bias,
808 CreateOpWithoutAttrs<lmhlo_gpu::GEMM_BiasOp>(custom_call));
809 gemm_bias.betaAttr(builder_.getF64FloatAttr(config.beta()));
810 return set_common_attributes(gemm_bias);
811 }
812
813 return xla::InvalidArgument("GEMM custom call should have 2 or 3 operands");
814 }
815
GetLHLOActivation(stream_executor::dnn::ActivationMode activation)816 static StatusOr<mlir::lmhlo_gpu::Activation> GetLHLOActivation(
817 stream_executor::dnn::ActivationMode activation) {
818 switch (activation) {
819 case stream_executor::dnn::kNone:
820 return mlir::lmhlo_gpu::Activation::None;
821 case stream_executor::dnn::kSigmoid:
822 return mlir::lmhlo_gpu::Activation::Sigmoid;
823 case stream_executor::dnn::kRelu:
824 return mlir::lmhlo_gpu::Activation::Relu;
825 case stream_executor::dnn::kRelu6:
826 return mlir::lmhlo_gpu::Activation::Relu6;
827 case stream_executor::dnn::kReluX:
828 return mlir::lmhlo_gpu::Activation::ReluX;
829 case stream_executor::dnn::kTanh:
830 return mlir::lmhlo_gpu::Activation::Tanh;
831 case stream_executor::dnn::kBandPass:
832 return mlir::lmhlo_gpu::Activation::BandPass;
833 default:
834 return xla::InternalError("Unknown activation");
835 }
836 }
837
EmitDnnConvolution(const HloCustomCallInstruction * custom_call)838 StatusOr<Operation*> LhloDialectEmitter::EmitDnnConvolution(
839 const HloCustomCallInstruction* custom_call) {
840 TF_ASSIGN_OR_RETURN(
841 auto const backend_config,
842 custom_call->backend_config<xla::gpu::CudnnConvBackendConfig>());
843
844 TF_ASSIGN_OR_RETURN(const xla::gpu::CudnnConvKind kind,
845 xla::gpu::GetCudnnConvKind(custom_call));
846
847 auto get_layout_attribute = [&](const xla::Layout& layout) {
848 std::vector<int64_t> minor_to_major(layout.minor_to_major_size());
849 absl::c_transform(layout.minor_to_major(), minor_to_major.begin(),
850 [](xla::int64 x) { return static_cast<int64_t>(x); });
851 return builder_.getI64ArrayAttr(minor_to_major);
852 };
853
854 auto set_common_conv_attributes = [&, this](auto op) -> Operation* {
855 const xla::Window& window = custom_call->window();
856 // Window size for Cudnn Conv is same as the kernel size.
857 op.window_stridesAttr(
858 GetWindowElements(window, [](const xla::WindowDimension& dim) {
859 return static_cast<int64_t>(dim.stride());
860 }));
861 // Cudnn Conv requires low and high padding to be equal.
862 op.paddingAttr(
863 GetWindowElements(window, [](const xla::WindowDimension& dim) {
864 return static_cast<int64_t>(dim.padding_low());
865 }));
866 // LHS dilation is encoded in base_dilation of the backend config.
867 // RHS dilation is encoded in window_dilation of the backend config.
868 op.lhs_dilationAttr(
869 GetWindowElements(window, [](const xla::WindowDimension& dim) {
870 return static_cast<int64_t>(dim.base_dilation());
871 }));
872 op.rhs_dilationAttr(
873 GetWindowElements(window, [](const xla::WindowDimension& dim) {
874 return static_cast<int64_t>(dim.window_dilation());
875 }));
876 // Setup window reversal.
877 auto window_reversal = llvm::to_vector<4>(llvm::map_range(
878 window.dimensions(),
879 [](const xla::WindowDimension& dim) { return dim.window_reversal(); }));
880 auto type = RankedTensorType::get(op.window_strides()->getType().getShape(),
881 builder_.getIntegerType(/*width=*/1));
882 op.window_reversalAttr(DenseElementsAttr::get(type, window_reversal));
883
884 op.dimension_numbersAttr(xla::ConvertConvDimensionNumbers(
885 custom_call->convolution_dimension_numbers(), &builder_));
886 op.feature_group_countAttr(
887 builder_.getI64IntegerAttr(custom_call->feature_group_count()));
888 op.batch_group_countAttr(
889 builder_.getI64IntegerAttr(custom_call->batch_group_count()));
890 op.precision_configAttr(xla::ConvertPrecisionConfig(
891 &custom_call->precision_config(), &builder_));
892 op.result_scaleAttr(
893 builder_.getF64FloatAttr(backend_config.conv_result_scale()));
894 auto config = mlir::lmhlo_gpu::ConvolutionBackendConfig::get(
895 builder_.getI64IntegerAttr(backend_config.algorithm()),
896 builder_.getBoolAttr(backend_config.tensor_ops_enabled()),
897 get_layout_attribute(custom_call->operand(0)->shape().layout()),
898 get_layout_attribute(custom_call->operand(1)->shape().layout()),
899 get_layout_attribute(custom_call->shape().tuple_shapes(0).layout()),
900 builder_.getContext());
901 op.backend_configAttr(config);
902
903 return op.getOperation();
904 };
905
906 auto set_activation = [&, this](auto op) -> Status {
907 auto se_activation = static_cast<stream_executor::dnn::ActivationMode>(
908 backend_config.activation_mode());
909 TF_ASSIGN_OR_RETURN(mlir::lmhlo_gpu::Activation activation,
910 GetLHLOActivation(se_activation));
911 StringAttr activation_attr = builder_.getStringAttr(
912 mlir::lmhlo_gpu::stringifyActivation(activation));
913 op.activation_modeAttr(activation_attr);
914 return Status::OK();
915 };
916
917 switch (kind) {
918 case xla::gpu::CudnnConvKind::kForward: {
919 TF_ASSIGN_OR_RETURN(
920 auto cnn_forward,
921 CreateOpWithoutAttrs<lmhlo_gpu::ConvForwardOp>(custom_call));
922 return set_common_conv_attributes(cnn_forward);
923 }
924 case xla::gpu::CudnnConvKind::kBackwardInput: {
925 TF_ASSIGN_OR_RETURN(
926 auto cnn_backward,
927 CreateOpWithoutAttrs<lmhlo_gpu::ConvBackwardInputOp>(custom_call));
928 return set_common_conv_attributes(cnn_backward);
929 }
930 case xla::gpu::CudnnConvKind::kBackwardFilter: {
931 TF_ASSIGN_OR_RETURN(
932 auto cnn_backward,
933 CreateOpWithoutAttrs<lmhlo_gpu::ConvBackwardFilterOp>(custom_call));
934 return set_common_conv_attributes(cnn_backward);
935 }
936 case xla::gpu::CudnnConvKind::kForwardActivation: {
937 // Fused conv can be either with side input or without.
938 if (custom_call->operand_count() == 3) {
939 TF_ASSIGN_OR_RETURN(
940 auto cnn_fused,
941 CreateOpWithoutAttrs<lmhlo_gpu::ConvForwardFusedOp>(custom_call));
942 TF_RETURN_IF_ERROR(set_activation(cnn_fused));
943 return set_common_conv_attributes(cnn_fused);
944 }
945
946 TF_RET_CHECK(custom_call->operand_count() == 4);
947 TF_ASSIGN_OR_RETURN(
948 auto cnn_fused_side_input,
949 CreateOpWithoutAttrs<lmhlo_gpu::ConvForwardFusedSideInputOp>(
950 custom_call));
951 cnn_fused_side_input.side_input_scaleAttr(
952 builder_.getF64FloatAttr(backend_config.side_input_scale()));
953 TF_RETURN_IF_ERROR(set_activation(cnn_fused_side_input));
954 return set_common_conv_attributes(cnn_fused_side_input);
955 }
956 }
957 }
958
EmitDnnBatchNorm(const HloCustomCallInstruction * custom_call)959 StatusOr<Operation*> LhloDialectEmitter::EmitDnnBatchNorm(
960 const HloCustomCallInstruction* custom_call) {
961 const xla::int64 num_operands = custom_call->operand_count();
962 auto set_batchnorm_attributes = [&](auto op) -> StatusOr<Operation*> {
963 // The last 2 operands of a custom call for batch norm are the epsilon and
964 // feature_index.
965 const HloInstruction* epsilon = custom_call->operand(num_operands - 2);
966 TF_RET_CHECK(epsilon->IsConstant());
967 float epsilon_value = epsilon->literal().Get<float>({});
968
969 const HloInstruction* feature_index =
970 custom_call->operand(num_operands - 1);
971 TF_RET_CHECK(feature_index->IsConstant());
972 xla::int64 feature_index_value =
973 feature_index->literal().Get<xla::int64>({});
974
975 op.epsilonAttr(builder_.getF32FloatAttr(epsilon_value));
976 op.feature_indexAttr(builder_.getI64IntegerAttr(feature_index_value));
977 return op.getOperation();
978 };
979
980 const std::string& target = custom_call->custom_call_target();
981 if (target == xla::gpu::kCudnnBatchNormForwardTrainingCallTarget) {
982 TF_ASSIGN_OR_RETURN(auto fwd_training,
983 CreateOpWithoutAttrs<lmhlo_gpu::BatchNormTrainingOp>(
984 custom_call, num_operands - 2));
985 return set_batchnorm_attributes(fwd_training);
986 }
987
988 if (target == xla::gpu::kCudnnBatchNormBackwardCallTarget) {
989 TF_ASSIGN_OR_RETURN(auto backward,
990 CreateOpWithoutAttrs<lmhlo_gpu::BatchNormGradOp>(
991 custom_call, num_operands - 2));
992 return set_batchnorm_attributes(backward);
993 }
994
995 if (target == xla::gpu::kCudnnBatchNormForwardInferenceCallTarget) {
996 TF_ASSIGN_OR_RETURN(auto fwd_inference,
997 CreateOpWithoutAttrs<lmhlo_gpu::BatchNormInferenceOp>(
998 custom_call, num_operands - 2));
999 return set_batchnorm_attributes(fwd_inference);
1000 }
1001
1002 return xla::Unimplemented("Unsupported batch norm operation");
1003 }
1004
1005 // Convert an XLA HLO constant to a global_memref + get_global_memref pair.
EmitConstant(const HloInstruction * instr)1006 StatusOr<mlir::memref::GetGlobalOp> LhloDialectEmitter::EmitConstant(
1007 const HloInstruction* instr) {
1008 // Insert a global_memref in the module.
1009 Location loc = getLocation(instr);
1010
1011 auto const_instr = xla::Cast<xla::HloConstantInstruction>(instr);
1012 TF_RET_CHECK(const_instr->shape().IsArray() &&
1013 const_instr->shape().is_static());
1014 TF_ASSIGN_OR_RETURN(Type type, xla::ConvertShapeToType<MemRefType>(
1015 const_instr->shape(), builder_));
1016 auto memref_type = type.dyn_cast<MemRefType>();
1017 TF_RET_CHECK(memref_type != nullptr);
1018
1019 TF_ASSIGN_OR_RETURN(
1020 DenseElementsAttr initial_value,
1021 CreateDenseElementsAttrFromLiteral(const_instr->literal(), builder_));
1022
1023 std::string constant_name = xla::llvm_ir::ConstantNameToGlobalName(
1024 xla::llvm_ir::SanitizeConstantName(instr->name()));
1025
1026 // Insert the global memref at the top level.
1027 {
1028 OpBuilder::InsertionGuard guard(builder_);
1029 builder_.clearInsertionPoint();
1030 auto global_var = builder_.create<memref::GlobalOp>(
1031 loc, constant_name, builder_.getStringAttr("private"), memref_type,
1032 initial_value, true);
1033 SymbolTable(module_).insert(global_var);
1034 global_var.getOperation()->moveBefore(&module_.front());
1035
1036 // For operations that do not fold this constant value in their codegen, we
1037 // still need to materialize it into a buffer. Since buffer allocation is
1038 // already done, annotate the global_memref with the information to get to
1039 // the allocated buffer slice for this constant if need be.
1040 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
1041 assignment_.GetUniqueTopLevelSlice(instr));
1042 global_var->setAttr(
1043 "lmhlo.alloc",
1044 builder_.getIndexAttr(allocations_.find(slice.allocation())
1045 ->second.cast<BlockArgument>()
1046 .getArgNumber()));
1047 TF_RET_CHECK(slice.offset() == 0)
1048 << "Each constant should have its own allocation from BufferAssignment";
1049 TF_RET_CHECK(slice.allocation()->size() == slice.size())
1050 << "Each constant should have its own allocation from BufferAssignment";
1051 }
1052
1053 auto get_global_memref =
1054 builder_.create<memref::GetGlobalOp>(loc, memref_type, constant_name);
1055
1056 // Update the cache to remember this value.
1057 auto& cached_value = slices_[std::make_pair(instr, xla::ShapeIndex())];
1058 TF_RET_CHECK(cached_value == nullptr);
1059 cached_value = get_global_memref;
1060 return get_global_memref;
1061 }
1062
1063 namespace {
1064 template <typename OpT>
SetupChannelIdAttribute(OpT op,const xla::HloChannelInstruction * instr,mlir::Builder builder)1065 void SetupChannelIdAttribute(OpT op, const xla::HloChannelInstruction* instr,
1066 mlir::Builder builder) {
1067 if (instr->channel_id().has_value()) {
1068 op.channel_idAttr(mlir::mhlo::ChannelHandle::get(
1069 builder.getI64IntegerAttr(*instr->channel_id()),
1070 builder.getI64IntegerAttr(0), builder.getContext()));
1071 }
1072 }
1073
1074 template <typename OpT>
SetupCommonCollectiveOpAttributes(OpT op,const HloInstruction * instr,mlir::OpBuilder & builder)1075 Status SetupCommonCollectiveOpAttributes(OpT op, const HloInstruction* instr,
1076 mlir::OpBuilder& builder) {
1077 auto* collective = xla::Cast<xla::HloCollectiveInstruction>(instr);
1078 auto replica_groups_attr = xla::HloFunctionImporter::ConvertReplicaGroups(
1079 collective->replica_groups(), &builder);
1080 op->setAttr(replica_groups_attr.first, replica_groups_attr.second);
1081 op.constrain_layoutAttr(builder.getBoolAttr(collective->constrain_layout()));
1082 SetupChannelIdAttribute(op, collective, builder);
1083 return Status::OK();
1084 }
1085 } // namespace
1086
EmitAllToAllOp(const HloInstruction * instr)1087 StatusOr<lmhlo::AllToAllOp> LhloDialectEmitter::EmitAllToAllOp(
1088 const HloInstruction* instr) {
1089 TF_ASSIGN_OR_RETURN(auto all_to_all_op,
1090 CreateOpWithoutAttrs<lmhlo::AllToAllOp>(instr));
1091 auto* all_to_all = xla::Cast<xla::HloAllToAllInstruction>(instr);
1092 TF_RETURN_IF_ERROR(
1093 SetupCommonCollectiveOpAttributes(all_to_all_op, instr, builder_));
1094 if (all_to_all->split_dimension().has_value()) {
1095 all_to_all_op.split_dimensionAttr(
1096 builder_.getI64IntegerAttr(*all_to_all->split_dimension()));
1097 }
1098 return all_to_all_op;
1099 }
1100
EmitAllGatherOp(const HloInstruction * instr)1101 StatusOr<lmhlo::AllGatherOp> LhloDialectEmitter::EmitAllGatherOp(
1102 const HloInstruction* instr) {
1103 TF_ASSIGN_OR_RETURN(auto all_gather_op,
1104 CreateOpWithoutAttrs<lmhlo::AllGatherOp>(instr));
1105 auto* all_gather = xla::Cast<xla::HloAllGatherInstruction>(instr);
1106 TF_RETURN_IF_ERROR(
1107 SetupCommonCollectiveOpAttributes(all_gather_op, instr, builder_));
1108 all_gather_op.use_global_device_idsAttr(
1109 builder_.getBoolAttr(all_gather->use_global_device_ids()));
1110 all_gather_op.all_gather_dimensionAttr(
1111 builder_.getI64IntegerAttr(all_gather->all_gather_dimension()));
1112 return all_gather_op;
1113 }
1114
EmitAllReduceOp(const HloInstruction * instr)1115 StatusOr<lmhlo::AllReduceOp> LhloDialectEmitter::EmitAllReduceOp(
1116 const HloInstruction* instr) {
1117 TF_ASSIGN_OR_RETURN(auto all_reduce_op,
1118 CreateOpWithoutAttrs<lmhlo::AllReduceOp>(instr));
1119 auto* all_reduce = xla::Cast<xla::HloAllReduceInstruction>(instr);
1120 TF_RETURN_IF_ERROR(
1121 SetupCommonCollectiveOpAttributes(all_reduce_op, instr, builder_));
1122 all_reduce_op.use_global_device_idsAttr(
1123 builder_.getBoolAttr(all_reduce->use_global_device_ids()));
1124 TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
1125 *instr->called_computations()[0], &all_reduce_op.computation(),
1126 &builder_));
1127 return all_reduce_op;
1128 }
1129
EmitAllReduceStartOp(const HloInstruction * instr)1130 StatusOr<lmhlo_gpu::AllReduceStartOp> LhloDialectEmitter::EmitAllReduceStartOp(
1131 const HloInstruction* instr) {
1132 llvm::SmallVector<Value, 4> operands;
1133 for (const HloInstruction* operand : instr->operands()) {
1134 TF_RETURN_IF_ERROR(GetOrCreateView(operand, &operands));
1135 }
1136 // Only include result index {1}. {0} always aliases the inputs.
1137 TF_RETURN_IF_ERROR(GetOrCreateView(instr, &operands, /*result_subset=*/{1}));
1138
1139 Location loc = getLocation(instr);
1140 mlir::Type token_type = mlir::mhlo::TokenType::get(builder_.getContext());
1141 std::array<mlir::Type, 1> result_types = {token_type};
1142 lmhlo_gpu::AllReduceStartOp all_reduce_start_op =
1143 builder_.create<lmhlo_gpu::AllReduceStartOp>(loc, result_types, operands);
1144
1145 auto* all_reduce = xla::Cast<xla::HloAllReduceInstruction>(instr);
1146 TF_RETURN_IF_ERROR(
1147 SetupCommonCollectiveOpAttributes(all_reduce_start_op, instr, builder_));
1148 all_reduce_start_op.use_global_device_idsAttr(
1149 builder_.getBoolAttr(all_reduce->use_global_device_ids()));
1150 TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
1151 *instr->called_computations()[0], &all_reduce_start_op.computation(),
1152 &builder_));
1153
1154 TF_RET_CHECK(all_reduce_start_ops_.emplace(instr, all_reduce_start_op).second)
1155 << "all-reduce-start already lowered";
1156 return all_reduce_start_op;
1157 }
1158
EmitAllReduceDoneOp(const HloInstruction * instr)1159 StatusOr<lmhlo_gpu::AllReduceDoneOp> LhloDialectEmitter::EmitAllReduceDoneOp(
1160 const HloInstruction* instr) {
1161 auto it = all_reduce_start_ops_.find(instr->operand(0));
1162 TF_RET_CHECK(it != all_reduce_start_ops_.end())
1163 << "didn't find all-reduce-start op";
1164
1165 llvm::SmallVector<Value, 4> operands;
1166 operands.push_back(it->second.token());
1167 all_reduce_start_ops_.erase(it);
1168
1169 for (const HloInstruction* operand : instr->operands()) {
1170 TF_RETURN_IF_ERROR(GetOrCreateView(operand, &operands));
1171 }
1172 // We don't need to add buffers for the outputs, as these always alias inputs.
1173 return builder_.create<lmhlo_gpu::AllReduceDoneOp>(
1174 getLocation(instr), /*resultTypes=*/llvm::None, operands);
1175 }
1176
EmitReduceScatterOp(const HloInstruction * instr)1177 StatusOr<lmhlo::ReduceScatterOp> LhloDialectEmitter::EmitReduceScatterOp(
1178 const HloInstruction* instr) {
1179 TF_ASSIGN_OR_RETURN(auto reduce_scatter_op,
1180 CreateOpWithoutAttrs<lmhlo::ReduceScatterOp>(instr));
1181 auto* ars = xla::Cast<xla::HloReduceScatterInstruction>(instr);
1182 TF_RETURN_IF_ERROR(
1183 SetupCommonCollectiveOpAttributes(reduce_scatter_op, instr, builder_));
1184 reduce_scatter_op.use_global_device_idsAttr(
1185 builder_.getBoolAttr(ars->use_global_device_ids()));
1186 TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
1187 *instr->called_computations()[0], &reduce_scatter_op.computation(),
1188 &builder_));
1189 reduce_scatter_op.scatter_dimensionAttr(
1190 builder_.getI64IntegerAttr(ars->scatter_dimension()));
1191 return reduce_scatter_op;
1192 }
1193
1194 StatusOr<lmhlo::CollectivePermuteOp>
EmitCollectivePermuteOp(const HloInstruction * instr)1195 LhloDialectEmitter::EmitCollectivePermuteOp(const HloInstruction* instr) {
1196 TF_ASSIGN_OR_RETURN(auto permute_op,
1197 CreateOpWithoutAttrs<lmhlo::CollectivePermuteOp>(instr));
1198 auto* permute = xla::Cast<xla::HloCollectivePermuteInstruction>(instr);
1199 SetupChannelIdAttribute(permute_op, permute, builder_);
1200 mlir::NamedAttribute source_target_pairs_attr =
1201 xla::HloFunctionImporter::ConvertSourceTargetPairs(
1202 permute->source_target_pairs(), &builder_);
1203 permute_op->setAttr(source_target_pairs_attr.first,
1204 source_target_pairs_attr.second);
1205 return permute_op;
1206 }
1207
EmitInfeedOp(const HloInstruction * instr)1208 StatusOr<lmhlo::InfeedOp> LhloDialectEmitter::EmitInfeedOp(
1209 const HloInstruction* instr) {
1210 const HloInfeedInstruction* infeed = xla::Cast<HloInfeedInstruction>(instr);
1211 // HLO Infeed instruction has a single operand of token type and a tuple
1212 // with buffers and a token as its output. LMHLO Infeed operation does not
1213 // need the token operand or result, so drop it.
1214 SmallVector<Value, 2> operands;
1215 TF_RETURN_IF_ERROR(GetOrCreateView(instr, &operands, /*result_subset=*/{0}));
1216 auto infeed_op = CreateOpWithoutAttrs<lmhlo::InfeedOp>(instr, operands);
1217 infeed_op.configAttr(builder_.getStringAttr(infeed->infeed_config()));
1218 return infeed_op;
1219 }
1220
EmitOutfeedOp(const HloInstruction * instr)1221 StatusOr<lmhlo::OutfeedOp> LhloDialectEmitter::EmitOutfeedOp(
1222 const HloInstruction* instr) {
1223 const HloOutfeedInstruction* outfeed =
1224 xla::Cast<HloOutfeedInstruction>(instr);
1225 // HLO outfeed instruction has 2 operands, the source and a token, and a
1226 // single token output. LMHLO Outfeed does not need the token operand and
1227 // result, do drop it.
1228 SmallVector<Value, 2> operands;
1229 TF_RETURN_IF_ERROR(GetOrCreateView(instr->operand(0), &operands));
1230 auto outfeed_op = CreateOpWithoutAttrs<lmhlo::OutfeedOp>(instr, operands);
1231 outfeed_op.configAttr(builder_.getStringAttr(outfeed->outfeed_config()));
1232 return outfeed_op;
1233 }
1234
1235 xla::StatusOr<lmhlo::RngGetAndUpdateStateOp>
EmitRngGetAndUpdateStateOp(const xla::HloInstruction * instr)1236 LhloDialectEmitter::EmitRngGetAndUpdateStateOp(
1237 const xla::HloInstruction* instr) {
1238 TF_ASSIGN_OR_RETURN(
1239 auto rng, CreateOpWithoutAttrs<lmhlo::RngGetAndUpdateStateOp>(instr));
1240 auto hlo_rng = xla::Cast<xla::HloRngGetAndUpdateStateInstruction>(instr);
1241 rng.deltaAttr(builder_.getI64IntegerAttr(hlo_rng->delta()));
1242 return rng;
1243 }
1244
EmitFftOp(const HloInstruction * instr)1245 xla::StatusOr<lmhlo::FftOp> LhloDialectEmitter::EmitFftOp(
1246 const HloInstruction* instr) {
1247 auto hlo_fft = xla::Cast<xla::HloFftInstruction>(instr);
1248 TF_ASSIGN_OR_RETURN(auto fft, CreateOpWithoutAttrs<lmhlo::FftOp>(instr));
1249 TF_ASSIGN_OR_RETURN(mlir::mhlo::FftType fft_type,
1250 xla::ConvertFftType(hlo_fft->fft_type()));
1251 StringAttr fft_type_attr =
1252 builder_.getStringAttr(mlir::mhlo::stringifyFftType(fft_type));
1253 fft.fft_typeAttr(fft_type_attr);
1254 fft.fft_lengthAttr(GetI64DenseElementsAttr(instr->fft_length()));
1255 return fft;
1256 }
1257
1258 xla::StatusOr<lmhlo::TriangularSolveOp>
EmitTriangularSolveOp(const xla::HloInstruction * instr)1259 LhloDialectEmitter::EmitTriangularSolveOp(const xla::HloInstruction* instr) {
1260 auto hlo_triangular_solve =
1261 xla::Cast<xla::HloTriangularSolveInstruction>(instr);
1262 TF_ASSIGN_OR_RETURN(auto triangular_solve,
1263 CreateOpWithoutAttrs<lmhlo::TriangularSolveOp>(instr));
1264 const xla::TriangularSolveOptions& options =
1265 hlo_triangular_solve->triangular_solve_options();
1266 triangular_solve.left_sideAttr(builder_.getBoolAttr(options.left_side()));
1267 triangular_solve.lowerAttr(builder_.getBoolAttr(options.lower()));
1268 triangular_solve.unit_diagonalAttr(
1269 builder_.getBoolAttr(options.unit_diagonal()));
1270 TF_ASSIGN_OR_RETURN(mlir::mhlo::Transpose transpose,
1271 xla::ConvertTranspose(options.transpose_a()));
1272 triangular_solve.transpose_aAttr(
1273 builder_.getStringAttr(mlir::mhlo::stringifyTranspose(transpose)));
1274 triangular_solve.layout_aAttr(
1275 GetLayoutAttribute(instr->operand(0)->shape().layout(), &builder_));
1276 triangular_solve.layout_bAttr(
1277 GetLayoutAttribute(instr->operand(1)->shape().layout(), &builder_));
1278 triangular_solve.layout_outputAttr(
1279 GetLayoutAttribute(instr->shape().layout(), &builder_));
1280 return triangular_solve;
1281 }
1282
EmitBitcast(const xla::HloInstruction * instr)1283 xla::StatusOr<Operation*> LhloDialectEmitter::EmitBitcast(
1284 const xla::HloInstruction* instr) {
1285 // XLA buffer assignment should assign the same slice to a bitcast input and
1286 // output.
1287 const xla::ShapeIndex top_index;
1288 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice,
1289 assignment_.GetUniqueSlice(instr, top_index));
1290 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice input_slice,
1291 assignment_.GetUniqueSlice(instr->operand(0), top_index));
1292
1293 if (input_slice != result_slice) {
1294 return xla::InvalidArgument(
1295 "Bitcast input and result slice should be same");
1296 }
1297 return nullptr;
1298 }
1299
GetLayoutAttribute(const xla::Layout & layout,Builder * builder)1300 mlir::DenseIntElementsAttr LhloDialectEmitter::GetLayoutAttribute(
1301 const xla::Layout& layout, Builder* builder) {
1302 llvm::SmallVector<int64_t, 4> minor_to_major(layout.minor_to_major().begin(),
1303 layout.minor_to_major().end());
1304 return builder->getIndexTensorAttr(minor_to_major);
1305 }
1306
ImportAsLmhloRegion(xla::HloComputation * computation,mlir::Region * region)1307 Status LhloDialectEmitter::ImportAsLmhloRegion(xla::HloComputation* computation,
1308 mlir::Region* region) {
1309 auto after = builder_.saveInsertionPoint();
1310 auto reverter = xla::MakeCleanup(
1311 [this, after] { builder_.restoreInsertionPoint(after); });
1312
1313 builder_ = OpBuilder(region);
1314 const xla::HloInstructionSequence* schedule =
1315 assignment_.hlo_ordering().SequentialOrder(*computation);
1316 if (!schedule)
1317 return xla::Unimplemented("Missing sequential order for the computation");
1318 TF_RETURN_IF_ERROR(
1319 computation->AcceptOrdered(this, schedule->instructions()));
1320 builder_.create<lmhlo::TerminatorOp>(builder_.getUnknownLoc());
1321 return Status::OK();
1322 }
1323
EmitCaseOp(const HloInstruction * instr)1324 StatusOr<lmhlo::CaseOp> LhloDialectEmitter::EmitCaseOp(
1325 const HloInstruction* instr) {
1326 Location loc = getLocation(instr);
1327 llvm::SmallVector<Value, 4> operands;
1328 size_t num_arguments, num_results;
1329 TF_RETURN_IF_ERROR(CreateOperands(instr, 1, TokenLoweringMode::kUseNull,
1330 operands, num_arguments, num_results));
1331
1332 auto case_op =
1333 builder_.create<lmhlo::CaseOp>(loc, operands[0], instr->branch_count());
1334
1335 for (int i = 0; i < instr->branch_count(); i++) {
1336 case_op.branches()[i].push_back(new mlir::Block());
1337 TF_RETURN_IF_ERROR(ImportAsLmhloRegion(instr->called_computations()[i],
1338 &case_op.branches()[i]));
1339 }
1340
1341 return case_op;
1342 }
1343
EmitWhileOp(const xla::HloInstruction * instr)1344 xla::StatusOr<lmhlo::WhileOp> LhloDialectEmitter::EmitWhileOp(
1345 const xla::HloInstruction* instr) {
1346 Location loc = getLocation(instr);
1347 SmallVector<Value, 1> operands;
1348 TF_RETURN_IF_ERROR(GetOrCreateView(
1349 instr->called_computations()[1]->root_instruction(), &operands));
1350 TF_RET_CHECK(operands.size() == 1);
1351
1352 TF_ASSIGN_OR_RETURN(auto config,
1353 instr->backend_config<xla::WhileLoopBackendConfig>());
1354 mlir::IntegerAttr trip_count;
1355 if (config.has_known_trip_count()) {
1356 trip_count = builder_.getI64IntegerAttr(config.known_trip_count().n());
1357 }
1358 lmhlo::WhileOp while_op =
1359 builder_.create<lmhlo::WhileOp>(loc, operands[0], trip_count);
1360
1361 while_op.cond().push_back(new mlir::Block());
1362 while_op.body().push_back(new mlir::Block());
1363 TF_RETURN_IF_ERROR(
1364 ImportAsLmhloRegion(instr->called_computations()[1], &while_op.cond()));
1365
1366 TF_RETURN_IF_ERROR(
1367 ImportAsLmhloRegion(instr->called_computations()[0], &while_op.body()));
1368
1369 return while_op;
1370 }
1371
GetOrCreateArrayView(const xla::HloInstruction * instr,const xla::Shape & current_shape,const xla::ShapeIndex & shape_index)1372 StatusOr<Value> LhloDialectEmitter::GetOrCreateArrayView(
1373 const xla::HloInstruction* instr, const xla::Shape& current_shape,
1374 const xla::ShapeIndex& shape_index) {
1375 // Cache generated ViewOp and StaticMemRefCastOp by (instruction,
1376 // shape_index).
1377 auto& cached_value = slices_[std::make_pair(instr, shape_index)];
1378 if (cached_value) {
1379 return cached_value;
1380 }
1381
1382 if (instr->IsConstant() && shape_index.empty()) {
1383 TF_ASSIGN_OR_RETURN(Value constant_memref, EmitConstant(instr));
1384 return cached_value = constant_memref;
1385 }
1386
1387 // If the shape happens to have dynamic dimensions, create the memref using
1388 // the underlying static shape.
1389 // TODO(jurahul): Revisit this when we can model memrefs with dynamic shape
1390 // but static bounds in MLIR.
1391 const Shape static_shape = xla::ShapeUtil::MakeStaticShape(current_shape);
1392
1393 TF_ASSIGN_OR_RETURN(Type out_type, xla::ConvertShapeToType<MemRefType>(
1394 static_shape, builder_));
1395 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
1396 assignment_.GetUniqueSlice(instr, shape_index));
1397 Value alloc = allocations_[slice.allocation()];
1398
1399 // TODO(timshen): revisit location handling.
1400 Location loc = builder_.getUnknownLoc();
1401
1402 Value result;
1403 if (AllocationShouldLowerToTypedArg(slice.allocation())) {
1404 TF_RET_CHECK(slice.offset() == 0);
1405 TF_RET_CHECK(slice.size() == slice.allocation()->size());
1406 result = alloc;
1407 } else {
1408 Value byte_shift =
1409 builder_.create<ConstantIndexOp>(alloc.getLoc(), slice.offset());
1410
1411 xla::Shape physical_shape =
1412 xla::ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
1413 static_shape);
1414 TF_ASSIGN_OR_RETURN(
1415 Type physical_out_type,
1416 xla::ConvertShapeToType<MemRefType>(physical_shape, builder_));
1417
1418 // ViewOp only takes memrefs without affine maps (layouts). Let ViewOp
1419 // produce the physical shape (where dimensions are ordered in major to
1420 // minor) first, then follow up with a MemRefReinterpretCast to cast the
1421 // resulting memref to the original layout.
1422 result = builder_.create<memref::ViewOp>(loc, physical_out_type, alloc,
1423 byte_shift,
1424 /*sizes=*/ValueRange{});
1425 }
1426 if (result.getType() != out_type) {
1427 int64_t out_offset;
1428 SmallVector<int64_t, 4> out_strides;
1429 auto out_memref_type = out_type.dyn_cast<MemRefType>();
1430 if (!out_memref_type)
1431 return tensorflow::errors::Internal(
1432 "Expected memref type when creating a view for leaf type of a "
1433 "tuple.");
1434 if (failed(getStridesAndOffset(out_memref_type, out_strides, out_offset)))
1435 return tensorflow::errors::Internal(
1436 "Failed to get strides and offset from the output type.");
1437 result = builder_.create<memref::ReinterpretCastOp>(
1438 loc, out_memref_type, result, out_offset, out_memref_type.getShape(),
1439 out_strides);
1440 }
1441 return cached_value = result;
1442 }
1443
GetOrCreateViewImpl(const HloInstruction * instr,const Shape & current_shape,xla::ShapeIndex * current_shape_index,SmallVectorImpl<Value> * values,TokenLoweringMode token_mode)1444 Status LhloDialectEmitter::GetOrCreateViewImpl(
1445 const HloInstruction* instr, const Shape& current_shape,
1446 xla::ShapeIndex* current_shape_index, SmallVectorImpl<Value>* values,
1447 TokenLoweringMode token_mode) {
1448 if (current_shape.IsTuple()) {
1449 for (int i = 0; i < current_shape.tuple_shapes().size(); ++i) {
1450 current_shape_index->push_back(i);
1451 TF_RETURN_IF_ERROR(
1452 GetOrCreateViewImpl(instr, current_shape.tuple_shapes(i),
1453 current_shape_index, values, token_mode));
1454 current_shape_index->pop_back();
1455 }
1456 return Status::OK();
1457 }
1458 if (current_shape.IsArray()) {
1459 TF_ASSIGN_OR_RETURN(auto v, GetOrCreateArrayView(instr, current_shape,
1460 *current_shape_index));
1461 values->push_back(v);
1462 return Status::OK();
1463 }
1464 if (current_shape.IsToken()) {
1465 switch (token_mode) {
1466 case TokenLoweringMode::kFailToLower:
1467 return xla::InternalError(
1468 "Unexpected token kind for %s and shape index %s",
1469 instr->ToString(), current_shape_index->ToString());
1470
1471 case TokenLoweringMode::kUseNull:
1472 values->push_back(Value{});
1473 return Status::OK();
1474 }
1475 }
1476 return xla::InternalError("Unexpected shape kind for %s and shape index %s",
1477 instr->ToString(), current_shape_index->ToString());
1478 }
1479
1480 // Returns a view for the result of an instruction.
1481 // We first get a view for the slice in the allocation, and then may need to
1482 // create another view to adjust the slice for the shape of the instruction.
GetOrCreateView(const HloInstruction * instr,SmallVectorImpl<Value> * values,const xla::ShapeIndex & result_subset,TokenLoweringMode token_mode)1483 Status LhloDialectEmitter::GetOrCreateView(const HloInstruction* instr,
1484 SmallVectorImpl<Value>* values,
1485 const xla::ShapeIndex& result_subset,
1486 TokenLoweringMode token_mode) {
1487 xla::ShapeIndex shape_index = result_subset;
1488 const Shape& sub_shape =
1489 xla::ShapeUtil::GetSubshape(instr->shape(), shape_index);
1490 return GetOrCreateViewImpl(instr, sub_shape, &shape_index, values,
1491 token_mode);
1492 }
1493
Initialize()1494 Status LhloDialectEmitter::Initialize() {
1495 TF_RET_CHECK(computation_.IsEntryComputation());
1496
1497 mlir::IntegerAttr unique_id =
1498 builder_.getI32IntegerAttr(computation_.parent()->unique_id());
1499 module_->setAttr("hlo.unique_id", unique_id);
1500 std::string function_name =
1501 computation_.name().empty() ? "__compute" : computation_.name();
1502
1503 // Create the function as () -> (), we'll compute the arguments from the
1504 // buffer allocation and update the type then.
1505 auto func_op = FuncOp::create(builder_.getUnknownLoc(), function_name,
1506 builder_.getFunctionType({}, {}));
1507
1508 {
1509 // This is an optional attribute used by the XLA backend. If the resulting
1510 // LMHLO doesn't go through XLA, this is not needed.
1511 const Shape& shape = computation_.root_instruction()->shape();
1512 func_op->setAttr(
1513 "result_xla_shape",
1514 builder_.getStringAttr(shape.ToString(/*print_layout=*/true)));
1515 }
1516 Block* block = func_op.addEntryBlock();
1517
1518 llvm::SmallVector<const BufferAllocation*, 8> ordered_allocations;
1519 for (const BufferAllocation& alloc : assignment_.Allocations())
1520 ordered_allocations.push_back(&alloc);
1521
1522 if (computation_.IsEntryComputation()) {
1523 // Sort the rather arbitrarily ordered allocations to match the input/output
1524 // parameters. Specifically we want to sort buffer allocations in the
1525 // following order:
1526 // * Parameters always order before non-parameters.
1527 // * Different parameters order by parameter number.
1528 // * Different allocations for the same parameter order by the shape index.
1529 //
1530 // TODO(timshen): there should be only one non-parameter buffer, the temp
1531 // buffer. Check on that.
1532 const auto allocation_comparator = [](const BufferAllocation* lhs,
1533 const BufferAllocation* rhs) {
1534 if (lhs->is_entry_computation_parameter() !=
1535 rhs->is_entry_computation_parameter()) {
1536 return lhs->is_entry_computation_parameter() >
1537 rhs->is_entry_computation_parameter();
1538 }
1539 if (lhs->is_entry_computation_parameter()) {
1540 return std::tuple<int, const xla::ShapeIndex&>(
1541 lhs->parameter_number(), lhs->param_shape_index()) <
1542 std::tuple<int, const xla::ShapeIndex&>(
1543 rhs->parameter_number(), rhs->param_shape_index());
1544 }
1545 return false;
1546 };
1547
1548 std::stable_sort(ordered_allocations.begin(), ordered_allocations.end(),
1549 allocation_comparator);
1550 }
1551
1552 absl::flat_hash_map<const BufferAllocation*,
1553 std::pair<const Shape*, xla::ShapeIndex>>
1554 allocation_to_output_info;
1555 TF_RETURN_IF_ERROR(xla::ShapeUtil::ForEachSubshapeWithStatus(
1556 computation_.root_instruction()->shape(),
1557 [&](const Shape& sub_shape, xla::ShapeIndex index) -> Status {
1558 TF_ASSIGN_OR_RETURN(
1559 auto slice,
1560 assignment_.GetUniqueSlice(computation_.root_instruction(), index));
1561 const BufferAllocation* alloc = slice.allocation();
1562 TF_RET_CHECK(slice.offset() == 0);
1563 TF_RET_CHECK(slice.size() == alloc->size());
1564 allocation_to_output_info[alloc] = std::make_pair(&sub_shape, index);
1565 return Status::OK();
1566 }));
1567
1568 // The function signature will be composed of:
1569 // - one memref for each of the parameters.
1570 // - one memref for each other buffer allocation.
1571 llvm::SmallVector<DictionaryAttr, 8> args_attrs;
1572 for (const BufferAllocation* alloc : ordered_allocations) {
1573 if (alloc->is_thread_local()) {
1574 continue;
1575 }
1576
1577 // There are optional attributes to help the program run through XLA. XLA
1578 // defines ExecutionInput and ExecutionOutput structures to carry
1579 // input-output type and buffer information, therefore any information they
1580 // need (mainly the type structure, potentially containing tuples) to be
1581 // preserved. They are not needed if the generated LMHLO is not sent to XLA.
1582 NamedAttrList arg_attr_list;
1583 mlir::Type arg_type;
1584 if (AllocationShouldLowerToTypedArg(alloc)) {
1585 xla::Shape buffer_shape = xla::ShapeUtil::GetSubshape(
1586 computation_.parameter_instruction(alloc->parameter_number())
1587 ->shape(),
1588 alloc->param_shape_index());
1589
1590 if (buffer_shape.IsTuple()) {
1591 arg_type = MemRefType::get({alloc->size()}, i8_type_);
1592 } else {
1593 // TODO(jurahul): Revisit this when we can model memrefs with dynamic
1594 // shape but static bounds in MLIR.
1595 const Shape static_shape =
1596 xla::ShapeUtil::MakeStaticShape(buffer_shape);
1597 TF_ASSIGN_OR_RETURN(arg_type, xla::ConvertShapeToType<MemRefType>(
1598 static_shape, builder_));
1599 }
1600 } else {
1601 arg_type = MemRefType::get({alloc->size()}, i8_type_);
1602 }
1603
1604 if (alloc->is_entry_computation_parameter()) {
1605 arg_attr_list.set("lmhlo.params",
1606 builder_.getIndexAttr(alloc->parameter_number()));
1607 if (!alloc->param_shape_index().empty()) {
1608 arg_attr_list.set("lmhlo.param_shape_index",
1609 builder_.getI64TensorAttr(llvm::makeArrayRef(
1610 alloc->param_shape_index().begin(),
1611 alloc->param_shape_index().end())));
1612 }
1613 }
1614 // Optional: an attribute for optimization. If a kernel uses this
1615 // allocation, but the allocation has lmhlo.constant_name, then the kernel
1616 // will instead use the global value indicated by the name for potentially
1617 // more optimizations (e.g. constant propagation).
1618 if (alloc->is_constant()) {
1619 arg_attr_list.set(
1620 "lmhlo.constant_name",
1621 builder_.getStringAttr(
1622 xla::llvm_ir::ConstantBufferAllocationToGlobalName(*alloc)));
1623 }
1624 auto iter = allocation_to_output_info.find(alloc);
1625 if (iter != allocation_to_output_info.end()) {
1626 const Shape* sub_shape = iter->second.first;
1627 const xla::ShapeIndex& shape_index = iter->second.second;
1628 if (!sub_shape->IsArray()) {
1629 continue;
1630 }
1631 arg_attr_list.set("lmhlo.output_index",
1632 builder_.getI64TensorAttr(llvm::makeArrayRef(
1633 shape_index.begin(), shape_index.end())));
1634 if (auto alias = computation_.parent()
1635 ->input_output_alias_config()
1636 .GetAliasedParameter(shape_index)) {
1637 if (alias->must_alias()) {
1638 arg_attr_list.set("lmhlo.must_alias", builder_.getUnitAttr());
1639 }
1640 }
1641 }
1642 block->addArgument(arg_type);
1643 allocations_[alloc] = block->getArguments().back();
1644 args_attrs.push_back(arg_attr_list.getDictionary(builder_.getContext()));
1645 }
1646
1647 FunctionType function_type =
1648 builder_.getFunctionType(block->getArgumentTypes(), {});
1649 func_op.setType(function_type);
1650 func_op.setAllArgAttrs(args_attrs);
1651
1652 SymbolTable symbol_table(module_);
1653 symbol_table.insert(func_op);
1654 builder_.setInsertionPointToEnd(block);
1655
1656 auto return_op =
1657 builder_.create<lmhlo::TerminatorOp>(builder_.getUnknownLoc());
1658 builder_ = OpBuilder(return_op);
1659
1660 return Status::OK();
1661 }
1662
createXlaHloToLhloWithXlaPass()1663 std::unique_ptr<OperationPass<ModuleOp>> createXlaHloToLhloWithXlaPass() {
1664 return std::make_unique<XlaHloToLhloPass>();
1665 }
1666
HloToLhloModule(const BufferAssignment & assignment,const HloModule & hlo_module,ModuleOp module)1667 Status HloToLhloModule(const BufferAssignment& assignment,
1668 const HloModule& hlo_module, ModuleOp module) {
1669 module.getContext()
1670 ->loadDialect<StandardOpsDialect, memref::MemRefDialect,
1671 mhlo::MhloDialect, lmhlo::LmhloDialect,
1672 lmhlo_gpu::LmhloGpuDialect>();
1673
1674 module->setLoc(mlir::NameLoc::get(
1675 mlir::Identifier::get(hlo_module.name(), module.getContext())));
1676
1677 const HloComputation* computation = hlo_module.entry_computation();
1678
1679 LhloDialectEmitter emitter(assignment, *computation, module);
1680 TF_RETURN_IF_ERROR(emitter.Initialize());
1681
1682 const xla::HloInstructionSequence* schedule =
1683 assignment.hlo_ordering().SequentialOrder(*computation);
1684 if (!schedule)
1685 return xla::Unimplemented("Missing sequential order for the computation");
1686 const std::vector<HloInstruction*>& ordering = schedule->instructions();
1687 TF_RETURN_IF_ERROR(computation->AcceptOrdered(&emitter, ordering));
1688 TF_RET_CHECK(succeeded(mlir::verify(module)));
1689 return Status::OK();
1690 }
1691
HloTextToLhloTranslateFunction(llvm::StringRef input,MLIRContext * context)1692 OwningModuleRef HloTextToLhloTranslateFunction(llvm::StringRef input,
1693 MLIRContext* context) {
1694 StatusOr<std::unique_ptr<HloModule>> maybe_module =
1695 xla::ParseAndReturnUnverifiedModule(
1696 absl::string_view(input.data(), input.size()));
1697 TF_CHECK_OK(maybe_module.status());
1698
1699 OwningModuleRef module = ModuleOp::create(UnknownLoc::get(context));
1700
1701 TF_CHECK_OK(OptimizeAndConvertHloToLmhlo(maybe_module.ConsumeValueOrDie(),
1702 module.get(), "Host"));
1703
1704 return module;
1705 }
1706
1707 static PassRegistration<XlaHloToLhloPass> registration;
1708
1709 } // namespace mlir
1710