1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16This dataset module creates an internal queue class to more optimally pass data 17between multiple processes in Python. It has same API as multiprocessing.queue 18but it will pass large data through shared memory. 19""" 20 21import multiprocessing.queues 22import multiprocessing 23import types 24import numpy as np 25 26from mindspore import log as logger 27from ..transforms.py_transforms_util import ExceptionHandler 28 29 30class _SharedQueue(multiprocessing.queues.Queue): 31 """ 32 Class to implement a queue using shared memory for better performance. 33 Args: 34 size: Number of elements in the queue. 35 copy_out: Flag to indidcate whether an extra copy should be done before returning. If data will immediately be 36 copied before returning, then this can be set to False. 37 max_rowsize: Maximum size of any element in the Queue in MB. 38 """ 39 40 def __init__(self, size, copy_out=False, max_rowsize=6): 41 super().__init__(size, ctx=multiprocessing.get_context()) 42 43 self.copy_out = copy_out 44 45 # change max_rowsize in MB into bytes 46 self.seg_size = max_rowsize * 1024 * 1024 47 ##pipe can hold up to 65,636 bytes at a time 48 self.min_shared_mem = 10000 49 self.shm_list = [] 50 self.seg_pos = 0 51 # num_seg has to be 2 more than the queue size. We can have remote worker filling a buffer, main process 52 # reading a buffer and also have a full queue of buffers in the meta-data queue 53 self.num_seg = size + 2 54 self.data_immediate = 0 55 self.data_shared = 1 56 self.print_error = True 57 58 try: 59 for _ in range(self.num_seg): 60 a = multiprocessing.Array("b", self.seg_size) 61 self.shm_list.append(a) 62 except Exception: 63 raise RuntimeError( 64 "_SharedQueue: Error allocating " 65 + str(self.seg_size) 66 + "bytes, " 67 + str(self.num_seg) 68 + " elements." 69 + " This might be caused by insufficient shm, and the recommended shm size is at least 5 GB." 70 ) 71 72 def put(self, data, timeout=None): 73 if isinstance(data, ExceptionHandler): 74 super().put(data, timeout=timeout) 75 else: 76 name_list = [] 77 count = 0 78 start_bytes = 0 79 if not isinstance(data, tuple) and not isinstance(data, np.ndarray): 80 raise TypeError("return value of user defined python function in GeneratorDataset or" 81 " map should be numpy array or tuple of numpy array.") 82 for r in data: 83 # the map:pyfunc is a yield generator which can't be serialize 84 if isinstance(r, types.GeneratorType): 85 raise TypeError("Can not pickle {} object, please verify pyfunc return with numpy array" 86 .format(type(r))) 87 if (isinstance(r, np.ndarray) and r.size > self.min_shared_mem 88 and start_bytes + r.nbytes < self.seg_size): 89 # need to convert start_bytes to offset in array 90 start_offset = start_bytes 91 dest = np.ndarray(r.shape, r.dtype, buffer=self.shm_list[self.seg_pos].get_obj(), 92 offset=start_offset) 93 np.copyto(dest, r) 94 byte = r.nbytes 95 byte = 8 * ((byte + 7) // 8) 96 start_bytes += byte 97 name_list.append((self.data_shared, self.seg_pos, byte, r.dtype, r.shape)) 98 count += 1 99 else: 100 if isinstance(r, np.ndarray) and r.size >= self.min_shared_mem: 101 # Only print out error the first time it happens 102 if self.print_error: 103 logger.warning( 104 "Using shared memory queue, but rowsize is larger than allocated memory " 105 + "max_rowsize " 106 + str(self.seg_size) 107 + " current rowsize " 108 + str(start_bytes + r.nbytes) 109 ) 110 self.print_error = False 111 name_list.append((self.data_immediate, r)) 112 super().put(name_list, timeout=timeout) 113 # note above could generate a queue full exception. It will be handled by teh caller 114 # only increment seg_pos after successfully adding to metadata queue 115 116 if start_bytes > 0: 117 self.seg_pos = (self.seg_pos + 1) % self.num_seg 118 119 def get(self, timeout=None): 120 result = super().get(timeout=timeout) 121 if isinstance(result, ExceptionHandler): 122 return result 123 r = [] 124 start_bytes = 0 125 for x in result: 126 if x[0] == self.data_shared: 127 seg_pos = x[1] 128 byte = x[2] 129 dtype = x[3] 130 shape = x[4] 131 start_offset = start_bytes 132 b = self.shm_list[seg_pos] 133 data = np.ndarray(shape, dtype, buffer=b.get_obj(), offset=start_offset) 134 start_bytes += byte 135 if self.copy_out: 136 data2 = np.copy(data) 137 r.append(data2) 138 else: 139 r.append(data) 140 elif x[0] == self.data_immediate: 141 r.append(x[1]) 142 else: 143 raise RuntimeError("SharedQueue, invalid entry in metadata.") 144 return tuple(r) 145