#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace torch::autograd { namespace { // look what you made me do >.< // Divergent paths for per-Impl stream recording that leak implementation // details of the impls should not be needed here. // See https://github.com/pytorch/pytorch/issues/60306 // TODO: clean this up when https://github.com/pytorch/pytorch/issues/60306 is // improved void record_stream_any_impl(Variable& var, c10::Stream& stream) { const auto guard = c10::impl::VirtualGuardImpl(device_of(var).value().type()); if (C10_UNLIKELY(at::isBatchedTensor(var))) { auto* impl = at::maybeGetBatchedImpl(var); if (impl) { guard.recordDataPtrOnStream(impl->value().storage().data_ptr(), stream); } else { TORCH_INTERNAL_ASSERT(false, "Expected batched tensor"); } } else { switch (var.layout()) { case c10::kSparseCsr: case c10::kSparseCsc: case c10::kSparseBsr: case c10::kSparseBsc: { auto* impl = at::sparse_csr::get_sparse_csr_impl(var); guard.recordDataPtrOnStream( impl->values().storage().data_ptr(), stream); guard.recordDataPtrOnStream( impl->compressed_indices().storage().data_ptr(), stream); guard.recordDataPtrOnStream( impl->plain_indices().storage().data_ptr(), stream); break; } case c10::kSparse: { auto* impl = at::sparse::get_sparse_impl(var); guard.recordDataPtrOnStream( impl->values().storage().data_ptr(), stream); guard.recordDataPtrOnStream( impl->indices().storage().data_ptr(), stream); break; } case c10::kStrided: guard.recordDataPtrOnStream(var.storage().data_ptr(), stream); break; default: TORCH_INTERNAL_ASSERT( false, "Unknown layout in record_stream_any_impl"); } } } bool can_accumulate_inplace(const Variable& v) { return ( // `v` is a "vanilla" Tensor !(at::isTensorSubclassLike(v) || v._is_zerotensor() || v.is_nested()) && // with a favorable memory layout v.is_non_overlapping_and_dense() && // and we hold the last reference at::caching::adjusted_use_count(v) == 1 && v.has_storage() && v.storage().use_count() == 1); } } // anonymous namespace static void accumulate( std::vector& buffer, const size_t pos, Variable&& var) { TORCH_INTERNAL_ASSERT(pos < buffer.size()); auto& old_var = buffer[pos]; // If we hold the last reference to `old_var` AND its storage we will try to // repurpose it to store the output. (Or, if `old_var` is sparse then `var` // becomes the candidate output Tensor.) We only do this if: // 1) GradMode is disabled since Autograd has special handling for inplace // mutation which we don't want to trigger. // // 2) We hold the last reference. // (Both `.use_count` and `.storage().use_count()` are one) // // 3) The candidate tensor is a contiguous, non-overlapping, dense, and // otherwise stock standard Tensor. // // 4) The candidate is mutable. Currently only ZeroTensors are immutable. // // 5) The other Tensor is not a Tensor subclass (except sparse), since // it's hard to predict the semantics of arbitrary subclass behavior. // NOLINTNEXTLINE(bugprone-branch-clone) if (at::GradMode::is_enabled()) { buffer[pos] = old_var + var; } else if ( // ATen doesn't route sparse additions correctly... old_var.is_sparse() || old_var.is_sparse_csr()) { if (can_accumulate_inplace(var)) { buffer[pos] = var.add_(old_var); } else { buffer[pos] = var + old_var; } } else if ( can_accumulate_inplace(old_var) && !at::isTensorSubclassLike(var)) { buffer[pos] = old_var.add_(var); } else { buffer[pos] = old_var + var; } } void InputBuffer::add( size_t pos, Variable&& var, const std::optional& opt_producer_stream, const std::optional& opt_consumer_stream) { TORCH_INTERNAL_ASSERT(pos < buffer.size()); if (!var.defined()) { return; } // Switches to accumulate device // The device (and stream) chosen for accumulation is: // (1) var is not a CUDA/privateuse1 variable. Accumulation happens on var's // device. (2) var is a CUDA/privateuse1 variable and it, the consumer, and // the producer share the same device: // (2a) Uses the consumer's stream as the accumulation stream // (2b) Syncs the accumulation stream with the producer's stream (if // different) (2c) Accumulates. // (3) var is a CUDA/privateuse1 variable and it shares a device with the // consumer but not the producer: // (3a) Uses the consumer's stream as the accumulation stream // (3b) Syncs the accumulation stream with the consumer device's default // stream (3c) Accumulates. // (4) var is a CUDA/privateuse1 variable and it shares a device with the // producer but not the consumer: // (4a) Uses the producer device's default stream as the accumulation // stream (4b) Syncs the accumulation stream with the producer's // stream (4c) Accumulates. // (5) var is a CUDA/privateuse1 variable and it does not share a device with // the consumer or producer. // Accumulation happens on the var device's default stream. TORCH_INTERNAL_ASSERT(device_of(var)); std::optional opt_accumulate_stream = std::nullopt; const auto device_type = device_of(var).value().type(); // NOLINTNEXTLINE(bugprone-unchecked-optional-access) if (device_of(var)->is_cuda() || device_of(var)->is_privateuseone()) { const auto on_producer = opt_producer_stream && device_of(var) == opt_producer_stream->device(); const auto on_consumer = opt_consumer_stream && device_of(var) == opt_consumer_stream->device(); if (on_producer && on_consumer) { // (2a) opt_accumulate_stream = opt_consumer_stream; if (opt_accumulate_stream != opt_producer_stream) { // (2b) auto event = c10::Event{device_type}; event.record(*opt_producer_stream); opt_accumulate_stream->wait(event); record_stream_any_impl(var, *opt_accumulate_stream); } } else { std::optional opt_sync_stream = std::nullopt; const auto guard = c10::impl::VirtualGuardImpl{device_type}; if (on_consumer && !on_producer) { // (3a) opt_accumulate_stream = opt_consumer_stream; opt_sync_stream = guard.getDefaultStream(opt_consumer_stream->device()); } else if (on_producer && !on_consumer) { // (4a) opt_accumulate_stream = guard.getDefaultStream(opt_producer_stream->device()); opt_sync_stream = opt_producer_stream; } else { // (5) // NOLINTNEXTLINE(bugprone-unchecked-optional-access) opt_accumulate_stream = guard.getDefaultStream(*device_of(var)); } if (opt_sync_stream && (opt_accumulate_stream != opt_sync_stream)) { // (3b), (4b) c10::OptionalDeviceGuard device_guard{opt_sync_stream->device()}; auto event = c10::Event{device_type}; event.record(*opt_sync_stream); opt_accumulate_stream->wait(event); const auto guard = c10::impl::VirtualGuardImpl(device_type); record_stream_any_impl(var, *opt_accumulate_stream); } } } auto& old_var = buffer[pos]; if (!old_var.defined()) { buffer[pos] = std::move(var); } else { if (opt_accumulate_stream) { c10::OptionalStreamGuard stream_guard{opt_accumulate_stream}; accumulate(buffer, pos, std::move(var)); } else { // (1) non-CUDA/privateuse1 variable // Accumulation happens on variable's device c10::OptionalDeviceGuard device_guard{device_of(var)}; accumulate(buffer, pos, std::move(var)); } } } auto InputBuffer::device() const -> at::Device { // Since we pick the first non-CPU tensor, this won't work with // mixed device-type operations (e.g., an op that is both CUDA // and XLA). This is *incredibly* unlikely, so we don't worry // about it. for (auto& var : buffer) { if (var.defined()) { auto device = var.device(); if (device.type() != at::kCPU) { return device; } } } // Only report to the CPU thread if there really were no tensors // from other devices. return at::kCPU; } auto InputBuffer::variables(InputBuffer&& g) -> std::vector { std::vector result = std::move(g.buffer); return result; } } // namespace torch::autograd