Home
last modified time | relevance | path

Searched refs:updates_shape (Results 1 – 6 of 6) sorted by relevance

/third_party/mindspore/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/
Dscatter_nd_functor_gpu_kernel.h90 auto updates_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); in Init() local
98 if (updates_shape.size() != indices_shape.size() - 1 + input_shape.size() - index_depth) { in Init()
102 if (updates_shape[i] != indices_shape[i]) { in Init()
116 for (size_t i = 0; i < updates_shape.size(); i++) { in Init()
117 updates_size_ *= updates_shape[i]; in Init()
122 for (size_t i = indices_shape.size() - 1; i < updates_shape.size(); ++i) { in Init()
123 unit_size_ *= SizeToInt(updates_shape[i]); in Init()
126 num_units_ *= updates_shape[indices_shape.size() - 2]; in Init()
128 num_units_ *= updates_shape[i]; in Init()
/third_party/mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/
Dscatter_nd_cpu_kernel.cc61 auto updates_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); in InitKernel() local
69 if (updates_shape.size() != indices_shape.size() - 1 + shape.size() - indices_unit_rank) { in InitKernel()
73 if (updates_shape[i] != indices_shape[i]) { in InitKernel()
78 for (size_t i = indices_shape.size() - 1; i < updates_shape.size(); ++i) { in InitKernel()
79 unit_size_ *= SizeToInt(updates_shape[i]); in InitKernel()
82 num_units_ *= updates_shape[indices_shape.size() - 2]; in InitKernel()
84 num_units_ *= updates_shape[IntToSize(i)]; in InitKernel()
Dscatter_nd_update_cpu_kernel.cc64 auto updates_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); in InitKernel() local
72 if (updates_shape.size() != indices_shape.size() - 1 + shape.size() - indices_unit_rank) { in InitKernel()
76 if (updates_shape[i] != indices_shape[i]) { in InitKernel()
82 for (size_t i = indices_shape.size() - 1; i < updates_shape.size(); ++i) { in InitKernel()
83 unit_size_ *= SizeToInt(updates_shape[i]); in InitKernel()
86 num_units_ *= updates_shape[indices_shape.size() - 2]; in InitKernel()
88 num_units_ *= updates_shape[i]; in InitKernel()
/third_party/mindspore/mindspore/ops/composite/multitype_ops/
D_constexpr_utils.py476 updates_shape = indices_shape + data_shape[1:]
478 updates_shape = indices_shape[:-1] + data_shape[indices_shape[-1]:]
479 return Tensor(np.full(updates_shape, value), dtype=data_dtype)
486 updates_shape = index_shape + data_shape[1:]
488 updates_shape = index_shape[:-1] + data_shape[index_shape[-1]:]
489 return updates_shape
D_compile_utils.py580 updates_shape = const_utils.generate_updates_shape(data.shape, index.shape, op_type)
581 need_broadcast = const_utils.check_two_shapes_need_broadcast(updates_shape, value.shape)
583 return _broadcast(updates_shape, value)
/third_party/mindspore/mindspore/ops/operations/
Darray_ops.py52 def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name): argument
53 if indices_shape != [-1] and updates_shape and updates_shape != indices_shape + x_shape[1:]:
65 def infer_shape(self, x_shape, indices_shape, updates_shape): argument
66 self._check_scatter_shape(x_shape, indices_shape, updates_shape, self.name)
86 def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name): argument
92 if np.any(np.array(indices_shape) == -1) or np.any(np.array(updates_shape) == -1):
94 … elif indices_shape != [-1] and updates_shape and updates_shape != indices_shape + x_shape[1:]:
106 def check_shape(self, x_shape, indices_shape, updates_shape): argument
107 self._check_scatter_shape(x_shape, indices_shape, updates_shape, self.name)
120 def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name): argument
[all …]