1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_RANDOM_NODE_H_ 18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_RANDOM_NODE_H_ 19 20 #include <memory> 21 #include <string> 22 #include <utility> 23 #include <vector> 24 25 #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" 26 #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" 27 #include "minddata/dataset/include/dataset/samplers.h" 28 29 namespace mindspore { 30 namespace dataset { 31 class RandomNode : public NonMappableSourceNode { 32 public: 33 // Some constants to provide limits to random generation. 34 static constexpr int32_t kMaxNumColumns = 4; 35 static constexpr int32_t kMaxRank = 4; 36 static constexpr int32_t kMaxDimValue = 32; 37 38 /// \brief Constructor RandomNode(int32_t total_rows,const std::shared_ptr<SchemaObj> & schema,const std::vector<std::string> & columns_list,const std::shared_ptr<DatasetCache> & cache)39 RandomNode(int32_t total_rows, const std::shared_ptr<SchemaObj> &schema, const std::vector<std::string> &columns_list, 40 const std::shared_ptr<DatasetCache> &cache) 41 : NonMappableSourceNode(cache), 42 total_rows_(total_rows), 43 schema_path_(""), 44 schema_(schema), 45 columns_list_(columns_list) {} 46 47 /// \brief Constructor RandomNode(int32_t total_rows,const std::string & schema_path,const std::vector<std::string> & columns_list,const std::shared_ptr<DatasetCache> & cache)48 RandomNode(int32_t total_rows, const std::string &schema_path, const std::vector<std::string> &columns_list, 49 const std::shared_ptr<DatasetCache> &cache) 50 : NonMappableSourceNode(cache), 51 total_rows_(total_rows), 52 schema_path_(schema_path), 53 schema_(nullptr), 54 columns_list_(columns_list) {} 55 56 /// \brief Destructor 57 ~RandomNode() override = default; 58 59 /// \brief Node name getter 60 /// \return Name of the current node Name()61 std::string Name() const override { return kRandomNode; } 62 63 /// \brief Print the description 64 /// \param out - The output stream to write output to 65 void Print(std::ostream &out) const override; 66 67 /// \brief Copy the node to a new object 68 /// \return A shared pointer to the new copy 69 std::shared_ptr<DatasetNode> Copy() override; 70 71 /// \brief a base class override function to create the required runtime dataset op objects for this class 72 /// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create 73 /// \return Status Status::OK() if build successfully 74 Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override; 75 76 /// \brief Parameters validation 77 /// \return Status Status::OK() if all the parameters are valid 78 Status ValidateParams() override; 79 80 /// \brief Get the shard id of node 81 /// \return Status Status::OK() if get shard id successfully 82 Status GetShardId(int32_t *const shard_id) override; 83 84 /// \brief Base-class override for GetDatasetSize 85 /// \param[in] size_getter Shared pointer to DatasetSizeGetter 86 /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting 87 /// dataset size at the expense of accuracy. 88 /// \param[out] dataset_size the size of the dataset 89 /// \return Status of the function 90 Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, 91 int64_t *dataset_size) override; 92 93 /// \brief Getter functions TotalRows()94 int32_t TotalRows() const { return total_rows_; } SchemaPath()95 const std::string &SchemaPath() const { return schema_path_; } GetSchema()96 const std::shared_ptr<SchemaObj> &GetSchema() const { return schema_; } ColumnsList()97 const std::vector<std::string> &ColumnsList() const { return columns_list_; } RandGen()98 const std::mt19937 &RandGen() const { return rand_gen_; } GetDataSchema()99 const std::unique_ptr<DataSchema> &GetDataSchema() const { return data_schema_; } 100 101 /// \brief RandomDataset by itself is a non-mappable dataset that does not support sampling. 102 /// However, if a cache operator is injected at some other place higher in the tree, that cache can 103 /// inherit this sampler from the leaf, providing sampling support from the caching layer. 104 /// That is why we setup the sampler for a leaf node that does not use sampling. 105 /// \param[in] sampler The sampler to setup 106 /// \return Status of the function 107 Status SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) override; 108 109 /// \brief Random node will always produce the full set of data into the cache 110 /// \return Status of the function MakeSimpleProducer()111 Status MakeSimpleProducer() override { return Status::OK(); } 112 113 /// \brief Base-class override for accepting IRNodePass visitor 114 /// \param[in] p The node to visit 115 /// \param[out] modified Indicator if the node was modified 116 /// \return Status of the node visit 117 Status Accept(IRNodePass *const p, bool *const modified) override; 118 119 /// \brief Base-class override for accepting IRNodePass visitor 120 /// \param[in] p The node to visit 121 /// \param[out] modified Indicator if the node was modified 122 /// \return Status of the node visit 123 Status AcceptAfter(IRNodePass *const p, bool *const modified) override; 124 125 private: 126 /// \brief A quick inline for producing a random number between (and including) min/max 127 /// \param[in] min minimum number that can be generated. 128 /// \param[in] max maximum number that can be generated. 129 /// \return The generated random number 130 int32_t GenRandomInt(int32_t min, int32_t max); 131 132 int32_t total_rows_; 133 std::string schema_path_; 134 std::shared_ptr<SchemaObj> schema_; 135 std::vector<std::string> columns_list_; 136 std::mt19937 rand_gen_; 137 std::unique_ptr<DataSchema> data_schema_; 138 }; 139 } // namespace dataset 140 } // namespace mindspore 141 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_RANDOM_NODE_H_ 142