#pragma once #include #include namespace c10::impl { enum class TorchDispatchModeKey : int8_t { FAKE, PROXY, FUNCTIONAL, NUM_MODE_KEYS }; using PyObject_TorchDispatchMode = SafePyObjectT; struct C10_API TorchDispatchModeTLS { // This API is NOT invariant safe. // It must not take in an infra mode that uses TorchDispatchModeKey // If you're pushing an infra mode onto the stack, we expect // you to use set_mode static void push_non_infra_mode_onto_stack( std::shared_ptr mode); // Pops the top mode of the stack, // giving precedence to user modes before attempting to pop // any infra modes static const std::shared_ptr pop_stack(); // Returns the highest-priority infra mode on the stack, // along with its mode key. static const std:: tuple, TorchDispatchModeKey> pop_highest_infra_mode(); static const std::shared_ptr& get_stack_at( int64_t idx); static int64_t stack_len(); static const std::optional> get_mode(TorchDispatchModeKey mode_key); static const std::optional> unset_mode(TorchDispatchModeKey mode_key); static void set_mode( const std::shared_ptr& mode, TorchDispatchModeKey mode_key); static const TorchDispatchModeTLS& get_state(); static void set_state(TorchDispatchModeTLS state); static bool any_modes_set(bool skip_infra_modes = false); private: std::vector> stack_; // Users are allowed to push multiple ProxyTorchDispatchMode objects onto the // stack // However, we only allow a single FakeTensorMode onto the stack at a time // (Pushing additional FakeTensorModes onto the stack is a no-op) std::array< std::optional>, static_cast(TorchDispatchModeKey::NUM_MODE_KEYS)> infra_modes_; }; C10_API bool dispatch_mode_enabled(); C10_API std::string to_string(TorchDispatchModeKey mode_key); } // namespace c10::impl