• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_
17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_
18 
19 #include <limits>
20 #include <memory>
21 #include <vector>
22 
23 #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
24 
25 namespace mindspore {
26 namespace dataset {
27 /// Samples elements from a given list of indices.
28 class SubsetSamplerRT : public SamplerRT {
29  public:
30   /// Constructor.
31   /// \param indices List of indices.
32   /// \param num_samples The number of elements to sample. 0 for the full amount.
33   /// \param samples_per_tensor The number of ids we draw on each call to GetNextSample().
34   /// When samples_per_tensor=0, GetNextSample() will draw all the sample ids and return them at once.
35   SubsetSamplerRT(const std::vector<int64_t> &indices, int64_t num_samples,
36                   std::int64_t samples_per_tensor = std::numeric_limits<int64_t>::max());
37 
38   /// Destructor.
39   ~SubsetSamplerRT() = default;
40 
41   /// Initialize the sampler.
42   /// \return Status
43   Status InitSampler() override;
44 
45   /// Reset the internal variable to the initial state and reshuffle the indices.
46   /// \return Status
47   Status ResetSampler() override;
48 
49   /// Get the sample ids.
50   /// \param[out] TensorRow where the sample ids will be placed.
51   /// @note the sample ids (int64_t) will be placed in one Tensor
52   Status GetNextSample(TensorRow *out) override;
53 
54   /// Printer for debugging purposes.
55   /// \param out - output stream to write to
56   /// \param show_all - bool to show detailed vs summary
57   void SamplerPrint(std::ostream &out, bool show_all) const override;
58 
59   /// \brief Get the arguments of node
60   /// \param[out] out_json JSON string of all attributes
61   /// \return Status of the function
62   Status to_json(nlohmann::json *out_json) override;
63 
64   /// Calculate num samples. Unlike GetNumSamples, it is not a getter and doesn't necessarily return the value of
65   /// num_samples_
66   /// \param num_rows the size of the dataset this sampler will be applied to.
67   /// \return number of samples
68   int64_t CalculateNumSamples(int64_t num_rows) override;
69 
70  protected:
71   /// A list of indices (already randomized in constructor).
72   std::vector<int64_t> indices_;
73 
74  private:
75   /// Current sample id.
76   int64_t sample_id_;
77 };
78 }  // namespace dataset
79 }  // namespace mindspore
80 
81 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_
82