1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Test histogram summary.""" 16 17import logging 18import os 19import tempfile 20import numpy as np 21 22from mindspore.common.tensor import Tensor 23from mindspore.train.summary._summary_adapter import _calc_histogram_bins 24from mindspore.train.summary.summary_record import SummaryRecord, _cache_summary_tensor_data 25from tests.summary_utils import SummaryReader 26from tests.security_utils import security_off_wrap 27 28CUR_DIR = os.getcwd() 29SUMMARY_DIR = os.path.join(CUR_DIR, "/test_temp_summary_event_file/") 30 31LOG = logging.getLogger("test") 32LOG.setLevel(level=logging.ERROR) 33 34 35def _wrap_test_data(input_data: Tensor): 36 """ 37 Wraps test data to summary format. 38 39 Args: 40 input_data (Tensor): Input data. 41 42 Returns: 43 dict, the wrapped data. 44 """ 45 46 return [{ 47 "name": "test_data[:Histogram]", 48 "data": input_data 49 }] 50 51 52@security_off_wrap 53def test_histogram_summary(): 54 """Test histogram summary.""" 55 with tempfile.TemporaryDirectory() as tmp_dir: 56 with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer: 57 test_data = _wrap_test_data(Tensor([[1, 2, 3], [4, 5, 6]])) 58 _cache_summary_tensor_data(test_data) 59 test_writer.record(step=1) 60 file_name = os.path.realpath(test_writer.log_dir) 61 with SummaryReader(file_name) as reader: 62 event = reader.read_event() 63 assert event.summary.value[0].histogram.count == 6 64 65 66@security_off_wrap 67def test_histogram_multi_summary(): 68 """Test histogram multiple step.""" 69 with tempfile.TemporaryDirectory() as tmp_dir: 70 with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer: 71 72 rng = np.random.RandomState(10) 73 size = 50 74 num_step = 5 75 76 for i in range(num_step): 77 arr = rng.normal(size=size) 78 79 test_data = _wrap_test_data(Tensor(arr)) 80 _cache_summary_tensor_data(test_data) 81 test_writer.record(step=i) 82 83 file_name = os.path.realpath(test_writer.log_dir) 84 with SummaryReader(file_name) as reader: 85 for _ in range(num_step): 86 event = reader.read_event() 87 assert event.summary.value[0].histogram.count == size 88 89 90@security_off_wrap 91def test_histogram_summary_empty_tensor(): 92 """Test histogram summary, input is an empty tensor.""" 93 with tempfile.TemporaryDirectory() as tmp_dir: 94 with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer: 95 test_data = _wrap_test_data(Tensor([])) 96 _cache_summary_tensor_data(test_data) 97 test_writer.record(step=1) 98 99 file_name = os.path.realpath(test_writer.log_dir) 100 with SummaryReader(file_name) as reader: 101 event = reader.read_event() 102 assert event.summary.value[0].histogram.count == 0 103 104 105@security_off_wrap 106def test_histogram_summary_same_value(): 107 """Test histogram summary, input is an ones tensor.""" 108 with tempfile.TemporaryDirectory() as tmp_dir: 109 with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer: 110 dim1 = 100 111 dim2 = 100 112 113 test_data = _wrap_test_data(Tensor(np.ones([dim1, dim2]))) 114 _cache_summary_tensor_data(test_data) 115 test_writer.record(step=1) 116 117 file_name = os.path.realpath(test_writer.log_dir) 118 with SummaryReader(file_name) as reader: 119 event = reader.read_event() 120 LOG.debug(event) 121 122 assert len(event.summary.value[0].histogram.buckets) == _calc_histogram_bins(dim1 * dim2) 123 124 125@security_off_wrap 126def test_histogram_summary_high_dims(): 127 """Test histogram summary, input is a 4-dimension tensor.""" 128 with tempfile.TemporaryDirectory() as tmp_dir: 129 with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer: 130 dim = 10 131 132 rng = np.random.RandomState(0) 133 tensor_data = rng.normal(size=[dim, dim, dim, dim]) 134 test_data = _wrap_test_data(Tensor(tensor_data)) 135 _cache_summary_tensor_data(test_data) 136 test_writer.record(step=1) 137 138 file_name = os.path.realpath(test_writer.log_dir) 139 with SummaryReader(file_name) as reader: 140 event = reader.read_event() 141 LOG.debug(event) 142 143 assert event.summary.value[0].histogram.count == tensor_data.size 144 145 146@security_off_wrap 147def test_histogram_summary_nan_inf(): 148 """Test histogram summary, input tensor has nan.""" 149 with tempfile.TemporaryDirectory() as tmp_dir: 150 with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer: 151 dim1 = 100 152 dim2 = 100 153 154 arr = np.ones([dim1, dim2]) 155 arr[0][0] = np.nan 156 arr[0][1] = np.inf 157 arr[0][2] = -np.inf 158 test_data = _wrap_test_data(Tensor(arr)) 159 160 _cache_summary_tensor_data(test_data) 161 test_writer.record(step=1) 162 163 file_name = os.path.realpath(test_writer.log_dir) 164 with SummaryReader(file_name) as reader: 165 event = reader.read_event() 166 LOG.debug(event) 167 168 assert event.summary.value[0].histogram.nan_count == 1 169 170 171@security_off_wrap 172def test_histogram_summary_all_nan_inf(): 173 """Test histogram summary, input tensor has no valid number.""" 174 with tempfile.TemporaryDirectory() as tmp_dir: 175 with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer: 176 test_data = _wrap_test_data(Tensor(np.array([np.nan, np.nan, np.nan, np.inf, -np.inf]))) 177 _cache_summary_tensor_data(test_data) 178 test_writer.record(step=1) 179 180 file_name = os.path.realpath(test_writer.log_dir) 181 with SummaryReader(file_name) as reader: 182 event = reader.read_event() 183 LOG.debug(event) 184 185 histogram = event.summary.value[0].histogram 186 assert histogram.nan_count == 3 187 assert histogram.pos_inf_count == 1 188 assert histogram.neg_inf_count == 1 189