1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_MLIR_HLO_BUILDER_H_ 16 #define TENSORFLOW_COMPILER_MLIR_XLA_IR_MLIR_HLO_BUILDER_H_ 17 18 #include <memory> 19 20 #include "absl/container/flat_hash_map.h" 21 #include "llvm/ADT/ArrayRef.h" 22 #include "llvm/ADT/StringRef.h" 23 #include "mlir/IR/Builders.h" // from @llvm-project 24 #include "mlir/IR/BuiltinOps.h" // from @llvm-project 25 #include "mlir/IR/Location.h" // from @llvm-project 26 #include "mlir/IR/Operation.h" // from @llvm-project 27 #include "mlir/IR/Value.h" // from @llvm-project 28 #include "tensorflow/compiler/xla/client/xla_builder.h" 29 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 30 #include "tensorflow/compiler/xla/shape.h" 31 #include "tensorflow/compiler/xla/types.h" 32 #include "tensorflow/core/platform/types.h" 33 #include "tensorflow/stream_executor/lib/statusor.h" 34 35 namespace xla { 36 37 // Provides a way to construct mhlo dialect ops in MLIR using XlaBuilder 38 // interface. 39 // 40 // Requires that all XlaOp arguments are either returned by any of the builder 41 // method or constructed using MakeXlaOp method in this builder. 42 // 43 // TODO(hinsu): Support more ops and utility functions to set special attributes 44 // like OpMetadata and Sharding. 45 class MlirHloBuilder : public XlaBuilder { 46 public: 47 // Constructs builder for the given function. New operations are added to the 48 // beginning of the function, if it is non empty and has a block. MlirHloBuilder(mlir::FuncOp func)49 explicit MlirHloBuilder(mlir::FuncOp func) 50 : XlaBuilder(func.getName().str()), 51 builder_(&func.getBody()), 52 loc_(builder_.getUnknownLoc()) {} 53 54 // TODO(hinsu): Add a constructor to build a new MLIR function from scratch 55 // and override Build methods. 56 MlirHloBuilder(std::string name,mlir::OpBuilder builder,mlir::Location loc)57 MlirHloBuilder(std::string name, mlir::OpBuilder builder, mlir::Location loc) 58 : XlaBuilder(name), builder_(builder), loc_(loc) {} 59 60 MlirHloBuilder(const MlirHloBuilder&) = delete; 61 MlirHloBuilder& operator=(const MlirHloBuilder&) = delete; 62 63 ~MlirHloBuilder() override; 64 65 // Wraps the given MLIR value under an XlaOp instance. Note that all HLO 66 // operations returns exactly one result therefore each op has an XlaOp 67 // wrapping result of the op. 68 // 69 // Returns an error if the HLO dialect doesn't support type of the given 70 // value. 71 StatusOr<XlaOp> MakeXlaOp(mlir::Value val); 72 73 // Returns value corresponding to the given op. 74 // 75 // Requires that the op was created by this builder. GetValue(XlaOp op)76 mlir::Value GetValue(XlaOp op) { 77 void* ptr = reinterpret_cast<void*>(op.handle()); 78 return mlir::Value::getFromOpaquePointer(ptr); 79 } 80 81 // Returns MLIR values corresponding to the given XLA ops. 82 // 83 // Requires that the ops were created by this builder. GetValues(absl::Span<const XlaOp> ops)84 std::vector<mlir::Value> GetValues(absl::Span<const XlaOp> ops) { 85 std::vector<mlir::Value> values; 86 for (auto xla_op : ops) { 87 values.push_back(GetValue(xla_op)); 88 } 89 return values; 90 } 91 92 // Sets location for newly built ops, until reset. SetLocation(mlir::Location loc)93 void SetLocation(mlir::Location loc) { loc_ = loc; } 94 95 // Update insertion point so that newly built ops are inserted before the 96 // given op in order, until reset. setInsertionPoint(mlir::Operation * op)97 void setInsertionPoint(mlir::Operation* op) { 98 builder_.setInsertionPoint(op); 99 } 100 101 // Returns the shape of the given op. 102 StatusOr<const Shape*> GetShapePtr(XlaOp op) const override; 103 104 // Creates the given op at the current location. 105 template <typename OpTy, typename... Args> create(Args &&...args)106 OpTy create(Args&&... args) { 107 return builder_.create<OpTy>(loc_, std::forward<Args>(args)...); 108 } 109 110 private: 111 XlaOp ConstantLiteral(const LiteralSlice& literal) override; 112 113 StatusOr<XlaOp> ConvGeneralDilatedInternal( 114 const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window, 115 absl::Span<const int64> window_strides, 116 absl::Span<const std::pair<int64, int64>> padding, 117 absl::Span<const int64> lhs_dilation, 118 absl::Span<const int64> rhs_dilation, 119 const ConvolutionDimensionNumbers& dimension_numbers, 120 int64 feature_group_count, int64 batch_group_count, 121 const PrecisionConfig* precision_config) override; 122 123 StatusOr<XlaOp> FftInternal(const Shape& shape, XlaOp operand, 124 FftType fft_type, 125 absl::Span<const int64> fft_length) override; 126 127 StatusOr<XlaOp> TriangularSolveInternal( 128 const Shape& shape, XlaOp a, XlaOp b, 129 TriangularSolveOptions options) override; 130 131 StatusOr<XlaOp> CholeskyInternal(const Shape& shape, XlaOp a, 132 bool lower) override; 133 134 StatusOr<XlaOp> CustomCallInternal( 135 const string& call_target_name, absl::Span<const XlaOp> operands, 136 const Shape& shape, const string& opaque, 137 absl::optional<absl::Span<const Shape>> operand_shapes_with_layout, 138 bool has_side_effect, 139 absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>> 140 output_operand_aliasing, 141 const Literal* literal) override; 142 143 StatusOr<XlaOp> ReduceInternal( 144 const Shape& shape, absl::Span<const XlaOp> all_operands, 145 const XlaComputation& computation, 146 absl::Span<const int64> dimensions_to_reduce) override; 147 148 StatusOr<XlaOp> ReduceWindowInternal(const Shape& shape, XlaOp operand, 149 XlaOp init_value, 150 const XlaComputation& computation, 151 Window window) override; 152 153 XlaOp Iota(const Shape& shape, int64 iota_dimension) override; 154 155 StatusOr<XlaOp> BitcastConvertTypeInternal(const Shape& shape, 156 XlaOp operand) override; 157 158 StatusOr<XlaOp> TransposeInternal( 159 const Shape& shape, XlaOp operand, 160 absl::Span<const int64> permutation) override; 161 162 StatusOr<XlaOp> RevInternal(const Shape& shape, XlaOp operand, 163 absl::Span<const int64> dimensions) override; 164 165 StatusOr<XlaOp> SortInternal(const Shape& shape, 166 absl::Span<const XlaOp> operands, 167 const XlaComputation& comparator, 168 int64 dimension, bool is_stable) override; 169 170 StatusOr<XlaOp> WhileInternal(const Shape& shape, 171 const XlaComputation& condition, 172 const XlaComputation& body, 173 XlaOp init) override; 174 175 StatusOr<XlaOp> ReducePrecisionInternal(const Shape& shape, XlaOp operand, 176 const int exponent_bits, 177 const int mantissa_bits) override; 178 179 StatusOr<XlaOp> GatherInternal( 180 const Shape& shape, XlaOp input, XlaOp start_indices, 181 const GatherDimensionNumbers& dimension_numbers, 182 absl::Span<const int64> slice_sizes, bool indices_are_sorted) override; 183 184 StatusOr<XlaOp> ScatterInternal( 185 const Shape& shape, XlaOp input, XlaOp scatter_indices, XlaOp updates, 186 const XlaComputation& update_computation, 187 const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted, 188 bool unique_indices) override; 189 190 StatusOr<XlaOp> SetDimensionSizeInternal(const Shape& shape, XlaOp operand, 191 XlaOp val, int64 dimension) override; 192 193 StatusOr<XlaOp> RngOpInternal(RandomDistribution distribution, 194 absl::Span<const XlaOp> parameters, 195 const Shape& shape) override; 196 StatusOr<XlaOp> RngBitGeneratorInternal(const Shape& full_result_shape, 197 RandomAlgorithm algorithm, 198 XlaOp initial_state) override; 199 200 StatusOr<XlaOp> ReshapeInternal(const Shape& shape, XlaOp operand, 201 int64 inferred_dimension) override; 202 203 StatusOr<XlaOp> DotGeneralInternal( 204 const Shape& shape, XlaOp lhs, XlaOp rhs, 205 const DotDimensionNumbers& dimension_number, 206 const PrecisionConfig* precision_config) override; 207 208 StatusOr<XlaOp> InDimBroadcast( 209 const Shape& shape, XlaOp operand, 210 absl::Span<const int64> broadcast_dimensions) override; 211 212 StatusOr<XlaOp> AddInstruction(HloInstructionProto&& instr, HloOpcode opcode, 213 absl::Span<const XlaOp> operands) override; 214 215 StatusOr<XlaOp> Compare(const Shape& shape, XlaOp lhs, XlaOp rhs, 216 ComparisonDirection direction, 217 Comparison::Type type) override; 218 219 XlaOp BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape, XlaOp lhs, 220 XlaOp rhs) override; 221 222 StatusOr<XlaOp> AddOpWithShape(HloOpcode opcode, const Shape& shape, 223 absl::Span<const XlaOp> operands) override; 224 225 XlaOp CreateToken() override; 226 227 StatusOr<XlaOp> InfeedWithTokenInternal(const Shape& infeed_instruction_shape, 228 XlaOp token, 229 const string& config) override; 230 StatusOr<XlaOp> OutfeedWithTokenInternal( 231 XlaOp operand, XlaOp token, const Shape& shape_with_layout, 232 const string& outfeed_config) override; 233 234 StatusOr<XlaOp> ConcatInDimInternal(const Shape& shape, 235 absl::Span<const XlaOp> operands, 236 int64 dimension) override; 237 238 StatusOr<XlaOp> GetTupleElementInternal(const Shape& shape, XlaOp tuple_data, 239 int64 index) override; 240 241 StatusOr<XlaOp> SliceInternal(const Shape& shape, XlaOp operand, 242 absl::Span<const int64> start_indices, 243 absl::Span<const int64> limit_indices, 244 absl::Span<const int64> strides) override; 245 246 StatusOr<XlaOp> DynamicSliceInternal( 247 const Shape& shape, XlaOp operand, absl::Span<const XlaOp> start_indices, 248 absl::Span<const int64> slice_sizes) override; 249 250 StatusOr<XlaOp> DynamicUpdateSliceInternal( 251 const Shape& shape, XlaOp operand, XlaOp update, 252 absl::Span<const XlaOp> start_indices) override; 253 254 StatusOr<XlaOp> PadInternal(const Shape& shape, XlaOp operand, 255 XlaOp padding_value, 256 const PaddingConfig& padding_config) override; 257 258 StatusOr<XlaOp> TupleInternal(const Shape& shape, 259 absl::Span<const XlaOp> elements) override; 260 261 // Creates HLO dialect op and returns the result as an XlaOp. 262 StatusOr<XlaOp> CreateOp( 263 const std::string& op_name, const Shape& shape, 264 llvm::ArrayRef<XlaOp> operands, 265 llvm::ArrayRef<mlir::NamedAttribute> attributes = {}); 266 267 Status ImportComputation(const HloModuleProto& computation, 268 mlir::Region* region); 269 270 mlir::OpBuilder builder_; 271 mlir::Location loc_; 272 273 absl::flat_hash_map<int64, std::unique_ptr<Shape>> handle_to_shape_; 274 }; 275 276 } // namespace xla 277 278 #endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_MLIR_HLO_BUILDER_H_ 279