1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""transfer_learning_export.""" 16 17import numpy as np 18import mindspore as M 19from mindspore.nn import Cell 20from mindspore.train.serialization import load_checkpoint, export 21from effnet import effnet 22from train_utils import train_wrap 23 24 25class TransferNet(Cell): 26 def __init__(self, backbone, head): 27 super().__init__(TransferNet) 28 self.backbone = backbone 29 self.head = head 30 31 def construct(self, x): 32 x = self.backbone(x) 33 x = self.head(x) 34 return x 35 36 37BACKBONE = effnet(num_classes=1000) 38load_checkpoint("efficient_net_b0.ckpt", BACKBONE) 39 40M.context.set_context(mode=M.context.PYNATIVE_MODE, 41 device_target="GPU", save_graphs=False) 42BATCH_SIZE = 16 43X = M.Tensor(np.ones((BATCH_SIZE, 3, 224, 224)), M.float32) 44export(BACKBONE, X, file_name="transfer_learning_tod_backbone", file_format='MINDIR') 45 46label = M.Tensor(np.zeros([BATCH_SIZE, 10]).astype(np.float32)) 47HEAD = M.nn.Dense(1000, 10) 48HEAD.weight.set_data(M.Tensor(np.random.normal( 49 0, 0.1, HEAD.weight.data.shape).astype("float32"))) 50HEAD.bias.set_data(M.Tensor(np.zeros(HEAD.bias.data.shape, dtype="float32"))) 51 52sgd = M.nn.SGD(HEAD.trainable_params(), learning_rate=0.015, momentum=0.9, 53 dampening=0.01, weight_decay=0.0, nesterov=False, loss_scale=1.0) 54net = train_wrap(HEAD, optimizer=sgd) 55backbone_out = M.Tensor(np.zeros([BATCH_SIZE, 1000]).astype(np.float32)) 56export(net, backbone_out, label, file_name="transfer_learning_tod_head", file_format='MINDIR') 57 58print("Exported") 59