#pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include namespace torch::jit { struct Node; struct Value; struct Graph; struct Module; namespace tracer { using ::c10::ivalue::Shared; using ::c10::IValue; using ::c10::ivalue::Future; using ::c10::ArrayRef; using ::c10::TupleType; using ::c10::TupleTypePtr; using ::c10::ivalue::ConstantString; using torch::autograd::Variable; using variable_list = std::vector; TORCH_API std::atomic& getTracerStateWarnMode(); struct TORCH_API TracingState : public std::enable_shared_from_this { TracingState(); ~TracingState(); // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) std::shared_ptr graph; // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) bool warn = getTracerStateWarnMode(); // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) bool strict = true; // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) bool force_outplace = false; // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) std::function lookup_var_name_fn = [](const Variable& var) { return ""; }; void enterFrame() { env_stack.emplace_back(); } void leaveFrame() { env_stack.pop_back(); } void setValue(const IValue& v, Value* value); void delValue(const IValue& var); Value* getValue(const IValue& var); Value* getOutput(const IValue& var, size_t i); bool hasValue(const IValue& var) const; Node* createNode(c10::Symbol op_name, size_t num_outputs); void insertNode(Node* node); private: using WeakIValue = at::WeakIValue; struct WeakIValueHasher { size_t operator()(const WeakIValue& t) const { return t.hash(); } }; struct WeakIValueEq { bool operator()(const WeakIValue& t1, const WeakIValue& t2) const { return t1.isSameIdentity(t2); } }; using Frame = std::unordered_map; std::vector env_stack; }; // This is meant to be used as a thread local place, where we can store extra // info that gets lost when we call into ATen from Python bindings. One example // for when this happens is when we get an IntArrayRef argument with e.g. sizes // for view. When tracing, those might be tensors, which let us encode extra // data dependencies, but once they get to the ATen call where we actually have // the tracing logic, they get converted into a raw IntArrayRef, and we loose // all information. To prevent this, we temporarily stash it in here. struct ArgumentStash { struct IntArrayRefTrace : std::vector { IntArrayRefTrace(size_t size) : std::vector(size, nullptr) {} }; static bool empty() { return stash.intlists.empty(); } TORCH_API static void stashIntArrayRefElem( const std::string& arg_name, size_t size, size_t idx, const Variable& var); static bool hasIntArrayRef(const std::string& arg_name) { return stash.intlists.count(arg_name) > 0; } static IntArrayRefTrace popIntArrayRef(const std::string& arg_name) { auto info = std::move(stash.intlists.at(arg_name)); stash.intlists.erase(arg_name); return info; } // Value stashing: Use these methods to stash arguments which correspond // to regular Value*'s in the graph. i.e. they don't require special // handling like in the case of IntArrayRefs TORCH_API static void stashValue( const std::string& arg_name, size_t idx, const Variable& var, const c10::TypePtr& type = nullptr); static bool hasValue(const std::string& arg_name) { return stash.values.count(arg_name) > 0; } static Value* popValue(const std::string& arg_name) { auto info = stash.values.at(arg_name); stash.values.erase(arg_name); return info; } private: static thread_local ArgumentStash stash; std::unordered_map intlists; std::unordered_map values; }; // Retrieve or set the current tracing state. Returns a nullptr if tracing is // disabled. TORCH_API const std::shared_ptr& getTracingState(); TORCH_API void setTracingState(std::shared_ptr state); inline bool isTracing() { return static_cast(getTracingState()); } using warn_fn_type = void (*)(const std::string& msg); TORCH_API extern const char* WARN_PYTHON_DATAFLOW; TORCH_API extern const char* WARN_CONSTRUCTOR; TORCH_API extern const char* WARN_RESIZE; TORCH_API extern const char* STRICT_TRACER_MSG; TORCH_API void _do_warn(const char* _reason, const char* _kind); inline void warn(const char* _reason, const char* _kind = nullptr) { if (const auto& state = getTracingState()) { if (!state->warn) return; _do_warn(_reason, _kind); } } TORCH_API void setWarn(warn_fn_type fn); struct TORCH_API NoWarn { NoWarn() : state(getTracingState()) { if (state) { prev = state->warn; state->warn = false; } } ~NoWarn() { if (state) { state->warn = prev; } } std::shared_ptr state; bool prev{false}; }; struct WithNestedTracingFrame { WithNestedTracingFrame() { getTracingState()->enterFrame(); } ~WithNestedTracingFrame() { getTracingState()->leaveFrame(); } }; TORCH_API void recordSourceLocation(Node* n); TORCH_API void setRecordSourceLocation(void (*v)(Node*)); TORCH_API std::vector pythonCallstack(); TORCH_API void setPythonCallstack(std::vector (*v)()); // Having finished adding a new 'node' to the graph IR 'setValueTrace' // associates this node with an output variable, so that further operations // involving this variable know which node in the IR to reference. TORCH_API void setValueTrace(const IValue& v, Value* value); TORCH_API void delValueTrace(const IValue& var); TORCH_API std::function pauseTracing(); TORCH_API Value* getValueTrace(const IValue& var); TORCH_API std::pair, Stack> trace( Stack inputs, const std::function& traced_fn, std::function var_name_lookup_fn, bool strict = true, bool force_outplace = false, Module* self = nullptr, const std::vector& argument_names = {}); TORCH_API void abandon(); // NB: those serve both as an intermediate steps in addInputs below, // as well as the overloads that terminate template recursion TORCH_API void addInputs(Node* n, const char* name, int64_t value); TORCH_API void addInputs(Node* n, const char* name, const c10::SymInt& value); TORCH_API void addInputs( Node* n, const char* name, std::optional value); TORCH_API void addInputs(Node* n, const char* name, bool value); TORCH_API void addInputs( Node* n, const char* name, const std::optional& value); TORCH_API void addInputs(Node* n, const char* name, double value); TORCH_API void addInputs( Node* n, const char* name, const std::optional& value); TORCH_API void addInputs(Node* n, const char* name, const at::Scalar& value); TORCH_API void addInputs( Node* n, const char* name, const std::optional& value); TORCH_API void addInputs(Node* n, const char* name, const at::Tensor& value); TORCH_API void addInputs( Node* n, const char* name, const std::optional& value); TORCH_API void addInputs(Node* n, const char* name, ArrayRef value); TORCH_API void addInputs(Node* n, const char* name, c10::SymIntArrayRef value); TORCH_API void addInputs( Node* n, const char* name, std::optional value); TORCH_API void addInputs( Node* n, const char* name, const std::optional>& value); TORCH_API void addInputs( Node* n, const char* name, const at::OptionalIntArrayRef& opt_value); TORCH_API void addInputs( Node* n, const char* name, const at::OptionalSymIntArrayRef& opt_value); TORCH_API void addInputs( Node* n, const char* name, ArrayRef value, bool allow_undefined = false); TORCH_API void addInputs( Node* n, const char* name, const std::vector& value, bool allow_undefined = false); TORCH_API void addInputs( Node* n, const char* name, at::ITensorListRef value, bool allow_undefined = false); TORCH_API void addInputs( Node* n, const char* name, const List>& value); TORCH_API void addInputs( Node* n, const char* name, ArrayRef> value, const c10::ClassTypePtr& class_type); TORCH_API void addInputs(Node* n, const char* name, ArrayRef value); TORCH_API void addInputs( Node* n, const char* name, const std::optional>& value); TORCH_API void addInputs( Node* n, const char* name, const c10::string_view value); TORCH_API void addInputs( Node* n, const char* name, const std::optional& value); TORCH_API void addInputs(Node* n, const char* name, at::Device value); TORCH_API void addInputs(Node* n, const char* name, c10::Stream stream); TORCH_API void addInputs(Node* n, const char* name, at::Layout value); TORCH_API void addInputs(Node* n, const char* name, at::ScalarType value); TORCH_API void addInputs( Node* n, const char* name, const std::optional& value); TORCH_API void addInputs( Node* n, const char* name, const std::optional& value); TORCH_API void addInputs( Node* n, const char* name, const std::optional& value); TORCH_API void addInputs(Node* n, const char* name, at::MemoryFormat value); TORCH_API void addInputs( Node* n, const char* name, std::optional value); TORCH_API void addInputs( Node* n, const char* name, const std::optional& value); TORCH_API void addInputs( Node* n, const char* name, const std::optional& value); inline void addInputs( Node* n, const char* name, const std::vector& value) { AT_ERROR("Tracing a list of bool type is currently not supported!"); } template void addInputs(Node* n, const char* name, ArrayRef value) { AT_ERROR("Tracing a list of arbitrary type is currently not supported!"); } template void addInputs( Node* n, const char* name, const std::unordered_map& value) { AT_ERROR("Tracing a dict of arbitrary types is currently not supported!"); } template void addInputs(Node* n, const char* name, std::array value) { throw std::runtime_error( "Found an unsupported argument type in the JIT tracer. File a bug report."); } TORCH_API void addInputs( Node* n, const char* name, const c10::intrusive_ptr& obj); TORCH_API void ensureUniqueIfOutOfPlaced( const char* name, const at::Tensor& tensor); TORCH_API void ensureUniqueIfOutOfPlaced( const char* name, const std::optional& tensor); template < typename T, typename = std::enable_if_t< (!std::is_convertible_v, at::TensorList> && !std::is_convertible_v, c10::List> && !std::is_convertible_v, at::Tensor> && !std::is_convertible_v< std::decay_t, c10::intrusive_ptr>)>> void addOutput(Node* node, T&&) { AT_ERROR( "Found an unsupported argument type ", c10::demangle_type(), " in the JIT tracer. File a bug report."); } TORCH_API void addOutput(Node* node, const at::Tensor& tensor); TORCH_API void setOutput(Value* value, const at::Tensor& output); TORCH_API void addOutput(Node* node, const std::vector& list); TORCH_API void addOutput(Node* node, const c10::List& list); TORCH_API void addOutput( Node* node, const c10::intrusive_ptr& output); TORCH_API autograd::Variable getSizeOf( const autograd::Variable& var, int64_t dim); TORCH_API autograd::Variable getNumelOf(const autograd::Variable& var); } // namespace tracer } // namespace torch::jit