• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <ATen/FuncTorchTLS.h>
2 
3 namespace at::functorch {
4 
5 namespace {
6 
7 thread_local std::unique_ptr<FuncTorchTLSBase> kFuncTorchTLS = nullptr;
8 
9 }
10 
getCopyOfFuncTorchTLS()11 std::unique_ptr<FuncTorchTLSBase> getCopyOfFuncTorchTLS() {
12   if (kFuncTorchTLS == nullptr) {
13     return nullptr;
14   }
15   return kFuncTorchTLS->deepcopy();
16 }
17 
setFuncTorchTLS(const std::shared_ptr<const FuncTorchTLSBase> & state)18 void setFuncTorchTLS(const std::shared_ptr<const FuncTorchTLSBase>& state) {
19   if (state == nullptr) {
20     kFuncTorchTLS = nullptr;
21     return;
22   }
23   kFuncTorchTLS = state->deepcopy();
24 }
25 
functorchTLSAccessor()26 std::unique_ptr<FuncTorchTLSBase>& functorchTLSAccessor() {
27   return kFuncTorchTLS;
28 }
29 
30 
31 } // namespace at::functorch
32