1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" 16test nn.Tril() 17""" 18import numpy as np 19 20import mindspore.nn as nn 21from mindspore import Tensor 22from mindspore import context 23 24context.set_context(mode=context.GRAPH_MODE) 25 26 27def test_tril(): 28 class Net(nn.Cell): 29 def __init__(self): 30 super(Net, self).__init__() 31 self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) 32 33 def construct(self): 34 tril = nn.Tril() 35 return tril(self.value, 0) 36 37 net = Net() 38 out = net() 39 assert np.sum(out.asnumpy()) == 34 40 41 42def test_tril_1(): 43 class Net(nn.Cell): 44 def __init__(self): 45 super(Net, self).__init__() 46 self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) 47 48 def construct(self): 49 tril = nn.Tril() 50 return tril(self.value, 1) 51 52 net = Net() 53 out = net() 54 assert np.sum(out.asnumpy()) == 42 55 56 57def test_tril_2(): 58 class Net(nn.Cell): 59 def __init__(self): 60 super(Net, self).__init__() 61 self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) 62 63 def construct(self): 64 tril = nn.Tril() 65 return tril(self.value, -1) 66 67 net = Net() 68 out = net() 69 assert np.sum(out.asnumpy()) == 19 70 71 72def test_tril_parameter(): 73 class Net(nn.Cell): 74 def __init__(self): 75 super(Net, self).__init__() 76 77 def construct(self, x): 78 tril = nn.Tril() 79 return tril(x, 0) 80 81 net = Net() 82 net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])) 83 84 85def test_tril_parameter_1(): 86 class Net(nn.Cell): 87 def __init__(self): 88 super(Net, self).__init__() 89 90 def construct(self, x): 91 tril = nn.Tril() 92 return tril(x, 1) 93 94 net = Net() 95 net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])) 96 97 98def test_tril_parameter_2(): 99 class Net(nn.Cell): 100 def __init__(self): 101 super(Net, self).__init__() 102 103 def construct(self, x): 104 tril = nn.Tril() 105 return tril(x, -1) 106 107 net = Net() 108 net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])) 109