1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import numpy as np 17import pytest 18 19import mindspore.context as context 20import mindspore.nn as nn 21from mindspore import Tensor 22from mindspore.ops import operations as P 23 24 25class NetCheckValid(nn.Cell): 26 def __init__(self): 27 super(NetCheckValid, self).__init__() 28 self.valid = P.CheckValid() 29 30 def construct(self, anchor, image_metas): 31 return self.valid(anchor, image_metas) 32 33def check_valid(nptype): 34 anchor = np.array([[50, 0, 100, 700], [-2, 2, 8, 100], [10, 20, 300, 2000]], nptype) 35 image_metas = np.array([768, 1280, 1], nptype) 36 anchor_box = Tensor(anchor) 37 image_metas_box = Tensor(image_metas) 38 expect = np.array([True, False, False], np.bool) 39 40 context.set_context(mode=context.GRAPH_MODE, device_target='CPU') 41 boundingbox_decode = NetCheckValid() 42 output = boundingbox_decode(anchor_box, image_metas_box) 43 assert np.array_equal(output.asnumpy(), expect) 44 45 context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU') 46 boundingbox_decode = NetCheckValid() 47 output = boundingbox_decode(anchor_box, image_metas_box) 48 assert np.array_equal(output.asnumpy(), expect) 49 50@pytest.mark.level0 51@pytest.mark.platform_x86_cpu 52@pytest.mark.env_onecard 53def test_check_valid_float32(): 54 check_valid(np.float32) 55 56@pytest.mark.level0 57@pytest.mark.platform_x86_cpu 58@pytest.mark.env_onecard 59def test_check_valid_float16(): 60 check_valid(np.float16) 61 62@pytest.mark.level0 63@pytest.mark.platform_x86_cpu 64@pytest.mark.env_onecard 65def test_check_valid_int16(): 66 check_valid(np.int16) 67 68@pytest.mark.level0 69@pytest.mark.platform_x86_cpu 70@pytest.mark.env_onecard 71def test_check_valid_uint8(): 72 anchor = np.array([[5, 0, 10, 70], [2, 2, 8, 10], [1, 2, 30, 200]], np.uint8) 73 image_metas = np.array([76, 128, 1], np.uint8) 74 anchor_box = Tensor(anchor) 75 image_metas_box = Tensor(image_metas) 76 expect = np.array([True, True, False], np.bool) 77 78 context.set_context(mode=context.GRAPH_MODE, device_target='CPU') 79 boundingbox_decode = NetCheckValid() 80 output = boundingbox_decode(anchor_box, image_metas_box) 81 assert np.array_equal(output.asnumpy(), expect) 82 83 context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU') 84 boundingbox_decode = NetCheckValid() 85 output = boundingbox_decode(anchor_box, image_metas_box) 86 assert np.array_equal(output.asnumpy(), expect) 87