1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" 16test CentralCrop 17""" 18import numpy as np 19import pytest 20 21import mindspore.nn as nn 22from mindspore import Tensor 23from mindspore.common import dtype as mstype 24from mindspore.common.api import _cell_graph_executor 25 26 27class CentralCropNet(nn.Cell): 28 def __init__(self, central_fraction): 29 super(CentralCropNet, self).__init__() 30 self.net = nn.CentralCrop(central_fraction) 31 32 def construct(self, image): 33 return self.net(image) 34 35 36def test_compile_3d_central_crop(): 37 central_fraction = 0.2 38 net = CentralCropNet(central_fraction) 39 image = Tensor(np.random.random((3, 16, 16)), mstype.float32) 40 _cell_graph_executor.compile(net, image) 41 42 43def test_compile_4d_central_crop(): 44 central_fraction = 0.5 45 net = CentralCropNet(central_fraction) 46 image = Tensor(np.random.random((8, 3, 16, 16)), mstype.float32) 47 _cell_graph_executor.compile(net, image) 48 49 50def test_central_fraction_bool(): 51 central_fraction = True 52 with pytest.raises(TypeError): 53 _ = CentralCropNet(central_fraction) 54 55 56def test_central_crop_central_fraction_negative(): 57 central_fraction = -1.0 58 with pytest.raises(ValueError): 59 _ = CentralCropNet(central_fraction) 60 61 62def test_central_fraction_zero(): 63 central_fraction = 0.0 64 with pytest.raises(ValueError): 65 _ = CentralCropNet(central_fraction) 66 67 68def test_central_crop_invalid_5d_input(): 69 invalid_shape = (8, 3, 16, 16, 1) 70 invalid_image = Tensor(np.random.random(invalid_shape)) 71 72 net = CentralCropNet(central_fraction=0.5) 73 with pytest.raises(ValueError): 74 _cell_graph_executor.compile(net, invalid_image) 75