• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import numpy as np
16from mindspore.nn import Cell
17from mindspore.common import Tensor, Parameter
18import mindspore.ops.operations as P
19from mindspore import context, ops, lazy_inline, nn
20
21context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
22context.set_context(jit_level='O2')
23
24
25class Grad(Cell):
26    def __init__(self, net):
27        super(Grad, self).__init__()
28        self.grad = ops.GradOperation()
29        self.net = net
30
31    def construct(self, x):
32        grad_net = self.grad(self.net)
33        return grad_net(x)
34
35
36class Block(Cell):
37    def __init__(self):
38        super(Block, self).__init__()
39        self.transpose1 = P.Transpose()
40        self.transpose2 = P.Transpose()
41        self.transpose3 = P.Transpose()
42        self.transpose4 = P.Transpose()
43        self.real_div1 = P.RealDiv()
44        self.real_div2 = P.RealDiv()
45        self.batch_matmul1 = P.BatchMatMul()
46        self.batch_matmul2 = P.BatchMatMul()
47        self.add = P.Add()
48        self.softmax = P.Softmax(-1)
49        self.dropout = P.Dropout(0.9)
50        self.expand_dims = P.ExpandDims()
51        self.sub = P.Sub()
52        self.mul = P.Mul()
53        self.y = Parameter(Tensor(np.ones((8, 128, 128)).astype(np.float32)))
54
55    def construct(self, x):
56        transpose1 = self.transpose1(x, (0, 2, 1, 3))
57        real_div1 = self.real_div1(transpose1, Tensor(2.37891))
58        transpose2 = self.transpose2(x, (0, 2, 3, 1))
59        real_div2 = self.real_div2(transpose2, Tensor(2.37891))
60        batch_matmul1 = self.batch_matmul1(real_div1, real_div2)
61        expand_dims = self.expand_dims(self.y, 1)
62        sub = self.sub(Tensor([1.0]), expand_dims)
63        mul = self.mul(sub, Tensor([-0.0001]))
64        add = self.add(mul, batch_matmul1)
65        soft_max = self.softmax(add)
66        dropout = self.dropout(soft_max)
67        transpose3 = self.transpose3(x, (0, 2, 1, 3))
68        batch_matmul2 = self.batch_matmul2(dropout[0], transpose3)
69        transpose4 = self.transpose4(batch_matmul2, (0, 2, 1, 3))
70        return transpose4
71
72
73class TestBlock(Cell):
74    def __init__(self):
75        super(TestBlock, self).__init__()
76        self.y = Parameter(Tensor(5))
77
78    def construct(self, x):
79        x = x + self.y
80        x = x + self.y * 2
81        x = x - 9
82        return x
83
84
85class TestIfBlock(Cell):
86    def __init__(self):
87        super(TestIfBlock, self).__init__()
88        self.y = Parameter(Tensor(5))
89
90    def construct(self, x):
91        if x > 10:
92            x = x + self.y * 2
93        else:
94            x = x + self.y
95        x = x - 9
96        return x
97
98
99def test_recompute_block_recompute():
100    """
101    Feature: Recompute with lazy inline.
102    Description: Each block is set recompute by the cell recompute api.
103    Expectation: Run successfully and the memory usage is reduced.
104    """
105
106    class OuterBlock(Cell):
107        @lazy_inline
108        def __init__(self):
109            super(OuterBlock, self).__init__()
110            self.block = Block()
111
112        def construct(self, x):
113            return self.block(x)
114
115    class Net(Cell):
116        def __init__(self):
117            super(Net, self).__init__()
118            self.blocks = nn.CellList()
119            for _ in range(3):
120                b = OuterBlock()
121                b.recompute()
122                self.blocks.append(b)
123
124        def construct(self, x):
125            out = x
126            for i in range(3):
127                out = self.blocks[i](out)
128            return out
129
130    x = Tensor(np.ones((8, 128, 16, 32)).astype(np.float32))
131    net = Net()
132    grad_net = Grad(net)
133    grad_net(x)
134
135
136def test_recompute_op_recompute1():
137    """
138    Feature: Recompute with lazy inline.
139    Description: Each block is set recompute by the primitive recompute api.
140    Expectation: Run successfully and the memory usage is reduced.
141    """
142
143    class OuterBlock(Cell):
144        @lazy_inline
145        def __init__(self):
146            super(OuterBlock, self).__init__()
147            self.block = Block()
148            self.block.real_div1.recompute()
149            self.block.batch_matmul1.recompute()
150            self.block.add.recompute()
151            self.block.softmax.recompute()
152
153        def construct(self, x):
154            return self.block(x)
155
156    class Net(Cell):
157        def __init__(self):
158            super(Net, self).__init__()
159            self.blocks = nn.CellList()
160            for _ in range(3):
161                b = OuterBlock()
162                self.blocks.append(b)
163
164        def construct(self, x):
165            out = x
166            for i in range(3):
167                out = self.blocks[i](out)
168            return out
169
170    x = Tensor(np.ones((8, 128, 16, 32)).astype(np.float32))
171    net = Net()
172    grad_net = Grad(net)
173    grad_net(x)
174
175
176def test_recompute_op_recompute2():
177    """
178    Feature: Recompute with lazy inline.
179    Description: Each block is set recompute by the primitive recompute api.
180    Expectation: Run successfully and the memory usage is reduced.
181    """
182
183    class OuterBlock(Cell):
184        @lazy_inline
185        def __init__(self):
186            super(OuterBlock, self).__init__()
187            self.block = Block()
188            self.block.transpose1.recompute()
189            self.block.transpose2.recompute()
190            self.block.real_div1.recompute()
191            self.block.real_div2.recompute()
192            self.block.batch_matmul1.recompute()
193            self.block.add.recompute()
194            self.block.softmax.recompute()
195            self.block.dropout.recompute()
196
197        def construct(self, x):
198            return self.block(x)
199
200    class Net(Cell):
201        def __init__(self):
202            super(Net, self).__init__()
203            self.blocks = nn.CellList()
204            for _ in range(3):
205                b = OuterBlock()
206                self.blocks.append(b)
207
208        def construct(self, x):
209            out = x
210            for i in range(3):
211                out = self.blocks[i](out)
212            return out
213
214    x = Tensor(np.ones((8, 128, 16, 32)).astype(np.float32))
215    net = Net()
216    grad_net = Grad(net)
217    grad_net(x)
218
219
220def test_recompute_op_recompute3():
221    """
222    Feature: Recompute with lazy inline.
223    Description: Each block is set recompute by the primitive recompute api.
224    Expectation: Run successfully and the memory usage is reduced.
225    """
226
227    class Block1(Cell):
228        def __init__(self):
229            super(Block1, self).__init__()
230            self.transpose1 = P.Transpose()
231            self.transpose2 = P.Transpose()
232            self.transpose3 = P.Transpose()
233            self.transpose4 = P.Transpose()
234            self.real_div1 = P.RealDiv()
235            self.real_div2 = P.RealDiv()
236            self.batch_matmul1 = P.BatchMatMul()
237            self.batch_matmul2 = P.BatchMatMul()
238            self.add = P.Add()
239            self.softmax = P.Softmax(-1)
240            self.dropout = P.Dropout(0.9)
241            self.expand_dims = P.ExpandDims()
242            self.sub1 = P.Sub()
243            self.sub2 = P.Sub()
244            self.mul = P.Mul()
245            self.y = Parameter(Tensor(np.ones((8, 16, 128, 128)).astype(np.float32)))
246
247        def construct(self, x):
248            transpose1 = self.transpose1(x, (0, 2, 1, 3))
249            real_div1 = self.real_div1(transpose1, Tensor(2.37891))
250            sub1 = self.sub1(Tensor([1.0]), transpose1)
251            sub2 = self.sub2(Tensor([1.0]), sub1)
252            mul = self.mul(sub2, Tensor([-0.0001]))
253            add = self.add(mul, real_div1)
254            soft_max = self.softmax(add)
255            dropout = self.dropout(soft_max)
256            transpose3 = self.transpose3(x, (0, 2, 1, 3))
257            batch_matmul2 = self.batch_matmul2(dropout[0], transpose3)
258            transpose4 = self.transpose4(batch_matmul2, (0, 2, 1, 3))
259            return transpose4
260
261    class OuterBlock(Cell):
262        @lazy_inline
263        def __init__(self):
264            super(OuterBlock, self).__init__()
265            self.block = Block1()
266            self.block.mul.recompute()
267            self.block.real_div1.recompute()
268            self.block.transpose1.recompute()
269            self.block.sub1.recompute()
270            self.block.add.recompute()
271            self.block.softmax.recompute()
272
273        def construct(self, x):
274            return self.block(x)
275
276    class Net(Cell):
277        def __init__(self):
278            super(Net, self).__init__()
279            self.blocks = nn.CellList()
280            for _ in range(3):
281                b = OuterBlock()
282                self.blocks.append(b)
283
284        def construct(self, x):
285            out = x
286            for i in range(3):
287                out = self.blocks[i](out)
288            return out
289
290    x = Tensor(np.ones((8, 128, 16, 128)).astype(np.float32))
291    net = Net()
292    grad_net = Grad(net)
293    grad_net(x)
294
295
296def test_recompute_cell_and_op_recompute1():
297    """
298    Feature: Recompute with lazy inline.
299    Description: Each block is set recompute by both the primitive and cell recompute api.
300    Expectation: Run successfully and the memory usage is reduced.
301    """
302
303    class Net1(Cell):
304        def __init__(self):
305            super(Net1, self).__init__()
306            self.transpose2 = P.Transpose()
307            self.real_div2 = P.RealDiv()
308
309        def construct(self, x):
310            transpose2 = self.transpose2(x, (0, 2, 3, 1))
311            real_div2 = self.real_div2(transpose2, Tensor(2.37891))
312            return real_div2
313
314    class Block1(Cell):
315        def __init__(self):
316            super(Block1, self).__init__()
317            self.transpose1 = P.Transpose()
318            self.transpose2 = P.Transpose()
319            self.transpose3 = P.Transpose()
320            self.transpose4 = P.Transpose()
321            self.real_div1 = P.RealDiv()
322            self.real_div1.recompute()
323            self.real_div2 = P.RealDiv()
324            self.batch_matmul1 = P.BatchMatMul()
325            self.batch_matmul1.recompute()
326            self.batch_matmul2 = P.BatchMatMul()
327            self.add = P.Add()
328            self.add.recompute()
329            self.softmax = P.Softmax(-1)
330            self.softmax.recompute()
331            self.dropout = P.Dropout(0.9)
332            self.expand_dims = P.ExpandDims()
333            self.sub = P.Sub()
334            self.mul = P.Mul()
335            self.net1 = Net1()
336            self.net1.recompute()
337            self.y = Parameter(Tensor(np.ones((8, 128, 128)).astype(np.float32)))
338
339        def construct(self, x):
340            transpose1 = self.transpose1(x, (0, 2, 1, 3))
341            real_div1 = self.real_div1(transpose1, Tensor(2.37891))
342            real_div2 = self.net1(x)
343            batch_matmul1 = self.batch_matmul1(real_div1, real_div2)
344            expand_dims = self.expand_dims(self.y, 1)
345            sub = self.sub(Tensor([1.0]), expand_dims)
346            mul = self.mul(sub, Tensor([-0.0001]))
347            add = self.add(mul, batch_matmul1)
348            soft_max = self.softmax(add)
349            dropout = self.dropout(soft_max)
350            transpose3 = self.transpose3(x, (0, 2, 1, 3))
351            batch_matmul2 = self.batch_matmul2(dropout[0], transpose3)
352            transpose4 = self.transpose4(batch_matmul2, (0, 2, 1, 3))
353            return transpose4
354
355    class OuterBlock(Cell):
356        @lazy_inline
357        def __init__(self):
358            super(OuterBlock, self).__init__()
359            self.block = Block1()
360
361        def construct(self, x):
362            return self.block(x)
363
364    class Net(Cell):
365        def __init__(self):
366            super(Net, self).__init__()
367            self.blocks = nn.CellList()
368            for _ in range(3):
369                b = OuterBlock()
370                self.blocks.append(b)
371
372        def construct(self, x):
373            out = x
374            for i in range(3):
375                out = self.blocks[i](out)
376            return out
377
378    x = Tensor(np.ones((8, 128, 16, 32)).astype(np.float32))
379    net = Net()
380    grad_net = Grad(net)
381    grad_net(x)
382
383
384def test_recompute_cell_and_op_recompute2():
385    """
386    Feature: Recompute with lazy inline.
387    Description: Each block is set recompute by both the primitive and cell recompute api.
388    Expectation: Run successfully and the memory usage is reduced.
389    """
390
391    class Net1(Cell):
392        def __init__(self):
393            super(Net1, self).__init__()
394            self.transpose2 = P.Transpose()
395            self.real_div2 = P.RealDiv()
396
397        def construct(self, x):
398            transpose2 = self.transpose2(x, (0, 2, 3, 1))
399            real_div2 = self.real_div2(transpose2, Tensor(2.37891))
400            return real_div2
401
402    class Block1(Cell):
403        def __init__(self):
404            super(Block1, self).__init__()
405            self.transpose1 = P.Transpose()
406            self.transpose2 = P.Transpose()
407            self.transpose3 = P.Transpose()
408            self.transpose4 = P.Transpose()
409            self.real_div1 = P.RealDiv()
410            self.real_div1.recompute()
411            self.real_div2 = P.RealDiv()
412            self.batch_matmul1 = P.BatchMatMul()
413            self.batch_matmul1.recompute()
414            self.batch_matmul2 = P.BatchMatMul()
415            self.add = P.Add()
416            self.add.recompute()
417            self.softmax = P.Softmax(-1)
418            self.softmax.recompute()
419            self.dropout = P.Dropout(0.9)
420            self.expand_dims = P.ExpandDims()
421            self.sub = P.Sub()
422            self.mul = P.Mul()
423            self.depend = ops.Depend()
424            self.net1 = Net1()
425            self.net1.recompute()
426            self.y = Parameter(Tensor(np.ones((8, 128, 128)).astype(np.float32)))
427
428        def construct(self, x):
429            real_div2 = self.net1(x)
430            depend = self.depend(x, real_div2)
431            transpose1 = self.transpose1(depend, (0, 2, 1, 3))
432            real_div1 = self.real_div1(transpose1, Tensor(2.37891))
433            batch_matmul1 = self.batch_matmul1(real_div1, real_div2)
434            expand_dims = self.expand_dims(self.y, 1)
435            sub = self.sub(Tensor([1.0]), expand_dims)
436            mul = self.mul(sub, Tensor([-0.0001]))
437            add = self.add(mul, batch_matmul1)
438            soft_max = self.softmax(add)
439            dropout = self.dropout(soft_max)
440            transpose3 = self.transpose3(x, (0, 2, 1, 3))
441            batch_matmul2 = self.batch_matmul2(dropout[0], transpose3)
442            transpose4 = self.transpose4(batch_matmul2, (0, 2, 1, 3))
443            return transpose4
444
445    class OuterBlock(Cell):
446        @lazy_inline
447        def __init__(self):
448            super(OuterBlock, self).__init__()
449            self.block = Block1()
450
451        def construct(self, x):
452            return self.block(x)
453
454    class Net(Cell):
455        def __init__(self):
456            super(Net, self).__init__()
457            self.blocks = nn.CellList()
458            for _ in range(3):
459                b = OuterBlock()
460                self.blocks.append(b)
461
462        def construct(self, x):
463            out = x
464            for i in range(3):
465                out = self.blocks[i](out)
466            return out
467
468    x = Tensor(np.ones((8, 128, 16, 32)).astype(np.float32))
469    net = Net()
470    grad_net = Grad(net)
471    grad_net(x)
472
473
474def test_recompute_cell_and_op_recompute_with_tuple_outputs1():
475    """
476    Feature: Recompute with lazy inline.
477    Description: Each block is set recompute by both the primitive and cell recompute api and return a tuple.
478    Expectation: Run successfully and the memory usage is reduced.
479    """
480
481    class Net1(Cell):
482        def __init__(self):
483            super(Net1, self).__init__()
484            self.transpose2 = P.Transpose()
485            self.real_div2 = P.RealDiv()
486
487        def construct(self, x):
488            transpose2 = self.transpose2(x, (0, 2, 3, 1))
489            real_div2 = self.real_div2(transpose2, Tensor(2.37891))
490            return real_div2
491
492    class Block1(Cell):
493        def __init__(self):
494            super(Block1, self).__init__()
495            self.transpose1 = P.Transpose()
496            self.transpose2 = P.Transpose()
497            self.transpose3 = P.Transpose()
498            self.transpose4 = P.Transpose()
499            self.transpose4.recompute()
500            self.real_div1 = P.RealDiv()
501            self.real_div1.recompute()
502            self.real_div2 = P.RealDiv()
503            self.batch_matmul1 = P.BatchMatMul()
504            self.batch_matmul1.recompute()
505            self.batch_matmul2 = P.BatchMatMul()
506            self.add = P.Add()
507            self.add.recompute()
508            self.add1 = P.Add()
509            self.softmax = P.Softmax(-1)
510            self.softmax.recompute()
511            self.dropout = P.Dropout(0.9)
512            self.expand_dims = P.ExpandDims()
513            self.sub = P.Sub()
514            self.mul = P.Mul()
515            self.net1 = Net1()
516            self.net1.recompute()
517            self.y = Parameter(Tensor(np.ones((8, 128, 128)).astype(np.float32)))
518
519        def construct(self, x, z):
520            transpose1 = self.transpose1(x, (0, 2, 1, 3))
521            real_div1 = self.real_div1(transpose1, Tensor(2.37891))
522            real_div2 = self.net1(x)
523            batch_matmul1 = self.batch_matmul1(real_div1, real_div2)
524            expand_dims = self.expand_dims(self.y, 1)
525            sub = self.sub(Tensor([1.0]), expand_dims)
526            mul = self.mul(sub, Tensor([-0.0001]))
527            add = self.add(mul, batch_matmul1)
528            soft_max = self.softmax(add)
529            dropout = self.dropout(soft_max)
530            transpose3 = self.transpose3(x, (0, 2, 1, 3))
531            batch_matmul2 = self.batch_matmul2(dropout[0], transpose3)
532            transpose4 = self.transpose4(batch_matmul2, (0, 2, 1, 3))
533            add1 = self.add1(transpose4, z)
534            return add1, transpose4
535
536    class OuterBlock(Cell):
537        @lazy_inline
538        def __init__(self):
539            super(OuterBlock, self).__init__()
540            self.block = Block1()
541
542        def construct(self, x, z):
543            return self.block(x, z)
544
545    class Net(Cell):
546        def __init__(self):
547            super(Net, self).__init__()
548            self.blocks = nn.CellList()
549            for _ in range(3):
550                b = OuterBlock()
551                self.blocks.append(b)
552
553        def construct(self, x):
554            out1, out2 = x, x
555            for i in range(3):
556                out1, out2 = self.blocks[i](out1, out2)
557            return out1, out2
558
559    x = Tensor(np.ones((8, 128, 16, 32)).astype(np.float32))
560    net = Net()
561    grad_net = Grad(net)
562    grad_net(x)
563
564
565def test_recompute_cell_and_op_recompute_with_tuple_outputs2():
566    """
567    Feature: Recompute with lazy inline.
568    Description: Each block is set recompute by both the primitive and cell recompute api and return a tuple.
569    Expectation: Run successfully and the memory usage is reduced.
570    """
571
572    class Net1(Cell):
573        def __init__(self):
574            super(Net1, self).__init__()
575            self.transpose2 = P.Transpose()
576            self.real_div2 = P.RealDiv()
577
578        def construct(self, x):
579            transpose2 = self.transpose2(x, (0, 2, 3, 1))
580            real_div2 = self.real_div2(transpose2, Tensor(2.37891))
581            return real_div2
582
583    class Block1(Cell):
584        def __init__(self):
585            super(Block1, self).__init__()
586            self.transpose1 = P.Transpose()
587            self.transpose2 = P.Transpose()
588            self.transpose3 = P.Transpose()
589            self.transpose4 = P.Transpose()
590            self.transpose4.recompute()
591            self.real_div1 = P.RealDiv()
592            self.real_div1.recompute()
593            self.real_div2 = P.RealDiv()
594            self.batch_matmul1 = P.BatchMatMul()
595            self.batch_matmul1.recompute()
596            self.batch_matmul2 = P.BatchMatMul()
597            self.add = P.Add()
598            self.add.recompute()
599            self.add1 = P.Add()
600            self.add1.recompute()
601            self.softmax = P.Softmax(-1)
602            self.softmax.recompute()
603            self.dropout = P.Dropout(0.9)
604            self.expand_dims = P.ExpandDims()
605            self.sub = P.Sub()
606            self.mul = P.Mul()
607            self.net1 = Net1()
608            self.net1.recompute()
609            self.y = Parameter(Tensor(np.ones((8, 128, 128)).astype(np.float32)))
610
611        def construct(self, x, z):
612            transpose1 = self.transpose1(x, (0, 2, 1, 3))
613            real_div1 = self.real_div1(transpose1, Tensor(2.37891))
614            real_div2 = self.net1(x)
615            batch_matmul1 = self.batch_matmul1(real_div1, real_div2)
616            expand_dims = self.expand_dims(self.y, 1)
617            sub = self.sub(Tensor([1.0]), expand_dims)
618            mul = self.mul(sub, Tensor([-0.0001]))
619            add = self.add(mul, batch_matmul1)
620            soft_max = self.softmax(add)
621            dropout = self.dropout(soft_max)
622            transpose3 = self.transpose3(x, (0, 2, 1, 3))
623            batch_matmul2 = self.batch_matmul2(dropout[0], transpose3)
624            transpose4 = self.transpose4(batch_matmul2, (0, 2, 1, 3))
625            add1 = self.add1(transpose4, z)
626            return add1, transpose4
627
628    class OuterBlock(Cell):
629        @lazy_inline
630        def __init__(self):
631            super(OuterBlock, self).__init__()
632            self.block = Block1()
633
634        def construct(self, x, z):
635            return self.block(x, z)
636
637    class Net(Cell):
638        def __init__(self):
639            super(Net, self).__init__()
640            self.blocks = nn.CellList()
641            for _ in range(3):
642                b = OuterBlock()
643                self.blocks.append(b)
644
645        def construct(self, x):
646            out1, out2 = x, x
647            for i in range(3):
648                out1, out2 = self.blocks[i](out1, out2)
649            return out1, out2
650
651    x = Tensor(np.ones((8, 128, 16, 32)).astype(np.float32))
652    net = Net()
653    grad_net = Grad(net)
654    grad_net(x)
655