1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_RANDOM_NODE_H_ 18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_RANDOM_NODE_H_ 19 20 #include <memory> 21 #include <string> 22 #include <utility> 23 #include <vector> 24 25 #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" 26 #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" 27 #include "minddata/dataset/include/dataset/samplers.h" 28 29 namespace mindspore { 30 namespace dataset { 31 32 class RandomNode : public NonMappableSourceNode { 33 public: 34 // Some constants to provide limits to random generation. 35 static constexpr int32_t kMaxNumColumns = 4; 36 static constexpr int32_t kMaxRank = 4; 37 static constexpr int32_t kMaxDimValue = 32; 38 39 /// \brief Constructor RandomNode(const int32_t & total_rows,std::shared_ptr<SchemaObj> schema,const std::vector<std::string> & columns_list,std::shared_ptr<DatasetCache> cache)40 RandomNode(const int32_t &total_rows, std::shared_ptr<SchemaObj> schema, const std::vector<std::string> &columns_list, 41 std::shared_ptr<DatasetCache> cache) 42 : NonMappableSourceNode(std::move(cache)), 43 total_rows_(total_rows), 44 schema_path_(""), 45 schema_(std::move(schema)), 46 columns_list_(columns_list) {} 47 48 /// \brief Constructor RandomNode(const int32_t & total_rows,std::string schema_path,const std::vector<std::string> & columns_list,std::shared_ptr<DatasetCache> cache)49 RandomNode(const int32_t &total_rows, std::string schema_path, const std::vector<std::string> &columns_list, 50 std::shared_ptr<DatasetCache> cache) 51 : NonMappableSourceNode(std::move(cache)), 52 total_rows_(total_rows), 53 schema_path_(schema_path), 54 schema_(nullptr), 55 columns_list_(columns_list) {} 56 57 /// \brief Destructor 58 ~RandomNode() = default; 59 60 /// \brief Node name getter 61 /// \return Name of the current node Name()62 std::string Name() const override { return kRandomNode; } 63 64 /// \brief Print the description 65 /// \param out - The output stream to write output to 66 void Print(std::ostream &out) const override; 67 68 /// \brief Copy the node to a new object 69 /// \return A shared pointer to the new copy 70 std::shared_ptr<DatasetNode> Copy() override; 71 72 /// \brief a base class override function to create the required runtime dataset op objects for this class 73 /// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create 74 /// \return Status Status::OK() if build successfully 75 Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override; 76 77 /// \brief Parameters validation 78 /// \return Status Status::OK() if all the parameters are valid 79 Status ValidateParams() override; 80 81 /// \brief Get the shard id of node 82 /// \return Status Status::OK() if get shard id successfully 83 Status GetShardId(int32_t *const shard_id) override; 84 85 /// \brief Base-class override for GetDatasetSize 86 /// \param[in] size_getter Shared pointer to DatasetSizeGetter 87 /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting 88 /// dataset size at the expense of accuracy. 89 /// \param[out] dataset_size the size of the dataset 90 /// \return Status of the function 91 Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, 92 int64_t *dataset_size) override; 93 94 /// \brief Getter functions TotalRows()95 int32_t TotalRows() const { return total_rows_; } SchemaPath()96 const std::string &SchemaPath() const { return schema_path_; } GetSchema()97 const std::shared_ptr<SchemaObj> &GetSchema() const { return schema_; } ColumnsList()98 const std::vector<std::string> &ColumnsList() const { return columns_list_; } RandGen()99 const std::mt19937 &RandGen() const { return rand_gen_; } GetDataSchema()100 const std::unique_ptr<DataSchema> &GetDataSchema() const { return data_schema_; } 101 102 /// \brief RandomDataset by itself is a non-mappable dataset that does not support sampling. 103 /// However, if a cache operator is injected at some other place higher in the tree, that cache can 104 /// inherit this sampler from the leaf, providing sampling support from the caching layer. 105 /// That is why we setup the sampler for a leaf node that does not use sampling. 106 /// \param[in] sampler The sampler to setup 107 /// \return Status of the function 108 Status SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) override; 109 110 /// \brief Random node will always produce the full set of data into the cache 111 /// \return Status of the function MakeSimpleProducer()112 Status MakeSimpleProducer() override { return Status::OK(); } 113 114 /// \brief Base-class override for accepting IRNodePass visitor 115 /// \param[in] p The node to visit 116 /// \param[out] modified Indicator if the node was modified 117 /// \return Status of the node visit 118 Status Accept(IRNodePass *const p, bool *const modified) override; 119 120 /// \brief Base-class override for accepting IRNodePass visitor 121 /// \param[in] p The node to visit 122 /// \param[out] modified Indicator if the node was modified 123 /// \return Status of the node visit 124 Status AcceptAfter(IRNodePass *const p, bool *const modified) override; 125 126 private: 127 /// \brief A quick inline for producing a random number between (and including) min/max 128 /// \param[in] min minimum number that can be generated. 129 /// \param[in] max maximum number that can be generated. 130 /// \return The generated random number 131 int32_t GenRandomInt(int32_t min, int32_t max); 132 133 int32_t total_rows_; 134 std::string schema_path_; 135 std::shared_ptr<SchemaObj> schema_; 136 std::vector<std::string> columns_list_; 137 std::mt19937 rand_gen_; 138 std::unique_ptr<DataSchema> data_schema_; 139 }; 140 141 } // namespace dataset 142 } // namespace mindspore 143 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_RANDOM_NODE_H_ 144