1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" 16test assign sub 17""" 18import numpy as np 19 20import mindspore.context as context 21import mindspore.nn as nn 22import mindspore.ops.operations as P 23from mindspore import Tensor 24from mindspore.common.initializer import initializer 25from mindspore.common.parameter import Parameter 26import mindspore as ms 27 28class AssignW(nn.Cell): 29 def __init__(self): 30 super(AssignW, self).__init__() 31 self.assign = P.Assign() 32 33 def construct(self, x, w): 34 self.assign(x, w) 35 return x 36 37 38class AssignOp(nn.Cell): 39 def __init__(self): 40 super(AssignOp, self).__init__() 41 self.b = Parameter(initializer('ones', [5]), name='b') 42 43 44 def construct(self, w): 45 self.b = w 46 return w 47 48 49def test_assign_by_operator(): 50 context.set_context(mode=context.GRAPH_MODE) 51 net = AssignOp() 52 net.to_float(ms.float16) 53 input_data = Tensor(np.ones([5]).astype(np.float32)) 54 net(input_data) 55 56 57class NetScatterNdUpdate(nn.Cell): 58 def __init__(self): 59 super(NetScatterNdUpdate, self).__init__() 60 self.b = Parameter(initializer('ones', [5, 5]), name='b') 61 self.scatter = P.ScatterNdUpdate() 62 63 def construct(self, idx, x): 64 return self.scatter(self.b, idx, x) 65 66 67def test_scatter_nd_update(): 68 context.set_context(mode=context.GRAPH_MODE) 69 net = NetScatterNdUpdate() 70 x = Tensor(np.ones([5]).astype(np.float16)) 71 idx = Tensor(np.ones([1]).astype(np.int32)) 72 net(idx, x) 73