#pragma once #include #include #include #include #include #include #include #include #include #include #include #include // TODO: replace with pytorch/rfcs#43 when it is ready. #define SOFT_ASSERT(cond, ...) \ [&]() -> bool { \ if (C10_UNLIKELY(!(cond))) { \ torch::profiler::impl::logSoftAssert( \ __func__, \ __FILE__, \ static_cast(__LINE__), \ #cond, \ ::c10::str(__VA_ARGS__)); \ if (torch::profiler::impl::softAssertRaises()) { \ TORCH_INTERNAL_ASSERT(cond, __VA_ARGS__); \ } else { \ TORCH_WARN_ONCE(__VA_ARGS__); \ } \ return false; \ } \ return true; \ }() namespace torch::profiler::impl { TORCH_API bool softAssertRaises(); TORCH_API void setSoftAssertRaises(std::optional value); TORCH_API void logSoftAssert( const char* func, const char* file, uint32_t line, const char* cond, const char* args); TORCH_API inline void logSoftAssert( const char* func, const char* file, uint32_t line, const char* cond, ::c10::detail::CompileTimeEmptyString args) { logSoftAssert(func, file, line, cond, (const char*)args); } TORCH_API void logSoftAssert( const char* func, const char* file, uint32_t line, const char* cond, const std::string& args); using shape = std::variant, std::vector>>; constexpr int TENSOR_LIST_DISPLAY_LENGTH_LIMIT = 30; std::string getNvtxStr( const char* name, int64_t sequence_nr, const std::vector>& shapes, at::RecordFunctionHandle op_id = 0, const std::list>& input_op_ids = {}); struct TORCH_API FileLineFunc { std::string filename; size_t line; std::string funcname; }; TORCH_API std::vector prepareCallstack( const std::vector& cs); TORCH_API std::vector callstackStr( const std::vector& cs); TORCH_API std::string stacksToStr( const std::vector& stacks, const char* delim); TORCH_API std::vector> inputSizes( const at::RecordFunction& fn, const bool flatten_list_enabled = false); TORCH_API std::string variantShapesToStr(const std::vector& shapes); TORCH_API std::string shapesToStr( const std::vector>& shapes); TORCH_API std::string strListToStr(const std::vector& types); TORCH_API std::string inputOpIdsToStr( const std::list>& input_op_ids); TORCH_API std::string ivalueToStr(const c10::IValue& val, bool isString); TORCH_API std::string ivalueListToStr(const std::vector& list); TORCH_API std::vector inputTypes(const at::RecordFunction& fn); std::unordered_map TORCH_API saveExtraArgs(const at::RecordFunction& fn); std::unordered_map TORCH_API saveNcclMeta(const at::RecordFunction& fn, bool truncate = true); uint64_t TORCH_API computeFlops( const std::string& op_name, const std::unordered_map& extra_args); std::string shapeToStr(const std::vector& shape); template class TORCH_API GlobalStateManager { public: static GlobalStateManager& singleton() { static GlobalStateManager singleton_; return singleton_; } static void push(std::shared_ptr&& state) { if (singleton().state_) { LOG(WARNING) << "GlobalStatePtr already exists!"; } else { singleton().state_ = std::move(state); } } static auto* get() { return singleton().state_.get(); } static std::shared_ptr pop() { auto out = singleton().state_; singleton().state_.reset(); return out; } private: GlobalStateManager() = default; std::shared_ptr state_; }; struct HashCombine { template size_t operator()(const std::pair& i) { return c10::get_hash((*this)(i.first), (*this)(i.second)); } template size_t operator()(const std::tuple& i) { return c10::get_hash(i); } template size_t operator()(const T& i) { return c10::get_hash(i); } }; #ifdef USE_DISTRIBUTED constexpr auto kCommsName = "Collective name"; constexpr auto kDtype = "dtype"; constexpr auto kInMsgNelems = "In msg nelems"; constexpr auto kOutMsgNelems = "Out msg nelems"; constexpr auto kInSplit = "In split size"; constexpr auto kOutSplit = "Out split size"; constexpr auto kGlobalRankStart = "Global rank start"; constexpr auto kGlobalRankStride = "Global rank stride"; constexpr auto kGroupSize = "Group size"; constexpr auto kProcessGroupName = "Process Group Name"; constexpr auto kProcessGroupDesc = "Process Group Description"; constexpr auto kGroupRanks = "Process Group Ranks"; constexpr auto kRank = "Rank"; constexpr auto kP2pSrc = "Src Rank"; constexpr auto kP2pDst = "Dst Rank"; #endif // USE_DISTRIBUTED } // namespace torch::profiler::impl