1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_ 17 #define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_ 18 19 #include "absl/types/optional.h" 20 #include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project 21 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project 22 #include "mlir/IR/Attributes.h" // from @llvm-project 23 #include "mlir/IR/Builders.h" // from @llvm-project 24 #include "mlir/IR/BuiltinOps.h" // from @llvm-project 25 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project 26 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h" 27 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" 28 #include "tensorflow/compiler/xla/service/buffer_assignment.h" 29 #include "tensorflow/compiler/xla/service/hlo_instructions.h" 30 #include "tensorflow/compiler/xla/service/hlo_module.h" 31 #include "tensorflow/compiler/xla/shape_util.h" 32 #include "tensorflow/compiler/xla/statusor.h" 33 34 namespace mlir { 35 36 // This class will process an HloModule with the supplied BufferAssignment and 37 // populate the MLIR ModuleOp with the computation converted in the LHLO 38 // dialect. 39 class LhloDialectEmitter : public xla::ConstDfsHloVisitorWithDefault { 40 public: 41 // Initializes internal data structures. It must be called before calling any 42 // of the visitors. 43 tensorflow::Status Initialize(); 44 LhloDialectEmitter(const xla::BufferAssignment & assignment,const xla::HloComputation & computation,ModuleOp module)45 LhloDialectEmitter(const xla::BufferAssignment& assignment, 46 const xla::HloComputation& computation, ModuleOp module) 47 : assignment_(std::move(assignment)), 48 computation_(computation), 49 module_(module), 50 builder_(module.getContext()), 51 i8_type_(builder_.getIntegerType(8)) {} 52 53 xla::StatusOr<mlir::Operation*> EmitOp(const xla::HloInstruction* instr); 54 55 static xla::StatusOr<mhlo::ScatterDimensionNumbers> 56 GetScatterDimensionNumbers(const xla::HloInstruction* instr, 57 mlir::MLIRContext* context); 58 59 private: 60 xla::StatusOr<lmhlo::SortOp> EmitSortOp(const xla::HloInstruction* instr); 61 xla::StatusOr<lmhlo::FusionOp> EmitFusionOp(const xla::HloInstruction* instr); 62 xla::StatusOr<lmhlo::ScatterOp> EmitScatterOp( 63 const xla::HloInstruction* instr); 64 xla::StatusOr<lmhlo::SelectAndScatterOp> EmitSelectAndScatterOp( 65 const xla::HloInstruction* instr); 66 67 xla::StatusOr<Operation*> EmitCustomCallOp(const xla::HloInstruction* instr); 68 xla::StatusOr<lmhlo_gpu::CholeskyOp> EmitCholesky( 69 const xla::HloCustomCallInstruction* custom_call); 70 xla::StatusOr<Operation*> EmitGemm( 71 const xla::HloCustomCallInstruction* custom_call); 72 xla::StatusOr<Operation*> EmitDnnConvolution( 73 const xla::HloCustomCallInstruction* custom_call); 74 xla::StatusOr<Operation*> EmitDnnBatchNorm( 75 const xla::HloCustomCallInstruction* custom_call); 76 77 xla::StatusOr<memref::GetGlobalOp> EmitConstant( 78 const xla::HloInstruction* instr); 79 80 xla::StatusOr<lmhlo::InfeedOp> EmitInfeedOp(const xla::HloInstruction* instr); 81 xla::StatusOr<lmhlo::OutfeedOp> EmitOutfeedOp( 82 const xla::HloInstruction* instr); 83 84 xla::StatusOr<lmhlo::AllToAllOp> EmitAllToAllOp( 85 const xla::HloInstruction* instr); 86 xla::StatusOr<lmhlo::AllGatherOp> EmitAllGatherOp( 87 const xla::HloInstruction* instr); 88 xla::StatusOr<lmhlo::AllReduceOp> EmitAllReduceOp( 89 const xla::HloInstruction* instr); 90 xla::StatusOr<lmhlo_gpu::AllReduceStartOp> EmitAllReduceStartOp( 91 const xla::HloInstruction* instr); 92 xla::StatusOr<lmhlo_gpu::AllReduceDoneOp> EmitAllReduceDoneOp( 93 const xla::HloInstruction* instr); 94 xla::StatusOr<lmhlo::ReduceScatterOp> EmitReduceScatterOp( 95 const xla::HloInstruction* instr); 96 xla::StatusOr<lmhlo::CollectivePermuteOp> EmitCollectivePermuteOp( 97 const xla::HloInstruction* instr); 98 99 xla::StatusOr<lmhlo::RngGetAndUpdateStateOp> EmitRngGetAndUpdateStateOp( 100 const xla::HloInstruction* instr); 101 xla::StatusOr<lmhlo::FftOp> EmitFftOp(const xla::HloInstruction* instr); 102 xla::StatusOr<lmhlo::TriangularSolveOp> EmitTriangularSolveOp( 103 const xla::HloInstruction* instr); 104 xla::StatusOr<Operation*> EmitBitcast(const xla::HloInstruction* instr); 105 106 xla::StatusOr<lmhlo::CaseOp> EmitCaseOp(const xla::HloInstruction* instr); 107 108 xla::StatusOr<lmhlo::WhileOp> EmitWhileOp(const xla::HloInstruction* instr); 109 110 xla::Status ImportAsLmhloRegion(xla::HloComputation* computation, 111 mlir::Region* region); 112 113 // Since LMHLO dialect does not define token types, this enum controls how 114 // token operand/results from XLA:HLO are lowered to MLIR. 115 enum class TokenLoweringMode { 116 kFailToLower, // Fail lowering if token inputs are encountered. 117 kUseNull, // Use a null Value in the operand list for each token. 118 // kSkip, // Skip any token inputs or outputs (not yet needed) 119 }; 120 121 // Create LHLO operation operands given an XLA HLO instruction. By default, 122 // all XLA HLO operands and results are converted to MLIR and appended to 123 // `operands`. If `num_operands` is specified, only the first `num_operand` 124 // operands of the instruction are converted to MLIR. The function returns the 125 // actual number of operands and results generated for MLIR in `num_arguments` 126 // and `num_results`. 127 xla::Status CreateOperands(const xla::HloInstruction* instr, 128 absl::optional<xla::int64> num_operands, 129 TokenLoweringMode token_mode, 130 SmallVectorImpl<Value>& operands, 131 size_t& num_arguments, size_t& num_results); 132 133 template <typename OpType> 134 xla::StatusOr<OpType> CreateOpWithoutAttrs( 135 const xla::HloInstruction* instr, 136 absl::optional<xla::int64> num_operands = absl::nullopt) { 137 size_t unused; 138 return CreateOpWithoutAttrs<OpType>(instr, unused, unused, num_operands); 139 } 140 141 template <typename OpType> 142 xla::StatusOr<OpType> CreateOpWithoutAttrs( 143 const xla::HloInstruction* instr, size_t& num_arguments, 144 size_t& num_results, 145 absl::optional<xla::int64> num_operands = absl::nullopt); 146 147 template <typename OpType> 148 OpType CreateOpWithoutAttrs(const xla::HloInstruction* instr, 149 ValueRange operands); 150 151 xla::StatusOr<mlir::Operation*> CreateOpInFusion( 152 const xla::HloInstruction* instr, ValueRange buffer_operands, 153 size_t num_arguments, size_t num_results); 154 155 xla::StatusOr<mlir::Operation*> CreateOpInFusion( 156 const xla::HloInstruction* instr); 157 158 template <typename T> GetI64DenseElementsAttr(const T & container)159 DenseIntElementsAttr GetI64DenseElementsAttr(const T& container) { 160 return builder_.getI64TensorAttr( 161 {container.data(), static_cast<size_t>(container.size())}); 162 } 163 GetWindowElements(const xla::Window & window,std::function<int64_t (const xla::WindowDimension & dim)> getter)164 DenseIntElementsAttr GetWindowElements( 165 const xla::Window& window, 166 std::function<int64_t(const xla::WindowDimension& dim)> getter) { 167 llvm::SmallVector<int64_t, 4> elements; 168 elements.reserve(window.dimensions_size()); 169 for (const xla::WindowDimension& dim : window.dimensions()) { 170 elements.push_back(getter(dim)); 171 } 172 return GetI64DenseElementsAttr(elements); 173 } 174 175 static mlir::DenseIntElementsAttr GetLayoutAttribute( 176 const xla::Layout& layout, Builder* builder); 177 178 tensorflow::Status DefaultAction(const xla::HloInstruction* instr) final; 179 180 // Computation parameters don't need any specific handling when they are 181 // visited, they are already processed when we enter a new computation. HandleParameter(const xla::HloInstruction * instr)182 tensorflow::Status HandleParameter(const xla::HloInstruction* instr) final { 183 return tensorflow::Status::OK(); 184 } 185 186 // Helper function that recursively visits the tuple structure in 187 // `current_shape`, and reconstruct a matching lmhlo::TupleOp. 188 // Each leaf node is converted to an std.view op with corresponding offsets. 189 // If no tuple presents, it simply returns a view of the buffer. 190 tensorflow::Status GetOrCreateViewImpl(const xla::HloInstruction* instr, 191 const xla::Shape& current_shape, 192 xla::ShapeIndex* current_shape_index, 193 SmallVectorImpl<Value>* values, 194 TokenLoweringMode token_mode); 195 196 // Helper function to create view/tuple of views to a buffer for a given 197 // instruction result. `result_subset` can be used to for instructions that 198 // have a tuple result and MLIR conversion needs to convert only one of the 199 // tuple elements. Note that if needed, this can be extended to take a list of 200 // ShapeIndex values in case we need finer control on what elements of the 201 // output tuple to be converted to MLIR. 202 tensorflow::Status GetOrCreateView( 203 const xla::HloInstruction* instr, SmallVectorImpl<Value>* values, 204 const xla::ShapeIndex& result_subset = {}, 205 TokenLoweringMode token_mode = TokenLoweringMode::kFailToLower); 206 207 xla::StatusOr<Value> GetOrCreateArrayView( 208 const xla::HloInstruction* instr, const xla::Shape& current_shape, 209 const xla::ShapeIndex& current_shape_index); 210 211 xla::StatusOr<Value> RewriteFusionOperand(const xla::HloInstruction* root, 212 const xla::Shape& shape, 213 xla::ShapeIndex* shape_index, 214 OpBuilder* b, Location loc); 215 216 // Return an MLIR location for an HLO instruction. getLocation(const xla::HloInstruction * inst)217 Location getLocation(const xla::HloInstruction* inst) { 218 return NameLoc::get(builder_.getIdentifier(inst->name())); 219 } 220 221 // This map provides access to MLIR buffers for each HLO buffer allocation. 222 // The MLIR buffers are all `memref<{size}xi8>` and correspond to function 223 // parameters. It is populated at the beginning of the processing with all 224 // the buffer allocations and is unchanged afterward. Every HLOInstruction 225 // is using a "slice" of the buffer allocation and providing shape, layout, 226 // and Dtype. An MLIR view is used separately to model slices into the 227 // allocations (see below). 228 llvm::DenseMap<const xla::BufferAllocation*, Value> allocations_; 229 230 // This map provides access to MLIR buffers for each HLO instruction, keyed 231 // instruction identity. A slice is contained in a BufferAllocation, and has 232 // an offset and a size. 233 // 234 // As for why we don't use HloInstruction*, see GetOrCreateView(), but 235 // mostly we want to leverage better of the aliased buffers. 236 // 237 // If the HloInstruction is a tuple, all leaf nodes are stored flattened. 238 // Otherwise, there will be a single buffer. 239 // 240 // An MLIR buffer is either an input parameter, or a ViewOp in the case 241 // where the slice is only part of its allocation. 242 // 243 // `slices_` is populated lazily in the `GetOrCreateView()` helper as we 244 // process every instruction. 245 absl::flat_hash_map<std::pair<const xla::HloInstruction*, xla::ShapeIndex>, 246 Value> 247 slices_; 248 249 // The BufferAssignment computed by XLA ahead of time. 250 const xla::BufferAssignment& assignment_; 251 252 // The HLO module that will be converted. 253 const xla::HloComputation& computation_; 254 255 // This is the MLIR module in which a function will be created for every HLO 256 // computation. 257 ModuleOp module_; 258 259 // The builder keeps track of the current insertion point in the MLIR 260 // module. 261 OpBuilder builder_; 262 // Convenient "cached" access to this widely used MLIR type (i8). 263 Type i8_type_; 264 265 // Map all-reduce-start ops to their LHLO op, so we can connect the 266 // all-reduce-done op with the correct token. 267 absl::flat_hash_map<const xla::HloInstruction*, lmhlo_gpu::AllReduceStartOp> 268 all_reduce_start_ops_; 269 }; 270 271 // Populate the MLIR `module` with the computation from the `hlo_module` using 272 // the provided buffer `assignment`. The returned `Status` indicates success 273 // or failure in the conversion. 274 tensorflow::Status HloToLhloModule(const xla::BufferAssignment& assignment, 275 const xla::HloModule& hlo_module, 276 ModuleOp module); 277 278 tensorflow::Status OptimizeAndConvertHloToLmhlo( 279 std::unique_ptr<xla::HloModule> hlo_module, ModuleOp module, 280 StringRef platform_name); 281 282 OwningModuleRef HloTextToLhloTranslateFunction(llvm::StringRef input, 283 MLIRContext* context); 284 285 } // namespace mlir 286 287 #endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_ 288