• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2import threading
3from typing import Any, Dict
4
5import torch._C._lazy
6
7
8class DeviceContext:
9    _CONTEXTS: Dict[str, Any] = {}
10    _CONTEXTS_LOCK = threading.Lock()
11
12    def __init__(self, device):
13        self.device = device
14
15
16def get_device_context(device=None):
17    if device is None:
18        device = torch._C._lazy._get_default_device_type()
19    else:
20        device = str(device)
21    with DeviceContext._CONTEXTS_LOCK:
22        devctx = DeviceContext._CONTEXTS.get(device, None)
23        if devctx is None:
24            devctx = DeviceContext(device)
25            DeviceContext._CONTEXTS[device] = devctx
26        return devctx
27