#include #include namespace c10 { // backend_dispatch_keyset includes all dispatch keys that map to backends. // Alias key DispatchKey::CompositeExplicitAutograd maps to // backend_dispatch_keyset constexpr DispatchKeySet backend_dispatch_keyset = autogradother_backends | DispatchKeySet(DispatchKey::Dense); // See Note [CompositeExplicitAutogradNonFunctional Key] // We have several types of decompositions in aten, that each have their own // alias key. You should register your decomposition to the // `CompositeExplicitAutogradNonFunctional key` if: (1) It's an out-of-place op // (2) It decomposes into one more mutation ops // (3) It has a derivative formula // (In theory we could also have a separate key for // "CompositeImplicitAutogradNonFunctional", but there isn't much of a use // case for it currently). // This key is important for "functional" backends like LazyTensor / XLA. // If you're a backend that only expects to deal with "functional ops", // then you don't want to decompose a functional op into an op that causes // aliasing. You should just directly write a kernel for that functional op // instead! constexpr DispatchKeySet non_functional_backend_dispatch_keyset = backend_dispatch_keyset // XLA and LazyTensor are currently the only 2 backends in core // that use functionalization pass in eager mode. .remove(DispatchKey::Sparse) .remove_backend(BackendComponent::XLABit) .remove_backend(BackendComponent::LazyBit); bool isBackendDispatchKey(DispatchKey t) { return t != DispatchKey::Undefined // See Note [No Alias Keys in DispatchKeySet] && !isAliasDispatchKey(t) // Note [NestedTensor Not Included in Backend Keys] // NestedTensor has been explicitly removed from the "backend keyset" due // to incompatibility with some kernels, so we don't want it to be // included in CompositeExplicitAutograd kernels. && t != DispatchKey::NestedTensor && backend_dispatch_keyset.has(t); } // math_dispatch_keyset contains all keys in backend_dispatch_keyset and // autograd_dispatch_keyset Alias key DispatchKey::CompositeImplicitAutograd // maps to [math_dispatch_keyset x full_backend_mask] constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset | autograd_dispatch_keyset | // See Note [NestedTensor Not Included in Backend Keys] // The caveat to that note is that nested_tensor is a special case // where we would like to support composite implicit kernels but not // explicit kernels therefore we manually add the key to the // math_dispatch_keyset DispatchKeySet{DispatchKey::NestedTensor} | // Functionalize should always re-use CompositeImplicit decomps. DispatchKeySet{DispatchKey::Functionalize}; constexpr DispatchKeySet nested_dispatch_keyset = DispatchKeySet( {DispatchKey::AutogradNestedTensor, DispatchKey::NestedTensor}) | DispatchKeySet(DispatchKeySet::RAW, full_backend_mask); DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) { TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined); switch (t) { case DispatchKey::Autograd: // See Note [autograd_dispatch_keyset Does Not Include Backend Bits] // That's why we OR it with a mask of the backend bits here. // getRuntimeDispatchKeySet() expects to return a keyset of runtime // dispatch keys, like AutogradCPU, but that requires having backend bits. return autograd_dispatch_keyset | DispatchKeySet(DispatchKeySet::RAW, full_backend_mask); case DispatchKey::CompositeImplicitAutograd: return math_dispatch_keyset; case DispatchKey::CompositeImplicitAutogradNestedTensor: return nested_dispatch_keyset; case DispatchKey::CompositeExplicitAutograd: return backend_dispatch_keyset; case DispatchKey::CompositeExplicitAutogradNonFunctional: return non_functional_backend_dispatch_keyset; default: return DispatchKeySet(t); } } bool runtimeDispatchKeySetHas(DispatchKey t, DispatchKey k) { TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined); switch (t) { case DispatchKey::Autograd: return autograd_dispatch_keyset.has(toFunctionalityKey(k)); case DispatchKey::CompositeImplicitAutograd: // See Note [NestedTensor Not Included in Backend Keys] return math_dispatch_keyset.has(k); case DispatchKey::CompositeImplicitAutogradNestedTensor: // See Note [NestedTensor Not Included in Backend Keys] return nested_dispatch_keyset.has(k); case DispatchKey::CompositeExplicitAutograd: // See Note [NestedTensor Not Included in Backend Keys] return k != DispatchKey::NestedTensor && backend_dispatch_keyset.has(k); case DispatchKey::CompositeExplicitAutogradNonFunctional: // See Note [NestedTensor Not Included in Backend Keys] return k != DispatchKey::NestedTensor && non_functional_backend_dispatch_keyset.has(k); case DispatchKey::FuncTorchBatchedDecomposition: return functorch_batched_ks.has(k); default: return t == k; } } // for a given autograd key, return the (guaranteed nonempty) set of associated // backend keys. for a non-autograd key, return the empty keyset. DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) { switch (t) { case DispatchKey::AutogradCPU: return DispatchKeySet(DispatchKey::CPU); case DispatchKey::AutogradCUDA: return DispatchKeySet(DispatchKey::CUDA); case DispatchKey::AutogradXLA: return DispatchKeySet(DispatchKey::XLA); case DispatchKey::AutogradLazy: return DispatchKeySet(DispatchKey::Lazy); case DispatchKey::AutogradMeta: return DispatchKeySet(DispatchKey::Meta); case DispatchKey::AutogradMPS: return DispatchKeySet(DispatchKey::MPS); case DispatchKey::AutogradHPU: return DispatchKeySet(DispatchKey::HPU); case DispatchKey::AutogradIPU: return DispatchKeySet(DispatchKey::IPU); case DispatchKey::AutogradXPU: return DispatchKeySet(DispatchKey::XPU); case DispatchKey::AutogradPrivateUse1: return DispatchKeySet(DispatchKey::PrivateUse1); case DispatchKey::AutogradPrivateUse2: return DispatchKeySet(DispatchKey::PrivateUse2); case DispatchKey::AutogradPrivateUse3: return DispatchKeySet(DispatchKey::PrivateUse3); case DispatchKey::AutogradNestedTensor: return DispatchKeySet(DispatchKey::NestedTensor) | DispatchKeySet(DispatchKeySet::RAW, full_backend_mask); case DispatchKey::AutogradOther: return autogradother_backends; default: return DispatchKeySet(); } } bool isIncludedInAlias(DispatchKey k, DispatchKey alias) { return k != DispatchKey::Undefined && runtimeDispatchKeySetHas(alias, k); } std::string toString(DispatchKeySet ts) { std::stringstream ss; ss << ts; return ss.str(); } std::ostream& operator<<(std::ostream& os, DispatchKeySet ts) { if (ts.empty()) { os << "DispatchKeySet()"; return os; } os << "DispatchKeySet("; bool first = true; for (auto k : ts) { if (!first) { os << ", "; } os << k; first = false; } os << ")"; return os; } DispatchKeySet::iterator& DispatchKeySet::iterator::operator++() { TORCH_INTERNAL_ASSERT(next_functionality_ <= iterator::end_iter_mask_val); TORCH_INTERNAL_ASSERT(next_backend_ <= num_backends, next_backend_); // Create a masked version of the set representation to ignore previous // keys that we've iterated through. uint64_t masked_functionality_bits = llvm::maskTrailingZeros(next_functionality_) & *data_ptr_; uint64_t masked_backend_bits = llvm::maskTrailingZeros(next_backend_) & full_backend_mask & *data_ptr_; uint64_t first_functionality_idx = llvm::findFirstSet(masked_functionality_bits); uint64_t first_backendcomponent_idx = llvm::findFirstSet(masked_backend_bits); // If there are no keys, set to end iterator value if (first_functionality_idx == std::numeric_limits::max() || next_functionality_ == iterator::end_iter_mask_val) { // Set up state to be the same as end() next_functionality_ = iterator::end_iter_mask_val; current_dispatchkey_idx_ = iterator::end_iter_key_val; next_backend_ = 0; current_backendcomponent_idx_ = iterator::end_iter_key_val; return *this; } // The +1 is because of DispatchKey::Undefined and // BackendComponent::InvalidBit auto new_next_functionality = first_functionality_idx + 1; auto new_backendcomponent_idx = first_backendcomponent_idx + 1; // and the -num_backends is because the first bits in the // keyset are not Dispatch Keys. auto next_dispatchkey_idx = new_next_functionality - num_backends; // If the current functionality bit is a per-backend bit, we need special // handling if (isPerBackendFunctionalityKey( static_cast(next_dispatchkey_idx))) { // case 1: if the current backend is undefined, then there is no valid // backend instance of this functionality key so we can skip it. if (first_backendcomponent_idx == std::numeric_limits::max()) { // increment the functionality mask so we skip the current functionality // bit on the next increment. next_functionality_ = new_next_functionality; ++(*this); return *this; } // Otherwise, at this point we know what the current backend and // functionality bits are. current_dispatchkey_idx_ = next_dispatchkey_idx; current_backendcomponent_idx_ = new_backendcomponent_idx; // Next, we need to set up the masks for the next increment. uint64_t next_backendcomponent_bits = llvm::maskTrailingZeros(first_backendcomponent_idx + 1) & full_backend_mask & *data_ptr_; uint64_t next_backendcomponent_idx = llvm::findFirstSet(next_backendcomponent_bits); if (next_backendcomponent_idx == std::numeric_limits::max()) { // case 2: the current backend is valid, but there is not another backend // in the keyset. In this case, we need to bump the functionality mask and // reset the backend mask for the next increment next_functionality_ = new_next_functionality; next_backend_ = 0; } else { // case 3: we have another backend to iterate over. We want to iterate // over the same functionality bit next time, but a different backend bit. next_backend_ = first_backendcomponent_idx + 1; } } else { // Functionality bits that aren't per backend are simpler to handle. We can // ignore the backend bits. TORCH_INTERNAL_ASSERT(next_backend_ == 0); current_dispatchkey_idx_ = next_dispatchkey_idx; next_functionality_ = new_next_functionality; } return *this; } std::array initializeFunctionalityOffsetsAndMasks() { std::array offsets_and_masks; // manually set the first entry, which corresponds to Undefined. offsets_and_masks[0] = FunctionalityOffsetAndMask(0, 0); // loop through every functionality key (aside from Undefined). for (const auto functionality_idx : c10::irange(1, num_functionality_keys)) { // functionality_idx should be Dense -> 1, ... auto prev_offset_and_mask = offsets_and_masks[functionality_idx - 1]; auto k = static_cast(functionality_idx); // If the previous functionality was not per-backend, then we can just // increment the previous offset. Otherwise, the next offset = // previous_offset + num_backends. auto next_offset = prev_offset_and_mask.offset + (prev_offset_and_mask.mask == 0 ? 1 : num_backends); // the mask is used in the runtime index calculation to find the offset of // the backend. For non-per-backend functionalities, this offset should // always be 0. Otherwise, we need to get the index of the backend (which we // can do using a backend mask). auto next_mask = isPerBackendFunctionalityKey(k) ? full_backend_mask : 0; offsets_and_masks[functionality_idx] = FunctionalityOffsetAndMask(next_offset, next_mask); } // Sanity check that the computed offset index of the last functionality key // is correct. This assumes that the highest priority functionality key is not // per backend. TORCH_INTERNAL_ASSERT( offsets_and_masks[num_functionality_keys - 1].offset == (num_runtime_entries - 1), "num_runtime_entries: ", num_runtime_entries, "last_offset: ", offsets_and_masks[num_functionality_keys - 1].offset); return offsets_and_masks; } } // namespace c10