1# Copyright 2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Models for test.""" 16from typing import Optional 17from mindspore import Tensor 18import mindspore.nn as nn 19 20 21class BaseNet(nn.Cell): 22 def __init__(self, a): 23 super().__init__() 24 self.relu = nn.ReLU() 25 self.a = a 26 27 def construct(self, x: Optional[Tensor]): 28 return x 29 30 def add_a(self, x): 31 x = x + self.a 32 return x 33 34 35class NoCellNet(): 36 def __init__(self, a, b): 37 self.a = a 38 self.b = b 39 40 def no_cell_func(self, x: Optional[Tensor]): 41 return x 42 43 44def external_func(x): 45 return x 46 47 48def external_func2(x): 49 return x 50 51EXTERN_LIST = [Tensor(1)] 52 53class NetWithClassVar(): 54 var1 = Tensor(1.0) 55 var2 = external_func 56 if True: # pylint: disable=using-constant-test 57 var3 = external_func2 58 var4 = EXTERN_LIST 59 60 def __init__(self, a): 61 self.a = a 62 63 def class_var_func(self, x: Optional[Tensor]): 64 # test class variables 65 x = x + self.var1 66 x = NetWithClassVar.var2(x) 67 x = NetWithClassVar.var3(x) 68 x = x + NetWithClassVar.var4[0] 69 # test instance variable 70 x = x + self.a 71 return x 72