• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_DATASET_OP_H_
17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_DATASET_OP_H_
18 
19 #include <memory>
20 #include <mutex>
21 #include <string>
22 #include <unordered_map>
23 #include <vector>
24 #include <utility>
25 
26 #include "minddata/dataset/callback/callback_manager.h"
27 #include "minddata/dataset/include/dataset/constants.h"
28 #include "minddata/dataset/engine/db_connector.h"
29 #include "minddata/dataset/util/status.h"
30 
31 namespace mindspore {
32 namespace dataset {
33 
34 constexpr char kBarrierOp[] = "BarrierOp";
35 constexpr char kBatchOp[] = "BatchOp";
36 constexpr char kBucketBatchByLengthOp[] = "BucketBatchByLengthOp";
37 constexpr char kBuildSentencePieceVocabOp[] = "BuildSentencePieceVocabOp";
38 constexpr char kBuildVocabOp[] = "BuildVocabOp";
39 constexpr char kCacheBase[] = "CacheBase";
40 constexpr char kCacheLookupOp[] = "CacheLookupOp";
41 constexpr char kCacheMergeOp[] = "CacheMergeOp";
42 constexpr char kCacheOp[] = "CacheOp";
43 constexpr char kConcatOp[] = "ConcatOp";
44 constexpr char kDatasetOp[] = "DatasetOp";
45 constexpr char kDeviceQueueOp[] = "DeviceQueueOp";
46 constexpr char kEpochCtrlOp[] = "EpochCtrlOp";
47 constexpr char kFilterOp[] = "FilterOp";
48 constexpr char kMapOp[] = "MapOp";
49 constexpr char kParallelOp[] = "ParallelOp";
50 constexpr char kPipelineOp[] = "PipelineOp";
51 constexpr char kProjectOp[] = "ProjectOp";
52 constexpr char kRenameOp[] = "RenameOp";
53 constexpr char kRepeatOp[] = "RepeatOp";
54 constexpr char kShuffleOp[] = "ShuffleOp";
55 constexpr char kSkipOp[] = "SkipOp";
56 constexpr char kTakeOp[] = "TakeOp";
57 constexpr char kZipOp[] = "ZipOp";
58 
59 // Forward declare
60 class ExecutionTree;
61 
62 class NodePass;
63 
64 class SamplerRT;
65 
66 // \brief The base class DatasetOp is the main tree node.  It is an abstract class, so
67 // the actual implementation of the operators will be derived from here.
68 class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
69   // Allow execution tree to access internal members
70   friend class ExecutionTree;
71 
72  public:
73   static constexpr int32_t kInvalidOperatorId = -1;
74   static constexpr int32_t kInfiniteRepeat = -1;
75 
76   // Flags that control operator runtime behaviours
77   enum OpState { kDeOpRunning = 0, kDeOpIdle = 1, kDeOpTerminated };
78 
79   // \brief Constructor
80   // \param op_connector_size - The size for the output connector of this operator.
81   // \param sampler - The sampler for the op
82   DatasetOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler);
83 
84   // \brief Destructor
~DatasetOp()85   virtual ~DatasetOp() { tree_ = nullptr; }
86 
87   // \brief Adds a operator to become our child.
88   // \param child - shared pointer to the child to add.
89   Status AddChild(std::shared_ptr<DatasetOp> child);
90 
91   // \brief Remove a operator from our children.
92   // \param child - shared pointer to the child to remove.
93   Status RemoveChild(std::shared_ptr<DatasetOp> child);
94 
95   // \brief Removes this node from the tree and connects it's parent/child together
96   // \return Status eerror code returned
97   Status Remove();
98 
99   // Removes child operator in this operator.
100   Status RemoveChildren();
101 
102   // \brief Getter function to get a shared pointer to our child
103   // \param[in] child_index An operator can have n children. Indicates which child to return.
104   // \return The shared pointer to the child.  If there are no children, it returns null regardless of the given index
105   std::shared_ptr<DatasetOp> child(int32_t child_index) const;
106 
107   // \brief Getter function to get the pointer to our parent
108   //     If there are no parents, it returns null regardless of the given index
109   // \param[in] parent_index An operator can have n parents. Indicates which parent to return.
110   void Parent(DatasetOp **parent, int32_t parent_index) const;
111 
112   // Getter function to get all of our parents.
113   std::vector<DatasetOp *> parents() const;
114 
115   // \brief Inserts a operator as the parent current op.
116   // \notes Inserted op will become the sole parent of the current op.
117   //     The existing parent of the current op will be transferred to the inserted op.
118   Status InsertAsParent(std::shared_ptr<DatasetOp> to_add);
119 
120   // \brief Creates the connector within this operator
121   // \param num_producers - number of threads that write into this connector
122   // \param num_consumers - number of threads that read from this connector
123   void CreateConnector(int32_t num_producers, int32_t num_consumers);
124 
125   // \brief A print method typically used for debugging
126   // \param out - The output stream to write output to
127   // \param show_all - A bool to control if you want to show all info or just a summary
128   virtual void Print(std::ostream &out, bool show_all) const;
129 
130   /// \brief Gets the next row
131   /// \param row[out] - Fetched TensorRow
132   /// \return Status The status code returned
133   virtual Status GetNextRowPullMode(TensorRow *const row);
134 
135   /// \brief << Stream output operator overload
136   /// \notes This allows you to write the debug print info using stream operators
137   /// \param out - reference to the output stream being overloaded
138   /// \param dO - reference to the DatasetOp to display
139   /// \return - the output stream must be returned
140   friend std::ostream &operator<<(std::ostream &out, const DatasetOp &dO) {
141     dO.Print(out, false);
142     return out;
143   }
144 
145   // \brief Class functor operator ().
146   // \notes DatasetOps operate by launching a thread (see ExecutionTree).
147   //     This pure virtual version makes the requirement that derived classes must provide a functor
148   //     that will execute their main runtime loop code.
149   // \return Status The status code returned
150   virtual Status operator()() = 0;
151 
152   /// \brief Gets the next row from the given child
153   /// \param row[out] - Fetched TensorRow
154   /// \param worker_id[in] - The worker id, default to 0.
155   /// \return Status The status code returned
156   virtual Status GetNextRow(TensorRow *row, int32_t worker_id = 0) { return GetNextRow(row, worker_id, false); }
157 
158   /// \brief Gets the next row from the given child
159   /// \param row[out] - Fetched TensorRow
160   /// \param worker_id[in] - The worker id, default to 0.
161   /// \param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE.
162   /// \return Status The status code returned
163   virtual Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe);
164 
165   // \brief Gets the batch size
166   // \return Status - The status code return
167   virtual int64_t GetTreeBatchSize();
168 
169   // \brief Gets the repeat count
170   // \return Status - The status code return
171   virtual int64_t GetTreeRepeatCount();
172 
173   // \brief Gets the number of classes
174   // \return Status - The status code return
175   virtual Status GetNumClasses(int64_t *num_classes);
176 
177   // \brief Gets the class indexing
178   // \return Status - The status code return
179   virtual Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing);
180 
181   // \brief Performs handling for when an eoe message is received.
182   //     The base class implementation simply flows the eoe message to output. Derived classes
183   //     may override if they need to perform special eoe handling.
184   // \param worker_id - The worker id
185   // \return Status The status code returned
186   virtual Status EoeReceived(int32_t worker_id);
187 
188   // \brief Performs handling for when an eof message is received.
189   //     The base class implementation simply flows the eof message to output. Derived classes
190   //     may override if they need to perform special eof handling.
191   // \param worker_id - The worker id
192   // \return Status The status code returned
193   virtual Status EofReceived(int32_t worker_id);
194 
195   // \brief Derived classes may implement the reset function if the operator is stateful and needs
196   //     specific reset handling that is not contained in this common code version of the reset
197   // \return Status The status code returned
198   virtual Status Reset();
199 
200   // \brief During tree prepare phase, operators may have specific post-operations to perform depending on
201   //     their role.
202   // \notes Derived versions of this function should always call it's superclass version first
203   //     before providing their own implementations.
204   virtual Status PrepareOperator();
205 
206   // \brief Getter function
207   // \return The operator id
id()208   int32_t id() const { return operator_id_; }
209 
210   // \brief Getter function
211   // \return The number of workers in this op
212   virtual int32_t NumWorkers() const = 0;
213 
214   // \brief Getter function
215   // \return The number of threads consuming from previous op.
216   virtual int32_t NumConsumers() const = 0;
217 
218   // \brief Getter function
219   // \return The number of threads producing to the output connector.
220   virtual int32_t NumProducers() const = 0;
221 
222   // \brief Getter function
223   // \return T/F if this is an inlined operator
inlined()224   bool inlined() const { return (oc_queue_size_ == 0); }
225 
226   // \brief Setter function, set the number of total repeats for the operator
SetTotalRepeats(int32_t total_repeats)227   void SetTotalRepeats(int32_t total_repeats) { op_total_repeats_ = total_repeats; }
228 
229   // \brief Setter function, set the number of repeats per epoch for the operator
SetNumRepeatsPerEpoch(int32_t num_repeats_per_epoch)230   void SetNumRepeatsPerEpoch(int32_t num_repeats_per_epoch) { op_num_repeats_per_epoch_ = num_repeats_per_epoch; }
231 
232   // \brief Getter function
233   // \return The number of required repeats for the operator
GetOpTotalRepeats()234   int32_t GetOpTotalRepeats() { return op_total_repeats_; }
235 
236   // \brief Getter function
237   // \return The number of repeats per epoch for the operator
GetOpNumRepeatsPerEpoch()238   int32_t GetOpNumRepeatsPerEpoch() const { return op_num_repeats_per_epoch_; }
239 
240   // \brief Register the internal worker connectors. No op unless it is a parallel op
241   // \return Status
RegisterWorkerConnectors()242   virtual Status RegisterWorkerConnectors() { return Status::OK(); }
243 
244   // \brief Getter for the column name mapping
245   // \return The returned map
column_name_id_map()246   std::unordered_map<std::string, int32_t> column_name_id_map() const { return column_name_id_map_; }
247 
248   // \brief Checks if the column name map has been set up yet for this op
249   // \return - T/F if the operator has the map set up
HasColumnNameMap()250   bool HasColumnNameMap() const { return (column_name_id_map_.empty()); }
251 
252   // \brief gives a string output for the column map for handy debug printing
253   // \return - the column name map as a string
254   std::string ColumnNameMapAsString() const;
255 
256   // \brief Getter function
257   // \return connector size of current op
ConnectorSize()258   int32_t ConnectorSize() const {
259     if (!inlined()) {
260       return out_connector_->size();
261     }
262     // Return child connector size for inlined op
263     return ChildOpConnectorSize();
264   }
265 
266   /// \brief Counting number of rows sent out by a connector
ConnectorOutRowsCount()267   int64_t ConnectorOutRowsCount() const {
268     return out_connector_ == nullptr ? int64_t(-1) : static_cast<int64_t>(out_connector_->out_rows_count());
269   }
270 
271   // \brief Getter function
272   // \return connector size of current op
ConnectorCapacity()273   int32_t ConnectorCapacity() const {
274     if (!inlined()) {
275       return out_connector_->capacity();
276     }
277     // Return child connector capacity for inlined op
278     return ChildOpConnectorCapacity();
279   }
280 
281   // \brief Getter function
282   // \return connector size of child op
283   int32_t ChildOpConnectorSize(int32_t child_index = 0) const { return child_[child_index]->ConnectorSize(); }
284 
285   // \brief Getter function
286   // \return connector capacity of child op
287   int32_t ChildOpConnectorCapacity(int32_t child_index = 0) const { return child_[child_index]->ConnectorCapacity(); }
288 
289   // \brief Children Getter
290   // \return Vector of Children
Children()291   std::vector<std::shared_ptr<DatasetOp>> Children() const { return child_; }
292 
293   // Op name getter
294   // \return Name of the current Op
295   virtual std::string Name() const = 0;
296 
297   // Op name and ID getter
298   // \return Name and ID of the current Op
NameWithID()299   std::string NameWithID() const { return Name() + "(ID:" + std::to_string(id()) + ")"; }
300 
301   // Execution Tree getter
302   // \return Pointer to the ExecutionTree the current op belongs to, no ownership
Tree()303   ExecutionTree *Tree() { return tree_; }
304 
305   // Getter for the sampler
306   // \return Shared pointer to the sampler (may return nullptr)
sampler()307   std::shared_ptr<SamplerRT> sampler() { return sampler_; }
308 
309   // \brief Getter for the sampler, and it also removes the sampler from the op
310   // \param[out] sampler A pointer to the output sampler that was removed
311   // \return Status error code
312   Status FetchRemoveSampler(std::shared_ptr<SamplerRT> *sampler);
313 
314 #ifndef ENABLE_ANDROID
315   // Computes a CRC value for the operator
316   static uint32_t GenerateCRC(const std::shared_ptr<DatasetOp> &op);
317 #endif
318 
319   // \brief A helper templated function for casting "this" pointer to shared_ptr<derived>
320   //     Similar to shared_from_this, except this one will give you the derived class as shared_ptr
321   // \return A shared_ptr casted to the derived class
322   template <typename Derived>
shared_from_base()323   std::shared_ptr<Derived> shared_from_base() {
324     return std::static_pointer_cast<Derived>(shared_from_this());
325   }
326 
327   // \brief Setter for the sampler.  Allows you to overwrite a previous sampler with a new one.
SetSampler(std::shared_ptr<SamplerRT> sampler)328   void SetSampler(std::shared_ptr<SamplerRT> sampler) { sampler_ = sampler; }
329 
330   // \brief Checks if this is a leaf node (0 children)
331   // \return boolean returns true if it's a leaf
IsLeaf()332   bool IsLeaf() { return (child_.empty()); }
333 
334   // Checks if an operator has reached its last iteration
335   // \return boolean returns true if it's last iteration
IsLastIteration()336   bool IsLastIteration() { return op_total_repeats_ == op_current_repeats_ + 1; }
337 
338   // This function is only intended to be called by CallbackManager within the master thread of ParallelOp
339   // The expected behavior is this, when this function is invoked, this function will block until all the workers
340   // have finished their remaining work and go to sleep. Since all ParallelOps use a QueueList to sync with master.
341   // They would automatically wait on the QueueList when they are done.
342   // \return Status
WaitForWorkers()343   virtual Status WaitForWorkers() { return Status::OK(); }
344 
345   // \brief Add callback to DatasetOp, only MapOp supports Callback at the moment
AddCallbacks(std::vector<std::shared_ptr<DSCallback>> callbacks)346   void AddCallbacks(std::vector<std::shared_ptr<DSCallback>> callbacks) { callback_manager_.AddCallbacks(callbacks); }
347 
348   // \brief Remove all callbacks from DatasetOp
ClearCallbacks()349   void ClearCallbacks() { callback_manager_.ClearCallbacks(); }
350 
351  protected:
352   // \brief Removes a parent operator from this operator
353   // \notes External callers do not have access to this function
354   // \param[in] parent The parent node to remove
355   void RemoveParent(const DatasetOp *parent);
356 
357   // \brief Adds a parent operator to this operator
358   // \notes External callers do not have access to this function
359   // \param[in] parent The parent node to add
360   void AddParent(DatasetOp *parent);
361 
362   // Compute the current op's column map using its child's column map.
363   // Get called during the tree post-prepare phase in PrepareOperator.
364   // This base implementation just inherits the map from child 0, and can only be used if the number of children is 1.
365   // Operations changing the column map it inherits from the child must overwrite this function.
366   // \return - Status
367   virtual Status ComputeColMap();
368 
369   // Increase op_current_repeats_ by 1 when one repeat finished.
370   // If this repeat happen to be the last repeat in the current epoch, also increase op_current_epochs_ by 1.
371   void UpdateRepeatAndEpochCounter();
372 
373   std::vector<std::shared_ptr<DatasetOp>> child_;                // Child nodes
374   std::vector<DatasetOp *> parent_;                              // Parent nodes. No ownership
375   std::shared_ptr<SamplerRT> sampler_;                           // Some leaf ops might have a sampler
376   int32_t oc_queue_size_;                                        // Capacity for each out_connector_
377   int32_t operator_id_;                                          // Generated id for the node
378   ExecutionTree *tree_;                                          // Back pointer to our tree.
379   OpState state_;                                                // The state of the operator, Running, Idle, Terminated
380   int32_t op_total_repeats_;                                     // Required number of repeats for the operator
381   int32_t op_num_repeats_per_epoch_;                             // Total number of repeats per epoch for the operator
382   int32_t op_current_repeats_;                                   // Current number of repeats the operator has handled
383   int32_t op_current_epochs_;                                    // Current number of epochs the operator has handled
384   std::unique_ptr<DbConnector> out_connector_;                   // Output Connector
385   std::unordered_map<std::string, int32_t> column_name_id_map_;  // Mapping between col index and col name
386   std::mutex column_name_map_mutex_;                             // For protecting shared access to the column map
387   CallbackManager callback_manager_;                             // Manages callbacks associated with a DatasetOp
388   int64_t dataset_size_;                                         // Size of the dataset
389   int64_t num_classes_;                                          // Number of classes
390 
391  private:
392   // Sets the operator id.
393   // \notes No public interface.  Only the class itself, or it's friend the execution tree can set
394   // this
395   // \param op_id - the Id value to set into the operator
SetId(int32_t op_id)396   void SetId(int32_t op_id) { operator_id_ = op_id; }
397 
398   // Sets the tree into the op so that the operator has a back pointer to the tree.
399   // \param tree - the tree to assign to the op.
set_tree(ExecutionTree * tree)400   void set_tree(ExecutionTree *tree) { tree_ = tree; }
401 };
402 }  // namespace dataset
403 }  // namespace mindspore
404 
405 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_DATASET_OP_H_
406