• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16This dataset module creates an internal queue class to more optimally pass data
17between multiple processes in Python.  It has same API as multiprocessing.queue
18but it will pass large data through shared memory.
19"""
20
21import multiprocessing.queues
22import multiprocessing
23import types
24import numpy as np
25
26from mindspore import log as logger
27from ..transforms.py_transforms_util import ExceptionHandler
28
29
30class _SharedQueue(multiprocessing.queues.Queue):
31    """
32    Class to implement a queue using shared memory for better performance.
33    Args:
34        size: Number of elements in the queue.
35        copy_out: Flag to indidcate whether an extra copy should be done before returning.  If data will immediately be
36                  copied before returning, then this can be set to False.
37        max_rowsize: Maximum size of any element in the Queue in MB.
38    """
39
40    def __init__(self, size, copy_out=False, max_rowsize=6):
41        super().__init__(size, ctx=multiprocessing.get_context())
42
43        self.copy_out = copy_out
44
45        # change max_rowsize in MB into bytes
46        self.seg_size = max_rowsize * 1024 * 1024
47        ##pipe can hold up to 65,636 bytes at a time
48        self.min_shared_mem = 10000
49        self.shm_list = []
50        self.seg_pos = 0
51        # num_seg has to be 2 more than the queue size.  We can have remote worker filling a buffer, main process
52        # reading a buffer and also have a full queue of buffers in the meta-data queue
53        self.num_seg = size + 2
54        self.data_immediate = 0
55        self.data_shared = 1
56        self.print_error = True
57
58        try:
59            for _ in range(self.num_seg):
60                a = multiprocessing.Array("b", self.seg_size)
61                self.shm_list.append(a)
62        except Exception:
63            raise RuntimeError(
64                "_SharedQueue: Error allocating "
65                + str(self.seg_size)
66                + "bytes, "
67                + str(self.num_seg)
68                + " elements."
69                + " This might be caused by insufficient shm, and the recommended shm size is at least 5 GB."
70            )
71
72    def put(self, data, timeout=None):
73        if isinstance(data, ExceptionHandler):
74            super().put(data, timeout=timeout)
75        else:
76            name_list = []
77            count = 0
78            start_bytes = 0
79            if not isinstance(data, tuple) and not isinstance(data, np.ndarray):
80                raise TypeError("return value of user defined python function in GeneratorDataset or"
81                                " map should be numpy array or tuple of numpy array.")
82            for r in data:
83                # the map:pyfunc is a yield generator which can't be serialize
84                if isinstance(r, types.GeneratorType):
85                    raise TypeError("Can not pickle {} object, please verify pyfunc return with numpy array"
86                                    .format(type(r)))
87                if (isinstance(r, np.ndarray) and r.size > self.min_shared_mem
88                        and start_bytes + r.nbytes < self.seg_size):
89                    # need to convert start_bytes to offset in array
90                    start_offset = start_bytes
91                    dest = np.ndarray(r.shape, r.dtype, buffer=self.shm_list[self.seg_pos].get_obj(),
92                                      offset=start_offset)
93                    np.copyto(dest, r)
94                    byte = r.nbytes
95                    byte = 8 * ((byte + 7) // 8)
96                    start_bytes += byte
97                    name_list.append((self.data_shared, self.seg_pos, byte, r.dtype, r.shape))
98                    count += 1
99                else:
100                    if isinstance(r, np.ndarray) and r.size >= self.min_shared_mem:
101                        # Only print out error the first time it happens
102                        if self.print_error:
103                            logger.warning(
104                                "Using shared memory queue, but rowsize is larger than allocated memory "
105                                + "max_rowsize "
106                                + str(self.seg_size)
107                                + " current rowsize "
108                                + str(start_bytes + r.nbytes)
109                            )
110                            self.print_error = False
111                    name_list.append((self.data_immediate, r))
112            super().put(name_list, timeout=timeout)
113            # note above could generate a queue full exception.  It will be handled by teh caller
114            # only increment seg_pos after successfully adding to metadata queue
115
116            if start_bytes > 0:
117                self.seg_pos = (self.seg_pos + 1) % self.num_seg
118
119    def get(self, timeout=None):
120        result = super().get(timeout=timeout)
121        if isinstance(result, ExceptionHandler):
122            return result
123        r = []
124        start_bytes = 0
125        for x in result:
126            if x[0] == self.data_shared:
127                seg_pos = x[1]
128                byte = x[2]
129                dtype = x[3]
130                shape = x[4]
131                start_offset = start_bytes
132                b = self.shm_list[seg_pos]
133                data = np.ndarray(shape, dtype, buffer=b.get_obj(), offset=start_offset)
134                start_bytes += byte
135                if self.copy_out:
136                    data2 = np.copy(data)
137                    r.append(data2)
138                else:
139                    r.append(data)
140            elif x[0] == self.data_immediate:
141                r.append(x[1])
142            else:
143                raise RuntimeError("SharedQueue, invalid entry in metadata.")
144        return tuple(r)
145