1 /** 2 * Copyright 2019-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_PARALLEL_OP_H_ 17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_PARALLEL_OP_H_ 18 19 #include <memory> 20 #include <string> 21 #include <vector> 22 #include "minddata/dataset/include/dataset/constants.h" 23 #include "minddata/dataset/engine/datasetops/dataset_op.h" 24 #include "minddata/dataset/engine/datasetops/source/io_block.h" 25 #include "minddata/dataset/util/status.h" 26 27 namespace mindspore { 28 namespace dataset { 29 // global const in our namespace 30 constexpr int32_t kEndOfActions = -1; 31 32 // Forward declares 33 class DbConnector; 34 35 // A ParallelOp provides a multi-threaded DatasetOp 36 class ParallelOp : public DatasetOp { 37 public: 38 // Constructor 39 // @param num_workers 40 // @param op_connector_size - size of the output connector for this operator 41 // @param sampler - The sampler for the op 42 ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler = nullptr); 43 44 // Destructor 45 ~ParallelOp() = default; 46 47 // Creates the internal worker connector for the parallel op if the derived class wants to use it. 48 // @notes This changes the number of producers of this op to 1, since it establishes a master/worker 49 // relationship within the op, making all production flow through a single master. 50 // @return Status - The error return code 51 Status CreateWorkerConnector(int32_t worker_connector_size); 52 53 // A print method typically used for debugging 54 // @param out - The output stream to write output to 55 // @param show_all - A bool to control if you want to show all info or just a summary 56 void Print(std::ostream &out, bool show_all) const override; Name()57 std::string Name() const override { return kParallelOp; } 58 59 // << Stream output operator overload 60 // @notes This allows you to write the debug print info using stream operators 61 // @param out - reference to the output stream being overloaded 62 // @param pO - reference to the ParallelOp to display 63 // @return - the output stream must be returned 64 friend std::ostream &operator<<(std::ostream &out, const ParallelOp &po) { 65 po.Print(out, false); 66 return out; 67 } 68 69 // Override base class reset to provide reset actions specific to the ParallelOp class. 70 // @return Status The status code returned 71 Status Reset() override; 72 73 // Getter 74 // @return the number of workers NumWorkers()75 int32_t NumWorkers() const override { return num_workers_; } 76 77 // Getter 78 // @return the number of threads consuming from the previous Connector NumConsumers()79 int32_t NumConsumers() const override { return num_workers_; } 80 81 // Getter 82 // @return the number of producers pushing to the output Connector 83 // @notes The number of producers is commonly the same as number of workers, except in the case 84 // when a worker connector is set up. In that case, there are n workers, and a single master 85 // such that only 1 thread is a producer rather than the n workers. 86 // @return the number of producers NumProducers()87 int32_t NumProducers() const override { return num_producers_; } 88 89 // Register the internal worker connectors. 90 // @return Status 91 Status RegisterWorkerConnectors() override; 92 93 protected: 94 // Interface for derived classes to implement. All derived classes must provide the entry 95 // function with the main execution loop for worker threads. 96 // @return Status The status code returned 97 virtual Status WorkerEntry(int32_t workerId) = 0; 98 99 // This function is only intended to be called by CallbackManager within the master thread of ParallelOp 100 // The expected behavior is this, when this function is invoked, this function will block until all the workers 101 // have finished their remaining work and go to sleep. Since all ParallelOps use a QueueList to sync with master. 102 // They would automatically wait on the QueueList when they are done. 103 // \return Status 104 Status WaitForWorkers() override; 105 106 // Wait post used to perform the pausing logic 107 WaitPost wait_for_workers_post_; 108 109 // Count number of workers that have signaled master 110 std::atomic_int num_workers_paused_; 111 112 // Whether or not to sync worker threads at the end of each epoch 113 bool epoch_sync_flag_; 114 115 int32_t num_workers_; // The number of worker threads 116 int32_t num_producers_; // The number of threads pushing to the out_connector_ 117 int32_t worker_connector_size_; 118 std::unique_ptr<DbConnector> worker_connector_; // The internal connector for worker threads 119 QueueList<std::unique_ptr<IOBlock>> io_block_queues_; // queues of IOBlocks 120 }; 121 } // namespace dataset 122 } // namespace mindspore 123 124 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_PARALLEL_OP_H_ 125