1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16Testing Mask op in DE 17""" 18import numpy as np 19import pytest 20 21import mindspore.common.dtype as mstype 22import mindspore.dataset as ds 23import mindspore.dataset.transforms.c_transforms as ops 24 25mstype_to_np_type = { 26 mstype.bool_: np.bool, 27 mstype.int8: np.int8, 28 mstype.uint8: np.uint8, 29 mstype.int16: np.int16, 30 mstype.uint16: np.uint16, 31 mstype.int32: np.int32, 32 mstype.uint32: np.uint32, 33 mstype.int64: np.int64, 34 mstype.uint64: np.uint64, 35 mstype.float16: np.float16, 36 mstype.float32: np.float32, 37 mstype.float64: np.float64, 38 mstype.string: np.str 39} 40 41 42def mask_compare(array, op, constant, dtype=mstype.bool_): 43 data = ds.NumpySlicesDataset([array]) 44 array = np.array(array) 45 data = data.map(operations=ops.Mask(op, constant, dtype)) 46 for d in data: 47 if op == ops.Relational.EQ: 48 array = array == np.array(constant, dtype=array.dtype) 49 elif op == ops.Relational.NE: 50 array = array != np.array(constant, dtype=array.dtype) 51 elif op == ops.Relational.GT: 52 array = array > np.array(constant, dtype=array.dtype) 53 elif op == ops.Relational.GE: 54 array = array >= np.array(constant, dtype=array.dtype) 55 elif op == ops.Relational.LT: 56 array = array < np.array(constant, dtype=array.dtype) 57 elif op == ops.Relational.LE: 58 array = array <= np.array(constant, dtype=array.dtype) 59 60 array = array.astype(dtype=mstype_to_np_type[dtype]) 61 62 np.testing.assert_array_equal(array, d[0].asnumpy()) 63 64 65def test_mask_int_comparison(): 66 for k in mstype_to_np_type: 67 if k == mstype.string: 68 continue 69 mask_compare([1, 2, 3, 4, 5], ops.Relational.EQ, 3, k) 70 mask_compare([1, 2, 3, 4, 5], ops.Relational.NE, 3, k) 71 mask_compare([1, 2, 3, 4, 5], ops.Relational.LT, 3, k) 72 mask_compare([1, 2, 3, 4, 5], ops.Relational.LE, 3, k) 73 mask_compare([1, 2, 3, 4, 5], ops.Relational.GT, 3, k) 74 mask_compare([1, 2, 3, 4, 5], ops.Relational.GE, 3, k) 75 76 77def test_mask_float_comparison(): 78 for k in mstype_to_np_type: 79 if k == mstype.string: 80 continue 81 mask_compare([1.5, 2.5, 3., 4.5, 5.5], ops.Relational.EQ, 3, k) 82 mask_compare([1.5, 2.5, 3., 4.5, 5.5], ops.Relational.NE, 3, k) 83 mask_compare([1.5, 2.5, 3., 4.5, 5.5], ops.Relational.LT, 3, k) 84 mask_compare([1.5, 2.5, 3., 4.5, 5.5], ops.Relational.LE, 3, k) 85 mask_compare([1.5, 2.5, 3., 4.5, 5.5], ops.Relational.GT, 3, k) 86 mask_compare([1.5, 2.5, 3., 4.5, 5.5], ops.Relational.GE, 3, k) 87 88 89def test_mask_float_comparison2(): 90 for k in mstype_to_np_type: 91 if k == mstype.string: 92 continue 93 mask_compare([1, 2, 3, 4, 5], ops.Relational.EQ, 3.5, k) 94 mask_compare([1, 2, 3, 4, 5], ops.Relational.NE, 3.5, k) 95 mask_compare([1, 2, 3, 4, 5], ops.Relational.LT, 3.5, k) 96 mask_compare([1, 2, 3, 4, 5], ops.Relational.LE, 3.5, k) 97 mask_compare([1, 2, 3, 4, 5], ops.Relational.GT, 3.5, k) 98 mask_compare([1, 2, 3, 4, 5], ops.Relational.GE, 3.5, k) 99 100 101def test_mask_string_comparison(): 102 for k in mstype_to_np_type: 103 if k == mstype.string: 104 continue 105 mask_compare(["1.5", "2.5", "3.", "4.5", "5.5"], ops.Relational.EQ, "3.", k) 106 mask_compare(["1.5", "2.5", "3.", "4.5", "5.5"], ops.Relational.NE, "3.", k) 107 mask_compare(["1.5", "2.5", "3.", "4.5", "5.5"], ops.Relational.LT, "3.", k) 108 mask_compare(["1.5", "2.5", "3.", "4.5", "5.5"], ops.Relational.LE, "3.", k) 109 mask_compare(["1.5", "2.5", "3.", "4.5", "5.5"], ops.Relational.GT, "3.", k) 110 mask_compare(["1.5", "2.5", "3.", "4.5", "5.5"], ops.Relational.GE, "3.", k) 111 112 113def test_mask_exceptions_str(): 114 with pytest.raises(RuntimeError) as info: 115 mask_compare([1, 2, 3, 4, 5], ops.Relational.EQ, "3.5") 116 assert "input datatype does not match the value datatype" in str(info.value) 117 118 with pytest.raises(RuntimeError) as info: 119 mask_compare(["1", "2", "3", "4", "5"], ops.Relational.EQ, 3.5) 120 assert "input datatype does not match the value datatype" in str(info.value) 121 122 with pytest.raises(RuntimeError) as info: 123 mask_compare(["1", "2", "3", "4", "5"], ops.Relational.EQ, "3.5", mstype.string) 124 assert "Only supports bool or numeric datatype for generated mask type" in str(info.value) 125 126 127if __name__ == "__main__": 128 test_mask_int_comparison() 129 test_mask_float_comparison() 130 test_mask_float_comparison2() 131 test_mask_string_comparison() 132 test_mask_exceptions_str() 133