/third_party/mindspore/mindspore/train/ |
D | _utils.py | 210 def check_value_type(arg_name, arg_value, valid_types): argument 212 valid_types = tuple(valid_types) if isinstance(valid_types, Iterable) else (valid_types,) 216 if isinstance(arg_value, int) and isinstance(arg_value, bool) and bool not in valid_types: 219 if not isinstance(arg_value, valid_types):
|
/third_party/mindspore/mindspore/core/ops/ |
D | apply_momentum.cc | 79 const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64}; in ApplyMomentumInfer() local 80 (void)CheckAndConvertUtils::CheckTensorTypeValid("v_type", v_tensor_type, valid_types, prim_name); in ApplyMomentumInfer() 81 (void)CheckAndConvertUtils::CheckTensorTypeValid("a_type", a_tensor_type, valid_types, prim_name); in ApplyMomentumInfer() 86 CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, valid_types, prim_name); in ApplyMomentumInfer()
|
D | apply_adagrad_d_a.cc | 75 const std::set<TypePtr> valid_types = {kFloat16, kFloat32}; in InferType() local 82 (void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name); in InferType() 89 (void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_lr, valid_types, prim_name); in InferType() 91 (void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_l1, valid_types, prim_name); in InferType() 93 (void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_l2, valid_types, prim_name); in InferType()
|
D | fill.cc | 40 auto valid_types = common_valid_types; in FillInfer() local 41 valid_types.insert(kBool); in FillInfer() 42 (void)CheckAndConvertUtils::CheckTypeValid("output datatype", dtype, valid_types, prim_name); in FillInfer()
|
D | gather_d.cc | 67 std::set<TypePtr> valid_types = {kInt32, kInt64}; in GatherDInfer() local 68 …kAndConvertUtils::CheckTensorTypeValid("index", input_args[kInputIndex2]->BuildType(), valid_types, in GatherDInfer() 70 …dConvertUtils::CheckSubClass("dim", input_args[kInputIndex1]->BuildType(), valid_types, prim_name); in GatherDInfer()
|
D | batch_norm.cc | 113 const std::set<TypePtr> valid_types = {kFloat16, kFloat32}; in BatchNormInfer() local 115 …ertUtils::CheckTensorTypeValid("x", input_args[kInputIndex0]->BuildType(), valid_types, prim_name); in BatchNormInfer() 119 (void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name); in BatchNormInfer() 123 (void)CheckAndConvertUtils::CheckTensorTypeSame(args_moving, valid_types, prim_name); in BatchNormInfer()
|
D | layer_norm.cc | 71 auto valid_types = {kFloat16, kFloat32}; in LayerNormInfer() local 72 …vertUtils::CheckTensorTypeValid("x_dtype", input_args[x_index]->BuildType(), valid_types, op_name); in LayerNormInfer() 73 …onvertUtils::CheckTensorTypeValid("gamma_dtype", input_args[gamma_index]->BuildType(), valid_types, in LayerNormInfer() 75 …dConvertUtils::CheckTensorTypeValid("beta_dtype", input_args[beta_index]->BuildType(), valid_types, in LayerNormInfer()
|
D | binary_cross_entropy.cc | 59 const std::set<TypePtr> valid_types = {kFloat16, kFloat32}; in BinaryCrossEntroyInferType() local 63 auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); in BinaryCrossEntroyInferType() 67 infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); in BinaryCrossEntroyInferType()
|
D | ceil.cc | 35 const std::set<TypePtr> valid_types = {kFloat16, kFloat32}; in CeilInfer() local 37 …auto data_type = CheckAndConvertUtils::CheckTensorTypeValid("x type", infer_type, valid_types, pri… in CeilInfer()
|
D | atan.cc | 35 const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kInt32}; in AtanInfer() local 36 …auto element = CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", dtype, valid_types, prim_name… in AtanInfer()
|
D | tan.cc | 41 const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kInt32}; in TanInfer() local 42 …auto infered_type = CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", dtype, valid_types, prim… in TanInfer()
|
D | index_add.cc | 66 …const std::set<TypePtr> valid_types = {kInt8, kInt16, kInt32, kUInt8, kFloat16, kFloat32, kFloat64… in IndexAddInferType() local 72 …(void)CheckAndConvertUtils::CheckTensorTypeValid("input_y type", updates_type, valid_types, prim->… in IndexAddInferType() 73 …return CheckAndConvertUtils::CheckTensorTypeValid("input_x type", var_type, valid_types, prim->nam… in IndexAddInferType()
|
D | asin.cc | 38 const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kInt32}; in AsinInfer() local 39 …auto infer_type = CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", dtype, valid_types, prim_n… in AsinInfer()
|
D | dtype.cc | 35 const std::set<TypePtr> valid_types = {kTensorType}; in DTypeInferValue() local 37 …ConvertUtils::CheckTensorTypeValid("infer type", input_args[0]->BuildType(), valid_types, op_name); in DTypeInferValue()
|
D | logical_and.cc | 40 const std::set<TypePtr> valid_types = {kBool}; in InferType() local 43 return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); in InferType()
|
D | squared_difference.cc | 34 const std::set<TypePtr> valid_types = {kInt32, kFloat16, kFloat32}; in InferType() local 38 return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); in InferType()
|
D | hshrink.cc | 38 const std::set<TypePtr> valid_types = {kFloat16, kFloat32}; in InferType() local 39 …turn CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), valid_types, in InferType()
|
D | floor.cc | 35 const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64}; in InferType() local 38 return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); in InferType()
|
D | logical_or.cc | 38 const std::set<TypePtr> valid_types = {kBool}; in InferType() local 41 return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); in InferType()
|
D | hsigmoid.cc | 38 const std::set<TypePtr> valid_types = {kInt8, kInt16, kInt32, kInt64, kFloat16, kFloat32}; in InferType() local 40 return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); in InferType()
|
D | sigmoid_cross_entropy_with_logits.cc | 43 …const std::set<TypePtr> valid_types = {kBool, kInt, kInt8, kInt16, kInt32, kInt64, kUIn… in SigmoidCrossEntropyWithLogitsInfer() local 48 auto x_type = CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name); in SigmoidCrossEntropyWithLogitsInfer()
|
/third_party/mindspore/mindspore/ |
D | _checkparam.py | 586 def check_value_type(arg_name, arg_value, valid_types, prim_name=None): argument 588 valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,) 592 type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in valid_types] 593 num_types = len(valid_types) 601 if isinstance(arg_value, bool) and bool not in tuple(valid_types): 603 if not isinstance(arg_value, tuple(valid_types)): 608 def check_type_name(arg_name, arg_type, valid_types, prim_name): argument 610 valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,) 614 type_names = [t.__name__ if hasattr(t, '__name__') else t for t in valid_types] 615 num_types = len(valid_types) [all …]
|
/third_party/mindspore/mindspore/core/ops/grad/ |
D | soft_margin_loss_grad.cc | 45 const std::set<TypePtr> valid_types = {kFloat16, kFloat32}; in SoftMarginLossGradInferType() local 50 (void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, op_name); in SoftMarginLossGradInferType() 51 …kAndConvertUtils::CheckTensorTypeValid("logits", input_args[0]->BuildType(), valid_types, op_name); in SoftMarginLossGradInferType()
|
D | binary_cross_entropy_grad.cc | 41 const std::set<TypePtr> valid_types = {kFloat16, kFloat32}; in BinaryCrossEntroyGradInferType() local 45 auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); in BinaryCrossEntroyGradInferType() 49 infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); in BinaryCrossEntroyGradInferType()
|
/third_party/mindspore/tests/ut/python/ir/ |
D | test_dtype.py | 151 valid_types = [dtype.float16, dtype.float32] 152 assert t1 not in valid_types 153 assert dtype.int32 not in valid_types 154 assert dtype.float32 in valid_types
|