1 /** 2 * Copyright 2019-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_REPEAT_OP_H_ 17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_REPEAT_OP_H_ 18 19 #include <memory> 20 #include <string> 21 #include <utility> 22 #include <vector> 23 #include "minddata/dataset/engine/datasetops/pipeline_op.h" 24 25 namespace mindspore { 26 namespace dataset { 27 class RepeatOp : public PipelineOp { 28 public: 29 // Constructor of the RepeatOp. 30 // @note The builder class should be used to call it 31 // @param count - The number of repeats to do 32 explicit RepeatOp(int32_t count); 33 34 // Destructor 35 ~RepeatOp(); 36 37 // A print method typically used for debugging 38 // @param out - The output stream to write output to 39 // @param show_all - A bool to control if you want to show all info or just a summary 40 void Print(std::ostream &out, bool show_all) const override; 41 42 // << Stream output operator overload 43 // @notes This allows you to write the debug print info using stream operators 44 // @param out - reference to the output stream being overloaded 45 // @param ro - reference to the RepeatOp to display 46 // @return - the output stream must be returned 47 friend std::ostream &operator<<(std::ostream &out, const RepeatOp &ro) { 48 ro.Print(out, false); 49 return out; 50 } 51 52 // Class functor operator () override. 53 // Most dataset ops operate by launching a thread (see ExecutionTree). 54 // However, the RepeatOp is defined as a inlined operator, so it is invalid to launch the 55 // functor since this op runs inlined inside another operator. The function is overloaded to 56 // ensure that it is not called by mistake (it will generate an error). 57 // @return Status The status code returned 58 Status operator()() override; 59 60 // This function returns the row that is at the top of our output connector. The caller is 61 // typically our parent node, when the parent is asking us to provide the next row of data. 62 // Since RepeatOp is an inlined op, getting a row from us will simply bounce you to get 63 // a row from our child. 64 // @note This function sets the `retryIfEoe` flag when popping from the child connector. This way, 65 // this function will retry to pop the connector again and will get the non-EOE row if any. 66 // @param row - output pointer to the buffer that it will fetch. 67 // @param worker_id - The worker id 68 // @param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE. 69 // @return Status The status code returned 70 Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override; 71 72 // Base-class override for handling cases when an eoe is received. 73 // @param worker_id - The worker id 74 Status EoeReceived(int32_t worker_id) override; 75 76 // Base-class override for handling cases when an eof is received. 77 // @param worker_id - The worker id 78 Status EofReceived(int32_t worker_id) override; 79 80 // Base-class override. Return the number of workers in the first parent. 81 // @param workerId - The worker id 82 int32_t NumConsumers() const override; 83 84 // Base-class override. Return the number of producers in the first child. 85 // @param workerId - The worker id 86 int32_t NumProducers() const override; 87 88 // Op name getter 89 // @return Name of the current Op Name()90 std::string Name() const override { return kRepeatOp; } 91 92 /// \brief Getter function 93 /// \return The number of repeats that the user requested num_repeats()94 int32_t num_repeats() { return num_repeats_; } 95 96 /// \brief reset Op 97 /// \@return Status The status code returned 98 Status Reset() override; 99 100 int64_t GetTreeRepeatCount() override; 101 102 // \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes 103 // \param[in] eoe_op The input leaf/eoe operator to add to the list AddToEoeList(std::shared_ptr<DatasetOp> eoe_op)104 void AddToEoeList(std::shared_ptr<DatasetOp> eoe_op) { eoe_ops_.push_back(std::move(eoe_op)); } 105 106 std::vector<std::shared_ptr<DatasetOp>> eoe_ops_; // List of operators that can generate EOE underneath this repeat. 107 108 protected: 109 // The number of repeats that the user requested. 110 // Note that num_repeats_ is different with op_total_repeats_ or op_num_repeats_per_epoch_ in base DatasetOp class. 111 // For example, for repeat1 op in pipeline tfreader -> repeat1(3) -> repeat2(2) -> epoch ctrl(4), 112 // num_repeats_ = 3, op_total_repeats_ = 24, op_num_repeats_per_epoch_ = 6. 113 int32_t num_repeats_; 114 // A counter for the current number of executed repeats. 115 // Note that repeat_count_ is different with op_current_repeats_ in the base DatasetOp class 116 // because it counts the repeats in the current epoch, whereas op_current_repeats_ counts the global total repeats. 117 int32_t repeat_count_; 118 }; 119 } // namespace dataset 120 } // namespace mindspore 121 122 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_REPEAT_OP_H_ 123