• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SAMPLER_H_
17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SAMPLER_H_
18 
19 #include <limits>
20 #include <map>
21 #include <memory>
22 #include <random>
23 #include <vector>
24 
25 #include "minddata/dataset/core/tensor.h"
26 
27 #include "minddata/dataset/engine/data_schema.h"
28 #include "minddata/dataset/engine/datasetops/dataset_op.h"
29 
30 namespace mindspore {
31 namespace dataset {
32 //  RandomAccessOp is a base class that all data-producing leaf operators
33 //  must inherit from if those leaf operator wish to support sampling.
34 class RandomAccessOp {
35  public:
36   // Sampler get number of rows in the dataset
37   // @param int64_t num - return number of rows for this dataset
38   // @return Status The status code returned
39   Status GetNumRowsInDataset(int64_t *num_rows) const;
40 
41   // sampler gets label , imageIds from corresponding Dataset Op, this function is unique to PK
42   // @param std::map<int64_t, std::vector<int64_t>> * map
43   // @return Status The status code returned
GetClassIds(std::map<int32_t,std::vector<int64_t>> * map)44   virtual Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *map) const {
45     RETURN_STATUS_UNEXPECTED("GetClassIds needs to be override to support PK");
46   }
47 
48   // default destructor
49   virtual ~RandomAccessOp() = default;
50 
51   /// Set num_rows
52   /// \param num_rows
SetNumRows(int64_t num_rows)53   void SetNumRows(int64_t num_rows) { num_rows_ = num_rows; }
54 
55  protected:
56   // The amount of rows in the dataset itself. This is the before-sampling value, the
57   // total count of rows.  A sampler may choose to sample less than this amount.
58   int64_t num_rows_ = -1;
59 };
60 
61 class SamplerRT {
62  public:
63   // Constructor
64   // @param int64_t num_samples: the user-requested number of samples ids to generate. A value of 0
65   //                indicates that the sampler should produce the complete set of ids.
66   // @param int64_t samples_per_tensor: Num of Sampler Ids to fetch via 1 GetNextSample call
67   SamplerRT(int64_t num_samples, int64_t samples_per_tensor);
68 
SamplerRT(const SamplerRT & s)69   SamplerRT(const SamplerRT &s) : SamplerRT(s.num_samples_, s.samples_per_tensor_) {}
70 
71   // default destructor
72   ~SamplerRT() = default;
73 
74   // Get a list of sample ids.
75   // @note It is Sampler responsibility to make sure that the id is not out of bound.
76   // @param TensorRow to be returned to StorageOp
77   // @param int32_t workerId - not meant to be used
78   // @return Status The status code returned
79   virtual Status GetNextSample(TensorRow *out) = 0;
80 
81 // This function only called by python layer. Not needed by Android.
82 #ifdef ENABLE_PYTHON
83   // return all ids in one epoch as a numpy array, then call reset
84   Status GetAllIdsThenReset(py::array *data);
85 #endif
86 
87   // for next epoch of sampleIds
88   // @return Status The status code returned
89   virtual Status ResetSampler() = 0;
90 
91   // first handshake between leaf source op and Sampler. This func will determine the amount of data
92   // in the dataset that we can sample from.
93   // @param op - leaf op pointer, pass in so Sampler can ask it about how much data there is
94   // @return
95   virtual Status HandshakeRandomAccessOp(const RandomAccessOp *op);
96 
97   // initialize sampler and perform checks on certain vars
InitSampler()98   virtual Status InitSampler() { return Status::OK(); }
99 
100   // setter for num samples
101   // @param num_samples - the number of samples to assign.
102   // @return status error code
103   Status SetNumSamples(int64_t num_samples);
104 
105   // getter for num samples
106   // @return number of samples
107   int64_t GetNumSamples() const;
108 
109   // Calculate num samples. Unlike GetNumSamples, it is not a getter and doesn't necessarily return the value of
110   // num_samples_
111   // @return number of samples, return -1 if sampler cannot determine this value (e.g. PKSampler)
112   virtual int64_t CalculateNumSamples(int64_t num_rows);
113 
114   // setter for num or records in the dataset
115   // @param num_rows - the number of records
116   // @return status error code
117   Status SetNumRowsInDataset(int64_t num_rows);
118 
119   // Adds a sampler to become our child.
120   // @param std::shared_ptr<DatasetOp> - The sampler to add as a child.
121   // @return Status The status code returned
122   Status AddChild(std::shared_ptr<SamplerRT> child);
123 
124   // A helper function to create a int64_t 1-D Tensor specifically used to hold sampleIds for Sampler
125   // @param std::shared_ptr<Tensor>* sampleIds
126   // @param int64_t numElements - must be a non 0 number
127   // @return Status The status code returned
128   Status CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements);
129 
130   // A print method typically used for debugging
131   // @param out - The output stream to write output to
132   // @param show_all - A bool to control if you want to show all info or just a summary
133   virtual void SamplerPrint(std::ostream &out, bool show_all) const;
134 
135   // << Stream output operator overload
136   // @notes This allows you to write the debug print info using stream operators
137   // @param out - reference to the output stream being overloaded
138   // @param sampler - reference to teh sampler to print
139   // @return - the output stream must be returned
140   friend std::ostream &operator<<(std::ostream &out, const SamplerRT &sampler) {
141     sampler.SamplerPrint(out, false);
142     return out;
143   }
144 
145   // Checks if this sampler has a child sampler.
146   // @return - tre if there is a child sampler, false otherwise.
147   bool HasChildSampler() const;
148 
149   // Uses id as an index for the list of ids generated by the child sampler, and gets the
150   // associated id.
151   // @param int64_t* out_associated_id - Out parameter, contains the associated id.
152   // @param int64_t id - The id used as an index to get the associated child id.
153   // @return Status The status code returned
154   Status GetAssociatedChildId(int64_t *out_associated_id, int64_t id);
155 
156   /// \brief Get the arguments of node
157   /// \param[out] out_json JSON string of all attributes
158   /// \return Status of the function
159   virtual Status to_json(nlohmann::json *out_json);
160 
161  protected:
162   // Number of rows of data from the place this sampler is sampling from. If this sampler
163   // has a child sampler, num_rows_ is the number of ids the child sampler will
164   // output. Otherwise, num_rows_ is the number of rows in the dataset.
165   int64_t num_rows_;
166 
167   // The user may want to sample less than the full amount of data.  num_samples_ reduces the number
168   // of id's returned as request by the user.  Derived classes will choose how to sample the smaller
169   // amount.
170   int64_t num_samples_;
171 
172   bool is_initialized;
173   int64_t samples_per_tensor_;
174   std::unique_ptr<ColDescriptor> col_desc_;
175   std::vector<std::shared_ptr<SamplerRT>> child_;  // Child nodes
176   TensorRow child_ids_;
177 };
178 }  // namespace dataset
179 }  // namespace mindspore
180 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SAMPLER_H_
181