• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Models for test."""
16from typing import Optional
17from mindspore import Tensor
18import mindspore.nn as nn
19
20
21class BaseNet(nn.Cell):
22    def __init__(self, a):
23        super().__init__()
24        self.relu = nn.ReLU()
25        self.a = a
26
27    def construct(self, x: Optional[Tensor]):
28        return x
29
30    def add_a(self, x):
31        x = x + self.a
32        return x
33
34
35class NoCellNet():
36    def __init__(self, a, b):
37        self.a = a
38        self.b = b
39
40    def no_cell_func(self, x: Optional[Tensor]):
41        return x
42
43
44def external_func(x):
45    return x
46
47
48def external_func2(x):
49    return x
50
51EXTERN_LIST = [Tensor(1)]
52
53class NetWithClassVar():
54    var1 = Tensor(1.0)
55    var2 = external_func
56    if True: # pylint: disable=using-constant-test
57        var3 = external_func2
58    var4 = EXTERN_LIST
59
60    def __init__(self, a):
61        self.a = a
62
63    def class_var_func(self, x: Optional[Tensor]):
64        # test class variables
65        x = x + self.var1
66        x = NetWithClassVar.var2(x)
67        x = NetWithClassVar.var3(x)
68        x = x + NetWithClassVar.var4[0]
69        # test instance variable
70        x = x + self.a
71        return x
72