1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15import numpy as np 16import pytest 17 18import mindspore.common.dtype as mstype 19import mindspore.dataset as ds 20import mindspore.dataset.text as text 21 22np_integral_types = [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, 23 np.uint32, np.uint64] 24ms_integral_types = [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.uint8, 25 mstype.uint16, mstype.uint32, mstype.uint64] 26 27np_non_integral_types = [np.float16, np.float32, np.float64] 28ms_non_integral_types = [mstype.float16, mstype.float32, mstype.float64] 29 30 31def string_dataset_generator(strings): 32 for string in strings: 33 yield (np.array(string, dtype='S'),) 34 35 36def test_to_number_eager(): 37 """ 38 Test ToNumber op is callable 39 """ 40 input_strings = [["1", "2", "3"], ["4", "5", "6"]] 41 op = text.ToNumber(mstype.int8) 42 43 # test input_strings as one 2D tensor 44 result1 = op(input_strings) # np array: [[1 2 3] [4 5 6]] 45 assert np.array_equal(result1, np.array([[1, 2, 3], [4, 5, 6]], dtype='i')) 46 47 # test input multiple tensors 48 with pytest.raises(RuntimeError) as info: 49 # test input_strings as two 1D tensor. It's error because to_number is an OneToOne op 50 _ = op(*input_strings) 51 assert "The op is OneToOne, can only accept one tensor as input." in str(info.value) 52 53 # test input invalid tensor 54 invalid_input = [["1", "2", "3"], ["4", "5"]] 55 with pytest.raises(TypeError) as info: 56 _ = op(invalid_input) 57 assert "Invalid user input. Got <class 'list'>: [['1', '2', '3'], ['4', '5']], cannot be converted into tensor" in \ 58 str(info.value) 59 60 61def test_to_number_typical_case_integral(): 62 input_strings = [["-121", "14"], ["-2219", "7623"], ["-8162536", "162371864"], 63 ["-1726483716", "98921728421"]] 64 65 for ms_type, inputs in zip(ms_integral_types, input_strings): 66 dataset = ds.GeneratorDataset(string_dataset_generator(inputs), "strings") 67 dataset = dataset.map(operations=text.ToNumber(ms_type), input_columns=["strings"]) 68 69 expected_output = [int(string) for string in inputs] 70 output = [] 71 for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): 72 output.append(data["strings"]) 73 74 assert output == expected_output 75 76 77def test_to_number_typical_case_non_integral(): 78 input_strings = [["-1.1", "1.4"], ["-2219.321", "7623.453"], ["-816256.234282", "162371864.243243"]] 79 epsilons = [0.001, 0.001, 0.0001, 0.0001, 0.0000001, 0.0000001] 80 81 for ms_type, inputs in zip(ms_non_integral_types, input_strings): 82 dataset = ds.GeneratorDataset(string_dataset_generator(inputs), "strings") 83 dataset = dataset.map(operations=text.ToNumber(ms_type), input_columns=["strings"]) 84 85 expected_output = [float(string) for string in inputs] 86 output = [] 87 for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): 88 output.append(data["strings"]) 89 90 for expected, actual, epsilon in zip(expected_output, output, epsilons): 91 assert abs(expected - actual) < epsilon 92 93 94def out_of_bounds_error_message_check(dataset, np_type, value_to_cast): 95 type_info = np.iinfo(np_type) 96 type_max = str(type_info.max) 97 type_min = str(type_info.min) 98 type_name = str(np.dtype(np_type)) 99 100 with pytest.raises(RuntimeError) as info: 101 for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): 102 pass 103 assert "string input " + value_to_cast + " will be out of bounds if cast to " + type_name in str(info.value) 104 assert "valid range is: [" + type_min + ", " + type_max + "]" in str(info.value) 105 106 107def test_to_number_out_of_bounds_integral(): 108 for np_type, ms_type in zip(np_integral_types, ms_integral_types): 109 type_info = np.iinfo(np_type) 110 input_strings = [str(type_info.max + 10)] 111 dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings") 112 dataset = dataset.map(operations=text.ToNumber(ms_type), input_columns=["strings"]) 113 out_of_bounds_error_message_check(dataset, np_type, input_strings[0]) 114 115 input_strings = [str(type_info.min - 10)] 116 dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings") 117 dataset = dataset.map(operations=text.ToNumber(ms_type), input_columns=["strings"]) 118 out_of_bounds_error_message_check(dataset, np_type, input_strings[0]) 119 120 121def test_to_number_out_of_bounds_non_integral(): 122 above_range = [str(np.finfo(np.float16).max * 10), str(np.finfo(np.float32).max * 10), "1.8e+308"] 123 124 input_strings = [above_range[0]] 125 dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings") 126 dataset = dataset.map(operations=text.ToNumber(ms_non_integral_types[0]), input_columns=["strings"]) 127 128 with pytest.raises(RuntimeError) as info: 129 for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): 130 pass 131 assert "outside of valid float16 range" in str(info.value) 132 133 input_strings = [above_range[1]] 134 dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings") 135 dataset = dataset.map(operations=text.ToNumber(ms_non_integral_types[1]), input_columns=["strings"]) 136 137 with pytest.raises(RuntimeError) as info: 138 for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): 139 pass 140 assert "string input " + input_strings[0] + " will be out of bounds if cast to float32" in str(info.value) 141 142 input_strings = [above_range[2]] 143 dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings") 144 dataset = dataset.map(operations=text.ToNumber(ms_non_integral_types[2]), input_columns=["strings"]) 145 146 with pytest.raises(RuntimeError) as info: 147 for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): 148 pass 149 assert "string input " + input_strings[0] + " will be out of bounds if cast to float64" in str(info.value) 150 151 below_range = [str(np.finfo(np.float16).min * 10), str(np.finfo(np.float32).min * 10), "-1.8e+308"] 152 153 input_strings = [below_range[0]] 154 dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings") 155 dataset = dataset.map(operations=text.ToNumber(ms_non_integral_types[0]), input_columns=["strings"]) 156 157 with pytest.raises(RuntimeError) as info: 158 for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): 159 pass 160 assert "outside of valid float16 range" in str(info.value) 161 162 input_strings = [below_range[1]] 163 dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings") 164 dataset = dataset.map(operations=text.ToNumber(ms_non_integral_types[1]), input_columns=["strings"]) 165 166 with pytest.raises(RuntimeError) as info: 167 for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): 168 pass 169 assert "string input " + input_strings[0] + " will be out of bounds if cast to float32" in str(info.value) 170 171 input_strings = [below_range[2]] 172 dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings") 173 dataset = dataset.map(operations=text.ToNumber(ms_non_integral_types[2]), input_columns=["strings"]) 174 175 with pytest.raises(RuntimeError) as info: 176 for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): 177 pass 178 assert "string input " + input_strings[0] + " will be out of bounds if cast to float64" in str(info.value) 179 180 181def test_to_number_boundaries_integral(): 182 for np_type, ms_type in zip(np_integral_types, ms_integral_types): 183 type_info = np.iinfo(np_type) 184 input_strings = [str(type_info.max)] 185 dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings") 186 dataset = dataset.map(operations=text.ToNumber(ms_type), input_columns=["strings"]) 187 for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): 188 assert data["strings"] == int(input_strings[0]) 189 190 input_strings = [str(type_info.min)] 191 dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings") 192 dataset = dataset.map(operations=text.ToNumber(ms_type), input_columns=["strings"]) 193 for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): 194 assert data["strings"] == int(input_strings[0]) 195 196 input_strings = [str(0)] 197 dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings") 198 dataset = dataset.map(operations=text.ToNumber(ms_type), input_columns=["strings"]) 199 for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): 200 assert data["strings"] == int(input_strings[0]) 201 202 203def test_to_number_invalid_input(): 204 input_strings = ["a8fa9ds8fa"] 205 dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings") 206 dataset = dataset.map(operations=text.ToNumber(mstype.int32), input_columns=["strings"]) 207 208 with pytest.raises(RuntimeError) as info: 209 for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): 210 pass 211 assert "it is invalid to convert \"" + input_strings[0] + "\" to a number" in str(info.value) 212 213 214def test_to_number_invalid_type(): 215 with pytest.raises(TypeError) as info: 216 dataset = ds.GeneratorDataset(string_dataset_generator(["a8fa9ds8fa"]), "strings") 217 dataset = dataset.map(operations=text.ToNumber(mstype.bool_), input_columns=["strings"]) 218 assert "data_type: Bool is not numeric data type" in str(info.value) 219 220 221if __name__ == '__main__': 222 test_to_number_eager() 223 test_to_number_typical_case_integral() 224 test_to_number_typical_case_non_integral() 225 test_to_number_boundaries_integral() 226 test_to_number_out_of_bounds_integral() 227 test_to_number_out_of_bounds_non_integral() 228 test_to_number_invalid_input() 229 test_to_number_invalid_type() 230