1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h"
17
18 #include <climits>
19 #include <memory>
20 #include <tuple>
21
22 #include "absl/algorithm/container.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
26 #include "mlir/IR/AffineExpr.h" // from @llvm-project
27 #include "mlir/IR/AffineMap.h" // from @llvm-project
28 #include "mlir/IR/Attributes.h" // from @llvm-project
29 #include "mlir/IR/Builders.h" // from @llvm-project
30 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
31 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
32 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
33 #include "mlir/IR/Dialect.h" // from @llvm-project
34 #include "mlir/IR/Location.h" // from @llvm-project
35 #include "mlir/IR/MLIRContext.h" // from @llvm-project
36 #include "mlir/IR/OpDefinition.h" // from @llvm-project
37 #include "mlir/IR/Operation.h" // from @llvm-project
38 #include "mlir/IR/PatternMatch.h" // from @llvm-project
39 #include "mlir/IR/SymbolTable.h" // from @llvm-project
40 #include "mlir/Pass/Pass.h" // from @llvm-project
41 #include "mlir/Pass/PassOptions.h" // from @llvm-project
42 #include "mlir/Translation.h" // from @llvm-project
43 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
44 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.h"
45 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
46 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
47 #include "tensorflow/compiler/mlir/xla/attribute_importer.h"
48 #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
49 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
50 #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
51 #include "tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.h"
52 #include "tensorflow/compiler/xla/debug_options_flags.h"
53 #include "tensorflow/compiler/xla/service/backend.h"
54 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
55 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
56 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
57 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
58 #include "tensorflow/compiler/xla/service/hlo_computation.h"
59 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
60 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
61 #include "tensorflow/compiler/xla/service/hlo_module.h"
62 #include "tensorflow/compiler/xla/service/hlo_parser.h"
63 #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
64 #include "tensorflow/compiler/xla/shape_util.h"
65 #include "tensorflow/compiler/xla/statusor.h"
66 #include "tensorflow/compiler/xla/util.h"
67 #include "tensorflow/compiler/xla/window_util.h"
68 #include "tensorflow/compiler/xla/xla_data.pb.h"
69
70 using xla::BufferAllocation;
71 using xla::BufferAssignment;
72 using xla::HloComputation;
73 using xla::HloCustomCallInstruction;
74 using xla::HloInfeedInstruction;
75 using xla::HloInstruction;
76 using xla::HloModule;
77 using xla::HloModuleProto;
78 using xla::HloOutfeedInstruction;
79 using xla::HloProto;
80 using xla::Shape;
81 using xla::StatusOr;
82
83 namespace mlir {
84 namespace {
85
StringRefToView(llvm::StringRef ref)86 absl::string_view StringRefToView(llvm::StringRef ref) {
87 return {ref.data(), ref.size()};
88 }
89
HloModuleFromProto(const HloProto & hlo_proto)90 StatusOr<std::unique_ptr<HloModule>> HloModuleFromProto(
91 const HloProto& hlo_proto) {
92 const HloModuleProto& module_proto = hlo_proto.hlo_module();
93 TF_ASSIGN_OR_RETURN(const xla::HloModuleConfig module_config,
94 HloModule::CreateModuleConfigFromProto(
95 module_proto, xla::GetDebugOptionsFromFlags()));
96 return HloModule::CreateFromProto(module_proto, module_config);
97 }
98
99 // Convert the MLIR `module` from HLO dialect to LHLO dialect using XLA for the
100 // given platform.
ConvertModule(std::unique_ptr<HloModule> hlo_module,ModuleOp module,StringRef platform_name)101 Status ConvertModule(std::unique_ptr<HloModule> hlo_module, ModuleOp module,
102 StringRef platform_name) {
103 auto platform = xla::se::MultiPlatformManager::PlatformWithName(
104 StringRefToView(platform_name));
105 if (!platform.ok()) {
106 std::string error_msg;
107 llvm::raw_string_ostream os(error_msg);
108 os << "failed to get platform: " << platform.status().ToString()
109 << " (available Platform: ";
110 std::vector<std::string> available_platforms;
111 (void)xla::se::MultiPlatformManager::PlatformsWithFilter(
112 [&](const stream_executor::Platform* p) {
113 available_platforms.push_back(p->Name());
114 return false;
115 });
116 llvm::interleaveComma(available_platforms, os);
117 os << ")";
118 return xla::InvalidArgument("%s", os.str().c_str());
119 }
120
121 xla::BackendOptions backend_options;
122 backend_options.set_platform(platform.ValueOrDie());
123 auto backend_or_err = xla::Backend::CreateBackend(backend_options);
124 TF_RETURN_WITH_CONTEXT_IF_ERROR(backend_or_err.status(),
125 "failed to create XLA Backend ");
126 auto backend = std::move(backend_or_err.ValueOrDie());
127
128 // Run all HLO passes to produce an optimized module.
129 auto result_or = backend->compiler()->RunHloPassesAndBufferAssignement(
130 std::move(hlo_module), backend->default_stream_executor(),
131 optimize_xla_hlo, {backend->memory_allocator()});
132 TF_RETURN_WITH_CONTEXT_IF_ERROR(result_or.status(),
133 "running XLA pass pipeline");
134 std::unique_ptr<HloModule> optimized_hlo_module =
135 std::move(std::get<0>(result_or.ValueOrDie()));
136 std::unique_ptr<BufferAssignment> assignment =
137 std::move(std::get<1>(result_or.ValueOrDie()));
138
139 // Clear the module before populating it back with the result of the
140 // conversion.
141 module.getBody()->clear();
142 OpBuilder builder(module);
143 module.ensureTerminator(module.getBodyRegion(), builder, module.getLoc());
144
145 TF_RETURN_WITH_CONTEXT_IF_ERROR(
146 HloToLhloModule(*assignment, *optimized_hlo_module, module),
147 "converting HLO to LHLO");
148
149 return Status::OK();
150 }
151
152 // This pass takes an MLIR HLO module, converts it to XLA to perform the HLO
153 // optimization pipeline for the required platform, and then converts it back to
154 // MLIR LHLO.
155 class XlaHloToLhloPass
156 : public PassWrapper<XlaHloToLhloPass, OperationPass<ModuleOp>> {
getDependentDialects(DialectRegistry & registry) const157 void getDependentDialects(DialectRegistry& registry) const override {
158 registry
159 .insert<mlir::StandardOpsDialect, mlir::mhlo::MhloDialect,
160 mlir::lmhlo::LmhloDialect, mlir::lmhlo_gpu::LmhloGpuDialect>();
161 }
162
163 public:
164 XlaHloToLhloPass() = default;
XlaHloToLhloPass(const XlaHloToLhloPass &)165 XlaHloToLhloPass(const XlaHloToLhloPass&) {}
166
167 private:
runOnOperation()168 void runOnOperation() final {
169 ModuleOp module = getOperation();
170
171 auto status = [&module, this]() -> Status {
172 SymbolTable symbol_table(module);
173 if (!symbol_table.lookup("main")) {
174 return xla::InvalidArgument(
175 "conversion to HLO module failed: missing main()");
176 }
177 HloProto hlo_proto;
178 TF_RETURN_WITH_CONTEXT_IF_ERROR(
179 ConvertMlirHloToHlo(module, &hlo_proto,
180 /*use_tuple_args=*/false,
181 /*return_tuple=*/false,
182 /*shape_representation_fn=*/nullptr),
183 "conversion to XLA HLO proto failed");
184
185 auto statusOrHloModule = HloModuleFromProto(hlo_proto);
186 TF_RETURN_WITH_CONTEXT_IF_ERROR(statusOrHloModule.status(),
187 "parsing HLO proto to HLO module failed");
188 std::unique_ptr<HloModule> hlo_module =
189 std::move(statusOrHloModule.ValueOrDie());
190
191 return ConvertModule(std::move(hlo_module), module, platform_);
192 }();
193 if (!status.ok()) {
194 module.emitError() << status.ToString();
195 return signalPassFailure();
196 }
197 }
198
199 Option<std::string> platform_{
200 *this, "platform",
201 llvm::cl::desc("The platform to use for the XLA optimization pipeline."),
202 llvm::cl::init("Host")};
203 };
204
205 } // namespace
206
207 // Creates MLIR operands corresponding to operands and results of the XLA HLO
208 // instruction. If `num_operands` is valid, then only the first `num_operands`
209 // operands of the HLO instruction will be considered.
CreateOperands(const HloInstruction * instr,absl::optional<xla::int64> num_operands,llvm::SmallVectorImpl<Value> & operands,size_t & num_arguments,size_t & num_results)210 Status LhloDialectEmitter::CreateOperands(
211 const HloInstruction* instr, absl::optional<xla::int64> num_operands,
212 llvm::SmallVectorImpl<Value>& operands, size_t& num_arguments,
213 size_t& num_results) {
214 if (num_operands.value_or(0) > instr->operand_count())
215 return xla::InvalidArgument("num_operands must be <= operand count");
216 for (xla::int64 i = 0; i < num_operands.value_or(instr->operand_count());
217 ++i) {
218 TF_RETURN_IF_ERROR(GetOrCreateView(instr->operand(i), &operands));
219 }
220 num_arguments = operands.size();
221 TF_RETURN_IF_ERROR(GetOrCreateView(instr, &operands));
222 num_results = operands.size() - num_arguments;
223 return Status::OK();
224 }
225
226 template <typename OpType>
CreateOpWithoutAttrs(const HloInstruction * instr,ValueRange operands)227 OpType LhloDialectEmitter::CreateOpWithoutAttrs(const HloInstruction* instr,
228 ValueRange operands) {
229 Location loc = getLocation(instr);
230 return builder_.create<OpType>(loc, llvm::None, operands,
231 llvm::ArrayRef<NamedAttribute>{});
232 }
233
234 template <typename OpType>
CreateOpWithoutAttrs(const HloInstruction * instr,size_t & num_arguments,size_t & num_results,absl::optional<xla::int64> num_operands)235 StatusOr<OpType> LhloDialectEmitter::CreateOpWithoutAttrs(
236 const HloInstruction* instr, size_t& num_arguments, size_t& num_results,
237 absl::optional<xla::int64> num_operands) {
238 llvm::SmallVector<Value, 4> operands;
239 TF_RETURN_IF_ERROR(CreateOperands(instr, num_operands, operands,
240 num_arguments, num_results));
241 return CreateOpWithoutAttrs<OpType>(instr, operands);
242 }
243
EmitOp(const HloInstruction * instr)244 StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(
245 const HloInstruction* instr) {
246 using xla::HloOpcode;
247 switch (instr->opcode()) {
248 case HloOpcode::kAbs:
249 return CreateOpWithoutAttrs<lmhlo::AbsOp>(instr);
250 case HloOpcode::kAdd:
251 return CreateOpWithoutAttrs<lmhlo::AddOp>(instr);
252 case HloOpcode::kAllToAll:
253 return EmitAllToAllOp(instr);
254 case HloOpcode::kAllGather:
255 return EmitAllGatherOp(instr);
256 case HloOpcode::kAllReduce:
257 return EmitAllReduceOp(instr);
258 case HloOpcode::kAnd:
259 return CreateOpWithoutAttrs<lmhlo::AndOp>(instr);
260 case HloOpcode::kAtan2:
261 return CreateOpWithoutAttrs<lmhlo::Atan2Op>(instr);
262 case HloOpcode::kBitcast:
263 return nullptr;
264 case HloOpcode::kBitcastConvert:
265 return CreateOpWithoutAttrs<lmhlo::BitcastConvertOp>(instr);
266 case HloOpcode::kBroadcast:
267 return EmitBroadcastOp(instr);
268 case HloOpcode::kCeil:
269 return CreateOpWithoutAttrs<lmhlo::CeilOp>(instr);
270 case HloOpcode::kCbrt:
271 return CreateOpWithoutAttrs<lmhlo::CbrtOp>(instr);
272 case HloOpcode::kClamp:
273 return CreateOpWithoutAttrs<lmhlo::ClampOp>(instr);
274 case HloOpcode::kCollectivePermute:
275 return EmitCollectivePermuteOp(instr);
276 case HloOpcode::kClz:
277 return CreateOpWithoutAttrs<lmhlo::ClzOp>(instr);
278 case HloOpcode::kCompare:
279 return EmitCompareOp(instr);
280 case HloOpcode::kComplex:
281 return CreateOpWithoutAttrs<lmhlo::ComplexOp>(instr);
282 case HloOpcode::kConcatenate:
283 return EmitConcatenateOp(instr);
284 case HloOpcode::kConvert:
285 return CreateOpWithoutAttrs<lmhlo::ConvertOp>(instr);
286 case HloOpcode::kCopy:
287 return CreateOpWithoutAttrs<lmhlo::CopyOp>(instr);
288 case HloOpcode::kCos:
289 return CreateOpWithoutAttrs<lmhlo::CosOp>(instr);
290 case HloOpcode::kDivide:
291 return CreateOpWithoutAttrs<lmhlo::DivOp>(instr);
292 case HloOpcode::kDot:
293 return EmitDotOp(instr);
294 case HloOpcode::kDynamicSlice:
295 return EmitDynamicSliceOp(instr);
296 case HloOpcode::kDynamicUpdateSlice:
297 return CreateOpWithoutAttrs<lmhlo::DynamicUpdateSliceOp>(instr);
298 case HloOpcode::kFft:
299 return EmitFftOp(instr);
300 case HloOpcode::kExp:
301 return CreateOpWithoutAttrs<lmhlo::ExpOp>(instr);
302 case HloOpcode::kExpm1:
303 return CreateOpWithoutAttrs<lmhlo::Expm1Op>(instr);
304 case HloOpcode::kFloor:
305 return CreateOpWithoutAttrs<lmhlo::FloorOp>(instr);
306 case HloOpcode::kGather:
307 return EmitGatherOp(instr);
308 case HloOpcode::kImag:
309 return CreateOpWithoutAttrs<lmhlo::ImagOp>(instr);
310 case HloOpcode::kInfeed:
311 return EmitInfeedOp(instr);
312 case HloOpcode::kIota:
313 return EmitIotaOp(instr);
314 case HloOpcode::kIsFinite:
315 return CreateOpWithoutAttrs<lmhlo::IsFiniteOp>(instr);
316 case HloOpcode::kLog:
317 return CreateOpWithoutAttrs<lmhlo::LogOp>(instr);
318 case HloOpcode::kLog1p:
319 return CreateOpWithoutAttrs<lmhlo::Log1pOp>(instr);
320 case HloOpcode::kMap:
321 return EmitMapOp(instr);
322 case HloOpcode::kMaximum:
323 return CreateOpWithoutAttrs<lmhlo::MaxOp>(instr);
324 case HloOpcode::kMinimum:
325 return CreateOpWithoutAttrs<lmhlo::MinOp>(instr);
326 case HloOpcode::kMultiply:
327 return CreateOpWithoutAttrs<lmhlo::MulOp>(instr);
328 case HloOpcode::kNegate:
329 return CreateOpWithoutAttrs<lmhlo::NegOp>(instr);
330 case HloOpcode::kNot:
331 return CreateOpWithoutAttrs<lmhlo::NotOp>(instr);
332 case HloOpcode::kOr:
333 return CreateOpWithoutAttrs<lmhlo::OrOp>(instr);
334 case HloOpcode::kOutfeed:
335 return EmitOutfeedOp(instr);
336 case HloOpcode::kPartitionId:
337 return CreateOpWithoutAttrs<lmhlo::PartitionIdOp>(instr);
338 case HloOpcode::kPad:
339 return EmitPadOp(instr);
340 case HloOpcode::kPopulationCount:
341 return CreateOpWithoutAttrs<lmhlo::PopulationCountOp>(instr);
342 case HloOpcode::kPower:
343 return CreateOpWithoutAttrs<lmhlo::PowOp>(instr);
344 case HloOpcode::kReal:
345 return CreateOpWithoutAttrs<lmhlo::RealOp>(instr);
346 case HloOpcode::kReshape:
347 return CreateOpWithoutAttrs<lmhlo::ReshapeOp>(instr);
348 case HloOpcode::kReducePrecision:
349 return EmitReducePrecisionOp(instr);
350 case HloOpcode::kReduceWindow:
351 return EmitReduceWindowOp(instr);
352 case HloOpcode::kRemainder:
353 return CreateOpWithoutAttrs<lmhlo::RemOp>(instr);
354 case HloOpcode::kReplicaId:
355 return CreateOpWithoutAttrs<lmhlo::ReplicaIdOp>(instr);
356 case HloOpcode::kReverse:
357 return EmitReverseOp(instr);
358 case HloOpcode::kRoundNearestAfz:
359 return CreateOpWithoutAttrs<lmhlo::RoundOp>(instr);
360 case HloOpcode::kRsqrt:
361 return CreateOpWithoutAttrs<lmhlo::RsqrtOp>(instr);
362 case HloOpcode::kSelect:
363 return CreateOpWithoutAttrs<lmhlo::SelectOp>(instr);
364 case HloOpcode::kShiftLeft:
365 return CreateOpWithoutAttrs<lmhlo::ShiftLeftOp>(instr);
366 case HloOpcode::kShiftRightLogical:
367 return CreateOpWithoutAttrs<lmhlo::ShiftRightLogicalOp>(instr);
368 case HloOpcode::kShiftRightArithmetic:
369 return CreateOpWithoutAttrs<lmhlo::ShiftRightArithmeticOp>(instr);
370 case HloOpcode::kSign:
371 return CreateOpWithoutAttrs<lmhlo::SignOp>(instr);
372 case HloOpcode::kSin:
373 return CreateOpWithoutAttrs<lmhlo::SinOp>(instr);
374 case HloOpcode::kSlice:
375 return EmitSliceOp(instr);
376 case HloOpcode::kSqrt:
377 return CreateOpWithoutAttrs<lmhlo::SqrtOp>(instr);
378 case HloOpcode::kSubtract:
379 return CreateOpWithoutAttrs<lmhlo::SubOp>(instr);
380 case HloOpcode::kTanh:
381 return CreateOpWithoutAttrs<lmhlo::TanhOp>(instr);
382 case HloOpcode::kTranspose:
383 return EmitTransposeOp(instr);
384 case HloOpcode::kTriangularSolve:
385 return EmitTriangularSolveOp(instr);
386 case HloOpcode::kXor:
387 return CreateOpWithoutAttrs<lmhlo::XorOp>(instr);
388 case HloOpcode::kSort:
389 return EmitSortOp(instr);
390 case HloOpcode::kFusion:
391 return EmitFusionOp(instr);
392 case HloOpcode::kScatter:
393 return EmitScatterOp(instr);
394 case HloOpcode::kSelectAndScatter:
395 return EmitSelectAndScatterOp(instr);
396 case HloOpcode::kCustomCall:
397 return EmitCustomCallOp(instr);
398 case HloOpcode::kConstant:
399 return EmitConstant(instr);
400 case HloOpcode::kReduce:
401 return EmitReduceOp(instr);
402 case HloOpcode::kRngGetAndUpdateState:
403 return EmitRngGetAndUpdateStateOp(instr);
404 default:
405 llvm::errs() << instr->ToString();
406 return tensorflow::errors::Internal(
407 absl::StrCat("LHLO opcode ", xla::HloOpcodeString(instr->opcode()),
408 " is not supported."));
409 }
410 }
411
DefaultAction(const HloInstruction * instr)412 Status LhloDialectEmitter::DefaultAction(const HloInstruction* instr) {
413 return EmitOp(instr).status();
414 }
415
EmitSortOp(const HloInstruction * instr)416 StatusOr<lmhlo::SortOp> LhloDialectEmitter::EmitSortOp(
417 const HloInstruction* instr) {
418 TF_ASSIGN_OR_RETURN(auto sort, CreateOpWithoutAttrs<lmhlo::SortOp>(instr));
419 auto* sort_instr = xla::Cast<xla::HloSortInstruction>(instr);
420 sort.dimensionAttr(builder_.getI64IntegerAttr(sort_instr->sort_dimension()));
421 sort.is_stableAttr(builder_.getBoolAttr(sort_instr->is_stable()));
422 TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
423 *sort_instr->called_computations()[0], &sort.comparator(), &builder_));
424 return sort;
425 }
426
427 // Walks MHLO::TupleOp recursively.
WalkTuplePostOrder(Value v,const std::function<Status (Value)> & visitor)428 Status WalkTuplePostOrder(Value v,
429 const std::function<Status(Value)>& visitor) {
430 if (auto* op = v.getDefiningOp()) {
431 if (auto tuple = dyn_cast<mhlo::TupleOp>(op)) {
432 for (Value sub_v : tuple.val()) {
433 TF_RETURN_IF_ERROR(WalkTuplePostOrder(sub_v, visitor));
434 }
435 return Status::OK();
436 }
437 }
438 return visitor(v);
439 }
440
441 // This function removes all uses of a fused region argument, and rewire those
442 // uses to a `tensor_load %memref`, where %memref is caller argument.
443 //
444 // It also flattens all input/output tuples into more region arguments /
445 // results.
RewriteFusionOperand(const HloInstruction * root,const Shape & shape,xla::ShapeIndex * shape_index,OpBuilder * b,Location loc)446 StatusOr<Value> LhloDialectEmitter::RewriteFusionOperand(
447 const HloInstruction* root, const Shape& shape,
448 xla::ShapeIndex* shape_index, OpBuilder* b, Location loc) {
449 if (shape.IsTuple()) {
450 llvm::SmallVector<Value, 4> values;
451 for (int i = 0; i < shape.tuple_shapes_size(); ++i) {
452 shape_index->push_back(i);
453 TF_ASSIGN_OR_RETURN(
454 auto v, RewriteFusionOperand(root, shape.tuple_shapes(i), shape_index,
455 b, loc));
456 values.push_back(v);
457 shape_index->pop_back();
458 }
459 return Value(b->create<mhlo::TupleOp>(loc, values));
460 }
461 TF_ASSIGN_OR_RETURN(Value memref,
462 GetOrCreateArrayView(root, shape, *shape_index));
463 auto load = b->create<TensorLoadOp>(loc, memref);
464 if (shape.layout() !=
465 xla::LayoutUtil::MakeDescendingLayout(shape.dimensions().size())) {
466 llvm::SmallVector<int64_t, 4> minor_to_major(
467 shape.layout().minor_to_major().begin(),
468 shape.layout().minor_to_major().end());
469 load->setAttr("minor_to_major", GetLayoutAttribute(shape.layout(), b));
470 }
471 return load.getResult();
472 }
473
EmitFusionOp(const HloInstruction * instr)474 StatusOr<lmhlo::FusionOp> LhloDialectEmitter::EmitFusionOp(
475 const HloInstruction* instr) {
476 Location loc = getLocation(instr);
477
478 auto* fusion_instr = xla::Cast<xla::HloFusionInstruction>(instr);
479
480 auto fusion = builder_.create<lmhlo::FusionOp>(getLocation(instr));
481 auto after_fusion = builder_.saveInsertionPoint();
482 builder_ = mlir::OpBuilder(fusion);
483
484 auto region_builder = OpBuilder::atBlockBegin(&fusion.region().front());
485
486 llvm::SmallVector<Value, 8> arguments;
487 for (int i = 0; i < instr->operands().size(); ++i) {
488 const HloInstruction* operand = instr->operand(i);
489 xla::ShapeIndex shape_index;
490 TF_ASSIGN_OR_RETURN(
491 auto arg, RewriteFusionOperand(operand, operand->shape(), &shape_index,
492 ®ion_builder, loc));
493 arguments.push_back(arg);
494 }
495
496 TF_ASSIGN_OR_RETURN(Value result,
497 xla::HloFunctionImporter::ImportInstructions(
498 *fusion_instr->fused_instructions_computation(),
499 arguments, ®ion_builder));
500
501 {
502 int i = 0;
503 llvm::SmallVector<Value, 4> output;
504 TF_RETURN_IF_ERROR(GetOrCreateView(instr, &output));
505 TF_RETURN_IF_ERROR(WalkTuplePostOrder(result, [&](Value v) mutable {
506 region_builder.create<TensorStoreOp>(loc, v, output[i++]);
507 return Status::OK();
508 }));
509 if (i != output.size()) {
510 return xla::InternalError("output sizes don't match");
511 }
512 }
513
514 // Fold GTE/Tuple pairs.
515 //
516 // Since the fused region refers to values in its parent region, we can't
517 // call applyPatternAndFoldGreedily. We optimize it manually.
518 //
519 // Only walk once, because post-ordering is exactly what we need for GTE
520 // optimizations.
521 fusion.region().walk([](mhlo::GetTupleElementOp gte) {
522 SmallVector<Value, 4> folded_values;
523 if (succeeded(OpBuilder(gte).tryFold(gte, folded_values))) {
524 gte.replaceAllUsesWith(folded_values[0]);
525 }
526 });
527
528 // Effectively a DCE on the region.
529 {
530 llvm::SmallVector<mlir::Operation*, 4> ops;
531 fusion.region().walk([&](mlir::Operation* op) { ops.push_back(op); });
532 // Visit the user first.
533 std::reverse(ops.begin(), ops.end());
534 for (auto op : ops) {
535 if (isOpTriviallyDead(op)) op->erase();
536 }
537 }
538
539 builder_.restoreInsertionPoint(after_fusion);
540 return fusion;
541 }
542
543 StatusOr<mhlo::ScatterDimensionNumbers>
GetScatterDimensionNumbers(const HloInstruction * instr)544 LhloDialectEmitter::GetScatterDimensionNumbers(const HloInstruction* instr) {
545 auto* scatter_instr = xla::Cast<xla::HloScatterInstruction>(instr);
546
547 const xla::ScatterDimensionNumbers& xla_scatter_dim =
548 scatter_instr->scatter_dimension_numbers();
549 auto scatter_dimension_numbers = mhlo::ScatterDimensionNumbers::get(
550 GetI64DenseElementsAttr(xla_scatter_dim.update_window_dims()),
551 GetI64DenseElementsAttr(xla_scatter_dim.inserted_window_dims()),
552 GetI64DenseElementsAttr(xla_scatter_dim.scatter_dims_to_operand_dims()),
553 builder_.getI64IntegerAttr(xla_scatter_dim.index_vector_dim()),
554 module_.getContext());
555 return scatter_dimension_numbers;
556 }
557
EmitScatterOp(const HloInstruction * instr)558 StatusOr<lmhlo::ScatterOp> LhloDialectEmitter::EmitScatterOp(
559 const HloInstruction* instr) {
560 TF_ASSIGN_OR_RETURN(auto scatter,
561 CreateOpWithoutAttrs<lmhlo::ScatterOp>(instr));
562
563 // copy attributes
564 auto* scatter_instr = xla::Cast<xla::HloScatterInstruction>(instr);
565
566 TF_ASSIGN_OR_RETURN(auto scatter_dimension_numbers,
567 GetScatterDimensionNumbers(instr));
568 scatter.scatter_dimension_numbersAttr(scatter_dimension_numbers);
569 scatter.indices_are_sortedAttr(
570 builder_.getBoolAttr(scatter_instr->indices_are_sorted()));
571 scatter.unique_indicesAttr(
572 builder_.getBoolAttr(scatter_instr->unique_indices()));
573
574 // import update computation as region
575 TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
576 *scatter_instr->called_computations()[0], &scatter.update_computation(),
577 &builder_));
578
579 return scatter;
580 }
581
EmitSelectAndScatterOp(const HloInstruction * instr)582 StatusOr<lmhlo::SelectAndScatterOp> LhloDialectEmitter::EmitSelectAndScatterOp(
583 const HloInstruction* instr) {
584 TF_ASSIGN_OR_RETURN(auto select_and_scatter,
585 CreateOpWithoutAttrs<lmhlo::SelectAndScatterOp>(instr));
586
587 // copy attributes
588 auto* select_and_scatter_instr =
589 xla::Cast<xla::HloSelectAndScatterInstruction>(instr);
590 const xla::Window& window = select_and_scatter_instr->window();
591
592 select_and_scatter.window_dimensionsAttr(
593 GetWindowElements(window, [](const xla::WindowDimension& dim) {
594 return static_cast<int64_t>(dim.size());
595 }));
596 select_and_scatter.window_stridesAttr(
597 GetWindowElements(window, [](const xla::WindowDimension& dim) {
598 return static_cast<int64_t>(dim.stride());
599 }));
600 select_and_scatter.paddingAttr(
601 GetWindowElements(window, [](const xla::WindowDimension& dim) {
602 return static_cast<int64_t>(dim.padding_low());
603 }));
604
605 // import select and scatter computation as region
606 TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
607 *select_and_scatter_instr->select(), &select_and_scatter.select(),
608 &builder_));
609 TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
610 *select_and_scatter_instr->scatter(), &select_and_scatter.scatter(),
611 &builder_));
612 return select_and_scatter;
613 }
614
EmitCustomCallOp(const HloInstruction * instr)615 StatusOr<mlir::Operation*> LhloDialectEmitter::EmitCustomCallOp(
616 const HloInstruction* instr) {
617 auto* custom_call_instr = xla::Cast<xla::HloCustomCallInstruction>(instr);
618
619 if (xla::gpu::IsCustomCallToCusolver(*instr)) {
620 return EmitCholesky(custom_call_instr);
621 }
622
623 if (xla::gpu::IsCublasGemm(*instr)) {
624 return EmitGemm(custom_call_instr);
625 }
626
627 if (xla::gpu::IsCustomCallToDnnConvolution(*instr)) {
628 return EmitDnnConvolution(custom_call_instr);
629 }
630
631 if (xla::gpu::IsCustomCallToDnnBatchNorm(*instr)) {
632 return EmitDnnBatchNorm(custom_call_instr);
633 }
634
635 size_t num_arguments, num_results;
636 TF_ASSIGN_OR_RETURN(auto custom_call,
637 CreateOpWithoutAttrs<lmhlo::CustomCallOp>(
638 instr, num_arguments, num_results));
639 custom_call.call_target_nameAttr(
640 builder_.getStringAttr(custom_call_instr->custom_call_target()));
641 custom_call.backend_configAttr(
642 builder_.getStringAttr(custom_call_instr->opaque()));
643 const int32_t segments[2] = {static_cast<int32_t>(num_arguments),
644 static_cast<int32_t>(num_results)};
645 custom_call->setAttr(lmhlo::CustomCallOp::getOperandSegmentSizeAttr(),
646 builder_.getI32VectorAttr(segments));
647 return custom_call.getOperation();
648 }
649
EmitCholesky(const HloCustomCallInstruction * custom_call)650 StatusOr<lmhlo_gpu::CholeskyOp> LhloDialectEmitter::EmitCholesky(
651 const HloCustomCallInstruction* custom_call) {
652 TF_ASSIGN_OR_RETURN(auto cholesky_op,
653 CreateOpWithoutAttrs<lmhlo_gpu::CholeskyOp>(custom_call));
654 TF_ASSIGN_OR_RETURN(xla::CholeskyOptions options,
655 custom_call->backend_config<xla::CholeskyOptions>());
656 cholesky_op.is_lowerAttr(builder_.getBoolAttr(options.lower()));
657 return cholesky_op;
658 }
659
EmitGemm(const HloCustomCallInstruction * custom_call)660 StatusOr<Operation*> LhloDialectEmitter::EmitGemm(
661 const HloCustomCallInstruction* custom_call) {
662 TF_ASSIGN_OR_RETURN(
663 auto const config,
664 custom_call->backend_config<xla::gpu::GemmBackendConfig>());
665
666 auto set_common_attributes = [&](auto op) -> Operation* {
667 auto hlo_dims = config.dot_dimension_numbers();
668 auto mlir_dims = mhlo::DotDimensionNumbers::get(
669 GetI64DenseElementsAttr(hlo_dims.lhs_batch_dimensions()),
670 GetI64DenseElementsAttr(hlo_dims.rhs_batch_dimensions()),
671 GetI64DenseElementsAttr(hlo_dims.lhs_contracting_dimensions()),
672 GetI64DenseElementsAttr(hlo_dims.rhs_contracting_dimensions()),
673 builder_.getContext());
674 op.dot_dimension_numbersAttr(mlir_dims);
675 op.alpha_realAttr(builder_.getF64FloatAttr(config.alpha_real()));
676 op.alpha_imagAttr(builder_.getF64FloatAttr(config.alpha_imag()));
677 op.batch_sizeAttr(builder_.getI64IntegerAttr(config.batch_size()));
678 if (config.algorithm_case() ==
679 xla::gpu::GemmBackendConfig::kSelectedAlgorithm) {
680 op.algorithmAttr(builder_.getI64IntegerAttr(config.selected_algorithm()));
681 }
682 return op.getOperation();
683 };
684
685 if (custom_call->operand_count() == 2) {
686 TF_ASSIGN_OR_RETURN(auto gemm,
687 CreateOpWithoutAttrs<lmhlo_gpu::GEMMOp>(custom_call));
688 return set_common_attributes(gemm);
689 }
690
691 if (custom_call->operand_count() == 3) {
692 TF_ASSIGN_OR_RETURN(
693 auto gemm_bias,
694 CreateOpWithoutAttrs<lmhlo_gpu::GEMM_BiasOp>(custom_call));
695 gemm_bias.betaAttr(builder_.getF64FloatAttr(config.beta()));
696 return set_common_attributes(gemm_bias);
697 }
698
699 return xla::InvalidArgument("GEMM custom call should have 2 or 3 operands");
700 }
701
GetLHLOActivation(stream_executor::dnn::ActivationMode activation)702 static StatusOr<mlir::lmhlo_gpu::Activation> GetLHLOActivation(
703 stream_executor::dnn::ActivationMode activation) {
704 switch (activation) {
705 case stream_executor::dnn::kNone:
706 return mlir::lmhlo_gpu::Activation::None;
707 case stream_executor::dnn::kSigmoid:
708 return mlir::lmhlo_gpu::Activation::Sigmoid;
709 case stream_executor::dnn::kRelu:
710 return mlir::lmhlo_gpu::Activation::Relu;
711 case stream_executor::dnn::kRelu6:
712 return mlir::lmhlo_gpu::Activation::Relu6;
713 case stream_executor::dnn::kReluX:
714 return mlir::lmhlo_gpu::Activation::ReluX;
715 case stream_executor::dnn::kTanh:
716 return mlir::lmhlo_gpu::Activation::Tanh;
717 case stream_executor::dnn::kBandPass:
718 return mlir::lmhlo_gpu::Activation::BandPass;
719 default:
720 return xla::InternalError("Unknown activation");
721 }
722 }
723
EmitDnnConvolution(const HloCustomCallInstruction * custom_call)724 StatusOr<Operation*> LhloDialectEmitter::EmitDnnConvolution(
725 const HloCustomCallInstruction* custom_call) {
726 TF_ASSIGN_OR_RETURN(
727 auto const backend_config,
728 custom_call->backend_config<xla::gpu::CudnnConvBackendConfig>());
729
730 TF_ASSIGN_OR_RETURN(const xla::gpu::CudnnConvKind kind,
731 xla::gpu::GetCudnnConvKind(custom_call));
732
733 auto get_layout_attribute = [&](const xla::Layout& layout) {
734 std::vector<int64_t> minor_to_major(layout.minor_to_major_size());
735 absl::c_transform(layout.minor_to_major(), minor_to_major.begin(),
736 [](xla::int64 x) { return static_cast<int64_t>(x); });
737 return builder_.getI64ArrayAttr(minor_to_major);
738 };
739
740 auto set_common_conv_attributes = [&, this](auto op) -> Operation* {
741 const xla::Window& window = custom_call->window();
742 // Window size for Cudnn Conv is same as the kernel size.
743 op.window_stridesAttr(
744 GetWindowElements(window, [](const xla::WindowDimension& dim) {
745 return static_cast<int64_t>(dim.stride());
746 }));
747 // Cudnn Conv requires low and high padding to be equal.
748 op.paddingAttr(
749 GetWindowElements(window, [](const xla::WindowDimension& dim) {
750 return static_cast<int64_t>(dim.padding_low());
751 }));
752 // LHS dilation is encoded in base_dilation of the backend config.
753 // RHS dilation is encoded in window_dilation of the backend config.
754 op.lhs_dilationAttr(
755 GetWindowElements(window, [](const xla::WindowDimension& dim) {
756 return static_cast<int64_t>(dim.base_dilation());
757 }));
758 op.rhs_dilationAttr(
759 GetWindowElements(window, [](const xla::WindowDimension& dim) {
760 return static_cast<int64_t>(dim.window_dilation());
761 }));
762 // Setup window reversal.
763 auto window_reversal = llvm::to_vector<4>(llvm::map_range(
764 window.dimensions(),
765 [](const xla::WindowDimension& dim) { return dim.window_reversal(); }));
766 auto type = RankedTensorType::get(op.window_strides()->getType().getShape(),
767 builder_.getIntegerType(/*width=*/1));
768 op.window_reversalAttr(DenseElementsAttr::get(type, window_reversal));
769
770 op.dimension_numbersAttr(xla::ConvertConvDimensionNumbers(
771 custom_call->convolution_dimension_numbers(), &builder_));
772 op.feature_group_countAttr(
773 builder_.getI64IntegerAttr(custom_call->feature_group_count()));
774 op.batch_group_countAttr(
775 builder_.getI64IntegerAttr(custom_call->batch_group_count()));
776 op.precision_configAttr(xla::ConvertPrecisionConfig(
777 &custom_call->precision_config(), &builder_));
778 op.result_scaleAttr(
779 builder_.getF64FloatAttr(backend_config.conv_result_scale()));
780 auto config = mlir::lmhlo_gpu::ConvolutionBackendConfig::get(
781 builder_.getI64IntegerAttr(backend_config.algorithm()),
782 builder_.getBoolAttr(backend_config.tensor_ops_enabled()),
783 get_layout_attribute(custom_call->operand(0)->shape().layout()),
784 get_layout_attribute(custom_call->operand(1)->shape().layout()),
785 get_layout_attribute(custom_call->shape().tuple_shapes(0).layout()),
786 builder_.getContext());
787 op.backend_configAttr(config);
788
789 return op.getOperation();
790 };
791
792 auto set_activation = [&, this](auto op) -> Status {
793 auto se_activation = static_cast<stream_executor::dnn::ActivationMode>(
794 backend_config.activation_mode());
795 TF_ASSIGN_OR_RETURN(mlir::lmhlo_gpu::Activation activation,
796 GetLHLOActivation(se_activation));
797 StringAttr activation_attr = builder_.getStringAttr(
798 mlir::lmhlo_gpu::stringifyActivation(activation));
799 op.activation_modeAttr(activation_attr);
800 return Status::OK();
801 };
802
803 switch (kind) {
804 case xla::gpu::CudnnConvKind::kForward: {
805 TF_ASSIGN_OR_RETURN(
806 auto cnn_forward,
807 CreateOpWithoutAttrs<lmhlo_gpu::ConvForwardOp>(custom_call));
808 return set_common_conv_attributes(cnn_forward);
809 }
810 case xla::gpu::CudnnConvKind::kBackwardInput: {
811 TF_ASSIGN_OR_RETURN(
812 auto cnn_backward,
813 CreateOpWithoutAttrs<lmhlo_gpu::ConvBackwardInputOp>(custom_call));
814 return set_common_conv_attributes(cnn_backward);
815 }
816 case xla::gpu::CudnnConvKind::kBackwardFilter: {
817 TF_ASSIGN_OR_RETURN(
818 auto cnn_backward,
819 CreateOpWithoutAttrs<lmhlo_gpu::ConvBackwardFilterOp>(custom_call));
820 return set_common_conv_attributes(cnn_backward);
821 }
822 case xla::gpu::CudnnConvKind::kForwardActivation: {
823 // Fused conv can be either with side input or without.
824 if (custom_call->operand_count() == 3) {
825 TF_ASSIGN_OR_RETURN(
826 auto cnn_fused,
827 CreateOpWithoutAttrs<lmhlo_gpu::ConvForwardFusedOp>(custom_call));
828 TF_RETURN_IF_ERROR(set_activation(cnn_fused));
829 return set_common_conv_attributes(cnn_fused);
830 }
831
832 TF_RET_CHECK(custom_call->operand_count() == 4);
833 TF_ASSIGN_OR_RETURN(
834 auto cnn_fused_side_input,
835 CreateOpWithoutAttrs<lmhlo_gpu::ConvForwardFusedSideInputOp>(
836 custom_call));
837 cnn_fused_side_input.side_input_scaleAttr(
838 builder_.getF64FloatAttr(backend_config.side_input_scale()));
839 TF_RETURN_IF_ERROR(set_activation(cnn_fused_side_input));
840 return set_common_conv_attributes(cnn_fused_side_input);
841 }
842 }
843 }
844
EmitDnnBatchNorm(const HloCustomCallInstruction * custom_call)845 StatusOr<Operation*> LhloDialectEmitter::EmitDnnBatchNorm(
846 const HloCustomCallInstruction* custom_call) {
847 const xla::int64 num_operands = custom_call->operand_count();
848 auto set_batchnorm_attributes = [&](auto op) -> StatusOr<Operation*> {
849 // The last 2 operands of a custom call for batch norm are the epsilon and
850 // feature_index.
851 const HloInstruction* epsilon = custom_call->operand(num_operands - 2);
852 TF_RET_CHECK(epsilon->IsConstant());
853 float epsilon_value = epsilon->literal().Get<float>({});
854
855 const HloInstruction* feature_index =
856 custom_call->operand(num_operands - 1);
857 TF_RET_CHECK(feature_index->IsConstant());
858 xla::int64 feature_index_value =
859 feature_index->literal().Get<xla::int64>({});
860
861 op.epsilonAttr(builder_.getF32FloatAttr(epsilon_value));
862 op.feature_indexAttr(builder_.getI64IntegerAttr(feature_index_value));
863 return op.getOperation();
864 };
865
866 const std::string& target = custom_call->custom_call_target();
867 if (target == xla::gpu::kCudnnBatchNormForwardTrainingCallTarget) {
868 TF_ASSIGN_OR_RETURN(auto fwd_training,
869 CreateOpWithoutAttrs<lmhlo_gpu::BatchNormTrainingOp>(
870 custom_call, num_operands - 2));
871 return set_batchnorm_attributes(fwd_training);
872 }
873
874 if (target == xla::gpu::kCudnnBatchNormBackwardCallTarget) {
875 TF_ASSIGN_OR_RETURN(auto backward,
876 CreateOpWithoutAttrs<lmhlo_gpu::BatchNormGradOp>(
877 custom_call, num_operands - 2));
878 return set_batchnorm_attributes(backward);
879 }
880
881 if (target == xla::gpu::kCudnnBatchNormForwardInferenceCallTarget) {
882 TF_ASSIGN_OR_RETURN(auto fwd_inference,
883 CreateOpWithoutAttrs<lmhlo_gpu::BatchNormInferenceOp>(
884 custom_call, num_operands - 2));
885 return set_batchnorm_attributes(fwd_inference);
886 }
887
888 return xla::Unimplemented("Unsupported batch norm operation");
889 }
890
891 // Convert an XLA HLO constant to a global_memref + get_global_memref pair.
EmitConstant(const HloInstruction * instr)892 StatusOr<mlir::GetGlobalMemrefOp> LhloDialectEmitter::EmitConstant(
893 const HloInstruction* instr) {
894 // Insert a global_memref in the module.
895 Location loc = getLocation(instr);
896
897 auto const_instr = xla::Cast<xla::HloConstantInstruction>(instr);
898 TF_RET_CHECK(const_instr->shape().IsArray() &&
899 const_instr->shape().is_static());
900 TF_ASSIGN_OR_RETURN(Type type, xla::ConvertShapeToType<MemRefType>(
901 const_instr->shape(), builder_));
902 auto memref_type = type.dyn_cast<MemRefType>();
903 TF_RET_CHECK(memref_type != nullptr);
904
905 TF_ASSIGN_OR_RETURN(
906 DenseElementsAttr initial_value,
907 CreateDenseElementsAttrFromLiteral(const_instr->literal(), builder_));
908
909 std::string constant_name = xla::llvm_ir::ConstantNameToGlobalName(
910 xla::llvm_ir::SanitizeConstantName(instr->name()));
911
912 // Insert the global memref at the top level.
913 {
914 OpBuilder::InsertionGuard guard(builder_);
915 builder_.clearInsertionPoint();
916 auto global_var = builder_.create<GlobalMemrefOp>(
917 loc, constant_name, builder_.getStringAttr("private"),
918 TypeAttr::get(memref_type), initial_value, true);
919 SymbolTable(module_).insert(global_var);
920 global_var.getOperation()->moveBefore(&module_.front());
921
922 // For operations that do not fold this constant value in their codegen, we
923 // still need to materialize it into a buffer. Since buffer allocation is
924 // already done, annotate the global_memref with the information to get to
925 // the allocated buffer slice for this constant if need be.
926 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
927 assignment_.GetUniqueTopLevelSlice(instr));
928 global_var->setAttr("lmhlo.alloc", builder_.getIndexAttr(slice.index()));
929 TF_RET_CHECK(slice.offset() == 0)
930 << "Each constant should have its own allocation from BufferAssignment";
931 TF_RET_CHECK(slice.allocation()->size() == slice.size())
932 << "Each constant should have its own allocation from BufferAssignment";
933 }
934
935 auto get_global_memref =
936 builder_.create<GetGlobalMemrefOp>(loc, memref_type, constant_name);
937
938 // Update the cache to remember this value.
939 auto& cached_value = slices_[std::make_pair(instr, xla::ShapeIndex())];
940 TF_RET_CHECK(cached_value == nullptr);
941 cached_value = get_global_memref;
942 return get_global_memref;
943 }
944
EmitReduceOp(const HloInstruction * instr)945 StatusOr<lmhlo::ReduceOp> LhloDialectEmitter::EmitReduceOp(
946 const HloInstruction* instr) {
947 TF_ASSIGN_OR_RETURN(auto reduce_op,
948 CreateOpWithoutAttrs<lmhlo::ReduceOp>(instr));
949 auto* reduce = xla::Cast<xla::HloReduceInstruction>(instr);
950 std::vector<int64_t> dimensions(reduce->dimensions().begin(),
951 reduce->dimensions().end());
952 reduce_op.dimensionsAttr(GetI64DenseElementsAttr(dimensions));
953 TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
954 *instr->called_computations()[0], &reduce_op.body(), &builder_));
955 return reduce_op;
956 }
957
EmitMapOp(const HloInstruction * instr)958 StatusOr<lmhlo::MapOp> LhloDialectEmitter::EmitMapOp(
959 const HloInstruction* instr) {
960 TF_ASSIGN_OR_RETURN(auto map_op, CreateOpWithoutAttrs<lmhlo::MapOp>(instr));
961 auto* map = xla::Cast<xla::HloMapInstruction>(instr);
962 std::vector<int64_t> dimensions(map->dimensions().begin(),
963 map->dimensions().end());
964 map_op.dimensionsAttr(GetI64DenseElementsAttr(dimensions));
965 TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
966 *instr->called_computations()[0], &map_op.computation(), &builder_));
967 return map_op;
968 }
969
EmitCompareOp(const HloInstruction * instr)970 StatusOr<lmhlo::CompareOp> LhloDialectEmitter::EmitCompareOp(
971 const HloInstruction* instr) {
972 TF_ASSIGN_OR_RETURN(auto compare_op,
973 CreateOpWithoutAttrs<lmhlo::CompareOp>(instr));
974
975 auto* compare = xla::Cast<xla::HloCompareInstruction>(instr);
976 auto direction = [&]() {
977 switch (compare->direction()) {
978 case xla::ComparisonDirection::kEq:
979 return mhlo::ComparisonDirection::EQ;
980 case xla::ComparisonDirection::kNe:
981 return mhlo::ComparisonDirection::NE;
982 case xla::ComparisonDirection::kGe:
983 return mhlo::ComparisonDirection::GE;
984 case xla::ComparisonDirection::kGt:
985 return mhlo::ComparisonDirection::GT;
986 case xla::ComparisonDirection::kLe:
987 return mhlo::ComparisonDirection::LE;
988 case xla::ComparisonDirection::kLt:
989 return mhlo::ComparisonDirection::LT;
990 }
991 }();
992 compare_op.comparison_directionAttr(
993 builder_.getStringAttr(stringifyComparisonDirection(direction)));
994 auto compare_type = [&]() {
995 switch (compare->type()) {
996 case xla::Comparison::Type::kFloat:
997 return mhlo::ComparisonType::FLOAT;
998 case xla::Comparison::Type::kFloatTotalOrder:
999 return mhlo::ComparisonType::TOTALORDER;
1000 case xla::Comparison::Type::kSigned:
1001 return mhlo::ComparisonType::SIGNED;
1002 case xla::Comparison::Type::kUnsigned:
1003 return mhlo::ComparisonType::UNSIGNED;
1004 }
1005 }();
1006 compare_op.compare_typeAttr(
1007 builder_.getStringAttr(stringifyComparisonType(compare_type)));
1008 return compare_op;
1009 }
1010
EmitReducePrecisionOp(const HloInstruction * instr)1011 StatusOr<lmhlo::ReducePrecisionOp> LhloDialectEmitter::EmitReducePrecisionOp(
1012 const HloInstruction* instr) {
1013 TF_ASSIGN_OR_RETURN(auto reduce_precision_op,
1014 CreateOpWithoutAttrs<lmhlo::ReducePrecisionOp>(instr));
1015 auto* reduce_precision = xla::Cast<xla::HloReducePrecisionInstruction>(instr);
1016 reduce_precision_op.exponent_bitsAttr(
1017 builder_.getI32IntegerAttr(reduce_precision->exponent_bits()));
1018 reduce_precision_op.mantissa_bitsAttr(
1019 builder_.getI32IntegerAttr(reduce_precision->mantissa_bits()));
1020 return reduce_precision_op;
1021 }
1022
1023 namespace {
1024 template <typename OpT>
SetupChannelIdAttribute(OpT op,const xla::HloChannelInstruction * instr,mlir::Builder builder)1025 void SetupChannelIdAttribute(OpT op, const xla::HloChannelInstruction* instr,
1026 mlir::Builder builder) {
1027 if (instr->channel_id().has_value()) {
1028 op.channel_idAttr(mlir::mhlo::ChannelHandle::get(
1029 builder.getI64IntegerAttr(*instr->channel_id()),
1030 builder.getI64IntegerAttr(0), builder.getContext()));
1031 }
1032 }
1033
1034 template <typename OpT>
SetupCommonCollectiveOpAttributes(OpT op,const HloInstruction * instr,mlir::OpBuilder & builder)1035 Status SetupCommonCollectiveOpAttributes(OpT op, const HloInstruction* instr,
1036 mlir::OpBuilder& builder) {
1037 auto* collective = xla::Cast<xla::HloCollectiveInstruction>(instr);
1038 auto replica_groups_attr = xla::HloFunctionImporter::ConvertReplicaGroups(
1039 collective->replica_groups(), &builder);
1040 op->setAttr(replica_groups_attr.first, replica_groups_attr.second);
1041 op.constrain_layoutAttr(builder.getBoolAttr(collective->constrain_layout()));
1042 SetupChannelIdAttribute(op, collective, builder);
1043 return Status::OK();
1044 }
1045 } // namespace
1046
EmitAllToAllOp(const HloInstruction * instr)1047 StatusOr<lmhlo::AllToAllOp> LhloDialectEmitter::EmitAllToAllOp(
1048 const HloInstruction* instr) {
1049 TF_ASSIGN_OR_RETURN(auto all_to_all_op,
1050 CreateOpWithoutAttrs<lmhlo::AllToAllOp>(instr));
1051 auto* all_to_all = xla::Cast<xla::HloAllToAllInstruction>(instr);
1052 TF_RETURN_IF_ERROR(
1053 SetupCommonCollectiveOpAttributes(all_to_all_op, instr, builder_));
1054 if (all_to_all->split_dimension().has_value()) {
1055 all_to_all_op.split_dimensionAttr(
1056 builder_.getI64IntegerAttr(*all_to_all->split_dimension()));
1057 }
1058 return all_to_all_op;
1059 }
1060
EmitAllGatherOp(const HloInstruction * instr)1061 StatusOr<lmhlo::AllGatherOp> LhloDialectEmitter::EmitAllGatherOp(
1062 const HloInstruction* instr) {
1063 TF_ASSIGN_OR_RETURN(auto all_gather_op,
1064 CreateOpWithoutAttrs<lmhlo::AllGatherOp>(instr));
1065 auto* all_gather = xla::Cast<xla::HloAllGatherInstruction>(instr);
1066 TF_RETURN_IF_ERROR(
1067 SetupCommonCollectiveOpAttributes(all_gather_op, instr, builder_));
1068 all_gather_op.use_global_device_idsAttr(
1069 builder_.getBoolAttr(all_gather->use_global_device_ids()));
1070 all_gather_op.all_gather_dimensionAttr(
1071 builder_.getI64IntegerAttr(all_gather->all_gather_dimension()));
1072 return all_gather_op;
1073 }
1074
EmitAllReduceOp(const HloInstruction * instr)1075 StatusOr<lmhlo::AllReduceOp> LhloDialectEmitter::EmitAllReduceOp(
1076 const HloInstruction* instr) {
1077 TF_ASSIGN_OR_RETURN(auto all_reduce_op,
1078 CreateOpWithoutAttrs<lmhlo::AllReduceOp>(instr));
1079 auto* all_reduce = xla::Cast<xla::HloAllReduceInstruction>(instr);
1080 TF_RETURN_IF_ERROR(
1081 SetupCommonCollectiveOpAttributes(all_reduce_op, instr, builder_));
1082 all_reduce_op.use_global_device_idsAttr(
1083 builder_.getBoolAttr(all_reduce->use_global_device_ids()));
1084 TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
1085 *instr->called_computations()[0], &all_reduce_op.computation(),
1086 &builder_));
1087 return all_reduce_op;
1088 }
1089
1090 StatusOr<lmhlo::CollectivePermuteOp>
EmitCollectivePermuteOp(const HloInstruction * instr)1091 LhloDialectEmitter::EmitCollectivePermuteOp(const HloInstruction* instr) {
1092 TF_ASSIGN_OR_RETURN(auto permute_op,
1093 CreateOpWithoutAttrs<lmhlo::CollectivePermuteOp>(instr));
1094 auto* permute = xla::Cast<xla::HloCollectivePermuteInstruction>(instr);
1095 SetupChannelIdAttribute(permute_op, permute, builder_);
1096 mlir::NamedAttribute source_target_pairs_attr =
1097 xla::HloFunctionImporter::ConvertSourceTargetPairs(
1098 permute->source_target_pairs(), &builder_);
1099 permute_op->setAttr(source_target_pairs_attr.first,
1100 source_target_pairs_attr.second);
1101 return permute_op;
1102 }
1103
EmitInfeedOp(const HloInstruction * instr)1104 StatusOr<lmhlo::InfeedOp> LhloDialectEmitter::EmitInfeedOp(
1105 const HloInstruction* instr) {
1106 const HloInfeedInstruction* infeed = xla::Cast<HloInfeedInstruction>(instr);
1107 // HLO Infeed instruction has a single operand of token type and a tuple
1108 // with buffers and a token as its output. LMHLO Infeed operation does not
1109 // need the token operand or result, so drop it.
1110 SmallVector<Value, 2> operands;
1111 TF_RETURN_IF_ERROR(GetOrCreateView(instr, &operands, /*result_subset=*/{0}));
1112 auto infeed_op = CreateOpWithoutAttrs<lmhlo::InfeedOp>(instr, operands);
1113 infeed_op.configAttr(builder_.getStringAttr(infeed->infeed_config()));
1114 return infeed_op;
1115 }
1116
EmitOutfeedOp(const HloInstruction * instr)1117 StatusOr<lmhlo::OutfeedOp> LhloDialectEmitter::EmitOutfeedOp(
1118 const HloInstruction* instr) {
1119 const HloOutfeedInstruction* outfeed =
1120 xla::Cast<HloOutfeedInstruction>(instr);
1121 // HLO outfeed instruction has 2 operands, the source and a token, and a
1122 // single token output. LMHLO Outfeed does not need the token operand and
1123 // result, do drop it.
1124 SmallVector<Value, 2> operands;
1125 TF_RETURN_IF_ERROR(GetOrCreateView(instr->operand(0), &operands));
1126 auto outfeed_op = CreateOpWithoutAttrs<lmhlo::OutfeedOp>(instr, operands);
1127 outfeed_op.configAttr(builder_.getStringAttr(outfeed->outfeed_config()));
1128 return outfeed_op;
1129 }
1130
EmitBroadcastOp(const xla::HloInstruction * instr)1131 xla::StatusOr<lmhlo::BroadcastInDimOp> LhloDialectEmitter::EmitBroadcastOp(
1132 const xla::HloInstruction* instr) {
1133 TF_ASSIGN_OR_RETURN(auto broadcast,
1134 CreateOpWithoutAttrs<lmhlo::BroadcastInDimOp>(instr));
1135 broadcast.broadcast_dimensionsAttr(
1136 builder_.getI64TensorAttr(instr->dimensions()));
1137 return broadcast;
1138 }
1139
EmitConcatenateOp(const xla::HloInstruction * instr)1140 xla::StatusOr<lmhlo::ConcatenateOp> LhloDialectEmitter::EmitConcatenateOp(
1141 const xla::HloInstruction* instr) {
1142 TF_ASSIGN_OR_RETURN(auto concat,
1143 CreateOpWithoutAttrs<lmhlo::ConcatenateOp>(instr));
1144 auto hlo_concat = xla::Cast<xla::HloConcatenateInstruction>(instr);
1145 concat.dimensionAttr(
1146 builder_.getI64IntegerAttr(hlo_concat->concatenate_dimension()));
1147 return concat;
1148 }
1149
EmitIotaOp(const xla::HloInstruction * instr)1150 xla::StatusOr<lmhlo::IotaOp> LhloDialectEmitter::EmitIotaOp(
1151 const xla::HloInstruction* instr) {
1152 TF_ASSIGN_OR_RETURN(auto iota, CreateOpWithoutAttrs<lmhlo::IotaOp>(instr));
1153 auto hlo_iota = xla::Cast<xla::HloIotaInstruction>(instr);
1154 iota.iota_dimensionAttr(
1155 builder_.getI64IntegerAttr(hlo_iota->iota_dimension()));
1156 return iota;
1157 }
1158
EmitReverseOp(const xla::HloInstruction * instr)1159 xla::StatusOr<lmhlo::ReverseOp> LhloDialectEmitter::EmitReverseOp(
1160 const xla::HloInstruction* instr) {
1161 TF_ASSIGN_OR_RETURN(auto reverse,
1162 CreateOpWithoutAttrs<lmhlo::ReverseOp>(instr));
1163 reverse.dimensionsAttr(builder_.getI64TensorAttr(instr->dimensions()));
1164 return reverse;
1165 }
1166
EmitTransposeOp(const xla::HloInstruction * instr)1167 xla::StatusOr<lmhlo::TransposeOp> LhloDialectEmitter::EmitTransposeOp(
1168 const xla::HloInstruction* instr) {
1169 TF_ASSIGN_OR_RETURN(auto transpose,
1170 CreateOpWithoutAttrs<lmhlo::TransposeOp>(instr));
1171 transpose.permutationAttr(builder_.getI64TensorAttr(instr->dimensions()));
1172 return transpose;
1173 }
1174
EmitPadOp(const xla::HloInstruction * instr)1175 xla::StatusOr<lmhlo::PadOp> LhloDialectEmitter::EmitPadOp(
1176 const xla::HloInstruction* instr) {
1177 TF_ASSIGN_OR_RETURN(auto pad, CreateOpWithoutAttrs<lmhlo::PadOp>(instr));
1178 auto hlo_pad = xla::Cast<xla::HloPadInstruction>(instr);
1179 std::vector<xla::int64> low, high, interior;
1180 for (const auto& dim : hlo_pad->padding_config().dimensions()) {
1181 low.push_back(dim.edge_padding_low());
1182 high.push_back(dim.edge_padding_high());
1183 interior.push_back(dim.interior_padding());
1184 }
1185 pad.edge_padding_lowAttr(builder_.getI64TensorAttr(low));
1186 pad.edge_padding_highAttr(builder_.getI64TensorAttr(high));
1187 pad.interior_paddingAttr(builder_.getI64TensorAttr(interior));
1188 return pad;
1189 }
1190
EmitReduceWindowOp(const xla::HloInstruction * instr)1191 xla::StatusOr<lmhlo::ReduceWindowOp> LhloDialectEmitter::EmitReduceWindowOp(
1192 const xla::HloInstruction* instr) {
1193 TF_ASSIGN_OR_RETURN(auto reduce_window,
1194 CreateOpWithoutAttrs<lmhlo::ReduceWindowOp>(instr));
1195 auto hlo_reduce_window = xla::Cast<xla::HloReduceWindowInstruction>(instr);
1196 std::vector<xla::int64> dims, strides, base_dilations, window_dilations,
1197 paddings;
1198 for (const auto& dim : hlo_reduce_window->window().dimensions()) {
1199 dims.push_back(dim.size());
1200 strides.push_back(dim.stride());
1201 base_dilations.push_back(dim.base_dilation());
1202 window_dilations.push_back(dim.window_dilation());
1203 paddings.push_back(dim.padding_low());
1204 paddings.push_back(dim.padding_high());
1205 }
1206 reduce_window.window_dimensionsAttr(builder_.getI64TensorAttr(dims));
1207 if (xla::window_util::HasStride(hlo_reduce_window->window())) {
1208 reduce_window.window_stridesAttr(builder_.getI64TensorAttr(strides));
1209 }
1210 if (xla::window_util::HasBaseDilation(hlo_reduce_window->window())) {
1211 reduce_window.base_dilationsAttr(builder_.getI64TensorAttr(base_dilations));
1212 }
1213 if (xla::window_util::HasWindowDilation(hlo_reduce_window->window())) {
1214 reduce_window.window_dilationsAttr(
1215 builder_.getI64TensorAttr(window_dilations));
1216 }
1217 CHECK_EQ(0, paddings.size() % 2);
1218 if (xla::window_util::HasPadding(hlo_reduce_window->window())) {
1219 reduce_window.paddingAttr(DenseIntElementsAttr::get(
1220 RankedTensorType::get({static_cast<int64_t>(paddings.size() / 2), 2},
1221 builder_.getIntegerType(64)),
1222 paddings));
1223 }
1224 TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
1225 *hlo_reduce_window->called_computations()[0], &reduce_window.body(),
1226 &builder_));
1227 return reduce_window;
1228 }
1229
EmitSliceOp(const xla::HloInstruction * instr)1230 xla::StatusOr<lmhlo::SliceOp> LhloDialectEmitter::EmitSliceOp(
1231 const xla::HloInstruction* instr) {
1232 TF_ASSIGN_OR_RETURN(auto slice, CreateOpWithoutAttrs<lmhlo::SliceOp>(instr));
1233 auto hlo_slice = xla::Cast<xla::HloSliceInstruction>(instr);
1234 slice.start_indicesAttr(builder_.getI64TensorAttr(hlo_slice->slice_starts()));
1235 slice.limit_indicesAttr(builder_.getI64TensorAttr(hlo_slice->slice_limits()));
1236 slice.stridesAttr(builder_.getI64TensorAttr(hlo_slice->slice_strides()));
1237 return slice;
1238 }
1239
EmitGatherOp(const xla::HloInstruction * instr)1240 xla::StatusOr<lmhlo::GatherOp> LhloDialectEmitter::EmitGatherOp(
1241 const xla::HloInstruction* instr) {
1242 TF_ASSIGN_OR_RETURN(auto gather,
1243 CreateOpWithoutAttrs<lmhlo::GatherOp>(instr));
1244 auto hlo_gather = xla::Cast<xla::HloGatherInstruction>(instr);
1245 gather.dimension_numbersAttr(xla::ConvertGatherDimensionNumbers(
1246 hlo_gather->gather_dimension_numbers(), &builder_));
1247 gather.slice_sizesAttr(builder_.getI64TensorAttr(
1248 std::vector<int64_t>(hlo_gather->gather_slice_sizes().begin(),
1249 hlo_gather->gather_slice_sizes().end())));
1250 return gather;
1251 }
1252
EmitDynamicSliceOp(const xla::HloInstruction * instr)1253 xla::StatusOr<lmhlo::DynamicSliceOp> LhloDialectEmitter::EmitDynamicSliceOp(
1254 const xla::HloInstruction* instr) {
1255 TF_ASSIGN_OR_RETURN(auto dynamic_slice,
1256 CreateOpWithoutAttrs<lmhlo::DynamicSliceOp>(instr));
1257 auto hlo_dynamic_slice = xla::Cast<xla::HloDynamicSliceInstruction>(instr);
1258 dynamic_slice.slice_sizesAttr(
1259 builder_.getI64TensorAttr(hlo_dynamic_slice->dynamic_slice_sizes()));
1260 return dynamic_slice;
1261 }
1262
EmitDotOp(const xla::HloInstruction * instr)1263 xla::StatusOr<lmhlo::DotOp> LhloDialectEmitter::EmitDotOp(
1264 const xla::HloInstruction* instr) {
1265 TF_ASSIGN_OR_RETURN(auto dot, CreateOpWithoutAttrs<lmhlo::DotOp>(instr));
1266 auto hlo_dot = xla::Cast<xla::HloDotInstruction>(instr);
1267 dot.dot_dimension_numbersAttr(xla::ConvertDotDimensionNumbers(
1268 hlo_dot->dot_dimension_numbers(), &builder_));
1269 dot.precision_configAttr(
1270 xla::ConvertPrecisionConfig(&hlo_dot->precision_config(), &builder_));
1271 return dot;
1272 }
1273
1274 xla::StatusOr<lmhlo::RngGetAndUpdateStateOp>
EmitRngGetAndUpdateStateOp(const xla::HloInstruction * instr)1275 LhloDialectEmitter::EmitRngGetAndUpdateStateOp(
1276 const xla::HloInstruction* instr) {
1277 TF_ASSIGN_OR_RETURN(
1278 auto rng, CreateOpWithoutAttrs<lmhlo::RngGetAndUpdateStateOp>(instr));
1279 auto hlo_rng = xla::Cast<xla::HloRngGetAndUpdateStateInstruction>(instr);
1280 rng.deltaAttr(builder_.getI64IntegerAttr(hlo_rng->delta()));
1281 return rng;
1282 }
1283
EmitFftOp(const HloInstruction * instr)1284 xla::StatusOr<lmhlo::FftOp> LhloDialectEmitter::EmitFftOp(
1285 const HloInstruction* instr) {
1286 auto hlo_fft = xla::Cast<xla::HloFftInstruction>(instr);
1287 TF_ASSIGN_OR_RETURN(auto fft, CreateOpWithoutAttrs<lmhlo::FftOp>(instr));
1288 TF_ASSIGN_OR_RETURN(mlir::mhlo::FftType fft_type,
1289 xla::ConvertFftType(hlo_fft->fft_type()));
1290 StringAttr fft_type_attr =
1291 builder_.getStringAttr(mlir::mhlo::stringifyFftType(fft_type));
1292 fft.fft_typeAttr(fft_type_attr);
1293 fft.fft_lengthAttr(GetI64DenseElementsAttr(instr->fft_length()));
1294 return fft;
1295 }
1296
1297 xla::StatusOr<lmhlo::TriangularSolveOp>
EmitTriangularSolveOp(const xla::HloInstruction * instr)1298 LhloDialectEmitter::EmitTriangularSolveOp(const xla::HloInstruction* instr) {
1299 auto hlo_triangular_solve =
1300 xla::Cast<xla::HloTriangularSolveInstruction>(instr);
1301 TF_ASSIGN_OR_RETURN(auto triangular_solve,
1302 CreateOpWithoutAttrs<lmhlo::TriangularSolveOp>(instr));
1303 const xla::TriangularSolveOptions& options =
1304 hlo_triangular_solve->triangular_solve_options();
1305 triangular_solve.left_sideAttr(builder_.getBoolAttr(options.left_side()));
1306 triangular_solve.lowerAttr(builder_.getBoolAttr(options.lower()));
1307 triangular_solve.unit_diagonalAttr(
1308 builder_.getBoolAttr(options.unit_diagonal()));
1309 TF_ASSIGN_OR_RETURN(mlir::mhlo::Transpose transpose,
1310 xla::ConvertTranspose(options.transpose_a()));
1311 triangular_solve.transpose_aAttr(
1312 builder_.getStringAttr(mlir::mhlo::stringifyTranspose(transpose)));
1313 triangular_solve.layout_aAttr(
1314 GetLayoutAttribute(instr->operand(0)->shape().layout(), &builder_));
1315 triangular_solve.layout_bAttr(
1316 GetLayoutAttribute(instr->operand(1)->shape().layout(), &builder_));
1317 triangular_solve.layout_outputAttr(
1318 GetLayoutAttribute(instr->shape().layout(), &builder_));
1319 return triangular_solve;
1320 }
1321
GetLayoutAttribute(const xla::Layout & layout,Builder * builder)1322 mlir::DenseIntElementsAttr LhloDialectEmitter::GetLayoutAttribute(
1323 const xla::Layout& layout, Builder* builder) {
1324 llvm::SmallVector<int64_t, 4> minor_to_major(layout.minor_to_major().begin(),
1325 layout.minor_to_major().end());
1326 return builder->getIndexTensorAttr(minor_to_major);
1327 }
1328
GetOrCreateArrayView(const xla::HloInstruction * instr,const xla::Shape & current_shape,const xla::ShapeIndex & shape_index)1329 StatusOr<Value> LhloDialectEmitter::GetOrCreateArrayView(
1330 const xla::HloInstruction* instr, const xla::Shape& current_shape,
1331 const xla::ShapeIndex& shape_index) {
1332 // Cache generated ViewOp and StaticMemRefCastOp by (instruction,
1333 // shape_index).
1334 auto& cached_value = slices_[std::make_pair(instr, shape_index)];
1335 if (cached_value) {
1336 return cached_value;
1337 }
1338
1339 if (instr->IsConstant() && shape_index.empty()) {
1340 TF_ASSIGN_OR_RETURN(Value constant_memref, EmitConstant(instr));
1341 return cached_value = constant_memref;
1342 }
1343
1344 // If the shape happens to have dynamic dimensions, create the memref using
1345 // the underlying static shape.
1346 // TODO(jurahul): Revisit this when we can model memrefs with dynamic shape
1347 // but static bounds in MLIR.
1348 const Shape static_shape = xla::ShapeUtil::MakeStaticShape(current_shape);
1349
1350 TF_ASSIGN_OR_RETURN(Type out_type, xla::ConvertShapeToType<MemRefType>(
1351 static_shape, builder_));
1352 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
1353 assignment_.GetUniqueSlice(instr, shape_index));
1354 Value alloc = allocations_[slice.allocation()];
1355 if (alloc.getType() == out_type && slice.offset() == 0) {
1356 return cached_value = alloc;
1357 }
1358
1359 auto out_memref_type = out_type.dyn_cast<MemRefType>();
1360 if (!out_memref_type)
1361 return tensorflow::errors::Internal(
1362 "Expected memref type when creating a view for leaf type of a "
1363 "tuple.");
1364
1365 Value byte_shift =
1366 builder_.create<ConstantIndexOp>(alloc.getLoc(), slice.offset());
1367
1368 xla::Shape physical_shape =
1369 xla::ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
1370 static_shape);
1371 TF_ASSIGN_OR_RETURN(
1372 Type physical_out_type,
1373 xla::ConvertShapeToType<MemRefType>(physical_shape, builder_));
1374
1375 // TODO(timshen): revisit location handling.
1376 Location loc = builder_.getUnknownLoc();
1377
1378 // ViewOp only takes memrefs without affine maps (layouts). Let ViewOp produce
1379 // the physical shape (where dimensions are ordered in major to minor) first,
1380 // then follow up with a MemRefReinterpretCast to cast the resulting memref to
1381 // the original layout.
1382 Value result =
1383 builder_.create<ViewOp>(loc, physical_out_type, alloc, byte_shift,
1384 /*sizes=*/ValueRange{});
1385 if (physical_out_type != out_type) {
1386 int64_t out_offset;
1387 SmallVector<int64_t, 4> out_strides;
1388 if (failed(getStridesAndOffset(out_memref_type, out_strides, out_offset)))
1389 return tensorflow::errors::Internal(
1390 "Failed to get strides and offset from the output type.");
1391 result = builder_.create<MemRefReinterpretCastOp>(
1392 loc, out_memref_type, result, out_offset, out_memref_type.getShape(),
1393 out_strides);
1394 }
1395 return cached_value = result;
1396 }
1397
GetOrCreateViewImpl(const HloInstruction * instr,const Shape & current_shape,xla::ShapeIndex * current_shape_index,SmallVectorImpl<Value> * values)1398 Status LhloDialectEmitter::GetOrCreateViewImpl(
1399 const HloInstruction* instr, const Shape& current_shape,
1400 xla::ShapeIndex* current_shape_index, SmallVectorImpl<Value>* values) {
1401 if (current_shape.IsTuple()) {
1402 for (int i = 0; i < current_shape.tuple_shapes().size(); ++i) {
1403 current_shape_index->push_back(i);
1404 TF_RETURN_IF_ERROR(GetOrCreateViewImpl(
1405 instr, current_shape.tuple_shapes(i), current_shape_index, values));
1406 current_shape_index->pop_back();
1407 }
1408 return Status::OK();
1409 }
1410 if (current_shape.IsArray()) {
1411 TF_ASSIGN_OR_RETURN(auto v, GetOrCreateArrayView(instr, current_shape,
1412 *current_shape_index));
1413 values->push_back(v);
1414 return Status::OK();
1415 }
1416 return xla::InternalError("Unexpected shape kind for %s and shape index %s",
1417 instr->ToString(), current_shape_index->ToString());
1418 }
1419
1420 // Returns a view for the result of an instruction.
1421 // We first get a view for the slice in the allocation, and then may need to
1422 // create another view to adjust the slice for the shape of the instruction.
GetOrCreateView(const HloInstruction * instr,SmallVectorImpl<Value> * values,const xla::ShapeIndex & result_subset)1423 Status LhloDialectEmitter::GetOrCreateView(
1424 const HloInstruction* instr, SmallVectorImpl<Value>* values,
1425 const xla::ShapeIndex& result_subset) {
1426 xla::ShapeIndex shape_index = result_subset;
1427 const Shape& sub_shape =
1428 xla::ShapeUtil::GetSubshape(instr->shape(), shape_index);
1429 return GetOrCreateViewImpl(instr, sub_shape, &shape_index, values);
1430 }
1431
Initialize()1432 Status LhloDialectEmitter::Initialize() {
1433 mlir::IntegerAttr unique_id =
1434 builder_.getI32IntegerAttr(computation_.parent()->unique_id());
1435 module_->setAttr("hlo.unique_id", unique_id);
1436 std::string function_name =
1437 computation_.name().empty() ? "__compute" : computation_.name();
1438
1439 // Create the function as () -> (), we'll compute the arguments from the
1440 // buffer allocation and update the type then.
1441 auto func_op = FuncOp::create(builder_.getUnknownLoc(), function_name,
1442 builder_.getFunctionType({}, {}));
1443 Block* block = func_op.addEntryBlock();
1444
1445 llvm::SmallVector<const BufferAllocation*, 8> ordered_allocations;
1446 for (const BufferAllocation& alloc : assignment_.Allocations())
1447 ordered_allocations.push_back(&alloc);
1448
1449 if (computation_.IsEntryComputation()) {
1450 // Sort the rather arbitrarily ordered allocations to match the input/output
1451 // parameters. Specifically we want to sort buffer allocations in the
1452 // following order:
1453 // * Parameters always order before non-parameters.
1454 // * Different parameters order by parameter number.
1455 // * Different allocations for the same parameter order by the shape index.
1456 //
1457 // TODO(timshen): there should be only one non-parameter buffer, the temp
1458 // buffer. Check on that.
1459 const auto allocation_comparator = [](const BufferAllocation* lhs,
1460 const BufferAllocation* rhs) {
1461 if (lhs->is_entry_computation_parameter() !=
1462 rhs->is_entry_computation_parameter()) {
1463 return lhs->is_entry_computation_parameter() >
1464 rhs->is_entry_computation_parameter();
1465 }
1466 if (lhs->is_entry_computation_parameter()) {
1467 return std::tuple<int, const xla::ShapeIndex&>(
1468 lhs->parameter_number(), lhs->param_shape_index()) <
1469 std::tuple<int, const xla::ShapeIndex&>(
1470 rhs->parameter_number(), rhs->param_shape_index());
1471 }
1472 return false;
1473 };
1474
1475 std::stable_sort(ordered_allocations.begin(), ordered_allocations.end(),
1476 allocation_comparator);
1477 }
1478
1479 // The function signature will be composed of:
1480 // - one memref for each of the parameters.
1481 // - one memref for each other buffer allocation.
1482 llvm::SmallVector<DictionaryAttr, 8> args_attrs;
1483 for (const BufferAllocation* alloc : ordered_allocations) {
1484 if (computation_.IsEntryComputation() &&
1485 alloc->is_entry_computation_parameter()) {
1486 const xla::Shape& buffer_shape = xla::ShapeUtil::GetSubshape(
1487 computation_.parameter_instruction(alloc->parameter_number())
1488 ->shape(),
1489 alloc->param_shape_index());
1490
1491 TF_ASSIGN_OR_RETURN(auto arg_type, xla::ConvertShapeToType<MemRefType>(
1492 buffer_shape, builder_));
1493
1494 // First map parameters to memrefs on the operation.
1495 block->addArgument(arg_type);
1496 allocations_[alloc] = block->getArguments().back();
1497 NamedAttrList arg_attr_list;
1498 arg_attr_list.set("lmhlo.alloc", builder_.getIndexAttr(alloc->index()));
1499 arg_attr_list.set("lmhlo.params",
1500 builder_.getIndexAttr(alloc->parameter_number()));
1501 args_attrs.push_back(arg_attr_list.getDictionary(builder_.getContext()));
1502 } else {
1503 block->addArgument(MemRefType::get({alloc->size()}, i8_type_));
1504 allocations_[alloc] = block->getArguments().back();
1505
1506 NamedAttrList arg_attr_list;
1507 arg_attr_list.set("lmhlo.alloc", builder_.getIndexAttr(alloc->index()));
1508 arg_attr_list.set("lmhlo.liveout", builder_.getBoolAttr(true));
1509 args_attrs.push_back(arg_attr_list.getDictionary(builder_.getContext()));
1510 }
1511 }
1512
1513 FunctionType function_type =
1514 builder_.getFunctionType(block->getArgumentTypes(), {});
1515 func_op.setType(function_type);
1516 func_op.setAllArgAttrs(args_attrs);
1517
1518 SymbolTable symbol_table(module_);
1519 symbol_table.insert(func_op);
1520 builder_.setInsertionPointToEnd(block);
1521
1522 auto return_op = builder_.create<ReturnOp>(builder_.getUnknownLoc());
1523 builder_ = OpBuilder(return_op);
1524
1525 return Status::OK();
1526 }
1527
createXlaHloToLhloWithXlaPass()1528 std::unique_ptr<OperationPass<ModuleOp>> createXlaHloToLhloWithXlaPass() {
1529 return std::make_unique<XlaHloToLhloPass>();
1530 }
1531
HloToLhloModule(const BufferAssignment & assignment,const HloModule & hlo_module,ModuleOp module)1532 Status HloToLhloModule(const BufferAssignment& assignment,
1533 const HloModule& hlo_module, ModuleOp module) {
1534 module.getContext()
1535 ->loadDialect<StandardOpsDialect, mhlo::MhloDialect, lmhlo::LmhloDialect,
1536 lmhlo_gpu::LmhloGpuDialect>();
1537 const HloComputation* computation = hlo_module.entry_computation();
1538
1539 LhloDialectEmitter emitter(assignment, *computation, module);
1540 TF_RETURN_IF_ERROR(emitter.Initialize());
1541
1542 const xla::HloInstructionSequence* schedule =
1543 assignment.hlo_ordering().SequentialOrder(*computation);
1544 if (!schedule)
1545 return xla::Unimplemented("Missing sequential order for the computation");
1546 const std::vector<HloInstruction*>& ordering = schedule->instructions();
1547 return computation->AcceptOrdered(&emitter, ordering);
1548 }
1549
HloTextToLhloTranslateFunction(llvm::StringRef input,MLIRContext * context)1550 OwningModuleRef HloTextToLhloTranslateFunction(llvm::StringRef input,
1551 MLIRContext* context) {
1552 StatusOr<std::unique_ptr<HloModule>> maybe_module =
1553 xla::ParseAndReturnUnverifiedModule(
1554 absl::string_view(input.data(), input.size()));
1555 TF_CHECK_OK(maybe_module.status());
1556
1557 OwningModuleRef module = ModuleOp::create(UnknownLoc::get(context));
1558
1559 TF_CHECK_OK(
1560 ConvertModule(maybe_module.ConsumeValueOrDie(), module.get(), "Host"));
1561
1562 return module;
1563 }
1564
1565 static PassRegistration<XlaHloToLhloPass> registration(
1566 "xla-hlo-to-lhlo-with-xla",
1567 "Emit LHLO from HLO using the existing XLA implementation");
1568
1569 } // namespace mlir
1570