Home
last modified time | relevance | path

Searched refs:x_dtype (Results 1 – 25 of 40) sorted by relevance

12

/third_party/mindspore/mindspore/ops/operations/
Dcomm_ops.py164 def infer_dtype(self, x_dtype): argument
165 validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
166 return x_dtype
241 def infer_dtype(self, x_dtype): argument
242 validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
243 return x_dtype
278 def infer_dtype(self, x_dtype, z_shape): argument
279 validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
280 return x_dtype
309 def infer_dtype(self, x_dtype, z_dtype): argument
[all …]
Dmath_ops.py93 def do_infer_dtype(x_dtype, y_dtype, valid_dtype=mstype.number_type, prim_name=None): argument
95 args_type = {"x": x_dtype, "y": y_dtype}
97 if x_dtype in complex_types or y_dtype in complex_types:
106 if (x_dtype.element_type(), y_dtype.element_type()) not in tpye_infer_dict.keys():
111 return tpye_infer_dict.get((x_dtype.element_type(), y_dtype.element_type()))
114 return x_dtype
116 def infer_dtype(self, x_dtype, y_dtype): argument
117 return _MathBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type, self.name)
1516 def infer_dtype(self, x_dtype): argument
1517 validator.check_tensor_dtype_valid("x", x_dtype, mstype.number_type, self.name)
[all …]
D_grad_ops.py131 def infer_dtype(self, x_dtype, dout_dtype): argument
132 args = {"x": x_dtype, "dout": dout_dtype}
134 return x_dtype
148 def infer_dtype(self, x_dtype, dout_dtype): argument
149 args = {"x": x_dtype, "dout": dout_dtype}
153 return x_dtype
171 def infer_dtype(self, x_dtype, dout_dtype): argument
172 args = {"x": x_dtype, "dout": dout_dtype}
175 return x_dtype
733 def infer_dtype(self, y_backprop_dtype, x_dtype, y_dtype): argument
[all …]
Darray_ops.py69 def infer_dtype(self, x_dtype, indices_dtype, updates_dtype): argument
71 args = {"x": x_dtype, "updates": updates_dtype}
73 return x_dtype
109 def check_dtype(self, x_dtype, indices_dtype, updates_dtype): argument
111 args = {"x": x_dtype, "updates": updates_dtype}
686 def infer_dtype(self, x_dtype): argument
687 validator.check_subclass("x", x_dtype, mstype.tensor, self.name)
688 return x_dtype
1686 def infer_dtype(self, x_dtype): argument
1687 validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name)
[all …]
D_inner_ops.py170 def infer_dtype(self, x_dtype): argument
171 validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float32, mstype.int32], self.name)
172 return x_dtype
314 def infer_dtype(self, x_dtype, assist_dtype): argument
316 args = {"x": x_dtype, "assist": assist_dtype}
318 return x_dtype
363 def infer_dtype(self, x_dtype, assist_dtype): argument
365 args = {"x": x_dtype, "assist": assist_dtype}
367 return x_dtype
430 def infer_dtype(self, x_dtype): argument
[all …]
Dnn_ops.py233 def infer_dtype(self, x_dtype): argument
234 …validator.check_tensor_dtype_valid("x_dtype", x_dtype, [mstype.float16, mstype.float32, mstype.flo…
236 return x_dtype
502 def infer_dtype(self, x_dtype): argument
503 … validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float16, mstype.float32], self.name)
504 return x_dtype
556 def infer_dtype(self, x_dtype): argument
558 validator.check_tensor_dtype_valid('x', x_dtype, valid_dtypes, self.name)
559 return x_dtype
753 def infer_dtype(self, x_dtype): argument
[all …]
D_thor_ops.py528 def infer_dtype(self, x_dtype): argument
530 validator.check_tensor_dtype_valid('x', x_dtype, valid_dtypes, self.name)
531 return x_dtype
602 def infer_dtype(self, x_dtype): argument
605 validator.check_tensor_dtype_valid('x', x_dtype, valid_dtypes, self.name)
606 return x_dtype
/third_party/mindspore/mindspore/core/ops/
Dreluv2.cc30 …GetOutputMaskShape(const std::vector<int64_t> &input_shape, const std::shared_ptr<Type> &x_dtype) { in GetOutputMaskShape() argument
38 if (x_dtype == kUInt8 || x_dtype == kInt8) { in GetOutputMaskShape()
49 if (x_dtype == kUInt8 || x_dtype == kInt8) { in GetOutputMaskShape()
71 auto x_dtype = input_type->element(); in InferShape() local
72 auto mask_shape = GetOutputMaskShape(input_shape, x_dtype); in InferShape()
80 auto min_mask_shape = GetOutputMaskShape(min_shape, x_dtype); in InferShape()
81 auto max_mask_shape = GetOutputMaskShape(max_shape, x_dtype); in InferShape()
Dcos.cc32 auto x_dtype = input_args[0]->BuildType(); in CosInferType() local
33 (void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_dtype, common_valid_types, prim->name()); in CosInferType()
34 return x_dtype; in CosInferType()
Dbroadcast_to.cc69 auto x_dtype = input_args[0]->BuildType()->cast<TensorTypePtr>(); in BroadcastToInferType() local
71 (void)CheckAndConvertUtils::CheckSubClass("x_dtype", x_dtype, template_types, prim->name()); in BroadcastToInferType()
72 return x_dtype->element(); in BroadcastToInferType()
Dtile.cc87 auto x_dtype = x_type_map->cast<TensorTypePtr>(); in TileInferType() local
88 MS_EXCEPTION_IF_NULL(x_dtype); in TileInferType()
90 return CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", x_dtype, template_types, prim_name); in TileInferType()
Dsquare.cc45 auto x_dtype = input_args[kInputIndex0]->BuildType(); in SquareInferType() local
46 (void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_dtype, common_valid_types, prim->name()); in SquareInferType()
47 return x_dtype; in SquareInferType()
Ddiag.cc39 auto x_dtype = input_args[0]->BuildType(); in PartInferType() local
40 …return CheckAndConvertUtils::CheckTensorTypeValid("input type", x_dtype, common_valid_types, primi… in PartInferType()
Ddiag_part.cc49 auto x_dtype = input_args[0]->BuildType(); in DiagPartInferType() local
50 …return CheckAndConvertUtils::CheckTensorTypeValid("input type", x_dtype, common_valid_types, primi… in DiagPartInferType()
Darg_min.cc60 auto x_dtype = input_args[0]->BuildType()->cast<TensorTypePtr>()->element(); in ArgMinInfer() local
61 …return std::make_shared<abstract::AbstractTensor>(x_dtype, std::make_shared<abstract::Shape>(out_s… in ArgMinInfer()
/third_party/mindspore/mindspore/ops/_op_impl/_custom_op/
Dminmax_update_perchannel.py85 x_dtype = x.get("dtype")
104 x_dtype = x_dtype.lower()
107 util.check_dtype_rule(x_dtype, check_list)
115 input_data = tvm.placeholder(x.get("shape"), name="x", dtype=x_dtype)
116 min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
117 max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype)
Dfake_quant_perchannel.py95 x_dtype = x.get("dtype")
114 x_dtype = x_dtype.lower()
117 util.check_dtype_rule(x_dtype, check_list)
125 return x_shape, shape_c, x_dtype
138 x_shape, shape_c, x_dtype = fake_quant_perchannel_param(x, min_val, max_val,
140 input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype)
141 min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
142 max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype)
Dfake_quant_perchannel_grad.py119 x_dtype = x.get("dtype")
138 x_dtype = x_dtype.lower()
141 util.check_dtype_rule(x_dtype, check_list)
149 return x_shape, shape_c, x_dtype
166 x_shape, shape_c, x_dtype = fake_quant_perchannel_grad_param(x, min_val, max_val,
168 dout_data = tvm.placeholder(x_shape, name="dout", dtype=x_dtype)
169 input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype)
170 min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
171 max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype)
Dminmax_update_perlayer.py98 x_dtype = input_dtype.lower()
101 util.check_dtype_rule(x_dtype, check_list)
108 input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype)
Dfake_quant_perlayer_grad.py138 x_dtype = input_dtype.lower()
141 util.check_dtype_rule(x_dtype, check_list)
153 dout_data = tvm.placeholder(input_shape, name="dout", dtype=x_dtype)
154 input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype)
Dfake_quant_perlayer.py112 x_dtype = input_dtype.lower()
115 util.check_dtype_rule(x_dtype, check_list)
127 input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype)
/third_party/mindspore/mindspore/ccsrc/backend/optimizer/ascend/mindir/
Davg_pool_grad_unify_mindir.cc102 const PadMode pad_mode, const TypeId x_dtype) { in CreateMeanMatrixValueNode() argument
146 …auto output_tensor = std::make_shared<tensor::Tensor>(x_dtype, output_shape, &output[0], kNumberTy… in CreateMeanMatrixValueNode()
148 auto abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(x_dtype), output_shape); in CreateMeanMatrixValueNode()
157 const std::vector<int64_t> &k_size, const TypeId x_dtype) { in CreateKernelMatrixValueNode() argument
168 …auto kernel_matrix_tensor = std::make_shared<tensor::Tensor>(x_dtype, kernel_shape, &data[0], kNum… in CreateKernelMatrixValueNode()
170 auto abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(x_dtype), kernel_shape); in CreateKernelMatrixValueNode()
194 auto x_dtype = AnfAlgo::GetPrevNodeOutputInferDataType(avgpool_grad, 0); in Process() local
200 …o mean_matrix_vnode = CreateMeanMatrixValueNode(graph, x_shape, k_size, stride, pad_mode, x_dtype); in Process()
201 auto kernel_matrix_vnode = CreateKernelMatrixValueNode(graph, x_shape, k_size, x_dtype); in Process()
/third_party/mindspore/mindspore/nn/layer/
Dbasic.py351 def _dtype_check(x_dtype, prim_name=None): argument
353 if x_dtype not in [mstype.float32, mstype.float16]:
999 def tril(x_shape, x_dtype, k): argument
1003 return Tensor(mask, x_dtype)
1086 def triu(x_shape, x_dtype, k): argument
1090 return Tensor(mask, x_dtype)
1173 def _get_matrix_diag_assist(x_shape, x_dtype): argument
1177 return Tensor(assist, x_dtype)
1181 def _get_matrix_diag_part_assist(x_shape, x_dtype): argument
1185 return Tensor(assist, x_dtype)
[all …]
/third_party/mindspore/mindspore/ops/_grad/
Dgrad_inner_ops.py26 def _get_matrix_diag_assist(x_shape, x_dtype): argument
27 base_eye = P.Eye()(x_shape[-1], x_shape[-1], x_dtype).flatten()
33 def _get_matrix_diag_part_assist(x_shape, x_dtype): argument
34 base_eye = P.Eye()(x_shape[-2], x_shape[-1], x_dtype).flatten()
/third_party/mindspore/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/
Dspace_to_depth_split.cc108 TypeId x_dtype = AnfAlgo::GetOutputInferDataType(ori_inputs[kIndex1], 0); in Process() local
109 if (x_dtype != kNumberTypeFloat16) { in Process()
110 …INFO) << "Node " << cnode->DebugString() << ": The data type of node's first input is: " << x_dtype in Process()

12