• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16This dataset module creates an internal queue class to more optimally pass data
17between multiple processes in Python.  It has same API as multiprocessing.queue
18but it will pass large data through shared memory.
19"""
20
21import errno
22import multiprocessing
23import platform
24import queue
25import types
26
27import numpy as np
28
29from mindspore import log as logger
30import mindspore._c_dataengine as cde
31from ..transforms.py_transforms_util import ExceptionHandler
32
33
34class _SharedQueue(multiprocessing.queues.Queue):
35    """
36    Class to implement a queue using shared memory for better performance.
37    Args:
38        size: Number of elements in the queue.
39        count: Shared variable to suppress log printing.
40        copy_out: Flag to indidcate whether an extra copy should be done before returning.  If data will immediately be
41                  copied before returning, then this can be set to False.
42        max_rowsize: Maximum size of any element in the Queue in MB.
43    """
44
45    def __init__(self, size, count, copy_out=False, max_rowsize=6):
46        super().__init__(size, ctx=multiprocessing.get_context())
47
48        self.copy_out = copy_out
49
50        # pipe can hold up to 65,636 bytes at a time
51        # there is less benefit for small data. To small data it can be slower as we need to pass 100 bytes of metadata
52        # and then access the shared memory.
53        self.min_shared_mem = 10000
54        self.data_immediate = 0
55        self.data_shared = 1
56        self.count = count
57        self.print_error = True
58
59        if platform.system().lower() != 'windows' and max_rowsize == -1:
60            self.dynamic_shm = True
61        else:
62            self.dynamic_shm = False
63            # change max_rowsize in MB into bytes
64            self.seg_size = max_rowsize * 1024 * 1024
65            self.shm_list = []
66            self.seg_pos = 0
67            # num_seg has to be 2 more than the queue size.  We can have remote worker filling a buffer, main process
68            # reading a buffer and also have a full queue of buffers in the meta-data queue
69            self.num_seg = size + 2
70            for _ in range(self.num_seg):
71                try:
72                    a = multiprocessing.Array("b", self.seg_size)
73                except OSError as e:
74                    if e.errno == errno.ENOMEM:
75                        raise RuntimeError("Failed to allocate shared memory for {0} elements of {1}MB: {2}"
76                                           .format(self.num_seg, self.seg_size / 1024 / 1024, e))
77                    raise
78                else:
79                    self.shm_list.append(a)
80
81    def put_until(self, data, timeout=None, exit_signal=None):
82        """Put data into the queue. Block until timeout is reached or exit_signal is set."""
83        while True:
84            try:
85                self.put(data, timeout=timeout)
86                return
87            except queue.Full as e:
88                if exit_signal is None:
89                    raise e
90                if exit_signal.is_set():
91                    return
92                continue
93
94    def put(self, data, timeout=None):
95        if isinstance(data, ExceptionHandler):  # pylint: disable=too-many-nested-blocks
96            super().put(data, timeout=timeout)
97        else:
98            name_list = []
99            start_bytes = 0
100            if not isinstance(data, tuple):
101                data = (data,)
102            if isinstance(data, np.ndarray):
103                name_list.append((self.data_immediate, np.array(data)))
104            else:
105                for r in data:
106                    # the map:pyfunc is a yield generator which can't be serialize
107                    if isinstance(r, types.GeneratorType):
108                        raise TypeError("Cannot pickle {} object, please verify pyfunc return with numpy array"
109                                        .format(type(r)))
110                    if isinstance(r, np.ndarray) and self.dynamic_shm:
111                        byte = r.nbytes
112                        shm = cde.SharedMemory(None, True, -1, byte)
113                        dest = np.ndarray(r.shape, r.dtype, buffer=shm.buf())
114                        np.copyto(dest, r)
115                        fd = shm.fd()
116                        df = multiprocessing.reduction.DupFd(fd)
117                        name_list.append((self.data_shared, r.dtype, r.shape, shm.name(), df, shm.size()))
118                    elif (isinstance(r, np.ndarray) and r.size > self.min_shared_mem
119                          and start_bytes + r.nbytes < self.seg_size):
120                        # need to convert start_bytes to offset in array
121                        start_offset = start_bytes
122                        dest = np.ndarray(r.shape, r.dtype, buffer=self.shm_list[self.seg_pos].get_obj(),
123                                          offset=start_offset)
124                        np.copyto(dest, r)
125                        byte = r.nbytes
126                        byte = 8 * ((byte + 7) // 8)
127                        start_bytes += byte
128                        name_list.append((self.data_shared, self.seg_pos, byte, r.dtype, r.shape))
129                    else:
130                        if isinstance(r, np.ndarray) and r.size > self.min_shared_mem:
131                            # Only print out error the first time it happens
132                            if self.count.value == 0 and self.print_error:
133                                logger.warning(
134                                    "Using shared memory queue, but rowsize is larger than allocated memory "
135                                    + "max_rowsize: "
136                                    + str(self.seg_size / 1024 / 1024)
137                                    + "MB, current rowsize: "
138                                    + str((start_bytes + r.nbytes) / 1024 / 1024)
139                                    + "MB."
140                                )
141                                self.print_error = False
142                                self.count.value += 1
143                        name_list.append((self.data_immediate, r))
144            super().put(name_list, timeout=timeout)
145            # note above could generate a queue full exception.  It will be handled by teh caller
146            # only increment seg_pos after successfully adding to metadata queue
147
148            if start_bytes > 0:
149                self.seg_pos = (self.seg_pos + 1) % self.num_seg
150
151    def get_until(self, timeout=None, exit_signal=None):
152        """Get data from the queue. Block until timeout is reached or exit_signal is set."""
153        while True:
154            try:
155                r = self.get(timeout=timeout)
156            except queue.Empty as e:
157                if exit_signal is None:
158                    raise e
159                if exit_signal.is_set():
160                    return None
161                continue
162            if r is None:
163                # receive finish signal
164                return None
165            if exit_signal.is_set():
166                # loop until the queue becomes empty
167                continue
168            return r
169
170    def get(self, timeout=None):
171        result = super().get(timeout=timeout)
172        if isinstance(result, ExceptionHandler):
173            return result
174        r = []
175        start_bytes = 0
176        for x in result:
177            if x[0] == self.data_shared:
178                if self.dynamic_shm:
179                    dtype, shape, shm_name, df, buf_size = x[1:]
180                    fd = df.detach()
181                    shm = cde.SharedMemory(shm_name, False, fd, buf_size)
182                    data = np.ndarray(shape, dtype, buffer=shm.buf())
183                    dest = np.copy(data)
184                    r.append(dest)
185                else:
186                    seg_pos, byte, dtype, shape = x[1:]
187                    start_offset = start_bytes
188                    b = self.shm_list[seg_pos]
189                    data = np.ndarray(shape, dtype, buffer=b.get_obj(), offset=start_offset)
190                    start_bytes += byte
191                    if self.copy_out:
192                        dest = np.copy(data)
193                        r.append(dest)
194                    else:
195                        r.append(data)
196            elif x[0] == self.data_immediate:
197                r.append(x[1])
198            else:
199                raise RuntimeError("SharedQueue, invalid entry in metadata.")
200        return tuple(r)
201
202    def __del__(self):
203        if not self.dynamic_shm:
204            shm_list_len = len(self.shm_list)
205            for idx in range(shm_list_len):
206                del self.shm_list[shm_list_len - idx - 1]
207            self.shm_list.clear()
208            del self.shm_list
209
210        self.close()
211        self.join_thread()
212
213
214class _Queue(multiprocessing.queues.Queue):
215    """Specialized multiprocessing Queue that supports interrupted operations."""
216
217    def __init__(self, size):
218        super().__init__(size, ctx=multiprocessing.get_context())
219
220    def put_until(self, data, timeout=None, exit_signal=None):
221        """Put data into the queue. Block until timeout is reached or exit_signal is set."""
222        while True:
223            try:
224                self.put(data, timeout=timeout)
225                return
226            except queue.Full as e:
227                if exit_signal is None:
228                    raise e
229                if exit_signal.is_set():
230                    return
231                continue
232
233    def get_until(self, timeout=None, exit_signal=None):
234        """Get data from the queue. Block until timeout is reached or exit_signal is set."""
235        while True:
236            try:
237                r = self.get(timeout=timeout)
238            except queue.Empty as e:
239                if exit_signal is None:
240                    raise e
241                if exit_signal.is_set():
242                    return None
243                continue
244            if r is None:
245                # receive finish signal
246                return None
247            if exit_signal.is_set():
248                # loop until the queue becomes empty
249                continue
250            return r
251