• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" create train dataset. """
16
17import os
18import re
19import numpy as np
20from mindspore.communication.management import init
21from mindspore.communication.management import get_rank
22from mindspore.communication.management import get_group_size
23from mindspore import Tensor
24
25
26def _count_unequal_element(data_expected, data_me, rtol, atol):
27    assert data_expected.shape == data_me.shape
28    total_count = len(data_expected.flatten())
29    error = np.abs(data_expected - data_me)
30    greater = np.greater(error, atol + np.abs(data_me) * rtol)
31    loss_count = np.count_nonzero(greater)
32    assert (loss_count / total_count) < rtol, \
33        "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \
34            format(data_expected[greater], data_me[greater], error[greater])
35
36
37def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True):
38    if np.any(np.isnan(data_expected)):
39        assert np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan)
40    elif not np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan):
41        _count_unequal_element(data_expected, data_me, rtol, atol)
42    else:
43        assert True
44
45
46def clean_all_ir_files(folder_path):
47    if os.path.exists(folder_path):
48        for file_name in os.listdir(folder_path):
49            if file_name.endswith('.ir') or file_name.endswith('.dat') or file_name.endswith('.dot'):
50                os.remove(os.path.join(folder_path, file_name))
51
52
53def find_newest_validateir_file(folder_path):
54    validate_files = map(lambda f: os.path.join(folder_path, f),
55                         filter(lambda f: re.match(r'\d+_validate_\d+.ir', f), os.listdir(folder_path)))
56    return max(validate_files, key=os.path.getctime)
57
58
59class FakeDataInitMode:
60    RandomInit = 0
61    OnesInit = 1
62    UniqueInit = 2
63    ZerosInit = 3
64
65
66class FakeData:
67    def __init__(self, size=1024, batch_size=32, image_size=(3, 224, 224),
68                 num_classes=10, random_offset=0, use_parallel=False,
69                 fakedata_mode=FakeDataInitMode.RandomInit):
70        self.size = size
71        self.rank_batch_size = batch_size
72        self.total_batch_size = self.rank_batch_size
73        self.random_offset = random_offset
74        self.image_size = image_size
75        self.num_classes = num_classes
76        self.rank_size = 1
77        self.rank_id = 0
78        self.batch_index = 0
79        self.image_data_type = np.float32
80        self.label_data_type = np.float32
81        self.is_onehot = True
82        self.fakedata_mode = fakedata_mode
83
84        if use_parallel is True:
85            init()
86            self.rank_size = get_group_size()
87            self.rank_id = get_rank()
88
89        self.total_batch_size = self.rank_batch_size * self.rank_size
90
91        assert (self.size % self.total_batch_size) == 0
92
93        self.total_batch_data_size = (self.rank_size, self.rank_batch_size) + image_size
94
95    def get_dataset_size(self):
96        return int(self.size / self.total_batch_size)
97
98    def get_repeat_count(self):
99        return 1
100
101    def set_image_data_type(self, data_type):
102        self.image_data_type = data_type
103
104    def set_label_data_type(self, data_type):
105        self.label_data_type = data_type
106
107    def set_label_onehot(self, is_onehot=True):
108        self.is_onehot = is_onehot
109
110    def create_tuple_iterator(self, num_epochs=-1, do_copy=True):
111        _ = num_epochs
112        return self
113
114    def __getitem__(self, batch_index):
115        if batch_index * self.total_batch_size >= len(self):
116            raise IndexError("{} index out of range".format(self.__class__.__name__))
117        rng_state = np.random.get_state()
118        np.random.seed(batch_index + self.random_offset)
119        if self.fakedata_mode == FakeDataInitMode.OnesInit:
120            img = np.ones(self.total_batch_data_size)
121        elif self.fakedata_mode == FakeDataInitMode.ZerosInit:
122            img = np.zeros(self.total_batch_data_size)
123        elif self.fakedata_mode == FakeDataInitMode.UniqueInit:
124            total_size = 1
125            for i in self.total_batch_data_size:
126                total_size = total_size * i
127            img = np.reshape(np.arange(total_size) * 0.0001, self.total_batch_data_size)
128        else:
129            img = np.random.randn(*self.total_batch_data_size)
130        target = np.random.randint(0, self.num_classes, size=(self.rank_size, self.rank_batch_size))
131        np.random.set_state(rng_state)
132        img = img[self.rank_id]
133        target = target[self.rank_id]
134        img_ret = img.astype(self.image_data_type)
135        target_ret = target.astype(self.label_data_type)
136        if self.is_onehot:
137            target_onehot = np.zeros(shape=(self.rank_batch_size, self.num_classes))
138            target_onehot[np.arange(self.rank_batch_size), target] = 1
139            target_ret = target_onehot.astype(self.label_data_type)
140        return Tensor(img_ret), Tensor(target_ret)
141
142    def __len__(self):
143        return self.size
144
145    def __iter__(self):
146        self.batch_index = 0
147        return self
148
149    def reset(self):
150        self.batch_index = 0
151
152    def __next__(self):
153        if self.batch_index * self.total_batch_size < len(self):
154            data = self[self.batch_index]
155            self.batch_index += 1
156            return data
157        raise StopIteration
158