# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """ This dataset module creates an internal queue class to more optimally pass data between multiple processes in Python. It has same API as multiprocessing.queue but it will pass large data through shared memory. """ import multiprocessing.queues import multiprocessing import types import numpy as np from mindspore import log as logger from ..transforms.py_transforms_util import ExceptionHandler class _SharedQueue(multiprocessing.queues.Queue): """ Class to implement a queue using shared memory for better performance. Args: size: Number of elements in the queue. copy_out: Flag to indidcate whether an extra copy should be done before returning. If data will immediately be copied before returning, then this can be set to False. max_rowsize: Maximum size of any element in the Queue in MB. """ def __init__(self, size, copy_out=False, max_rowsize=6): super().__init__(size, ctx=multiprocessing.get_context()) self.copy_out = copy_out # change max_rowsize in MB into bytes self.seg_size = max_rowsize * 1024 * 1024 ##pipe can hold up to 65,636 bytes at a time self.min_shared_mem = 10000 self.shm_list = [] self.seg_pos = 0 # num_seg has to be 2 more than the queue size. We can have remote worker filling a buffer, main process # reading a buffer and also have a full queue of buffers in the meta-data queue self.num_seg = size + 2 self.data_immediate = 0 self.data_shared = 1 self.print_error = True try: for _ in range(self.num_seg): a = multiprocessing.Array("b", self.seg_size) self.shm_list.append(a) except Exception: raise RuntimeError( "_SharedQueue: Error allocating " + str(self.seg_size) + "bytes, " + str(self.num_seg) + " elements." + " This might be caused by insufficient shm, and the recommended shm size is at least 5 GB." ) def put(self, data, timeout=None): if isinstance(data, ExceptionHandler): super().put(data, timeout=timeout) else: name_list = [] count = 0 start_bytes = 0 if not isinstance(data, tuple) and not isinstance(data, np.ndarray): raise TypeError("return value of user defined python function in GeneratorDataset or" " map should be numpy array or tuple of numpy array.") for r in data: # the map:pyfunc is a yield generator which can't be serialize if isinstance(r, types.GeneratorType): raise TypeError("Can not pickle {} object, please verify pyfunc return with numpy array" .format(type(r))) if (isinstance(r, np.ndarray) and r.size > self.min_shared_mem and start_bytes + r.nbytes < self.seg_size): # need to convert start_bytes to offset in array start_offset = start_bytes dest = np.ndarray(r.shape, r.dtype, buffer=self.shm_list[self.seg_pos].get_obj(), offset=start_offset) np.copyto(dest, r) byte = r.nbytes byte = 8 * ((byte + 7) // 8) start_bytes += byte name_list.append((self.data_shared, self.seg_pos, byte, r.dtype, r.shape)) count += 1 else: if isinstance(r, np.ndarray) and r.size >= self.min_shared_mem: # Only print out error the first time it happens if self.print_error: logger.warning( "Using shared memory queue, but rowsize is larger than allocated memory " + "max_rowsize " + str(self.seg_size) + " current rowsize " + str(start_bytes + r.nbytes) ) self.print_error = False name_list.append((self.data_immediate, r)) super().put(name_list, timeout=timeout) # note above could generate a queue full exception. It will be handled by teh caller # only increment seg_pos after successfully adding to metadata queue if start_bytes > 0: self.seg_pos = (self.seg_pos + 1) % self.num_seg def get(self, timeout=None): result = super().get(timeout=timeout) if isinstance(result, ExceptionHandler): return result r = [] start_bytes = 0 for x in result: if x[0] == self.data_shared: seg_pos = x[1] byte = x[2] dtype = x[3] shape = x[4] start_offset = start_bytes b = self.shm_list[seg_pos] data = np.ndarray(shape, dtype, buffer=b.get_obj(), offset=start_offset) start_bytes += byte if self.copy_out: data2 = np.copy(data) r.append(data2) else: r.append(data) elif x[0] == self.data_immediate: r.append(x[1]) else: raise RuntimeError("SharedQueue, invalid entry in metadata.") return tuple(r)