1 /** 2 * Copyright 2021-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_NONMAPPABLE_LEAF_OP_H_ 17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_NONMAPPABLE_LEAF_OP_H_ 18 19 #include <algorithm> 20 #include <memory> 21 #include <mutex> 22 #include <string> 23 #include <vector> 24 #include <map> 25 26 #include "minddata/dataset/util/wait_post.h" 27 #include "minddata/dataset/util/auto_index.h" 28 #include "minddata/dataset/util/status.h" 29 #include "minddata/dataset/core/tensor.h" 30 #include "minddata/dataset/engine/datasetops/parallel_op.h" 31 32 namespace mindspore { 33 namespace dataset { 34 template <typename T> 35 class Queue; 36 37 template <class T> 38 class Connector; 39 40 class JaggedConnector; 41 class FilenameBlock; 42 43 using StringIndex = AutoIndexObj<std::string>; 44 45 class NonMappableLeafOp : public ParallelOp<TensorRow, TensorRow> { 46 public: 47 // NONE: No compression_type is used 48 // GZIP: GZIP compression_type with num_samples provided 49 // ZLIB: ZLIB compression_type with num_samples provided 50 // GZIP_WITH_COUNT: GZIP compression_type with num_samples not provided 51 // ZLIB_WITH_COUNT: ZLIB compression_type with num_samples not provided 52 enum class CompressionType { NONE = 0, GZIP = 1, ZLIB = 2, GZIP_WITH_COUNT = 3, ZLIB_WITH_COUNT = 4 }; 53 54 // Constructor of TFReaderOp (2) 55 // @note The builder class should be used to call this constructor. 56 // @param num_workers - number of worker threads reading data from tf_file files. 57 // @param worker_connector_size - size of each internal queue. 58 // @param total_num_rows - Number of rows to read 59 // @param dataset_files_list - list of filepaths for the dataset files. 60 // @param op_connector_size - size of each queue in the connector that the child operator pulls from. 61 // @param columns_to_load - the names of the columns to load data from. 62 // @param shuffle_files - whether or not to shuffle the files before reading data. 63 // @param equal_rows_per_shard - whether or not to get equal rows for each process. 64 // @param compression_type - the compression type of the tf_file files 65 NonMappableLeafOp(int32_t num_workers, int32_t worker_connector_size, int64_t total_num_rows, 66 int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id, 67 const CompressionType &compression_type = CompressionType::NONE); 68 69 // Default destructor 70 ~NonMappableLeafOp() override = default; 71 72 // Instantiates the internal queues and connectors. 73 // @return Status - the error code returned. 74 virtual Status Init() = 0; 75 76 // Class functor operator () override. 77 // All dataset operators operate by launching a thread (see ExecutionTree). This class functor will 78 // provide the master loop that drives the logic for performing the work 79 // @return Status - the error code returned. 80 Status operator()() override; 81 82 // Overrides base class reset method. Cleans up any state info from it's previous execution and 83 // reinitializes itself so that it can be executed again, as if it was just created. 84 // @return Status - the error code returned. 85 Status Reset() override; 86 87 // Op name getter 88 // @return Name of the current Op Name()89 std::string Name() const override { return "NonMappableLeafOp"; } 90 91 // \Common implementation for PrepareOperators and PrepareOperatorPullBased 92 // @return Status The status code returned 93 Status PrepareOperatorImplementation(); 94 95 // \brief During tree prepare phase, operators may have specific post-operations to perform depending on 96 // their role. 97 // \notes Derived versions of this function should always call their superclass version first 98 // before providing their own implementations. 99 // @return Status The status code returned 100 Status PrepareOperator() override; 101 102 // \brief During tree prepare phase, operators may have specific post-operations to perform depending on 103 // their role. This is the implementation for pull mode. 104 // \notes Derived versions of this function should always call its superclass version first 105 // before providing their own implementations. 106 // \return Status The status code returned 107 Status PrepareOperatorPullBased() override; 108 109 /// \brief In pull mode, gets the next row 110 /// \param row[out] - Fetched TensorRow 111 /// \return Status The status code returned 112 Status GetNextRowPullMode(TensorRow *const row) override; 113 114 protected: 115 // The entry point for when workers are launched. 116 // @param worker_id - the id of the worker that is executing this function. 117 // @return Status - the error code returned. 118 Status WorkerEntry(int32_t worker_id) override; 119 120 // Pushes a control indicator onto the IOBlockQueue for each worker to consume. 121 // When the worker pops this control indicator, it will shut itself down gracefully. 122 // @return Status - the error code returned. 123 Status PostEndOfData(); 124 125 // Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker 126 // pops this control indicator, it will wait until the next epoch starts and then resume execution. 127 // @return Status - the error code returned. 128 Status PostEndOfEpoch(int32_t queue_index); 129 130 // Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue. 131 // @return Status - the error code returned. 132 Status WaitToFillIOBlockQueue(); 133 134 // Notifies the thread which called WaitToFillIOBlockQueue to resume execution. 135 void NotifyToFillIOBlockQueue(); 136 137 // Pops an element from a queue in IOBlockQueue. 138 // @param index - the index of the queue to pop from. 139 // @param out_block - the popped element. 140 // @return Status - the error code returned. 141 Status PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block); 142 143 // Pushes an element to a queue in IOBlockQueue. 144 // @param index - the index of the queue to push to. 145 // @param io_block - the element to push onto the queue. 146 // @return Status - the error code returned. 147 Status PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block); 148 149 // Reads a tf_file file and loads the data into multiple TensorRows. 150 // @param filename - the tf_file file to read. 151 // @param start_offset - the start offset of file. 152 // @param end_offset - the end offset of file. 153 // @param worker_id - the id of the worker that is executing this function. 154 // @return Status - the error code returned. 155 virtual Status LoadFile(const std::string &filename, int64_t start_offset, int64_t end_offset, int32_t worker_id) = 0; 156 157 // Select file and push it to the block queue. 158 // @param file_name - File name. 159 // @param start_file - If file contains the first sample of data. 160 // @param end_file - If file contains the end sample of data. 161 // @param pre_count - Total rows of previous files. 162 // @return Status - the error code returned. 163 bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, 164 const int64_t &pre_count); 165 166 // Calculate number of rows in each shard. 167 // @return Status - the error code returned. 168 virtual Status CalculateNumRowsPerShard() = 0; 169 170 void ShuffleKeys(); 171 172 // Fill the IOBlockQueue. 173 // @para i_keys - keys of file to fill to the IOBlockQueue 174 // @return Status - the error code returned. 175 virtual Status FillIOBlockQueue(const std::vector<int64_t> &i_keys) = 0; 176 GetLoadIoBlockQueue()177 virtual bool GetLoadIoBlockQueue() { 178 bool ret_load_io_block_queue = false; 179 { 180 std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_); 181 ret_load_io_block_queue = load_io_block_queue_; 182 } 183 return ret_load_io_block_queue; 184 } 185 GetLoadJaggedConnector()186 virtual bool GetLoadJaggedConnector() { 187 bool ret_load_jagged_connector = false; 188 { 189 std::unique_lock<std::mutex> lock(load_jagged_connector_mutex_); 190 ret_load_jagged_connector = load_jagged_connector_; 191 } 192 return ret_load_jagged_connector; 193 } 194 195 /// \brief Prepare data by reading from disk and caching tensors into the jagged_row_connector queue. 196 /// \return Status The status code returned 197 Status PrepareData(); 198 199 /// \brief Gets the implementation status for operator in pull mode 200 /// \return implementation status PullModeImplementationStatus()201 ImplementedPullMode PullModeImplementationStatus() const override { return ImplementedPullMode::Implemented; } 202 203 /// \brief reset the op and update repeat and epoch number if the condition is met. 204 /// \return Status The status code returned 205 Status ResetAndUpdateRepeat(); 206 207 int32_t device_id_; 208 int32_t num_devices_; 209 bool load_jagged_connector_; 210 std::mutex load_jagged_connector_mutex_; 211 std::unique_ptr<StringIndex> filename_index_; 212 213 QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_; 214 std::map<std::string, int64_t> filename_numrows_; 215 bool finished_reading_dataset_; 216 // Note: If compression_type_ is not empty, then total_rows_ is the total rows that will be read per shard 217 int64_t total_rows_; 218 CompressionType compression_type_; 219 220 WaitPost io_block_queue_wait_post_; 221 bool load_io_block_queue_; 222 std::mutex load_io_block_queue_mutex_; 223 std::unique_ptr<JaggedConnector> jagged_rows_connector_; 224 bool shuffle_files_; 225 int64_t num_rows_per_shard_; 226 int64_t num_rows_; 227 bool prepared_data_; // flag to indicate whether the data is prepared before taking for pull mode 228 uint32_t curr_row_; // current row number count for pull mode 229 uint32_t workers_done_; // how many workers have done the tensors reading work for pull mode 230 231 private: 232 std::vector<int64_t> shuffled_keys_; // to store shuffled filename indices 233 uint32_t seed_; // used to shuffle filename indices 234 }; 235 } // namespace dataset 236 } // namespace mindspore 237 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_NONMAPPABLE_LEAF_OP_H_ 238