• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SAMPLERS_SAMPLERS_IR_H_
18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SAMPLERS_SAMPLERS_IR_H_
19 
20 #include <limits>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 #include <nlohmann/json.hpp>
26 
27 #include "include/api/status.h"
28 #ifndef ENABLE_ANDROID
29 #include "minddata/mindrecord/include/shard_operator.h"
30 #endif
31 
32 namespace mindspore {
33 namespace dataset {
34 
35 // Internal Sampler class forward declaration
36 class SamplerRT;
37 
38 class SamplerObj {
39  public:
40   /// \brief Constructor
41   SamplerObj();
42 
43   /// \brief Destructor
44   ~SamplerObj();
45 
46   /// \brief Pure virtual function for derived class to implement parameters validation
47   /// \return The Status code of the function. It returns OK status if parameters are valid.
48   virtual Status ValidateParams() = 0;
49 
50   /// \brief Pure virtual function to convert a SamplerObj class into a runtime sampler object
51   /// \param[out] sampler Shared pointers to the newly created Sampler
52   /// \return The Status code of the function. It returns OK status if sampler is created successfully.
53   virtual Status SamplerBuild(std::shared_ptr<SamplerRT> *sampler) = 0;
54 
55   /// \brief Pure virtual function to copy a SamplerObj class
56   /// \return Shared pointers to the newly copied SamplerObj
57   virtual std::shared_ptr<SamplerObj> SamplerCopy() = 0;
58 
59   /// \brief Function for derived class to get the shard id of sampler
60   /// \return The shard id of the derived sampler
ShardId()61   virtual int64_t ShardId() { return 0; }
62 
63   /// \brief Adds a child to the sampler
64   /// \param[in] child The sampler to be added as child
65   /// \return the Status code returned
66   Status AddChildSampler(std::shared_ptr<SamplerObj> child);
67 
68   virtual Status to_json(nlohmann::json *const out_json);
69 
70 #ifndef ENABLE_ANDROID
71   /// \brief Function to construct children samplers
72   /// \param[in] json_obj The JSON object to be deserialized
73   /// \param[out] parent_sampler given parent sampler, output constructed parent sampler with children samplers added
74   /// \return Status The status code returned
75   static Status from_json(nlohmann::json json_obj, std::shared_ptr<SamplerObj> *parent_sampler);
76 #endif
77 
GetChild()78   std::vector<std::shared_ptr<SamplerObj>> GetChild() { return children_; }
79 
80 #ifndef ENABLE_ANDROID
81   /// \brief Virtual function to convert a SamplerObj class into a runtime mindrecord sampler object,
82   ///     only override by SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler
83   /// \return Shared pointers to the newly created Sampler
BuildForMindDataset()84   virtual std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() { return nullptr; }
85 #endif
86 
87  protected:
88   /// \brief A function that calls build on the children of this sampler
89   /// \param[in] sampler The samplerRT object built from this sampler
90   /// \return the Status code returned
91   Status BuildChildren(std::shared_ptr<SamplerRT> *const sampler);
92 
93   std::vector<std::shared_ptr<SamplerObj>> children_;
94 };
95 
96 }  // namespace dataset
97 }  // namespace mindspore
98 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SAMPLERS_SAMPLERS_IR_H_
99