1 /** 2 * Copyright 2019-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_DATASET_OP_H_ 17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_DATASET_OP_H_ 18 19 #include <memory> 20 #include <mutex> 21 #include <string> 22 #include <unordered_map> 23 #include <vector> 24 #include <utility> 25 26 #include "minddata/dataset/callback/callback_manager.h" 27 #include "minddata/dataset/include/dataset/constants.h" 28 #include "minddata/dataset/engine/db_connector.h" 29 #include "minddata/dataset/util/status.h" 30 31 namespace mindspore { 32 namespace dataset { 33 34 constexpr char kBarrierOp[] = "BarrierOp"; 35 constexpr char kBatchOp[] = "BatchOp"; 36 constexpr char kBucketBatchByLengthOp[] = "BucketBatchByLengthOp"; 37 constexpr char kBuildSentencePieceVocabOp[] = "BuildSentencePieceVocabOp"; 38 constexpr char kBuildVocabOp[] = "BuildVocabOp"; 39 constexpr char kCacheBase[] = "CacheBase"; 40 constexpr char kCacheLookupOp[] = "CacheLookupOp"; 41 constexpr char kCacheMergeOp[] = "CacheMergeOp"; 42 constexpr char kCacheOp[] = "CacheOp"; 43 constexpr char kConcatOp[] = "ConcatOp"; 44 constexpr char kDatasetOp[] = "DatasetOp"; 45 constexpr char kDeviceQueueOp[] = "DeviceQueueOp"; 46 constexpr char kEpochCtrlOp[] = "EpochCtrlOp"; 47 constexpr char kFilterOp[] = "FilterOp"; 48 constexpr char kMapOp[] = "MapOp"; 49 constexpr char kParallelOp[] = "ParallelOp"; 50 constexpr char kPipelineOp[] = "PipelineOp"; 51 constexpr char kProjectOp[] = "ProjectOp"; 52 constexpr char kRenameOp[] = "RenameOp"; 53 constexpr char kRepeatOp[] = "RepeatOp"; 54 constexpr char kShuffleOp[] = "ShuffleOp"; 55 constexpr char kSkipOp[] = "SkipOp"; 56 constexpr char kTakeOp[] = "TakeOp"; 57 constexpr char kZipOp[] = "ZipOp"; 58 59 // Forward declare 60 class ExecutionTree; 61 62 class NodePass; 63 64 class SamplerRT; 65 66 // \brief The base class DatasetOp is the main tree node. It is an abstract class, so 67 // the actual implementation of the operators will be derived from here. 68 class DatasetOp : public std::enable_shared_from_this<DatasetOp> { 69 // Allow execution tree to access internal members 70 friend class ExecutionTree; 71 72 public: 73 static constexpr int32_t kInvalidOperatorId = -1; 74 static constexpr int32_t kInfiniteRepeat = -1; 75 76 // Flags that control operator runtime behaviours 77 enum OpState { kDeOpRunning = 0, kDeOpIdle = 1, kDeOpTerminated }; 78 79 // \brief Constructor 80 // \param op_connector_size - The size for the output connector of this operator. 81 // \param sampler - The sampler for the op 82 DatasetOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler); 83 84 // \brief Destructor ~DatasetOp()85 virtual ~DatasetOp() { tree_ = nullptr; } 86 87 // \brief Adds a operator to become our child. 88 // \param child - shared pointer to the child to add. 89 Status AddChild(std::shared_ptr<DatasetOp> child); 90 91 // \brief Remove a operator from our children. 92 // \param child - shared pointer to the child to remove. 93 Status RemoveChild(std::shared_ptr<DatasetOp> child); 94 95 // \brief Removes this node from the tree and connects it's parent/child together 96 // \return Status eerror code returned 97 Status Remove(); 98 99 // Removes child operator in this operator. 100 Status RemoveChildren(); 101 102 // \brief Getter function to get a shared pointer to our child 103 // \param[in] child_index An operator can have n children. Indicates which child to return. 104 // \return The shared pointer to the child. If there are no children, it returns null regardless of the given index 105 std::shared_ptr<DatasetOp> child(int32_t child_index) const; 106 107 // \brief Getter function to get the pointer to our parent 108 // If there are no parents, it returns null regardless of the given index 109 // \param[in] parent_index An operator can have n parents. Indicates which parent to return. 110 void Parent(DatasetOp **parent, int32_t parent_index) const; 111 112 // Getter function to get all of our parents. 113 std::vector<DatasetOp *> parents() const; 114 115 // \brief Inserts a operator as the parent current op. 116 // \notes Inserted op will become the sole parent of the current op. 117 // The existing parent of the current op will be transferred to the inserted op. 118 Status InsertAsParent(std::shared_ptr<DatasetOp> to_add); 119 120 // \brief Creates the connector within this operator 121 // \param num_producers - number of threads that write into this connector 122 // \param num_consumers - number of threads that read from this connector 123 void CreateConnector(int32_t num_producers, int32_t num_consumers); 124 125 // \brief A print method typically used for debugging 126 // \param out - The output stream to write output to 127 // \param show_all - A bool to control if you want to show all info or just a summary 128 virtual void Print(std::ostream &out, bool show_all) const; 129 130 /// \brief Gets the next row 131 /// \param row[out] - Fetched TensorRow 132 /// \return Status The status code returned 133 virtual Status GetNextRowPullMode(TensorRow *const row); 134 135 /// \brief << Stream output operator overload 136 /// \notes This allows you to write the debug print info using stream operators 137 /// \param out - reference to the output stream being overloaded 138 /// \param dO - reference to the DatasetOp to display 139 /// \return - the output stream must be returned 140 friend std::ostream &operator<<(std::ostream &out, const DatasetOp &dO) { 141 dO.Print(out, false); 142 return out; 143 } 144 145 // \brief Class functor operator (). 146 // \notes DatasetOps operate by launching a thread (see ExecutionTree). 147 // This pure virtual version makes the requirement that derived classes must provide a functor 148 // that will execute their main runtime loop code. 149 // \return Status The status code returned 150 virtual Status operator()() = 0; 151 152 /// \brief Gets the next row from the given child 153 /// \param row[out] - Fetched TensorRow 154 /// \param worker_id[in] - The worker id, default to 0. 155 /// \return Status The status code returned 156 virtual Status GetNextRow(TensorRow *row, int32_t worker_id = 0) { return GetNextRow(row, worker_id, false); } 157 158 /// \brief Gets the next row from the given child 159 /// \param row[out] - Fetched TensorRow 160 /// \param worker_id[in] - The worker id, default to 0. 161 /// \param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE. 162 /// \return Status The status code returned 163 virtual Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe); 164 165 // \brief Gets the batch size 166 // \return Status - The status code return 167 virtual int64_t GetTreeBatchSize(); 168 169 // \brief Gets the repeat count 170 // \return Status - The status code return 171 virtual int64_t GetTreeRepeatCount(); 172 173 // \brief Gets the number of classes 174 // \return Status - The status code return 175 virtual Status GetNumClasses(int64_t *num_classes); 176 177 // \brief Gets the class indexing 178 // \return Status - The status code return 179 virtual Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing); 180 181 // \brief Performs handling for when an eoe message is received. 182 // The base class implementation simply flows the eoe message to output. Derived classes 183 // may override if they need to perform special eoe handling. 184 // \param worker_id - The worker id 185 // \return Status The status code returned 186 virtual Status EoeReceived(int32_t worker_id); 187 188 // \brief Performs handling for when an eof message is received. 189 // The base class implementation simply flows the eof message to output. Derived classes 190 // may override if they need to perform special eof handling. 191 // \param worker_id - The worker id 192 // \return Status The status code returned 193 virtual Status EofReceived(int32_t worker_id); 194 195 // \brief Derived classes may implement the reset function if the operator is stateful and needs 196 // specific reset handling that is not contained in this common code version of the reset 197 // \return Status The status code returned 198 virtual Status Reset(); 199 200 // \brief During tree prepare phase, operators may have specific post-operations to perform depending on 201 // their role. 202 // \notes Derived versions of this function should always call it's superclass version first 203 // before providing their own implementations. 204 virtual Status PrepareOperator(); 205 206 // \brief Getter function 207 // \return The operator id id()208 int32_t id() const { return operator_id_; } 209 210 // \brief Getter function 211 // \return The number of workers in this op 212 virtual int32_t NumWorkers() const = 0; 213 214 // \brief Getter function 215 // \return The number of threads consuming from previous op. 216 virtual int32_t NumConsumers() const = 0; 217 218 // \brief Getter function 219 // \return The number of threads producing to the output connector. 220 virtual int32_t NumProducers() const = 0; 221 222 // \brief Getter function 223 // \return T/F if this is an inlined operator inlined()224 bool inlined() const { return (oc_queue_size_ == 0); } 225 226 // \brief Setter function, set the number of total repeats for the operator SetTotalRepeats(int32_t total_repeats)227 void SetTotalRepeats(int32_t total_repeats) { op_total_repeats_ = total_repeats; } 228 229 // \brief Setter function, set the number of repeats per epoch for the operator SetNumRepeatsPerEpoch(int32_t num_repeats_per_epoch)230 void SetNumRepeatsPerEpoch(int32_t num_repeats_per_epoch) { op_num_repeats_per_epoch_ = num_repeats_per_epoch; } 231 232 // \brief Getter function 233 // \return The number of required repeats for the operator GetOpTotalRepeats()234 int32_t GetOpTotalRepeats() { return op_total_repeats_; } 235 236 // \brief Getter function 237 // \return The number of repeats per epoch for the operator GetOpNumRepeatsPerEpoch()238 int32_t GetOpNumRepeatsPerEpoch() const { return op_num_repeats_per_epoch_; } 239 240 // \brief Register the internal worker connectors. No op unless it is a parallel op 241 // \return Status RegisterWorkerConnectors()242 virtual Status RegisterWorkerConnectors() { return Status::OK(); } 243 244 // \brief Getter for the column name mapping 245 // \return The returned map column_name_id_map()246 std::unordered_map<std::string, int32_t> column_name_id_map() const { return column_name_id_map_; } 247 248 // \brief Checks if the column name map has been set up yet for this op 249 // \return - T/F if the operator has the map set up HasColumnNameMap()250 bool HasColumnNameMap() const { return (column_name_id_map_.empty()); } 251 252 // \brief gives a string output for the column map for handy debug printing 253 // \return - the column name map as a string 254 std::string ColumnNameMapAsString() const; 255 256 // \brief Getter function 257 // \return connector size of current op ConnectorSize()258 int32_t ConnectorSize() const { 259 if (!inlined()) { 260 return out_connector_->size(); 261 } 262 // Return child connector size for inlined op 263 return ChildOpConnectorSize(); 264 } 265 266 /// \brief Counting number of rows sent out by a connector ConnectorOutRowsCount()267 int64_t ConnectorOutRowsCount() const { 268 return out_connector_ == nullptr ? int64_t(-1) : static_cast<int64_t>(out_connector_->out_rows_count()); 269 } 270 271 // \brief Getter function 272 // \return connector size of current op ConnectorCapacity()273 int32_t ConnectorCapacity() const { 274 if (!inlined()) { 275 return out_connector_->capacity(); 276 } 277 // Return child connector capacity for inlined op 278 return ChildOpConnectorCapacity(); 279 } 280 281 // \brief Getter function 282 // \return connector size of child op 283 int32_t ChildOpConnectorSize(int32_t child_index = 0) const { return child_[child_index]->ConnectorSize(); } 284 285 // \brief Getter function 286 // \return connector capacity of child op 287 int32_t ChildOpConnectorCapacity(int32_t child_index = 0) const { return child_[child_index]->ConnectorCapacity(); } 288 289 // \brief Children Getter 290 // \return Vector of Children Children()291 std::vector<std::shared_ptr<DatasetOp>> Children() const { return child_; } 292 293 // Op name getter 294 // \return Name of the current Op 295 virtual std::string Name() const = 0; 296 297 // Op name and ID getter 298 // \return Name and ID of the current Op NameWithID()299 std::string NameWithID() const { return Name() + "(ID:" + std::to_string(id()) + ")"; } 300 301 // Execution Tree getter 302 // \return Pointer to the ExecutionTree the current op belongs to, no ownership Tree()303 ExecutionTree *Tree() { return tree_; } 304 305 // Getter for the sampler 306 // \return Shared pointer to the sampler (may return nullptr) sampler()307 std::shared_ptr<SamplerRT> sampler() { return sampler_; } 308 309 // \brief Getter for the sampler, and it also removes the sampler from the op 310 // \param[out] sampler A pointer to the output sampler that was removed 311 // \return Status error code 312 Status FetchRemoveSampler(std::shared_ptr<SamplerRT> *sampler); 313 314 #ifndef ENABLE_ANDROID 315 // Computes a CRC value for the operator 316 static uint32_t GenerateCRC(const std::shared_ptr<DatasetOp> &op); 317 #endif 318 319 // \brief A helper templated function for casting "this" pointer to shared_ptr<derived> 320 // Similar to shared_from_this, except this one will give you the derived class as shared_ptr 321 // \return A shared_ptr casted to the derived class 322 template <typename Derived> shared_from_base()323 std::shared_ptr<Derived> shared_from_base() { 324 return std::static_pointer_cast<Derived>(shared_from_this()); 325 } 326 327 // \brief Setter for the sampler. Allows you to overwrite a previous sampler with a new one. SetSampler(std::shared_ptr<SamplerRT> sampler)328 void SetSampler(std::shared_ptr<SamplerRT> sampler) { sampler_ = sampler; } 329 330 // \brief Checks if this is a leaf node (0 children) 331 // \return boolean returns true if it's a leaf IsLeaf()332 bool IsLeaf() { return (child_.empty()); } 333 334 // Checks if an operator has reached its last iteration 335 // \return boolean returns true if it's last iteration IsLastIteration()336 bool IsLastIteration() { return op_total_repeats_ == op_current_repeats_ + 1; } 337 338 // This function is only intended to be called by CallbackManager within the master thread of ParallelOp 339 // The expected behavior is this, when this function is invoked, this function will block until all the workers 340 // have finished their remaining work and go to sleep. Since all ParallelOps use a QueueList to sync with master. 341 // They would automatically wait on the QueueList when they are done. 342 // \return Status WaitForWorkers()343 virtual Status WaitForWorkers() { return Status::OK(); } 344 345 // \brief Add callback to DatasetOp, only MapOp supports Callback at the moment AddCallbacks(std::vector<std::shared_ptr<DSCallback>> callbacks)346 void AddCallbacks(std::vector<std::shared_ptr<DSCallback>> callbacks) { callback_manager_.AddCallbacks(callbacks); } 347 348 // \brief Remove all callbacks from DatasetOp ClearCallbacks()349 void ClearCallbacks() { callback_manager_.ClearCallbacks(); } 350 351 protected: 352 // \brief Removes a parent operator from this operator 353 // \notes External callers do not have access to this function 354 // \param[in] parent The parent node to remove 355 void RemoveParent(const DatasetOp *parent); 356 357 // \brief Adds a parent operator to this operator 358 // \notes External callers do not have access to this function 359 // \param[in] parent The parent node to add 360 void AddParent(DatasetOp *parent); 361 362 // Compute the current op's column map using its child's column map. 363 // Get called during the tree post-prepare phase in PrepareOperator. 364 // This base implementation just inherits the map from child 0, and can only be used if the number of children is 1. 365 // Operations changing the column map it inherits from the child must overwrite this function. 366 // \return - Status 367 virtual Status ComputeColMap(); 368 369 // Increase op_current_repeats_ by 1 when one repeat finished. 370 // If this repeat happen to be the last repeat in the current epoch, also increase op_current_epochs_ by 1. 371 void UpdateRepeatAndEpochCounter(); 372 373 std::vector<std::shared_ptr<DatasetOp>> child_; // Child nodes 374 std::vector<DatasetOp *> parent_; // Parent nodes. No ownership 375 std::shared_ptr<SamplerRT> sampler_; // Some leaf ops might have a sampler 376 int32_t oc_queue_size_; // Capacity for each out_connector_ 377 int32_t operator_id_; // Generated id for the node 378 ExecutionTree *tree_; // Back pointer to our tree. 379 OpState state_; // The state of the operator, Running, Idle, Terminated 380 int32_t op_total_repeats_; // Required number of repeats for the operator 381 int32_t op_num_repeats_per_epoch_; // Total number of repeats per epoch for the operator 382 int32_t op_current_repeats_; // Current number of repeats the operator has handled 383 int32_t op_current_epochs_; // Current number of epochs the operator has handled 384 std::unique_ptr<DbConnector> out_connector_; // Output Connector 385 std::unordered_map<std::string, int32_t> column_name_id_map_; // Mapping between col index and col name 386 std::mutex column_name_map_mutex_; // For protecting shared access to the column map 387 CallbackManager callback_manager_; // Manages callbacks associated with a DatasetOp 388 int64_t dataset_size_; // Size of the dataset 389 int64_t num_classes_; // Number of classes 390 391 private: 392 // Sets the operator id. 393 // \notes No public interface. Only the class itself, or it's friend the execution tree can set 394 // this 395 // \param op_id - the Id value to set into the operator SetId(int32_t op_id)396 void SetId(int32_t op_id) { operator_id_ = op_id; } 397 398 // Sets the tree into the op so that the operator has a back pointer to the tree. 399 // \param tree - the tree to assign to the op. set_tree(ExecutionTree * tree)400 void set_tree(ExecutionTree *tree) { tree_ = tree; } 401 }; 402 } // namespace dataset 403 } // namespace mindspore 404 405 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_DATASET_OP_H_ 406