• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #pragma once
2 
3 #include <torch/csrc/jit/api/module.h>
4 
5 #include <ATen/core/jit_type.h>
6 
7 #include <functional>
8 
9 namespace torch {
10 namespace jit {
11 
12 using DebugHandleType = int64_t;
13 
14 using NodeToDebugHandle = std::unordered_map<Node*, DebugHandleType>;
15 
16 using BackendDebugHandleGenerator =
17     std::function<NodeToDebugHandle(const std::shared_ptr<Graph>&)>;
18 
19 namespace detail {
20 
21 using BackendPreprocessFunction = std::function<c10::IValue(
22     const Module&,
23     const c10::Dict<IValue, IValue>&,
24     const BackendDebugHandleGenerator& generate_debug_handles)>;
25 
26 TORCH_API void registerBackendPreprocessFunction(
27     const std::string& name,
28     const BackendPreprocessFunction& preprocess);
29 
30 bool hasBackendPreprocessFunction(const std::string& name);
31 
32 BackendPreprocessFunction getBackendPreprocessFunction(const std::string& name);
33 
34 TORCH_API Module codegen_backend_module(
35     const std::string& backend_name,
36     const Module& orig_module,
37     const c10::Dict<IValue, IValue>& method_compile_spec,
38     const c10::DictTypePtr& any_dict_ty);
39 } // namespace detail
40 } // namespace jit
41 } // namespace torch
42