1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test summary ops.""" 16import os 17import shutil 18import tempfile 19 20import numpy as np 21import pytest 22 23from mindspore import nn, Tensor, context 24from mindspore.common.initializer import Normal 25from mindspore.nn.metrics import Loss 26from mindspore.nn.optim import Momentum 27from mindspore.ops import operations as P 28from mindspore.train import Model 29from mindspore.train.summary.summary_record import _get_summary_tensor_data 30from tests.st.summary.dataset import create_mnist_dataset 31from tests.security_utils import security_off_wrap 32 33 34class LeNet5(nn.Cell): 35 """LeNet network""" 36 37 def __init__(self, num_class=10, num_channel=1, include_top=True): 38 super(LeNet5, self).__init__() 39 self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') 40 self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') 41 self.relu = nn.ReLU() 42 self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) 43 self.include_top = include_top 44 if self.include_top: 45 self.flatten = nn.Flatten() 46 self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02)) 47 self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02)) 48 self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02)) 49 50 self.scalar_summary = P.ScalarSummary() 51 self.image_summary = P.ImageSummary() 52 self.tensor_summary = P.TensorSummary() 53 self.channel = Tensor(num_channel) 54 55 def construct(self, x): 56 """construct""" 57 self.image_summary('x', x) 58 self.tensor_summary('x', x) 59 x = self.conv1(x) 60 x = self.relu(x) 61 x = self.max_pool2d(x) 62 x = self.conv2(x) 63 x = self.relu(x) 64 x = self.max_pool2d(x) 65 if not self.include_top: 66 return x 67 x = self.flatten(x) 68 x = self.relu(self.fc1(x)) 69 x = self.relu(self.fc2(x)) 70 x = self.fc3(x) 71 self.scalar_summary('x_fc3', x[0][0]) 72 return x 73 74 75class TestSummaryOps: 76 """Test summary ops.""" 77 base_summary_dir = '' 78 79 @classmethod 80 def setup_class(cls): 81 """Run before test this class.""" 82 device_id = int(os.getenv('DEVICE_ID')) if os.getenv('DEVICE_ID') else 0 83 context.set_context(mode=context.GRAPH_MODE, device_id=device_id) 84 cls.base_summary_dir = tempfile.mkdtemp(suffix='summary') 85 86 @classmethod 87 def teardown_class(cls): 88 """Run after test this class.""" 89 if os.path.exists(cls.base_summary_dir): 90 shutil.rmtree(cls.base_summary_dir) 91 92 @pytest.mark.level0 93 @pytest.mark.platform_x86_ascend_training 94 @pytest.mark.platform_arm_ascend_training 95 @pytest.mark.platform_x86_gpu_training 96 @pytest.mark.env_onecard 97 @security_off_wrap 98 def test_summary_ops(self): 99 """Test summary operators.""" 100 ds_train = create_mnist_dataset('train', num_samples=1, batch_size=1) 101 ds_train_iter = ds_train.create_dict_iterator() 102 expected_data = next(ds_train_iter)['image'].asnumpy() 103 104 net = LeNet5() 105 loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") 106 optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 107 model = Model(net, loss_fn=loss, optimizer=optim, metrics={'loss': Loss()}) 108 model.train(1, ds_train, dataset_sink_mode=False) 109 110 summary_data = _get_summary_tensor_data() 111 image_data = summary_data['x[:Image]'].asnumpy() 112 tensor_data = summary_data['x[:Tensor]'].asnumpy() 113 x_fc3 = summary_data['x_fc3[:Scalar]'].asnumpy() 114 115 assert np.allclose(expected_data, image_data) 116 assert np.allclose(expected_data, tensor_data) 117 assert not np.allclose(0, x_fc3) 118