Searched refs:wrapper_cls (Results 1 – 10 of 10) sorted by relevance
/external/pytorch/torch/distributed/fsdp/ |
D | wrap.py | 411 *, wrapper_cls: Any, **wrapper_kwargs: Any 438 "wrapper_cls": wrapper_cls, 471 assert _ConfigAutoWrap.wrapper_cls is not None 476 _ConfigAutoWrap.wrapper_cls, 482 def _wrap(module: nn.Module, wrapper_cls: Callable, **kwargs) -> nn.Module: 483 assert wrapper_cls is not None 490 return wrapper_cls(module, **overrides) 492 return wrapper_cls(module, **kwargs) 498 wrapper_cls: Callable, 522 assert wrapper_cls is not None, "Must specify wrapper_cls" [all …]
|
/external/tensorflow/tensorflow/python/training/experimental/ |
D | mixed_precision.py | 36 def register_loss_scale_wrapper(optimizer_cls, wrapper_fn, wrapper_cls=None): argument 57 wrapper_fn, wrapper_cls or wrapper_fn)
|
/external/tensorflow/tensorflow/tools/api/golden/v2/ |
D | tensorflow.__internal__.mixed_precision.pbtxt | 5 …argspec: "args=[\'optimizer_cls\', \'wrapper_fn\', \'wrapper_cls\'], varargs=None, keywords=None, …
|
/external/pytorch/test/distributed/fsdp/ |
D | test_distributed_checkpoint.py | 54 with enable_wrap(wrapper_cls=FSDP):
|
D | test_wrap.py | 394 with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group): 419 with enable_wrap(wrapper_cls=FSDP, process_group=pg): 429 with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group): 775 with enable_wrap(wrapper_cls=FSDP, **fsdp_kwargs): 804 with enable_wrap(wrapper_cls=FSDP, **fsdp_kwargs):
|
D | test_fsdp_meta.py | 260 wrapper_cls=FSDP, 274 with enable_wrap(wrapper_cls=FSDP):
|
D | test_fsdp_ignored_modules.py | 165 wrapper_cls = fully_shard if composable else FSDP 166 wrapped_model = wrapper_cls(model, **fsdp_kwargs)
|
D | test_fsdp_state_dict.py | 944 ctx = enable_wrap(wrapper_cls=FSDP) if wrap_fsdp else nullcontext() 1123 with enable_wrap(wrapper_cls=FSDP):
|
/external/tensorflow/tensorflow/python/kernel_tests/nn_ops/ |
D | rnn_cell_test.py | 3057 wrapper_cls = rnn_cell_impl.DeviceWrapper 3059 wrapper = wrapper_cls(cell, "/cpu:0") 3067 reconstructed_wrapper = wrapper_cls.from_config(config_copy) 3069 self.assertIsInstance(reconstructed_wrapper, wrapper_cls) 3072 wrapper_cls = rnn_cell_impl.ResidualWrapper 3074 wrapper = wrapper_cls(cell) 3082 reconstructed_wrapper = wrapper_cls.from_config(config_copy) 3084 self.assertIsInstance(reconstructed_wrapper, wrapper_cls) 3086 wrapper = wrapper_cls(cell, residual_fn=lambda i, o: i + i + o) 3092 reconstructed_wrapper = wrapper_cls.from_config(config_copy) [all …]
|
/external/pytorch/torch/distributed/algorithms/_checkpoint/ |
D | checkpoint_wrapper.py | 319 wrapper_cls=checkpoint_wrapper_fn,
|