• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_NONMAPPABLE_LEAF_OP_H_
17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_NONMAPPABLE_LEAF_OP_H_
18 
19 #include <algorithm>
20 #include <memory>
21 #include <mutex>
22 #include <string>
23 #include <vector>
24 #include <utility>
25 #include <map>
26 
27 #include "minddata/dataset/util/wait_post.h"
28 #include "minddata/dataset/util/auto_index.h"
29 #include "minddata/dataset/util/status.h"
30 #include "minddata/dataset/core/tensor.h"
31 #include "minddata/dataset/engine/datasetops/parallel_op.h"
32 
33 namespace mindspore {
34 namespace dataset {
35 template <typename T>
36 class Queue;
37 
38 template <class T>
39 class Connector;
40 
41 class JaggedConnector;
42 class FilenameBlock;
43 
44 using StringIndex = AutoIndexObj<std::string>;
45 
46 class NonMappableLeafOp : public ParallelOp {
47  public:
48   // Constructor of TFReaderOp (2)
49   // @note The builder class should be used to call this constructor.
50   // @param num_workers - number of worker threads reading data from tf_file files.
51   // @param worker_connector_size - size of each internal queue.
52   // @param total_num_rows - Number of rows to read
53   // @param dataset_files_list - list of filepaths for the dataset files.
54   // @param op_connector_size - size of each queue in the connector that the child operator pulls from.
55   // @param columns_to_load - the names of the columns to load data from.
56   // @param shuffle_files - whether or not to shuffle the files before reading data.
57   // @param equal_rows_per_shard - whether or not to get equal rows for each process.
58   NonMappableLeafOp(int32_t num_workers, int32_t worker_connector_size, int64_t total_num_rows,
59                     int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id);
60 
61   // Default destructor
62   ~NonMappableLeafOp() = default;
63 
64   // Instantiates the internal queues and connectors.
65   // @return Status - the error code returned.
66   virtual Status Init() = 0;
67 
68   // Class functor operator () override.
69   // All dataset operators operate by launching a thread (see ExecutionTree). This class functor will
70   // provide the master loop that drives the logic for performing the work
71   // @return Status - the error code returned.
72   Status operator()() override;
73 
74   // Overrides base class reset method. Cleans up any state info from it's previous execution and
75   // reinitializes itself so that it can be executed again, as if it was just created.
76   // @return Status - the error code returned.
77   Status Reset() override;
78 
79   // Op name getter
80   // @return Name of the current Op
Name()81   std::string Name() const override { return "NonMappableLeafOp"; }
82 
83  protected:
84   // The entry point for when workers are launched.
85   // @param worker_id - the id of the worker that is executing this function.
86   // @return Status - the error code returned.
87   Status WorkerEntry(int32_t worker_id) override;
88 
89   // Pushes a control indicator onto the IOBlockQueue for each worker to consume.
90   // When the worker pops this control indicator, it will shut itself down gracefully.
91   // @return Status - the error code returned.
92   Status PostEndOfData();
93 
94   // Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
95   // pops this control indicator, it will wait until the next epoch starts and then resume execution.
96   // @return Status - the error code returned.
97   Status PostEndOfEpoch(int32_t queue_index);
98 
99   // Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue.
100   // @return Status - the error code returned.
101   Status WaitToFillIOBlockQueue();
102 
103   // Notifies the thread which called WaitToFillIOBlockQueue to resume execution.
104   void NotifyToFillIOBlockQueue();
105 
106   // Pops an element from a queue in IOBlockQueue.
107   // @param index - the index of the queue to pop from.
108   // @param out_block - the popped element.
109   // @return Status - the error code returned.
110   Status PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block);
111 
112   // Pushes an element to a queue in IOBlockQueue.
113   // @param index - the index of the queue to push to.
114   // @param io_block - the element to push onto the queue.
115   // @return Status - the error code returned.
116   Status PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block);
117 
118   // Reads a tf_file file and loads the data into multiple TensorRows.
119   // @param filename - the tf_file file to read.
120   // @param start_offset - the start offset of file.
121   // @param end_offset - the end offset of file.
122   // @param worker_id - the id of the worker that is executing this function.
123   // @return Status - the error code returned.
124   virtual Status LoadFile(const std::string &filename, int64_t start_offset, int64_t end_offset, int32_t worker_id) = 0;
125 
126   // Select file and push it to the block queue.
127   // @param file_name - File name.
128   // @param start_file - If file contains the first sample of data.
129   // @param end_file - If file contains the end sample of data.
130   // @param pre_count - Total rows of previous files.
131   // @return Status - the error code returned.
132   bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
133                                 const int64_t &pre_count);
134 
135   // Calculate number of rows in each shard.
136   // @return Status - the error code returned.
137   virtual Status CalculateNumRowsPerShard() = 0;
138 
139   static void ShuffleKeys(std::vector<int64_t> *i_keys, uint32_t seed);
140 
141   // Fill the IOBlockQueue.
142   // @para i_keys - keys of file to fill to the IOBlockQueue
143   // @return Status - the error code returned.
144   virtual Status FillIOBlockQueue(const std::vector<int64_t> &i_keys) = 0;
145 
146   int32_t device_id_;
147   int32_t num_devices_;
148   bool load_jagged_connector_;
149   std::unique_ptr<StringIndex> filename_index_;
150 
151   QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_;
152   std::map<std::string, int64_t> filename_numrows_;
153   bool finished_reading_dataset_;
154   int64_t total_rows_;
155 
156   WaitPost io_block_queue_wait_post_;
157   bool load_io_block_queue_;
158   std::mutex load_io_block_queue_mutex_;
159   std::unique_ptr<JaggedConnector> jagged_rows_connector_;
160   bool shuffle_files_;
161   int64_t num_rows_per_shard_;
162   int64_t num_rows_;
163 };
164 }  // namespace dataset
165 }  // namespace mindspore
166 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_NONMAPPABLE_LEAF_OP_H_
167