1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Export net test.""" 16import os 17import numpy as np 18import pytest 19 20import mindspore as ms 21import mindspore.nn as nn 22from mindspore import context 23from mindspore.common.tensor import Tensor 24from mindspore.train.serialization import export 25 26 27class SliceNet(nn.Cell): 28 def __init__(self): 29 super().__init__() 30 self.relu = nn.ReLU() 31 32 def construct(self, x, y): 33 x = self.relu(x) 34 x[2,] = y 35 return x 36 37 38@pytest.mark.level0 39@pytest.mark.platform_x86_ascend_training 40@pytest.mark.platform_arm_ascend_training 41@pytest.mark.env_onecard 42def test_export_slice_net(): 43 context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 44 input_x = Tensor(np.random.rand(4, 4, 4), ms.float32) 45 input_y = Tensor(np.array([1]), ms.float32) 46 net = SliceNet() 47 file_name = "slice_net" 48 export(net, input_x, input_y, file_name=file_name, file_format='AIR') 49 verify_name = file_name + ".air" 50 assert os.path.exists(verify_name) 51 os.remove(verify_name) 52 export(net, input_x, input_y, file_name=file_name, file_format='MINDIR') 53 54 verify_name = file_name + ".mindir" 55 assert os.path.exists(verify_name) 56 os.remove(verify_name) 57