1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test uncertainty toolbox """ 16import mindspore.dataset as ds 17import mindspore.dataset.transforms.c_transforms as C 18import mindspore.dataset.vision.c_transforms as CV 19import mindspore.nn as nn 20from mindspore import context, Tensor 21from mindspore import dtype as mstype 22from mindspore.common.initializer import TruncatedNormal 23from mindspore.dataset.vision import Inter 24from mindspore.nn.probability.toolbox.uncertainty_evaluation import UncertaintyEvaluation 25from mindspore.train import load_checkpoint, load_param_into_net 26 27context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 28 29 30def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): 31 """weight initial for conv layer""" 32 weight = weight_variable() 33 return nn.Conv2d(in_channels, out_channels, 34 kernel_size=kernel_size, stride=stride, padding=padding, 35 weight_init=weight, has_bias=False, pad_mode="valid") 36 37 38def fc_with_initialize(input_channels, out_channels): 39 """weight initial for fc layer""" 40 weight = weight_variable() 41 bias = weight_variable() 42 return nn.Dense(input_channels, out_channels, weight, bias) 43 44 45def weight_variable(): 46 """weight initial""" 47 return TruncatedNormal(0.02) 48 49 50class LeNet5(nn.Cell): 51 def __init__(self, num_class=10, channel=1): 52 super(LeNet5, self).__init__() 53 self.num_class = num_class 54 self.conv1 = conv(channel, 6, 5) 55 self.conv2 = conv(6, 16, 5) 56 self.fc1 = fc_with_initialize(16 * 5 * 5, 120) 57 self.fc2 = fc_with_initialize(120, 84) 58 self.fc3 = fc_with_initialize(84, self.num_class) 59 self.relu = nn.ReLU() 60 self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) 61 self.flatten = nn.Flatten() 62 63 def construct(self, x): 64 x = self.conv1(x) 65 x = self.relu(x) 66 x = self.max_pool2d(x) 67 x = self.conv2(x) 68 x = self.relu(x) 69 x = self.max_pool2d(x) 70 x = self.flatten(x) 71 x = self.fc1(x) 72 x = self.relu(x) 73 x = self.fc2(x) 74 x = self.relu(x) 75 x = self.fc3(x) 76 return x 77 78 79def create_dataset(data_path, batch_size=32, repeat_size=1, 80 num_parallel_workers=1): 81 """ 82 create dataset for train or test 83 """ 84 # define dataset 85 mnist_ds = ds.MnistDataset(data_path) 86 87 resize_height, resize_width = 32, 32 88 rescale = 1.0 / 255.0 89 shift = 0.0 90 rescale_nml = 1 / 0.3081 91 shift_nml = -1 * 0.1307 / 0.3081 92 93 # define map operations 94 resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode 95 rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) 96 rescale_op = CV.Rescale(rescale, shift) 97 hwc2chw_op = CV.HWC2CHW() 98 type_cast_op = C.TypeCast(mstype.int32) 99 100 # apply map operations on images 101 mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers) 102 mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers) 103 mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers) 104 mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers) 105 mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers) 106 107 # apply DatasetOps 108 buffer_size = 10000 109 mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script 110 mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) 111 mnist_ds = mnist_ds.repeat(repeat_size) 112 113 return mnist_ds 114 115 116if __name__ == '__main__': 117 # get trained model 118 network = LeNet5() 119 param_dict = load_checkpoint('checkpoint_lenet.ckpt') 120 load_param_into_net(network, param_dict) 121 # get train and eval dataset 122 ds_train = create_dataset('workspace/mnist/train') 123 ds_eval = create_dataset('workspace/mnist/test') 124 evaluation = UncertaintyEvaluation(model=network, 125 train_dataset=ds_train, 126 task_type='classification', 127 num_classes=10, 128 epochs=1, 129 epi_uncer_model_path=None, 130 ale_uncer_model_path=None, 131 save_model=False) 132 for eval_data in ds_eval.create_dict_iterator(output_numpy=True, num_epochs=1): 133 eval_data = Tensor(eval_data['image'], mstype.float32) 134 epistemic_uncertainty = evaluation.eval_epistemic_uncertainty(eval_data) 135 aleatoric_uncertainty = evaluation.eval_aleatoric_uncertainty(eval_data) 136