1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef XLA_MLIR_RUNTIME_CUSTOM_CALL_ENCODING_H_ 17 #define XLA_MLIR_RUNTIME_CUSTOM_CALL_ENCODING_H_ 18 19 #include <functional> 20 #include <memory> 21 #include <string> 22 #include <tuple> 23 #include <type_traits> 24 #include <utility> 25 #include <vector> 26 27 #include "llvm/ADT/StringRef.h" 28 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project 29 #include "mlir/IR/Attributes.h" // from @llvm-project 30 #include "mlir/IR/BuiltinOps.h" // from @llvm-project 31 #include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project 32 #include "mlir/IR/SymbolTable.h" // from @llvm-project 33 #include "mlir/Support/LogicalResult.h" // from @llvm-project 34 #include "tensorflow/compiler/xla/runtime/custom_call.h" 35 #include "tensorflow/compiler/xla/runtime/type_id.h" 36 37 namespace xla { 38 namespace runtime { 39 40 //===----------------------------------------------------------------------===// 41 // Helper classes to build XLA custom calls' lowering to the LLVM dialect. 42 //===----------------------------------------------------------------------===// 43 // 44 // Arguments to the custom call API intrinsic are encoded as an array of opaque 45 // pointers and at the runtime side available as `void**`. Runtime decodes 46 // opaque pointers to the C++ data structures (see runtime/custom_call.h), and 47 // passes them to the registered callback. Argument encoding/decoding must be 48 // compatible, otherwise it's very easy to get a segfault because of an illegal 49 // memory access. 50 // 51 // Attributes are encoded into a separate opaque storage together with names, 52 // so the runtime side can decode the attributes it needs and check that all 53 // required attributes were passed to the custom call handler. 54 // 55 // Custom call attributes are encoded as module global constants, and at run 56 // time we only need to pass a pointer to the constant section. 57 // 58 // Custom call arguments are encoded as an array of pointers allocated on the 59 // stack. Each individual argument is also encoded on the stack, because 60 // arguments are run time values and we can't encode them in the constant 61 // section. 62 63 // Forward declare class declared below. 64 class Globals; 65 66 //===----------------------------------------------------------------------===// 67 // Custom call arguments encoding. 68 //===----------------------------------------------------------------------===// 69 70 // Encodes argument into stack allocated storage according to the ABI. If 71 // argument is a constant, then it can be packed as a global constant. 72 class CustomCallArgEncoding { 73 public: 74 struct Encoded { 75 mlir::Value type_id; // !llvm.ptr<i64> 76 mlir::Value value; // !llvm.ptr<ArgType> 77 }; 78 79 virtual ~CustomCallArgEncoding() = default; 80 81 virtual mlir::LogicalResult Match(mlir::Value value, 82 mlir::Value conterted) const = 0; 83 84 virtual mlir::FailureOr<Encoded> Encode(Globals &g, 85 mlir::ImplicitLocOpBuilder &b, 86 mlir::Value value, 87 mlir::Value converted) const = 0; 88 }; 89 90 // A set of registered custom call arguments encodings. 91 class CustomCallArgEncodingSet { 92 public: 93 using Encoded = CustomCallArgEncoding::Encoded; 94 95 // Finds matching argument encoding and tries to encode the values. Returns 96 // failure if didn't match values to any of the argument encodings. 97 mlir::FailureOr<Encoded> Encode(Globals &g, mlir::ImplicitLocOpBuilder &b, 98 mlir::Value value, 99 mlir::Value converted) const; 100 101 template <typename... Ts, typename = std::enable_if_t<sizeof...(Ts) != 0>> Add()102 CustomCallArgEncodingSet &Add() { 103 (encodings_.emplace_back(std::make_unique<Ts>()), ...); 104 return *this; 105 } 106 107 private: 108 std::vector<std::unique_ptr<CustomCallArgEncoding>> encodings_; 109 }; 110 111 //===----------------------------------------------------------------------===// 112 // Custom call attributes encoding. 113 //===----------------------------------------------------------------------===// 114 115 // Attributes encoding packs attribute name, data type and a value into the 116 // module global constant, and returns values pointing to the encoded data. 117 struct CustomCallAttrEncoding { 118 static constexpr char kAttrName[] = "__rt_attr_name"; 119 static constexpr char kAttrValue[] = "__rt_attr_value"; 120 121 struct Encoded { 122 mlir::Value name; // !llvm.ptr<i8> 123 mlir::Value type_id; // !llvm.ptr<i64> 124 mlir::Value value; // !llvm.ptr<EncodedAttrType> 125 }; 126 127 virtual ~CustomCallAttrEncoding() = default; 128 129 virtual mlir::LogicalResult Match(llvm::StringRef name, 130 mlir::Attribute attr) const = 0; 131 132 virtual mlir::FailureOr<Encoded> Encode(Globals &g, 133 mlir::ImplicitLocOpBuilder &b, 134 llvm::StringRef name, 135 mlir::Attribute attr) const = 0; 136 }; 137 138 // A set of registered custom call attributes encodings. 139 class CustomCallAttrEncodingSet { 140 public: 141 using Encoded = CustomCallAttrEncoding::Encoded; 142 143 // Finds matching attribute encoding and tries to encode the attribute. 144 // Returns failure if didn't match attribute to any of the encodings. 145 mlir::FailureOr<Encoded> Encode(Globals &g, mlir::ImplicitLocOpBuilder &b, 146 llvm::StringRef name, 147 mlir::Attribute attr) const; 148 149 template <typename... Ts, typename = std::enable_if_t<sizeof...(Ts) != 0>> Add()150 CustomCallAttrEncodingSet &Add() { 151 (encodings_.emplace_back(std::make_unique<Ts>()), ...); 152 return *this; 153 } 154 155 template <typename... Ts, typename ConstructorArg, 156 typename... ConstructorArgs, 157 typename = std::enable_if_t<sizeof...(Ts) != 0>> Add(ConstructorArg && arg,ConstructorArgs &&...args)158 CustomCallAttrEncodingSet &Add(ConstructorArg &&arg, 159 ConstructorArgs &&...args) { 160 (encodings_.emplace_back(std::make_unique<Ts>(arg, args...)), ...); 161 return *this; 162 } 163 164 private: 165 std::vector<std::unique_ptr<CustomCallAttrEncoding>> encodings_; 166 }; 167 168 //===----------------------------------------------------------------------===// 169 // A set of helper functions for packing primitive attributes. 170 //===----------------------------------------------------------------------===// 171 172 // Packs TypeID as `i64` constant value and casts it to the `!llvm.ptr<i8>`, 173 // because type id internally is implemented as an opaque pointer. 174 mlir::Value PackTypeId(Globals &g, mlir::ImplicitLocOpBuilder &b, 175 mlir::TypeID type_id); 176 177 // Packs string as a module global null-terminated string constant. We reuse 178 // the encoding scheme for arrays to store sting with its size, to avoid 179 // computing the length of the null-terminated string at run tine. 180 // 181 // Returns `!llvm.ptr<EncodedArray<char>>`. 182 mlir::Value PackString(Globals &g, mlir::ImplicitLocOpBuilder &b, 183 llvm::StringRef strref, llvm::StringRef symbol_base); 184 185 // Packs scalar attribute as a global constant. Returns `!llvm.ptr<AttrType>`. 186 mlir::Value PackScalarAttribute(Globals &g, mlir::ImplicitLocOpBuilder &b, 187 mlir::Attribute value, 188 mlir::StringRef symbol_base); 189 190 //===----------------------------------------------------------------------===// 191 // A helper class to create global constants in the module. 192 //===----------------------------------------------------------------------===// 193 194 class Globals { 195 public: 196 // Global value initializer that build the initialization region. 197 using GlobalInitializer = 198 std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Attribute)>; 199 200 // Global value initializer that can return failure if it can't initialize the 201 // global value from the given attribute. 202 using FailureOrGlobalInitializer = std::function<mlir::LogicalResult( 203 mlir::ImplicitLocOpBuilder &, mlir::Attribute)>; 204 Globals(mlir::ModuleOp module,TypeIDNameRegistry type_id_names)205 Globals(mlir::ModuleOp module, TypeIDNameRegistry type_id_names) 206 : module_(module), 207 sym_table_(module_), 208 type_id_names_(std::move(type_id_names)) {} 209 210 // Creates a global external variable for the type id. 211 mlir::LLVM::GlobalOp GetOrCreate(mlir::ImplicitLocOpBuilder &b, 212 mlir::TypeID type_id); 213 214 // Creates a global null-terminated string constant. 215 mlir::LLVM::GlobalOp GetOrCreate(mlir::ImplicitLocOpBuilder &b, 216 llvm::StringRef strref, 217 llvm::StringRef symbol_base); 218 219 // Creates a global constant value from the attribute. Attribute type must be 220 // a valid type compatible with LLVM globals. 221 mlir::LLVM::GlobalOp GetOrCreate(mlir::ImplicitLocOpBuilder &b, 222 mlir::TypedAttr attr, 223 llvm::StringRef symbol_base); 224 225 // Creates a global constant value of the given type from the attribute, using 226 // optional user-provided global constant initialization. 227 mlir::LLVM::GlobalOp GetOrCreate( 228 mlir::ImplicitLocOpBuilder &b, mlir::Attribute attr, mlir::Type type, 229 llvm::StringRef symbol_base, GlobalInitializer initialize = {}, 230 mlir::LLVM::Linkage linkage = mlir::LLVM::Linkage::Internal); 231 232 // Creates a global constant value of the given type from the attribute, using 233 // optional user-provided global constant initialization. Returns failure if 234 // user-provided initialization failed to initialize the global value. 235 mlir::FailureOr<mlir::LLVM::GlobalOp> TryGetOrCreate( 236 mlir::ImplicitLocOpBuilder &b, mlir::Attribute attr, mlir::Type type, 237 llvm::StringRef symbol_base, FailureOrGlobalInitializer initialize = {}, 238 mlir::LLVM::Linkage linkage = mlir::LLVM::Linkage::Internal); 239 240 // Returns the address of the global value. 241 static mlir::Value AddrOf(mlir::ImplicitLocOpBuilder &b, 242 mlir::LLVM::GlobalOp global); 243 244 // Return the address of the global value casted to `!llvm.ptr<i8>`. 245 static mlir::Value OpaqueAddrOf(mlir::ImplicitLocOpBuilder &b, 246 mlir::LLVM::GlobalOp global); 247 module()248 mlir::ModuleOp module() { return module_; } 249 250 private: 251 // Globals key: {attribute, encoded-type, sym-name}. We can only have global 252 // constants of one of the LLVM types, and there could be multiple ways to 253 // encode an attribute as an LLVM type, e.g. strings can be stored as null 254 // terminated array of bytes, or a pair of string size and and array of bytes. 255 using Key = std::tuple<mlir::Attribute, mlir::Type, mlir::StringAttr>; 256 257 mlir::LLVM::GlobalOp Find(Key key); 258 259 mlir::ModuleOp module_; 260 mlir::SymbolTable sym_table_; // symbol table for the `module_` 261 llvm::DenseMap<Key, mlir::LLVM::GlobalOp> globals_; 262 263 // A mapping from the TypeID to the unique type name for encoding external 264 // globals corresponding to types ids. 265 TypeIDNameRegistry type_id_names_; 266 }; 267 268 //===----------------------------------------------------------------------===// 269 // Custom call attributes encoding. 270 //===----------------------------------------------------------------------===// 271 272 // Encodes attribute using a scheme compatible with run time attributes decoding 273 // (see `internal::DecodedAttrs` in the custom call header file). 274 // 275 // Returns a value of `!llvm.ptr<ptr<i8>>` (void**) type pointing to the encoded 276 // attributes array (array of pointers). 277 // 278 // This function is used to encode: 279 // 280 // 1. Struct attributes as aggregates of nested attributes, where the order of 281 // attributes matches the order defined with the `AggregateAttrDef` schema 282 // defined below. 283 // 284 // 2. Custom call attributes, where the attributes sorted lexicographically by 285 // name, to be able to efficiently decode named attributes. 286 // 287 mlir::FailureOr<mlir::Value> EncodeAttributes( 288 Globals &g, mlir::ImplicitLocOpBuilder &b, 289 const CustomCallAttrEncodingSet &encoding, llvm::StringRef symbol_base, 290 llvm::ArrayRef<mlir::NamedAttribute> attrs); 291 292 struct StringAttrEncoding : public CustomCallAttrEncoding { 293 mlir::LogicalResult Match(llvm::StringRef, mlir::Attribute) const final; 294 mlir::FailureOr<Encoded> Encode(Globals &g, mlir::ImplicitLocOpBuilder &b, 295 mlir::StringRef, mlir::Attribute) const final; 296 }; 297 298 struct ScalarAttrEncoding : public CustomCallAttrEncoding { 299 mlir::LogicalResult Match(llvm::StringRef, mlir::Attribute) const final; 300 mlir::FailureOr<Encoded> Encode(Globals &g, mlir::ImplicitLocOpBuilder &b, 301 mlir::StringRef, mlir::Attribute) const final; 302 }; 303 304 struct DenseElementsAttrEncoding : public CustomCallAttrEncoding { 305 mlir::LogicalResult Match(llvm::StringRef, mlir::Attribute) const final; 306 mlir::FailureOr<Encoded> Encode(Globals &g, mlir::ImplicitLocOpBuilder &b, 307 mlir::StringRef, mlir::Attribute) const final; 308 }; 309 310 struct ArrayAttrEncoding : public CustomCallAttrEncoding { 311 mlir::LogicalResult Match(llvm::StringRef, mlir::Attribute) const final; 312 mlir::FailureOr<Encoded> Encode(Globals &g, mlir::ImplicitLocOpBuilder &b, 313 mlir::StringRef, mlir::Attribute) const final; 314 }; 315 316 struct DenseArrayAttrEncoding : public CustomCallAttrEncoding { 317 mlir::LogicalResult Match(llvm::StringRef, mlir::Attribute) const final; 318 mlir::FailureOr<Encoded> Encode(Globals &g, mlir::ImplicitLocOpBuilder &b, 319 mlir::StringRef, mlir::Attribute) const final; 320 }; 321 322 struct EmptyArrayAttrEncoding : public CustomCallAttrEncoding { 323 mlir::LogicalResult Match(llvm::StringRef, mlir::Attribute) const final; 324 mlir::FailureOr<Encoded> Encode(Globals &g, mlir::ImplicitLocOpBuilder &b, 325 mlir::StringRef, mlir::Attribute) const final; 326 }; 327 328 // Custom call attribute encoding that encodes enums using their underlying 329 // scalar type. Type id is based on the enum type passed to the runtime. 330 // 331 // This encoding can convert enum types defined in the compiler (e.g. dialect 332 // enums defined in MLIR) to the enum types used at run time. 333 template <typename AttrType, typename EnumType, 334 typename RuntimeEnumType = EnumType> 335 struct EnumAttrEncoding : public CustomCallAttrEncoding { 336 static_assert(std::is_enum<RuntimeEnumType>::value, "must be an enum class"); 337 338 // Convert from the compile time enum to the run time enum. 339 using Converter = std::function<RuntimeEnumType(EnumType)>; 340 EnumAttrEncodingEnumAttrEncoding341 EnumAttrEncoding() { 342 static_assert(std::is_same<EnumType, RuntimeEnumType>::value, 343 "requires enum converter"); 344 convert = [](EnumType value) { return value; }; 345 } 346 EnumAttrEncodingEnumAttrEncoding347 explicit EnumAttrEncoding(Converter convert) : convert(std::move(convert)) {} 348 MatchEnumAttrEncoding349 mlir::LogicalResult Match(llvm::StringRef, mlir::Attribute attr) const final { 350 return mlir::success(attr.isa<AttrType>()); 351 } 352 EncodeEnumAttrEncoding353 mlir::FailureOr<Encoded> Encode(Globals &g, mlir::ImplicitLocOpBuilder &b, 354 mlir::StringRef name, 355 mlir::Attribute attr) const final { 356 // Convert enum underlying integral value to an attribute. 357 EnumType compile_time_enum = attr.cast<AttrType>().getValue(); 358 RuntimeEnumType run_time_enum = convert(compile_time_enum); 359 360 using T = std::underlying_type_t<RuntimeEnumType>; 361 T underlying_value = static_cast<T>(run_time_enum); 362 363 mlir::TypeID type_id = mlir::TypeID::get<Tagged<RuntimeEnumType>>(); 364 mlir::Attribute underlying_attr = AsAttr(b, underlying_value); 365 366 Encoded encoded; 367 encoded.name = PackString(g, b, name, kAttrName); 368 encoded.type_id = PackTypeId(g, b, type_id); 369 encoded.value = PackScalarAttribute(g, b, underlying_attr, kAttrValue); 370 371 return encoded; 372 } 373 AsAttrEnumAttrEncoding374 static mlir::Attribute AsAttr(mlir::ImplicitLocOpBuilder &b, uint32_t value) { 375 return b.getI32IntegerAttr(value); 376 } 377 378 Converter convert; 379 }; 380 381 // A helper type to define `AttrType` encoding scheme. 382 template <typename AttrType> 383 struct AggregateAttrDef { 384 template <typename T> 385 using Extract = T (AttrType::*)() const; 386 387 template <typename T, typename Attr = mlir::Attribute> 388 using Encode = Attr (mlir::Builder::*)(T); 389 390 template <typename T, typename Attr = mlir::Attribute> AddAggregateAttrDef391 AggregateAttrDef &Add(std::string name, Extract<T> extract, 392 Encode<T, Attr> encode) { 393 bindings.emplace_back([=](AttrType attr, mlir::Builder &b) { 394 auto encoded = std::invoke(encode, b, std::invoke(extract, attr)); 395 return mlir::NamedAttribute(b.getStringAttr(name), encoded); 396 }); 397 return *this; 398 } 399 AddAggregateAttrDef400 AggregateAttrDef &Add(std::string name, Extract<bool> extract) { 401 return Add(name, extract, &mlir::Builder::getBoolAttr); 402 } 403 AddAggregateAttrDef404 AggregateAttrDef &Add(std::string name, Extract<int64_t> extract) { 405 return Add(name, extract, &mlir::Builder::getI64IntegerAttr); 406 } 407 AddAggregateAttrDef408 AggregateAttrDef &Add(std::string name, 409 Extract<llvm::ArrayRef<int64_t>> extract) { 410 return Add(name, extract, &mlir::Builder::getI64TensorAttr); 411 } 412 413 // A list of functions to destruct `AttrType` attribute into the aggregate 414 // attributes that will be used for encoding. 415 using Bind = std::function<mlir::NamedAttribute(AttrType, mlir::Builder &)>; 416 llvm::SmallVector<Bind> bindings; 417 }; 418 419 // Custom call attribute encoding for the user-defined attributes which encodes 420 // them as an aggregate of primitive attributes. It uses the encoding scheme 421 // compatible with the custom call attributes decoding. 422 template <typename AttrType, typename RuntimeType = AttrType> 423 struct AggregateAttrEncoding : public CustomCallAttrEncoding { 424 using AttrDef = AggregateAttrDef<AttrType>; 425 AggregateAttrEncodingAggregateAttrEncoding426 AggregateAttrEncoding(const CustomCallAttrEncodingSet &encoding, 427 AttrDef attrdef) 428 : encoding(encoding), attrdef(std::move(attrdef)) {} 429 MatchAggregateAttrEncoding430 mlir::LogicalResult Match(llvm::StringRef, mlir::Attribute attr) const final { 431 return mlir::success(attr.isa<AttrType>()); 432 } 433 EncodeAggregateAttrEncoding434 mlir::FailureOr<Encoded> Encode(Globals &g, mlir::ImplicitLocOpBuilder &b, 435 mlir::StringRef name, 436 mlir::Attribute attr) const final { 437 // Extract aggregate attributes from the user-defined attributes. 438 llvm::SmallVector<mlir::NamedAttribute> attrs; 439 for (auto &bind : attrdef.bindings) 440 attrs.emplace_back(bind(attr.cast<AttrType>(), b)); 441 442 // Encode extracted attributes as an aggregate. 443 auto type_id = mlir::TypeID::get<Tagged<RuntimeType>>(); 444 auto sym = "__rt_aggregate_" + AttrType::getMnemonic(); 445 auto aggregate = EncodeAttributes(g, b, encoding, sym.str(), attrs); 446 if (mlir::failed(aggregate)) return mlir::failure(); 447 448 Encoded encoded; 449 encoded.name = PackString(g, b, name, kAttrName); 450 encoded.type_id = PackTypeId(g, b, type_id); 451 encoded.value = *aggregate; 452 return encoded; 453 } 454 455 const CustomCallAttrEncodingSet &encoding; 456 AttrDef attrdef; 457 }; 458 459 //===----------------------------------------------------------------------===// 460 // Custom call arguments encoding. 461 //===----------------------------------------------------------------------===// 462 463 // Encodes scalar operands. 464 class ScalarArgEncoding : public CustomCallArgEncoding { 465 public: 466 mlir::LogicalResult Match(mlir::Value, mlir::Value) const final; 467 mlir::FailureOr<Encoded> Encode(Globals &g, mlir::ImplicitLocOpBuilder &b, 468 mlir::Value, mlir::Value) const final; 469 }; 470 471 // Encodes MemRef operands according to the (Strided)MemrefView ABI. 472 class MemrefArgEncoding : public CustomCallArgEncoding { 473 public: 474 mlir::LogicalResult Match(mlir::Value, mlir::Value) const final; 475 mlir::FailureOr<Encoded> Encode(Globals &g, mlir::ImplicitLocOpBuilder &b, 476 mlir::Value, mlir::Value) const final; 477 478 private: 479 // Encodes memref as LLVM struct value: 480 // 481 // { i8: dtype, i8: rank, ptr<i8>: data, 482 // array<2*rank x i64>: sizes_and_strides } 483 // 484 // This is a type erased version of the MLIR memref descriptor without base 485 // pointer. We pack sizes and strides as a single array member, so that on 486 // the runtime side we can read it back using C flexible array member. 487 mlir::Value EncodeMemRef(mlir::ImplicitLocOpBuilder &b, 488 mlir::MemRefType memref_ty, 489 mlir::Value descriptor) const; 490 }; 491 492 //===----------------------------------------------------------------------===// 493 // Default encodings for arguments and attributes. 494 //===----------------------------------------------------------------------===// 495 496 CustomCallArgEncodingSet DefaultArgEncodings(); 497 CustomCallAttrEncodingSet DefaultAttrEncodings(); 498 499 } // namespace runtime 500 } // namespace xla 501 502 #endif // XLA_MLIR_RUNTIME_CUSTOM_CALL_ENCODING_H_ 503