• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Test histogram summary."""
16
17import logging
18import os
19import tempfile
20import numpy as np
21
22from mindspore.common.tensor import Tensor
23from mindspore.train.summary._summary_adapter import _calc_histogram_bins
24from mindspore.train.summary.summary_record import SummaryRecord, _cache_summary_tensor_data
25from tests.summary_utils import SummaryReader
26from tests.security_utils import security_off_wrap
27
28CUR_DIR = os.getcwd()
29SUMMARY_DIR = os.path.join(CUR_DIR, "/test_temp_summary_event_file/")
30
31LOG = logging.getLogger("test")
32LOG.setLevel(level=logging.ERROR)
33
34
35def _wrap_test_data(input_data: Tensor):
36    """
37    Wraps test data to summary format.
38
39    Args:
40        input_data (Tensor): Input data.
41
42    Returns:
43        dict, the wrapped data.
44    """
45
46    return [{
47        "name": "test_data[:Histogram]",
48        "data": input_data
49    }]
50
51
52@security_off_wrap
53def test_histogram_summary():
54    """Test histogram summary."""
55    with tempfile.TemporaryDirectory() as tmp_dir:
56        with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer:
57            test_data = _wrap_test_data(Tensor([[1, 2, 3], [4, 5, 6]]))
58            _cache_summary_tensor_data(test_data)
59            test_writer.record(step=1)
60        file_name = os.path.realpath(test_writer.log_dir)
61        with SummaryReader(file_name) as reader:
62            event = reader.read_event()
63            assert event.summary.value[0].histogram.count == 6
64
65
66@security_off_wrap
67def test_histogram_multi_summary():
68    """Test histogram multiple step."""
69    with tempfile.TemporaryDirectory() as tmp_dir:
70        with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer:
71
72            rng = np.random.RandomState(10)
73            size = 50
74            num_step = 5
75
76            for i in range(num_step):
77                arr = rng.normal(size=size)
78
79                test_data = _wrap_test_data(Tensor(arr))
80                _cache_summary_tensor_data(test_data)
81                test_writer.record(step=i)
82
83        file_name = os.path.realpath(test_writer.log_dir)
84        with SummaryReader(file_name) as reader:
85            for _ in range(num_step):
86                event = reader.read_event()
87                assert event.summary.value[0].histogram.count == size
88
89
90@security_off_wrap
91def test_histogram_summary_empty_tensor():
92    """Test histogram summary, input is an empty tensor."""
93    with tempfile.TemporaryDirectory() as tmp_dir:
94        with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer:
95            test_data = _wrap_test_data(Tensor([]))
96            _cache_summary_tensor_data(test_data)
97            test_writer.record(step=1)
98
99        file_name = os.path.realpath(test_writer.log_dir)
100        with SummaryReader(file_name) as reader:
101            event = reader.read_event()
102            assert event.summary.value[0].histogram.count == 0
103
104
105@security_off_wrap
106def test_histogram_summary_same_value():
107    """Test histogram summary, input is an ones tensor."""
108    with tempfile.TemporaryDirectory() as tmp_dir:
109        with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer:
110            dim1 = 100
111            dim2 = 100
112
113            test_data = _wrap_test_data(Tensor(np.ones([dim1, dim2])))
114            _cache_summary_tensor_data(test_data)
115            test_writer.record(step=1)
116
117        file_name = os.path.realpath(test_writer.log_dir)
118        with SummaryReader(file_name) as reader:
119            event = reader.read_event()
120            LOG.debug(event)
121
122            assert len(event.summary.value[0].histogram.buckets) == _calc_histogram_bins(dim1 * dim2)
123
124
125@security_off_wrap
126def test_histogram_summary_high_dims():
127    """Test histogram summary, input is a 4-dimension tensor."""
128    with tempfile.TemporaryDirectory() as tmp_dir:
129        with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer:
130            dim = 10
131
132            rng = np.random.RandomState(0)
133            tensor_data = rng.normal(size=[dim, dim, dim, dim])
134            test_data = _wrap_test_data(Tensor(tensor_data))
135            _cache_summary_tensor_data(test_data)
136            test_writer.record(step=1)
137
138        file_name = os.path.realpath(test_writer.log_dir)
139        with SummaryReader(file_name) as reader:
140            event = reader.read_event()
141            LOG.debug(event)
142
143            assert event.summary.value[0].histogram.count == tensor_data.size
144
145
146@security_off_wrap
147def test_histogram_summary_nan_inf():
148    """Test histogram summary, input tensor has nan."""
149    with tempfile.TemporaryDirectory() as tmp_dir:
150        with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer:
151            dim1 = 100
152            dim2 = 100
153
154            arr = np.ones([dim1, dim2])
155            arr[0][0] = np.nan
156            arr[0][1] = np.inf
157            arr[0][2] = -np.inf
158            test_data = _wrap_test_data(Tensor(arr))
159
160            _cache_summary_tensor_data(test_data)
161            test_writer.record(step=1)
162
163        file_name = os.path.realpath(test_writer.log_dir)
164        with SummaryReader(file_name) as reader:
165            event = reader.read_event()
166            LOG.debug(event)
167
168            assert event.summary.value[0].histogram.nan_count == 1
169
170
171@security_off_wrap
172def test_histogram_summary_all_nan_inf():
173    """Test histogram summary, input tensor has no valid number."""
174    with tempfile.TemporaryDirectory() as tmp_dir:
175        with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer:
176            test_data = _wrap_test_data(Tensor(np.array([np.nan, np.nan, np.nan, np.inf, -np.inf])))
177            _cache_summary_tensor_data(test_data)
178            test_writer.record(step=1)
179
180        file_name = os.path.realpath(test_writer.log_dir)
181        with SummaryReader(file_name) as reader:
182            event = reader.read_event()
183            LOG.debug(event)
184
185            histogram = event.summary.value[0].histogram
186            assert histogram.nan_count == 3
187            assert histogram.pos_inf_count == 1
188            assert histogram.neg_inf_count == 1
189