1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_MLIR_HLO_BUILDER_H_ 16 #define TENSORFLOW_COMPILER_MLIR_XLA_IR_MLIR_HLO_BUILDER_H_ 17 18 #include <memory> 19 20 #include "absl/container/flat_hash_map.h" 21 #include "llvm/ADT/ArrayRef.h" 22 #include "llvm/ADT/StringRef.h" 23 #include "mlir/IR/Builders.h" // from @llvm-project 24 #include "mlir/IR/BuiltinOps.h" // from @llvm-project 25 #include "mlir/IR/Location.h" // from @llvm-project 26 #include "mlir/IR/Operation.h" // from @llvm-project 27 #include "mlir/IR/Value.h" // from @llvm-project 28 #include "tensorflow/compiler/xla/client/xla_builder.h" 29 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 30 #include "tensorflow/compiler/xla/shape.h" 31 #include "tensorflow/compiler/xla/types.h" 32 #include "tensorflow/core/platform/types.h" 33 #include "tensorflow/stream_executor/lib/statusor.h" 34 35 namespace xla { 36 37 // Provides a way to construct mhlo dialect ops in MLIR using XlaBuilder 38 // interface. 39 // 40 // Requires that all XlaOp arguments are either returned by any of the builder 41 // method or constructed using MakeXlaOp method in this builder. 42 // 43 // TODO(hinsu): Support more ops and utility functions to set special attributes 44 // like OpMetadata and Sharding. 45 class MlirHloBuilder : public XlaBuilder { 46 public: 47 // Constructs builder for the given function. New operations are added to the 48 // beginning of the function, if it is non empty and has a block. MlirHloBuilder(mlir::FuncOp func)49 explicit MlirHloBuilder(mlir::FuncOp func) 50 : XlaBuilder(func.getName().str()), 51 builder_(&func.getBody()), 52 loc_(builder_.getUnknownLoc()) {} 53 54 // TODO(hinsu): Add a constructor to build a new MLIR function from scratch 55 // and override Build methods. 56 MlirHloBuilder(std::string name,mlir::OpBuilder builder,mlir::Location loc)57 MlirHloBuilder(std::string name, mlir::OpBuilder builder, mlir::Location loc) 58 : XlaBuilder(name), builder_(builder), loc_(loc) {} 59 60 MlirHloBuilder(const MlirHloBuilder&) = delete; 61 MlirHloBuilder& operator=(const MlirHloBuilder&) = delete; 62 63 ~MlirHloBuilder() override; 64 65 // Wraps the given MLIR value under an XlaOp instance. Note that all HLO 66 // operations returns exactly one result therefore each op has an XlaOp 67 // wrapping result of the op. 68 // 69 // Returns an error if the HLO dialect doesn't support type of the given 70 // value. 71 StatusOr<XlaOp> MakeXlaOp(mlir::Value val); 72 73 // Returns value corresponding to the given op. 74 // 75 // Requires that the op was created by this builder. GetValue(XlaOp op)76 mlir::Value GetValue(XlaOp op) { 77 void* ptr = reinterpret_cast<void*>(op.handle()); 78 return mlir::Value::getFromOpaquePointer(ptr); 79 } 80 81 // Returns MLIR values corresponding to the given XLA ops. 82 // 83 // Requires that the ops were created by this builder. GetValues(absl::Span<const XlaOp> ops)84 std::vector<mlir::Value> GetValues(absl::Span<const XlaOp> ops) { 85 std::vector<mlir::Value> values; 86 for (auto xla_op : ops) { 87 values.push_back(GetValue(xla_op)); 88 } 89 return values; 90 } 91 92 // Sets location for newly built ops, until reset. SetLocation(mlir::Location loc)93 void SetLocation(mlir::Location loc) { loc_ = loc; } 94 95 // Update insertion point so that newly built ops are inserted before the 96 // given op in order, until reset. setInsertionPoint(mlir::Operation * op)97 void setInsertionPoint(mlir::Operation* op) { 98 builder_.setInsertionPoint(op); 99 } 100 101 // Returns the shape of the given op. 102 StatusOr<const Shape*> GetShapePtr(XlaOp op) const override; 103 104 // Creates the given op at the current location. 105 template <typename OpTy, typename... Args> create(Args &&...args)106 OpTy create(Args&&... args) { 107 return builder_.create<OpTy>(loc_, std::forward<Args>(args)...); 108 } 109 110 private: 111 XlaOp ConstantLiteral(const LiteralSlice& literal) override; 112 113 StatusOr<XlaOp> ConvGeneralDilatedInternal( 114 const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window, 115 absl::Span<const int64> window_strides, 116 absl::Span<const std::pair<int64, int64>> padding, 117 absl::Span<const int64> lhs_dilation, 118 absl::Span<const int64> rhs_dilation, 119 const ConvolutionDimensionNumbers& dimension_numbers, 120 int64_t feature_group_count, int64_t batch_group_count, 121 const PrecisionConfig* precision_config) override; 122 123 StatusOr<XlaOp> FftInternal(const Shape& shape, XlaOp operand, 124 FftType fft_type, 125 absl::Span<const int64> fft_length) override; 126 127 StatusOr<XlaOp> TriangularSolveInternal( 128 const Shape& shape, XlaOp a, XlaOp b, 129 TriangularSolveOptions options) override; 130 131 StatusOr<XlaOp> CholeskyInternal(const Shape& shape, XlaOp a, 132 bool lower) override; 133 134 StatusOr<XlaOp> CustomCallInternal( 135 const string& call_target_name, absl::Span<const XlaOp> operands, 136 const Shape& shape, const string& opaque, 137 absl::optional<absl::Span<const Shape>> operand_shapes_with_layout, 138 bool has_side_effect, 139 absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>> 140 output_operand_aliasing, 141 const Literal* literal, absl::optional<Window> window, 142 absl::optional<ConvolutionDimensionNumbers> dnums, 143 CustomCallSchedule schedule, CustomCallApiVersion api_version) override; 144 145 StatusOr<XlaOp> ReduceInternal( 146 const Shape& shape, absl::Span<const XlaOp> all_operands, 147 const XlaComputation& computation, 148 absl::Span<const int64> dimensions_to_reduce) override; 149 150 StatusOr<XlaOp> ReduceWindowInternal(const Shape& shape, XlaOp operand, 151 XlaOp init_value, 152 const XlaComputation& computation, 153 Window window) override; 154 155 XlaOp Iota(const Shape& shape, int64_t iota_dimension) override; 156 157 StatusOr<XlaOp> BitcastConvertTypeInternal(const Shape& shape, 158 XlaOp operand) override; 159 160 StatusOr<XlaOp> TransposeInternal( 161 const Shape& shape, XlaOp operand, 162 absl::Span<const int64> permutation) override; 163 164 StatusOr<XlaOp> RevInternal(const Shape& shape, XlaOp operand, 165 absl::Span<const int64> dimensions) override; 166 167 StatusOr<XlaOp> SortInternal(const Shape& shape, 168 absl::Span<const XlaOp> operands, 169 const XlaComputation& comparator, 170 int64_t dimension, bool is_stable) override; 171 172 StatusOr<XlaOp> WhileInternal(const Shape& shape, 173 const XlaComputation& condition, 174 const XlaComputation& body, 175 XlaOp init) override; 176 177 StatusOr<XlaOp> ReducePrecisionInternal(const Shape& shape, XlaOp operand, 178 const int exponent_bits, 179 const int mantissa_bits) override; 180 181 StatusOr<XlaOp> GatherInternal( 182 const Shape& shape, XlaOp input, XlaOp start_indices, 183 const GatherDimensionNumbers& dimension_numbers, 184 absl::Span<const int64> slice_sizes, bool indices_are_sorted) override; 185 186 StatusOr<XlaOp> ScatterInternal( 187 const Shape& shape, XlaOp input, XlaOp scatter_indices, XlaOp updates, 188 const XlaComputation& update_computation, 189 const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted, 190 bool unique_indices) override; 191 192 StatusOr<XlaOp> SetDimensionSizeInternal(const Shape& shape, XlaOp operand, 193 XlaOp val, 194 int64_t dimension) override; 195 196 StatusOr<XlaOp> RngOpInternal(RandomDistribution distribution, 197 absl::Span<const XlaOp> parameters, 198 const Shape& shape) override; 199 StatusOr<XlaOp> RngBitGeneratorInternal(const Shape& full_result_shape, 200 RandomAlgorithm algorithm, 201 XlaOp initial_state) override; 202 203 StatusOr<XlaOp> ReshapeInternal(const Shape& shape, XlaOp operand, 204 int64_t inferred_dimension) override; 205 206 StatusOr<XlaOp> DotGeneralInternal( 207 const Shape& shape, XlaOp lhs, XlaOp rhs, 208 const DotDimensionNumbers& dimension_number, 209 const PrecisionConfig* precision_config) override; 210 211 StatusOr<XlaOp> InDimBroadcast( 212 const Shape& shape, XlaOp operand, 213 absl::Span<const int64> broadcast_dimensions) override; 214 215 StatusOr<XlaOp> AddInstruction(HloInstructionProto&& instr, HloOpcode opcode, 216 absl::Span<const XlaOp> operands) override; 217 218 StatusOr<XlaOp> Compare(const Shape& shape, XlaOp lhs, XlaOp rhs, 219 ComparisonDirection direction, 220 Comparison::Type type) override; 221 222 XlaOp BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape, XlaOp lhs, 223 XlaOp rhs) override; 224 225 StatusOr<XlaOp> AddOpWithShape(HloOpcode opcode, const Shape& shape, 226 absl::Span<const XlaOp> operands) override; 227 228 XlaOp CreateToken() override; 229 230 StatusOr<XlaOp> InfeedWithTokenInternal(const Shape& infeed_instruction_shape, 231 XlaOp token, 232 const string& config) override; 233 StatusOr<XlaOp> OutfeedWithTokenInternal( 234 XlaOp operand, XlaOp token, const Shape& shape_with_layout, 235 const string& outfeed_config) override; 236 237 StatusOr<XlaOp> ConcatInDimInternal(const Shape& shape, 238 absl::Span<const XlaOp> operands, 239 int64_t dimension) override; 240 241 StatusOr<XlaOp> GetTupleElementInternal(const Shape& shape, XlaOp tuple_data, 242 int64_t index) override; 243 244 StatusOr<XlaOp> SliceInternal(const Shape& shape, XlaOp operand, 245 absl::Span<const int64> start_indices, 246 absl::Span<const int64> limit_indices, 247 absl::Span<const int64> strides) override; 248 249 StatusOr<XlaOp> DynamicSliceInternal( 250 const Shape& shape, XlaOp operand, absl::Span<const XlaOp> start_indices, 251 absl::Span<const int64> slice_sizes) override; 252 253 StatusOr<XlaOp> DynamicUpdateSliceInternal( 254 const Shape& shape, XlaOp operand, XlaOp update, 255 absl::Span<const XlaOp> start_indices) override; 256 257 StatusOr<XlaOp> PadInternal(const Shape& shape, XlaOp operand, 258 XlaOp padding_value, 259 const PaddingConfig& padding_config) override; 260 261 StatusOr<XlaOp> TupleInternal(const Shape& shape, 262 absl::Span<const XlaOp> elements) override; 263 264 // Creates HLO dialect op and returns the result as an XlaOp. 265 StatusOr<XlaOp> CreateOp( 266 const std::string& op_name, const Shape& shape, 267 llvm::ArrayRef<XlaOp> operands, 268 llvm::ArrayRef<mlir::NamedAttribute> attributes = {}); 269 270 Status ImportComputation(const HloModuleProto& computation, 271 mlir::Region* region); 272 273 mlir::OpBuilder builder_; 274 mlir::Location loc_; 275 276 absl::flat_hash_map<int64, std::unique_ptr<Shape>> handle_to_shape_; 277 }; 278 279 } // namespace xla 280 281 #endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_MLIR_HLO_BUILDER_H_ 282