Home
last modified time | relevance | path

Searched refs:wrapper_cls (Results 1 – 10 of 10) sorted by relevance

/external/pytorch/torch/distributed/fsdp/
Dwrap.py411 *, wrapper_cls: Any, **wrapper_kwargs: Any
438 "wrapper_cls": wrapper_cls,
471 assert _ConfigAutoWrap.wrapper_cls is not None
476 _ConfigAutoWrap.wrapper_cls,
482 def _wrap(module: nn.Module, wrapper_cls: Callable, **kwargs) -> nn.Module:
483 assert wrapper_cls is not None
490 return wrapper_cls(module, **overrides)
492 return wrapper_cls(module, **kwargs)
498 wrapper_cls: Callable,
522 assert wrapper_cls is not None, "Must specify wrapper_cls"
[all …]
/external/tensorflow/tensorflow/python/training/experimental/
Dmixed_precision.py36 def register_loss_scale_wrapper(optimizer_cls, wrapper_fn, wrapper_cls=None): argument
57 wrapper_fn, wrapper_cls or wrapper_fn)
/external/tensorflow/tensorflow/tools/api/golden/v2/
Dtensorflow.__internal__.mixed_precision.pbtxt5 …argspec: "args=[\'optimizer_cls\', \'wrapper_fn\', \'wrapper_cls\'], varargs=None, keywords=None, …
/external/pytorch/test/distributed/fsdp/
Dtest_distributed_checkpoint.py54 with enable_wrap(wrapper_cls=FSDP):
Dtest_wrap.py394 with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group):
419 with enable_wrap(wrapper_cls=FSDP, process_group=pg):
429 with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group):
775 with enable_wrap(wrapper_cls=FSDP, **fsdp_kwargs):
804 with enable_wrap(wrapper_cls=FSDP, **fsdp_kwargs):
Dtest_fsdp_meta.py260 wrapper_cls=FSDP,
274 with enable_wrap(wrapper_cls=FSDP):
Dtest_fsdp_ignored_modules.py165 wrapper_cls = fully_shard if composable else FSDP
166 wrapped_model = wrapper_cls(model, **fsdp_kwargs)
Dtest_fsdp_state_dict.py944 ctx = enable_wrap(wrapper_cls=FSDP) if wrap_fsdp else nullcontext()
1123 with enable_wrap(wrapper_cls=FSDP):
/external/tensorflow/tensorflow/python/kernel_tests/nn_ops/
Drnn_cell_test.py3057 wrapper_cls = rnn_cell_impl.DeviceWrapper
3059 wrapper = wrapper_cls(cell, "/cpu:0")
3067 reconstructed_wrapper = wrapper_cls.from_config(config_copy)
3069 self.assertIsInstance(reconstructed_wrapper, wrapper_cls)
3072 wrapper_cls = rnn_cell_impl.ResidualWrapper
3074 wrapper = wrapper_cls(cell)
3082 reconstructed_wrapper = wrapper_cls.from_config(config_copy)
3084 self.assertIsInstance(reconstructed_wrapper, wrapper_cls)
3086 wrapper = wrapper_cls(cell, residual_fn=lambda i, o: i + i + o)
3092 reconstructed_wrapper = wrapper_cls.from_config(config_copy)
[all …]
/external/pytorch/torch/distributed/algorithms/_checkpoint/
Dcheckpoint_wrapper.py319 wrapper_cls=checkpoint_wrapper_fn,