• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SHUFFLE_OP_H_
17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SHUFFLE_OP_H_
18 
19 #include <map>
20 #include <memory>
21 #include <queue>
22 #include <random>
23 #include <string>
24 #include <unordered_map>
25 #include <vector>
26 
27 #include "minddata/dataset/core/tensor.h"
28 #include "minddata/dataset/core/tensor_shape.h"
29 #include "minddata/dataset/engine/dataset_iterator.h"
30 #include "minddata/dataset/engine/datasetops/pipeline_op.h"
31 #include "minddata/dataset/util/status.h"
32 
33 namespace mindspore {
34 namespace dataset {
35 // Forward declare
36 class ExecutionTree;
37 
38 class DbConnector;
39 
40 class ShuffleOp : public PipelineOp {
41   // Shuffle buffer state flags
42   //
43   // Shuffle buffer is in a state of being initialized
44   static constexpr int32_t kShuffleStateInit = 0;
45 
46   // Shuffle buffer is in a state of being actively drained from, but refilling as well
47   static constexpr int32_t kShuffleStateActive = 1;
48 
49   // Shuffle buffer is in a state of being drained
50   static constexpr int32_t kShuffleStateDrain = 2;
51 
52  public:
53   // Constructor of the ShuffleOp
54   // @note The builder class should be used to call it
55   // @param shuffle_size - The size for the shuffle buffer
56   // @param shuffle_seed - The seed to use for random number generation
57   // @param op_connector_size - The output connector queue size
58   ShuffleOp(int32_t shuffle_size, uint32_t shuffle_seed, int32_t op_connector_size, bool reset_every_epoch);
59 
60   // Destructor
61   ~ShuffleOp() = default;
62 
63   // A print method typically used for debugging
64   // @param out - The output stream to write output to
65   // @param show_all - A bool to control if you want to show all info or just a summary
66   void Print(std::ostream &out, bool show_all) const override;
67 
68   // << Stream output operator overload
69   // @notes This allows you to write the debug print info using stream operators
70   // @param out - reference to the output stream being overloaded
71   // @param so - reference to the ShuffleOp to display
72   // @return - the output stream must be returned
73   friend std::ostream &operator<<(std::ostream &out, const ShuffleOp &so) {
74     so.Print(out, false);
75     return out;
76   }
77 
78   // Class functor operator () override.
79   // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will
80   // provide the master loop that drives the logic for performing the work
81   // @return Status The status code returned
82   Status operator()() override;
83 
84   // Base-class override for special eoe handler.
85   // ShuffleOp must override this because it shall not perform default handling of eoe. Instead
86   // the ShuffleOp needs to manage actions related to the end of the epoch itself.
87   // @return Status The status code returned
88   Status EoeReceived(int32_t worker_id) override;
89 
90   // Op name getter
91   // @return Name of the current Op
Name()92   std::string Name() const override { return kShuffleOp; }
93 
94  private:
95   // Private function to add a new row to the shuffle buffer.
96   // @return Status The status code returned
97   Status AddRowToShuffleBuffer(TensorRow new_shuffle_row);
98 
99   // Private function to populate the shuffle buffer initially by fetching from the child output
100   // connector until the shuffle buffer is full (or there is no more data coming).
101   // @return Status The status code returned
102   Status InitShuffleBuffer();
103 
104   // Private function to re-init the shuffle op for another epoch.  Shuffle op calls this by
105   // itself rather than waiting for the reset driven from operators above it in the pipeline.
106   // @return Status The status code returned
107   Status SelfReset();
108 
109   int32_t shuffle_size_;  // User config for the size of the shuffle buffer (number of rows)
110   uint32_t shuffle_seed_;
111   bool reshuffle_each_epoch_;
112   // rng_ is seeded initially with shuffle_seed_. mt19937 is used for its large period.
113   // specifically mt19937_64 is used to generate larger random numbers to reduce bias when
114   // modding to fit within our desired range. we dont use a distribution
115   // (ie uniform_int_distribution) because we will need to create up to |dataset| instances
116   // of the distribution object in the common case of a perfect shuffle
117   std::mt19937_64 rng_;
118   // A single (potentially large) buffer of tensor rows for performing shuffling.
119   std::unique_ptr<TensorTable> shuffle_buffer_;
120   int32_t shuffle_last_row_idx_;  // Internal tracking of the last slot of our shuffle buffer
121   int32_t shuffle_buffer_state_;  // State tracking for the shuffle buffer phases of work
122 
123   std::unique_ptr<ChildIterator> child_iterator_;  // An iterator for fetching.
124 };
125 }  // namespace dataset
126 }  // namespace mindspore
127 
128 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SHUFFLE_OP_H_
129