1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16Testing ComplexNorm op in DE. 17""" 18import numpy as np 19from numpy import random 20 21import mindspore.dataset as ds 22import mindspore.dataset.audio.transforms as audio 23from mindspore import log as logger 24 25 26def test_complex_norm(): 27 """ 28 Test complex_norm (pipeline). 29 """ 30 logger.info("Test ComplexNorm.") 31 32 def gen(): 33 data = np.array([[1.0, 1.0], [2.0, 3.0], [4.0, 4.0]]) 34 yield (np.array(data, dtype=np.float32),) 35 36 dataset = ds.GeneratorDataset(source=gen, column_names=["multi_dim_data"]) 37 38 dataset = dataset.map(operations=audio.ComplexNorm(2), input_columns=["multi_dim_data"]) 39 40 for i in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): 41 assert i["multi_dim_data"].shape == (3,) 42 expected = np.array([2., 13., 32.]) 43 assert np.array_equal(i["multi_dim_data"], expected) 44 45 logger.info("Finish testing ComplexNorm.") 46 47 48def test_complex_norm_eager(): 49 """ 50 Test complex_norm callable (eager). 51 """ 52 logger.info("Test ComplexNorm callable.") 53 54 input_t = np.array([[1.0, 1.0], [2.0, 3.0], [4.0, 4.0]]) 55 output_t = audio.ComplexNorm()(input_t) 56 assert output_t.shape == (3,) 57 expected = np.array([1.4142135623730951, 3.605551275463989, 5.656854249492381]) 58 assert np.array_equal(output_t, expected) 59 60 logger.info("Finish testing ComplexNorm.") 61 62 63def test_complex_norm_uncallable(): 64 """ 65 Test complex_norm_op not callable. 66 """ 67 logger.info("Test ComplexNorm not callable.") 68 69 try: 70 input_t = random.rand(2, 4, 3, 2) 71 output_t = audio.ComplexNorm(-3.)(input_t) 72 assert output_t.shape == (2, 4, 3) 73 except ValueError as e: 74 assert 'Input power is not within the required interval of [0, 16777216].' in str(e) 75 76 logger.info("Finish testing ComplexNorm.") 77 78 79if __name__ == "__main__": 80 test_complex_norm() 81 test_complex_norm_eager() 82 test_complex_norm_uncallable() 83