1 #pragma once 2 3 #include <ATen/record_function.h> 4 #include <c10/util/Synchronized.h> 5 #include <map> 6 #include <set> 7 #include <string> 8 9 namespace torch::jit::mobile { 10 /* The KernelDTypeTracer class handles the attachment and removal of a recording 11 * callback that traces the invocation of code that handles specific dtypes in 12 * kernel function implementations that are tagged with specific tags. 13 * 14 * You can get the set of kernel tags and the dtypes using 15 * getCalledKernelTags(). 16 * 17 * Note: This class is not thread safe or re-entrant, and should not be used 18 * across multiple threads of execution. 19 * 20 */ 21 struct KernelDTypeTracer final { 22 at::CallbackHandle handle_; 23 /* The key of the map below (std::string) is the kernel tag name (constant 24 * character string) which shows up in code. The value part of type 25 * std::set<std::string> is the collection of dtypes for which we need to 26 * generate code for the said kernel tag. 27 */ 28 typedef std::map<std::string, std::set<std::string>> kernel_tags_type; 29 30 KernelDTypeTracer(); 31 static c10::Synchronized<kernel_tags_type>& getCalledKernelTags(); 32 ~KernelDTypeTracerfinal33 ~KernelDTypeTracer() { 34 at::removeCallback(handle_); 35 } 36 }; 37 } // namespace torch::jit::mobile 38