1# Copyright 2023 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""SummaryCollector scripts without main function.""" 16import os 17import re 18from collections import Counter 19import tempfile 20import shutil 21import sys 22sys.path.append('../../../../') 23from mindspore import nn, Tensor, context 24from mindspore.common.initializer import Normal 25from mindspore.train import Loss 26from mindspore.nn.optim import Momentum 27from mindspore.ops import operations as P 28from mindspore.train import Model 29from mindspore import SummaryCollector 30from mindspore.communication import init, get_rank 31import mindspore as ms 32from tests.st.summary.dataset import create_mnist_dataset 33from tests.summary_utils import SummaryReader 34 35 36context.set_context(mode=ms.GRAPH_MODE) 37init() 38rank_id = get_rank() 39base_summary_dir = tempfile.mkdtemp(suffix='summary') 40 41class LeNet5(nn.Cell): 42 """ 43 Lenet network 44 45 Args: 46 num_class (int): Number of classes. Default: 10. 47 num_channel (int): Number of channels. Default: 1. 48 49 Returns: 50 Tensor, output tensor 51 Examples: 52 >>> LeNet(num_class=10) 53 54 """ 55 56 def __init__(self, num_class=10, num_channel=1, include_top=True): 57 super(LeNet5, self).__init__() 58 self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid', weight_init="normal", bias_init="zeros") 59 self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid', weight_init="normal", bias_init="zeros") 60 self.relu = nn.ReLU() 61 self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) 62 self.include_top = include_top 63 if self.include_top: 64 self.flatten = nn.Flatten() 65 self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02), bias_init="zeros") 66 self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02), bias_init="zeros") 67 self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02), bias_init="zeros") 68 69 self.scalar_summary = P.ScalarSummary() 70 self.image_summary = P.ImageSummary() 71 self.histogram_summary = P.HistogramSummary() 72 self.tensor_summary = P.TensorSummary() 73 self.channel = Tensor(num_channel) 74 75 def construct(self, x): 76 """construct.""" 77 self.image_summary('image', x) 78 x = self.conv1(x) 79 self.histogram_summary('histogram', x) 80 x = self.relu(x) 81 self.tensor_summary('tensor', x) 82 x = self.relu(x) 83 x = self.max_pool2d(x) 84 self.scalar_summary('scalar', self.channel) 85 x = self.conv2(x) 86 x = self.relu(x) 87 x = self.max_pool2d(x) 88 if not self.include_top: 89 return x 90 x = self.flatten(x) 91 x = self.relu(self.fc1(x)) 92 x = self.relu(self.fc2(x)) 93 x = self.fc3(x) 94 return x 95 96 97def run_network(dataset_sink_mode=False, num_samples=2, **kwargs): 98 """run network.""" 99 lenet = LeNet5() 100 loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") 101 optim = Momentum(lenet.trainable_params(), learning_rate=0.1, momentum=0.9) 102 model = Model(lenet, loss_fn=loss, optimizer=optim, metrics={'loss': Loss()}) 103 summary_dir = base_summary_dir + '/summary_' + str(rank_id) 104 summary_collector = SummaryCollector(summary_dir=summary_dir, collect_freq=2, **kwargs) 105 106 ds_train = create_mnist_dataset("train", num_samples=num_samples) 107 model.train(1, ds_train, callbacks=[summary_collector], dataset_sink_mode=dataset_sink_mode) 108 109 ds_eval = create_mnist_dataset("test") 110 model.eval(ds_eval, dataset_sink_mode=dataset_sink_mode, callbacks=[summary_collector]) 111 return summary_dir 112 113 114def list_summary_tags(summary_dir): 115 """list summary tags.""" 116 summary_file_path = '' 117 for file in os.listdir(summary_dir): 118 if re.search("_MS", file): 119 summary_file_path = os.path.join(summary_dir, file) 120 break 121 assert summary_file_path 122 123 tags = list() 124 with SummaryReader(summary_file_path) as summary_reader: 125 126 while True: 127 summary_event = summary_reader.read_event() 128 if not summary_event: 129 break 130 for value in summary_event.summary.value: 131 tags.append(value.tag) 132 return tags 133 134 135summary_path = run_network(num_samples=10) 136 137tag_list = list_summary_tags(summary_path) 138 139expected_tag_set = {'conv1.weight/auto', 'conv2.weight/auto', 'fc1.weight/auto', 'fc1.bias/auto', 140 'fc2.weight/auto', 'input_data/auto', 'loss/auto', 141 'histogram', 'image', 'scalar', 'tensor'} 142assert set(expected_tag_set) == set(tag_list) 143 144# num samples is 10, batch size is 2, so step is 5, collect freq is 2, 145# SummaryCollector will collect the first step and 2th, 4th step 146tag_count = 3 147for count in Counter(tag_list).values(): 148 assert count == tag_count 149 150if os.path.exists(base_summary_dir): 151 shutil.rmtree(base_summary_dir) 152