• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SAMPLER_H_
17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SAMPLER_H_
18 
19 #include <limits>
20 #include <map>
21 #include <memory>
22 #include <random>
23 #include <vector>
24 
25 #include "minddata/dataset/core/tensor.h"
26 
27 #include "minddata/dataset/engine/data_schema.h"
28 #include "minddata/dataset/engine/datasetops/dataset_op.h"
29 
30 namespace mindspore {
31 namespace dataset {
32 //  RandomAccessOp is a base class that all data-producing leaf operators
33 //  must inherit from if those leaf operator wish to support sampling.
34 class RandomAccessOp {
35  public:
36   // Sampler get number of rows in the dataset
37   // @param int64_t num - return number of rows for this dataset
38   // @return Status The status code returned
39   Status GetNumRowsInDataset(int64_t *num_rows) const;
40 
41   // sampler gets label , imageIds from corresponding Dataset Op, this function is unique to PK
42   // @param std::map<int64_t, std::vector<int64_t>> * map
43   // @return Status The status code returned
GetClassIds(std::map<int32_t,std::vector<int64_t>> * map)44   virtual Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *map) const {
45     RETURN_STATUS_UNEXPECTED("[Internal ERROR] GetClassIds needs to be override to support PK.");
46   }
47 
48   // default destructor
49   virtual ~RandomAccessOp() = default;
50 
51   /// Set num_rows
52   /// \param num_rows
SetNumRows(int64_t num_rows)53   void SetNumRows(int64_t num_rows) { num_rows_ = num_rows; }
54 
55  protected:
56   // The amount of rows in the dataset itself. This is the before-sampling value, the
57   // total count of rows.  A sampler may choose to sample less than this amount.
58   int64_t num_rows_ = -1;
59 };
60 
61 class SamplerRT {
62  public:
63   // Constructor
64   // @param int64_t num_samples: the user-requested number of samples ids to generate. A value of 0
65   //                indicates that the sampler should produce the complete set of ids.
66   // @param int64_t samples_per_tensor: Num of Sampler Ids to fetch via 1 GetNextSample call
67   SamplerRT(int64_t num_samples, int64_t samples_per_tensor);
68 
SamplerRT(const SamplerRT & s)69   SamplerRT(const SamplerRT &s) : SamplerRT(s.num_samples_, s.samples_per_tensor_) {}
70 
71   // Copy assignment operator
72   SamplerRT &operator=(const SamplerRT &other) {
73     num_samples_ = other.num_samples_;
74     samples_per_tensor_ = other.samples_per_tensor_;
75     return *this;
76   }
77 
78   // default destructor
79   virtual ~SamplerRT() = default;
80 
81   // Get a list of sample ids.
82   // @note It is Sampler responsibility to make sure that the id is not out of bound.
83   // @param TensorRow to be returned to StorageOp
84   // @param int32_t workerId - not meant to be used
85   // @return Status The status code returned
86   virtual Status GetNextSample(TensorRow *out) = 0;
87 
88 // This function only called by python layer. Not needed by Android.
89 #ifdef ENABLE_PYTHON
90   // return all ids in one epoch as a numpy array, then call reset
91   Status GetAllIdsThenReset(py::array *data);
92 #endif
93 
94   /// \brief Reset for next epoch.
95   /// \note If failover_reset is set, any override of this function must support the scenario where consecutive calls to
96   /// it are executed successfully (to prepare the sampler for a specific epoch, including updating any random
97   /// generator's internal state)
98   /// \param[in] failover_reset - A boolean to show whether we are resetting the pipeline
99   /// \return Status The status code returned
100   virtual Status ResetSampler(const bool failover_reset = false) = 0;
101 
102   // first handshake between leaf source op and Sampler. This func will determine the amount of data
103   // in the dataset that we can sample from.
104   // @param op - leaf op pointer, pass in so Sampler can ask it about how much data there is
105   // @param reset_count - reset the random generator these many times (used in fast_recovery mode of reset)
106   // @return status error code
107   virtual Status HandshakeRandomAccessOp(const RandomAccessOp *op, const int32_t reset_count = 0);
108 
109   // initialize sampler and perform checks on certain vars
InitSampler()110   virtual Status InitSampler() { return Status::OK(); }
111 
112   // setter for num samples
113   // @param num_samples - the number of samples to assign.
114   // @return status error code
115   Status SetNumSamples(int64_t num_samples);
116 
117   // getter for num samples
118   // @return number of samples
119   int64_t GetNumSamples() const;
120 
121   // Calculate num samples. Unlike GetNumSamples, it is not a getter and doesn't necessarily return the value of
122   // num_samples_
123   // @return number of samples, return -1 if sampler cannot determine this value (e.g. PKSampler)
124   virtual int64_t CalculateNumSamples(int64_t num_rows);
125 
126   // setter for num or records in the dataset
127   // @param num_rows - the number of records
128   // @return status error code
129   Status SetNumRowsInDataset(int64_t num_rows);
130 
131   // Adds a sampler to become our child.
132   // @param std::shared_ptr<DatasetOp> - The sampler to add as a child.
133   // @return Status The status code returned
134   Status AddChild(std::shared_ptr<SamplerRT> child);
135 
136   // A helper function to create an int64_t 1-D Tensor specifically used to hold sampleIds for Sampler
137   // @param std::shared_ptr<Tensor>* sampleIds
138   // @param int64_t numElements - must be a non 0 number
139   // @return Status The status code returned
140   Status CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements);
141 
142   // A print method typically used for debugging
143   // @param out - The output stream to write output to
144   // @param show_all - A bool to control if you want to show all info or just a summary
145   virtual void SamplerPrint(std::ostream &out, bool show_all) const;
146 
147   // << Stream output operator overload
148   // @notes This allows you to write the debug print info using stream operators
149   // @param out - reference to the output stream being overloaded
150   // @param sampler - reference to teh sampler to print
151   // @return - the output stream must be returned
152   friend std::ostream &operator<<(std::ostream &out, const SamplerRT &sampler) {
153     sampler.SamplerPrint(out, false);
154     return out;
155   }
156 
157   // Checks if this sampler has a child sampler.
158   // @return - tre if there is a child sampler, false otherwise.
159   bool HasChildSampler() const;
160 
161   // Uses id as an index for the list of ids generated by the child sampler, and gets the
162   // associated id.
163   // @param int64_t* out_associated_id - Out parameter, contains the associated id.
164   // @param int64_t id - The id used as an index to get the associated child id.
165   // @return Status The status code returned
166   Status GetAssociatedChildId(int64_t *out_associated_id, int64_t id);
167 
168   /// \brief Get the arguments of node
169   /// \param[out] out_json JSON string of all attributes
170   /// \return Status of the function
171   virtual Status to_json(nlohmann::json *out_json);
172 
173  protected:
174   // Number of rows of data from the place this sampler is sampling from. If this sampler
175   // has a child sampler, num_rows_ is the number of ids the child sampler will
176   // output. Otherwise, num_rows_ is the number of rows in the dataset.
177   int64_t num_rows_;
178 
179   // The user may want to sample less than the full amount of data.  num_samples_ reduces the number
180   // of id's returned as request by the user.  Derived classes will choose how to sample the smaller
181   // amount.
182   int64_t num_samples_;
183 
184   bool is_initialized;
185   int64_t samples_per_tensor_;
186   std::unique_ptr<ColDescriptor> col_desc_;
187   std::vector<std::shared_ptr<SamplerRT>> child_;  // Child nodes
188   TensorRow child_ids_;
189 };
190 }  // namespace dataset
191 }  // namespace mindspore
192 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SAMPLER_H_
193