#pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace torch::jit::fuser { // Helper struct containing partition information: the number of tensors // created and the dimension the partitioning is performed on. // Note: created during upfront compilation, once the tensors are known // at runtime the partition info is logically combined with the tensor // descriptions to create PartitionDesc objects. struct TORCH_API PartitionInfo { PartitionInfo(const int64_t _nSubTensors, const int64_t _dim) : nSubTensors_{_nSubTensors}, dim_{_dim} {}; int64_t nSubTensors() const { return nSubTensors_; } int64_t dim() const { return dim_; } private: int64_t nSubTensors_; int64_t dim_; }; // "Kernel Specification." - Contains device-independent fusion information. // Each kernel specification contains a map of instantiated generated functions // that implement some or most of its functionality. Multiple generated // functions are needed by each abstract specification because of different // devices (cpu vs gpu, different gpus) and different inputs (int vs float, // contiguous vs discontiguous). // Note: uses a mutex to control access to its kernel store // Note: unordered containers do not invalidate references/pointers on // rehashing, which is critical for thread-safety. // TODO: allow abstract kernels to use multiple generated kernels // TODO: allow abstract kernels to reuse generated kernels from common pool struct TORCH_API KernelSpec { // Note: assumes the spec is a single block // Note: This is the appropriate place to generalize if you want to add other // passes to upfront compilation that walk the graph. KernelSpec(const int64_t _key, const std::shared_ptr& _graph) : key_{_key}, graph_{_graph}, code_{_graph, ""}, nInputs_{_graph->inputs().size()}, inputBroadcastGroups_{}, inputChunks_{}, kernels_{} { // No need to iterate over reference since n is pointer for (const auto n : graph_->nodes()) { static_assert(std::is_pointer_v, "n must be a pointer"); if (n->kind() == aten::rand_like) { has_random_ = true; break; } } nTensorInputs_ = std::count_if( graph_->inputs().begin(), graph_->inputs().end(), [](const Value* v) { return v->type()->isSubtypeOf(*TensorType::get()); }); } // Getters int64_t key() const { return key_; } std::shared_ptr graph() const { return graph_; } const Code& code() const { return code_; } int64_t nInputs() const { return nInputs_; } int64_t nTensorInputs() const { return nTensorInputs_; } std::vector>& inputBroadcastGroups() { return inputBroadcastGroups_; } const std::vector>& inputBroadcastGroups() const { return inputBroadcastGroups_; } std::vector& inputChunks() { return inputChunks_; } const std::vector& inputChunks() const { return inputChunks_; } bool hasRandom() const { return has_random_; } // Cache functions std::optional> findKernel( const ArgSpec& arg_spec) const { std::lock_guard guard{mutex_}; const auto it = kernels_.find(arg_spec); if (it == kernels_.end()) return std::nullopt; return it->second; } void cacheKernel( const ArgSpec& arg_spec, const std::shared_ptr& kernel) const { std::lock_guard guard{mutex_}; kernels_.emplace(arg_spec, kernel); } private: int64_t key_; std::shared_ptr graph_; Code code_; uint64_t nInputs_; uint64_t nTensorInputs_{}; std::vector> inputBroadcastGroups_; std::vector inputChunks_; bool has_random_{false}; mutable std::mutex mutex_; mutable std:: unordered_map, c10::hash> kernels_; }; } // namespace torch::jit::fuser