• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef XLA_MLIR_RUNTIME_CUSTOM_CALL_ENCODING_H_
17 #define XLA_MLIR_RUNTIME_CUSTOM_CALL_ENCODING_H_
18 
19 #include <functional>
20 #include <memory>
21 #include <string>
22 #include <tuple>
23 #include <type_traits>
24 #include <utility>
25 #include <vector>
26 
27 #include "llvm/ADT/StringRef.h"
28 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"  // from @llvm-project
29 #include "mlir/IR/Attributes.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
31 #include "mlir/IR/ImplicitLocOpBuilder.h"  // from @llvm-project
32 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
33 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
34 #include "tensorflow/compiler/xla/runtime/custom_call.h"
35 #include "tensorflow/compiler/xla/runtime/type_id.h"
36 
37 namespace xla {
38 namespace runtime {
39 
40 //===----------------------------------------------------------------------===//
41 // Helper classes to build XLA custom calls' lowering to the LLVM dialect.
42 //===----------------------------------------------------------------------===//
43 //
44 // Arguments to the custom call API intrinsic are encoded as an array of opaque
45 // pointers and at the runtime side available as `void**`. Runtime decodes
46 // opaque pointers to the C++ data structures (see runtime/custom_call.h), and
47 // passes them to the registered callback. Argument encoding/decoding must be
48 // compatible, otherwise it's very easy to get a segfault because of an illegal
49 // memory access.
50 //
51 // Attributes are encoded into a separate opaque storage together with names,
52 // so the runtime side can decode the attributes it needs and check that all
53 // required attributes were passed to the custom call handler.
54 //
55 // Custom call attributes are encoded as module global constants, and at run
56 // time we only need to pass a pointer to the constant section.
57 //
58 // Custom call arguments are encoded as an array of pointers allocated on the
59 // stack. Each individual argument is also encoded on the stack, because
60 // arguments are run time values and we can't encode them in the constant
61 // section.
62 
63 // Forward declare class declared below.
64 class Globals;
65 
66 //===----------------------------------------------------------------------===//
67 // Custom call arguments encoding.
68 //===----------------------------------------------------------------------===//
69 
70 // Encodes argument into stack allocated storage according to the ABI. If
71 // argument is a constant, then it can be packed as a global constant.
72 class CustomCallArgEncoding {
73  public:
74   struct Encoded {
75     mlir::Value type_id;  // !llvm.ptr<i64>
76     mlir::Value value;    // !llvm.ptr<ArgType>
77   };
78 
79   virtual ~CustomCallArgEncoding() = default;
80 
81   virtual mlir::LogicalResult Match(mlir::Value value,
82                                     mlir::Value conterted) const = 0;
83 
84   virtual mlir::FailureOr<Encoded> Encode(Globals &g,
85                                           mlir::ImplicitLocOpBuilder &b,
86                                           mlir::Value value,
87                                           mlir::Value converted) const = 0;
88 };
89 
90 // A set of registered custom call arguments encodings.
91 class CustomCallArgEncodingSet {
92  public:
93   using Encoded = CustomCallArgEncoding::Encoded;
94 
95   // Finds matching argument encoding and tries to encode the values. Returns
96   // failure if didn't match values to any of the argument encodings.
97   mlir::FailureOr<Encoded> Encode(Globals &g, mlir::ImplicitLocOpBuilder &b,
98                                   mlir::Value value,
99                                   mlir::Value converted) const;
100 
101   template <typename... Ts, typename = std::enable_if_t<sizeof...(Ts) != 0>>
Add()102   CustomCallArgEncodingSet &Add() {
103     (encodings_.emplace_back(std::make_unique<Ts>()), ...);
104     return *this;
105   }
106 
107  private:
108   std::vector<std::unique_ptr<CustomCallArgEncoding>> encodings_;
109 };
110 
111 //===----------------------------------------------------------------------===//
112 // Custom call attributes encoding.
113 //===----------------------------------------------------------------------===//
114 
115 // Attributes encoding packs attribute name, data type and a value into the
116 // module global constant, and returns values pointing to the encoded data.
117 struct CustomCallAttrEncoding {
118   static constexpr char kAttrName[] = "__rt_attr_name";
119   static constexpr char kAttrValue[] = "__rt_attr_value";
120 
121   struct Encoded {
122     mlir::Value name;     // !llvm.ptr<i8>
123     mlir::Value type_id;  // !llvm.ptr<i64>
124     mlir::Value value;    // !llvm.ptr<EncodedAttrType>
125   };
126 
127   virtual ~CustomCallAttrEncoding() = default;
128 
129   virtual mlir::LogicalResult Match(llvm::StringRef name,
130                                     mlir::Attribute attr) const = 0;
131 
132   virtual mlir::FailureOr<Encoded> Encode(Globals &g,
133                                           mlir::ImplicitLocOpBuilder &b,
134                                           llvm::StringRef name,
135                                           mlir::Attribute attr) const = 0;
136 };
137 
138 // A set of registered custom call attributes encodings.
139 class CustomCallAttrEncodingSet {
140  public:
141   using Encoded = CustomCallAttrEncoding::Encoded;
142 
143   // Finds matching attribute encoding and tries to encode the attribute.
144   // Returns failure if didn't match attribute to any of the encodings.
145   mlir::FailureOr<Encoded> Encode(Globals &g, mlir::ImplicitLocOpBuilder &b,
146                                   llvm::StringRef name,
147                                   mlir::Attribute attr) const;
148 
149   template <typename... Ts, typename = std::enable_if_t<sizeof...(Ts) != 0>>
Add()150   CustomCallAttrEncodingSet &Add() {
151     (encodings_.emplace_back(std::make_unique<Ts>()), ...);
152     return *this;
153   }
154 
155   template <typename... Ts, typename ConstructorArg,
156             typename... ConstructorArgs,
157             typename = std::enable_if_t<sizeof...(Ts) != 0>>
Add(ConstructorArg && arg,ConstructorArgs &&...args)158   CustomCallAttrEncodingSet &Add(ConstructorArg &&arg,
159                                  ConstructorArgs &&...args) {
160     (encodings_.emplace_back(std::make_unique<Ts>(arg, args...)), ...);
161     return *this;
162   }
163 
164  private:
165   std::vector<std::unique_ptr<CustomCallAttrEncoding>> encodings_;
166 };
167 
168 //===----------------------------------------------------------------------===//
169 // A set of helper functions for packing primitive attributes.
170 //===----------------------------------------------------------------------===//
171 
172 // Packs TypeID as `i64` constant value and casts it to the `!llvm.ptr<i8>`,
173 // because type id internally is implemented as an opaque pointer.
174 mlir::Value PackTypeId(Globals &g, mlir::ImplicitLocOpBuilder &b,
175                        mlir::TypeID type_id);
176 
177 // Packs string as a module global null-terminated string constant. We reuse
178 // the encoding scheme for arrays to store sting with its size, to avoid
179 // computing the length of the null-terminated string at run tine.
180 //
181 // Returns `!llvm.ptr<EncodedArray<char>>`.
182 mlir::Value PackString(Globals &g, mlir::ImplicitLocOpBuilder &b,
183                        llvm::StringRef strref, llvm::StringRef symbol_base);
184 
185 // Packs scalar attribute as a global constant. Returns `!llvm.ptr<AttrType>`.
186 mlir::Value PackScalarAttribute(Globals &g, mlir::ImplicitLocOpBuilder &b,
187                                 mlir::Attribute value,
188                                 mlir::StringRef symbol_base);
189 
190 //===----------------------------------------------------------------------===//
191 // A helper class to create global constants in the module.
192 //===----------------------------------------------------------------------===//
193 
194 class Globals {
195  public:
196   // Global value initializer that build the initialization region.
197   using GlobalInitializer =
198       std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Attribute)>;
199 
200   // Global value initializer that can return failure if it can't initialize the
201   // global value from the given attribute.
202   using FailureOrGlobalInitializer = std::function<mlir::LogicalResult(
203       mlir::ImplicitLocOpBuilder &, mlir::Attribute)>;
204 
Globals(mlir::ModuleOp module,TypeIDNameRegistry type_id_names)205   Globals(mlir::ModuleOp module, TypeIDNameRegistry type_id_names)
206       : module_(module),
207         sym_table_(module_),
208         type_id_names_(std::move(type_id_names)) {}
209 
210   // Creates a global external variable for the type id.
211   mlir::LLVM::GlobalOp GetOrCreate(mlir::ImplicitLocOpBuilder &b,
212                                    mlir::TypeID type_id);
213 
214   // Creates a global null-terminated string constant.
215   mlir::LLVM::GlobalOp GetOrCreate(mlir::ImplicitLocOpBuilder &b,
216                                    llvm::StringRef strref,
217                                    llvm::StringRef symbol_base);
218 
219   // Creates a global constant value from the attribute. Attribute type must be
220   // a valid type compatible with LLVM globals.
221   mlir::LLVM::GlobalOp GetOrCreate(mlir::ImplicitLocOpBuilder &b,
222                                    mlir::TypedAttr attr,
223                                    llvm::StringRef symbol_base);
224 
225   // Creates a global constant value of the given type from the attribute, using
226   // optional user-provided global constant initialization.
227   mlir::LLVM::GlobalOp GetOrCreate(
228       mlir::ImplicitLocOpBuilder &b, mlir::Attribute attr, mlir::Type type,
229       llvm::StringRef symbol_base, GlobalInitializer initialize = {},
230       mlir::LLVM::Linkage linkage = mlir::LLVM::Linkage::Internal);
231 
232   // Creates a global constant value of the given type from the attribute, using
233   // optional user-provided global constant initialization. Returns failure if
234   // user-provided initialization failed to initialize the global value.
235   mlir::FailureOr<mlir::LLVM::GlobalOp> TryGetOrCreate(
236       mlir::ImplicitLocOpBuilder &b, mlir::Attribute attr, mlir::Type type,
237       llvm::StringRef symbol_base, FailureOrGlobalInitializer initialize = {},
238       mlir::LLVM::Linkage linkage = mlir::LLVM::Linkage::Internal);
239 
240   // Returns the address of the global value.
241   static mlir::Value AddrOf(mlir::ImplicitLocOpBuilder &b,
242                             mlir::LLVM::GlobalOp global);
243 
244   // Return the address of the global value casted to `!llvm.ptr<i8>`.
245   static mlir::Value OpaqueAddrOf(mlir::ImplicitLocOpBuilder &b,
246                                   mlir::LLVM::GlobalOp global);
247 
module()248   mlir::ModuleOp module() { return module_; }
249 
250  private:
251   // Globals key: {attribute, encoded-type, sym-name}. We can only have global
252   // constants of one of the LLVM types, and there could be multiple ways to
253   // encode an attribute as an LLVM type, e.g. strings can be stored as null
254   // terminated array of bytes, or a pair of string size and and array of bytes.
255   using Key = std::tuple<mlir::Attribute, mlir::Type, mlir::StringAttr>;
256 
257   mlir::LLVM::GlobalOp Find(Key key);
258 
259   mlir::ModuleOp module_;
260   mlir::SymbolTable sym_table_;  // symbol table for the `module_`
261   llvm::DenseMap<Key, mlir::LLVM::GlobalOp> globals_;
262 
263   // A mapping from the TypeID to the unique type name for encoding external
264   // globals corresponding to types ids.
265   TypeIDNameRegistry type_id_names_;
266 };
267 
268 //===----------------------------------------------------------------------===//
269 // Custom call attributes encoding.
270 //===----------------------------------------------------------------------===//
271 
272 // Encodes attribute using a scheme compatible with run time attributes decoding
273 // (see `internal::DecodedAttrs` in the custom call header file).
274 //
275 // Returns a value of `!llvm.ptr<ptr<i8>>` (void**) type pointing to the encoded
276 // attributes array (array of pointers).
277 //
278 // This function is used to encode:
279 //
280 //   1. Struct attributes as aggregates of nested attributes, where the order of
281 //      attributes matches the order defined with the `AggregateAttrDef` schema
282 //      defined below.
283 //
284 //   2. Custom call attributes, where the attributes sorted lexicographically by
285 //      name, to be able to efficiently decode named attributes.
286 //
287 mlir::FailureOr<mlir::Value> EncodeAttributes(
288     Globals &g, mlir::ImplicitLocOpBuilder &b,
289     const CustomCallAttrEncodingSet &encoding, llvm::StringRef symbol_base,
290     llvm::ArrayRef<mlir::NamedAttribute> attrs);
291 
292 struct StringAttrEncoding : public CustomCallAttrEncoding {
293   mlir::LogicalResult Match(llvm::StringRef, mlir::Attribute) const final;
294   mlir::FailureOr<Encoded> Encode(Globals &g, mlir::ImplicitLocOpBuilder &b,
295                                   mlir::StringRef, mlir::Attribute) const final;
296 };
297 
298 struct ScalarAttrEncoding : public CustomCallAttrEncoding {
299   mlir::LogicalResult Match(llvm::StringRef, mlir::Attribute) const final;
300   mlir::FailureOr<Encoded> Encode(Globals &g, mlir::ImplicitLocOpBuilder &b,
301                                   mlir::StringRef, mlir::Attribute) const final;
302 };
303 
304 struct DenseElementsAttrEncoding : public CustomCallAttrEncoding {
305   mlir::LogicalResult Match(llvm::StringRef, mlir::Attribute) const final;
306   mlir::FailureOr<Encoded> Encode(Globals &g, mlir::ImplicitLocOpBuilder &b,
307                                   mlir::StringRef, mlir::Attribute) const final;
308 };
309 
310 struct ArrayAttrEncoding : public CustomCallAttrEncoding {
311   mlir::LogicalResult Match(llvm::StringRef, mlir::Attribute) const final;
312   mlir::FailureOr<Encoded> Encode(Globals &g, mlir::ImplicitLocOpBuilder &b,
313                                   mlir::StringRef, mlir::Attribute) const final;
314 };
315 
316 struct DenseArrayAttrEncoding : public CustomCallAttrEncoding {
317   mlir::LogicalResult Match(llvm::StringRef, mlir::Attribute) const final;
318   mlir::FailureOr<Encoded> Encode(Globals &g, mlir::ImplicitLocOpBuilder &b,
319                                   mlir::StringRef, mlir::Attribute) const final;
320 };
321 
322 struct EmptyArrayAttrEncoding : public CustomCallAttrEncoding {
323   mlir::LogicalResult Match(llvm::StringRef, mlir::Attribute) const final;
324   mlir::FailureOr<Encoded> Encode(Globals &g, mlir::ImplicitLocOpBuilder &b,
325                                   mlir::StringRef, mlir::Attribute) const final;
326 };
327 
328 // Custom call attribute encoding that encodes enums using their underlying
329 // scalar type. Type id is based on the enum type passed to the runtime.
330 //
331 // This encoding can convert enum types defined in the compiler (e.g. dialect
332 // enums defined in MLIR) to the enum types used at run time.
333 template <typename AttrType, typename EnumType,
334           typename RuntimeEnumType = EnumType>
335 struct EnumAttrEncoding : public CustomCallAttrEncoding {
336   static_assert(std::is_enum<RuntimeEnumType>::value, "must be an enum class");
337 
338   // Convert from the compile time enum to the run time enum.
339   using Converter = std::function<RuntimeEnumType(EnumType)>;
340 
EnumAttrEncodingEnumAttrEncoding341   EnumAttrEncoding() {
342     static_assert(std::is_same<EnumType, RuntimeEnumType>::value,
343                   "requires enum converter");
344     convert = [](EnumType value) { return value; };
345   }
346 
EnumAttrEncodingEnumAttrEncoding347   explicit EnumAttrEncoding(Converter convert) : convert(std::move(convert)) {}
348 
MatchEnumAttrEncoding349   mlir::LogicalResult Match(llvm::StringRef, mlir::Attribute attr) const final {
350     return mlir::success(attr.isa<AttrType>());
351   }
352 
EncodeEnumAttrEncoding353   mlir::FailureOr<Encoded> Encode(Globals &g, mlir::ImplicitLocOpBuilder &b,
354                                   mlir::StringRef name,
355                                   mlir::Attribute attr) const final {
356     // Convert enum underlying integral value to an attribute.
357     EnumType compile_time_enum = attr.cast<AttrType>().getValue();
358     RuntimeEnumType run_time_enum = convert(compile_time_enum);
359 
360     using T = std::underlying_type_t<RuntimeEnumType>;
361     T underlying_value = static_cast<T>(run_time_enum);
362 
363     mlir::TypeID type_id = mlir::TypeID::get<Tagged<RuntimeEnumType>>();
364     mlir::Attribute underlying_attr = AsAttr(b, underlying_value);
365 
366     Encoded encoded;
367     encoded.name = PackString(g, b, name, kAttrName);
368     encoded.type_id = PackTypeId(g, b, type_id);
369     encoded.value = PackScalarAttribute(g, b, underlying_attr, kAttrValue);
370 
371     return encoded;
372   }
373 
AsAttrEnumAttrEncoding374   static mlir::Attribute AsAttr(mlir::ImplicitLocOpBuilder &b, uint32_t value) {
375     return b.getI32IntegerAttr(value);
376   }
377 
378   Converter convert;
379 };
380 
381 // A helper type to define `AttrType` encoding scheme.
382 template <typename AttrType>
383 struct AggregateAttrDef {
384   template <typename T>
385   using Extract = T (AttrType::*)() const;
386 
387   template <typename T, typename Attr = mlir::Attribute>
388   using Encode = Attr (mlir::Builder::*)(T);
389 
390   template <typename T, typename Attr = mlir::Attribute>
AddAggregateAttrDef391   AggregateAttrDef &Add(std::string name, Extract<T> extract,
392                         Encode<T, Attr> encode) {
393     bindings.emplace_back([=](AttrType attr, mlir::Builder &b) {
394       auto encoded = std::invoke(encode, b, std::invoke(extract, attr));
395       return mlir::NamedAttribute(b.getStringAttr(name), encoded);
396     });
397     return *this;
398   }
399 
AddAggregateAttrDef400   AggregateAttrDef &Add(std::string name, Extract<bool> extract) {
401     return Add(name, extract, &mlir::Builder::getBoolAttr);
402   }
403 
AddAggregateAttrDef404   AggregateAttrDef &Add(std::string name, Extract<int64_t> extract) {
405     return Add(name, extract, &mlir::Builder::getI64IntegerAttr);
406   }
407 
AddAggregateAttrDef408   AggregateAttrDef &Add(std::string name,
409                         Extract<llvm::ArrayRef<int64_t>> extract) {
410     return Add(name, extract, &mlir::Builder::getI64TensorAttr);
411   }
412 
413   // A list of functions to destruct `AttrType` attribute into the aggregate
414   // attributes that will be used for encoding.
415   using Bind = std::function<mlir::NamedAttribute(AttrType, mlir::Builder &)>;
416   llvm::SmallVector<Bind> bindings;
417 };
418 
419 // Custom call attribute encoding for the user-defined attributes which encodes
420 // them as an aggregate of primitive attributes. It uses the encoding scheme
421 // compatible with the custom call attributes decoding.
422 template <typename AttrType, typename RuntimeType = AttrType>
423 struct AggregateAttrEncoding : public CustomCallAttrEncoding {
424   using AttrDef = AggregateAttrDef<AttrType>;
425 
AggregateAttrEncodingAggregateAttrEncoding426   AggregateAttrEncoding(const CustomCallAttrEncodingSet &encoding,
427                         AttrDef attrdef)
428       : encoding(encoding), attrdef(std::move(attrdef)) {}
429 
MatchAggregateAttrEncoding430   mlir::LogicalResult Match(llvm::StringRef, mlir::Attribute attr) const final {
431     return mlir::success(attr.isa<AttrType>());
432   }
433 
EncodeAggregateAttrEncoding434   mlir::FailureOr<Encoded> Encode(Globals &g, mlir::ImplicitLocOpBuilder &b,
435                                   mlir::StringRef name,
436                                   mlir::Attribute attr) const final {
437     // Extract aggregate attributes from the user-defined attributes.
438     llvm::SmallVector<mlir::NamedAttribute> attrs;
439     for (auto &bind : attrdef.bindings)
440       attrs.emplace_back(bind(attr.cast<AttrType>(), b));
441 
442     // Encode extracted attributes as an aggregate.
443     auto type_id = mlir::TypeID::get<Tagged<RuntimeType>>();
444     auto sym = "__rt_aggregate_" + AttrType::getMnemonic();
445     auto aggregate = EncodeAttributes(g, b, encoding, sym.str(), attrs);
446     if (mlir::failed(aggregate)) return mlir::failure();
447 
448     Encoded encoded;
449     encoded.name = PackString(g, b, name, kAttrName);
450     encoded.type_id = PackTypeId(g, b, type_id);
451     encoded.value = *aggregate;
452     return encoded;
453   }
454 
455   const CustomCallAttrEncodingSet &encoding;
456   AttrDef attrdef;
457 };
458 
459 //===----------------------------------------------------------------------===//
460 // Custom call arguments encoding.
461 //===----------------------------------------------------------------------===//
462 
463 // Encodes scalar operands.
464 class ScalarArgEncoding : public CustomCallArgEncoding {
465  public:
466   mlir::LogicalResult Match(mlir::Value, mlir::Value) const final;
467   mlir::FailureOr<Encoded> Encode(Globals &g, mlir::ImplicitLocOpBuilder &b,
468                                   mlir::Value, mlir::Value) const final;
469 };
470 
471 // Encodes MemRef operands according to the (Strided)MemrefView ABI.
472 class MemrefArgEncoding : public CustomCallArgEncoding {
473  public:
474   mlir::LogicalResult Match(mlir::Value, mlir::Value) const final;
475   mlir::FailureOr<Encoded> Encode(Globals &g, mlir::ImplicitLocOpBuilder &b,
476                                   mlir::Value, mlir::Value) const final;
477 
478  private:
479   // Encodes memref as LLVM struct value:
480   //
481   //   { i8: dtype, i8: rank, ptr<i8>: data,
482   //     array<2*rank x i64>: sizes_and_strides }
483   //
484   // This is a type erased version of the MLIR memref descriptor without base
485   // pointer. We pack sizes and strides as a single array member, so that on
486   // the runtime side we can read it back using C flexible array member.
487   mlir::Value EncodeMemRef(mlir::ImplicitLocOpBuilder &b,
488                            mlir::MemRefType memref_ty,
489                            mlir::Value descriptor) const;
490 };
491 
492 //===----------------------------------------------------------------------===//
493 // Default encodings for arguments and attributes.
494 //===----------------------------------------------------------------------===//
495 
496 CustomCallArgEncodingSet DefaultArgEncodings();
497 CustomCallAttrEncodingSet DefaultAttrEncodings();
498 
499 }  // namespace runtime
500 }  // namespace xla
501 
502 #endif  // XLA_MLIR_RUNTIME_CUSTOM_CALL_ENCODING_H_
503