1 /** 2 * Copyright 2020-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SAMPLERS_SAMPLERS_IR_H_ 18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SAMPLERS_SAMPLERS_IR_H_ 19 20 #include <limits> 21 #include <memory> 22 #include <string> 23 #include <utility> 24 #include <vector> 25 #include <nlohmann/json.hpp> 26 27 #include "include/api/status.h" 28 #ifndef ENABLE_ANDROID 29 #include "minddata/mindrecord/include/shard_operator.h" 30 #endif 31 32 namespace mindspore { 33 namespace dataset { 34 35 // Internal Sampler class forward declaration 36 class SamplerRT; 37 38 class SamplerObj { 39 public: 40 /// \brief Constructor 41 SamplerObj(); 42 43 /// \brief Destructor 44 ~SamplerObj(); 45 46 /// \brief Pure virtual function for derived class to implement parameters validation 47 /// \return The Status code of the function. It returns OK status if parameters are valid. 48 virtual Status ValidateParams() = 0; 49 50 /// \brief Pure virtual function to convert a SamplerObj class into a runtime sampler object 51 /// \param[out] sampler Shared pointers to the newly created Sampler 52 /// \return The Status code of the function. It returns OK status if sampler is created successfully. 53 virtual Status SamplerBuild(std::shared_ptr<SamplerRT> *sampler) = 0; 54 55 /// \brief Pure virtual function to copy a SamplerObj class 56 /// \return Shared pointers to the newly copied SamplerObj 57 virtual std::shared_ptr<SamplerObj> SamplerCopy() = 0; 58 59 /// \brief Function for derived class to get the shard id of sampler 60 /// \return The shard id of the derived sampler ShardId()61 virtual int64_t ShardId() { return 0; } 62 63 /// \brief Adds a child to the sampler 64 /// \param[in] child The sampler to be added as child 65 /// \return the Status code returned 66 Status AddChildSampler(std::shared_ptr<SamplerObj> child); 67 68 virtual Status to_json(nlohmann::json *const out_json); 69 70 #ifndef ENABLE_ANDROID 71 /// \brief Function to construct children samplers 72 /// \param[in] json_obj The JSON object to be deserialized 73 /// \param[out] parent_sampler given parent sampler, output constructed parent sampler with children samplers added 74 /// \return Status The status code returned 75 static Status from_json(nlohmann::json json_obj, std::shared_ptr<SamplerObj> *parent_sampler); 76 #endif 77 GetChild()78 std::vector<std::shared_ptr<SamplerObj>> GetChild() { return children_; } 79 80 #ifndef ENABLE_ANDROID 81 /// \brief Virtual function to convert a SamplerObj class into a runtime mindrecord sampler object, 82 /// only override by SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler 83 /// \return Shared pointers to the newly created Sampler BuildForMindDataset()84 virtual std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() { return nullptr; } 85 #endif 86 87 protected: 88 /// \brief A function that calls build on the children of this sampler 89 /// \param[in] sampler The samplerRT object built from this sampler 90 /// \return the Status code returned 91 Status BuildChildren(std::shared_ptr<SamplerRT> *const sampler); 92 93 std::vector<std::shared_ptr<SamplerObj>> children_; 94 }; 95 96 } // namespace dataset 97 } // namespace mindspore 98 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SAMPLERS_SAMPLERS_IR_H_ 99