1# Copyright 2024 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import pytest 17from mindspore.nn import Cell 18from mindspore.common import Tensor, Parameter 19from mindspore import context, ops, lazy_inline, nn, no_inline, jit 20 21 22class Grad(Cell): 23 def __init__(self, net): 24 super(Grad, self).__init__() 25 self.grad = ops.GradOperation() 26 self.net = net 27 28 def construct(self, x): 29 grad_net = self.grad(self.net) 30 return grad_net(x) 31 32 33class TestBlock(Cell): 34 def __init__(self): 35 super(TestBlock, self).__init__() 36 self.y = Parameter(Tensor(5)) 37 38 def construct(self, x): 39 x = x + self.y 40 x = x + self.y * 2 41 x = x - 9 42 return x 43 44 45@pytest.mark.level1 46@pytest.mark.platform_arm_ascend_training 47@pytest.mark.env_onecard 48def test_nest(): 49 """ 50 Feature: Nest reusing cell with lazy inline. 51 Description: Nest reusing cell with lazy inline. 52 Expectation: Run successfully. 53 """ 54 55 class MyBlock(Cell): 56 @lazy_inline(policy="front") 57 def __init__(self): 58 super(MyBlock, self).__init__() 59 self.block = TestBlock() 60 61 def construct(self, x): 62 x = x + 3 63 x = self.block(x) 64 x = x + 4 65 return x 66 67 class InnerBlock(Cell): 68 @lazy_inline(policy="front") 69 def __init__(self): 70 super(InnerBlock, self).__init__() 71 self.blocks = nn.SequentialCell() 72 for _ in range(5): 73 b = MyBlock() 74 self.blocks.append(b) 75 76 def construct(self, x): 77 x = x + 1 78 x = self.blocks(x) 79 return x 80 81 class OuterBlock(Cell): 82 @lazy_inline 83 def __init__(self): 84 super(OuterBlock, self).__init__() 85 self.blocks = nn.SequentialCell() 86 for _ in range(5): 87 b = InnerBlock() 88 self.blocks.append(b) 89 90 def construct(self, x): 91 out = x + 2 92 out = self.blocks(out) 93 return out 94 95 class Net(Cell): 96 def __init__(self): 97 super(Net, self).__init__() 98 self.blocks = nn.SequentialCell() 99 for _ in range(3): 100 b = OuterBlock() 101 self.blocks.append(b) 102 103 def construct(self, x): 104 out = x 105 out = self.blocks(out) 106 out = out + 0.1 107 out = self.blocks(out) 108 return out 109 110 class Net1(Cell): 111 def __init__(self): 112 super(Net1, self).__init__() 113 self.blocks = nn.SequentialCell() 114 for _ in range(3): 115 b = OuterBlock() 116 self.blocks.append(b) 117 118 def construct(self, x): 119 out = x 120 out = self.blocks(out) 121 out = out + x 122 out = self.blocks(out) 123 return out 124 125 context.set_context(mode=context.GRAPH_MODE, save_graphs=0, save_graphs_path="./lazy") 126 x = Tensor(10) 127 net = Net1() 128 net(x) 129 net = Grad(net) 130 net(x) 131 132 133@pytest.mark.level1 134@pytest.mark.platform_arm_ascend_training 135@pytest.mark.env_onecard 136def test_no_inline(): 137 """ 138 Feature: make reusing function with no inline. 139 Description: reusing function with no inline. 140 Expectation: Run successfully. 141 """ 142 143 @no_inline 144 def no_inline_fun(val): 145 x = val * 3 + 2 146 return x 147 148 @jit 149 def call_no_inline_fun(val): 150 for _ in range(100): 151 val = no_inline_fun(val) 152 return val 153 154 x = Tensor(1) 155 x = call_no_inline_fun(x) 156