1 /** 2 * Copyright 2019-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SHUFFLE_OP_H_ 17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SHUFFLE_OP_H_ 18 19 #include <map> 20 #include <memory> 21 #include <queue> 22 #include <random> 23 #include <string> 24 #include <unordered_map> 25 #include <vector> 26 27 #include "minddata/dataset/core/tensor.h" 28 #include "minddata/dataset/core/tensor_shape.h" 29 #include "minddata/dataset/engine/dataset_iterator.h" 30 #include "minddata/dataset/engine/datasetops/pipeline_op.h" 31 #include "minddata/dataset/util/status.h" 32 33 namespace mindspore { 34 namespace dataset { 35 // Forward declare 36 class ExecutionTree; 37 38 class DbConnector; 39 40 class ShuffleOp : public PipelineOp { 41 // Shuffle buffer state flags 42 // 43 // Shuffle buffer is in a state of being initialized 44 static constexpr int32_t kShuffleStateInit = 0; 45 46 // Shuffle buffer is in a state of being actively drained from, but refilling as well 47 static constexpr int32_t kShuffleStateActive = 1; 48 49 // Shuffle buffer is in a state of being drained 50 static constexpr int32_t kShuffleStateDrain = 2; 51 52 public: 53 // Constructor of the ShuffleOp 54 // @note The builder class should be used to call it 55 // @param shuffle_size - The size for the shuffle buffer 56 // @param shuffle_seed - The seed to use for random number generation 57 // @param op_connector_size - The output connector queue size 58 ShuffleOp(int32_t shuffle_size, uint32_t shuffle_seed, int32_t op_connector_size, bool reset_every_epoch); 59 60 // Destructor 61 ~ShuffleOp() = default; 62 63 // A print method typically used for debugging 64 // @param out - The output stream to write output to 65 // @param show_all - A bool to control if you want to show all info or just a summary 66 void Print(std::ostream &out, bool show_all) const override; 67 68 // << Stream output operator overload 69 // @notes This allows you to write the debug print info using stream operators 70 // @param out - reference to the output stream being overloaded 71 // @param so - reference to the ShuffleOp to display 72 // @return - the output stream must be returned 73 friend std::ostream &operator<<(std::ostream &out, const ShuffleOp &so) { 74 so.Print(out, false); 75 return out; 76 } 77 78 // Class functor operator () override. 79 // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will 80 // provide the master loop that drives the logic for performing the work 81 // @return Status The status code returned 82 Status operator()() override; 83 84 // Base-class override for special eoe handler. 85 // ShuffleOp must override this because it shall not perform default handling of eoe. Instead 86 // the ShuffleOp needs to manage actions related to the end of the epoch itself. 87 // @return Status The status code returned 88 Status EoeReceived(int32_t worker_id) override; 89 90 // Op name getter 91 // @return Name of the current Op Name()92 std::string Name() const override { return kShuffleOp; } 93 94 private: 95 // Private function to add a new row to the shuffle buffer. 96 // @return Status The status code returned 97 Status AddRowToShuffleBuffer(TensorRow new_shuffle_row); 98 99 // Private function to populate the shuffle buffer initially by fetching from the child output 100 // connector until the shuffle buffer is full (or there is no more data coming). 101 // @return Status The status code returned 102 Status InitShuffleBuffer(); 103 104 // Private function to re-init the shuffle op for another epoch. Shuffle op calls this by 105 // itself rather than waiting for the reset driven from operators above it in the pipeline. 106 // @return Status The status code returned 107 Status SelfReset(); 108 109 int32_t shuffle_size_; // User config for the size of the shuffle buffer (number of rows) 110 uint32_t shuffle_seed_; 111 bool reshuffle_each_epoch_; 112 // rng_ is seeded initially with shuffle_seed_. mt19937 is used for its large period. 113 // specifically mt19937_64 is used to generate larger random numbers to reduce bias when 114 // modding to fit within our desired range. we dont use a distribution 115 // (ie uniform_int_distribution) because we will need to create up to |dataset| instances 116 // of the distribution object in the common case of a perfect shuffle 117 std::mt19937_64 rng_; 118 // A single (potentially large) buffer of tensor rows for performing shuffling. 119 std::unique_ptr<TensorTable> shuffle_buffer_; 120 int32_t shuffle_last_row_idx_; // Internal tracking of the last slot of our shuffle buffer 121 int32_t shuffle_buffer_state_; // State tracking for the shuffle buffer phases of work 122 123 std::unique_ptr<ChildIterator> child_iterator_; // An iterator for fetching. 124 }; 125 } // namespace dataset 126 } // namespace mindspore 127 128 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SHUFFLE_OP_H_ 129