• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_PARALLEL_OP_H_
17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_PARALLEL_OP_H_
18 
19 #include <memory>
20 #include <string>
21 #include <vector>
22 #include "minddata/dataset/include/dataset/constants.h"
23 #include "minddata/dataset/engine/datasetops/dataset_op.h"
24 #include "minddata/dataset/engine/datasetops/source/io_block.h"
25 #include "minddata/dataset/util/status.h"
26 
27 namespace mindspore {
28 namespace dataset {
29 // global const in our namespace
30 constexpr int32_t kEndOfActions = -1;
31 
32 // Forward declares
33 class DbConnector;
34 
35 // A ParallelOp provides a multi-threaded DatasetOp
36 class ParallelOp : public DatasetOp {
37  public:
38   // Constructor
39   // @param num_workers
40   // @param op_connector_size - size of the output connector for this operator
41   // @param sampler - The sampler for the op
42   ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler = nullptr);
43 
44   // Destructor
45   ~ParallelOp() = default;
46 
47   // Creates the internal worker connector for the parallel op if the derived class wants to use it.
48   // @notes This changes the number of producers of this op to 1, since it establishes a master/worker
49   // relationship within the op, making all production flow through a single master.
50   // @return Status - The error return code
51   Status CreateWorkerConnector(int32_t worker_connector_size);
52 
53   // A print method typically used for debugging
54   // @param out - The output stream to write output to
55   // @param show_all - A bool to control if you want to show all info or just a summary
56   void Print(std::ostream &out, bool show_all) const override;
Name()57   std::string Name() const override { return kParallelOp; }
58 
59   // << Stream output operator overload
60   // @notes This allows you to write the debug print info using stream operators
61   // @param out - reference to the output stream being overloaded
62   // @param pO - reference to the ParallelOp to display
63   // @return - the output stream must be returned
64   friend std::ostream &operator<<(std::ostream &out, const ParallelOp &po) {
65     po.Print(out, false);
66     return out;
67   }
68 
69   // Override base class reset to provide reset actions specific to the ParallelOp class.
70   // @return Status The status code returned
71   Status Reset() override;
72 
73   // Getter
74   // @return the number of workers
NumWorkers()75   int32_t NumWorkers() const override { return num_workers_; }
76 
77   // Getter
78   // @return the number of threads consuming from the previous Connector
NumConsumers()79   int32_t NumConsumers() const override { return num_workers_; }
80 
81   // Getter
82   // @return the number of producers pushing to the output Connector
83   // @notes The number of producers is commonly the same as number of workers, except in the case
84   // when a worker connector is set up.  In that case, there are n workers, and a single master
85   // such that only 1 thread is a producer rather than the n workers.
86   // @return the number of producers
NumProducers()87   int32_t NumProducers() const override { return num_producers_; }
88 
89   // Register the internal worker connectors.
90   // @return Status
91   Status RegisterWorkerConnectors() override;
92 
93  protected:
94   // Interface for derived classes to implement. All derived classes must provide the entry
95   // function with the main execution loop for worker threads.
96   // @return Status The status code returned
97   virtual Status WorkerEntry(int32_t workerId) = 0;
98 
99   // This function is only intended to be called by CallbackManager within the master thread of ParallelOp
100   // The expected behavior is this, when this function is invoked, this function will block until all the workers
101   // have finished their remaining work and go to sleep. Since all ParallelOps use a QueueList to sync with master.
102   // They would automatically wait on the QueueList when they are done.
103   // \return Status
104   Status WaitForWorkers() override;
105 
106   // Wait post used to perform the pausing logic
107   WaitPost wait_for_workers_post_;
108 
109   // Count number of workers that have signaled master
110   std::atomic_int num_workers_paused_;
111 
112   // Whether or not to sync worker threads at the end of each epoch
113   bool epoch_sync_flag_;
114 
115   int32_t num_workers_;    // The number of worker threads
116   int32_t num_producers_;  // The number of threads pushing to the out_connector_
117   int32_t worker_connector_size_;
118   std::unique_ptr<DbConnector> worker_connector_;        // The internal connector for worker threads
119   QueueList<std::unique_ptr<IOBlock>> io_block_queues_;  // queues of IOBlocks
120 };
121 }  // namespace dataset
122 }  // namespace mindspore
123 
124 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_PARALLEL_OP_H_
125