Home
last modified time | relevance | path

Searched refs:model_cls (Results 1 – 2 of 2) sorted by relevance

/external/pytorch/benchmarks/dynamo/
Dhuggingface.py130 def get_sequence_length(model_cls, model_name): argument
172 model_cls, model, model_name, bs, device, include_loss_args=False argument
177 seq_length = get_sequence_length(model_cls, model_name)
207 or model_cls
233 if model_cls in [ElectraForPreTraining, LxmertForPreTraining]:
238 if model_cls in [AlbertForPreTraining]
369 model_cls = get_module_cls_by_model_name(model_name)
370 config_cls = model_cls.config_class
375 model_cls
381 or model_cls.__name__.startswith("Roberta")
[all …]
/external/pytorch/test/distributed/fsdp/
Dtest_fsdp_sharded_grad_scaler.py189 model_cls = NonUniformReqGradNWM
193 model_cls = NestedWrappedModule # type: ignore[assignment]
197 model_cls,