• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #pragma once
2 
3 #include <ATen/record_function.h>
4 #include <c10/util/Synchronized.h>
5 #include <map>
6 #include <set>
7 #include <string>
8 
9 namespace torch::jit::mobile {
10 /* The KernelDTypeTracer class handles the attachment and removal of a recording
11  * callback that traces the invocation of code that handles specific dtypes in
12  * kernel function implementations that are tagged with specific tags.
13  *
14  * You can get the set of kernel tags and the dtypes using
15  * getCalledKernelTags().
16  *
17  * Note: This class is not thread safe or re-entrant, and should not be used
18  * across multiple threads of execution.
19  *
20  */
21 struct KernelDTypeTracer final {
22   at::CallbackHandle handle_;
23   /* The key of the map below (std::string) is the kernel tag name (constant
24    * character string) which shows up in code. The value part of type
25    * std::set<std::string> is the collection of dtypes for which we need to
26    * generate code for the said kernel tag.
27    */
28   typedef std::map<std::string, std::set<std::string>> kernel_tags_type;
29 
30   KernelDTypeTracer();
31   static c10::Synchronized<kernel_tags_type>& getCalledKernelTags();
32 
~KernelDTypeTracerfinal33   ~KernelDTypeTracer() {
34     at::removeCallback(handle_);
35   }
36 };
37 } // namespace torch::jit::mobile
38