• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_DATASET_OP_H_
17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_DATASET_OP_H_
18 
19 #include <memory>
20 #include <mutex>
21 #include <string>
22 #include <unordered_map>
23 #include <vector>
24 #include <utility>
25 
26 #include "minddata/dataset/callback/callback_manager.h"
27 #include "minddata/dataset/include/dataset/constants.h"
28 #include "minddata/dataset/engine/operator_connector.h"
29 #include "minddata/dataset/engine/perf/info_collector.h"
30 #include "minddata/dataset/util/status.h"
31 
32 namespace mindspore {
33 namespace dataset {
34 
35 constexpr char kBarrierOp[] = "BarrierOp";
36 constexpr char kBatchOp[] = "BatchOp";
37 constexpr char kBucketBatchByLengthOp[] = "BucketBatchByLengthOp";
38 constexpr char kBuildSentencePieceVocabOp[] = "BuildSentencePieceVocabOp";
39 constexpr char kBuildVocabOp[] = "BuildVocabOp";
40 constexpr char kCacheBase[] = "CacheBase";
41 constexpr char kCacheLookupOp[] = "CacheLookupOp";
42 constexpr char kCacheMergeOp[] = "CacheMergeOp";
43 constexpr char kCacheOp[] = "CacheOp";
44 constexpr char kConcatOp[] = "ConcatOp";
45 constexpr char kDatasetOp[] = "DatasetOp";
46 constexpr char kDeviceQueueOp[] = "DataQueueOp";
47 constexpr char kEpochCtrlOp[] = "EpochCtrlOp";
48 constexpr char kFilterOp[] = "FilterOp";
49 constexpr char kMapOp[] = "MapOp";
50 constexpr char kParallelOp[] = "ParallelOp";
51 constexpr char kPipelineOp[] = "PipelineOp";
52 constexpr char kProjectOp[] = "ProjectOp";
53 constexpr char kRenameOp[] = "RenameOp";
54 constexpr char kRepeatOp[] = "RepeatOp";
55 constexpr char kShuffleOp[] = "ShuffleOp";
56 constexpr char kSkipOp[] = "SkipOp";
57 constexpr char kTakeOp[] = "TakeOp";
58 constexpr char kZipOp[] = "ZipOp";
59 constexpr char kSendBridgeOp[] = "SendBridgeOp";
60 constexpr char kReceiveBridgeOp[] = "ReceiveBridgeOp";
61 
62 // Forward declare
63 class ExecutionTree;
64 
65 class NodePass;
66 
67 class SamplerRT;
68 
69 // \brief The base class DatasetOp is the main tree node.  It is an abstract class, so
70 // the actual implementation of the operators will be derived from here.
71 class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
72   // Allow execution tree to access internal members
73   friend class ExecutionTree;
74 
75  public:
76   static constexpr int32_t kInvalidOperatorId = -1;
77   static constexpr int32_t kInfiniteRepeat = -1;
78 
79   // Flags that control operator runtime behaviors
80   enum OpState { kDeOpRunning = 0, kDeOpIdle = 1, kDeOpTerminated };
81 
82   // \brief Constructor
83   // \param op_connector_size - The size for the output connector of this operator.
84   // \param sampler - The sampler for the op
85   DatasetOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler);
86 
87   // \brief Destructor
~DatasetOp()88   virtual ~DatasetOp() { tree_ = nullptr; }
89 
90   // \brief Adds a operator to become our child.
91   // \param child - shared pointer to the child to add.
92   Status AddChild(std::shared_ptr<DatasetOp> child);
93 
94   // \brief Remove a operator from our children.
95   // \param child - shared pointer to the child to remove.
96   Status RemoveChild(std::shared_ptr<DatasetOp> child);
97 
98   // \brief Removes this node from the tree and connects it's parent/child together
99   // \return Status eerror code returned
100   Status Remove();
101 
102   // Removes child operator in this operator.
103   Status RemoveChildren();
104 
105   // \brief Getter function to get a shared pointer to our child
106   // \param[in] child_index An operator can have n children. Indicates which child to return.
107   // \return The shared pointer to the child.  If there are no children, it returns null regardless of the given index
108   std::shared_ptr<DatasetOp> child(int32_t child_index) const;
109 
110   // \brief Getter function to get the pointer to our parent
111   //     If there are no parents, it returns null regardless of the given index
112   // \param[in] parent_index An operator can have n parents. Indicates which parent to return.
113   void Parent(DatasetOp **parent, int32_t parent_index) const;
114 
115   // Getter function to get all of our parents.
116   std::vector<DatasetOp *> parents() const;
117 
118   virtual Status AddNewWorkers(int32_t num_new_workers = 1) {
119     return Status(StatusCode::kMDUnexpectedError, "Add new workers is not supported for non-ParallelOps");
120   }
121 
122   virtual Status RemoveWorkers(int32_t num_workers = 1) {
123     return Status(StatusCode::kMDUnexpectedError, "Remove workers is not supported for non-ParallelOps");
124   }
125 
126   // \brief Inserts a operator as the parent current op.
127   // \notes Inserted op will become the sole parent of the current op.
128   //     The existing parent of the current op will be transferred to the inserted op.
129   Status InsertAsParent(std::shared_ptr<DatasetOp> to_add);
130 
131   // \brief Creates the connector within this operator
132   void CreateConnector();
133 
134   // \brief A print method typically used for debugging
135   // \param out - The output stream to write output to
136   // \param show_all - A bool to control if you want to show all info or just a summary
137   virtual void Print(std::ostream &out, bool show_all) const;
138 
139   /// \brief Gets the next row
140   /// \param row[out] - Fetched TensorRow
141   /// \return Status The status code returned
142   virtual Status GetNextRowPullMode(TensorRow *const row);
143 
144   /// \brief << Stream output operator overload
145   /// \notes This allows you to write the debug print info using stream operators
146   /// \param out - reference to the output stream being overloaded
147   /// \param dO - reference to the DatasetOp to display
148   /// \return - the output stream must be returned
149   friend std::ostream &operator<<(std::ostream &out, const DatasetOp &dO) {
150     dO.Print(out, false);
151     return out;
152   }
153 
154   // \brief Class functor operator ().
155   // \notes DatasetOps operate by launching a thread (see ExecutionTree).
156   //     This pure virtual version makes the requirement that derived classes must provide a functor
157   //     that will execute their main runtime loop code.
158   // \return Status The status code returned
159   virtual Status operator()() = 0;
160 
161   /// \brief Gets the next row from the given child
162   /// \param row[out] - Fetched TensorRow
163   /// \return Status The status code returned
164   virtual Status GetNextRow(TensorRow *row);
165 
166   // \brief Gets the batch size
167   // \return Status - The status code return
168   virtual int64_t GetTreeBatchSize();
169 
170   // \brief Gets the repeat count
171   // \return Status - The status code return
172   virtual int64_t GetTreeRepeatCount();
173 
174   // \brief Gets the number of classes
175   // \return Status - The status code return
176   virtual Status GetNumClasses(int64_t *num_classes);
177 
178   // \brief Gets the class indexing
179   // \return Status - The status code return
180   virtual Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing);
181 
182   // \brief Performs handling for when an eoe message is received.
183   //     The base class implementation simply flows the eoe message to output. Derived classes
184   //     may override if they need to perform special eoe handling.
185   // \param worker_id - The worker id
186   // \return Status The status code returned
187   virtual Status EoeReceived(int32_t worker_id);
188 
189   // \brief Performs handling for when an eof message is received.
190   //     The base class implementation simply flows the eof message to output. Derived classes
191   //     may override if they need to perform special eof handling.
192   // \param worker_id - The worker id
193   // \return Status The status code returned
194   virtual Status EofReceived(int32_t worker_id);
195 
196   // \brief Derived classes may implement the reset function if the operator is stateful and needs
197   //     specific reset handling that is not contained in this common code version of the reset
198   // \return Status The status code returned
199   virtual Status Reset();
200 
201   // \brief During tree prepare phase, operators may have specific post-operations to perform depending on
202   //     their role.
203   // \notes Derived versions of this function should always call their superclass version first
204   //     before providing their own implementations.
205   virtual Status PrepareOperator();
206 
207   // \brief During tree prepare phase, operators may have specific post-operations to perform depending on
208   //     their role.
209   // \notes Derived versions of this function should always call its superclass version first
210   //     before providing their own implementations.
211   virtual Status PrepareOperatorPullBased();
212 
213   // \brief Getter function
214   // \return The operator id
id()215   int32_t id() const { return operator_id_; }
216 
217   // \brief Getter function
218   // \return The number of workers in this op
219   virtual int32_t NumWorkers() const = 0;
220 
221   // \brief Getter function
222   // \return T/F if this is an inlined operator
inlined()223   bool inlined() const { return (oc_queue_size_ == 0); }
224 
225   // \brief Set the epoch number for op manually. This is only used in reset mode.
226   // \param[in] epoch The new epoch number to restart the pipeline from
227   // \return - Status
228   Status SetEpoch(const int64_t epoch);
229 
230   // \brief Setter function, set the number of total repeats for the operator
SetTotalRepeats(int32_t total_repeats)231   void SetTotalRepeats(int32_t total_repeats) { op_total_repeats_ = total_repeats; }
232 
233   // \brief Setter function, set the number of repeats per epoch for the operator
SetNumRepeatsPerEpoch(int32_t num_repeats_per_epoch)234   void SetNumRepeatsPerEpoch(int32_t num_repeats_per_epoch) { op_num_repeats_per_epoch_ = num_repeats_per_epoch; }
235 
236   // \brief Getter function
237   // \return The number of required repeats for the operator
GetOpTotalRepeats()238   int32_t GetOpTotalRepeats() { return op_total_repeats_; }
239 
240   // \brief Getter function
241   // \return The number of repeats per epoch for the operator
GetOpNumRepeatsPerEpoch()242   int32_t GetOpNumRepeatsPerEpoch() const { return op_num_repeats_per_epoch_; }
243 
244   // \brief Register the internal worker connectors. No op unless it is a parallel op
245   // \return Status
RegisterWorkerConnectors()246   virtual Status RegisterWorkerConnectors() { return Status::OK(); }
247 
248   // \brief Getter for the column name mapping
249   // \return The returned map
column_name_id_map()250   std::unordered_map<std::string, int32_t> column_name_id_map() const { return column_name_id_map_; }
251 
252   // \brief Checks if the column name map has been set up yet for this op
253   // \return - T/F if the operator has the map set up
HasColumnNameMap()254   bool HasColumnNameMap() const { return (column_name_id_map_.empty()); }
255 
256   // \brief gives a string output for the column map for handy debug printing
257   // \return - the column name map as a string
258   std::string ColumnNameMapAsString() const;
259 
OutputConnector()260   OperatorConnector *OutputConnector() const { return out_connector_.get(); }
261 
262   // \brief Getter function
263   // \return connector size of current op
ConnectorSize()264   int32_t ConnectorSize() const {
265     if (!inlined()) {
266       return out_connector_->size();
267     }
268     // Return child connector size for inlined op
269     return ChildOpConnectorSize();
270   }
271 
272   /// \brief Counting number of rows sent out by a connector
ConnectorOutRowsCount()273   int64_t ConnectorOutRowsCount() const {
274     return out_connector_ == nullptr ? int64_t(-1) : static_cast<int64_t>(out_connector_->out_rows_count());
275   }
276 
277   // \brief Getter function
278   // \return connector size of current op
ConnectorCapacity()279   int32_t ConnectorCapacity() const {
280     if (!inlined()) {
281       return out_connector_->capacity();
282     }
283     // Return child connector capacity for inlined op
284     return ChildOpConnectorCapacity();
285   }
286 
287   // \brief Getter function
288   // \return connector size of child op
289   int32_t ChildOpConnectorSize(int32_t child_index = 0) const { return child_[child_index]->ConnectorSize(); }
290 
291   // \brief Getter function
292   // \return connector capacity of child op
293   int32_t ChildOpConnectorCapacity(int32_t child_index = 0) const { return child_[child_index]->ConnectorCapacity(); }
294 
295   // \brief Children Getter
296   // \return Vector of Children
Children()297   std::vector<std::shared_ptr<DatasetOp>> Children() const { return child_; }
298 
299   // Op name getter
300   // \return Name of the current Op
301   virtual std::string Name() const = 0;
302 
303   // Op name and ID getter
304   // \return Name and ID of the current Op
NameWithID()305   std::string NameWithID() const { return Name() + "(ID:" + std::to_string(id()) + ")"; }
306 
307   // Execution Tree getter
308   // \return Pointer to the ExecutionTree the current op belongs to, no ownership
Tree()309   ExecutionTree *Tree() { return tree_; }
310 
311   // Getter for the sampler
312   // \return Shared pointer to the sampler (may return nullptr)
sampler()313   std::shared_ptr<SamplerRT> sampler() { return sampler_; }
314 
315   // \brief Getter for the sampler, and it also removes the sampler from the op
316   // \param[out] sampler A pointer to the output sampler that was removed
317   // \return Status error code
318   Status FetchRemoveSampler(std::shared_ptr<SamplerRT> *sampler);
319 
320 #ifndef ENABLE_ANDROID
321   // Computes a CRC value for the operator
322   static uint32_t GenerateCRC(const std::shared_ptr<DatasetOp> &op);
323 #endif
324 
325   // \brief A helper templated function for casting "this" pointer to shared_ptr<derived>
326   //     Similar to shared_from_this, except this one will give you the derived class as shared_ptr
327   // \return A shared_ptr casted to the derived class
328   template <typename Derived>
shared_from_base()329   std::shared_ptr<Derived> shared_from_base() {
330     return std::static_pointer_cast<Derived>(shared_from_this());
331   }
332 
333   // \brief Setter for the sampler.  Allows you to overwrite a previous sampler with a new one.
SetSampler(std::shared_ptr<SamplerRT> sampler)334   void SetSampler(std::shared_ptr<SamplerRT> sampler) { sampler_ = sampler; }
335 
336   // \brief Checks if this is a leaf node (0 children)
337   // \return boolean returns true if it's a leaf
IsLeaf()338   bool IsLeaf() { return (child_.empty()); }
339 
340   // Checks if an operator has reached its last iteration
341   // \return boolean returns true if it's last iteration
IsLastIteration()342   bool IsLastIteration() { return op_total_repeats_ == op_current_repeats_ + 1; }
343 
344   // This function is only intended to be called by CallbackManager within the master thread of ParallelOp
345   // The expected behavior is this, when this function is invoked, this function will block until all the workers
346   // have finished their remaining work and go to sleep. Since all ParallelOps use a QueueList to sync with master.
347   // They would automatically wait on the QueueList when they are done.
348   // \return Status
WaitForWorkers()349   virtual Status WaitForWorkers() { return Status::OK(); }
350 
PostForWorkers()351   virtual Status PostForWorkers() { return Status::OK(); }
352 
NumWorkers()353   virtual int32_t NumWorkers() { return 0; }
354 
SendQuitFlagToWorker(int32_t worker_id)355   virtual Status SendQuitFlagToWorker(int32_t worker_id) { return Status::OK(); }
356 
SendWaitFlagToWorker(int32_t worker_id)357   virtual Status SendWaitFlagToWorker(int32_t worker_id) { return Status::OK(); }
358 
359   // \brief Add callback to DatasetOp, only MapOp supports Callback at the moment
AddCallbacks(std::vector<std::shared_ptr<DSCallback>> callbacks)360   void AddCallbacks(std::vector<std::shared_ptr<DSCallback>> callbacks) { callback_manager_.AddCallbacks(callbacks); }
361 
362   // \brief Remove all callbacks from DatasetOp
ClearCallbacks()363   void ClearCallbacks() { callback_manager_.ClearCallbacks(); }
364 
IsPython()365   virtual bool IsPython() const { return false; }
366 
367   virtual std::vector<int32_t> GetMPWorkerPIDs() const;
368 
369  protected:
370   // \brief Removes a parent operator from this operator
371   // \notes External callers do not have access to this function
372   // \param[in] parent The parent node to remove
373   void RemoveParent(const DatasetOp *parent);
374 
375   // \brief Adds a parent operator to this operator
376   // \notes External callers do not have access to this function
377   // \param[in] parent The parent node to add
378   void AddParent(DatasetOp *parent);
379 
380   // Compute the current op's column map using its child's column map.
381   // Get called during the tree post-prepare phase in PrepareOperator.
382   // This base implementation just inherits the map from child 0, and can only be used if the number of children is 1.
383   // Operations changing the column map it inherits from the child must overwrite this function.
384   // \return - Status
385   virtual Status ComputeColMap();
386 
387   // Increase op_current_repeats_ by 1 when one repeat finished.
388   // If this repeat happen to be the last repeat in the current epoch, also increase op_current_epochs_ by 1.
389   void UpdateRepeatAndEpochCounter();
390 
391   // Launch the Op
Launch()392   virtual Status Launch() { return Status::OK(); }
393 
394   enum ImplementedPullMode { NotImplemented = 0, Implemented, DisabledDebugMode };
395   /// \brief Gets the implementation status for operator in pull mode
396   /// \return implementation status
PullModeImplementationStatus()397   virtual ImplementedPullMode PullModeImplementationStatus() const { return ImplementedPullMode::NotImplemented; }
398 
399   std::vector<std::shared_ptr<DatasetOp>> child_;                // Child nodes
400   std::vector<DatasetOp *> parent_;                              // Parent nodes. No ownership
401   std::shared_ptr<SamplerRT> sampler_;                           // Some leaf ops might have a sampler
402   int32_t oc_queue_size_;                                        // Capacity for each out_connector_
403   int32_t operator_id_;                                          // Generated id for the node
404   ExecutionTree *tree_;                                          // Back pointer to our tree.
405   OpState state_;                                                // The state of the operator, Running, Idle, Terminated
406   int32_t op_total_repeats_;                                     // Required number of repeats for the operator
407   int32_t op_num_repeats_per_epoch_;                             // Total number of repeats per epoch for the operator
408   int32_t op_current_repeats_;                                   // Current number of repeats the operator has handled
409   int32_t op_current_epochs_;                                    // Current number of epochs the operator has handled
410   std::unique_ptr<OperatorConnector> out_connector_;             // Output Connector
411   std::unordered_map<std::string, int32_t> column_name_id_map_;  // Mapping between col index and col name
412   std::mutex column_name_map_mutex_;                             // For protecting shared access to the column map
413   CallbackManager callback_manager_;                             // Manages callbacks associated with a DatasetOp
414   int64_t dataset_size_;                                         // Size of the dataset
415   int64_t num_classes_;                                          // Number of classes
416 
417  private:
418   // Sets the operator id.
419   // \notes No public interface.  Only the class itself, or it's friend the execution tree can set
420   // this
421   // \param op_id - the Id value to set into the operator
SetId(int32_t op_id)422   void SetId(int32_t op_id) { operator_id_ = op_id; }
423 
424   // Sets the tree into the op so that the operator has a back pointer to the tree.
425   // \param tree - the tree to assign to the op.
set_tree(ExecutionTree * tree)426   void set_tree(ExecutionTree *tree) { tree_ = tree; }
427 };
428 }  // namespace dataset
429 }  // namespace mindspore
430 
431 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_DATASET_OP_H_
432