1 /** 2 * Copyright 2019-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_DATASET_OP_H_ 17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_DATASET_OP_H_ 18 19 #include <memory> 20 #include <mutex> 21 #include <string> 22 #include <unordered_map> 23 #include <vector> 24 #include <utility> 25 26 #include "minddata/dataset/callback/callback_manager.h" 27 #include "minddata/dataset/include/dataset/constants.h" 28 #include "minddata/dataset/engine/operator_connector.h" 29 #include "minddata/dataset/engine/perf/info_collector.h" 30 #include "minddata/dataset/util/status.h" 31 32 namespace mindspore { 33 namespace dataset { 34 35 constexpr char kBarrierOp[] = "BarrierOp"; 36 constexpr char kBatchOp[] = "BatchOp"; 37 constexpr char kBucketBatchByLengthOp[] = "BucketBatchByLengthOp"; 38 constexpr char kBuildSentencePieceVocabOp[] = "BuildSentencePieceVocabOp"; 39 constexpr char kBuildVocabOp[] = "BuildVocabOp"; 40 constexpr char kCacheBase[] = "CacheBase"; 41 constexpr char kCacheLookupOp[] = "CacheLookupOp"; 42 constexpr char kCacheMergeOp[] = "CacheMergeOp"; 43 constexpr char kCacheOp[] = "CacheOp"; 44 constexpr char kConcatOp[] = "ConcatOp"; 45 constexpr char kDatasetOp[] = "DatasetOp"; 46 constexpr char kDeviceQueueOp[] = "DataQueueOp"; 47 constexpr char kEpochCtrlOp[] = "EpochCtrlOp"; 48 constexpr char kFilterOp[] = "FilterOp"; 49 constexpr char kMapOp[] = "MapOp"; 50 constexpr char kParallelOp[] = "ParallelOp"; 51 constexpr char kPipelineOp[] = "PipelineOp"; 52 constexpr char kProjectOp[] = "ProjectOp"; 53 constexpr char kRenameOp[] = "RenameOp"; 54 constexpr char kRepeatOp[] = "RepeatOp"; 55 constexpr char kShuffleOp[] = "ShuffleOp"; 56 constexpr char kSkipOp[] = "SkipOp"; 57 constexpr char kTakeOp[] = "TakeOp"; 58 constexpr char kZipOp[] = "ZipOp"; 59 constexpr char kSendBridgeOp[] = "SendBridgeOp"; 60 constexpr char kReceiveBridgeOp[] = "ReceiveBridgeOp"; 61 62 // Forward declare 63 class ExecutionTree; 64 65 class NodePass; 66 67 class SamplerRT; 68 69 // \brief The base class DatasetOp is the main tree node. It is an abstract class, so 70 // the actual implementation of the operators will be derived from here. 71 class DatasetOp : public std::enable_shared_from_this<DatasetOp> { 72 // Allow execution tree to access internal members 73 friend class ExecutionTree; 74 75 public: 76 static constexpr int32_t kInvalidOperatorId = -1; 77 static constexpr int32_t kInfiniteRepeat = -1; 78 79 // Flags that control operator runtime behaviors 80 enum OpState { kDeOpRunning = 0, kDeOpIdle = 1, kDeOpTerminated }; 81 82 // \brief Constructor 83 // \param op_connector_size - The size for the output connector of this operator. 84 // \param sampler - The sampler for the op 85 DatasetOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler); 86 87 // \brief Destructor ~DatasetOp()88 virtual ~DatasetOp() { tree_ = nullptr; } 89 90 // \brief Adds a operator to become our child. 91 // \param child - shared pointer to the child to add. 92 Status AddChild(std::shared_ptr<DatasetOp> child); 93 94 // \brief Remove a operator from our children. 95 // \param child - shared pointer to the child to remove. 96 Status RemoveChild(std::shared_ptr<DatasetOp> child); 97 98 // \brief Removes this node from the tree and connects it's parent/child together 99 // \return Status eerror code returned 100 Status Remove(); 101 102 // Removes child operator in this operator. 103 Status RemoveChildren(); 104 105 // \brief Getter function to get a shared pointer to our child 106 // \param[in] child_index An operator can have n children. Indicates which child to return. 107 // \return The shared pointer to the child. If there are no children, it returns null regardless of the given index 108 std::shared_ptr<DatasetOp> child(int32_t child_index) const; 109 110 // \brief Getter function to get the pointer to our parent 111 // If there are no parents, it returns null regardless of the given index 112 // \param[in] parent_index An operator can have n parents. Indicates which parent to return. 113 void Parent(DatasetOp **parent, int32_t parent_index) const; 114 115 // Getter function to get all of our parents. 116 std::vector<DatasetOp *> parents() const; 117 118 virtual Status AddNewWorkers(int32_t num_new_workers = 1) { 119 return Status(StatusCode::kMDUnexpectedError, "Add new workers is not supported for non-ParallelOps"); 120 } 121 122 virtual Status RemoveWorkers(int32_t num_workers = 1) { 123 return Status(StatusCode::kMDUnexpectedError, "Remove workers is not supported for non-ParallelOps"); 124 } 125 126 // \brief Inserts a operator as the parent current op. 127 // \notes Inserted op will become the sole parent of the current op. 128 // The existing parent of the current op will be transferred to the inserted op. 129 Status InsertAsParent(std::shared_ptr<DatasetOp> to_add); 130 131 // \brief Creates the connector within this operator 132 void CreateConnector(); 133 134 // \brief A print method typically used for debugging 135 // \param out - The output stream to write output to 136 // \param show_all - A bool to control if you want to show all info or just a summary 137 virtual void Print(std::ostream &out, bool show_all) const; 138 139 /// \brief Gets the next row 140 /// \param row[out] - Fetched TensorRow 141 /// \return Status The status code returned 142 virtual Status GetNextRowPullMode(TensorRow *const row); 143 144 /// \brief << Stream output operator overload 145 /// \notes This allows you to write the debug print info using stream operators 146 /// \param out - reference to the output stream being overloaded 147 /// \param dO - reference to the DatasetOp to display 148 /// \return - the output stream must be returned 149 friend std::ostream &operator<<(std::ostream &out, const DatasetOp &dO) { 150 dO.Print(out, false); 151 return out; 152 } 153 154 // \brief Class functor operator (). 155 // \notes DatasetOps operate by launching a thread (see ExecutionTree). 156 // This pure virtual version makes the requirement that derived classes must provide a functor 157 // that will execute their main runtime loop code. 158 // \return Status The status code returned 159 virtual Status operator()() = 0; 160 161 /// \brief Gets the next row from the given child 162 /// \param row[out] - Fetched TensorRow 163 /// \return Status The status code returned 164 virtual Status GetNextRow(TensorRow *row); 165 166 // \brief Gets the batch size 167 // \return Status - The status code return 168 virtual int64_t GetTreeBatchSize(); 169 170 // \brief Gets the repeat count 171 // \return Status - The status code return 172 virtual int64_t GetTreeRepeatCount(); 173 174 // \brief Gets the number of classes 175 // \return Status - The status code return 176 virtual Status GetNumClasses(int64_t *num_classes); 177 178 // \brief Gets the class indexing 179 // \return Status - The status code return 180 virtual Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing); 181 182 // \brief Performs handling for when an eoe message is received. 183 // The base class implementation simply flows the eoe message to output. Derived classes 184 // may override if they need to perform special eoe handling. 185 // \param worker_id - The worker id 186 // \return Status The status code returned 187 virtual Status EoeReceived(int32_t worker_id); 188 189 // \brief Performs handling for when an eof message is received. 190 // The base class implementation simply flows the eof message to output. Derived classes 191 // may override if they need to perform special eof handling. 192 // \param worker_id - The worker id 193 // \return Status The status code returned 194 virtual Status EofReceived(int32_t worker_id); 195 196 // \brief Derived classes may implement the reset function if the operator is stateful and needs 197 // specific reset handling that is not contained in this common code version of the reset 198 // \return Status The status code returned 199 virtual Status Reset(); 200 201 // \brief During tree prepare phase, operators may have specific post-operations to perform depending on 202 // their role. 203 // \notes Derived versions of this function should always call their superclass version first 204 // before providing their own implementations. 205 virtual Status PrepareOperator(); 206 207 // \brief During tree prepare phase, operators may have specific post-operations to perform depending on 208 // their role. 209 // \notes Derived versions of this function should always call its superclass version first 210 // before providing their own implementations. 211 virtual Status PrepareOperatorPullBased(); 212 213 // \brief Getter function 214 // \return The operator id id()215 int32_t id() const { return operator_id_; } 216 217 // \brief Getter function 218 // \return The number of workers in this op 219 virtual int32_t NumWorkers() const = 0; 220 221 // \brief Getter function 222 // \return T/F if this is an inlined operator inlined()223 bool inlined() const { return (oc_queue_size_ == 0); } 224 225 // \brief Set the epoch number for op manually. This is only used in reset mode. 226 // \param[in] epoch The new epoch number to restart the pipeline from 227 // \return - Status 228 Status SetEpoch(const int64_t epoch); 229 230 // \brief Setter function, set the number of total repeats for the operator SetTotalRepeats(int32_t total_repeats)231 void SetTotalRepeats(int32_t total_repeats) { op_total_repeats_ = total_repeats; } 232 233 // \brief Setter function, set the number of repeats per epoch for the operator SetNumRepeatsPerEpoch(int32_t num_repeats_per_epoch)234 void SetNumRepeatsPerEpoch(int32_t num_repeats_per_epoch) { op_num_repeats_per_epoch_ = num_repeats_per_epoch; } 235 236 // \brief Getter function 237 // \return The number of required repeats for the operator GetOpTotalRepeats()238 int32_t GetOpTotalRepeats() { return op_total_repeats_; } 239 240 // \brief Getter function 241 // \return The number of repeats per epoch for the operator GetOpNumRepeatsPerEpoch()242 int32_t GetOpNumRepeatsPerEpoch() const { return op_num_repeats_per_epoch_; } 243 244 // \brief Register the internal worker connectors. No op unless it is a parallel op 245 // \return Status RegisterWorkerConnectors()246 virtual Status RegisterWorkerConnectors() { return Status::OK(); } 247 248 // \brief Getter for the column name mapping 249 // \return The returned map column_name_id_map()250 std::unordered_map<std::string, int32_t> column_name_id_map() const { return column_name_id_map_; } 251 252 // \brief Checks if the column name map has been set up yet for this op 253 // \return - T/F if the operator has the map set up HasColumnNameMap()254 bool HasColumnNameMap() const { return (column_name_id_map_.empty()); } 255 256 // \brief gives a string output for the column map for handy debug printing 257 // \return - the column name map as a string 258 std::string ColumnNameMapAsString() const; 259 OutputConnector()260 OperatorConnector *OutputConnector() const { return out_connector_.get(); } 261 262 // \brief Getter function 263 // \return connector size of current op ConnectorSize()264 int32_t ConnectorSize() const { 265 if (!inlined()) { 266 return out_connector_->size(); 267 } 268 // Return child connector size for inlined op 269 return ChildOpConnectorSize(); 270 } 271 272 /// \brief Counting number of rows sent out by a connector ConnectorOutRowsCount()273 int64_t ConnectorOutRowsCount() const { 274 return out_connector_ == nullptr ? int64_t(-1) : static_cast<int64_t>(out_connector_->out_rows_count()); 275 } 276 277 // \brief Getter function 278 // \return connector size of current op ConnectorCapacity()279 int32_t ConnectorCapacity() const { 280 if (!inlined()) { 281 return out_connector_->capacity(); 282 } 283 // Return child connector capacity for inlined op 284 return ChildOpConnectorCapacity(); 285 } 286 287 // \brief Getter function 288 // \return connector size of child op 289 int32_t ChildOpConnectorSize(int32_t child_index = 0) const { return child_[child_index]->ConnectorSize(); } 290 291 // \brief Getter function 292 // \return connector capacity of child op 293 int32_t ChildOpConnectorCapacity(int32_t child_index = 0) const { return child_[child_index]->ConnectorCapacity(); } 294 295 // \brief Children Getter 296 // \return Vector of Children Children()297 std::vector<std::shared_ptr<DatasetOp>> Children() const { return child_; } 298 299 // Op name getter 300 // \return Name of the current Op 301 virtual std::string Name() const = 0; 302 303 // Op name and ID getter 304 // \return Name and ID of the current Op NameWithID()305 std::string NameWithID() const { return Name() + "(ID:" + std::to_string(id()) + ")"; } 306 307 // Execution Tree getter 308 // \return Pointer to the ExecutionTree the current op belongs to, no ownership Tree()309 ExecutionTree *Tree() { return tree_; } 310 311 // Getter for the sampler 312 // \return Shared pointer to the sampler (may return nullptr) sampler()313 std::shared_ptr<SamplerRT> sampler() { return sampler_; } 314 315 // \brief Getter for the sampler, and it also removes the sampler from the op 316 // \param[out] sampler A pointer to the output sampler that was removed 317 // \return Status error code 318 Status FetchRemoveSampler(std::shared_ptr<SamplerRT> *sampler); 319 320 #ifndef ENABLE_ANDROID 321 // Computes a CRC value for the operator 322 static uint32_t GenerateCRC(const std::shared_ptr<DatasetOp> &op); 323 #endif 324 325 // \brief A helper templated function for casting "this" pointer to shared_ptr<derived> 326 // Similar to shared_from_this, except this one will give you the derived class as shared_ptr 327 // \return A shared_ptr casted to the derived class 328 template <typename Derived> shared_from_base()329 std::shared_ptr<Derived> shared_from_base() { 330 return std::static_pointer_cast<Derived>(shared_from_this()); 331 } 332 333 // \brief Setter for the sampler. Allows you to overwrite a previous sampler with a new one. SetSampler(std::shared_ptr<SamplerRT> sampler)334 void SetSampler(std::shared_ptr<SamplerRT> sampler) { sampler_ = sampler; } 335 336 // \brief Checks if this is a leaf node (0 children) 337 // \return boolean returns true if it's a leaf IsLeaf()338 bool IsLeaf() { return (child_.empty()); } 339 340 // Checks if an operator has reached its last iteration 341 // \return boolean returns true if it's last iteration IsLastIteration()342 bool IsLastIteration() { return op_total_repeats_ == op_current_repeats_ + 1; } 343 344 // This function is only intended to be called by CallbackManager within the master thread of ParallelOp 345 // The expected behavior is this, when this function is invoked, this function will block until all the workers 346 // have finished their remaining work and go to sleep. Since all ParallelOps use a QueueList to sync with master. 347 // They would automatically wait on the QueueList when they are done. 348 // \return Status WaitForWorkers()349 virtual Status WaitForWorkers() { return Status::OK(); } 350 PostForWorkers()351 virtual Status PostForWorkers() { return Status::OK(); } 352 NumWorkers()353 virtual int32_t NumWorkers() { return 0; } 354 SendQuitFlagToWorker(int32_t worker_id)355 virtual Status SendQuitFlagToWorker(int32_t worker_id) { return Status::OK(); } 356 SendWaitFlagToWorker(int32_t worker_id)357 virtual Status SendWaitFlagToWorker(int32_t worker_id) { return Status::OK(); } 358 359 // \brief Add callback to DatasetOp, only MapOp supports Callback at the moment AddCallbacks(std::vector<std::shared_ptr<DSCallback>> callbacks)360 void AddCallbacks(std::vector<std::shared_ptr<DSCallback>> callbacks) { callback_manager_.AddCallbacks(callbacks); } 361 362 // \brief Remove all callbacks from DatasetOp ClearCallbacks()363 void ClearCallbacks() { callback_manager_.ClearCallbacks(); } 364 IsPython()365 virtual bool IsPython() const { return false; } 366 367 virtual std::vector<int32_t> GetMPWorkerPIDs() const; 368 369 protected: 370 // \brief Removes a parent operator from this operator 371 // \notes External callers do not have access to this function 372 // \param[in] parent The parent node to remove 373 void RemoveParent(const DatasetOp *parent); 374 375 // \brief Adds a parent operator to this operator 376 // \notes External callers do not have access to this function 377 // \param[in] parent The parent node to add 378 void AddParent(DatasetOp *parent); 379 380 // Compute the current op's column map using its child's column map. 381 // Get called during the tree post-prepare phase in PrepareOperator. 382 // This base implementation just inherits the map from child 0, and can only be used if the number of children is 1. 383 // Operations changing the column map it inherits from the child must overwrite this function. 384 // \return - Status 385 virtual Status ComputeColMap(); 386 387 // Increase op_current_repeats_ by 1 when one repeat finished. 388 // If this repeat happen to be the last repeat in the current epoch, also increase op_current_epochs_ by 1. 389 void UpdateRepeatAndEpochCounter(); 390 391 // Launch the Op Launch()392 virtual Status Launch() { return Status::OK(); } 393 394 enum ImplementedPullMode { NotImplemented = 0, Implemented, DisabledDebugMode }; 395 /// \brief Gets the implementation status for operator in pull mode 396 /// \return implementation status PullModeImplementationStatus()397 virtual ImplementedPullMode PullModeImplementationStatus() const { return ImplementedPullMode::NotImplemented; } 398 399 std::vector<std::shared_ptr<DatasetOp>> child_; // Child nodes 400 std::vector<DatasetOp *> parent_; // Parent nodes. No ownership 401 std::shared_ptr<SamplerRT> sampler_; // Some leaf ops might have a sampler 402 int32_t oc_queue_size_; // Capacity for each out_connector_ 403 int32_t operator_id_; // Generated id for the node 404 ExecutionTree *tree_; // Back pointer to our tree. 405 OpState state_; // The state of the operator, Running, Idle, Terminated 406 int32_t op_total_repeats_; // Required number of repeats for the operator 407 int32_t op_num_repeats_per_epoch_; // Total number of repeats per epoch for the operator 408 int32_t op_current_repeats_; // Current number of repeats the operator has handled 409 int32_t op_current_epochs_; // Current number of epochs the operator has handled 410 std::unique_ptr<OperatorConnector> out_connector_; // Output Connector 411 std::unordered_map<std::string, int32_t> column_name_id_map_; // Mapping between col index and col name 412 std::mutex column_name_map_mutex_; // For protecting shared access to the column map 413 CallbackManager callback_manager_; // Manages callbacks associated with a DatasetOp 414 int64_t dataset_size_; // Size of the dataset 415 int64_t num_classes_; // Number of classes 416 417 private: 418 // Sets the operator id. 419 // \notes No public interface. Only the class itself, or it's friend the execution tree can set 420 // this 421 // \param op_id - the Id value to set into the operator SetId(int32_t op_id)422 void SetId(int32_t op_id) { operator_id_ = op_id; } 423 424 // Sets the tree into the op so that the operator has a back pointer to the tree. 425 // \param tree - the tree to assign to the op. set_tree(ExecutionTree * tree)426 void set_tree(ExecutionTree *tree) { tree_ = tree; } 427 }; 428 } // namespace dataset 429 } // namespace mindspore 430 431 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_DATASET_OP_H_ 432