1 /** 2 * Copyright 2019-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_RANDOM_DATA_OP_ 17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_RANDOM_DATA_OP_ 18 19 #include <atomic> 20 #include <memory> 21 #include <mutex> 22 #include <random> 23 #include <string> 24 #include <vector> 25 #include <utility> 26 #include "minddata/dataset/util/status.h" 27 #include "minddata/dataset/core/tensor.h" 28 #include "minddata/dataset/core/data_type.h" 29 #include "minddata/dataset/engine/data_schema.h" 30 #include "minddata/dataset/engine/datasetops/parallel_op.h" 31 #include "minddata/dataset/util/wait_post.h" 32 33 namespace mindspore { 34 namespace dataset { 35 // The RandomDataOp is a leaf node storage operator that generates random data based 36 // on the schema specifications. Typically, it's used for testing and demonstrating 37 // various dataset operator pipelines. It is not "real" data to train with. 38 // The data that is random created is just random and repeated bytes, there is no 39 // "meaning" behind what these bytes are. 40 class RandomDataOp : public ParallelOp { 41 public: 42 // Some constants to provide limits to random generation. 43 static constexpr int32_t kMaxNumColumns = 4; 44 static constexpr int32_t kMaxRank = 4; 45 static constexpr int32_t kMaxDimValue = 32; 46 static constexpr int32_t kMaxTotalRows = 1024; 47 48 /** 49 * Constructor for RandomDataOp 50 * @note Private constructor. Must use builder to construct. 51 * @param num_workers - The number of workers 52 * @param op_connector_size - The size of the output connector 53 * @param data_schema - A user-provided schema 54 * @param total_rows - The total number of rows in the dataset 55 * @return Builder - The modified builder by reference 56 */ 57 RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t total_rows, 58 std::unique_ptr<DataSchema> data_schema); 59 60 /** 61 * Destructor 62 */ 63 ~RandomDataOp() = default; 64 65 /** 66 * A print method typically used for debugging 67 * @param out - The output stream to write output to 68 * @param show_all - A bool to control if you want to show all info or just a summary 69 */ 70 void Print(std::ostream &out, bool show_all) const override; 71 72 /** 73 * << Stream output operator overload 74 * @notes This allows you to write the debug print info using stream operators 75 * @param out - reference to the output stream being overloaded 76 * @param so - reference to the ShuffleOp to display 77 * @return - the output stream must be returned 78 */ 79 friend std::ostream &operator<<(std::ostream &out, const RandomDataOp &op) { 80 op.Print(out, false); 81 return out; 82 } 83 84 /** 85 * Class functor operator () override. 86 * All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will 87 * provide the master loop that drives the logic for performing the work. 88 * @return Status The status code returned 89 */ 90 Status operator()() override; 91 92 /** 93 * Overrides base class reset method. When an operator does a reset, it cleans up any state 94 * info from it's previous execution and then initializes itself so that it can be executed 95 * again. 96 * @return Status The status code returned 97 */ 98 Status Reset() override; 99 100 /** 101 * Quick getter for total rows. 102 */ GetTotalRows()103 int64_t GetTotalRows() const { return total_rows_; } 104 105 // Op name getter 106 // @return Name of the current Op Name()107 std::string Name() const override { return "RandomDataOp"; } 108 109 private: 110 /** 111 * The entry point code for when workers are launched 112 * @param worker_id - The worker id 113 * @return Status The status code returned 114 */ 115 Status WorkerEntry(int32_t worker_id) override; 116 117 /** 118 * Helper function to produce a default/random schema if one didn't exist 119 */ 120 void GenerateSchema(); 121 122 /** 123 * Performs a synchronization between workers at the end of an epoch 124 * @param worker_id - The worker id 125 * @return Status The status code returned 126 */ 127 Status EpochSync(int32_t worker_id, bool *quitting); 128 129 /** 130 * A helper function to create random data for the row 131 * @param worker_id - The worker id 132 * @param new_row - The output row to produce 133 * @return Status The status code returned 134 */ 135 Status CreateRandomRow(int32_t worker_id, TensorRow *new_row); 136 137 /** 138 * A quick inline for producing a random number between (and including) min/max 139 * @param min - minimum number that can be generated 140 * @param max - maximum number that can be generated 141 * @return - The generated random number 142 */ GenRandomInt(int32_t min,int32_t max)143 inline int32_t GenRandomInt(int32_t min, int32_t max) { 144 std::uniform_int_distribution<int32_t> uniDist(min, max); 145 return uniDist(rand_gen_); 146 } 147 148 // Private function for computing the assignment of the column name map. 149 // @return - Status 150 Status ComputeColMap() override; 151 152 int64_t total_rows_; 153 int64_t epoch_rows_sent_; 154 std::atomic<int32_t> guys_in_; 155 std::atomic<int32_t> guys_out_; 156 int32_t eoe_worker_id_; 157 std::unique_ptr<DataSchema> data_schema_; 158 std::vector<int64_t> worker_max_rows_; 159 std::vector<int64_t> worker_rows_packed_; 160 std::mt19937 rand_gen_; 161 WaitPost epoch_sync_wait_post_; 162 WaitPost all_out_; 163 }; // class RandomDataOp 164 } // namespace dataset 165 } // namespace mindspore 166 167 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_RANDOM_DATA_OP_ 168