• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15'''Remove after MindData merge to MindSpore '''
16import numpy as np
17
18from mindspore import Tensor
19
20
21class MindData:
22    """ Stub for MindData """
23
24    def __init__(self, size=1, batch_size=None, repeat_count=1,
25                 np_types=None, output_shapes=None, input_indexs=()):
26        self._size = size
27        self._batch_size = batch_size
28        self._repeat_count = repeat_count
29        self._np_types = np_types
30        self._output_shapes = output_shapes
31        self._input_indexs = input_indexs
32        self._iter_num = 0
33        self._global_step = 0
34
35    def get_dataset_size(self):
36        return self._size
37
38    def get_repeat_count(self):
39        return self._repeat_count
40
41    def get_batch_size(self):
42        return self._batch_size
43
44    def output_types(self):
45        return self._np_types
46
47    def output_shapes(self):
48        return self._output_shapes
49
50    @property
51    def input_indexs(self):
52        return self._input_indexs
53
54    def device_que(self, send_epoch_end=True, create_data_info_queue=False, queue_name=""):
55        self.queue_name = '6ba41974-209e-11ea-88b0-a24efeb2c736'
56        self.send_epoch_end = send_epoch_end
57        return self
58
59    def create_tuple_iterator(self, num_epochs=-1, do_copy=True):
60        return self.__iter__()
61
62    def send(self, num_epochs=-1):
63        pass
64
65    def stop_send(self):
66        pass
67
68    def release(self):
69        pass
70
71    def continue_send(self):
72        pass
73
74    def get_data_info(self):
75        pass
76
77    def get_mbuf_queue_size(self):
78        pass
79
80    def get_send_info(self):
81        pass
82
83    def __len__(self):
84        return self._size
85
86    def __iter__(self):
87        return self
88
89    def __next__(self):
90        if self._size < self._iter_num:
91            raise StopIteration
92        self._iter_num += 1
93        next_value = []
94        for shape, typ in zip(self._output_shapes, self._np_types):
95            next_value.append(Tensor(np.ndarray(shape, typ)))
96
97        return tuple(next_value)
98
99    def next(self):
100        return self.__next__()
101
102    def reset(self):
103        self._iter_num = 0
104
105    def get_init_step(self):
106        return self._global_step
107