1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_NONMAPPABLE_LEAF_OP_H_ 17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_NONMAPPABLE_LEAF_OP_H_ 18 19 #include <algorithm> 20 #include <memory> 21 #include <mutex> 22 #include <string> 23 #include <vector> 24 #include <utility> 25 #include <map> 26 27 #include "minddata/dataset/util/wait_post.h" 28 #include "minddata/dataset/util/auto_index.h" 29 #include "minddata/dataset/util/status.h" 30 #include "minddata/dataset/core/tensor.h" 31 #include "minddata/dataset/engine/datasetops/parallel_op.h" 32 33 namespace mindspore { 34 namespace dataset { 35 template <typename T> 36 class Queue; 37 38 template <class T> 39 class Connector; 40 41 class JaggedConnector; 42 class FilenameBlock; 43 44 using StringIndex = AutoIndexObj<std::string>; 45 46 class NonMappableLeafOp : public ParallelOp { 47 public: 48 // Constructor of TFReaderOp (2) 49 // @note The builder class should be used to call this constructor. 50 // @param num_workers - number of worker threads reading data from tf_file files. 51 // @param worker_connector_size - size of each internal queue. 52 // @param total_num_rows - Number of rows to read 53 // @param dataset_files_list - list of filepaths for the dataset files. 54 // @param op_connector_size - size of each queue in the connector that the child operator pulls from. 55 // @param columns_to_load - the names of the columns to load data from. 56 // @param shuffle_files - whether or not to shuffle the files before reading data. 57 // @param equal_rows_per_shard - whether or not to get equal rows for each process. 58 NonMappableLeafOp(int32_t num_workers, int32_t worker_connector_size, int64_t total_num_rows, 59 int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id); 60 61 // Default destructor 62 ~NonMappableLeafOp() = default; 63 64 // Instantiates the internal queues and connectors. 65 // @return Status - the error code returned. 66 virtual Status Init() = 0; 67 68 // Class functor operator () override. 69 // All dataset operators operate by launching a thread (see ExecutionTree). This class functor will 70 // provide the master loop that drives the logic for performing the work 71 // @return Status - the error code returned. 72 Status operator()() override; 73 74 // Overrides base class reset method. Cleans up any state info from it's previous execution and 75 // reinitializes itself so that it can be executed again, as if it was just created. 76 // @return Status - the error code returned. 77 Status Reset() override; 78 79 // Op name getter 80 // @return Name of the current Op Name()81 std::string Name() const override { return "NonMappableLeafOp"; } 82 83 protected: 84 // The entry point for when workers are launched. 85 // @param worker_id - the id of the worker that is executing this function. 86 // @return Status - the error code returned. 87 Status WorkerEntry(int32_t worker_id) override; 88 89 // Pushes a control indicator onto the IOBlockQueue for each worker to consume. 90 // When the worker pops this control indicator, it will shut itself down gracefully. 91 // @return Status - the error code returned. 92 Status PostEndOfData(); 93 94 // Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker 95 // pops this control indicator, it will wait until the next epoch starts and then resume execution. 96 // @return Status - the error code returned. 97 Status PostEndOfEpoch(int32_t queue_index); 98 99 // Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue. 100 // @return Status - the error code returned. 101 Status WaitToFillIOBlockQueue(); 102 103 // Notifies the thread which called WaitToFillIOBlockQueue to resume execution. 104 void NotifyToFillIOBlockQueue(); 105 106 // Pops an element from a queue in IOBlockQueue. 107 // @param index - the index of the queue to pop from. 108 // @param out_block - the popped element. 109 // @return Status - the error code returned. 110 Status PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block); 111 112 // Pushes an element to a queue in IOBlockQueue. 113 // @param index - the index of the queue to push to. 114 // @param io_block - the element to push onto the queue. 115 // @return Status - the error code returned. 116 Status PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block); 117 118 // Reads a tf_file file and loads the data into multiple TensorRows. 119 // @param filename - the tf_file file to read. 120 // @param start_offset - the start offset of file. 121 // @param end_offset - the end offset of file. 122 // @param worker_id - the id of the worker that is executing this function. 123 // @return Status - the error code returned. 124 virtual Status LoadFile(const std::string &filename, int64_t start_offset, int64_t end_offset, int32_t worker_id) = 0; 125 126 // Select file and push it to the block queue. 127 // @param file_name - File name. 128 // @param start_file - If file contains the first sample of data. 129 // @param end_file - If file contains the end sample of data. 130 // @param pre_count - Total rows of previous files. 131 // @return Status - the error code returned. 132 bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, 133 const int64_t &pre_count); 134 135 // Calculate number of rows in each shard. 136 // @return Status - the error code returned. 137 virtual Status CalculateNumRowsPerShard() = 0; 138 139 static void ShuffleKeys(std::vector<int64_t> *i_keys, uint32_t seed); 140 141 // Fill the IOBlockQueue. 142 // @para i_keys - keys of file to fill to the IOBlockQueue 143 // @return Status - the error code returned. 144 virtual Status FillIOBlockQueue(const std::vector<int64_t> &i_keys) = 0; 145 146 int32_t device_id_; 147 int32_t num_devices_; 148 bool load_jagged_connector_; 149 std::unique_ptr<StringIndex> filename_index_; 150 151 QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_; 152 std::map<std::string, int64_t> filename_numrows_; 153 bool finished_reading_dataset_; 154 int64_t total_rows_; 155 156 WaitPost io_block_queue_wait_post_; 157 bool load_io_block_queue_; 158 std::mutex load_io_block_queue_mutex_; 159 std::unique_ptr<JaggedConnector> jagged_rows_connector_; 160 bool shuffle_files_; 161 int64_t num_rows_per_shard_; 162 int64_t num_rows_; 163 }; 164 } // namespace dataset 165 } // namespace mindspore 166 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_NONMAPPABLE_LEAF_OP_H_ 167