#pragma once #include #include #include namespace c10::impl { /** * A StreamGuard is an RAII class that changes the current device * to the device corresponding to some stream, and changes the * default stream on that device to be this stream. * * InlineStreamGuard is a helper class for implementing StreamGuards. * See InlineDeviceGuard for guidance on how to use this class. */ template class InlineStreamGuard : private InlineDeviceGuard { public: /// No default constructor, see Note [Omitted default constructor from RAII] explicit InlineStreamGuard() = delete; /// Set the current device to the device associated with the passed stream, /// and set the current stream on that device to the passed stream. explicit InlineStreamGuard(Stream stream) : InlineDeviceGuard(stream.device()), original_stream_of_original_device_( this->impl_.getStream(original_device())), original_stream_of_current_device_(this->impl_.exchangeStream(stream)), current_stream_(stream) {} /// This constructor exists purely for testing template < typename U = T, typename = typename std::enable_if_t>> explicit InlineStreamGuard( Stream stream, const DeviceGuardImplInterface* impl) : InlineDeviceGuard( stream.device(), impl ? impl : getDeviceGuardImpl(stream.device_type())), original_stream_of_original_device_( this->impl_.getStream(original_device())), original_stream_of_current_device_(this->impl_.exchangeStream(stream)), current_stream_(stream) {} /// Copy is disallowed InlineStreamGuard(const InlineStreamGuard&) = delete; InlineStreamGuard& operator=(const InlineStreamGuard&) = delete; /// Move is disallowed, as StreamGuard does not have an uninitialized state, /// which is required for moves on types with nontrivial destructors. InlineStreamGuard(InlineStreamGuard&& other) = delete; InlineStreamGuard& operator=(InlineStreamGuard&& other) = delete; ~InlineStreamGuard() { this->impl_.exchangeStream(original_stream_of_current_device_); } /// Resets the currently set stream to the original stream and /// the currently set device to the original device. Then, /// set the current device to the device associated with the passed stream, /// and set the current stream on that device to the passed stream. /// /// NOTE: this implementation may skip some stream/device setting if /// it can prove that it is unnecessary. /// /// WARNING: reset_stream does NOT preserve previously set streams on /// different devices. If you need to set streams on multiple devices /// use MultiStreamGuard instead. void reset_stream(Stream stream) { // TODO: make a version that takes an impl argument. Unfortunately, // that will require SFINAE because impl is only valid for the // VirtualGuardImpl specialization. if (stream.device() == this->current_device()) { this->impl_.exchangeStream(stream); current_stream_ = stream; } else { // Destruct and reconstruct the StreamGuard in-place this->impl_.exchangeStream(original_stream_of_current_device_); this->reset_device(stream.device()); original_stream_of_current_device_ = this->impl_.exchangeStream(stream); current_stream_ = stream; } } // It's not clear if set_device should also reset the current stream // if the device is unchanged; therefore, we don't provide it. // The situation is somewhat clearer with reset_device, but it's still // a pretty weird thing to do, so haven't added this either. /// Returns the stream of the original device prior to this guard. Subtly, /// the stream returned here is the original stream of the *original* /// device; i.e., it's the stream that your computation *would* have /// been put on, if it hadn't been for this meddling stream guard. /// This is usually what you want. Stream original_stream() const { return original_stream_of_original_device_; } /// Returns the most recent stream that was set using this device guard, /// either from construction, or via set_stream. Stream current_stream() const { return current_stream_; } /// Returns the most recent device that was set using this device guard, /// either from construction, or via set_device/reset_device/set_index. Device current_device() const { return InlineDeviceGuard::current_device(); } /// Returns the device that was set at the most recent reset_stream(), /// or otherwise the device at construction time. Device original_device() const { return InlineDeviceGuard::original_device(); } private: Stream original_stream_of_original_device_; // what the user probably cares about Stream original_stream_of_current_device_; // what we need to restore Stream current_stream_; }; /** * An OptionalStreamGuard is an RAII class that sets a device to some value on * initialization, and resets the device to its original value on destruction. * See InlineOptionalDeviceGuard for more guidance on how to use this class. */ template class InlineOptionalStreamGuard { public: /// Creates an uninitialized stream guard. explicit InlineOptionalStreamGuard() : guard_() // See Note [Explicit initialization of optional fields] {} /// Set the current device to the device associated with the passed stream, /// and set the current stream on that device to the passed stream, /// if the passed stream is not nullopt. explicit InlineOptionalStreamGuard(std::optional stream_opt) : guard_() { if (stream_opt.has_value()) { guard_.emplace(stream_opt.value()); } } /// All constructors of StreamGuard are valid for OptionalStreamGuard template explicit InlineOptionalStreamGuard(Args&&... args) : guard_(std::in_place, std::forward(args)...) {} // See Note [Move construction for RAII guards is tricky] InlineOptionalStreamGuard(InlineOptionalStreamGuard&& other) = delete; // See Note [Move assignment for RAII guards is tricky] InlineOptionalStreamGuard& operator=(InlineOptionalStreamGuard&& other) = delete; /// Resets the currently set stream to the original stream and /// the currently set device to the original device. Then, /// set the current device to the device associated with the passed stream, /// and set the current stream on that device to the passed stream. /// Initializes the OptionalStreamGuard if it was not previously initialized. void reset_stream(Stream stream) { if (guard_.has_value()) { guard_->reset_stream(stream); } else { guard_.emplace(stream); } } /// Returns the stream that was set at the time the guard was most recently /// initialized, or nullopt if the guard is uninitialized. std::optional original_stream() const { return guard_.has_value() ? std::make_optional(guard_->original_stream()) : std::nullopt; } /// Returns the most recent stream that was set using this stream guard, /// either from construction, or via reset_stream, if the guard is /// initialized, or nullopt if the guard is uninitialized. std::optional current_stream() const { return guard_.has_value() ? std::make_optional(guard_->current_stream()) : std::nullopt; } /// Restore the original device and stream, resetting this guard to /// uninitialized state. void reset() { guard_.reset(); } private: std::optional> guard_; }; template class InlineMultiStreamGuard { public: /// Calls `set_stream` on each of the streams in the list. /// This may be useful if you need to set different streams /// for different devices. explicit InlineMultiStreamGuard(ArrayRef streams) { if (!streams.empty()) { impl_.emplace(getDeviceTypeOfStreams(streams)); original_streams_.reserve(streams.size()); for (const Stream& s : streams) { original_streams_.emplace_back(this->impl_->exchangeStream(s)); } } } /// Copy is disallowed InlineMultiStreamGuard(const InlineMultiStreamGuard&) = delete; InlineMultiStreamGuard& operator=(const InlineMultiStreamGuard&) = delete; /// Move is disallowed, as StreamGuard does not have an uninitialized state, /// which is required for moves on types with nontrivial destructors. InlineMultiStreamGuard(InlineMultiStreamGuard&& other) = delete; InlineMultiStreamGuard& operator=(InlineMultiStreamGuard&& other) = delete; ~InlineMultiStreamGuard() noexcept { if (this->impl_.has_value()) { for (const Stream& s : original_streams_) { this->impl_->exchangeStream(s); } } } protected: std::optional impl_; private: /// The original streams that were active on all devices. std::vector original_streams_; static DeviceType getDeviceTypeOfStreams(ArrayRef streams) { TORCH_INTERNAL_ASSERT(!streams.empty()); DeviceType type = streams[0].device_type(); for (const auto idx : c10::irange(1, streams.size())) { TORCH_CHECK_VALUE( streams[idx].device_type() == type, "Streams have a mix of device types: stream 0 is on ", streams[0].device(), " while stream ", idx, " is on device ", streams[idx].device()); } return type; } }; } // namespace c10::impl