Home
last modified time | relevance | path

Searched refs:dout_alpha_shape (Results 1 – 3 of 3) sorted by relevance

/third_party/mindspore/mindspore/ops/_op_impl/_custom_op/
Dfake_learned_scale_quant_perchannel_grad_reduce.py53 dout_alpha_shape = dout_alpha.get("shape")
54 axis = list(range(len(dout_alpha_shape)))
65 dout_alpha_shape = dout_alpha.get("shape")
69 util.check_shape_rule(dout_alpha_shape)
70 util.check_tensor_shape_size(dout_alpha_shape)
76 dout_alpha_data = tvm.placeholder(dout_alpha_shape, name="dout_alpha", dtype=dout_alpha_dtype)
Dfake_learned_scale_quant_perlayer_grad_reduce.py64 dout_alpha_shape = dout_alpha.get("shape")
68 util.check_shape_rule(dout_alpha_shape)
69 util.check_tensor_shape_size(dout_alpha_shape)
75 input_shape = (functools_reduce(lambda x, y: x * y, dout_alpha_shape[:]),)
/third_party/mindspore/mindspore/ops/operations/
D_quant_ops.py323 def infer_shape(self, dout_alpha_shape): argument
489 def infer_shape(self, dout_alpha_shape): argument
490 return (dout_alpha_shape[self.channel_axis],)