1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15import pytest 16 17import mindspore.dataset as ds 18import mindspore.dataset.vision.c_transforms as vision 19from mindspore import log as logger 20 21DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] 22SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" 23 24 25def test_exception_01(): 26 """ 27 Test single exception with invalid input 28 """ 29 logger.info("test_exception_01") 30 data = ds.TFRecordDataset(DATA_DIR, columns_list=["image"]) 31 with pytest.raises(TypeError) as info: 32 data.map(operations=vision.Resize(100, 100), input_columns=["image"]) 33 assert "Argument interpolation with value 100 is not of type [<enum 'Inter'>]" in str(info.value) 34 35 36def test_exception_02(): 37 """ 38 Test exceptions with invalid input, and test valid input 39 """ 40 logger.info("test_exception_02") 41 num_samples = -1 42 with pytest.raises(ValueError) as info: 43 ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples) 44 assert 'num_samples exceeds the boundary between 0 and 9223372036854775807(INT64_MAX)' in str(info.value) 45 46 num_samples = 1 47 data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples) 48 data = data.map(operations=vision.Decode(), input_columns=["image"]) 49 data = data.map(operations=vision.Resize((100, 100)), input_columns=["image"]) 50 # Confirm 1 sample in dataset 51 assert sum([1 for _ in data]) == 1 52 num_iters = 0 53 for _ in data.create_dict_iterator(num_epochs=1): 54 num_iters += 1 55 assert num_iters == 1 56 57 58if __name__ == '__main__': 59 test_exception_01() 60 test_exception_02() 61