1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_ 17 #define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_ 18 19 #include "absl/types/optional.h" 20 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project 21 #include "mlir/IR/Attributes.h" // from @llvm-project 22 #include "mlir/IR/Builders.h" // from @llvm-project 23 #include "mlir/IR/BuiltinOps.h" // from @llvm-project 24 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project 25 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h" 26 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" 27 #include "tensorflow/compiler/xla/service/buffer_assignment.h" 28 #include "tensorflow/compiler/xla/service/hlo_instructions.h" 29 #include "tensorflow/compiler/xla/service/hlo_module.h" 30 #include "tensorflow/compiler/xla/shape_util.h" 31 #include "tensorflow/compiler/xla/statusor.h" 32 33 namespace mlir { 34 35 // This class will process an HloModule with the supplied BufferAssignment and 36 // populate the MLIR ModuleOp with the computation converted in the LHLO 37 // dialect. 38 class LhloDialectEmitter : public xla::ConstDfsHloVisitorWithDefault { 39 public: 40 // Initializes internal data structures. It must be called before calling any 41 // of the visitors. 42 tensorflow::Status Initialize(); 43 LhloDialectEmitter(const xla::BufferAssignment & assignment,const xla::HloComputation & computation,ModuleOp module)44 LhloDialectEmitter(const xla::BufferAssignment& assignment, 45 const xla::HloComputation& computation, ModuleOp module) 46 : assignment_(std::move(assignment)), 47 computation_(computation), 48 module_(module), 49 builder_(module.getContext()), 50 i8_type_(builder_.getIntegerType(8)) {} 51 52 xla::StatusOr<mlir::Operation*> EmitOp(const xla::HloInstruction* instr); 53 54 xla::StatusOr<mhlo::ScatterDimensionNumbers> GetScatterDimensionNumbers( 55 const xla::HloInstruction* instr); 56 57 private: 58 xla::StatusOr<lmhlo::SortOp> EmitSortOp(const xla::HloInstruction* instr); 59 xla::StatusOr<lmhlo::FusionOp> EmitFusionOp(const xla::HloInstruction* instr); 60 xla::StatusOr<lmhlo::ScatterOp> EmitScatterOp( 61 const xla::HloInstruction* instr); 62 xla::StatusOr<lmhlo::SelectAndScatterOp> EmitSelectAndScatterOp( 63 const xla::HloInstruction* instr); 64 65 xla::StatusOr<Operation*> EmitCustomCallOp(const xla::HloInstruction* instr); 66 xla::StatusOr<lmhlo_gpu::CholeskyOp> EmitCholesky( 67 const xla::HloCustomCallInstruction* custom_call); 68 xla::StatusOr<Operation*> EmitGemm( 69 const xla::HloCustomCallInstruction* custom_call); 70 xla::StatusOr<Operation*> EmitDnnConvolution( 71 const xla::HloCustomCallInstruction* custom_call); 72 xla::StatusOr<Operation*> EmitDnnBatchNorm( 73 const xla::HloCustomCallInstruction* custom_call); 74 75 xla::StatusOr<lmhlo::ReduceOp> EmitReduceOp(const xla::HloInstruction* instr); 76 xla::StatusOr<GetGlobalMemrefOp> EmitConstant( 77 const xla::HloInstruction* instr); 78 79 xla::StatusOr<lmhlo::CompareOp> EmitCompareOp( 80 const xla::HloInstruction* instr); 81 82 xla::StatusOr<lmhlo::InfeedOp> EmitInfeedOp(const xla::HloInstruction* instr); 83 xla::StatusOr<lmhlo::OutfeedOp> EmitOutfeedOp( 84 const xla::HloInstruction* instr); 85 xla::StatusOr<lmhlo::MapOp> EmitMapOp(const xla::HloInstruction* instr); 86 87 xla::StatusOr<lmhlo::ReducePrecisionOp> EmitReducePrecisionOp( 88 const xla::HloInstruction* instr); 89 90 xla::StatusOr<lmhlo::AllToAllOp> EmitAllToAllOp( 91 const xla::HloInstruction* instr); 92 xla::StatusOr<lmhlo::AllGatherOp> EmitAllGatherOp( 93 const xla::HloInstruction* instr); 94 xla::StatusOr<lmhlo::AllReduceOp> EmitAllReduceOp( 95 const xla::HloInstruction* instr); 96 xla::StatusOr<lmhlo::CollectivePermuteOp> EmitCollectivePermuteOp( 97 const xla::HloInstruction* instr); 98 99 xla::StatusOr<lmhlo::BroadcastInDimOp> EmitBroadcastOp( 100 const xla::HloInstruction* instr); 101 102 xla::StatusOr<lmhlo::ConcatenateOp> EmitConcatenateOp( 103 const xla::HloInstruction* instr); 104 105 xla::StatusOr<lmhlo::IotaOp> EmitIotaOp(const xla::HloInstruction* instr); 106 107 xla::StatusOr<lmhlo::ReverseOp> EmitReverseOp( 108 const xla::HloInstruction* instr); 109 110 xla::StatusOr<lmhlo::TransposeOp> EmitTransposeOp( 111 const xla::HloInstruction* instr); 112 113 xla::StatusOr<lmhlo::PadOp> EmitPadOp(const xla::HloInstruction* instr); 114 115 xla::StatusOr<lmhlo::ReduceWindowOp> EmitReduceWindowOp( 116 const xla::HloInstruction* instr); 117 118 xla::StatusOr<lmhlo::SliceOp> EmitSliceOp(const xla::HloInstruction* instr); 119 120 xla::StatusOr<lmhlo::GatherOp> EmitGatherOp(const xla::HloInstruction* instr); 121 122 xla::StatusOr<lmhlo::DynamicSliceOp> EmitDynamicSliceOp( 123 const xla::HloInstruction* instr); 124 125 xla::StatusOr<lmhlo::DotOp> EmitDotOp(const xla::HloInstruction* instr); 126 xla::StatusOr<lmhlo::RngGetAndUpdateStateOp> EmitRngGetAndUpdateStateOp( 127 const xla::HloInstruction* instr); 128 xla::StatusOr<lmhlo::FftOp> EmitFftOp(const xla::HloInstruction* instr); 129 xla::StatusOr<lmhlo::TriangularSolveOp> EmitTriangularSolveOp( 130 const xla::HloInstruction* instr); 131 132 // Create LHLO operation operands given an XLA HLO instruction. By default, 133 // all XLA HLO operands and results are converted to MLIR and appended to 134 // `operands`. If `num_operands` is specified, only the first `num_operand` 135 // operands of the instruction are converted to MLIR. The function returns the 136 // actual number of operands and results generated for MLIR in `num_arguments` 137 // and `num_results`. 138 xla::Status CreateOperands(const xla::HloInstruction* instr, 139 absl::optional<xla::int64> num_operands, 140 SmallVectorImpl<Value>& operands, 141 size_t& num_arguments, size_t& num_results); 142 143 template <typename OpType> 144 xla::StatusOr<OpType> CreateOpWithoutAttrs( 145 const xla::HloInstruction* instr, 146 absl::optional<xla::int64> num_operands = absl::nullopt) { 147 size_t unused; 148 return CreateOpWithoutAttrs<OpType>(instr, unused, unused, num_operands); 149 } 150 151 template <typename OpType> 152 xla::StatusOr<OpType> CreateOpWithoutAttrs( 153 const xla::HloInstruction* instr, size_t& num_arguments, 154 size_t& num_results, 155 absl::optional<xla::int64> num_operands = absl::nullopt); 156 157 template <typename OpType> 158 OpType CreateOpWithoutAttrs(const xla::HloInstruction* instr, 159 ValueRange operands); 160 161 template <typename T> GetI64DenseElementsAttr(const T & container)162 DenseIntElementsAttr GetI64DenseElementsAttr(const T& container) { 163 return builder_.getI64TensorAttr( 164 {container.data(), static_cast<size_t>(container.size())}); 165 } 166 GetWindowElements(const xla::Window & window,std::function<int64_t (const xla::WindowDimension & dim)> getter)167 DenseIntElementsAttr GetWindowElements( 168 const xla::Window& window, 169 std::function<int64_t(const xla::WindowDimension& dim)> getter) { 170 llvm::SmallVector<int64_t, 4> elements; 171 elements.reserve(window.dimensions_size()); 172 for (const xla::WindowDimension& dim : window.dimensions()) { 173 elements.push_back(getter(dim)); 174 } 175 return GetI64DenseElementsAttr(elements); 176 } 177 178 static mlir::DenseIntElementsAttr GetLayoutAttribute( 179 const xla::Layout& layout, Builder* builder); 180 181 tensorflow::Status DefaultAction(const xla::HloInstruction* instr) final; 182 183 // Computation parameters don't need any specific handling when they are 184 // visited, they are already processed when we enter a new computation. HandleParameter(const xla::HloInstruction * instr)185 tensorflow::Status HandleParameter(const xla::HloInstruction* instr) final { 186 return tensorflow::Status::OK(); 187 } 188 189 // Helper function that recursively visits the tuple structure in 190 // `current_shape`, and reconstruct a matching lmhlo::TupleOp. 191 // Each leaf node is converted to an std.view op with corresponding offsets. 192 // If no tuple presents, it simply returns a view of the buffer. 193 tensorflow::Status GetOrCreateViewImpl(const xla::HloInstruction* instr, 194 const xla::Shape& current_shape, 195 xla::ShapeIndex* current_shape_index, 196 SmallVectorImpl<Value>* values); 197 198 // Helper function to create view/tuple of views to a buffer for a given 199 // instruction result. `result_subset` can be used to for instructions that 200 // have a tuple result and MLIR conversion needs to convert only one of the 201 // tuple elements. Note that if needed, this can be extended to take a list of 202 // ShapeIndex values in case we need finer control on what elements of the 203 // output tuple to be converted to MLIR. 204 tensorflow::Status GetOrCreateView(const xla::HloInstruction* instr, 205 SmallVectorImpl<Value>* values, 206 const xla::ShapeIndex& result_subset = {}); 207 208 xla::StatusOr<Value> GetOrCreateArrayView( 209 const xla::HloInstruction* instr, const xla::Shape& current_shape, 210 const xla::ShapeIndex& current_shape_index); 211 212 xla::StatusOr<Value> RewriteFusionOperand(const xla::HloInstruction* root, 213 const xla::Shape& shape, 214 xla::ShapeIndex* shape_index, 215 OpBuilder* b, Location loc); 216 217 // Return an MLIR location for an HLO instruction. getLocation(const xla::HloInstruction * inst)218 Location getLocation(const xla::HloInstruction* inst) { 219 return NameLoc::get(builder_.getIdentifier(inst->name()), 220 builder_.getContext()); 221 } 222 223 // This map provides access to MLIR buffers for each HLO buffer allocation. 224 // The MLIR buffers are all `memref<{size}xi8>` and correspond to function 225 // parameters. It is populated at the beginning of the processing with all 226 // the buffer allocations and is unchanged afterward. Every HLOInstruction 227 // is using a "slice" of the buffer allocation and providing shape, layout, 228 // and Dtype. An MLIR view is used separately to model slices into the 229 // allocations (see below). 230 llvm::DenseMap<const xla::BufferAllocation*, Value> allocations_; 231 232 // This map provides access to MLIR buffers for each HLO instruction, keyed 233 // instruction identity. A slice is contained in a BufferAllocation, and has 234 // an offset and a size. 235 // 236 // As for why we don't use HloInstruction*, see GetOrCreateView(), but 237 // mostly we want to leverage better of the aliased buffers. 238 // 239 // If the HloInstruction is a tuple, all leaf nodes are stored flattened. 240 // Otherwise, there will be a single buffer. 241 // 242 // An MLIR buffer is either an input parameter, or a ViewOp in the case 243 // where the slice is only part of its allocation. 244 // 245 // `slices_` is populated lazily in the `GetOrCreateView()` helper as we 246 // process every instruction. 247 absl::flat_hash_map<std::pair<const xla::HloInstruction*, xla::ShapeIndex>, 248 Value> 249 slices_; 250 251 // The BufferAssignment computed by XLA ahead of time. 252 const xla::BufferAssignment& assignment_; 253 254 // The HLO module that will be converted. 255 const xla::HloComputation& computation_; 256 257 // This is the MLIR module in which a function will be created for every HLO 258 // computation. 259 ModuleOp module_; 260 261 // The builder keeps track of the current insertion point in the MLIR 262 // module. 263 OpBuilder builder_; 264 // Convenient "cached" access to this widely used MLIR type (i8). 265 Type i8_type_; 266 }; 267 268 // Populate the MLIR `module` with the computation from the `hlo_module` using 269 // the provided buffer `assignment`. The returned `Status` indicates success 270 // or failure in the conversion. 271 tensorflow::Status HloToLhloModule(const xla::BufferAssignment& assignment, 272 const xla::HloModule& hlo_module, 273 ModuleOp module); 274 275 OwningModuleRef HloTextToLhloTranslateFunction(llvm::StringRef input, 276 MLIRContext* context); 277 278 } // namespace mlir 279 280 #endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_ 281