1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15import numpy as np 16import pytest 17 18import mindspore.dataset as ds 19import mindspore.dataset.audio.transforms as a_c_trans 20 21 22def count_unequal_element(data_expected, data_me, rtol, atol): 23 assert data_expected.shape == data_me.shape 24 total_count = len(data_expected.flatten()) 25 error = np.abs(data_expected - data_me) 26 greater = np.greater(error, atol + np.abs(data_expected) * rtol) 27 loss_count = np.count_nonzero(greater) 28 assert (loss_count / total_count) < rtol, "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}".format( 29 data_expected[greater], data_me[greater], error[greater]) 30 31 32def test_func_dc_shift_eager(): 33 """ 34 Eager Test 35 """ 36 arr = np.array([0.60, 0.97, -1.04, -1.26, 0.97, 0.91, 0.48, 0.93, 0.71, 0.61], dtype=np.double) 37 expected = np.array([0.0400, 0.0400, -0.0400, -0.2600, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400], 38 dtype=np.double) 39 dcshift_op = a_c_trans.DCShift(1.0, 0.04) 40 output = dcshift_op(arr) 41 count_unequal_element(expected, output, 0.0001, 0.0001) 42 43 44def test_func_dc_shift_pipeline(): 45 """ 46 Pipeline Test 47 """ 48 arr = np.array([[1.14, -1.06, 0.94, 0.90], [-1.11, 1.40, -0.33, 1.43]], dtype=np.double) 49 expected = np.array([[0.2300, -0.2600, 0.2300, 0.2300], [-0.3100, 0.2300, 0.4700, 0.2300]], dtype=np.double) 50 dataset = ds.NumpySlicesDataset(arr, column_names=["col1"], shuffle=False) 51 dcshift_op = a_c_trans.DCShift(0.8, 0.03) 52 dataset = dataset.map(operations=dcshift_op, input_columns=["col1"]) 53 for item1, item2 in zip(dataset.create_dict_iterator(output_numpy=True), expected): 54 count_unequal_element(item2, item1['col1'], 0.0001, 0.0001) 55 56 57def test_func_dc_shift_pipeline_error(): 58 """ 59 Pipeline Error Test 60 """ 61 arr = np.random.uniform(-2, 2, size=(1000)).astype(np.float) 62 label = np.random.sample((1000, 1)) 63 data = (arr, label) 64 dataset = ds.NumpySlicesDataset(data, column_names=["col1", "col2"], shuffle=False) 65 num_itr = 0 66 with pytest.raises(ValueError, match=r"Input shift is not within the required interval of \[-2.0, 2.0\]."): 67 dcshift_op = a_c_trans.DCShift(2.5, 0.03) 68 dataset = dataset.map(operations=dcshift_op, input_columns=["col1"]) 69 for _ in dataset.create_dict_iterator(output_numpy=True): 70 num_itr += 1 71 72 73if __name__ == "__main__": 74 test_func_dc_shift_eager() 75 test_func_dc_shift_pipeline() 76 test_func_dc_shift_pipeline_error() 77