• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PK_SAMPLER_H_
17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PK_SAMPLER_H_
18 
19 #include <limits>
20 #include <map>
21 #include <memory>
22 #include <random>
23 #include <vector>
24 
25 #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
26 
27 namespace mindspore {
28 namespace dataset {
29 class PKSamplerRT : public SamplerRT {  // NOT YET FINISHED
30  public:
31   // @param int64_t num_val
32   // @param bool shuffle - shuffle all classIds or not, if true, classes may be 5,1,4,3,2
33   // @param num_samples - the number of samples to draw.  value of 0 means to take the full amount
34   // @param int64_t samples_per_tensor - Num of Sampler Ids to fetch via 1 GetNextSample call
35   PKSamplerRT(int64_t num_val, bool shuffle, int64_t num_samples,
36               int64_t samples_per_tensor = std::numeric_limits<int64_t>::max());
37 
38   // default destructor
39   ~PKSamplerRT() = default;
40 
41   // @param TensorRow
42   // @param int32_t workerId
43   // @return Status The status code returned
44   Status GetNextSample(TensorRow *out) override;
45 
46   // first handshake between leaf source op and Sampler. This func will determine the amount of data
47   // in the dataset that we can sample from.
48   // @param op - leaf op pointer, pass in so Sampler can ask it about how much data there is
49   // @return
50   Status HandshakeRandomAccessOp(const RandomAccessOp *op) override;
51 
52   // init sampler, to be called by python or Handshake
53   Status InitSampler() override;
54 
55   // for next epoch of sampleIds
56   // @return Status The status code returned
57   Status ResetSampler() override;
58 
59   // Printer for debugging purposes.
60   // @param out - output stream to write to
61   // @param show_all - bool to show detailed vs summary
62   void SamplerPrint(std::ostream &out, bool show_all) const override;
63 
64   /// \brief Get the arguments of node
65   /// \param[out] out_json JSON string of all attributes
66   /// \return Status of the function
67   Status to_json(nlohmann::json *out_json) override;
68 
69   /// \brief PK cannot return an exact value because num_classes is not known until runtime, hence -1 is used
70   /// \param[out] num_rows
71   /// \return -1, which means PKSampler doesn't know how much data
CalculateNumSamples(int64_t num_rows)72   int64_t CalculateNumSamples(int64_t num_rows) override { return -1; }
73 
74  private:
75   bool shuffle_;
76   uint32_t seed_;
77   int64_t next_id_;
78   int64_t samples_per_class_;
79   std::mt19937 rnd_;
80   std::vector<int64_t> labels_;
81   std::map<int32_t, std::vector<int64_t>> label_to_ids_;
82 };
83 }  // namespace dataset
84 }  // namespace mindspore
85 
86 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PK_SAMPLER_H_
87