Searched refs:delta_shape (Results 1 – 3 of 3) sorted by relevance
/third_party/mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/ |
D | adam_delta_cpu_kernel.cc | 63 std::vector<size_t> delta_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); in InitKernel() local 68 if (!IsSameShape(delta_shape, m_shape)) { in InitKernel() 71 if (!IsSameShape(delta_shape, v_shape)) { in InitKernel() 74 if (!IsSameShape(delta_shape, grad_shape)) { in InitKernel() 77 if (delta_shape.empty()) { in InitKernel() 81 for (size_t i = 0; i < delta_shape.size(); ++i) { in InitKernel() 82 elem_num_ *= delta_shape[i]; in InitKernel()
|
/third_party/mindspore/mindspore/ops/operations/ |
D | nn_ops.py | 6734 def infer_shape(self, var_shape, alpha_shape, delta_shape): argument 6735 validator.check('delta shape', delta_shape, 'var shape', var_shape, Rel.EQ, self.name) 6822 def infer_shape(self, var_shape, alpha_shape, l1_shape, l2_shape, delta_shape): argument 6823 validator.check('delta shape', delta_shape, 'var shape', var_shape, Rel.EQ, self.name)
|
D | array_ops.py | 6025 def check_shape(self, start_shape, limit_shape, delta_shape): argument 6028 validator.check("delta_shape", len(delta_shape), "", 0, Rel.EQ, self.name)
|