Home
last modified time | relevance | path

Searched refs:tensor_scatter_update (Results 1 – 6 of 6) sorted by relevance

/third_party/mindspore/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/
Dtensor_scatter_update_fission.cc25 CNodePtr CreateTensorMove(const FuncGraphPtr &graph, const CNodePtr &tensor_scatter_update) { in CreateTensorMove() argument
27 MS_EXCEPTION_IF_NULL(tensor_scatter_update); in CreateTensorMove()
29 tensor_scatter_update->input(1)}; in CreateTensorMove()
32 tensor_move->set_scope(tensor_scatter_update->scope()); in CreateTensorMove()
33 tensor_move->set_abstract(tensor_scatter_update->abstract()); in CreateTensorMove()
38 CNodePtr CreateScatterNdUpdate(const FuncGraphPtr &graph, const CNodePtr &tensor_scatter_update, in CreateScatterNdUpdate() argument
41 MS_EXCEPTION_IF_NULL(tensor_scatter_update); in CreateScatterNdUpdate()
44tensor_scatter_update->input(2), tensor_scatter_update->input(3)}; in CreateScatterNdUpdate()
47 scatter_nd_update->set_scope(tensor_scatter_update->scope()); in CreateScatterNdUpdate()
48 scatter_nd_update->set_abstract(tensor_scatter_update->abstract()); in CreateScatterNdUpdate()
[all …]
/third_party/mindspore/tests/ut/cpp/python_input/gtest_input/pre_activate/
Dtensor_scatter_update_fission_test.py19 tensor_scatter_update = P.TensorScatterUpdate() variable
42 res = tensor_scatter_update(x, indices, updates)
/third_party/mindspore/mindspore/ops/composite/multitype_ops/
D_compile_utils.py638 return F.tensor_scatter_update(data, index, updates)
685 return F.tensor_scatter_update(data, index, updates)
719 result = F.tensor_scatter_update(data, indices, value.astype(F.dtype(data)))
764 return F.tensor_scatter_update(data, indices, updates)
790 return F.tensor_scatter_update(data, index, value)
/third_party/mindspore/mindspore/ops/
Dfunctional.py139 tensor_scatter_update = P.TensorScatterUpdate() variable
/third_party/mindspore/mindspore/ops/_op_impl/tbe/
D__init__.py353 from .tensor_scatter_update import _tensor_scatter_update_tbe
/third_party/mindspore/mindspore/ops/_grad/
Dgrad_array_ops.py772 tensor_scatter_update = P.TensorScatterUpdate()
775 x_grad = tensor_scatter_update(dout, indices, zeros_like(update))