Home
last modified time | relevance | path

Searched refs:valid_types (Results 1 – 25 of 86) sorted by relevance

1234

/third_party/mindspore/mindspore/train/
D_utils.py210 def check_value_type(arg_name, arg_value, valid_types): argument
212 valid_types = tuple(valid_types) if isinstance(valid_types, Iterable) else (valid_types,)
216 if isinstance(arg_value, int) and isinstance(arg_value, bool) and bool not in valid_types:
219 if not isinstance(arg_value, valid_types):
/third_party/mindspore/mindspore/core/ops/
Dapply_momentum.cc79 const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64}; in ApplyMomentumInfer() local
80 (void)CheckAndConvertUtils::CheckTensorTypeValid("v_type", v_tensor_type, valid_types, prim_name); in ApplyMomentumInfer()
81 (void)CheckAndConvertUtils::CheckTensorTypeValid("a_type", a_tensor_type, valid_types, prim_name); in ApplyMomentumInfer()
86 CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, valid_types, prim_name); in ApplyMomentumInfer()
Dapply_adagrad_d_a.cc75 const std::set<TypePtr> valid_types = {kFloat16, kFloat32}; in InferType() local
82 (void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name); in InferType()
89 (void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_lr, valid_types, prim_name); in InferType()
91 (void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_l1, valid_types, prim_name); in InferType()
93 (void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_l2, valid_types, prim_name); in InferType()
Dfill.cc40 auto valid_types = common_valid_types; in FillInfer() local
41 valid_types.insert(kBool); in FillInfer()
42 (void)CheckAndConvertUtils::CheckTypeValid("output datatype", dtype, valid_types, prim_name); in FillInfer()
Dgather_d.cc67 std::set<TypePtr> valid_types = {kInt32, kInt64}; in GatherDInfer() local
68 …kAndConvertUtils::CheckTensorTypeValid("index", input_args[kInputIndex2]->BuildType(), valid_types, in GatherDInfer()
70 …dConvertUtils::CheckSubClass("dim", input_args[kInputIndex1]->BuildType(), valid_types, prim_name); in GatherDInfer()
Dbatch_norm.cc113 const std::set<TypePtr> valid_types = {kFloat16, kFloat32}; in BatchNormInfer() local
115 …ertUtils::CheckTensorTypeValid("x", input_args[kInputIndex0]->BuildType(), valid_types, prim_name); in BatchNormInfer()
119 (void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name); in BatchNormInfer()
123 (void)CheckAndConvertUtils::CheckTensorTypeSame(args_moving, valid_types, prim_name); in BatchNormInfer()
Dlayer_norm.cc71 auto valid_types = {kFloat16, kFloat32}; in LayerNormInfer() local
72 …vertUtils::CheckTensorTypeValid("x_dtype", input_args[x_index]->BuildType(), valid_types, op_name); in LayerNormInfer()
73 …onvertUtils::CheckTensorTypeValid("gamma_dtype", input_args[gamma_index]->BuildType(), valid_types, in LayerNormInfer()
75 …dConvertUtils::CheckTensorTypeValid("beta_dtype", input_args[beta_index]->BuildType(), valid_types, in LayerNormInfer()
Dbinary_cross_entropy.cc59 const std::set<TypePtr> valid_types = {kFloat16, kFloat32}; in BinaryCrossEntroyInferType() local
63 auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); in BinaryCrossEntroyInferType()
67 infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); in BinaryCrossEntroyInferType()
Dceil.cc35 const std::set<TypePtr> valid_types = {kFloat16, kFloat32}; in CeilInfer() local
37 …auto data_type = CheckAndConvertUtils::CheckTensorTypeValid("x type", infer_type, valid_types, pri… in CeilInfer()
Datan.cc35 const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kInt32}; in AtanInfer() local
36 …auto element = CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", dtype, valid_types, prim_name… in AtanInfer()
Dtan.cc41 const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kInt32}; in TanInfer() local
42 …auto infered_type = CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", dtype, valid_types, prim… in TanInfer()
Dindex_add.cc66 …const std::set<TypePtr> valid_types = {kInt8, kInt16, kInt32, kUInt8, kFloat16, kFloat32, kFloat64… in IndexAddInferType() local
72 …(void)CheckAndConvertUtils::CheckTensorTypeValid("input_y type", updates_type, valid_types, prim->… in IndexAddInferType()
73 …return CheckAndConvertUtils::CheckTensorTypeValid("input_x type", var_type, valid_types, prim->nam… in IndexAddInferType()
Dasin.cc38 const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kInt32}; in AsinInfer() local
39 …auto infer_type = CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", dtype, valid_types, prim_n… in AsinInfer()
Ddtype.cc35 const std::set<TypePtr> valid_types = {kTensorType}; in DTypeInferValue() local
37 …ConvertUtils::CheckTensorTypeValid("infer type", input_args[0]->BuildType(), valid_types, op_name); in DTypeInferValue()
Dlogical_and.cc40 const std::set<TypePtr> valid_types = {kBool}; in InferType() local
43 return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); in InferType()
Dsquared_difference.cc34 const std::set<TypePtr> valid_types = {kInt32, kFloat16, kFloat32}; in InferType() local
38 return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); in InferType()
Dhshrink.cc38 const std::set<TypePtr> valid_types = {kFloat16, kFloat32}; in InferType() local
39 …turn CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), valid_types, in InferType()
Dfloor.cc35 const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64}; in InferType() local
38 return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); in InferType()
Dlogical_or.cc38 const std::set<TypePtr> valid_types = {kBool}; in InferType() local
41 return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); in InferType()
Dhsigmoid.cc38 const std::set<TypePtr> valid_types = {kInt8, kInt16, kInt32, kInt64, kFloat16, kFloat32}; in InferType() local
40 return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); in InferType()
Dsigmoid_cross_entropy_with_logits.cc43 …const std::set<TypePtr> valid_types = {kBool, kInt, kInt8, kInt16, kInt32, kInt64, kUIn… in SigmoidCrossEntropyWithLogitsInfer() local
48 auto x_type = CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name); in SigmoidCrossEntropyWithLogitsInfer()
/third_party/mindspore/mindspore/
D_checkparam.py586 def check_value_type(arg_name, arg_value, valid_types, prim_name=None): argument
588 valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
592 type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in valid_types]
593 num_types = len(valid_types)
601 if isinstance(arg_value, bool) and bool not in tuple(valid_types):
603 if not isinstance(arg_value, tuple(valid_types)):
608 def check_type_name(arg_name, arg_type, valid_types, prim_name): argument
610 valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
614 type_names = [t.__name__ if hasattr(t, '__name__') else t for t in valid_types]
615 num_types = len(valid_types)
[all …]
/third_party/mindspore/mindspore/core/ops/grad/
Dsoft_margin_loss_grad.cc45 const std::set<TypePtr> valid_types = {kFloat16, kFloat32}; in SoftMarginLossGradInferType() local
50 (void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, op_name); in SoftMarginLossGradInferType()
51 …kAndConvertUtils::CheckTensorTypeValid("logits", input_args[0]->BuildType(), valid_types, op_name); in SoftMarginLossGradInferType()
Dbinary_cross_entropy_grad.cc41 const std::set<TypePtr> valid_types = {kFloat16, kFloat32}; in BinaryCrossEntroyGradInferType() local
45 auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); in BinaryCrossEntroyGradInferType()
49 infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); in BinaryCrossEntroyGradInferType()
/third_party/mindspore/tests/ut/python/ir/
Dtest_dtype.py151 valid_types = [dtype.float16, dtype.float32]
152 assert t1 not in valid_types
153 assert dtype.int32 not in valid_types
154 assert dtype.float32 in valid_types

1234