• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""SummaryCollector scripts without main function."""
16import os
17import re
18from collections import Counter
19import tempfile
20import shutil
21import sys
22sys.path.append('../../../../')
23from mindspore import nn, Tensor, context
24from mindspore.common.initializer import Normal
25from mindspore.train import Loss
26from mindspore.nn.optim import Momentum
27from mindspore.ops import operations as P
28from mindspore.train import Model
29from mindspore import SummaryCollector
30from mindspore.communication import init, get_rank
31import mindspore as ms
32from tests.st.summary.dataset import create_mnist_dataset
33from tests.summary_utils import SummaryReader
34
35
36context.set_context(mode=ms.GRAPH_MODE)
37init()
38rank_id = get_rank()
39base_summary_dir = tempfile.mkdtemp(suffix='summary')
40
41class LeNet5(nn.Cell):
42    """
43    Lenet network
44
45    Args:
46        num_class (int): Number of classes. Default: 10.
47        num_channel (int): Number of channels. Default: 1.
48
49    Returns:
50        Tensor, output tensor
51    Examples:
52        >>> LeNet(num_class=10)
53
54    """
55
56    def __init__(self, num_class=10, num_channel=1, include_top=True):
57        super(LeNet5, self).__init__()
58        self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid', weight_init="normal", bias_init="zeros")
59        self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid', weight_init="normal", bias_init="zeros")
60        self.relu = nn.ReLU()
61        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
62        self.include_top = include_top
63        if self.include_top:
64            self.flatten = nn.Flatten()
65            self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02), bias_init="zeros")
66            self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02), bias_init="zeros")
67            self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02), bias_init="zeros")
68
69        self.scalar_summary = P.ScalarSummary()
70        self.image_summary = P.ImageSummary()
71        self.histogram_summary = P.HistogramSummary()
72        self.tensor_summary = P.TensorSummary()
73        self.channel = Tensor(num_channel)
74
75    def construct(self, x):
76        """construct."""
77        self.image_summary('image', x)
78        x = self.conv1(x)
79        self.histogram_summary('histogram', x)
80        x = self.relu(x)
81        self.tensor_summary('tensor', x)
82        x = self.relu(x)
83        x = self.max_pool2d(x)
84        self.scalar_summary('scalar', self.channel)
85        x = self.conv2(x)
86        x = self.relu(x)
87        x = self.max_pool2d(x)
88        if not self.include_top:
89            return x
90        x = self.flatten(x)
91        x = self.relu(self.fc1(x))
92        x = self.relu(self.fc2(x))
93        x = self.fc3(x)
94        return x
95
96
97def run_network(dataset_sink_mode=False, num_samples=2, **kwargs):
98    """run network."""
99    lenet = LeNet5()
100    loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
101    optim = Momentum(lenet.trainable_params(), learning_rate=0.1, momentum=0.9)
102    model = Model(lenet, loss_fn=loss, optimizer=optim, metrics={'loss': Loss()})
103    summary_dir = base_summary_dir + '/summary_' + str(rank_id)
104    summary_collector = SummaryCollector(summary_dir=summary_dir, collect_freq=2, **kwargs)
105
106    ds_train = create_mnist_dataset("train", num_samples=num_samples)
107    model.train(1, ds_train, callbacks=[summary_collector], dataset_sink_mode=dataset_sink_mode)
108
109    ds_eval = create_mnist_dataset("test")
110    model.eval(ds_eval, dataset_sink_mode=dataset_sink_mode, callbacks=[summary_collector])
111    return summary_dir
112
113
114def list_summary_tags(summary_dir):
115    """list summary tags."""
116    summary_file_path = ''
117    for file in os.listdir(summary_dir):
118        if re.search("_MS", file):
119            summary_file_path = os.path.join(summary_dir, file)
120            break
121    assert summary_file_path
122
123    tags = list()
124    with SummaryReader(summary_file_path) as summary_reader:
125
126        while True:
127            summary_event = summary_reader.read_event()
128            if not summary_event:
129                break
130            for value in summary_event.summary.value:
131                tags.append(value.tag)
132    return tags
133
134
135summary_path = run_network(num_samples=10)
136
137tag_list = list_summary_tags(summary_path)
138
139expected_tag_set = {'conv1.weight/auto', 'conv2.weight/auto', 'fc1.weight/auto', 'fc1.bias/auto',
140                    'fc2.weight/auto', 'input_data/auto', 'loss/auto',
141                    'histogram', 'image', 'scalar', 'tensor'}
142assert set(expected_tag_set) == set(tag_list)
143
144# num samples is 10, batch size is 2, so step is 5, collect freq is 2,
145# SummaryCollector will collect the first step and 2th, 4th step
146tag_count = 3
147for count in Counter(tag_list).values():
148    assert count == tag_count
149
150if os.path.exists(base_summary_dir):
151    shutil.rmtree(base_summary_dir)
152