• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_NONMAPPABLE_LEAF_OP_H_
17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_NONMAPPABLE_LEAF_OP_H_
18 
19 #include <algorithm>
20 #include <memory>
21 #include <mutex>
22 #include <string>
23 #include <vector>
24 #include <map>
25 
26 #include "minddata/dataset/util/wait_post.h"
27 #include "minddata/dataset/util/auto_index.h"
28 #include "minddata/dataset/util/status.h"
29 #include "minddata/dataset/core/tensor.h"
30 #include "minddata/dataset/engine/datasetops/parallel_op.h"
31 
32 namespace mindspore {
33 namespace dataset {
34 template <typename T>
35 class Queue;
36 
37 template <class T>
38 class Connector;
39 
40 class JaggedConnector;
41 class FilenameBlock;
42 
43 using StringIndex = AutoIndexObj<std::string>;
44 
45 class NonMappableLeafOp : public ParallelOp<TensorRow, TensorRow> {
46  public:
47   // NONE: No compression_type is used
48   // GZIP: GZIP compression_type with num_samples provided
49   // ZLIB: ZLIB compression_type with num_samples provided
50   // GZIP_WITH_COUNT: GZIP compression_type with num_samples not provided
51   // ZLIB_WITH_COUNT: ZLIB compression_type with num_samples not provided
52   enum class CompressionType { NONE = 0, GZIP = 1, ZLIB = 2, GZIP_WITH_COUNT = 3, ZLIB_WITH_COUNT = 4 };
53 
54   // Constructor of TFReaderOp (2)
55   // @note The builder class should be used to call this constructor.
56   // @param num_workers - number of worker threads reading data from tf_file files.
57   // @param worker_connector_size - size of each internal queue.
58   // @param total_num_rows - Number of rows to read
59   // @param dataset_files_list - list of filepaths for the dataset files.
60   // @param op_connector_size - size of each queue in the connector that the child operator pulls from.
61   // @param columns_to_load - the names of the columns to load data from.
62   // @param shuffle_files - whether or not to shuffle the files before reading data.
63   // @param equal_rows_per_shard - whether or not to get equal rows for each process.
64   // @param compression_type - the compression type of the tf_file files
65   NonMappableLeafOp(int32_t num_workers, int32_t worker_connector_size, int64_t total_num_rows,
66                     int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id,
67                     const CompressionType &compression_type = CompressionType::NONE);
68 
69   // Default destructor
70   ~NonMappableLeafOp() override = default;
71 
72   // Instantiates the internal queues and connectors.
73   // @return Status - the error code returned.
74   virtual Status Init() = 0;
75 
76   // Class functor operator () override.
77   // All dataset operators operate by launching a thread (see ExecutionTree). This class functor will
78   // provide the master loop that drives the logic for performing the work
79   // @return Status - the error code returned.
80   Status operator()() override;
81 
82   // Overrides base class reset method. Cleans up any state info from it's previous execution and
83   // reinitializes itself so that it can be executed again, as if it was just created.
84   // @return Status - the error code returned.
85   Status Reset() override;
86 
87   // Op name getter
88   // @return Name of the current Op
Name()89   std::string Name() const override { return "NonMappableLeafOp"; }
90 
91   // \Common implementation for PrepareOperators and PrepareOperatorPullBased
92   // @return Status The status code returned
93   Status PrepareOperatorImplementation();
94 
95   // \brief During tree prepare phase, operators may have specific post-operations to perform depending on
96   //     their role.
97   // \notes Derived versions of this function should always call their superclass version first
98   //     before providing their own implementations.
99   // @return Status The status code returned
100   Status PrepareOperator() override;
101 
102   // \brief During tree prepare phase, operators may have specific post-operations to perform depending on
103   //     their role. This is the implementation for pull mode.
104   // \notes Derived versions of this function should always call its superclass version first
105   //     before providing their own implementations.
106   // \return Status The status code returned
107   Status PrepareOperatorPullBased() override;
108 
109   /// \brief In pull mode, gets the next row
110   /// \param row[out] - Fetched TensorRow
111   /// \return Status The status code returned
112   Status GetNextRowPullMode(TensorRow *const row) override;
113 
114  protected:
115   // The entry point for when workers are launched.
116   // @param worker_id - the id of the worker that is executing this function.
117   // @return Status - the error code returned.
118   Status WorkerEntry(int32_t worker_id) override;
119 
120   // Pushes a control indicator onto the IOBlockQueue for each worker to consume.
121   // When the worker pops this control indicator, it will shut itself down gracefully.
122   // @return Status - the error code returned.
123   Status PostEndOfData();
124 
125   // Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
126   // pops this control indicator, it will wait until the next epoch starts and then resume execution.
127   // @return Status - the error code returned.
128   Status PostEndOfEpoch(int32_t queue_index);
129 
130   // Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue.
131   // @return Status - the error code returned.
132   Status WaitToFillIOBlockQueue();
133 
134   // Notifies the thread which called WaitToFillIOBlockQueue to resume execution.
135   void NotifyToFillIOBlockQueue();
136 
137   // Pops an element from a queue in IOBlockQueue.
138   // @param index - the index of the queue to pop from.
139   // @param out_block - the popped element.
140   // @return Status - the error code returned.
141   Status PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block);
142 
143   // Pushes an element to a queue in IOBlockQueue.
144   // @param index - the index of the queue to push to.
145   // @param io_block - the element to push onto the queue.
146   // @return Status - the error code returned.
147   Status PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block);
148 
149   // Reads a tf_file file and loads the data into multiple TensorRows.
150   // @param filename - the tf_file file to read.
151   // @param start_offset - the start offset of file.
152   // @param end_offset - the end offset of file.
153   // @param worker_id - the id of the worker that is executing this function.
154   // @return Status - the error code returned.
155   virtual Status LoadFile(const std::string &filename, int64_t start_offset, int64_t end_offset, int32_t worker_id) = 0;
156 
157   // Select file and push it to the block queue.
158   // @param file_name - File name.
159   // @param start_file - If file contains the first sample of data.
160   // @param end_file - If file contains the end sample of data.
161   // @param pre_count - Total rows of previous files.
162   // @return Status - the error code returned.
163   bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
164                                 const int64_t &pre_count);
165 
166   // Calculate number of rows in each shard.
167   // @return Status - the error code returned.
168   virtual Status CalculateNumRowsPerShard() = 0;
169 
170   void ShuffleKeys();
171 
172   // Fill the IOBlockQueue.
173   // @para i_keys - keys of file to fill to the IOBlockQueue
174   // @return Status - the error code returned.
175   virtual Status FillIOBlockQueue(const std::vector<int64_t> &i_keys) = 0;
176 
GetLoadIoBlockQueue()177   virtual bool GetLoadIoBlockQueue() {
178     bool ret_load_io_block_queue = false;
179     {
180       std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_);
181       ret_load_io_block_queue = load_io_block_queue_;
182     }
183     return ret_load_io_block_queue;
184   }
185 
GetLoadJaggedConnector()186   virtual bool GetLoadJaggedConnector() {
187     bool ret_load_jagged_connector = false;
188     {
189       std::unique_lock<std::mutex> lock(load_jagged_connector_mutex_);
190       ret_load_jagged_connector = load_jagged_connector_;
191     }
192     return ret_load_jagged_connector;
193   }
194 
195   /// \brief Prepare data by reading from disk and caching tensors into the jagged_row_connector queue.
196   /// \return Status The status code returned
197   Status PrepareData();
198 
199   /// \brief Gets the implementation status for operator in pull mode
200   /// \return implementation status
PullModeImplementationStatus()201   ImplementedPullMode PullModeImplementationStatus() const override { return ImplementedPullMode::Implemented; }
202 
203   /// \brief reset the op and update repeat and epoch number if the condition is met.
204   /// \return Status The status code returned
205   Status ResetAndUpdateRepeat();
206 
207   int32_t device_id_;
208   int32_t num_devices_;
209   bool load_jagged_connector_;
210   std::mutex load_jagged_connector_mutex_;
211   std::unique_ptr<StringIndex> filename_index_;
212 
213   QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_;
214   std::map<std::string, int64_t> filename_numrows_;
215   bool finished_reading_dataset_;
216   // Note: If compression_type_ is not empty, then total_rows_ is the total rows that will be read per shard
217   int64_t total_rows_;
218   CompressionType compression_type_;
219 
220   WaitPost io_block_queue_wait_post_;
221   bool load_io_block_queue_;
222   std::mutex load_io_block_queue_mutex_;
223   std::unique_ptr<JaggedConnector> jagged_rows_connector_;
224   bool shuffle_files_;
225   int64_t num_rows_per_shard_;
226   int64_t num_rows_;
227   bool prepared_data_;     // flag to indicate whether the data is prepared before taking for pull mode
228   uint32_t curr_row_;      // current row number count for pull mode
229   uint32_t workers_done_;  // how many workers have done the tensors reading work for pull mode
230 
231  private:
232   std::vector<int64_t> shuffled_keys_;  // to store shuffled filename indices
233   uint32_t seed_;                       // used to shuffle filename indices
234 };
235 }  // namespace dataset
236 }  // namespace mindspore
237 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_NONMAPPABLE_LEAF_OP_H_
238