#pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace c10 { struct FunctionalityOffsetAndMask { // empty constructor shouldn't be used; only needed to initialize // the array before populating it. FunctionalityOffsetAndMask() = default; FunctionalityOffsetAndMask(uint16_t offset, uint16_t mask) : offset(offset), mask(mask) {} // This needs to big enough to cover the size of the operator table. uint16_t offset{}; // See Note [No More Than 16 Backends] // This mask needs to be big enough to mask all of the backend bits. // We probably don't ever want to have more than 16 backend bits, so uint16_t // should be enough. uint16_t mask{}; }; static_assert( c10::num_runtime_entries < 65536, "The dispatcher currently only supports up to 2^16 runtime entries"); C10_API std::array initializeFunctionalityOffsetsAndMasks(); C10_ALWAYS_INLINE static const std:: array& offsetsAndMasks() { static auto offsets_and_masks_ = initializeFunctionalityOffsetsAndMasks(); return offsets_and_masks_; } // A representation of a set of DispatchKeys. A DispatchKeySet contains both // "functionality" bits and "backend bits", and every tensor holds its own // DispatchKeySet. The Dispatcher implements multiple dispatch by grabbing the // keyset on every input tensor, or’ing them together, and dispatching to a // specific piece of functionality. The functionality bits are *ordered*. When // multiple functionality bits are set, we use the highest priority // functionality. Similarly, multiple backend bits can theoretically be set if // you call an operator with multiple tensors from difference devices (e.g. CPU // and CUDA), although support for mixed device dispatch is limited (the only // kernels that gracefully handle mixed device inputs for now are cuda kernels // that take in a scalar cpu tensor). // A representation of a set of DispatchKeys. A tensor may have multiple // tensor type ids, e.g., a Variable tensor can also be a CPU tensor; the // DispatchKeySet specifies what type ids apply. The internal representation is // as a 64-bit bit set (this means only 64 tensor type ids are supported). // // As mentioned above, DispatchKeys are ordered; thus, we can ask questions like // "what is the highest priority DispatchKey in the set"? (The set itself is // not ordered; two sets with the same ids will always have the ids ordered in // the same way.) // // Note [DispatchKeySet Internal Representation] // Internally, dispatch keys are packed into 64-bit DispatchKeySet objects // that get passed around at runtime. // However, there isn't necessarily a 1-to-1 mapping between bits in the keyset // and individual dispatch keys. // // First: why do we have this distinction, and why not map every dispatch key // directly to a bit? This is mostly because we have several types of // functionalities that different backends would like to customize. For example, // we have: // - "Dense": CPU, CUDA, XLA, ... (~12 keys) // - "Sparse": SparseCPU, SparseCUDA, ... // - "SparseCsr": SparseCsrCPU, SparseCsrCUDA, ... // - "Quantized": QuantizedCPU, QuantizedCUDA, QuantizedXLA, ... // - "Autograd": AutogradCPU, AutogradCUDA, Autograd XLA, ... // The problem is that total number of keys grows quadratically with [# // backends] x [# functionalities], making it very difficult to map each key // directly to a bit in a bitset without dramatically increasing the size of the // bitset over time. // // The two enums (BackendComponent and DispatchKey) can be divided roughly into // 5 categories. // // (1) "Building block" keys // (a) backends: Everything in the BackendComponent enum (e.g. CPUBit, // CUDABit) (b) functionalities: (per-backend) functionality-bit DispatchKeys // (e.g. AutogradFunctionality, SparseCsr, Sparse, Dense) // (2) "Runtime" keys // (a) "non-customizable backends" (e.g. FPGA) // (b) "non-customizable functionalities" (e.g. Functionalize) // (c) "per-backend instances of customizable functionalities" (e.g. CPU, // SparseCPU, AutogradCPU) // (3) "Alias" DispatchKeys (see Note [Alias Dispatch Keys]) // // (1) Building block keys always correspond to individual bits in a // DispatchKeySet. They can also be combined in a DispatchKeySet to form actual // runtime keys. e.g. // auto dense_cpu_ks = DispatchKeySet({DispatchKey::CPUBit, // DispatchKey::Dense}); // // The keyset has the runtime dense-cpu key. // dense_cpu_ks.has(DispatchKey::CPU); // // And it contains the building block keys too. // dense_cpu_ks.has(DispatchKey::CPUBit); // dense_cpu_ks.has(DispatchKey::Dense); // // Not every backend and not every functionality counts as a "building block // key". This is mostly to give us more levers to pull in the design space. // Backend keys and functionality keys that count as "building blocks" will // contribute to a full cross product of functionality that can be overriden. // // For example, right now we have at least 12 "backend" building // blocks (CPU, CUDA, XLA, ...) and at least 5 "functionality" // building blocks (Dense, Sparse, SparseCsr, Quantized, // AutogradFunctionality, ...). These keys together allow every // dispatcher operator to be customized in up to 12*4 different // ways. Each of those requires a slot in the operator table of every // dispatcher operator. Not every piece of functionality necessarily // needs to be customizable per-backend, and not every backend // necessarily needs to be able to customize every type of // functionality. // // // (2) Every runtime key corresponds directly to a slot in an operator's runtime // dispatch table, and you can directly register kernels to a runtime dispatch // key. // // For per-backend functionalities like "Dense" or "AutogradFunctionality", // you can think of the corresponding runtime dispatch keys as "instances" of // that functionality, per backend. E.g. "CPU", "CUDA", "XLA", etc. are all // runtime instances of the "Dense" building block key. // (2a) and (2b) are represented identically in the DispatchKeySet logic: // - backend-agnostic functionalities (e.g. FuncTorchBatched) are NOT // customizable per backend. // In order to do so, we'd need to promote it to a per-backend functionality // "building block" key. // - non-customizable backends (e.g. FPGA) can NOT customize existing // functionality like Sparse, Autograd, etc. // In order to do so, we'd need to promote it to a backend "building block" // key. // // In both cases, these keys directly correspond to runtime slots in the // operator table. // // // (3) "Alias" keys // See Note [Alias Dispatch Keys] // // Final note: for anyone making future changes to the Dispatcher + // DispatchKeySet internals, there's a closed PR with a basic // python-implementation of the Dispatcher that might be useful in quickly // testing out and validating changes. See it at // https://github.com/pytorch/pytorch/pull/68743 // An undefined tensor is one with an empty tensor type set. class DispatchKeySet final { public: enum Full { FULL }; enum FullAfter { FULL_AFTER }; enum Raw { RAW }; // NB: default constructor representation as zero is MANDATORY as // use of DispatchKeySet in TLS requires this. constexpr DispatchKeySet() = default; constexpr DispatchKeySet(Full) : repr_((1ULL << (num_backends + num_functionality_keys - 1)) - 1) {} constexpr DispatchKeySet(FullAfter, DispatchKey t) // LSB after t are OK, but not t itself. // "functionalities" have a notion of ordering (e.g. Autograd > Sparse > // Quantized > Dense). But backends don't really have an ordering. // Therefore, we're enforcing that FullAfter can only be used on // "functionality" keys. : repr_( (1ULL << (num_backends + static_cast(toFunctionalityKey(t)) - 1)) - 1) { *this = add(DispatchKey::PythonDispatcher); } // Public version of DispatchKeySet(uint64_t) API; external users // must be explicit when they do this! constexpr DispatchKeySet(Raw, uint64_t x) : repr_(x) {} constexpr explicit DispatchKeySet(BackendComponent k) { if (k == BackendComponent::InvalidBit) { repr_ = 0; } else { repr_ = 1ULL << (static_cast(k) - 1); } } constexpr explicit DispatchKeySet(DispatchKey k) { // NOLINTNEXTLINE(bugprone-branch-clone) if (k == DispatchKey::Undefined) { // Case 1: handle Undefined specifically repr_ = 0; } else if (k <= DispatchKey::EndOfFunctionalityKeys) { // Case 2: handle "functionality-only" keys // These keys have a functionality bit set, but no backend bits // These can technically be either: // - valid runtime keys (e.g. DispatchKey::AutogradOther, // DispatchKey::FuncTorchBatched, etc) // - "building block" keys that aren't actual runtime keys (e.g. // DispatchKey::Dense or Sparse) uint64_t functionality_val = 1ULL << (num_backends + static_cast(k) - 1); repr_ = functionality_val; } else if (k <= DispatchKey::EndOfRuntimeBackendKeys) { // Case 3: "runtime" keys that have a functionality bit AND a backend bit. // First compute which bit to flip for the functionality. auto functionality_k = toFunctionalityKey(k); // The - 1 is because Undefined is technically a "functionality" that // doesn't show up in the bitset. So e.g. Dense is technically the second // functionality, but the lowest functionality bit. uint64_t functionality_val = 1ULL << (num_backends + static_cast(functionality_k) - 1); // then compute which bit to flip for the backend // Case 4a: handle the runtime instances of "per-backend functionality" // keys For example, given DispatchKey::CPU, we should set: // - the Dense functionality bit // - the CPUBit backend bit // first compute which bit to flip for the backend auto backend_k = toBackendComponent(k); uint64_t backend_val = backend_k == BackendComponent::InvalidBit ? 0 : 1ULL << (static_cast(backend_k) - 1); repr_ = functionality_val + backend_val; } else { // At this point, we should have covered every case except for alias keys. // Technically it would be possible to add alias dispatch keys to a // DispatchKeySet, but the semantics are a little confusing and this // currently isn't needed anywhere. repr_ = 0; } } constexpr uint64_t keys_to_repr(std::initializer_list ks) { uint64_t repr = 0; for (auto k : ks) { repr |= DispatchKeySet(k).repr_; } return repr; } constexpr uint64_t backend_bits_to_repr( std::initializer_list ks) { uint64_t repr = 0; for (auto k : ks) { repr |= DispatchKeySet(k).repr_; } return repr; } explicit constexpr DispatchKeySet(std::initializer_list ks) : repr_(keys_to_repr(ks)) {} explicit constexpr DispatchKeySet(std::initializer_list ks) // Note: for some reason, putting this logic directly in the constructor // appears to fail to compile on CUDA 10.1. // See an example internal failure at // https://www.internalfb.com/intern/skycastle/run/76561193669136035/artifact/actionlog.76561193742069401.stderr : repr_(backend_bits_to_repr(ks)) {} // Test if a DispatchKey is in the set inline bool has(DispatchKey t) const { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(t != DispatchKey::Undefined); return has_all(DispatchKeySet(t)); } constexpr bool has_backend(BackendComponent t) const { return has_all(DispatchKeySet(t)); } // Test if a DispatchKey is in the set // Given a DispatchKeySet of functionality keys and (potentially) backend // keys, tests if all of them are in the current set. constexpr bool has_all(DispatchKeySet ks) const { return static_cast((repr_ & ks.repr_) == ks.repr_); } // Given a DispatchKeySet of functionality keys and (potentially) backend // keys, tests if any of them are in the current set. This could technically // be pretty easily implemented using has(). It is strictly a perf // optimization though. There are many places in the code base where we want // to test for multiple functionality keys together. HOWEVER, runtime // per-backend functionality keys aren't allowed to be used with this // function, because you can end up with weird results. e.g. // DispatchKeySet(DispatchKey::AutogradCPU).has_any(DispatchKeySet(DispatchKey::CPU)) // would return true. inline bool has_any(DispatchKeySet ks) const { TORCH_INTERNAL_ASSERT_DEBUG_ONLY( // Either there are no backend bits in the input keyset ((ks.repr_ & full_backend_mask) == 0) || // or there are no per-backend-functionality bits // See [Note: Per-Backend Functionality Dispatch Keys] ((ks & DispatchKeySet({ DispatchKey::Dense, DispatchKey::Quantized, DispatchKey::Sparse, DispatchKey::SparseCsr, DispatchKey::AutogradFunctionality, }) .repr_) == 0)); return static_cast((repr_ & ks.repr_) != 0); } // Test if DispatchKeySet is a superset of ks. bool isSupersetOf(DispatchKeySet ks) const { return (repr_ & ks.repr_) == ks.repr_; } // Perform set union constexpr DispatchKeySet operator|(DispatchKeySet other) const { return DispatchKeySet(repr_ | other.repr_); } // Perform set intersection constexpr DispatchKeySet operator&(DispatchKeySet other) const { return DispatchKeySet(repr_ & other.repr_); } // Compute the set difference self - other, // but ONLY for the functionality keys. // Any backend bits set on self will remain unchanged. // See Note [Removing keys from DispatchKeySet Only Affects Functionality // Keys] constexpr DispatchKeySet operator-(DispatchKeySet other) const { return DispatchKeySet(repr_ & (full_backend_mask | ~other.repr_)); } // Compute self ^ other constexpr DispatchKeySet operator^(DispatchKeySet other) const { return DispatchKeySet(repr_ ^ other.repr_); } bool operator==(DispatchKeySet other) const { return repr_ == other.repr_; } bool operator!=(DispatchKeySet other) const { return repr_ != other.repr_; } // Add a DispatchKey to the DispatchKey set. Does NOT mutate, // returns the extended DispatchKeySet! C10_NODISCARD constexpr DispatchKeySet add(DispatchKey t) const { return *this | DispatchKeySet(t); } C10_NODISCARD constexpr DispatchKeySet add(DispatchKeySet ks) const { return *this | ks; } // Remove a DispatchKey from the DispatchKey set. // This is generally not an operation you should be doing // (it's used to implement the printing overload, operator<<) // // Note [Removing keys from DispatchKeySet Only Affects Functionality Keys] // Only functionality bits are allowed to be removed from a keyset. // For now, we're only allowing removal of "functionality bits" from the // keyset, which is specifically needed by the fallthrough key calculation // logic. Why is removing backend bits problematic? Consider this example: // // DispatchKeySet([DispatchKey.CPU, DispatchKey.AutogradCUDA, // DispatchKey.CUDA]).remove(DispatchKey.AutogradCUDA) // DispatchKeySet([DispatchKey.CPU, // DispatchKey.AutogradCUDA]).remove(DispatchKey.AutogradCUDA) // // What do we want to happen? // Technically, we'd like it to be true that after removal, // the first keyset still has the CUDA dispatch key while the second doesn't. // Unfortunately there's no way to represent that, because the two keysets are // represented the same way internally: functionality bits: Autograd, Dense // backend bits: CPU, CUDA // // Instead, remove(DispatchKey.AutogradCPU) will only remove the "Autograd" // bit from the bitset. C10_NODISCARD constexpr DispatchKeySet remove(DispatchKey t) const { return DispatchKeySet( repr_ & ~(DispatchKeySet(t).repr_ & ~full_backend_mask)); } // You're allowed to remove a backend bit from a DispatchKeySet, // but you have to be explicit about it (remove_backend() instead of // remove()). constexpr DispatchKeySet remove_backend(BackendComponent b) const { return DispatchKeySet(repr_ & ~(DispatchKeySet(b).repr_)); } // Is the set empty? (AKA undefined tensor) bool empty() const { return repr_ == 0; } uint64_t raw_repr() { return repr_; } DispatchKey highestFunctionalityKey() const { auto functionality_idx = indexOfHighestBit(); // This means that none of the functionality bits were set. if (functionality_idx < num_backends) return DispatchKey::Undefined; // The first num_backend bits in the keyset don't correspond to real // dispatch keys. return static_cast(functionality_idx - num_backends); } // This is similar like toBackendComponent(DispatchKey), but less restrictive. // toBackendComponent() errors out if the key that it was passed has no // backend bits, which is useful for error checking. We need a version of that // here that can also handle "fake" backends like FPGA, because they need to // map to the AutogradOther key. For those backends, we return // BackendComponent::InvalidBit. BackendComponent highestBackendKey() const { // mask to mask out functionality bits auto backend_idx = DispatchKeySet(repr_ & full_backend_mask).indexOfHighestBit(); // all zeros across the backend bits means that no backend bits are set. if (backend_idx == 0) return BackendComponent::InvalidBit; return static_cast(backend_idx); } // returns the DispatchKey of highest priority in the set. DispatchKey highestPriorityTypeId() const { auto functionality_k = highestFunctionalityKey(); if (isPerBackendFunctionalityKey(functionality_k)) { return toRuntimePerBackendFunctionalityKey( functionality_k, highestBackendKey()); } return functionality_k; } // Returns the index of the most-significant bit in the keyset. // This is used to as part of the calculation into the operator table to get: // - the highest "functionality" bit in the keyset. // - the highest "backend" bit in the keyset. uint8_t indexOfHighestBit() const { return 64 - llvm::countLeadingZeros(repr_); } #if defined(C10_MOBILE_TRIM_DISPATCH_KEYS) // [Note: Trimmed Mobile Dispatch Keys] /** * The method below maps the dispatch key in the enum DispatchKey to an * integer index in the dispatchTable_ array in OperatorEntry. The array * is trimmed for mobile to reduce peak memory usage since it's * unnecessary to reserve additional space for dispatch keys that will * never be used on mobile. */ int getDispatchTableIndexForDispatchKeySet() const { auto dk = highestPriorityTypeId(); switch (dk) { case DispatchKey::Undefined: return 0; case DispatchKey::CPU: return 1; case DispatchKey::QuantizedCPU: return 2; case DispatchKey::SparseCPU: return 3; case DispatchKey::BackendSelect: return 4; case DispatchKey::ADInplaceOrView: return 5; case DispatchKey::AutogradOther: return 6; case DispatchKey::AutogradCPU: return 7; default: return -1; } } #else // returns the index in the operator table of highest priority key in the the // keyset Note that we could in theory implement this using // highestPriorityTypeId(), but this code is very hotpath and we can do it // faster without it. int getDispatchTableIndexForDispatchKeySet() const { auto functionality_idx = DispatchKeySet(repr_ >> num_backends).indexOfHighestBit(); auto offset_and_mask = offsetsAndMasks()[functionality_idx]; // Mask the functionality bits out first, then right-shift by 1. // right-shifting by 1 because everything is zero-indexed. // E.g. 000001 (CPU) should give us an offset of 0, 000010 (CUDA) should // give us an offset of 1, etc. auto backend_idx = DispatchKeySet((repr_ & offset_and_mask.mask) >> 1).indexOfHighestBit(); return offset_and_mask.offset + backend_idx; } #endif // returns the "index" of the highest priority backend in the keyset. // This is pretty similar to getBackendKey(), but: // - It's hotpath code (part of the runtime bitset calculation) // - I's returns an integer index, not an enum value // - Everything is shifted to the right by 1. // BackendComponent::InvalidBit is technically the lowest enum value, // but it isn't included in the runtime table. So CPUBit = 1, CUDABit = 2, // etc. uint64_t getBackendIndex() const { return DispatchKeySet((repr_ & full_backend_mask) >> 1).indexOfHighestBit(); } private: constexpr DispatchKeySet(uint64_t repr) : repr_(repr) {} uint64_t repr_ = 0; public: // STL iterator for DispatchKeySet. Iterates through all runtime DispatchKeys // in the set. The iterator is only invalidated by the destruction of the // underlying DispatchKeySet as the iterator stores a pointer to the raw // representation of the DispatchKeySet. Note: When we encounter a per-backend // functionality (e.g. Dense or Sparse), we will iterate through EVERY backend // in the keyset, for that functionality. For example, if the next // functionality key to iterate over is Autograd, and the backend bits in the // keyset correspond to [BackendComponent::CPUBit, BackendComponent::CUDABit], // then the next two keys we return will be DispatchKey::AutogradCPU, // DispatchKey::AutogradCUDA (CPU first because it has lower precedence than // CUDA in DispatchKey.h). class iterator { public: using self_type = iterator; using iterator_category = std::input_iterator_tag; using value_type = DispatchKey; using difference_type = ptrdiff_t; using reference = value_type&; using pointer = value_type*; // final mask value should mask out the entire keyset static const uint8_t end_iter_mask_val = num_backends + num_functionality_keys; // final key value should be the last DispatchKey static const uint8_t end_iter_key_val = num_functionality_keys; // current_dispatchkey_idx_ will iterate through all functionality bits. // current_backendcomponent_idx_ will iterate through all backend bits. explicit iterator( const uint64_t* data_ptr, uint8_t next_functionality = num_backends, uint8_t next_backend = 0) : data_ptr_(data_ptr), next_functionality_(next_functionality), next_backend_(next_backend), // These are in an invalid state at construction time, and set by the // first increment call current_dispatchkey_idx_(end_iter_key_val), current_backendcomponent_idx_(end_iter_key_val) { // Go to the first key in the set TORCH_INTERNAL_ASSERT( next_functionality_ >= num_backends, "num_backends=", static_cast(num_backends), "next_functionality_=", static_cast(next_functionality_)); ++(*this); } C10_API self_type& operator++(); self_type operator++(int) { self_type previous_iterator = *this; ++(*this); return previous_iterator; } bool operator==(const self_type& rhs) const { return next_functionality_ == rhs.next_functionality_ && current_dispatchkey_idx_ == rhs.current_dispatchkey_idx_ && next_backend_ == rhs.next_backend_ && current_backendcomponent_idx_ == rhs.current_backendcomponent_idx_; } bool operator!=(const self_type& rhs) const { return next_functionality_ != rhs.next_functionality_ || current_dispatchkey_idx_ != rhs.current_dispatchkey_idx_ || next_backend_ != rhs.next_backend_ || current_backendcomponent_idx_ != rhs.current_backendcomponent_idx_; } DispatchKey operator*() const { auto functionality_key = static_cast(current_dispatchkey_idx_); if (isPerBackendFunctionalityKey(functionality_key)) { auto next_key = toRuntimePerBackendFunctionalityKey( functionality_key, static_cast(current_backendcomponent_idx_)); // We expect all of the Dense, Sparse, Quantized, and Autograd keys to // be ordered the same way with respect to their backends TORCH_INTERNAL_ASSERT( toBackendComponent(next_key) == static_cast(current_backendcomponent_idx_), "Tried to map functionality key ", toString(functionality_key), " and backend bit ", toString( static_cast(current_backendcomponent_idx_)), " to a runtime key, but ended up with ", toString(next_key), ". This can happen if the order of the backend dispatch keys in DispatchKey.h isn't consistent.", " Please double check that enum for inconsistencies."); return next_key; } else { return functionality_key; } } private: const uint64_t* data_ptr_; uint8_t next_functionality_; uint8_t next_backend_; uint8_t current_dispatchkey_idx_; uint8_t current_backendcomponent_idx_; }; public: // Returns iterator to the first key in the set. If no keys are in the // set, then will return the end iterator. iterator begin() const { return iterator(&repr_); } // We do not need to iterate beyond EndOfFunctionalityKeys so we will treat // this as the end iterator. iterator end() const { return iterator(&repr_, iterator::end_iter_mask_val); } }; C10_API std::string toString(DispatchKeySet); C10_API std::ostream& operator<<(std::ostream&, DispatchKeySet); C10_API inline int getDispatchTableIndexForDispatchKey(DispatchKey k) { return DispatchKeySet(k).getDispatchTableIndexForDispatchKeySet(); } // Alias key DispatchKey::Autograd maps to // (autograd_dispatch_keyset x full_backend_mask) // NB: keys in this set also get associated with CompositeImplicitAutograd // // Note [autograd_dispatch_keyset Does Not Include Backend Bits] // We don't want to include any backend bits (BackendComponent::CPUBit, etc) // directly in autograd_dispatch_keyset. // Why? keysets like autograd_dispatch_keyset are commonly used to remove // autograd keys from a DispatchKeySet throughout the code base. However, you // are only allowed to remove functionality bits from a keyset, not backend // bits. See Note [Removing keys from DispatchKeySet Only Affects Functionality // Keys] for details. To be consistent and avoid confusion, we're explicitly // setting up autograd_dispatch_keyset to not have any backend bits. constexpr DispatchKeySet autograd_dispatch_keyset = DispatchKeySet({ DispatchKey::AutogradFunctionality, DispatchKey::AutogradOther, DispatchKey::AutogradNestedTensor, }); constexpr DispatchKeySet autocast_dispatch_keyset = DispatchKeySet({ DispatchKey::AutocastCPU, DispatchKey::AutocastMPS, DispatchKey::AutocastCUDA, DispatchKey::AutocastXPU, DispatchKey::AutocastIPU, DispatchKey::AutocastHPU, DispatchKey::AutocastXLA, DispatchKey::AutocastPrivateUse1, }); // See Note [TLS Initialization] constexpr DispatchKeySet default_included_set = DispatchKeySet({ DispatchKey::BackendSelect, DispatchKey::ADInplaceOrView, }); constexpr DispatchKeySet default_excluded_set = DispatchKeySet({ DispatchKey::AutocastCPU, DispatchKey::AutocastMPS, DispatchKey::AutocastCUDA, DispatchKey::AutocastXPU, DispatchKey::AutocastIPU, DispatchKey::AutocastHPU, DispatchKey::AutocastXLA, DispatchKey::AutocastPrivateUse1, }); constexpr DispatchKeySet autograd_dispatch_keyset_with_ADInplaceOrView = autograd_dispatch_keyset | DispatchKeySet(DispatchKey::ADInplaceOrView); constexpr DispatchKeySet python_ks = DispatchKeySet({ DispatchKey::Python, DispatchKey::PythonTLSSnapshot, }); constexpr DispatchKeySet sparse_ks = DispatchKeySet(DispatchKey::Sparse); constexpr DispatchKeySet sparse_csr_ks = DispatchKeySet(DispatchKey::SparseCsr); constexpr DispatchKeySet mkldnn_ks = DispatchKeySet(DispatchKey::MkldnnCPU); // backend dispatch keys that map to DispatchKey::AutogradOther // NB: keys in this set also get associated with CompositeImplicitAutograd constexpr DispatchKeySet autogradother_backends = DispatchKeySet( // HIP and VE aren't in this list: they now have their own backend bits // which means that they can now have their own Autograd keys. // Technically, HIP will now redispatch to its own custom AutogradHIP // slot in the runtime table. {DispatchKey::FPGA, DispatchKey::MAIA, DispatchKey::Vulkan, DispatchKey::Metal, DispatchKey::CustomRNGKeyId, DispatchKey::MkldnnCPU, // Sparse and Quantized backends also live here. DispatchKey::Sparse, DispatchKey::SparseCsr, DispatchKey::Quantized}) // Including the backend bits because this keyset is used during op // registration, which requires looping over all runtime autogradother // backend keys. | DispatchKeySet(DispatchKeySet::RAW, full_backend_mask); // The set of dispatch keys that come after autograd // n.b. this relies on the fact that AutogradOther is currently the lowest // Autograd key constexpr DispatchKeySet after_autograd_keyset = DispatchKeySet(DispatchKeySet::FULL_AFTER, c10::DispatchKey::AutogradOther); // The set of dispatch keys that come after ADInplaceOrView constexpr DispatchKeySet after_ADInplaceOrView_keyset = DispatchKeySet( DispatchKeySet::FULL_AFTER, c10::DispatchKey::ADInplaceOrView); // The set of dispatch keys that come after Functionalize constexpr DispatchKeySet after_func_keyset = DispatchKeySet(DispatchKeySet::FULL_AFTER, c10::DispatchKey::Functionalize) .remove( // NOTE: we also need to remove ADInplaceOrView from the keyset when // redispatching after the func kernels. This is because we're not // calling the same op; we originally called an inplace op, and now // we aren't. The original key calculation figured out which keys // were Fallthrough based on the inplace op. That means that it did // not include the ADInPlaceOrView kernel as a fallthrough key. // However, we WANT the ADInPlaceOrView kernel to be ignored now // that we're calling an out-of-place op. Re-invoking // Dispatcher::call would re-run the Fallthrough key calculation and // get us that, But at::redispatch is more performant. We can get // away with it by explicitly removing the key here. c10::DispatchKey::ADInplaceOrView); constexpr DispatchKeySet backend_bitset_mask = DispatchKeySet(DispatchKeySet::RAW, (1ULL << num_backends) - 1); constexpr auto inplace_or_view_ks = DispatchKeySet(DispatchKey::ADInplaceOrView); constexpr auto autograd_cpu_ks = DispatchKeySet(DispatchKey::AutogradCPU); constexpr auto autograd_ipu_ks = DispatchKeySet(DispatchKey::AutogradIPU); constexpr auto autograd_xpu_ks = DispatchKeySet(DispatchKey::AutogradXPU); constexpr auto autograd_cuda_ks = DispatchKeySet(DispatchKey::AutogradCUDA); constexpr auto autograd_xla_ks = DispatchKeySet(DispatchKey::AutogradXLA); constexpr auto autograd_lazy_ks = DispatchKeySet(DispatchKey::AutogradLazy); constexpr auto autograd_meta_ks = DispatchKeySet(DispatchKey::AutogradMeta); constexpr auto autograd_mps_ks = DispatchKeySet(DispatchKey::AutogradMPS); constexpr auto autograd_hpu_ks = DispatchKeySet(DispatchKey::AutogradHPU); constexpr auto autograd_privateuse1_ks = DispatchKeySet(DispatchKey::AutogradPrivateUse1); constexpr auto autograd_privateuse2_ks = DispatchKeySet(DispatchKey::AutogradPrivateUse2); constexpr auto autograd_privateuse3_ks = DispatchKeySet(DispatchKey::AutogradPrivateUse3); constexpr auto autograd_other_ks = DispatchKeySet(DispatchKey::AutogradOther); constexpr auto autograd_nested = DispatchKeySet(DispatchKey::AutogradNestedTensor); // keyset corresponding to functorch keys that have their own dedicated // TensorImpl subclass. constexpr auto functorch_transforms_ks = DispatchKeySet( {DispatchKey::FuncTorchBatched, DispatchKey::FuncTorchVmapMode, DispatchKey::Batched, DispatchKey::VmapMode, DispatchKey::FuncTorchGradWrapper}); constexpr auto functorch_batched_ks = DispatchKeySet({DispatchKey::FuncTorchBatched}); // This keyset has: // (1) the functionality bits corresponding to backends (dense, sparse, // quantized) (2) all of the backend bits set constexpr DispatchKeySet backend_functionality_keys = DispatchKeySet({ DispatchKey::Dense, DispatchKey::Quantized, DispatchKey::Sparse, DispatchKey::SparseCsr, }) | DispatchKeySet(DispatchKeySet::RAW, full_backend_mask); struct OpTableOffsetAndMask { uint16_t offset; uint16_t backend_mask; }; static_assert( num_backends <= 16, "Right now we expect the number of backends not to exceed 16. In the (unlikely) event" " that this changes, the size of OpTableOffsetAndMask::backend_mask needs to be increased too."); // true if t is a backend dispatch key C10_API bool isBackendDispatchKey(DispatchKey t); // Resolve alias dispatch key to DispatchKeySet if applicable C10_API DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t); // Resolve alias dispatch key to DispatchKeySet if applicable, // and check if k is a part of that set C10_API bool runtimeDispatchKeySetHas(DispatchKey t, DispatchKey k); // Returns a DispatchKeySet of all backend keys mapped to Autograd dispatch key // t, DispatchKeySet is empty if t is not alias of DispatchKey::Autograd. C10_API DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t); // Returns a DispatchKeySet of autograd related keys mapped to backend. // for a given backend key, use the associated autograd key. // for non-backend keys, use AutogradOther as a default. // Note: it's convenient and fast to return a default here rather than (say) // returning an std::optional, or throwing. But it makes callers // responsible for either a) enforcing the invariant that only backend keys // be passed as arguments, or b) interpreting our return value carefully. inline DispatchKeySet getAutogradRelatedKeySetFromBackend(BackendComponent t) { switch (t) { case BackendComponent::CPUBit: return inplace_or_view_ks | autograd_cpu_ks; case BackendComponent::IPUBit: return inplace_or_view_ks | autograd_ipu_ks; case BackendComponent::XPUBit: return inplace_or_view_ks | autograd_xpu_ks; case BackendComponent::CUDABit: return inplace_or_view_ks | autograd_cuda_ks; case BackendComponent::XLABit: return inplace_or_view_ks | autograd_xla_ks; case BackendComponent::LazyBit: return inplace_or_view_ks | autograd_lazy_ks; case BackendComponent::MetaBit: return inplace_or_view_ks | autograd_meta_ks; case BackendComponent::MPSBit: return inplace_or_view_ks | autograd_mps_ks; case BackendComponent::HPUBit: return inplace_or_view_ks | autograd_hpu_ks; case BackendComponent::PrivateUse1Bit: return inplace_or_view_ks | autograd_privateuse1_ks; case BackendComponent::PrivateUse2Bit: return inplace_or_view_ks | autograd_privateuse2_ks; case BackendComponent::PrivateUse3Bit: return inplace_or_view_ks | autograd_privateuse3_ks; default: return inplace_or_view_ks | autograd_other_ks; } } // Returns a DispatchKeySet of autocast related keys mapped to backend. inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) { constexpr auto autocast_cpu_ks = DispatchKeySet(DispatchKey::AutocastCPU); constexpr auto autocast_xpu_ks = DispatchKeySet(DispatchKey::AutocastXPU); constexpr auto autocast_ipu_ks = DispatchKeySet(DispatchKey::AutocastIPU); constexpr auto autocast_hpu_ks = DispatchKeySet(DispatchKey::AutocastHPU); constexpr auto autocast_cuda_ks = DispatchKeySet(DispatchKey::AutocastCUDA); constexpr auto autocast_xla_ks = DispatchKeySet(DispatchKey::AutocastXLA); constexpr auto autocast_privateuse1_ks = DispatchKeySet(DispatchKey::AutocastPrivateUse1); constexpr auto autocast_mps_ks = DispatchKeySet(DispatchKey::AutocastMPS); switch (t) { case BackendComponent::CPUBit: return autocast_cpu_ks; case BackendComponent::XPUBit: return autocast_xpu_ks; case BackendComponent::IPUBit: return autocast_ipu_ks; case BackendComponent::HPUBit: return autocast_hpu_ks; case BackendComponent::CUDABit: return autocast_cuda_ks; case BackendComponent::XLABit: return autocast_xla_ks; case BackendComponent::PrivateUse1Bit: return autocast_privateuse1_ks; case BackendComponent::MPSBit: return autocast_mps_ks; default: return DispatchKeySet(); } } // returns the "backend" DispatchKey of highest priority in the set. // This is basically like highestBackendKey(), except that we have some // "functionality" bits that correspond to backends (Sparse, Quantized) inline DispatchKey highestPriorityBackendTypeId(DispatchKeySet ks) { return (ks & backend_functionality_keys).highestPriorityTypeId(); } // This API exists because we have a use case for checking // getRuntimeDispatchKeySet(alias).has(DispatchKey::Undefined) // in OperatorEntry.cpp but we disallow it in has() API. C10_API bool isIncludedInAlias(DispatchKey k, DispatchKey alias); // Historically, every tensor only had a single DispatchKey, and it was always // something like CPU, and there wasn't any of this business where TLS // could cause the DispatchKey of a tensor to change. But we still have some // legacy code that is still using DispatchKey for things like instanceof // checks; if at all possible, refactor the code to stop using DispatchKey in // those cases. inline DispatchKey legacyExtractDispatchKey(DispatchKeySet s) { // NB: If you add any extra keys that can be stored in TensorImpl on // top of existing "backend" keys like CPU/CUDA, you need to add it // here. At the moment, autograd keys and ADInplaceOrView key need this // treatment; return (s - autograd_dispatch_keyset_with_ADInplaceOrView - autocast_dispatch_keyset - DispatchKeySet( {DispatchKey::Functionalize, DispatchKey::PythonTLSSnapshot, DispatchKey::FuncTorchGradWrapper, DispatchKey::FuncTorchVmapMode, DispatchKey::FuncTorchBatched, DispatchKey::Python})) .highestPriorityTypeId(); } template using is_not_DispatchKeySet = std::negation>; // Given a function type, constructs a function_traits type that drops the first // parameter type if the first parameter is of type DispatchKeySet. NB: // DispatchKeySet is currently explicitly hidden from JIT (mainly to avoid // pushing unnecessary arguments on the stack - see Note [ Plumbing Keys Through // the Dispatcher] for details). If at any point in the future we need to expose // this type to JIT, revisit the usage of this type alias. template using remove_DispatchKeySet_arg_from_func = guts::make_function_traits_t< typename guts::infer_function_traits_t::return_type, typename std::conditional_t< std::is_same_v< DispatchKeySet, typename guts::typelist::head_with_default_t< void, typename guts::infer_function_traits_t< FuncType>::parameter_types>>, guts::typelist::drop_if_nonempty_t< typename guts::infer_function_traits_t::parameter_types, 1>, typename guts::infer_function_traits_t::parameter_types>>; } // namespace c10