1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_MLIR_XLA_FUNCTION_IMPORTER_H_ 17 #define TENSORFLOW_COMPILER_MLIR_XLA_FUNCTION_IMPORTER_H_ 18 19 #include <unordered_map> 20 21 #include "absl/types/optional.h" 22 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project 23 #include "mlir/IR/Attributes.h" // from @llvm-project 24 #include "mlir/IR/Builders.h" // from @llvm-project 25 #include "mlir/IR/BuiltinOps.h" // from @llvm-project 26 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project 27 #include "mlir/IR/MLIRContext.h" // from @llvm-project 28 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" 29 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" 30 #include "tensorflow/compiler/xla/comparison_util.h" 31 #include "tensorflow/compiler/xla/status.h" 32 #include "tensorflow/compiler/xla/statusor.h" 33 #include "tensorflow/compiler/xla/xla_data.pb.h" 34 #include "tensorflow/core/platform/types.h" 35 36 namespace xla { 37 38 class HloModule; 39 class HloComputation; 40 class HloInstruction; 41 class Shape; 42 43 // Helper class for importing HloComputations. 44 class HloFunctionImporter { 45 public: 46 // Imports the given computation as a function in the given module. This also 47 // imports any computations referred by instructions in this computation. 48 static Status ImportAsFunc(const xla::HloComputation& computation, 49 mlir::ModuleOp module, 50 std::unordered_map<const xla::HloComputation*, 51 mlir::FuncOp>* function_map, 52 mlir::Builder* builder); 53 54 // Imports the given hlo computation to the specified region. 55 static Status ImportAsRegion(const xla::HloComputation& computation, 56 mlir::Region* region, mlir::Builder* builder); 57 58 // Imports the given computation to the given place specified by `builder`. 59 // `arguments` contains values for all parameters. 60 static StatusOr<mlir::Value> ImportInstructions( 61 const xla::HloComputation& computation, 62 const llvm::SmallVectorImpl<mlir::Value>& arguments, 63 mlir::OpBuilder* builder); 64 65 static StatusOr<mlir::Operation*> ImportInstruction( 66 const xla::HloInstruction* instr, 67 const llvm::SmallVectorImpl<mlir::Value>& operands, 68 mlir::OpBuilder* builder); 69 70 static void SetLayoutForMlir(mlir::Operation* op, const Shape& shape, 71 llvm::StringRef attr_name = "minor_to_major"); 72 73 // TODO(b/179166199): move this to attribute_importer.h. 74 // Converts XLA instruction source target pairs to MLIR attribute. 75 static mlir::NamedAttribute ConvertSourceTargetPairs( 76 const std::vector<std::pair<tensorflow::int64, tensorflow::int64>>& 77 source_target_pairs, 78 mlir::Builder* builder); 79 80 // TODO(b/179166199): move this to attribute_importer.h. 81 // Converts replica groups to attribute 82 static mlir::NamedAttribute ConvertReplicaGroups( 83 absl::Span<const ReplicaGroup> replica_groups, mlir::Builder* builder); 84 85 private: HloFunctionImporter(mlir::ModuleOp module,std::unordered_map<const xla::HloComputation *,mlir::FuncOp> * function_map,mlir::Builder * builder)86 HloFunctionImporter(mlir::ModuleOp module, 87 std::unordered_map<const xla::HloComputation*, 88 mlir::FuncOp>* function_map, 89 mlir::Builder* builder) 90 : context_(module.getContext()), 91 module_(module), 92 builder_(builder), 93 function_map_(function_map) { 94 context_->loadDialect<mlir::StandardOpsDialect>(); 95 context_->loadDialect<mlir::mhlo::MhloDialect>(); 96 } 97 98 // Imports the given computation as a new function, if it hasn't been already 99 // imported. 100 StatusOr<mlir::FuncOp> ImportAsFunc(const xla::HloComputation& computation); 101 102 // Imports the given computation in the specified region. 103 tensorflow::Status ImportAsRegion(const HloComputation& computation, 104 mlir::Region* region); 105 106 // Imports instructions from the given computation in the specified block. 107 // Assumes that the block already has correct arguments populated. 108 tensorflow::Status ImportInstructions(const HloComputation& computation, 109 mlir::Block* block); 110 StatusOr<mlir::Value> ImportInstructionsImpl( 111 const xla::HloComputation& computation, 112 const llvm::SmallVectorImpl<mlir::Value>& arguments, 113 mlir::OpBuilder* builder); 114 115 // Imports an instruction. 116 StatusOr<mlir::Operation*> ImportInstructionWithLayout( 117 const xla::HloInstruction* instruction, 118 const llvm::SmallVectorImpl<mlir::Value>& operands, 119 mlir::OpBuilder* func_builder); 120 StatusOr<mlir::Operation*> ImportInstructionImpl( 121 const HloInstruction* instruction, 122 const llvm::SmallVectorImpl<mlir::Value>& operands, 123 mlir::OpBuilder* func_builder); 124 125 // Gets the MLIR operand values from an HLO Instruction. 126 StatusOr<llvm::SmallVector<mlir::Value, 4>> GetOperands( 127 const xla::HloInstruction* instruction); 128 129 // Converts xla Tensor type to the corresponding MLIR type. 130 StatusOr<mlir::RankedTensorType> ConvertTensorType(const xla::Shape& shape); 131 132 // Converts an XLA shape/layout to the corresponding MLIR layout 133 StatusOr<mlir::Attribute> ConvertShapeToMlirLayout(const xla::Shape& shape); 134 135 // Returns the output type of an HloInstruction. 136 StatusOr<mlir::Type> GetReturnType(const xla::HloInstruction* instruction); 137 138 // Takes a list of HloInstructions and generates the list of types used for 139 // input, bypassing tuples to subsets. 140 Status GetMlirTypes(const std::vector<xla::HloInstruction*>& instructions, 141 llvm::SmallVectorImpl<mlir::Type>* types); 142 143 // Returns the Mlir Value for the corresponding HloInstruction. 144 StatusOr<mlir::Value> GetMlirValue(const xla::HloInstruction* instruction); 145 146 // Converts an XLA ComparisonDirection to the corresponding MLIR attribute. 147 mlir::NamedAttribute ConvertComparisonDirection( 148 ComparisonDirection direction); 149 150 // Converts an XLA Comparison::Type to the corresponding MLIR attribute. 151 mlir::NamedAttribute ConvertComparisonType(Comparison::Type type); 152 153 // Converts the dimensions of an HLO instruction into an MLIR attribute. 154 mlir::DenseIntElementsAttr ConvertDimensions( 155 llvm::ArrayRef<tensorflow::int64> op_dimensions); 156 157 // Converts Array ref to an DenseIntElementsAttr. 158 mlir::DenseIntElementsAttr Convert(llvm::ArrayRef<int64_t> elements); 159 160 // Converts Array ref to padding attribute. Input is a flattened list of 161 // padding low and padding high for each of the spatial dimensions. 162 mlir::NamedAttribute ConvertPadding(llvm::ArrayRef<int64_t> padding); 163 164 // Converts channel id to attribute 165 mlir::NamedAttribute ConvertChannelHandle( 166 absl::optional<tensorflow::int64> channel_id); 167 168 // Converts channel handle to attribute 169 mlir::NamedAttribute ConvertChannelHandle(const xla::ChannelHandle& channel); 170 171 mlir::MLIRContext* context_; 172 mlir::ModuleOp module_; 173 mlir::Builder* builder_; 174 175 // Mapping from HloComputation to the created MLIR function. 176 std::unordered_map<const xla::HloComputation*, mlir::FuncOp>* function_map_; 177 178 // Mapping from HloInstructions to the associative MLIR values. 179 std::unordered_map<const xla::HloInstruction*, mlir::Value> 180 instruction_value_map_; 181 }; 182 183 } // namespace xla 184 185 #endif // TENSORFLOW_COMPILER_MLIR_XLA_FUNCTION_IMPORTER_H_ 186