1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SAMPLER_H_ 17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SAMPLER_H_ 18 19 #include <limits> 20 #include <map> 21 #include <memory> 22 #include <random> 23 #include <vector> 24 25 #include "minddata/dataset/core/tensor.h" 26 27 #include "minddata/dataset/engine/data_schema.h" 28 #include "minddata/dataset/engine/datasetops/dataset_op.h" 29 30 namespace mindspore { 31 namespace dataset { 32 // RandomAccessOp is a base class that all data-producing leaf operators 33 // must inherit from if those leaf operator wish to support sampling. 34 class RandomAccessOp { 35 public: 36 // Sampler get number of rows in the dataset 37 // @param int64_t num - return number of rows for this dataset 38 // @return Status The status code returned 39 Status GetNumRowsInDataset(int64_t *num_rows) const; 40 41 // sampler gets label , imageIds from corresponding Dataset Op, this function is unique to PK 42 // @param std::map<int64_t, std::vector<int64_t>> * map 43 // @return Status The status code returned GetClassIds(std::map<int32_t,std::vector<int64_t>> * map)44 virtual Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *map) const { 45 RETURN_STATUS_UNEXPECTED("GetClassIds needs to be override to support PK"); 46 } 47 48 // default destructor 49 virtual ~RandomAccessOp() = default; 50 51 /// Set num_rows 52 /// \param num_rows SetNumRows(int64_t num_rows)53 void SetNumRows(int64_t num_rows) { num_rows_ = num_rows; } 54 55 protected: 56 // The amount of rows in the dataset itself. This is the before-sampling value, the 57 // total count of rows. A sampler may choose to sample less than this amount. 58 int64_t num_rows_ = -1; 59 }; 60 61 class SamplerRT { 62 public: 63 // Constructor 64 // @param int64_t num_samples: the user-requested number of samples ids to generate. A value of 0 65 // indicates that the sampler should produce the complete set of ids. 66 // @param int64_t samples_per_tensor: Num of Sampler Ids to fetch via 1 GetNextSample call 67 SamplerRT(int64_t num_samples, int64_t samples_per_tensor); 68 SamplerRT(const SamplerRT & s)69 SamplerRT(const SamplerRT &s) : SamplerRT(s.num_samples_, s.samples_per_tensor_) {} 70 71 // default destructor 72 ~SamplerRT() = default; 73 74 // Get a list of sample ids. 75 // @note It is Sampler responsibility to make sure that the id is not out of bound. 76 // @param TensorRow to be returned to StorageOp 77 // @param int32_t workerId - not meant to be used 78 // @return Status The status code returned 79 virtual Status GetNextSample(TensorRow *out) = 0; 80 81 // This function only called by python layer. Not needed by Android. 82 #ifdef ENABLE_PYTHON 83 // return all ids in one epoch as a numpy array, then call reset 84 Status GetAllIdsThenReset(py::array *data); 85 #endif 86 87 // for next epoch of sampleIds 88 // @return Status The status code returned 89 virtual Status ResetSampler() = 0; 90 91 // first handshake between leaf source op and Sampler. This func will determine the amount of data 92 // in the dataset that we can sample from. 93 // @param op - leaf op pointer, pass in so Sampler can ask it about how much data there is 94 // @return 95 virtual Status HandshakeRandomAccessOp(const RandomAccessOp *op); 96 97 // initialize sampler and perform checks on certain vars InitSampler()98 virtual Status InitSampler() { return Status::OK(); } 99 100 // setter for num samples 101 // @param num_samples - the number of samples to assign. 102 // @return status error code 103 Status SetNumSamples(int64_t num_samples); 104 105 // getter for num samples 106 // @return number of samples 107 int64_t GetNumSamples() const; 108 109 // Calculate num samples. Unlike GetNumSamples, it is not a getter and doesn't necessarily return the value of 110 // num_samples_ 111 // @return number of samples, return -1 if sampler cannot determine this value (e.g. PKSampler) 112 virtual int64_t CalculateNumSamples(int64_t num_rows); 113 114 // setter for num or records in the dataset 115 // @param num_rows - the number of records 116 // @return status error code 117 Status SetNumRowsInDataset(int64_t num_rows); 118 119 // Adds a sampler to become our child. 120 // @param std::shared_ptr<DatasetOp> - The sampler to add as a child. 121 // @return Status The status code returned 122 Status AddChild(std::shared_ptr<SamplerRT> child); 123 124 // A helper function to create a int64_t 1-D Tensor specifically used to hold sampleIds for Sampler 125 // @param std::shared_ptr<Tensor>* sampleIds 126 // @param int64_t numElements - must be a non 0 number 127 // @return Status The status code returned 128 Status CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements); 129 130 // A print method typically used for debugging 131 // @param out - The output stream to write output to 132 // @param show_all - A bool to control if you want to show all info or just a summary 133 virtual void SamplerPrint(std::ostream &out, bool show_all) const; 134 135 // << Stream output operator overload 136 // @notes This allows you to write the debug print info using stream operators 137 // @param out - reference to the output stream being overloaded 138 // @param sampler - reference to teh sampler to print 139 // @return - the output stream must be returned 140 friend std::ostream &operator<<(std::ostream &out, const SamplerRT &sampler) { 141 sampler.SamplerPrint(out, false); 142 return out; 143 } 144 145 // Checks if this sampler has a child sampler. 146 // @return - tre if there is a child sampler, false otherwise. 147 bool HasChildSampler() const; 148 149 // Uses id as an index for the list of ids generated by the child sampler, and gets the 150 // associated id. 151 // @param int64_t* out_associated_id - Out parameter, contains the associated id. 152 // @param int64_t id - The id used as an index to get the associated child id. 153 // @return Status The status code returned 154 Status GetAssociatedChildId(int64_t *out_associated_id, int64_t id); 155 156 /// \brief Get the arguments of node 157 /// \param[out] out_json JSON string of all attributes 158 /// \return Status of the function 159 virtual Status to_json(nlohmann::json *out_json); 160 161 protected: 162 // Number of rows of data from the place this sampler is sampling from. If this sampler 163 // has a child sampler, num_rows_ is the number of ids the child sampler will 164 // output. Otherwise, num_rows_ is the number of rows in the dataset. 165 int64_t num_rows_; 166 167 // The user may want to sample less than the full amount of data. num_samples_ reduces the number 168 // of id's returned as request by the user. Derived classes will choose how to sample the smaller 169 // amount. 170 int64_t num_samples_; 171 172 bool is_initialized; 173 int64_t samples_per_tensor_; 174 std::unique_ptr<ColDescriptor> col_desc_; 175 std::vector<std::shared_ptr<SamplerRT>> child_; // Child nodes 176 TensorRow child_ids_; 177 }; 178 } // namespace dataset 179 } // namespace mindspore 180 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SAMPLER_H_ 181