• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test summary ops."""
16import os
17import shutil
18import tempfile
19
20import numpy as np
21import pytest
22
23from mindspore import nn, Tensor, context
24from mindspore.common.initializer import Normal
25from mindspore.nn.metrics import Loss
26from mindspore.nn.optim import Momentum
27from mindspore.ops import operations as P
28from mindspore.train import Model
29from mindspore.train.summary.summary_record import _get_summary_tensor_data
30from tests.st.summary.dataset import create_mnist_dataset
31from tests.security_utils import security_off_wrap
32
33
34class LeNet5(nn.Cell):
35    """LeNet network"""
36
37    def __init__(self, num_class=10, num_channel=1, include_top=True):
38        super(LeNet5, self).__init__()
39        self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
40        self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
41        self.relu = nn.ReLU()
42        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
43        self.include_top = include_top
44        if self.include_top:
45            self.flatten = nn.Flatten()
46            self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
47            self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
48            self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
49
50        self.scalar_summary = P.ScalarSummary()
51        self.image_summary = P.ImageSummary()
52        self.tensor_summary = P.TensorSummary()
53        self.channel = Tensor(num_channel)
54
55    def construct(self, x):
56        """construct"""
57        self.image_summary('x', x)
58        self.tensor_summary('x', x)
59        x = self.conv1(x)
60        x = self.relu(x)
61        x = self.max_pool2d(x)
62        x = self.conv2(x)
63        x = self.relu(x)
64        x = self.max_pool2d(x)
65        if not self.include_top:
66            return x
67        x = self.flatten(x)
68        x = self.relu(self.fc1(x))
69        x = self.relu(self.fc2(x))
70        x = self.fc3(x)
71        self.scalar_summary('x_fc3', x[0][0])
72        return x
73
74
75class TestSummaryOps:
76    """Test summary ops."""
77    base_summary_dir = ''
78
79    @classmethod
80    def setup_class(cls):
81        """Run before test this class."""
82        device_id = int(os.getenv('DEVICE_ID')) if os.getenv('DEVICE_ID') else 0
83        context.set_context(mode=context.GRAPH_MODE, device_id=device_id)
84        cls.base_summary_dir = tempfile.mkdtemp(suffix='summary')
85
86    @classmethod
87    def teardown_class(cls):
88        """Run after test this class."""
89        if os.path.exists(cls.base_summary_dir):
90            shutil.rmtree(cls.base_summary_dir)
91
92    @pytest.mark.level0
93    @pytest.mark.platform_x86_ascend_training
94    @pytest.mark.platform_arm_ascend_training
95    @pytest.mark.platform_x86_gpu_training
96    @pytest.mark.env_onecard
97    @security_off_wrap
98    def test_summary_ops(self):
99        """Test summary operators."""
100        ds_train = create_mnist_dataset('train', num_samples=1, batch_size=1)
101        ds_train_iter = ds_train.create_dict_iterator()
102        expected_data = next(ds_train_iter)['image'].asnumpy()
103
104        net = LeNet5()
105        loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
106        optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
107        model = Model(net, loss_fn=loss, optimizer=optim, metrics={'loss': Loss()})
108        model.train(1, ds_train, dataset_sink_mode=False)
109
110        summary_data = _get_summary_tensor_data()
111        image_data = summary_data['x[:Image]'].asnumpy()
112        tensor_data = summary_data['x[:Tensor]'].asnumpy()
113        x_fc3 = summary_data['x_fc3[:Scalar]'].asnumpy()
114
115        assert np.allclose(expected_data, image_data)
116        assert np.allclose(expected_data, tensor_data)
117        assert not np.allclose(0, x_fc3)
118