• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h"
17 
18 #include <climits>
19 #include <memory>
20 #include <tuple>
21 
22 #include "absl/algorithm/container.h"
23 #include "absl/types/optional.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
27 #include "mlir/IR/AffineExpr.h"  // from @llvm-project
28 #include "mlir/IR/AffineMap.h"  // from @llvm-project
29 #include "mlir/IR/Attributes.h"  // from @llvm-project
30 #include "mlir/IR/Builders.h"  // from @llvm-project
31 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
32 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
33 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
34 #include "mlir/IR/Dialect.h"  // from @llvm-project
35 #include "mlir/IR/Location.h"  // from @llvm-project
36 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
37 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
38 #include "mlir/IR/Operation.h"  // from @llvm-project
39 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
40 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
41 #include "mlir/IR/Verifier.h"  // from @llvm-project
42 #include "mlir/Pass/Pass.h"  // from @llvm-project
43 #include "mlir/Pass/PassOptions.h"  // from @llvm-project
44 #include "mlir/Translation.h"  // from @llvm-project
45 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
46 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.h"
47 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
48 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
49 #include "tensorflow/compiler/mlir/xla/attribute_importer.h"
50 #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
51 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
52 #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
53 #include "tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.h"
54 #include "tensorflow/compiler/xla/debug_options_flags.h"
55 #include "tensorflow/compiler/xla/service/backend.h"
56 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
57 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
58 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
59 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
60 #include "tensorflow/compiler/xla/service/hlo_computation.h"
61 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
62 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
63 #include "tensorflow/compiler/xla/service/hlo_module.h"
64 #include "tensorflow/compiler/xla/service/hlo_parser.h"
65 #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
66 #include "tensorflow/compiler/xla/shape_util.h"
67 #include "tensorflow/compiler/xla/statusor.h"
68 #include "tensorflow/compiler/xla/util.h"
69 #include "tensorflow/compiler/xla/window_util.h"
70 #include "tensorflow/compiler/xla/xla_data.pb.h"
71 
72 using xla::BufferAllocation;
73 using xla::BufferAssignment;
74 using xla::HloComputation;
75 using xla::HloCustomCallInstruction;
76 using xla::HloInfeedInstruction;
77 using xla::HloInstruction;
78 using xla::HloModule;
79 using xla::HloModuleProto;
80 using xla::HloOutfeedInstruction;
81 using xla::HloProto;
82 using xla::Shape;
83 using xla::StatusOr;
84 
85 namespace mlir {
86 namespace {
87 
StringRefToView(llvm::StringRef ref)88 absl::string_view StringRefToView(llvm::StringRef ref) {
89   return {ref.data(), ref.size()};
90 }
91 
HloModuleFromProto(const HloProto & hlo_proto)92 StatusOr<std::unique_ptr<HloModule>> HloModuleFromProto(
93     const HloProto& hlo_proto) {
94   const HloModuleProto& module_proto = hlo_proto.hlo_module();
95   TF_ASSIGN_OR_RETURN(const xla::HloModuleConfig module_config,
96                       HloModule::CreateModuleConfigFromProto(
97                           module_proto, xla::GetDebugOptionsFromFlags()));
98   return HloModule::CreateFromProto(module_proto, module_config);
99 }
100 
AllocationShouldLowerToTypedArg(const BufferAllocation * alloc)101 bool AllocationShouldLowerToTypedArg(const BufferAllocation* alloc) {
102   return alloc->is_entry_computation_parameter() && !alloc->maybe_live_out();
103 }
104 
105 }  // namespace
106 
107 // Convert the MLIR `module` from HLO dialect to LHLO dialect using XLA for the
108 // given platform.
OptimizeAndConvertHloToLmhlo(std::unique_ptr<HloModule> hlo_module,ModuleOp module,StringRef platform_name)109 Status OptimizeAndConvertHloToLmhlo(std::unique_ptr<HloModule> hlo_module,
110                                     ModuleOp module, StringRef platform_name) {
111   auto platform = xla::se::MultiPlatformManager::PlatformWithName(
112       StringRefToView(platform_name));
113   if (!platform.ok()) {
114     std::string error_msg;
115     llvm::raw_string_ostream os(error_msg);
116     os << "failed to get platform: " << platform.status().ToString()
117        << " (available Platform: ";
118     std::vector<std::string> available_platforms;
119     (void)xla::se::MultiPlatformManager::PlatformsWithFilter(
120         [&](const stream_executor::Platform* p) {
121           available_platforms.push_back(p->Name());
122           return false;
123         });
124     llvm::interleaveComma(available_platforms, os);
125     os << ")";
126     return xla::InvalidArgument("%s", os.str().c_str());
127   }
128 
129   xla::BackendOptions backend_options;
130   backend_options.set_platform(platform.ValueOrDie());
131   auto backend_or_err = xla::Backend::CreateBackend(backend_options);
132   TF_RETURN_WITH_CONTEXT_IF_ERROR(backend_or_err.status(),
133                                   "failed to create XLA Backend ");
134   auto backend = std::move(backend_or_err.ValueOrDie());
135 
136   // Run all HLO passes to produce an optimized module.
137   auto result_or = backend->compiler()->RunHloPassesAndBufferAssignement(
138       std::move(hlo_module), backend->default_stream_executor(),
139       optimize_xla_hlo, {backend->memory_allocator()});
140   TF_RETURN_WITH_CONTEXT_IF_ERROR(result_or.status(),
141                                   "running XLA pass pipeline");
142   std::unique_ptr<HloModule> optimized_hlo_module =
143       std::move(std::get<0>(result_or.ValueOrDie()));
144   std::unique_ptr<BufferAssignment> assignment =
145       std::move(std::get<1>(result_or.ValueOrDie()));
146 
147   // Clear the module before populating it back with the result of the
148   // conversion.
149   module.getBody()->clear();
150   OpBuilder builder(module);
151 
152   TF_RETURN_WITH_CONTEXT_IF_ERROR(
153       HloToLhloModule(*assignment, *optimized_hlo_module, module),
154       "converting HLO to LHLO");
155 
156   return Status::OK();
157 }
158 
159 namespace {
160 // This pass takes an MLIR HLO module, converts it to XLA to perform the HLO
161 // optimization pipeline for the required platform, and then converts it back to
162 // MLIR LHLO.
163 class XlaHloToLhloPass
164     : public PassWrapper<XlaHloToLhloPass, OperationPass<ModuleOp>> {
getDependentDialects(DialectRegistry & registry) const165   void getDependentDialects(DialectRegistry& registry) const override {
166     registry
167         .insert<StandardOpsDialect, memref::MemRefDialect, mhlo::MhloDialect,
168                 lmhlo::LmhloDialect, lmhlo_gpu::LmhloGpuDialect>();
169   }
170 
171  public:
172   XlaHloToLhloPass() = default;
XlaHloToLhloPass(const XlaHloToLhloPass &)173   XlaHloToLhloPass(const XlaHloToLhloPass&) {}
getArgument() const174   StringRef getArgument() const final { return "xla-hlo-to-lhlo-with-xla"; }
getDescription() const175   StringRef getDescription() const final {
176     return "Emit LHLO from HLO using the existing XLA implementation";
177   }
178 
179  private:
runOnOperation()180   void runOnOperation() final {
181     ModuleOp module = getOperation();
182 
183     auto status = [&module, this]() -> Status {
184       SymbolTable symbol_table(module);
185       if (!symbol_table.lookup("main")) {
186         return xla::InvalidArgument(
187             "conversion to HLO module failed: missing main()");
188       }
189       HloProto hlo_proto;
190       TF_RETURN_WITH_CONTEXT_IF_ERROR(
191           ConvertMlirHloToHlo(module, &hlo_proto,
192                               /*use_tuple_args=*/false,
193                               /*return_tuple=*/false,
194                               /*shape_representation_fn=*/nullptr),
195           "conversion to XLA HLO proto failed");
196 
197       auto statusOrHloModule = HloModuleFromProto(hlo_proto);
198       TF_RETURN_WITH_CONTEXT_IF_ERROR(statusOrHloModule.status(),
199                                       "parsing HLO proto to HLO module failed");
200       std::unique_ptr<HloModule> hlo_module =
201           std::move(statusOrHloModule.ValueOrDie());
202 
203       return OptimizeAndConvertHloToLmhlo(std::move(hlo_module), module,
204                                           platform_);
205     }();
206     if (!status.ok()) {
207       module.emitError() << status.ToString();
208       return signalPassFailure();
209     }
210   }
211 
212   Option<std::string> platform_{
213       *this, "platform",
214       llvm::cl::desc("The platform to use for the XLA optimization pipeline."),
215       llvm::cl::init("Host")};
216 };
217 
218 }  // namespace
219 
220 // Creates MLIR operands corresponding to operands and results of the XLA HLO
221 // instruction. If `num_operands` is valid, then only the first `num_operands`
222 // operands of the HLO instruction will be considered.
CreateOperands(const HloInstruction * instr,absl::optional<xla::int64> num_operands,TokenLoweringMode token_mode,llvm::SmallVectorImpl<Value> & operands,size_t & num_arguments,size_t & num_results)223 Status LhloDialectEmitter::CreateOperands(
224     const HloInstruction* instr, absl::optional<xla::int64> num_operands,
225     TokenLoweringMode token_mode, llvm::SmallVectorImpl<Value>& operands,
226     size_t& num_arguments, size_t& num_results) {
227   if (num_operands.value_or(0) > instr->operand_count())
228     return xla::InvalidArgument("num_operands must be <= operand count");
229   for (xla::int64 i = 0; i < num_operands.value_or(instr->operand_count());
230        ++i) {
231     TF_RETURN_IF_ERROR(GetOrCreateView(instr->operand(i), &operands,
232                                        /*result_subset=*/{}, token_mode));
233   }
234   num_arguments = operands.size();
235   TF_RETURN_IF_ERROR(
236       GetOrCreateView(instr, &operands, /*result_subset=*/{}, token_mode));
237   num_results = operands.size() - num_arguments;
238   return Status::OK();
239 }
240 
241 template <typename OpType>
CreateOpWithoutAttrs(const HloInstruction * instr,ValueRange operands)242 OpType LhloDialectEmitter::CreateOpWithoutAttrs(const HloInstruction* instr,
243                                                 ValueRange operands) {
244   Location loc = getLocation(instr);
245   return builder_.create<OpType>(loc, llvm::None, operands,
246                                  llvm::ArrayRef<NamedAttribute>{});
247 }
248 
249 template <typename OpType>
CreateOpWithoutAttrs(const HloInstruction * instr,size_t & num_arguments,size_t & num_results,absl::optional<xla::int64> num_operands)250 StatusOr<OpType> LhloDialectEmitter::CreateOpWithoutAttrs(
251     const HloInstruction* instr, size_t& num_arguments, size_t& num_results,
252     absl::optional<xla::int64> num_operands) {
253   llvm::SmallVector<Value, 4> operands;
254   TF_RETURN_IF_ERROR(CreateOperands(instr, num_operands,
255                                     TokenLoweringMode::kFailToLower, operands,
256                                     num_arguments, num_results));
257   return CreateOpWithoutAttrs<OpType>(instr, operands);
258 }
259 
CreateOpInFusion(const HloInstruction * instr,ValueRange buffer_operands,size_t num_arguments,size_t num_results)260 StatusOr<mlir::Operation*> LhloDialectEmitter::CreateOpInFusion(
261     const HloInstruction* instr, ValueRange buffer_operands,
262     size_t num_arguments, size_t num_results) {
263   Location loc = getLocation(instr);
264   std::vector<Value> buffers(buffer_operands.begin(), buffer_operands.end());
265   absl::Span<Value> arguments =
266       absl::MakeSpan(buffers).subspan(0, num_arguments);
267   absl::Span<Value> results =
268       absl::MakeSpan(buffers).subspan(num_arguments, num_results);
269 
270   mlir::lmhlo::FusionOp fusion = builder_.create<mlir::lmhlo::FusionOp>(loc);
271   mlir::OpBuilder b(&fusion.region());
272 
273   llvm::SmallVector<mlir::Value, 4> loads;
274   for (Value arg : arguments) {
275     auto load = b.create<mlir::memref::TensorLoadOp>(loc, arg);
276     Shape shape = xla::TypeToShape(arg.getType());
277     TF_RET_CHECK(shape.IsArray());
278     if (shape.layout() !=
279         xla::LayoutUtil::MakeDescendingLayout(shape.dimensions().size())) {
280       load->setAttr("minor_to_major", GetLayoutAttribute(shape.layout(), &b));
281     }
282     loads.push_back(load);
283   }
284   mlir::Operation* op = nullptr;
285   if (instr->opcode() == xla::HloOpcode::kReduce) {
286     TF_RET_CHECK(loads.size() % 2 == 0);
287     std::vector<int64_t> dimensions(instr->dimensions().begin(),
288                                     instr->dimensions().end());
289     auto reduce_op = b.create<mhlo::ReduceOp>(
290         loc, llvm::makeArrayRef(loads).take_front(loads.size() / 2),
291         llvm::makeArrayRef(loads).drop_front(loads.size() / 2),
292         GetI64DenseElementsAttr(dimensions));
293 
294     TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
295         *instr->called_computations()[0], &reduce_op.body(), &builder_));
296     op = reduce_op;
297   } else {
298     TF_ASSIGN_OR_RETURN(
299         op, xla::HloFunctionImporter::ImportInstruction(instr, loads, &b));
300   }
301   TF_RET_CHECK(op->getNumResults() == num_results);
302   for (int i = 0; i < results.size(); i++) {
303     b.create<mlir::memref::TensorStoreOp>(loc, op->getResult(i), results[i]);
304   }
305   return op;
306 }
307 
CreateOpInFusion(const HloInstruction * instr)308 StatusOr<mlir::Operation*> LhloDialectEmitter::CreateOpInFusion(
309     const HloInstruction* instr) {
310   llvm::SmallVector<Value, 4> operands;
311   size_t num_arguments, num_results;
312   TF_RETURN_IF_ERROR(CreateOperands(instr, absl::nullopt,
313                                     TokenLoweringMode::kFailToLower, operands,
314                                     num_arguments, num_results));
315   TF_ASSIGN_OR_RETURN(
316       auto op, CreateOpInFusion(instr, operands, num_arguments, num_results));
317   return op->getParentOp();
318 }
319 
EmitOp(const HloInstruction * instr)320 StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(
321     const HloInstruction* instr) {
322   using xla::HloOpcode;
323   switch (instr->opcode()) {
324     case HloOpcode::kAddDependency:
325       return nullptr;
326     case HloOpcode::kAfterAll:
327       // LMHLO is already ordered. This assumption may be broken after
328       // introducing async regions and partial orders.
329       return nullptr;
330     case HloOpcode::kAllToAll:
331       return EmitAllToAllOp(instr);
332     case HloOpcode::kAllGather:
333       return EmitAllGatherOp(instr);
334     case HloOpcode::kAllReduce:
335       return EmitAllReduceOp(instr);
336     case HloOpcode::kAllReduceStart:
337       return EmitAllReduceStartOp(instr);
338     case HloOpcode::kAllReduceDone:
339       return EmitAllReduceDoneOp(instr);
340     case HloOpcode::kReduceScatter:
341       return EmitReduceScatterOp(instr);
342     case HloOpcode::kBitcast:
343       return EmitBitcast(instr);
344     case HloOpcode::kCollectivePermute:
345       return EmitCollectivePermuteOp(instr);
346     case HloOpcode::kConditional:
347       return EmitCaseOp(instr);
348     case HloOpcode::kFft:
349       return EmitFftOp(instr);
350     case HloOpcode::kGetTupleElement:
351       return nullptr;
352     case HloOpcode::kInfeed:
353       return EmitInfeedOp(instr);
354     case HloOpcode::kOutfeed:
355       return EmitOutfeedOp(instr);
356     case HloOpcode::kPartitionId:
357       return CreateOpWithoutAttrs<lmhlo::PartitionIdOp>(instr);
358     case HloOpcode::kReplicaId:
359       return CreateOpWithoutAttrs<lmhlo::ReplicaIdOp>(instr);
360     case HloOpcode::kTriangularSolve:
361       return EmitTriangularSolveOp(instr);
362     case HloOpcode::kTuple:
363       return nullptr;
364     case HloOpcode::kSort:
365       return EmitSortOp(instr);
366     case HloOpcode::kFusion:
367       return EmitFusionOp(instr);
368     case HloOpcode::kScatter:
369       return EmitScatterOp(instr);
370     case HloOpcode::kSelectAndScatter:
371       return EmitSelectAndScatterOp(instr);
372     case HloOpcode::kCustomCall:
373       return EmitCustomCallOp(instr);
374     case HloOpcode::kConstant:
375       return EmitConstant(instr);
376     case HloOpcode::kRngGetAndUpdateState:
377       return EmitRngGetAndUpdateStateOp(instr);
378     case HloOpcode::kWhile:
379       return EmitWhileOp(instr);
380 
381     case HloOpcode::kAbs:
382     case HloOpcode::kAdd:
383     case HloOpcode::kAnd:
384     case HloOpcode::kAtan2:
385     case HloOpcode::kBitcastConvert:
386     case HloOpcode::kBroadcast:
387     case HloOpcode::kCeil:
388     case HloOpcode::kCbrt:
389     case HloOpcode::kClamp:
390     case HloOpcode::kClz:
391     case HloOpcode::kCompare:
392     case HloOpcode::kComplex:
393     case HloOpcode::kConcatenate:
394     case HloOpcode::kConvert:
395     case HloOpcode::kCos:
396     case HloOpcode::kDivide:
397     case HloOpcode::kDot:
398     case HloOpcode::kDynamicSlice:
399     case HloOpcode::kDynamicUpdateSlice:
400     case HloOpcode::kExp:
401     case HloOpcode::kExpm1:
402     case HloOpcode::kFloor:
403     case HloOpcode::kGather:
404     case HloOpcode::kImag:
405     case HloOpcode::kIota:
406     case HloOpcode::kIsFinite:
407     case HloOpcode::kLog:
408     case HloOpcode::kLog1p:
409     case HloOpcode::kMap:
410     case HloOpcode::kMaximum:
411     case HloOpcode::kMinimum:
412     case HloOpcode::kMultiply:
413     case HloOpcode::kNegate:
414     case HloOpcode::kNot:
415     case HloOpcode::kOr:
416     case HloOpcode::kPad:
417     case HloOpcode::kPopulationCount:
418     case HloOpcode::kPower:
419     case HloOpcode::kReal:
420     case HloOpcode::kReshape:
421     case HloOpcode::kReducePrecision:
422     case HloOpcode::kReduceWindow:
423     case HloOpcode::kRemainder:
424     case HloOpcode::kReverse:
425     case HloOpcode::kRoundNearestAfz:
426     case HloOpcode::kRsqrt:
427     case HloOpcode::kSelect:
428     case HloOpcode::kShiftLeft:
429     case HloOpcode::kShiftRightLogical:
430     case HloOpcode::kShiftRightArithmetic:
431     case HloOpcode::kSign:
432     case HloOpcode::kSin:
433     case HloOpcode::kSlice:
434     case HloOpcode::kSqrt:
435     case HloOpcode::kSubtract:
436     case HloOpcode::kTanh:
437     case HloOpcode::kTranspose:
438     case HloOpcode::kXor:
439     case HloOpcode::kCopy:
440     case HloOpcode::kReduce:
441       return CreateOpInFusion(instr);
442     default:
443       llvm::errs() << instr->ToString();
444       return tensorflow::errors::Internal(
445           absl::StrCat("LHLO opcode ", xla::HloOpcodeString(instr->opcode()),
446                        " is not supported."));
447   }
448 }
449 
DefaultAction(const HloInstruction * instr)450 Status LhloDialectEmitter::DefaultAction(const HloInstruction* instr) {
451   return EmitOp(instr).status();
452 }
453 
EmitSortOp(const HloInstruction * instr)454 StatusOr<lmhlo::SortOp> LhloDialectEmitter::EmitSortOp(
455     const HloInstruction* instr) {
456   TF_ASSIGN_OR_RETURN(auto sort, CreateOpWithoutAttrs<lmhlo::SortOp>(instr));
457   auto* sort_instr = xla::Cast<xla::HloSortInstruction>(instr);
458   sort.dimensionAttr(builder_.getI64IntegerAttr(sort_instr->sort_dimension()));
459   sort.is_stableAttr(builder_.getBoolAttr(sort_instr->is_stable()));
460   TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
461       *sort_instr->called_computations()[0], &sort.comparator(), &builder_));
462   return sort;
463 }
464 
465 // Walks MHLO::TupleOp recursively.
WalkTuplePostOrder(Value v,const std::function<Status (Value)> & visitor)466 Status WalkTuplePostOrder(Value v,
467                           const std::function<Status(Value)>& visitor) {
468   if (auto* op = v.getDefiningOp()) {
469     if (auto tuple = dyn_cast<mhlo::TupleOp>(op)) {
470       for (Value sub_v : tuple.val()) {
471         TF_RETURN_IF_ERROR(WalkTuplePostOrder(sub_v, visitor));
472       }
473       return Status::OK();
474     }
475   }
476   return visitor(v);
477 }
478 
RewriteFusionOperand(const HloInstruction * root,const Shape & shape,xla::ShapeIndex * shape_index,OpBuilder * b,Location loc)479 StatusOr<Value> LhloDialectEmitter::RewriteFusionOperand(
480     const HloInstruction* root, const Shape& shape,
481     xla::ShapeIndex* shape_index, OpBuilder* b, Location loc) {
482   if (shape.IsTuple()) {
483     llvm::SmallVector<Value, 4> values;
484     for (int i = 0; i < shape.tuple_shapes_size(); ++i) {
485       shape_index->push_back(i);
486       TF_ASSIGN_OR_RETURN(
487           auto v, RewriteFusionOperand(root, shape.tuple_shapes(i), shape_index,
488                                        b, loc));
489       values.push_back(v);
490       shape_index->pop_back();
491     }
492     return Value(b->create<mhlo::TupleOp>(loc, values));
493   }
494   TF_ASSIGN_OR_RETURN(Value memref,
495                       GetOrCreateArrayView(root, shape, *shape_index));
496   auto load = b->create<memref::TensorLoadOp>(loc, memref);
497   if (shape.layout() !=
498       xla::LayoutUtil::MakeDescendingLayout(shape.dimensions().size())) {
499     llvm::SmallVector<int64_t, 4> minor_to_major(
500         shape.layout().minor_to_major().begin(),
501         shape.layout().minor_to_major().end());
502     load->setAttr("minor_to_major", GetLayoutAttribute(shape.layout(), b));
503   }
504   return load.getResult();
505 }
506 
507 // Emit a lmhlo.fusion based on XLA HLO fusion. Structurally they are not neatly
508 // equivalent. Specifically, XLA HLO fusion:
509 //     fused_computation {
510 //       %p0 = parameter(0)
511 //       %p1 = parameter(1)
512 //       ...
513 //       ROOT %ret = ...
514 //     }
515 // will be converted to
516 //     lmhlo.fusion() {  // no explicit operands
517 //       // capturing outside buffers
518 //       %p0 = tensor_load(%arg0) : memref<...> -> tensor<...>
519 //       %p1 = tensor_load(%arg1) : memref<...> -> tensor<...>
520 //       ...
521 //       tensor_store ..., %ret // store a tensor to a memref
522 //     }
EmitFusionOp(const HloInstruction * instr)523 StatusOr<lmhlo::FusionOp> LhloDialectEmitter::EmitFusionOp(
524     const HloInstruction* instr) {
525   Location loc = getLocation(instr);
526 
527   auto* fusion_instr = xla::Cast<xla::HloFusionInstruction>(instr);
528 
529   auto fusion = builder_.create<lmhlo::FusionOp>(getLocation(instr));
530   auto after_fusion = builder_.saveInsertionPoint();
531   auto reverter = xla::MakeCleanup(
532       [this, after_fusion] { builder_.restoreInsertionPoint(after_fusion); });
533   builder_ = mlir::OpBuilder(fusion);
534 
535   auto region_builder = OpBuilder::atBlockBegin(&fusion.region().front());
536 
537   llvm::SmallVector<Value, 8> arguments;
538   for (int i = 0; i < instr->operands().size(); ++i) {
539     const HloInstruction* operand = instr->operand(i);
540     xla::ShapeIndex shape_index;
541     TF_ASSIGN_OR_RETURN(
542         auto arg, RewriteFusionOperand(operand, operand->shape(), &shape_index,
543                                        &region_builder, loc));
544     arguments.push_back(arg);
545   }
546 
547   TF_ASSIGN_OR_RETURN(Value result,
548                       xla::HloFunctionImporter::ImportInstructions(
549                           *fusion_instr->fused_instructions_computation(),
550                           arguments, &region_builder));
551   {
552     int i = 0;
553     llvm::SmallVector<Value, 4> output;
554     TF_RETURN_IF_ERROR(GetOrCreateView(instr, &output));
555     TF_RETURN_IF_ERROR(WalkTuplePostOrder(result, [&](Value v) mutable {
556       region_builder.create<memref::TensorStoreOp>(loc, v, output[i++]);
557       return Status::OK();
558     }));
559     if (i != output.size()) {
560       return xla::InternalError("output sizes don't match");
561     }
562   }
563 
564   // Fold GTE/Tuple pairs.
565   //
566   // Since the fused region refers to values in its parent region, we can't
567   // call applyPatternAndFoldGreedily. We optimize it manually.
568   //
569   // Only walk once, because post-ordering is exactly what we need for GTE
570   // optimizations.
571   fusion.region().walk([](mhlo::GetTupleElementOp gte) {
572     SmallVector<Value, 4> folded_values;
573     if (succeeded(OpBuilder(gte).tryFold(gte, folded_values))) {
574       gte.replaceAllUsesWith(folded_values[0]);
575     }
576   });
577 
578   // Effectively a DCE on the region.
579   {
580     llvm::SmallVector<mlir::Operation*, 4> ops;
581     fusion.region().walk([&](mlir::Operation* op) { ops.push_back(op); });
582     // Visit the user first.
583     std::reverse(ops.begin(), ops.end());
584     for (auto op : ops) {
585       if (isOpTriviallyDead(op)) op->erase();
586     }
587   }
588 
589   return fusion;
590 }
591 
592 StatusOr<mhlo::ScatterDimensionNumbers>
GetScatterDimensionNumbers(const HloInstruction * instr,mlir::MLIRContext * context)593 LhloDialectEmitter::GetScatterDimensionNumbers(const HloInstruction* instr,
594                                                mlir::MLIRContext* context) {
595   auto* scatter_instr = xla::Cast<xla::HloScatterInstruction>(instr);
596 
597   const xla::ScatterDimensionNumbers& xla_scatter_dim =
598       scatter_instr->scatter_dimension_numbers();
599 
600   mlir::Builder builder(context);
601   auto get_i64_array_attr =
602       [builder](absl::Span<const xla::int64> container) mutable {
603         return builder.getI64TensorAttr(
604             {container.data(), static_cast<size_t>(container.size())});
605       };
606   auto scatter_dimension_numbers = mhlo::ScatterDimensionNumbers::get(
607       get_i64_array_attr(xla_scatter_dim.update_window_dims()),
608       get_i64_array_attr(xla_scatter_dim.inserted_window_dims()),
609       get_i64_array_attr(xla_scatter_dim.scatter_dims_to_operand_dims()),
610       builder.getI64IntegerAttr(xla_scatter_dim.index_vector_dim()), context);
611   return scatter_dimension_numbers;
612 }
613 
EmitScatterOp(const HloInstruction * instr)614 StatusOr<lmhlo::ScatterOp> LhloDialectEmitter::EmitScatterOp(
615     const HloInstruction* instr) {
616   TF_ASSIGN_OR_RETURN(auto scatter,
617                       CreateOpWithoutAttrs<lmhlo::ScatterOp>(instr));
618 
619   // copy attributes
620   auto* scatter_instr = xla::Cast<xla::HloScatterInstruction>(instr);
621 
622   TF_ASSIGN_OR_RETURN(auto scatter_dimension_numbers,
623                       GetScatterDimensionNumbers(instr, builder_.getContext()));
624   scatter.scatter_dimension_numbersAttr(scatter_dimension_numbers);
625   scatter.indices_are_sortedAttr(
626       builder_.getBoolAttr(scatter_instr->indices_are_sorted()));
627   scatter.unique_indicesAttr(
628       builder_.getBoolAttr(scatter_instr->unique_indices()));
629 
630   // import update computation as region
631   TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
632       *scatter_instr->called_computations()[0], &scatter.update_computation(),
633       &builder_));
634 
635   return scatter;
636 }
637 
EmitSelectAndScatterOp(const HloInstruction * instr)638 StatusOr<lmhlo::SelectAndScatterOp> LhloDialectEmitter::EmitSelectAndScatterOp(
639     const HloInstruction* instr) {
640   TF_ASSIGN_OR_RETURN(auto select_and_scatter,
641                       CreateOpWithoutAttrs<lmhlo::SelectAndScatterOp>(instr));
642 
643   // copy attributes
644   auto* select_and_scatter_instr =
645       xla::Cast<xla::HloSelectAndScatterInstruction>(instr);
646   const xla::Window& window = select_and_scatter_instr->window();
647 
648   if (xla::window_util::HasDilation(window)) {
649     return xla::Unimplemented("Dilation for SelectAndScatter is not supported");
650   }
651 
652   select_and_scatter.window_dimensionsAttr(
653       GetWindowElements(window, [](const xla::WindowDimension& dim) {
654         return static_cast<int64_t>(dim.size());
655       }));
656   select_and_scatter.window_stridesAttr(
657       GetWindowElements(window, [](const xla::WindowDimension& dim) {
658         return static_cast<int64_t>(dim.stride());
659       }));
660   select_and_scatter.paddingAttr(
661       GetWindowElements(window, [](const xla::WindowDimension& dim) {
662         return static_cast<int64_t>(dim.padding_low());
663       }));
664 
665   // import select and scatter computation as region
666   TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
667       *select_and_scatter_instr->select(), &select_and_scatter.select(),
668       &builder_));
669   TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
670       *select_and_scatter_instr->scatter(), &select_and_scatter.scatter(),
671       &builder_));
672   return select_and_scatter;
673 }
674 
EmitCustomCallOp(const HloInstruction * instr)675 StatusOr<mlir::Operation*> LhloDialectEmitter::EmitCustomCallOp(
676     const HloInstruction* instr) {
677   auto* custom_call_instr = xla::Cast<xla::HloCustomCallInstruction>(instr);
678 
679   if (xla::gpu::IsCustomCallToCusolver(*instr)) {
680     return EmitCholesky(custom_call_instr);
681   }
682 
683   if (xla::gpu::IsCublasGemm(*instr)) {
684     return EmitGemm(custom_call_instr);
685   }
686 
687   if (xla::gpu::IsCustomCallToDnnConvolution(*instr)) {
688     return EmitDnnConvolution(custom_call_instr);
689   }
690 
691   if (xla::gpu::IsCustomCallToDnnBatchNorm(*instr)) {
692     return EmitDnnBatchNorm(custom_call_instr);
693   }
694 
695   // For custom call, if there are any token operands or results, they will not
696   // be represented in LHLO so we need to remember the mapping. First create
697   // operands where each token is replaced with a null Value.
698   llvm::SmallVector<Value, 4> operands;
699   size_t num_arguments, num_results;
700   TF_RETURN_IF_ERROR(CreateOperands(instr, /*num_operands=*/absl::nullopt,
701                                     TokenLoweringMode::kUseNull, operands,
702                                     num_arguments, num_results));
703 
704   // Now check if any of the operands is Null, which would indicate the presence
705   // of a token in the input or output.
706   bool has_token = llvm::any_of(operands, [](Value v) { return !v; });
707 
708   lmhlo::CustomCallTargetArgMapping target_mapping;
709   if (has_token) {
710     // If there was a token, squeeze all the non-token arguments and results
711     // (in-place) and remember the mapping.
712     int next_index = 0;
713     llvm::SmallVector<int64_t> arg_to_target_arg_mapping;
714     for (int i = 0; i < num_arguments; ++i) {
715       if (operands[i]) {
716         arg_to_target_arg_mapping.push_back(i);
717         operands[next_index++] = operands[i];
718       }
719     }
720     // Size of arg_to_target_arg_mapping is the number of arguments in LHLO.
721     llvm::SmallVector<int64_t> result_to_target_result_mapping;
722     for (int i = num_arguments; i < operands.size(); ++i) {
723       if (operands[i]) {
724         result_to_target_result_mapping.push_back(i - num_arguments);
725         operands[next_index++] = operands[i];
726       }
727     }
728 
729     // Build the mapping attribute.
730     target_mapping = lmhlo::CustomCallTargetArgMapping::get(
731         builder_.getI64IntegerAttr(num_arguments),
732         builder_.getI64IntegerAttr(num_results),
733         builder_.getI64ArrayAttr(arg_to_target_arg_mapping),
734         builder_.getI64ArrayAttr(result_to_target_result_mapping),
735         builder_.getContext());
736 
737     // Drop the remaining operands and adjust num_arguments and num_results
738     // for LMHLO creation.
739     operands.resize(next_index);
740     num_arguments = arg_to_target_arg_mapping.size();
741     num_results = result_to_target_result_mapping.size();
742   }
743 
744   auto custom_call = CreateOpWithoutAttrs<lmhlo::CustomCallOp>(instr, operands);
745   TF_ASSIGN_OR_RETURN(
746       auto mlir_api_version,
747       ConvertCustomCallApiVersion(custom_call_instr->api_version()));
748   custom_call.call_target_nameAttr(
749       builder_.getStringAttr(custom_call_instr->custom_call_target()));
750   custom_call.backend_configAttr(
751       builder_.getStringAttr(custom_call_instr->opaque()));
752   custom_call.api_versionAttr(mhlo::CustomCallApiVersionAttr::get(
753       builder_.getContext(), mlir_api_version));
754   const int32_t segments[2] = {static_cast<int32_t>(num_arguments),
755                                static_cast<int32_t>(num_results)};
756   custom_call->setAttr(lmhlo::CustomCallOp::getOperandSegmentSizeAttr(),
757                        builder_.getI32VectorAttr(segments));
758   if (target_mapping) custom_call.target_arg_mappingAttr(target_mapping);
759   return custom_call.getOperation();
760 }
761 
EmitCholesky(const HloCustomCallInstruction * custom_call)762 StatusOr<lmhlo_gpu::CholeskyOp> LhloDialectEmitter::EmitCholesky(
763     const HloCustomCallInstruction* custom_call) {
764   TF_ASSIGN_OR_RETURN(auto cholesky_op,
765                       CreateOpWithoutAttrs<lmhlo_gpu::CholeskyOp>(custom_call));
766   TF_ASSIGN_OR_RETURN(xla::CholeskyOptions options,
767                       custom_call->backend_config<xla::CholeskyOptions>());
768   cholesky_op.is_lowerAttr(builder_.getBoolAttr(options.lower()));
769   return cholesky_op;
770 }
771 
EmitGemm(const HloCustomCallInstruction * custom_call)772 StatusOr<Operation*> LhloDialectEmitter::EmitGemm(
773     const HloCustomCallInstruction* custom_call) {
774   TF_ASSIGN_OR_RETURN(
775       auto const config,
776       custom_call->backend_config<xla::gpu::GemmBackendConfig>());
777 
778   auto set_common_attributes = [&](auto op) -> Operation* {
779     auto hlo_dims = config.dot_dimension_numbers();
780     auto mlir_dims = mhlo::DotDimensionNumbers::get(
781         GetI64DenseElementsAttr(hlo_dims.lhs_batch_dimensions()),
782         GetI64DenseElementsAttr(hlo_dims.rhs_batch_dimensions()),
783         GetI64DenseElementsAttr(hlo_dims.lhs_contracting_dimensions()),
784         GetI64DenseElementsAttr(hlo_dims.rhs_contracting_dimensions()),
785         builder_.getContext());
786     op.dot_dimension_numbersAttr(mlir_dims);
787     op.alpha_realAttr(builder_.getF64FloatAttr(config.alpha_real()));
788     op.alpha_imagAttr(builder_.getF64FloatAttr(config.alpha_imag()));
789     op.batch_sizeAttr(builder_.getI64IntegerAttr(config.batch_size()));
790     op.lhs_strideAttr(builder_.getI64IntegerAttr(config.lhs_stride()));
791     op.rhs_strideAttr(builder_.getI64IntegerAttr(config.rhs_stride()));
792     if (config.algorithm_case() ==
793         xla::gpu::GemmBackendConfig::kSelectedAlgorithm) {
794       op.algorithmAttr(builder_.getI64IntegerAttr(config.selected_algorithm()));
795     }
796     return op.getOperation();
797   };
798 
799   if (custom_call->operand_count() == 2) {
800     TF_ASSIGN_OR_RETURN(auto gemm,
801                         CreateOpWithoutAttrs<lmhlo_gpu::GEMMOp>(custom_call));
802     return set_common_attributes(gemm);
803   }
804 
805   if (custom_call->operand_count() == 3) {
806     TF_ASSIGN_OR_RETURN(
807         auto gemm_bias,
808         CreateOpWithoutAttrs<lmhlo_gpu::GEMM_BiasOp>(custom_call));
809     gemm_bias.betaAttr(builder_.getF64FloatAttr(config.beta()));
810     return set_common_attributes(gemm_bias);
811   }
812 
813   return xla::InvalidArgument("GEMM custom call should have 2 or 3 operands");
814 }
815 
GetLHLOActivation(stream_executor::dnn::ActivationMode activation)816 static StatusOr<mlir::lmhlo_gpu::Activation> GetLHLOActivation(
817     stream_executor::dnn::ActivationMode activation) {
818   switch (activation) {
819     case stream_executor::dnn::kNone:
820       return mlir::lmhlo_gpu::Activation::None;
821     case stream_executor::dnn::kSigmoid:
822       return mlir::lmhlo_gpu::Activation::Sigmoid;
823     case stream_executor::dnn::kRelu:
824       return mlir::lmhlo_gpu::Activation::Relu;
825     case stream_executor::dnn::kRelu6:
826       return mlir::lmhlo_gpu::Activation::Relu6;
827     case stream_executor::dnn::kReluX:
828       return mlir::lmhlo_gpu::Activation::ReluX;
829     case stream_executor::dnn::kTanh:
830       return mlir::lmhlo_gpu::Activation::Tanh;
831     case stream_executor::dnn::kBandPass:
832       return mlir::lmhlo_gpu::Activation::BandPass;
833     default:
834       return xla::InternalError("Unknown activation");
835   }
836 }
837 
EmitDnnConvolution(const HloCustomCallInstruction * custom_call)838 StatusOr<Operation*> LhloDialectEmitter::EmitDnnConvolution(
839     const HloCustomCallInstruction* custom_call) {
840   TF_ASSIGN_OR_RETURN(
841       auto const backend_config,
842       custom_call->backend_config<xla::gpu::CudnnConvBackendConfig>());
843 
844   TF_ASSIGN_OR_RETURN(const xla::gpu::CudnnConvKind kind,
845                       xla::gpu::GetCudnnConvKind(custom_call));
846 
847   auto get_layout_attribute = [&](const xla::Layout& layout) {
848     std::vector<int64_t> minor_to_major(layout.minor_to_major_size());
849     absl::c_transform(layout.minor_to_major(), minor_to_major.begin(),
850                       [](xla::int64 x) { return static_cast<int64_t>(x); });
851     return builder_.getI64ArrayAttr(minor_to_major);
852   };
853 
854   auto set_common_conv_attributes = [&, this](auto op) -> Operation* {
855     const xla::Window& window = custom_call->window();
856     // Window size for Cudnn Conv is same as the kernel size.
857     op.window_stridesAttr(
858         GetWindowElements(window, [](const xla::WindowDimension& dim) {
859           return static_cast<int64_t>(dim.stride());
860         }));
861     // Cudnn Conv requires low and high padding to be equal.
862     op.paddingAttr(
863         GetWindowElements(window, [](const xla::WindowDimension& dim) {
864           return static_cast<int64_t>(dim.padding_low());
865         }));
866     // LHS dilation is encoded in base_dilation of the backend config.
867     // RHS dilation is encoded in window_dilation of the backend config.
868     op.lhs_dilationAttr(
869         GetWindowElements(window, [](const xla::WindowDimension& dim) {
870           return static_cast<int64_t>(dim.base_dilation());
871         }));
872     op.rhs_dilationAttr(
873         GetWindowElements(window, [](const xla::WindowDimension& dim) {
874           return static_cast<int64_t>(dim.window_dilation());
875         }));
876     // Setup window reversal.
877     auto window_reversal = llvm::to_vector<4>(llvm::map_range(
878         window.dimensions(),
879         [](const xla::WindowDimension& dim) { return dim.window_reversal(); }));
880     auto type = RankedTensorType::get(op.window_strides()->getType().getShape(),
881                                       builder_.getIntegerType(/*width=*/1));
882     op.window_reversalAttr(DenseElementsAttr::get(type, window_reversal));
883 
884     op.dimension_numbersAttr(xla::ConvertConvDimensionNumbers(
885         custom_call->convolution_dimension_numbers(), &builder_));
886     op.feature_group_countAttr(
887         builder_.getI64IntegerAttr(custom_call->feature_group_count()));
888     op.batch_group_countAttr(
889         builder_.getI64IntegerAttr(custom_call->batch_group_count()));
890     op.precision_configAttr(xla::ConvertPrecisionConfig(
891         &custom_call->precision_config(), &builder_));
892     op.result_scaleAttr(
893         builder_.getF64FloatAttr(backend_config.conv_result_scale()));
894     auto config = mlir::lmhlo_gpu::ConvolutionBackendConfig::get(
895         builder_.getI64IntegerAttr(backend_config.algorithm()),
896         builder_.getBoolAttr(backend_config.tensor_ops_enabled()),
897         get_layout_attribute(custom_call->operand(0)->shape().layout()),
898         get_layout_attribute(custom_call->operand(1)->shape().layout()),
899         get_layout_attribute(custom_call->shape().tuple_shapes(0).layout()),
900         builder_.getContext());
901     op.backend_configAttr(config);
902 
903     return op.getOperation();
904   };
905 
906   auto set_activation = [&, this](auto op) -> Status {
907     auto se_activation = static_cast<stream_executor::dnn::ActivationMode>(
908         backend_config.activation_mode());
909     TF_ASSIGN_OR_RETURN(mlir::lmhlo_gpu::Activation activation,
910                         GetLHLOActivation(se_activation));
911     StringAttr activation_attr = builder_.getStringAttr(
912         mlir::lmhlo_gpu::stringifyActivation(activation));
913     op.activation_modeAttr(activation_attr);
914     return Status::OK();
915   };
916 
917   switch (kind) {
918     case xla::gpu::CudnnConvKind::kForward: {
919       TF_ASSIGN_OR_RETURN(
920           auto cnn_forward,
921           CreateOpWithoutAttrs<lmhlo_gpu::ConvForwardOp>(custom_call));
922       return set_common_conv_attributes(cnn_forward);
923     }
924     case xla::gpu::CudnnConvKind::kBackwardInput: {
925       TF_ASSIGN_OR_RETURN(
926           auto cnn_backward,
927           CreateOpWithoutAttrs<lmhlo_gpu::ConvBackwardInputOp>(custom_call));
928       return set_common_conv_attributes(cnn_backward);
929     }
930     case xla::gpu::CudnnConvKind::kBackwardFilter: {
931       TF_ASSIGN_OR_RETURN(
932           auto cnn_backward,
933           CreateOpWithoutAttrs<lmhlo_gpu::ConvBackwardFilterOp>(custom_call));
934       return set_common_conv_attributes(cnn_backward);
935     }
936     case xla::gpu::CudnnConvKind::kForwardActivation: {
937       // Fused conv can be either with side input or without.
938       if (custom_call->operand_count() == 3) {
939         TF_ASSIGN_OR_RETURN(
940             auto cnn_fused,
941             CreateOpWithoutAttrs<lmhlo_gpu::ConvForwardFusedOp>(custom_call));
942         TF_RETURN_IF_ERROR(set_activation(cnn_fused));
943         return set_common_conv_attributes(cnn_fused);
944       }
945 
946       TF_RET_CHECK(custom_call->operand_count() == 4);
947       TF_ASSIGN_OR_RETURN(
948           auto cnn_fused_side_input,
949           CreateOpWithoutAttrs<lmhlo_gpu::ConvForwardFusedSideInputOp>(
950               custom_call));
951       cnn_fused_side_input.side_input_scaleAttr(
952           builder_.getF64FloatAttr(backend_config.side_input_scale()));
953       TF_RETURN_IF_ERROR(set_activation(cnn_fused_side_input));
954       return set_common_conv_attributes(cnn_fused_side_input);
955     }
956   }
957 }
958 
EmitDnnBatchNorm(const HloCustomCallInstruction * custom_call)959 StatusOr<Operation*> LhloDialectEmitter::EmitDnnBatchNorm(
960     const HloCustomCallInstruction* custom_call) {
961   const xla::int64 num_operands = custom_call->operand_count();
962   auto set_batchnorm_attributes = [&](auto op) -> StatusOr<Operation*> {
963     // The last 2 operands of a custom call for batch norm are the epsilon and
964     // feature_index.
965     const HloInstruction* epsilon = custom_call->operand(num_operands - 2);
966     TF_RET_CHECK(epsilon->IsConstant());
967     float epsilon_value = epsilon->literal().Get<float>({});
968 
969     const HloInstruction* feature_index =
970         custom_call->operand(num_operands - 1);
971     TF_RET_CHECK(feature_index->IsConstant());
972     xla::int64 feature_index_value =
973         feature_index->literal().Get<xla::int64>({});
974 
975     op.epsilonAttr(builder_.getF32FloatAttr(epsilon_value));
976     op.feature_indexAttr(builder_.getI64IntegerAttr(feature_index_value));
977     return op.getOperation();
978   };
979 
980   const std::string& target = custom_call->custom_call_target();
981   if (target == xla::gpu::kCudnnBatchNormForwardTrainingCallTarget) {
982     TF_ASSIGN_OR_RETURN(auto fwd_training,
983                         CreateOpWithoutAttrs<lmhlo_gpu::BatchNormTrainingOp>(
984                             custom_call, num_operands - 2));
985     return set_batchnorm_attributes(fwd_training);
986   }
987 
988   if (target == xla::gpu::kCudnnBatchNormBackwardCallTarget) {
989     TF_ASSIGN_OR_RETURN(auto backward,
990                         CreateOpWithoutAttrs<lmhlo_gpu::BatchNormGradOp>(
991                             custom_call, num_operands - 2));
992     return set_batchnorm_attributes(backward);
993   }
994 
995   if (target == xla::gpu::kCudnnBatchNormForwardInferenceCallTarget) {
996     TF_ASSIGN_OR_RETURN(auto fwd_inference,
997                         CreateOpWithoutAttrs<lmhlo_gpu::BatchNormInferenceOp>(
998                             custom_call, num_operands - 2));
999     return set_batchnorm_attributes(fwd_inference);
1000   }
1001 
1002   return xla::Unimplemented("Unsupported batch norm operation");
1003 }
1004 
1005 // Convert an XLA HLO constant to a global_memref + get_global_memref pair.
EmitConstant(const HloInstruction * instr)1006 StatusOr<mlir::memref::GetGlobalOp> LhloDialectEmitter::EmitConstant(
1007     const HloInstruction* instr) {
1008   // Insert a global_memref in the module.
1009   Location loc = getLocation(instr);
1010 
1011   auto const_instr = xla::Cast<xla::HloConstantInstruction>(instr);
1012   TF_RET_CHECK(const_instr->shape().IsArray() &&
1013                const_instr->shape().is_static());
1014   TF_ASSIGN_OR_RETURN(Type type, xla::ConvertShapeToType<MemRefType>(
1015                                      const_instr->shape(), builder_));
1016   auto memref_type = type.dyn_cast<MemRefType>();
1017   TF_RET_CHECK(memref_type != nullptr);
1018 
1019   TF_ASSIGN_OR_RETURN(
1020       DenseElementsAttr initial_value,
1021       CreateDenseElementsAttrFromLiteral(const_instr->literal(), builder_));
1022 
1023   std::string constant_name = xla::llvm_ir::ConstantNameToGlobalName(
1024       xla::llvm_ir::SanitizeConstantName(instr->name()));
1025 
1026   // Insert the global memref at the top level.
1027   {
1028     OpBuilder::InsertionGuard guard(builder_);
1029     builder_.clearInsertionPoint();
1030     auto global_var = builder_.create<memref::GlobalOp>(
1031         loc, constant_name, builder_.getStringAttr("private"), memref_type,
1032         initial_value, true);
1033     SymbolTable(module_).insert(global_var);
1034     global_var.getOperation()->moveBefore(&module_.front());
1035 
1036     // For operations that do not fold this constant value in their codegen, we
1037     // still need to materialize it into a buffer. Since buffer allocation is
1038     // already done, annotate the global_memref with the information to get to
1039     // the allocated buffer slice for this constant if need be.
1040     TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
1041                         assignment_.GetUniqueTopLevelSlice(instr));
1042     global_var->setAttr(
1043         "lmhlo.alloc",
1044         builder_.getIndexAttr(allocations_.find(slice.allocation())
1045                                   ->second.cast<BlockArgument>()
1046                                   .getArgNumber()));
1047     TF_RET_CHECK(slice.offset() == 0)
1048         << "Each constant should have its own allocation from BufferAssignment";
1049     TF_RET_CHECK(slice.allocation()->size() == slice.size())
1050         << "Each constant should have its own allocation from BufferAssignment";
1051   }
1052 
1053   auto get_global_memref =
1054       builder_.create<memref::GetGlobalOp>(loc, memref_type, constant_name);
1055 
1056   // Update the cache to remember this value.
1057   auto& cached_value = slices_[std::make_pair(instr, xla::ShapeIndex())];
1058   TF_RET_CHECK(cached_value == nullptr);
1059   cached_value = get_global_memref;
1060   return get_global_memref;
1061 }
1062 
1063 namespace {
1064 template <typename OpT>
SetupChannelIdAttribute(OpT op,const xla::HloChannelInstruction * instr,mlir::Builder builder)1065 void SetupChannelIdAttribute(OpT op, const xla::HloChannelInstruction* instr,
1066                              mlir::Builder builder) {
1067   if (instr->channel_id().has_value()) {
1068     op.channel_idAttr(mlir::mhlo::ChannelHandle::get(
1069         builder.getI64IntegerAttr(*instr->channel_id()),
1070         builder.getI64IntegerAttr(0), builder.getContext()));
1071   }
1072 }
1073 
1074 template <typename OpT>
SetupCommonCollectiveOpAttributes(OpT op,const HloInstruction * instr,mlir::OpBuilder & builder)1075 Status SetupCommonCollectiveOpAttributes(OpT op, const HloInstruction* instr,
1076                                          mlir::OpBuilder& builder) {
1077   auto* collective = xla::Cast<xla::HloCollectiveInstruction>(instr);
1078   auto replica_groups_attr = xla::HloFunctionImporter::ConvertReplicaGroups(
1079       collective->replica_groups(), &builder);
1080   op->setAttr(replica_groups_attr.first, replica_groups_attr.second);
1081   op.constrain_layoutAttr(builder.getBoolAttr(collective->constrain_layout()));
1082   SetupChannelIdAttribute(op, collective, builder);
1083   return Status::OK();
1084 }
1085 }  // namespace
1086 
EmitAllToAllOp(const HloInstruction * instr)1087 StatusOr<lmhlo::AllToAllOp> LhloDialectEmitter::EmitAllToAllOp(
1088     const HloInstruction* instr) {
1089   TF_ASSIGN_OR_RETURN(auto all_to_all_op,
1090                       CreateOpWithoutAttrs<lmhlo::AllToAllOp>(instr));
1091   auto* all_to_all = xla::Cast<xla::HloAllToAllInstruction>(instr);
1092   TF_RETURN_IF_ERROR(
1093       SetupCommonCollectiveOpAttributes(all_to_all_op, instr, builder_));
1094   if (all_to_all->split_dimension().has_value()) {
1095     all_to_all_op.split_dimensionAttr(
1096         builder_.getI64IntegerAttr(*all_to_all->split_dimension()));
1097   }
1098   return all_to_all_op;
1099 }
1100 
EmitAllGatherOp(const HloInstruction * instr)1101 StatusOr<lmhlo::AllGatherOp> LhloDialectEmitter::EmitAllGatherOp(
1102     const HloInstruction* instr) {
1103   TF_ASSIGN_OR_RETURN(auto all_gather_op,
1104                       CreateOpWithoutAttrs<lmhlo::AllGatherOp>(instr));
1105   auto* all_gather = xla::Cast<xla::HloAllGatherInstruction>(instr);
1106   TF_RETURN_IF_ERROR(
1107       SetupCommonCollectiveOpAttributes(all_gather_op, instr, builder_));
1108   all_gather_op.use_global_device_idsAttr(
1109       builder_.getBoolAttr(all_gather->use_global_device_ids()));
1110   all_gather_op.all_gather_dimensionAttr(
1111       builder_.getI64IntegerAttr(all_gather->all_gather_dimension()));
1112   return all_gather_op;
1113 }
1114 
EmitAllReduceOp(const HloInstruction * instr)1115 StatusOr<lmhlo::AllReduceOp> LhloDialectEmitter::EmitAllReduceOp(
1116     const HloInstruction* instr) {
1117   TF_ASSIGN_OR_RETURN(auto all_reduce_op,
1118                       CreateOpWithoutAttrs<lmhlo::AllReduceOp>(instr));
1119   auto* all_reduce = xla::Cast<xla::HloAllReduceInstruction>(instr);
1120   TF_RETURN_IF_ERROR(
1121       SetupCommonCollectiveOpAttributes(all_reduce_op, instr, builder_));
1122   all_reduce_op.use_global_device_idsAttr(
1123       builder_.getBoolAttr(all_reduce->use_global_device_ids()));
1124   TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
1125       *instr->called_computations()[0], &all_reduce_op.computation(),
1126       &builder_));
1127   return all_reduce_op;
1128 }
1129 
EmitAllReduceStartOp(const HloInstruction * instr)1130 StatusOr<lmhlo_gpu::AllReduceStartOp> LhloDialectEmitter::EmitAllReduceStartOp(
1131     const HloInstruction* instr) {
1132   llvm::SmallVector<Value, 4> operands;
1133   for (const HloInstruction* operand : instr->operands()) {
1134     TF_RETURN_IF_ERROR(GetOrCreateView(operand, &operands));
1135   }
1136   // Only include result index {1}. {0} always aliases the inputs.
1137   TF_RETURN_IF_ERROR(GetOrCreateView(instr, &operands, /*result_subset=*/{1}));
1138 
1139   Location loc = getLocation(instr);
1140   mlir::Type token_type = mlir::mhlo::TokenType::get(builder_.getContext());
1141   std::array<mlir::Type, 1> result_types = {token_type};
1142   lmhlo_gpu::AllReduceStartOp all_reduce_start_op =
1143       builder_.create<lmhlo_gpu::AllReduceStartOp>(loc, result_types, operands);
1144 
1145   auto* all_reduce = xla::Cast<xla::HloAllReduceInstruction>(instr);
1146   TF_RETURN_IF_ERROR(
1147       SetupCommonCollectiveOpAttributes(all_reduce_start_op, instr, builder_));
1148   all_reduce_start_op.use_global_device_idsAttr(
1149       builder_.getBoolAttr(all_reduce->use_global_device_ids()));
1150   TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
1151       *instr->called_computations()[0], &all_reduce_start_op.computation(),
1152       &builder_));
1153 
1154   TF_RET_CHECK(all_reduce_start_ops_.emplace(instr, all_reduce_start_op).second)
1155       << "all-reduce-start already lowered";
1156   return all_reduce_start_op;
1157 }
1158 
EmitAllReduceDoneOp(const HloInstruction * instr)1159 StatusOr<lmhlo_gpu::AllReduceDoneOp> LhloDialectEmitter::EmitAllReduceDoneOp(
1160     const HloInstruction* instr) {
1161   auto it = all_reduce_start_ops_.find(instr->operand(0));
1162   TF_RET_CHECK(it != all_reduce_start_ops_.end())
1163       << "didn't find all-reduce-start op";
1164 
1165   llvm::SmallVector<Value, 4> operands;
1166   operands.push_back(it->second.token());
1167   all_reduce_start_ops_.erase(it);
1168 
1169   for (const HloInstruction* operand : instr->operands()) {
1170     TF_RETURN_IF_ERROR(GetOrCreateView(operand, &operands));
1171   }
1172   // We don't need to add buffers for the outputs, as these always alias inputs.
1173   return builder_.create<lmhlo_gpu::AllReduceDoneOp>(
1174       getLocation(instr), /*resultTypes=*/llvm::None, operands);
1175 }
1176 
EmitReduceScatterOp(const HloInstruction * instr)1177 StatusOr<lmhlo::ReduceScatterOp> LhloDialectEmitter::EmitReduceScatterOp(
1178     const HloInstruction* instr) {
1179   TF_ASSIGN_OR_RETURN(auto reduce_scatter_op,
1180                       CreateOpWithoutAttrs<lmhlo::ReduceScatterOp>(instr));
1181   auto* ars = xla::Cast<xla::HloReduceScatterInstruction>(instr);
1182   TF_RETURN_IF_ERROR(
1183       SetupCommonCollectiveOpAttributes(reduce_scatter_op, instr, builder_));
1184   reduce_scatter_op.use_global_device_idsAttr(
1185       builder_.getBoolAttr(ars->use_global_device_ids()));
1186   TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
1187       *instr->called_computations()[0], &reduce_scatter_op.computation(),
1188       &builder_));
1189   reduce_scatter_op.scatter_dimensionAttr(
1190       builder_.getI64IntegerAttr(ars->scatter_dimension()));
1191   return reduce_scatter_op;
1192 }
1193 
1194 StatusOr<lmhlo::CollectivePermuteOp>
EmitCollectivePermuteOp(const HloInstruction * instr)1195 LhloDialectEmitter::EmitCollectivePermuteOp(const HloInstruction* instr) {
1196   TF_ASSIGN_OR_RETURN(auto permute_op,
1197                       CreateOpWithoutAttrs<lmhlo::CollectivePermuteOp>(instr));
1198   auto* permute = xla::Cast<xla::HloCollectivePermuteInstruction>(instr);
1199   SetupChannelIdAttribute(permute_op, permute, builder_);
1200   mlir::NamedAttribute source_target_pairs_attr =
1201       xla::HloFunctionImporter::ConvertSourceTargetPairs(
1202           permute->source_target_pairs(), &builder_);
1203   permute_op->setAttr(source_target_pairs_attr.first,
1204                       source_target_pairs_attr.second);
1205   return permute_op;
1206 }
1207 
EmitInfeedOp(const HloInstruction * instr)1208 StatusOr<lmhlo::InfeedOp> LhloDialectEmitter::EmitInfeedOp(
1209     const HloInstruction* instr) {
1210   const HloInfeedInstruction* infeed = xla::Cast<HloInfeedInstruction>(instr);
1211   // HLO Infeed instruction has a single operand of token type and a tuple
1212   // with buffers and a token as its output. LMHLO Infeed operation does not
1213   // need the token operand or result, so drop it.
1214   SmallVector<Value, 2> operands;
1215   TF_RETURN_IF_ERROR(GetOrCreateView(instr, &operands, /*result_subset=*/{0}));
1216   auto infeed_op = CreateOpWithoutAttrs<lmhlo::InfeedOp>(instr, operands);
1217   infeed_op.configAttr(builder_.getStringAttr(infeed->infeed_config()));
1218   return infeed_op;
1219 }
1220 
EmitOutfeedOp(const HloInstruction * instr)1221 StatusOr<lmhlo::OutfeedOp> LhloDialectEmitter::EmitOutfeedOp(
1222     const HloInstruction* instr) {
1223   const HloOutfeedInstruction* outfeed =
1224       xla::Cast<HloOutfeedInstruction>(instr);
1225   // HLO outfeed instruction has 2 operands, the source and a token, and a
1226   // single token output. LMHLO Outfeed does not need the token operand and
1227   // result, do drop it.
1228   SmallVector<Value, 2> operands;
1229   TF_RETURN_IF_ERROR(GetOrCreateView(instr->operand(0), &operands));
1230   auto outfeed_op = CreateOpWithoutAttrs<lmhlo::OutfeedOp>(instr, operands);
1231   outfeed_op.configAttr(builder_.getStringAttr(outfeed->outfeed_config()));
1232   return outfeed_op;
1233 }
1234 
1235 xla::StatusOr<lmhlo::RngGetAndUpdateStateOp>
EmitRngGetAndUpdateStateOp(const xla::HloInstruction * instr)1236 LhloDialectEmitter::EmitRngGetAndUpdateStateOp(
1237     const xla::HloInstruction* instr) {
1238   TF_ASSIGN_OR_RETURN(
1239       auto rng, CreateOpWithoutAttrs<lmhlo::RngGetAndUpdateStateOp>(instr));
1240   auto hlo_rng = xla::Cast<xla::HloRngGetAndUpdateStateInstruction>(instr);
1241   rng.deltaAttr(builder_.getI64IntegerAttr(hlo_rng->delta()));
1242   return rng;
1243 }
1244 
EmitFftOp(const HloInstruction * instr)1245 xla::StatusOr<lmhlo::FftOp> LhloDialectEmitter::EmitFftOp(
1246     const HloInstruction* instr) {
1247   auto hlo_fft = xla::Cast<xla::HloFftInstruction>(instr);
1248   TF_ASSIGN_OR_RETURN(auto fft, CreateOpWithoutAttrs<lmhlo::FftOp>(instr));
1249   TF_ASSIGN_OR_RETURN(mlir::mhlo::FftType fft_type,
1250                       xla::ConvertFftType(hlo_fft->fft_type()));
1251   StringAttr fft_type_attr =
1252       builder_.getStringAttr(mlir::mhlo::stringifyFftType(fft_type));
1253   fft.fft_typeAttr(fft_type_attr);
1254   fft.fft_lengthAttr(GetI64DenseElementsAttr(instr->fft_length()));
1255   return fft;
1256 }
1257 
1258 xla::StatusOr<lmhlo::TriangularSolveOp>
EmitTriangularSolveOp(const xla::HloInstruction * instr)1259 LhloDialectEmitter::EmitTriangularSolveOp(const xla::HloInstruction* instr) {
1260   auto hlo_triangular_solve =
1261       xla::Cast<xla::HloTriangularSolveInstruction>(instr);
1262   TF_ASSIGN_OR_RETURN(auto triangular_solve,
1263                       CreateOpWithoutAttrs<lmhlo::TriangularSolveOp>(instr));
1264   const xla::TriangularSolveOptions& options =
1265       hlo_triangular_solve->triangular_solve_options();
1266   triangular_solve.left_sideAttr(builder_.getBoolAttr(options.left_side()));
1267   triangular_solve.lowerAttr(builder_.getBoolAttr(options.lower()));
1268   triangular_solve.unit_diagonalAttr(
1269       builder_.getBoolAttr(options.unit_diagonal()));
1270   TF_ASSIGN_OR_RETURN(mlir::mhlo::Transpose transpose,
1271                       xla::ConvertTranspose(options.transpose_a()));
1272   triangular_solve.transpose_aAttr(
1273       builder_.getStringAttr(mlir::mhlo::stringifyTranspose(transpose)));
1274   triangular_solve.layout_aAttr(
1275       GetLayoutAttribute(instr->operand(0)->shape().layout(), &builder_));
1276   triangular_solve.layout_bAttr(
1277       GetLayoutAttribute(instr->operand(1)->shape().layout(), &builder_));
1278   triangular_solve.layout_outputAttr(
1279       GetLayoutAttribute(instr->shape().layout(), &builder_));
1280   return triangular_solve;
1281 }
1282 
EmitBitcast(const xla::HloInstruction * instr)1283 xla::StatusOr<Operation*> LhloDialectEmitter::EmitBitcast(
1284     const xla::HloInstruction* instr) {
1285   // XLA buffer assignment should assign the same slice to a bitcast input and
1286   // output.
1287   const xla::ShapeIndex top_index;
1288   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice,
1289                       assignment_.GetUniqueSlice(instr, top_index));
1290   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice input_slice,
1291                       assignment_.GetUniqueSlice(instr->operand(0), top_index));
1292 
1293   if (input_slice != result_slice) {
1294     return xla::InvalidArgument(
1295         "Bitcast input and result slice should be same");
1296   }
1297   return nullptr;
1298 }
1299 
GetLayoutAttribute(const xla::Layout & layout,Builder * builder)1300 mlir::DenseIntElementsAttr LhloDialectEmitter::GetLayoutAttribute(
1301     const xla::Layout& layout, Builder* builder) {
1302   llvm::SmallVector<int64_t, 4> minor_to_major(layout.minor_to_major().begin(),
1303                                                layout.minor_to_major().end());
1304   return builder->getIndexTensorAttr(minor_to_major);
1305 }
1306 
ImportAsLmhloRegion(xla::HloComputation * computation,mlir::Region * region)1307 Status LhloDialectEmitter::ImportAsLmhloRegion(xla::HloComputation* computation,
1308                                                mlir::Region* region) {
1309   auto after = builder_.saveInsertionPoint();
1310   auto reverter = xla::MakeCleanup(
1311       [this, after] { builder_.restoreInsertionPoint(after); });
1312 
1313   builder_ = OpBuilder(region);
1314   const xla::HloInstructionSequence* schedule =
1315       assignment_.hlo_ordering().SequentialOrder(*computation);
1316   if (!schedule)
1317     return xla::Unimplemented("Missing sequential order for the computation");
1318   TF_RETURN_IF_ERROR(
1319       computation->AcceptOrdered(this, schedule->instructions()));
1320   builder_.create<lmhlo::TerminatorOp>(builder_.getUnknownLoc());
1321   return Status::OK();
1322 }
1323 
EmitCaseOp(const HloInstruction * instr)1324 StatusOr<lmhlo::CaseOp> LhloDialectEmitter::EmitCaseOp(
1325     const HloInstruction* instr) {
1326   Location loc = getLocation(instr);
1327   llvm::SmallVector<Value, 4> operands;
1328   size_t num_arguments, num_results;
1329   TF_RETURN_IF_ERROR(CreateOperands(instr, 1, TokenLoweringMode::kUseNull,
1330                                     operands, num_arguments, num_results));
1331 
1332   auto case_op =
1333       builder_.create<lmhlo::CaseOp>(loc, operands[0], instr->branch_count());
1334 
1335   for (int i = 0; i < instr->branch_count(); i++) {
1336     case_op.branches()[i].push_back(new mlir::Block());
1337     TF_RETURN_IF_ERROR(ImportAsLmhloRegion(instr->called_computations()[i],
1338                                            &case_op.branches()[i]));
1339   }
1340 
1341   return case_op;
1342 }
1343 
EmitWhileOp(const xla::HloInstruction * instr)1344 xla::StatusOr<lmhlo::WhileOp> LhloDialectEmitter::EmitWhileOp(
1345     const xla::HloInstruction* instr) {
1346   Location loc = getLocation(instr);
1347   SmallVector<Value, 1> operands;
1348   TF_RETURN_IF_ERROR(GetOrCreateView(
1349       instr->called_computations()[1]->root_instruction(), &operands));
1350   TF_RET_CHECK(operands.size() == 1);
1351 
1352   TF_ASSIGN_OR_RETURN(auto config,
1353                       instr->backend_config<xla::WhileLoopBackendConfig>());
1354   mlir::IntegerAttr trip_count;
1355   if (config.has_known_trip_count()) {
1356     trip_count = builder_.getI64IntegerAttr(config.known_trip_count().n());
1357   }
1358   lmhlo::WhileOp while_op =
1359       builder_.create<lmhlo::WhileOp>(loc, operands[0], trip_count);
1360 
1361   while_op.cond().push_back(new mlir::Block());
1362   while_op.body().push_back(new mlir::Block());
1363   TF_RETURN_IF_ERROR(
1364       ImportAsLmhloRegion(instr->called_computations()[1], &while_op.cond()));
1365 
1366   TF_RETURN_IF_ERROR(
1367       ImportAsLmhloRegion(instr->called_computations()[0], &while_op.body()));
1368 
1369   return while_op;
1370 }
1371 
GetOrCreateArrayView(const xla::HloInstruction * instr,const xla::Shape & current_shape,const xla::ShapeIndex & shape_index)1372 StatusOr<Value> LhloDialectEmitter::GetOrCreateArrayView(
1373     const xla::HloInstruction* instr, const xla::Shape& current_shape,
1374     const xla::ShapeIndex& shape_index) {
1375   // Cache generated ViewOp and StaticMemRefCastOp by (instruction,
1376   // shape_index).
1377   auto& cached_value = slices_[std::make_pair(instr, shape_index)];
1378   if (cached_value) {
1379     return cached_value;
1380   }
1381 
1382   if (instr->IsConstant() && shape_index.empty()) {
1383     TF_ASSIGN_OR_RETURN(Value constant_memref, EmitConstant(instr));
1384     return cached_value = constant_memref;
1385   }
1386 
1387   // If the shape happens to have dynamic dimensions, create the memref using
1388   // the underlying static shape.
1389   // TODO(jurahul): Revisit this when we can model memrefs with dynamic shape
1390   // but static bounds in MLIR.
1391   const Shape static_shape = xla::ShapeUtil::MakeStaticShape(current_shape);
1392 
1393   TF_ASSIGN_OR_RETURN(Type out_type, xla::ConvertShapeToType<MemRefType>(
1394                                          static_shape, builder_));
1395   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
1396                       assignment_.GetUniqueSlice(instr, shape_index));
1397   Value alloc = allocations_[slice.allocation()];
1398 
1399   // TODO(timshen): revisit location handling.
1400   Location loc = builder_.getUnknownLoc();
1401 
1402   Value result;
1403   if (AllocationShouldLowerToTypedArg(slice.allocation())) {
1404     TF_RET_CHECK(slice.offset() == 0);
1405     TF_RET_CHECK(slice.size() == slice.allocation()->size());
1406     result = alloc;
1407   } else {
1408     Value byte_shift =
1409         builder_.create<ConstantIndexOp>(alloc.getLoc(), slice.offset());
1410 
1411     xla::Shape physical_shape =
1412         xla::ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
1413             static_shape);
1414     TF_ASSIGN_OR_RETURN(
1415         Type physical_out_type,
1416         xla::ConvertShapeToType<MemRefType>(physical_shape, builder_));
1417 
1418     // ViewOp only takes memrefs without affine maps (layouts). Let ViewOp
1419     // produce the physical shape (where dimensions are ordered in major to
1420     // minor) first, then follow up with a MemRefReinterpretCast to cast the
1421     // resulting memref to the original layout.
1422     result = builder_.create<memref::ViewOp>(loc, physical_out_type, alloc,
1423                                              byte_shift,
1424                                              /*sizes=*/ValueRange{});
1425   }
1426   if (result.getType() != out_type) {
1427     int64_t out_offset;
1428     SmallVector<int64_t, 4> out_strides;
1429     auto out_memref_type = out_type.dyn_cast<MemRefType>();
1430     if (!out_memref_type)
1431       return tensorflow::errors::Internal(
1432           "Expected memref type when creating a view for leaf type of a "
1433           "tuple.");
1434     if (failed(getStridesAndOffset(out_memref_type, out_strides, out_offset)))
1435       return tensorflow::errors::Internal(
1436           "Failed to get strides and offset from the output type.");
1437     result = builder_.create<memref::ReinterpretCastOp>(
1438         loc, out_memref_type, result, out_offset, out_memref_type.getShape(),
1439         out_strides);
1440   }
1441   return cached_value = result;
1442 }
1443 
GetOrCreateViewImpl(const HloInstruction * instr,const Shape & current_shape,xla::ShapeIndex * current_shape_index,SmallVectorImpl<Value> * values,TokenLoweringMode token_mode)1444 Status LhloDialectEmitter::GetOrCreateViewImpl(
1445     const HloInstruction* instr, const Shape& current_shape,
1446     xla::ShapeIndex* current_shape_index, SmallVectorImpl<Value>* values,
1447     TokenLoweringMode token_mode) {
1448   if (current_shape.IsTuple()) {
1449     for (int i = 0; i < current_shape.tuple_shapes().size(); ++i) {
1450       current_shape_index->push_back(i);
1451       TF_RETURN_IF_ERROR(
1452           GetOrCreateViewImpl(instr, current_shape.tuple_shapes(i),
1453                               current_shape_index, values, token_mode));
1454       current_shape_index->pop_back();
1455     }
1456     return Status::OK();
1457   }
1458   if (current_shape.IsArray()) {
1459     TF_ASSIGN_OR_RETURN(auto v, GetOrCreateArrayView(instr, current_shape,
1460                                                      *current_shape_index));
1461     values->push_back(v);
1462     return Status::OK();
1463   }
1464   if (current_shape.IsToken()) {
1465     switch (token_mode) {
1466       case TokenLoweringMode::kFailToLower:
1467         return xla::InternalError(
1468             "Unexpected token kind for %s and shape index %s",
1469             instr->ToString(), current_shape_index->ToString());
1470 
1471       case TokenLoweringMode::kUseNull:
1472         values->push_back(Value{});
1473         return Status::OK();
1474     }
1475   }
1476   return xla::InternalError("Unexpected shape kind for %s and shape index %s",
1477                             instr->ToString(), current_shape_index->ToString());
1478 }
1479 
1480 // Returns a view for the result of an instruction.
1481 // We first get a view for the slice in the allocation, and then may need to
1482 // create another view to adjust the slice for the shape of the instruction.
GetOrCreateView(const HloInstruction * instr,SmallVectorImpl<Value> * values,const xla::ShapeIndex & result_subset,TokenLoweringMode token_mode)1483 Status LhloDialectEmitter::GetOrCreateView(const HloInstruction* instr,
1484                                            SmallVectorImpl<Value>* values,
1485                                            const xla::ShapeIndex& result_subset,
1486                                            TokenLoweringMode token_mode) {
1487   xla::ShapeIndex shape_index = result_subset;
1488   const Shape& sub_shape =
1489       xla::ShapeUtil::GetSubshape(instr->shape(), shape_index);
1490   return GetOrCreateViewImpl(instr, sub_shape, &shape_index, values,
1491                              token_mode);
1492 }
1493 
Initialize()1494 Status LhloDialectEmitter::Initialize() {
1495   TF_RET_CHECK(computation_.IsEntryComputation());
1496 
1497   mlir::IntegerAttr unique_id =
1498       builder_.getI32IntegerAttr(computation_.parent()->unique_id());
1499   module_->setAttr("hlo.unique_id", unique_id);
1500   std::string function_name =
1501       computation_.name().empty() ? "__compute" : computation_.name();
1502 
1503   // Create the function as () -> (), we'll compute the arguments from the
1504   // buffer allocation and update the type then.
1505   auto func_op = FuncOp::create(builder_.getUnknownLoc(), function_name,
1506                                 builder_.getFunctionType({}, {}));
1507 
1508   {
1509     // This is an optional attribute used by the XLA backend. If the resulting
1510     // LMHLO doesn't go through XLA, this is not needed.
1511     const Shape& shape = computation_.root_instruction()->shape();
1512     func_op->setAttr(
1513         "result_xla_shape",
1514         builder_.getStringAttr(shape.ToString(/*print_layout=*/true)));
1515   }
1516   Block* block = func_op.addEntryBlock();
1517 
1518   llvm::SmallVector<const BufferAllocation*, 8> ordered_allocations;
1519   for (const BufferAllocation& alloc : assignment_.Allocations())
1520     ordered_allocations.push_back(&alloc);
1521 
1522   if (computation_.IsEntryComputation()) {
1523     // Sort the rather arbitrarily ordered allocations to match the input/output
1524     // parameters. Specifically we want to sort buffer allocations in the
1525     // following order:
1526     // * Parameters always order before non-parameters.
1527     // * Different parameters order by parameter number.
1528     // * Different allocations for the same parameter order by the shape index.
1529     //
1530     // TODO(timshen): there should be only one non-parameter buffer, the temp
1531     // buffer. Check on that.
1532     const auto allocation_comparator = [](const BufferAllocation* lhs,
1533                                           const BufferAllocation* rhs) {
1534       if (lhs->is_entry_computation_parameter() !=
1535           rhs->is_entry_computation_parameter()) {
1536         return lhs->is_entry_computation_parameter() >
1537                rhs->is_entry_computation_parameter();
1538       }
1539       if (lhs->is_entry_computation_parameter()) {
1540         return std::tuple<int, const xla::ShapeIndex&>(
1541                    lhs->parameter_number(), lhs->param_shape_index()) <
1542                std::tuple<int, const xla::ShapeIndex&>(
1543                    rhs->parameter_number(), rhs->param_shape_index());
1544       }
1545       return false;
1546     };
1547 
1548     std::stable_sort(ordered_allocations.begin(), ordered_allocations.end(),
1549                      allocation_comparator);
1550   }
1551 
1552   absl::flat_hash_map<const BufferAllocation*,
1553                       std::pair<const Shape*, xla::ShapeIndex>>
1554       allocation_to_output_info;
1555   TF_RETURN_IF_ERROR(xla::ShapeUtil::ForEachSubshapeWithStatus(
1556       computation_.root_instruction()->shape(),
1557       [&](const Shape& sub_shape, xla::ShapeIndex index) -> Status {
1558         TF_ASSIGN_OR_RETURN(
1559             auto slice,
1560             assignment_.GetUniqueSlice(computation_.root_instruction(), index));
1561         const BufferAllocation* alloc = slice.allocation();
1562         TF_RET_CHECK(slice.offset() == 0);
1563         TF_RET_CHECK(slice.size() == alloc->size());
1564         allocation_to_output_info[alloc] = std::make_pair(&sub_shape, index);
1565         return Status::OK();
1566       }));
1567 
1568   // The function signature will be composed of:
1569   // - one memref for each of the parameters.
1570   // - one memref for each other buffer allocation.
1571   llvm::SmallVector<DictionaryAttr, 8> args_attrs;
1572   for (const BufferAllocation* alloc : ordered_allocations) {
1573     if (alloc->is_thread_local()) {
1574       continue;
1575     }
1576 
1577     // There are optional attributes to help the program run through XLA. XLA
1578     // defines ExecutionInput and ExecutionOutput structures to carry
1579     // input-output type and buffer information, therefore any information they
1580     // need (mainly the type structure, potentially containing tuples) to be
1581     // preserved. They are not needed if the generated LMHLO is not sent to XLA.
1582     NamedAttrList arg_attr_list;
1583     mlir::Type arg_type;
1584     if (AllocationShouldLowerToTypedArg(alloc)) {
1585       xla::Shape buffer_shape = xla::ShapeUtil::GetSubshape(
1586           computation_.parameter_instruction(alloc->parameter_number())
1587               ->shape(),
1588           alloc->param_shape_index());
1589 
1590       if (buffer_shape.IsTuple()) {
1591         arg_type = MemRefType::get({alloc->size()}, i8_type_);
1592       } else {
1593         // TODO(jurahul): Revisit this when we can model memrefs with dynamic
1594         // shape but static bounds in MLIR.
1595         const Shape static_shape =
1596             xla::ShapeUtil::MakeStaticShape(buffer_shape);
1597         TF_ASSIGN_OR_RETURN(arg_type, xla::ConvertShapeToType<MemRefType>(
1598                                           static_shape, builder_));
1599       }
1600     } else {
1601       arg_type = MemRefType::get({alloc->size()}, i8_type_);
1602     }
1603 
1604     if (alloc->is_entry_computation_parameter()) {
1605       arg_attr_list.set("lmhlo.params",
1606                         builder_.getIndexAttr(alloc->parameter_number()));
1607       if (!alloc->param_shape_index().empty()) {
1608         arg_attr_list.set("lmhlo.param_shape_index",
1609                           builder_.getI64TensorAttr(llvm::makeArrayRef(
1610                               alloc->param_shape_index().begin(),
1611                               alloc->param_shape_index().end())));
1612       }
1613     }
1614     // Optional: an attribute for optimization. If a kernel uses this
1615     // allocation, but the allocation has lmhlo.constant_name, then the kernel
1616     // will instead use the global value indicated by the name for potentially
1617     // more optimizations (e.g. constant propagation).
1618     if (alloc->is_constant()) {
1619       arg_attr_list.set(
1620           "lmhlo.constant_name",
1621           builder_.getStringAttr(
1622               xla::llvm_ir::ConstantBufferAllocationToGlobalName(*alloc)));
1623     }
1624     auto iter = allocation_to_output_info.find(alloc);
1625     if (iter != allocation_to_output_info.end()) {
1626       const Shape* sub_shape = iter->second.first;
1627       const xla::ShapeIndex& shape_index = iter->second.second;
1628       if (!sub_shape->IsArray()) {
1629         continue;
1630       }
1631       arg_attr_list.set("lmhlo.output_index",
1632                         builder_.getI64TensorAttr(llvm::makeArrayRef(
1633                             shape_index.begin(), shape_index.end())));
1634       if (auto alias = computation_.parent()
1635                            ->input_output_alias_config()
1636                            .GetAliasedParameter(shape_index)) {
1637         if (alias->must_alias()) {
1638           arg_attr_list.set("lmhlo.must_alias", builder_.getUnitAttr());
1639         }
1640       }
1641     }
1642     block->addArgument(arg_type);
1643     allocations_[alloc] = block->getArguments().back();
1644     args_attrs.push_back(arg_attr_list.getDictionary(builder_.getContext()));
1645   }
1646 
1647   FunctionType function_type =
1648       builder_.getFunctionType(block->getArgumentTypes(), {});
1649   func_op.setType(function_type);
1650   func_op.setAllArgAttrs(args_attrs);
1651 
1652   SymbolTable symbol_table(module_);
1653   symbol_table.insert(func_op);
1654   builder_.setInsertionPointToEnd(block);
1655 
1656   auto return_op =
1657       builder_.create<lmhlo::TerminatorOp>(builder_.getUnknownLoc());
1658   builder_ = OpBuilder(return_op);
1659 
1660   return Status::OK();
1661 }
1662 
createXlaHloToLhloWithXlaPass()1663 std::unique_ptr<OperationPass<ModuleOp>> createXlaHloToLhloWithXlaPass() {
1664   return std::make_unique<XlaHloToLhloPass>();
1665 }
1666 
HloToLhloModule(const BufferAssignment & assignment,const HloModule & hlo_module,ModuleOp module)1667 Status HloToLhloModule(const BufferAssignment& assignment,
1668                        const HloModule& hlo_module, ModuleOp module) {
1669   module.getContext()
1670       ->loadDialect<StandardOpsDialect, memref::MemRefDialect,
1671                     mhlo::MhloDialect, lmhlo::LmhloDialect,
1672                     lmhlo_gpu::LmhloGpuDialect>();
1673 
1674   module->setLoc(mlir::NameLoc::get(
1675       mlir::Identifier::get(hlo_module.name(), module.getContext())));
1676 
1677   const HloComputation* computation = hlo_module.entry_computation();
1678 
1679   LhloDialectEmitter emitter(assignment, *computation, module);
1680   TF_RETURN_IF_ERROR(emitter.Initialize());
1681 
1682   const xla::HloInstructionSequence* schedule =
1683       assignment.hlo_ordering().SequentialOrder(*computation);
1684   if (!schedule)
1685     return xla::Unimplemented("Missing sequential order for the computation");
1686   const std::vector<HloInstruction*>& ordering = schedule->instructions();
1687   TF_RETURN_IF_ERROR(computation->AcceptOrdered(&emitter, ordering));
1688   TF_RET_CHECK(succeeded(mlir::verify(module)));
1689   return Status::OK();
1690 }
1691 
HloTextToLhloTranslateFunction(llvm::StringRef input,MLIRContext * context)1692 OwningModuleRef HloTextToLhloTranslateFunction(llvm::StringRef input,
1693                                                MLIRContext* context) {
1694   StatusOr<std::unique_ptr<HloModule>> maybe_module =
1695       xla::ParseAndReturnUnverifiedModule(
1696           absl::string_view(input.data(), input.size()));
1697   TF_CHECK_OK(maybe_module.status());
1698 
1699   OwningModuleRef module = ModuleOp::create(UnknownLoc::get(context));
1700 
1701   TF_CHECK_OK(OptimizeAndConvertHloToLmhlo(maybe_module.ConsumeValueOrDie(),
1702                                            module.get(), "Host"));
1703 
1704   return module;
1705 }
1706 
1707 static PassRegistration<XlaHloToLhloPass> registration;
1708 
1709 }  // namespace mlir
1710