• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef XLA_RUNTIME_CUSTOM_CALL_H_
17 #define XLA_RUNTIME_CUSTOM_CALL_H_
18 
19 #include <any>
20 #include <cstddef>
21 #include <cstdint>
22 #include <functional>
23 #include <iterator>
24 #include <numeric>
25 #include <string>
26 #include <tuple>
27 #include <type_traits>
28 #include <utility>
29 #include <vector>
30 
31 #include "absl/base/dynamic_annotations.h"
32 #include "third_party/eigen3/Eigen/Core"
33 #include "llvm/ADT/Any.h"
34 #include "llvm/ADT/ArrayRef.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/ADT/StringExtras.h"
37 #include "llvm/ADT/StringMap.h"
38 #include "llvm/ADT/StringRef.h"
39 #include "llvm/Support/Compiler.h"
40 #include "llvm/Support/Error.h"
41 #include "tensorflow/compiler/xla/runtime/diagnostics.h"
42 #include "tensorflow/compiler/xla/runtime/logical_result.h"
43 #include "tensorflow/compiler/xla/runtime/type_id.h"
44 #include "tfrt/dtype/dtype.h"  // from @tf_runtime
45 #include "tfrt/support/map_by_type.h"  // from @tf_runtime
46 
47 namespace xla {
48 namespace runtime {
49 
50 // Forward declare.
51 struct KernelContext;
52 
53 // Forward declare template defined below.
54 template <typename... Ts>
55 class CustomCallBinding;
56 
57 // Registers mappings from TypeIDs supported by the custom calls to their unique
58 // names in the given registry.
59 void PopulateCustomCallTypeIdNames(TypeIDNameRegistry& registry);
60 
61 class CustomCall {
62  public:
63   // Container for passing data between XLA user and the custom call handler.
64   using UserData = tfrt::PtrMapByType<CustomCall>;
65 
66   // A type for matching all remaining custom call arguments.
67   class RemainingArgs;
68 
69   // A type for passing an argument of different types at the same position,
70   // and the handler will do the decoding.
71   class VariantArg;
72   class VariantAttr;
73 
74   // A type for representing tensors with shapes.
75   template <typename T>
76   struct TensorRef {
77     llvm::ArrayRef<int64_t> shape;
78     llvm::ArrayRef<T> data;
79   };
80 
81   // Custom call handler can check arguments and attributes types and names
82   // at runtime, however this comes at extra cost and can be optionally
83   // disabled. If the version of the compiler that generated the XLA executable
84   // doesn't match the custom call handler, it can lead to undefined behavior.
85   enum class RuntimeChecks : uint8_t {
86     // Check arguments and attributes types, also check attribute names. It is
87     // safe to pass extra arguments to the custom call handler when name
88     // checking is enabled, because it will safely skip irrelevant attributes.
89     kDefault = 0,
90 
91     // Check only the types of the arguments and attributes. If an attribute
92     // with the same type but different name is passed to the custom call
93     // handler,
94     // it will happily proceed ignoring the name mismatch.
95     kTypes = 1,
96 
97     // Do not check the number of arguments and attributes and their types, and
98     // do not check that the user data was passed to the custom call. This is
99     // the most dangerous option, because it blindly reinterprets opaque memory
100     // passed to the handler, and can easily lead to segfaults if the data
101     // doesn't match the expected custom call signature.
102     kNone = 2
103   };
104 
105   // Allows to bind custom calls to handlers with optional arguments without
106   // spelling the full type.
107   //
108   // Example:
109   //
110   //   LogicalResult MyCustomCall(Optional<int32_t> version);
111   //
112   //   CustomCall::Bind("api").Value(CustomCall::None).To(MyCustomCall);
113   //
114   // Works around the fact that llvm::Optional can't store an instance of
115   // llvm::NoneType (llvm::Optional<llvm::NoneType> has ambiguous constructor).
116   struct NoneType {
117     template <typename T>
118     operator llvm::Optional<T>() const {  // NOLINT
119       return llvm::None;
120     }
121   };
122 
123   static constexpr NoneType None = {};  // NOLINT
124 
CheckNames(RuntimeChecks checks)125   static constexpr bool CheckNames(RuntimeChecks checks) {
126     return checks == RuntimeChecks::kDefault;
127   }
128 
CheckTypes(RuntimeChecks checks)129   static constexpr bool CheckTypes(RuntimeChecks checks) {
130     return checks != RuntimeChecks::kNone;
131   }
132 
CheckUserData(RuntimeChecks checks)133   static constexpr bool CheckUserData(RuntimeChecks checks) {
134     return checks != RuntimeChecks::kNone;
135   }
136 
137   template <typename T>
CheckType(RuntimeChecks checks,TypeID type_id)138   static bool CheckType(RuntimeChecks checks, TypeID type_id) {
139     return !CheckTypes(checks) || type_id == TypeID::get<T>();
140   }
141 
142   virtual ~CustomCall() = default;
143 
144   virtual llvm::StringRef name() const = 0;
145   virtual LogicalResult call(void** args, void** attrs,
146                              const UserData* user_data,
147                              const DiagnosticEngine* diagnostic) const = 0;
148 
149   static CustomCallBinding<> Bind(std::string callee);
150 };
151 
152 // Direct custom call is a custom call that can be linked directly with the
153 // compiled executable, and doesn't have to go through the custom call look up
154 // by name at run time (see CustomCallRegistry).
155 //
156 // Direct custom call is a preffered way of implemenenting custom calls with
157 // low run time overheads, as they will become just an indirect function calls
158 // once LLVM ORC links them with the executable.
159 //
160 // See `GetSymbolsBinding` to convert custom call library to symbols binding.
161 class DirectCustomCallLibrary {
162  public:
163   // Function type corresponding to the direct custom call (custom calls
164   // linked directly with the compiled executable).
165   using DirectCustomCall = bool (*)(KernelContext* kernel_context, void** args,
166                                     void** attrs);
167 
Insert(llvm::StringRef name,DirectCustomCall custom_call)168   void Insert(llvm::StringRef name, DirectCustomCall custom_call) {
169     lib_.try_emplace(name, custom_call);
170   }
171 
ForEach(std::function<void (llvm::StringRef,DirectCustomCall)> f)172   void ForEach(std::function<void(llvm::StringRef, DirectCustomCall)> f) const {
173     for (auto& kv : lib_) f(kv.first(), kv.second);
174   }
175 
176  private:
177   llvm::StringMap<DirectCustomCall> lib_;
178 };
179 
180 // Forward declare template defined below.
181 template <CustomCall::RuntimeChecks checks, typename Fn, typename... Ts>
182 class CustomCallHandler;
183 
184 namespace internal {
185 
186 // A type tag to distinguish arguments tied to the attributes in the
187 // `CustomCallBinding` variadic template argument.
188 template <typename T>
189 struct Attr {};
190 
191 // A type tag to distinguish arguments tied to the user data in the
192 // `CustomCallBinding` variadic template argument.
193 template <typename T>
194 struct UserData {};
195 
196 // A type tag to distinguish arguments tied to the constant values in the
197 // `CustomCallBinding` variadic template argument.
198 template <typename T>
199 struct Value {};
200 
201 // A template for checking if type is a wrapped attribute or user data.
202 template <typename>
203 struct IsWrapped : std::false_type {};
204 
205 template <typename T>
206 struct IsWrapped<internal::Attr<T>> : std::true_type {};
207 
208 template <typename T>
209 struct IsWrapped<internal::UserData<T>> : std::true_type {};
210 
211 template <typename T>
212 struct IsWrapped<internal::Value<T>> : std::true_type {};
213 
214 // Checks if remaining arguments are in the parameter pack.
215 template <typename... Ts>
216 using HasRemainingArgs =
217     std::disjunction<std::is_same<CustomCall::RemainingArgs, Ts>...>;
218 
219 }  // namespace internal
220 
221 // Custom call binding describes the function signature of the expected custom
222 // call handler using its variadic template parameter.
223 //
224 //   Custom call binding:
225 //     CustomCallBinding<int32_t, MemrefView>
226 //
227 //   Function signature:
228 //     LogicalResult MyHandle(int32_t algo, MemrefView memref);
229 //
230 template <typename... Ts>
231 class CustomCallBinding {
232  public:
233   using RuntimeChecks = CustomCall::RuntimeChecks;
234 
235   template <typename T>
236   CustomCallBinding<Ts..., T> Arg() && {
237     return {std::move(*this)};
238   }
239 
240   CustomCallBinding<Ts..., CustomCall::RemainingArgs> RemainingArgs() && {
241     static_assert(!internal::HasRemainingArgs<Ts...>::value,
242                   "remaining arguments can be passed just once");
243     return {std::move(*this)};
244   }
245 
246   template <typename T>
247   CustomCallBinding<Ts..., internal::Attr<T>> Attr(std::string attr) && {
248     attrs_.push_back(std::move(attr));
249     return {std::move(*this)};
250   }
251 
252   template <typename T>
253   CustomCallBinding<Ts..., internal::UserData<T>> UserData() && {
254     static_assert(std::is_pointer<T>::value, "user data must be a pointer");
255     return {std::move(*this)};
256   }
257 
258   template <typename T>
259   CustomCallBinding<Ts..., internal::Value<T>> Value(T value) && {
260     values_.push_back(std::move(value));
261     return {std::move(*this)};
262   }
263 
264   template <RuntimeChecks checks = RuntimeChecks::kDefault, typename Fn>
265   std::unique_ptr<CustomCallHandler<checks, Fn, Ts...>> To(Fn fn) {
266     return std::unique_ptr<CustomCallHandler<checks, Fn, Ts...>>(
267         new CustomCallHandler<checks, Fn, Ts...>(
268             std::forward<Fn>(fn), std::move(callee_), std::move(attrs_),
269             std::move(values_)));
270   }
271 
272  private:
273   template <typename...>
274   friend class CustomCallBinding;
275   friend class CustomCall;
276 
277   explicit CustomCallBinding(std::string callee) : callee_(std::move(callee)) {
278     static_assert(sizeof...(Ts) == 0, "custom call arguments must be empty");
279   }
280 
281   template <typename... TTs>
282   CustomCallBinding(CustomCallBinding<TTs...>&& other)  // NOLINT
283       : callee_(std::move(other.callee_)),
284         attrs_(std::move(other.attrs_)),
285         values_(std::move(other.values_)) {}
286 
287   CustomCallBinding(CustomCallBinding&) = delete;
288 
289   std::string callee_;              // custom call target
290   std::vector<std::string> attrs_;  // names of bound attributes
291   std::vector<llvm::Any> values_;   // values bound to arguments
292 };
293 
294 inline CustomCallBinding<> CustomCall::Bind(std::string callee) {
295   return CustomCallBinding<>(std::move(callee));
296 }
297 
298 // Custom call arguments decoding must be defined by specializing this template.
299 //
300 // Example: decoding for the `MyType` arguments
301 //
302 //   template <CustomCall::RuntimeChecks checks>
303 //   struct CustomCallArgDecoding<MyType, checks> {
304 //    static FailureOr<MyType> Decode(TypeID type_id, void* value);
305 //   };
306 //
307 template <typename T, CustomCall::RuntimeChecks>
308 struct CustomCallArgDecoding;
309 
310 // Custom call attribute decoding must be defined by specializing this template.
311 //
312 // Example: decoding for the `MyType` attributes
313 //
314 //   template <CustomCall::RuntimeChecks checks>
315 //   struct CustomCallAttrDecoding<MyType, checks> {
316 //    static FailureOr<MyType> Decode(llvm::StringRef name,
317 //                                    TypeID type_id, void* value);
318 //   }
319 //
320 template <typename T, CustomCall::RuntimeChecks>
321 struct CustomCallAttrDecoding;
322 
323 // A type tag to declare MLIR TypeID specializations for types passed to the
324 // custom calls. We don't want to declare specializations for scalar types
325 // directly in this translation unit, so we rely on a tag to wrap them.
326 //
327 // See explicit TypeID declarations at the end of this file.
328 template <typename T>
329 struct Tagged {};
330 
331 // A type tag to represent empty arrays of unknown element type.
332 struct EmptyArrayRef {};
333 
334 //===----------------------------------------------------------------------===//
335 // C structures corresponding to the `rt-to-llvm` pass LLVM structs encoding
336 // various types of arguments/attributes.
337 
338 namespace internal {
339 
340 struct EncodedMemref {
341   uint8_t dtype;
342   uint8_t rank;
343   void* data;
344   int64_t dims[];
345 };
346 
347 template <typename T>
348 struct EncodedArray {
349   int64_t size;
350   const T* data;
351 };
352 
353 template <typename T>
354 struct EncodedDenseElements {
355   struct EncodedArray<T> payload;
356   int64_t rank;
357   int64_t shape[];
358 };
359 
360 }  // namespace internal
361 
362 //===----------------------------------------------------------------------===//
363 // Helpers for decoding opaque arguments and attributes memory.
364 
365 namespace internal {
366 
367 // Decoded pair of an argument type and opaque value.
368 struct DecodedArg {
369   TypeID type_id;
370   void* value;
371 };
372 
373 // Decoded triple of an attribute name, type and opaque value.
374 struct DecodedAttr {
375   llvm::StringRef name;
376   TypeID type_id;
377   void* value;
378 };
379 
380 // A convenience wrapper around opaque arguments memory.
381 class DecodedArgs {
382  public:
383   explicit DecodedArgs(void** args)
384       : args_(args), num_args_(*reinterpret_cast<int64_t*>(args_[0])) {}
385 
386   LLVM_ATTRIBUTE_ALWAYS_INLINE int64_t size() const { return num_args_; }
387 
388   LLVM_ATTRIBUTE_ALWAYS_INLINE DecodedArg operator[](size_t i) const {
389     void** arg_base = args_ + 1 + i * 2;
390 
391     DecodedArg arg;
392     arg.type_id = TypeID::getFromOpaquePointer(arg_base[0]);
393     arg.value = arg_base[1];
394 
395     return arg;
396   }
397 
398  private:
399   void** args_;
400   int64_t num_args_;
401 };
402 
403 // A convenience wrapper around opaque attributes memory.
404 class DecodedAttrs {
405  public:
406   explicit DecodedAttrs(void** attrs)
407       : attrs_(attrs), num_attrs_(*reinterpret_cast<int64_t*>(attrs_[0])) {}
408 
409   LLVM_ATTRIBUTE_ALWAYS_INLINE int64_t size() const { return num_attrs_; }
410 
411   LLVM_ATTRIBUTE_ALWAYS_INLINE DecodedAttr operator[](size_t i) const {
412     void** attr_base = attrs_ + 1 + i * 3;
413 
414     DecodedAttr attr;
415     auto* name = reinterpret_cast<internal::EncodedArray<char>*>(attr_base[0]);
416     attr.name = llvm::StringRef(name->data, name->size);
417     attr.type_id = TypeID::getFromOpaquePointer(attr_base[1]);
418     attr.value = attr_base[2];
419 
420     return attr;
421   }
422 
423  private:
424   void** attrs_;
425   int64_t num_attrs_;
426 };
427 
428 }  // namespace internal
429 
430 //===----------------------------------------------------------------------===//
431 // CustomCall remaining arguments wraps the type-erased `DecodedArg` container,
432 // and provides a type-safe API for accessing individual arguments.
433 
434 class CustomCall::RemainingArgs {
435  public:
436   using RuntimeChecks = CustomCall::RuntimeChecks;
437 
438   RemainingArgs(internal::DecodedArgs args, size_t offset)
439       : args_(args), offset_(offset) {
440     assert(offset <= args_.size() && "illegal remaining args offset");
441   }
442 
443   size_t size() const { return args_.size() - offset_; }
444   bool empty() const { return size() == 0; }
445 
446   template <typename T>
447   bool isa(size_t index) const {
448     return args_[index + offset_].type_id == TypeID::get<Tagged<T>>();
449   }
450 
451   template <typename T, RuntimeChecks checks = RuntimeChecks::kDefault>
452   FailureOr<T> get(size_t index) const {
453     internal::DecodedArg arg = args_[index + offset_];
454     return CustomCallArgDecoding<T, checks>::Decode(arg.type_id, arg.value);
455   }
456 
457  private:
458   internal::DecodedArgs args_;
459   size_t offset_;
460 };
461 
462 class CustomCall::VariantArg {
463  public:
464   using RuntimeChecks = CustomCall::RuntimeChecks;
465 
466   VariantArg(internal::DecodedArgs args, size_t offset)
467       : args_(args), offset_(offset) {
468     assert(offset <= args_.size() && "illegal remaining args offset");
469   }
470 
471   template <typename T>
472   bool isa() const {
473     return args_[offset_].type_id == TypeID::get<Tagged<T>>();
474   }
475 
476   template <typename T, RuntimeChecks checks = RuntimeChecks::kDefault>
477   FailureOr<T> get() const {
478     internal::DecodedArg arg = args_[offset_];
479     return CustomCallArgDecoding<T, checks>::Decode(arg.type_id, arg.value);
480   }
481 
482  private:
483   internal::DecodedArgs args_;
484   size_t offset_;
485 };
486 
487 class CustomCall::VariantAttr {
488  public:
489   using RuntimeChecks = CustomCall::RuntimeChecks;
490 
491   VariantAttr(llvm::StringRef name, TypeID type_id, void* value)
492       : name_(name), type_id_(type_id), value_(value) {}
493 
494   template <typename T>
495   bool isa() const {
496     return type_id_ == TypeID::get<Tagged<T>>();
497   }
498 
499   template <typename T, RuntimeChecks checks = RuntimeChecks::kDefault>
500   FailureOr<T> get() const {
501     return CustomCallAttrDecoding<T, checks>::Decode(name_, type_id_, value_);
502   }
503 
504  private:
505   llvm::StringRef name_;
506   TypeID type_id_;
507   void* value_;
508 };
509 
510 //===----------------------------------------------------------------------===//
511 // A little bit of template metaprogramming to implement type safe binding
512 // of custom calls to C++ functions. This is internal implementation details,
513 // and must not be relied on in any of the client code.
514 
515 namespace internal {
516 
517 // A helper struct to extract the type of the handler argument.
518 template <typename T>
519 struct FnArgType {
520   using Type = T;
521 };
522 
523 // Extracts the underlying type from the attribute type tag.
524 template <typename T>
525 struct FnArgType<internal::Attr<T>> {
526   using Type = T;
527 };
528 
529 // Extracts the underlying type from the user data type tag.
530 template <typename T>
531 struct FnArgType<internal::UserData<T>> {
532   using Type = T;
533 };
534 
535 // Extracts the underlying type from the value type tag.
536 template <typename T>
537 struct FnArgType<internal::Value<T>> {
538   using Type = T;
539 };
540 
541 // A template for counting regular arguments in the Ts pack.
542 template <typename T, typename... Ts>
543 struct NumArgs {
544   static constexpr int64_t value = !IsWrapped<T>::value + NumArgs<Ts...>::value;
545 };
546 
547 template <typename T>
548 struct NumArgs<T> {
549   static constexpr int64_t value = !IsWrapped<T>::value;
550 };
551 
552 // When decoding input data we need to keep track of how many arguments and
553 // attributes we decoded so far to index into the correct data strucuture.
554 struct DecodingOffsets {
555   int64_t args = 0;
556   int64_t attrs = 0;
557   int64_t values = 0;
558 };
559 
560 template <typename T, CustomCall::RuntimeChecks checks>
561 struct Decode {
562   LLVM_ATTRIBUTE_ALWAYS_INLINE static FailureOr<T> call(
563       DecodingOffsets& offsets, internal::DecodedArgs args,
564       llvm::ArrayRef<std::string> attrs_names, llvm::ArrayRef<size_t> attrs_idx,
565       internal::DecodedAttrs attrs, llvm::ArrayRef<llvm::Any> values,
566       const CustomCall::UserData* user_data) {
567     internal::DecodedArg arg = args[offsets.args++];
568     return CustomCallArgDecoding<T, checks>::Decode(arg.type_id, arg.value);
569   }
570 };
571 
572 template <typename T, CustomCall::RuntimeChecks checks>
573 struct Decode<internal::Attr<T>, checks> {
574   LLVM_ATTRIBUTE_ALWAYS_INLINE static FailureOr<T> call(
575       DecodingOffsets& offsets, internal::DecodedArgs args,
576       llvm::ArrayRef<std::string> attrs_names, llvm::ArrayRef<size_t> attrs_idx,
577       internal::DecodedAttrs attrs, llvm::ArrayRef<llvm::Any> values,
578       const CustomCall::UserData* user_data) {
579     // Find decoded attribute corresponding for the given attribute index.
580     int64_t idx = offsets.attrs++;
581 
582     // Do not check the attribute name, and decode attribute at the given index.
583     if (!CustomCall::CheckNames(checks)) {
584       size_t i = attrs_idx[idx];
585       return CustomCallAttrDecoding<T, checks>::Decode(
586           attrs[i].name, attrs[i].type_id, attrs[i].value);
587     }
588 
589     llvm::StringRef attr = attrs_names[idx];
590 
591     // Given that attributes are passed to the custom call handler
592     // lexicographically sorted by name, we can find the attribute we are
593     // looking for only between the `attrs_idx` offset and the end of the
594     // attributes array.
595     for (size_t i = attrs_idx[idx]; i < attrs.size(); ++i) {
596       if (LLVM_LIKELY(attrs[i].name == attr))
597         return CustomCallAttrDecoding<T, checks>::Decode(
598             attrs[i].name, attrs[i].type_id, attrs[i].value);
599     }
600 
601     // Attribute we were looking for was not passed as an argument.
602     return mlir::failure();
603   }
604 };
605 
606 template <typename T, CustomCall::RuntimeChecks checks>
607 struct Decode<internal::UserData<T>, checks> {
608   LLVM_ATTRIBUTE_ALWAYS_INLINE static FailureOr<T> call(
609       DecodingOffsets& offsets, internal::DecodedArgs args,
610       llvm::ArrayRef<std::string> attrs_names, llvm::ArrayRef<size_t> attrs_idx,
611       internal::DecodedAttrs attrs, llvm::ArrayRef<llvm::Any> values,
612       const CustomCall::UserData* user_data) {
613     using UserDataT = std::remove_pointer_t<T>;
614 
615     if (!CustomCall::CheckUserData(checks)) return user_data->get<UserDataT>();
616 
617     // TODO(ezhulenev): Add an option to request nullable user data, because
618     // right now we do not distinguish between a user data pointer that doesn't
619     // exist, and a null pointer passed by the user.
620 
621     // Get the requested value if user data was passed to the custom call.
622     auto* ptr = user_data ? user_data->getIfExists<UserDataT>() : nullptr;
623     if (LLVM_UNLIKELY(!ptr)) return mlir::failure();
624     return ptr;
625   }
626 };
627 
628 template <typename T, CustomCall::RuntimeChecks checks>
629 struct Decode<internal::Value<T>, checks> {
630   LLVM_ATTRIBUTE_ALWAYS_INLINE static FailureOr<T> call(
631       DecodingOffsets& offsets, internal::DecodedArgs args,
632       llvm::ArrayRef<std::string> attrs_names, llvm::ArrayRef<size_t> attrs_idx,
633       internal::DecodedAttrs attrs, llvm::ArrayRef<llvm::Any> values,
634       const CustomCall::UserData* user_data) {
635     return llvm::any_cast<T>(values[offsets.values++]);
636   }
637 };
638 
639 template <CustomCall::RuntimeChecks checks>
640 struct Decode<CustomCall::RemainingArgs, checks> {
641   LLVM_ATTRIBUTE_ALWAYS_INLINE static FailureOr<CustomCall::RemainingArgs> call(
642       DecodingOffsets& offsets, internal::DecodedArgs args,
643       llvm::ArrayRef<std::string> attr_names, llvm::ArrayRef<size_t> attrs_idx,
644       internal::DecodedAttrs attrs, llvm::ArrayRef<llvm::Any> values,
645       const CustomCall::UserData* user_data) {
646     return CustomCall::RemainingArgs(args, offsets.args);
647   }
648 };
649 
650 template <CustomCall::RuntimeChecks checks>
651 struct Decode<CustomCall::VariantArg, checks> {
652   LLVM_ATTRIBUTE_ALWAYS_INLINE static FailureOr<CustomCall::VariantArg> call(
653       DecodingOffsets& offsets, internal::DecodedArgs args,
654       llvm::ArrayRef<std::string> attr_names, llvm::ArrayRef<size_t> attrs_idx,
655       internal::DecodedAttrs attrs, llvm::ArrayRef<llvm::Any> values,
656       const CustomCall::UserData* user_data) {
657     return CustomCall::VariantArg(args, offsets.args++);
658   }
659 };
660 
661 }  // namespace internal
662 
663 // Custom call handler binds concrete custom call implementation of type `Fn` to
664 // the custom call function signature. `Fn` can be a function pointer, or a
665 // lambda.
666 //
667 // Custom call handler uses the variadic template parameter `Ts` to decode the
668 // opaque pointers passed to the `call` function into the C++ types that are
669 // forwarded to the custom call implementation.
670 template <CustomCall::RuntimeChecks checks, typename Fn, typename... Ts>
671 class CustomCallHandler : public CustomCall {
672   static constexpr int64_t kSize = sizeof...(Ts);
673   static constexpr int64_t kNumArgs = internal::NumArgs<Ts...>::value;
674 
675   template <typename T>
676   using FnArgType = typename internal::FnArgType<T>::Type;
677 
678   // Custom call can signal error using a LogicalError result.
679   static constexpr bool kIsLogicalErr =
680       std::is_invocable_r_v<LogicalResult, Fn, FnArgType<Ts>...>;
681 
682   // Custom call can signal error together with a detailed error message.
683   static constexpr bool kIsDetailedErr =
684       std::is_invocable_r_v<llvm::Error, Fn, FnArgType<Ts>...>;
685 
686   static_assert(kIsLogicalErr || kIsDetailedErr,
687                 "incompatible custom call handler types");
688 
689  public:
690   llvm::StringRef name() const final { return callee_; }
691 
692   LLVM_ATTRIBUTE_ALWAYS_INLINE LogicalResult
693   call(void** args, void** attrs, const UserData* user_data,
694        const DiagnosticEngine* diagnostic) const final {
695     // Unpoison the first pointer to get the args and attrs sizes.
696     ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(args, sizeof(void*));
697     ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(attrs, sizeof(void*));
698 
699     // Decode arguments and attributes from the opaque pointers.
700     internal::DecodedArgs decoded_args(args);
701     internal::DecodedAttrs decoded_attrs(attrs);
702 
703     int64_t num_args = decoded_args.size();
704     int64_t num_attrs = decoded_attrs.size();
705 
706     // Unpoison the rest of the of args and attrs data.
707     ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(args,
708                                         (1 + 2 * num_args) * sizeof(void*));
709     ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(attrs,
710                                         (1 + 3 * num_attrs) * sizeof(void*));
711 
712     if (LLVM_UNLIKELY(diagnostic == nullptr))
713       diagnostic = DiagnosticEngine::DefaultDiagnosticEngine();
714 
715     // If all runtime checks are disabled we are just reinterpreting opaque
716     // `args` and `attrs` memory acording to the requested handler signature.
717     if (checks != RuntimeChecks::kNone) {
718       // Check that the number of passed arguments matches the signature. Each
719       // individual argument decoding will check the actual type.
720       if (internal::HasRemainingArgs<Ts...>::value) {
721         if (LLVM_UNLIKELY(num_args < kNumArgs - 1))
722           return diagnostic->EmitError()
723                  << "Wrong number of arguments: expected at least "
724                  << (kNumArgs - 1) << " got " << num_args;
725       } else {
726         if (LLVM_UNLIKELY(num_args != kNumArgs))
727           return diagnostic->EmitError()
728                  << "Wrong number of arguments: expected " << kNumArgs
729                  << " got " << num_args;
730       }
731 
732       // Check that we have enough attributes passed to the custom call. Each
733       // individual attribute decoding will check the name and the type.
734       if (LLVM_UNLIKELY(num_attrs < attrs_.size()))
735         return diagnostic->EmitError()
736                << "Wrong number of attributes: expected at least "
737                << attrs_.size() << " got " << num_attrs;
738     }
739 
740     return call(decoded_args, decoded_attrs, user_data, diagnostic,
741                 std::make_index_sequence<kSize>{});
742   }
743 
744   template <size_t... Is>
745   LLVM_ATTRIBUTE_ALWAYS_INLINE LogicalResult
746   call(internal::DecodedArgs args, internal::DecodedAttrs attrs,
747        const UserData* user_data, const DiagnosticEngine* diagnostic,
748        std::index_sequence<Is...>) const {
749     // A helper structure to allow each decoder find the correct offset in the
750     // arguments or attributes.
751     internal::DecodingOffsets offsets;
752 
753     // Check if all arguments and attributes were decoded.
754     bool all_decoded = true;
755     auto check_all_decoded = [&](auto result) {
756       all_decoded &= mlir::succeeded(result);
757       return std::move(result);
758     };
759 
760     // Decode all arguments into FailureOr containers. It is guaranteed
761     // that initializer list will be evaluated left-to-right, and we can rely
762     // on correct offsets computation.
763     std::tuple<FailureOr<FnArgType<Ts>>...> fn_args = {
764         check_all_decoded(internal::Decode<Ts, checks>::call(
765             offsets, args, attrs_, attrs_idx_, attrs, values_, user_data))...};
766     if (LLVM_UNLIKELY(!all_decoded))
767       return diagnostic->EmitError()
768              << "Failed to decode all custom call arguments and attributes";
769 
770     // Custom call returns logical result to signal failures.
771     if constexpr (kIsLogicalErr)
772       return fn_(std::move(*std::get<Is>(fn_args))...);
773 
774     // Custom call returns detailed error to signal failures.
775     if constexpr (kIsDetailedErr) {
776       if (auto err = fn_(std::move(*std::get<Is>(fn_args))...))
777         return diagnostic->EmitError() << std::move(err);
778       return mlir::success();
779     }
780 
781     llvm_unreachable("unexpected custom call type");
782   }
783 
784  private:
785   template <typename...>
786   friend class CustomCallBinding;
787 
788   CustomCallHandler(Fn fn, std::string callee, std::vector<std::string> attrs,
789                     std::vector<llvm::Any> values)
790       : fn_(std::move(fn)),
791         callee_(std::move(callee)),
792         attrs_(std::move(attrs)),
793         values_(std::move(values)),
794         attrs_idx_(attrs_.size()) {
795     // Sort attributes names.
796     std::vector<std::string> sorted = attrs_;
797     llvm::sort(sorted);
798 
799     // Find index or every attribute in the sorted attributes vector.
800     for (size_t i = 0; i < attrs_.size(); ++i) {
801       const std::string& attr = attrs_[i];
802       attrs_idx_[i] = std::distance(sorted.begin(), llvm::find(sorted, attr));
803     }
804   }
805 
806   Fn fn_;
807   std::string callee_;
808   std::vector<std::string> attrs_;
809   std::vector<llvm::Any> values_;
810   // A mapping from the attribute index to its index in the lexicographically
811   // sorter vector of attribute names. Attributes passed in the custom call
812   // handler sorted by the name, we use this index to efficiently find the
813   // decoded attribute entry.
814   std::vector<size_t> attrs_idx_;
815 };
816 
817 template <CustomCall::RuntimeChecks checks, typename Fn, typename... Ts>
818 constexpr int64_t CustomCallHandler<checks, Fn, Ts...>::kSize;
819 
820 template <CustomCall::RuntimeChecks checks, typename Fn, typename... Ts>
821 constexpr int64_t CustomCallHandler<checks, Fn, Ts...>::kNumArgs;
822 
823 //===----------------------------------------------------------------------===//
824 // Custom arguments attributes decoding.
825 
826 // A view into the memref argument. Corresponds to the MemrefDesc, however it
827 // doesn't own the sizes/strides vectors, and cheap to pass around. Memrefs with
828 // non-identity layouts can be decoded only as a StridedMemrefView.
829 struct StridedMemrefView {
830   tfrt::DType dtype;
831   void* data;
832   llvm::ArrayRef<int64_t> sizes;
833   llvm::ArrayRef<int64_t> strides;
834 };
835 
836 // A view into the memref argument with an identity (row major) layout.
837 struct MemrefView {
838   tfrt::DType dtype;
839   void* data;
840   llvm::ArrayRef<int64_t> sizes;
841 };
842 
843 // A flat view into memref argument with an identity (row major) layout. If the
844 // memref shape and strides are not required for the custom call, it's cheaper
845 // to pass the flat view.
846 struct FlatMemrefView {
847   tfrt::DType dtype;
848   void* data;
849   int64_t size_in_bytes;
850 };
851 
852 llvm::raw_ostream& operator<<(llvm::raw_ostream& os, const StridedMemrefView&);
853 llvm::raw_ostream& operator<<(llvm::raw_ostream& os, const MemrefView&);
854 llvm::raw_ostream& operator<<(llvm::raw_ostream& os, const FlatMemrefView&);
855 
856 template <CustomCall::RuntimeChecks checks>
857 struct CustomCallArgDecoding<StridedMemrefView, checks> {
858   using EncodedMemref = internal::EncodedMemref;
859 
860   LLVM_ATTRIBUTE_ALWAYS_INLINE
861   static FailureOr<StridedMemrefView> Decode(TypeID type_id, void* value) {
862     if (!(CustomCall::CheckType<Tagged<MemrefView>>(checks, type_id) ||
863           CustomCall::CheckType<Tagged<StridedMemrefView>>(checks, type_id)))
864       return mlir::failure();
865 
866     auto* encoded = reinterpret_cast<EncodedMemref*>(value);
867     ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(encoded, sizeof(EncodedMemref));
868     ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(
869         encoded, sizeof(EncodedMemref) + encoded->rank * sizeof(int64_t));
870 
871     tfrt::DType dtype = static_cast<tfrt::DType>(encoded->dtype);
872     return StridedMemrefView{dtype,
873                              encoded->data,
874                              {encoded->dims, encoded->rank},
875                              {encoded->dims + encoded->rank, encoded->rank}};
876   }
877 };
878 
879 template <CustomCall::RuntimeChecks checks>
880 struct CustomCallArgDecoding<MemrefView, checks> {
881   using EncodedMemref = internal::EncodedMemref;
882 
883   LLVM_ATTRIBUTE_ALWAYS_INLINE
884   static FailureOr<MemrefView> Decode(TypeID type_id, void* value) {
885     if (!CustomCall::CheckType<Tagged<MemrefView>>(checks, type_id))
886       return mlir::failure();
887 
888     auto* encoded = reinterpret_cast<EncodedMemref*>(value);
889     ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(encoded, sizeof(EncodedMemref));
890     ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(
891         encoded, sizeof(EncodedMemref) + encoded->rank * sizeof(int64_t));
892 
893     tfrt::DType dtype = static_cast<tfrt::DType>(encoded->dtype);
894     return MemrefView{dtype, encoded->data, {encoded->dims, encoded->rank}};
895   }
896 };
897 
898 template <CustomCall::RuntimeChecks checks>
899 struct CustomCallArgDecoding<FlatMemrefView, checks> {
900   using EncodedMemref = internal::EncodedMemref;
901 
902   LLVM_ATTRIBUTE_ALWAYS_INLINE
903   static FailureOr<FlatMemrefView> Decode(TypeID type_id, void* value) {
904     if (!CustomCall::CheckType<Tagged<MemrefView>>(checks, type_id))
905       return mlir::failure();
906 
907     auto* encoded = reinterpret_cast<EncodedMemref*>(value);
908     ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(encoded, sizeof(EncodedMemref));
909     ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(
910         encoded, sizeof(EncodedMemref) + encoded->rank * sizeof(int64_t));
911 
912     tfrt::DType dtype = static_cast<tfrt::DType>(encoded->dtype);
913     int64_t size_in_bytes = GetHostSize(dtype);
914     for (int d = 0; d < encoded->rank; ++d) size_in_bytes *= encoded->dims[d];
915     return FlatMemrefView{dtype, encoded->data, size_in_bytes};
916   }
917 };
918 
919 #define XLA_RUNTIME_REGISTER_SCALAR_ARG_DECODING(T)                         \
920   template <CustomCall::RuntimeChecks checks>                               \
921   struct CustomCallArgDecoding<T, checks> {                                 \
922     LLVM_ATTRIBUTE_ALWAYS_INLINE static FailureOr<T> Decode(TypeID type_id, \
923                                                             void* value) {  \
924       if (!CustomCall::CheckType<Tagged<T>>(checks, type_id))               \
925         return mlir::failure();                                             \
926                                                                             \
927       return *reinterpret_cast<T*>(value);                                  \
928     }                                                                       \
929   }
930 
931 XLA_RUNTIME_REGISTER_SCALAR_ARG_DECODING(bool);
932 XLA_RUNTIME_REGISTER_SCALAR_ARG_DECODING(int32_t);
933 XLA_RUNTIME_REGISTER_SCALAR_ARG_DECODING(int64_t);
934 XLA_RUNTIME_REGISTER_SCALAR_ARG_DECODING(float);
935 XLA_RUNTIME_REGISTER_SCALAR_ARG_DECODING(double);
936 
937 #undef XLA_RUNTIME_REGISTER_SCALAR_ARG_DECODING
938 
939 template <CustomCall::RuntimeChecks checks>
940 struct CustomCallArgDecoding<Eigen::half, checks> {
941   LLVM_ATTRIBUTE_ALWAYS_INLINE static FailureOr<Eigen::half> Decode(
942       TypeID type_id, void* value) {
943     if (!CustomCall::CheckType<Tagged<Eigen::half>>(checks, type_id))
944       return mlir::failure();
945 
946     auto* src = reinterpret_cast<uint16_t*>(value);
947     return Eigen::numext::bit_cast<Eigen::half>(*src);
948   }
949 };
950 
951 //===----------------------------------------------------------------------===//
952 // Custom call attributes decoding.
953 
954 template <CustomCall::RuntimeChecks checks>
955 struct CustomCallAttrDecoding<llvm::StringRef, checks> {
956   using StringRef = llvm::StringRef;
957 
958   LLVM_ATTRIBUTE_ALWAYS_INLINE static FailureOr<StringRef> Decode(
959       llvm::StringRef name, TypeID type_id, void* value) {
960     if (!CustomCall::CheckType<Tagged<StringRef>>(checks, type_id))
961       return mlir::failure();
962 
963     auto* encoded = reinterpret_cast<internal::EncodedArray<char>*>(value);
964     return StringRef(encoded->data, encoded->size);
965   }
966 };
967 
968 template <CustomCall::RuntimeChecks checks>
969 struct CustomCallAttrDecoding<CustomCall::VariantAttr, checks> {
970   LLVM_ATTRIBUTE_ALWAYS_INLINE static FailureOr<CustomCall::VariantAttr> Decode(
971       llvm::StringRef name, TypeID type_id, void* value) {
972     return CustomCall::VariantAttr(name, type_id, value);
973   }
974 };
975 
976 #define XLA_RUNTIME_REGISTER_SCALAR_ATTR_DECODING(T)          \
977   template <CustomCall::RuntimeChecks checks>                 \
978   struct CustomCallAttrDecoding<T, checks> {                  \
979     LLVM_ATTRIBUTE_ALWAYS_INLINE static FailureOr<T> Decode(  \
980         llvm::StringRef name, TypeID type_id, void* value) {  \
981       if (!CustomCall::CheckType<Tagged<T>>(checks, type_id)) \
982         return mlir::failure();                               \
983                                                               \
984       return *reinterpret_cast<T*>(value);                    \
985     }                                                         \
986   }
987 
988 XLA_RUNTIME_REGISTER_SCALAR_ATTR_DECODING(bool);
989 XLA_RUNTIME_REGISTER_SCALAR_ATTR_DECODING(int32_t);
990 XLA_RUNTIME_REGISTER_SCALAR_ATTR_DECODING(int64_t);
991 XLA_RUNTIME_REGISTER_SCALAR_ATTR_DECODING(float);
992 XLA_RUNTIME_REGISTER_SCALAR_ATTR_DECODING(double);
993 
994 #undef XLA_RUNTIME_REGISTER_SCALAR_ATTR_DECODING
995 
996 // Both EncodedArray and 1-D EncodedDenseElements can be decoded as an
997 // llvm::ArrayRef. Pointers to both EncodedArray and 1-D EncodedDenseElements
998 // can be dereferenced as a pointer to EncodedArray.
999 #define XLA_RUNTIME_REGISTER_ARRAY_ATTR_DECODING(T)                           \
1000   template <CustomCall::RuntimeChecks checks>                                 \
1001   struct CustomCallAttrDecoding<llvm::ArrayRef<T>, checks> {                  \
1002     LLVM_ATTRIBUTE_ALWAYS_INLINE static FailureOr<llvm::ArrayRef<T>> Decode(  \
1003         llvm::StringRef name, TypeID type_id, void* value) {                  \
1004       if ((!CustomCall::CheckType<Tagged<llvm::ArrayRef<T>>>(checks,          \
1005                                                              type_id)) &&     \
1006           (!CustomCall::CheckType<Tagged<CustomCall::TensorRef<T>>>(          \
1007               checks, type_id)) &&                                            \
1008           (!CustomCall::CheckType<Tagged<EmptyArrayRef>>(checks, type_id))) { \
1009         return mlir::failure();                                               \
1010       }                                                                       \
1011                                                                               \
1012       auto* encoded = reinterpret_cast<internal::EncodedArray<T>*>(value);    \
1013       return llvm::ArrayRef<T>(encoded->data, encoded->size);                 \
1014     }                                                                         \
1015   }
1016 
1017 XLA_RUNTIME_REGISTER_ARRAY_ATTR_DECODING(int32_t);
1018 XLA_RUNTIME_REGISTER_ARRAY_ATTR_DECODING(int64_t);
1019 XLA_RUNTIME_REGISTER_ARRAY_ATTR_DECODING(float);
1020 XLA_RUNTIME_REGISTER_ARRAY_ATTR_DECODING(double);
1021 
1022 #undef XLA_RUNTIME_REGISTER_ARRAY_ATTR_DECODING
1023 
1024 #define XLA_RUNTIME_REGISTER_DENSE_ELEMENTS_ATTR_DECODING(T)                 \
1025   template <CustomCall::RuntimeChecks checks>                                \
1026   struct CustomCallAttrDecoding<CustomCall::TensorRef<T>, checks> {          \
1027     LLVM_ATTRIBUTE_ALWAYS_INLINE static FailureOr<CustomCall::TensorRef<T>>  \
1028     Decode(llvm::StringRef name, TypeID type_id, void* value) {              \
1029       if (!CustomCall::CheckType<Tagged<CustomCall::TensorRef<T>>>(checks,   \
1030                                                                    type_id)) \
1031         return mlir::failure();                                              \
1032                                                                              \
1033       auto* encoded =                                                        \
1034           reinterpret_cast<internal::EncodedDenseElements<T>*>(value);       \
1035       auto payload = encoded->payload;                                       \
1036       llvm::ArrayRef<T> data(payload.data, payload.size);                    \
1037       llvm::ArrayRef<int64_t> shape(encoded->shape, encoded->rank);          \
1038       return CustomCall::TensorRef<T>({shape, data});                        \
1039     }                                                                        \
1040   }
1041 
1042 XLA_RUNTIME_REGISTER_DENSE_ELEMENTS_ATTR_DECODING(int32_t);
1043 XLA_RUNTIME_REGISTER_DENSE_ELEMENTS_ATTR_DECODING(int64_t);
1044 XLA_RUNTIME_REGISTER_DENSE_ELEMENTS_ATTR_DECODING(float);
1045 XLA_RUNTIME_REGISTER_DENSE_ELEMENTS_ATTR_DECODING(double);
1046 
1047 #undef XLA_RUNTIME_REGISTER_DENSE_ELEMENTS_ATTR_DECODING
1048 
1049 //===----------------------------------------------------------------------===//
1050 // Register an XLA custom call attribute decoding for enum class. At runtime the
1051 // value should be passed as the underlying enum type.
1052 //===----------------------------------------------------------------------===//
1053 
1054 // Example: register decoding for a user-defined enum class
1055 //
1056 //   enum class MyEnumType { kFoo, kBar, kBaz };
1057 //
1058 //   XLA_RUNTIME_REGISTER_ENUM_ATTR_DECODING(MyEnumType);
1059 //
1060 #define XLA_RUNTIME_REGISTER_ENUM_ATTR_DECODING(T)                \
1061   template <CustomCall::RuntimeChecks checks>                     \
1062   struct CustomCallAttrDecoding<T, checks> {                      \
1063     static_assert(std::is_enum<T>::value, "expected enum class"); \
1064     using U = std::underlying_type_t<T>;                          \
1065                                                                   \
1066     LLVM_ATTRIBUTE_ALWAYS_INLINE static FailureOr<T> Decode(      \
1067         llvm::StringRef name, TypeID type_id, void* value) {      \
1068       if (!CustomCall::CheckType<Tagged<T>>(checks, type_id))     \
1069         return mlir::failure();                                   \
1070                                                                   \
1071       return static_cast<T>(*reinterpret_cast<U*>(value));        \
1072     }                                                             \
1073   }
1074 
1075 //===----------------------------------------------------------------------===//
1076 // Register an XLA custom call attribute decoding for aggregate attributes.
1077 //===----------------------------------------------------------------------===//
1078 
1079 // A workaround for passing braced initializers to macro.
1080 #define XLA_RUNTIME_AGGREGATE_FIELDS(...) \
1081   { __VA_ARGS__ }
1082 
1083 // Example: register decoding for a user-defined struct
1084 //
1085 //   struct PairOfI64 { int64_t a; int64_t b; };
1086 //
1087 //   XLA_RUNTIME_REGISTER_AGGREGATE_ATTR_DECODING(
1088 //     PairOfI64, XLA_RUNTIME_AGGREGATE_FIELDS("a", "b"),
1089 //     int64_t, int64_t);
1090 //
1091 #define XLA_RUNTIME_REGISTER_AGGREGATE_ATTR_DECODING(T, NAMES, ...)       \
1092   template <CustomCall::RuntimeChecks checks>                             \
1093   struct CustomCallAttrDecoding<T, checks> {                              \
1094     LLVM_ATTRIBUTE_ALWAYS_INLINE static FailureOr<T> Decode(              \
1095         llvm::StringRef name, TypeID type_id, void* value) {              \
1096       if (!CustomCall::CheckType<Tagged<T>>(checks, type_id))             \
1097         return mlir::failure();                                           \
1098                                                                           \
1099       using Impl = internal::DecodeAggregateAttr<T, checks, __VA_ARGS__>; \
1100       return Impl::Decode(reinterpret_cast<void**>(value), NAMES);        \
1101     }                                                                     \
1102   }
1103 
1104 namespace internal {
1105 // Decodes aggregate attribute into the object of type `T` that must be
1106 // constructible from the `Ts` types.
1107 template <typename T, CustomCall::RuntimeChecks checks, typename... Ts>
1108 struct DecodeAggregateAttr {
1109   static constexpr size_t kSize = sizeof...(Ts);
1110 
1111   using RuntimeChecks = CustomCall::RuntimeChecks;
1112 
1113   LLVM_ATTRIBUTE_ALWAYS_INLINE
1114   static FailureOr<T> Decode(void** value,
1115                              std::array<llvm::StringRef, kSize> names) {
1116     internal::DecodedAttrs attrs(value);
1117     return Decode(attrs, names, std::make_index_sequence<kSize>{});
1118   }
1119 
1120   template <size_t... Is>
1121   LLVM_ATTRIBUTE_ALWAYS_INLINE static FailureOr<T> Decode(
1122       internal::DecodedAttrs attrs, std::array<llvm::StringRef, kSize> names,
1123       std::index_sequence<Is...>) {
1124     // Check that the number of encoded attributes matches the signature.
1125     if (checks != RuntimeChecks::kNone && kSize != attrs.size())
1126       return mlir::failure();
1127 
1128     // Check that aggregate member names match the expected names.
1129     if (CustomCall::CheckNames(checks)) {
1130       for (unsigned i = 0; i < kSize; ++i)
1131         if (attrs[i].name != names[i]) return mlir::failure();
1132     }
1133 
1134     // Check if all members were decoded.
1135     bool all_decoded = true;
1136     auto check_all_decoded = [&](auto result) {
1137       all_decoded &= mlir::succeeded(result);
1138       return std::move(result);
1139     };
1140 
1141     // Decode all arguments into FailureOr containers. It is guaranteed
1142     // that initializer list will be evaluated left-to-right, and we can rely
1143     // on correct offsets computation.
1144     std::tuple<FailureOr<Ts>...> members = {
1145         check_all_decoded(CustomCallAttrDecoding<Ts, checks>::Decode(
1146             attrs[Is].name, attrs[Is].type_id, attrs[Is].value))...};
1147     if (LLVM_UNLIKELY(!all_decoded)) return mlir::failure();
1148 
1149     // Forward unpacked members to the type constructor.
1150     return T{std::move(*std::get<Is>(members))...};
1151   }
1152 };
1153 }  // namespace internal
1154 
1155 // Declare/define an explicit specialialization for TypeID for types used
1156 // by the custom calls. This forces the compiler to emit a strong definition for
1157 // a class and controls which translation unit and shared object will actually
1158 // have it.
1159 //
1160 // See TypeID for more documentation.
1161 //
1162 // Because custom calls do not "own" the types passed across the function
1163 // boundary, we declare/define specializations for tagged types to avoid
1164 // potential conflicts with other libraries.
1165 #define XLA_RUNTIME_DECLARE_EXPLICIT_TYPE_ID(T) \
1166   MLIR_DECLARE_EXPLICIT_TYPE_ID(::xla::runtime::Tagged<T>)
1167 
1168 #define XLA_RUNTIME_DEFINE_EXPLICIT_TYPE_ID(T) \
1169   MLIR_DEFINE_EXPLICIT_TYPE_ID(::xla::runtime::Tagged<T>)
1170 
1171 }  // namespace runtime
1172 }  // namespace xla
1173 
1174 XLA_RUNTIME_DECLARE_EXPLICIT_TYPE_ID(llvm::StringRef);
1175 XLA_RUNTIME_DECLARE_EXPLICIT_TYPE_ID(xla::runtime::StridedMemrefView);
1176 XLA_RUNTIME_DECLARE_EXPLICIT_TYPE_ID(xla::runtime::MemrefView);
1177 XLA_RUNTIME_DECLARE_EXPLICIT_TYPE_ID(xla::runtime::FlatMemrefView);
1178 XLA_RUNTIME_DECLARE_EXPLICIT_TYPE_ID(int32_t);
1179 XLA_RUNTIME_DECLARE_EXPLICIT_TYPE_ID(int64_t);
1180 XLA_RUNTIME_DECLARE_EXPLICIT_TYPE_ID(float);
1181 XLA_RUNTIME_DECLARE_EXPLICIT_TYPE_ID(double);
1182 XLA_RUNTIME_DECLARE_EXPLICIT_TYPE_ID(llvm::ArrayRef<int32_t>);
1183 XLA_RUNTIME_DECLARE_EXPLICIT_TYPE_ID(llvm::ArrayRef<int64_t>);
1184 XLA_RUNTIME_DECLARE_EXPLICIT_TYPE_ID(llvm::ArrayRef<float>);
1185 XLA_RUNTIME_DECLARE_EXPLICIT_TYPE_ID(llvm::ArrayRef<double>);
1186 
1187 #endif  // XLA_RUNTIME_CUSTOM_CALL_H_
1188