1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_MLIR_XLA_FUNCTION_IMPORTER_H_ 17 #define TENSORFLOW_COMPILER_MLIR_XLA_FUNCTION_IMPORTER_H_ 18 19 #include <unordered_map> 20 21 #include "absl/types/optional.h" 22 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project 23 #include "mlir/IR/Attributes.h" // from @llvm-project 24 #include "mlir/IR/Builders.h" // from @llvm-project 25 #include "mlir/IR/BuiltinOps.h" // from @llvm-project 26 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project 27 #include "mlir/IR/MLIRContext.h" // from @llvm-project 28 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" 29 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" 30 #include "tensorflow/compiler/xla/comparison_util.h" 31 #include "tensorflow/compiler/xla/status.h" 32 #include "tensorflow/compiler/xla/statusor.h" 33 #include "tensorflow/compiler/xla/xla_data.pb.h" 34 #include "tensorflow/core/platform/types.h" 35 36 namespace xla { 37 38 class HloModule; 39 class HloComputation; 40 class HloInstruction; 41 class Shape; 42 43 // Helper class for importing HloComputations. 44 class HloFunctionImporter { 45 public: 46 // Imports the given computation as a function in the given module. This also 47 // imports any computations referred by instructions in this computation. 48 static Status ImportAsFunc(const xla::HloComputation& computation, 49 mlir::ModuleOp module, 50 std::unordered_map<const xla::HloComputation*, 51 mlir::FuncOp>* function_map, 52 mlir::Builder* builder); 53 54 // Imports the given hlo computation to the specified region. 55 static Status ImportAsRegion(const xla::HloComputation& computation, 56 mlir::Region* region, mlir::Builder* builder); 57 58 // Imports the given computation to the given place specified by `builder`. 59 // `arguments` contains values for all parameters. 60 static StatusOr<mlir::Value> ImportInstructions( 61 const xla::HloComputation& computation, 62 const llvm::SmallVectorImpl<mlir::Value>& arguments, 63 mlir::OpBuilder* builder); 64 65 static void SetLayoutForMlir(mlir::Operation* op, const Shape& shape); 66 67 // TODO(b/179166199): move this to attribute_importer.h. 68 // Converts XLA instruction source target pairs to MLIR attribute. 69 static mlir::NamedAttribute ConvertSourceTargetPairs( 70 const std::vector<std::pair<tensorflow::int64, tensorflow::int64>>& 71 source_target_pairs, 72 mlir::Builder* builder); 73 74 // TODO(b/179166199): move this to attribute_importer.h. 75 // Converts replica groups to attribute 76 static mlir::NamedAttribute ConvertReplicaGroups( 77 const std::vector<ReplicaGroup>& replica_groups, mlir::Builder* builder); 78 79 private: HloFunctionImporter(mlir::ModuleOp module,std::unordered_map<const xla::HloComputation *,mlir::FuncOp> * function_map,mlir::Builder * builder)80 HloFunctionImporter(mlir::ModuleOp module, 81 std::unordered_map<const xla::HloComputation*, 82 mlir::FuncOp>* function_map, 83 mlir::Builder* builder) 84 : context_(module.getContext()), 85 module_(module), 86 builder_(builder), 87 function_map_(function_map) { 88 context_->loadDialect<mlir::StandardOpsDialect>(); 89 context_->loadDialect<mlir::mhlo::MhloDialect>(); 90 } 91 92 // Imports the given computation as a new function, if it hasn't been already 93 // imported. 94 StatusOr<mlir::FuncOp> ImportAsFunc(const xla::HloComputation& computation); 95 96 // Imports the given computation in the specified region. 97 tensorflow::Status ImportAsRegion(const HloComputation& computation, 98 mlir::Region* region); 99 100 // Imports instructions from the given computation in the specified block. 101 // Assumes that the block already has correct arguments populated. 102 tensorflow::Status ImportInstructions(const HloComputation& computation, 103 mlir::Block* block); 104 StatusOr<mlir::Value> ImportInstructionsImpl( 105 const xla::HloComputation& computation, 106 const llvm::SmallVectorImpl<mlir::Value>& arguments, 107 mlir::OpBuilder* builder); 108 109 // Imports an instruction. 110 StatusOr<mlir::Operation*> ImportInstruction(xla::HloInstruction* instruction, 111 mlir::OpBuilder* func_builder); 112 StatusOr<mlir::Operation*> ImportInstructionImpl( 113 HloInstruction* instruction, mlir::OpBuilder* func_builder); 114 115 // Gets the MLIR operand values from an HLO Instruction. 116 StatusOr<llvm::SmallVector<mlir::Value, 4>> GetOperands( 117 xla::HloInstruction* instruction); 118 119 // Converts xla Tensor type to the corresponding MLIR type. 120 StatusOr<mlir::RankedTensorType> ConvertTensorType(const xla::Shape& shape); 121 122 // Returns the output type of an HloInstruction. 123 StatusOr<mlir::Type> GetReturnType(xla::HloInstruction* instruction); 124 125 // Takes a list of HloInstructions and generates the list of types used for 126 // input, bypassing tuples to subsets. 127 Status GetMlirTypes(const std::vector<xla::HloInstruction*>& instructions, 128 llvm::SmallVectorImpl<mlir::Type>* types); 129 130 // Returns the Mlir Value for the corresponding HloInstruction. 131 StatusOr<mlir::Value> GetMlirValue(xla::HloInstruction* instruction); 132 133 // Converts an XLA ComparisonDirection to the corresponding MLIR attribute. 134 mlir::NamedAttribute ConvertComparisonDirection( 135 ComparisonDirection direction); 136 137 // Converts an XLA Comparison::Type to the corresponding MLIR attribute. 138 mlir::NamedAttribute ConvertComparisonType(Comparison::Type type); 139 140 // Converts the dimensions of an HLO instruction into an MLIR attribute. 141 mlir::DenseIntElementsAttr ConvertDimensions( 142 llvm::ArrayRef<tensorflow::int64> op_dimensions); 143 144 // Converts Array ref to an DenseIntElementsAttr. 145 mlir::DenseIntElementsAttr Convert(llvm::ArrayRef<int64_t> elements); 146 147 // Converts Array ref to padding attribute. Input is a flattened list of 148 // padding low and padding high for each of the spatial dimensions. 149 mlir::NamedAttribute ConvertPadding(llvm::ArrayRef<int64_t> padding); 150 151 // Converts channel id to attribute 152 mlir::NamedAttribute ConvertChannelHandle( 153 absl::optional<tensorflow::int64> channel_id); 154 155 // Converts channel handle to attribute 156 mlir::NamedAttribute ConvertChannelHandle(const xla::ChannelHandle& channel); 157 158 mlir::MLIRContext* context_; 159 mlir::ModuleOp module_; 160 mlir::Builder* builder_; 161 162 // Mapping from HloComputation to the created MLIR function. 163 std::unordered_map<const xla::HloComputation*, mlir::FuncOp>* function_map_; 164 165 // Mapping from HloInstructions to the associative MLIR values. 166 std::unordered_map<xla::HloInstruction*, mlir::Value> instruction_value_map_; 167 }; 168 169 } // namespace xla 170 171 #endif // TENSORFLOW_COMPILER_MLIR_XLA_FUNCTION_IMPORTER_H_ 172