1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15import json 16import pytest 17import mindspore.dataset as ds 18from mindspore import log as logger 19from util import dataset_equal 20 21FILES = ["../data/dataset/testTFTestAllTypes/test.data"] 22DATASET_ROOT = "../data/dataset/testTFTestAllTypes/" 23SCHEMA_FILE = "../data/dataset/testTFTestAllTypes/datasetSchema.json" 24 25 26def test_schema_simple(): 27 logger.info("test_schema_simple") 28 ds.Schema(SCHEMA_FILE) 29 30 31def test_schema_file_vs_string(): 32 logger.info("test_schema_file_vs_string") 33 34 schema1 = ds.Schema(SCHEMA_FILE) 35 with open(SCHEMA_FILE) as file: 36 json_obj = json.load(file) 37 schema2 = ds.Schema() 38 schema2.from_json(json_obj) 39 40 ds1 = ds.TFRecordDataset(FILES, schema1) 41 ds2 = ds.TFRecordDataset(FILES, schema2) 42 43 dataset_equal(ds1, ds2, 0) 44 45 46def test_schema_exception(): 47 logger.info("test_schema_exception") 48 49 with pytest.raises(TypeError) as info: 50 ds.Schema(1) 51 assert "path: 1 is not string" in str(info.value) 52 53 with pytest.raises(RuntimeError) as info: 54 schema = ds.Schema(SCHEMA_FILE) 55 columns = [{'type': 'int8', 'shape': [3, 3]}] 56 schema.parse_columns(columns) 57 assert "Column's name is missing" in str(info.value) 58 59 60if __name__ == '__main__': 61 test_schema_simple() 62 test_schema_file_vs_string() 63 test_schema_exception() 64