1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test expand_as""" 16import mindspore as ms 17import mindspore.nn as nn 18import mindspore.common.initializer as init 19from mindspore import Tensor 20from mindspore import context 21 22context.set_context(mode=context.GRAPH_MODE) 23 24 25def test_expand_as(): 26 class Net(nn.Cell): 27 def __init__(self): 28 super(Net, self).__init__() 29 self.t1 = Tensor([1, 2, 3]) 30 self.t2 = Tensor([[1, 1, 1], [1, 1, 1]]) 31 32 def construct(self): 33 return self.t1.expand_as(self.t2) 34 35 net = Net() 36 net() 37 38 39def test_initializer_expand_as(): 40 class Net(nn.Cell): 41 def __init__(self): 42 super(Net, self).__init__() 43 self.t1 = init.initializer('one', [1, 3], ms.float32) 44 self.t2 = init.initializer('one', [2, 3], ms.float32) 45 46 def construct(self): 47 return self.t1.expand_as(self.t2) 48 49 net = Net() 50 net() 51 52 53def test_expand_as_parameter(): 54 class Net(nn.Cell): 55 def __init__(self): 56 super(Net, self).__init__() 57 self.t1 = Tensor([1, 2, 3]) 58 59 def construct(self, x): 60 return self.t1.expand_as(x) 61 62 net = Net() 63 net(Tensor([[1, 1, 1], [1, 1, 1]])) 64 65 66def test_expand_tensor_as_parameter_1(): 67 class Net(nn.Cell): 68 def __init__(self): 69 super(Net, self).__init__() 70 self.t2 = Tensor([[1, 1, 1], [1, 1, 1]]) 71 72 def construct(self, x): 73 return x.expand_as(self.t2) 74 75 net = Net() 76 net(Tensor([1, 2, 3])) 77