1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""mobilenetv2_train_export.""" 16 17import sys 18import numpy as np 19from train_utils import save_inout, train_wrap 20from official.cv.mobilenetv2.src.mobilenetV2 import MobileNetV2Backbone, MobileNetV2Head, mobilenet_v2 21import mindspore.common.dtype as mstype 22from mindspore import context, Tensor, nn 23from mindspore.train.serialization import export 24 25context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False) 26batch = 8 27 28backbone_net = MobileNetV2Backbone() 29head_net = MobileNetV2Head(input_channel=backbone_net.out_channels, num_classes=10) 30n = mobilenet_v2(backbone_net, head_net) 31 32loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False) 33optimizer = nn.Momentum(n.trainable_params(), 0.01, 0.9, use_nesterov=False) 34net = train_wrap(n, loss_fn, optimizer) 35 36x = Tensor(np.random.randn(batch, 3, 224, 224), mstype.float32) 37label = Tensor(np.zeros([batch, 10]).astype(np.float32)) 38export(net, x, label, file_name="mindir/mobilenetv2_train", file_format='MINDIR') 39 40if len(sys.argv) > 1: 41 save_inout(sys.argv[1] + "mobilenetv2", x, label, n, net, sparse=False) 42