• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/dataset/engine/ir/datasetops/source/samplers/sequential_sampler_ir.h"
18 #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
19 #include "minddata/dataset/core/config_manager.h"
20 
21 #ifndef ENABLE_ANDROID
22 #include "minddata/dataset/util/random.h"
23 #include "minddata/mindrecord/include/shard_distributed_sample.h"
24 #include "minddata/mindrecord/include/shard_operator.h"
25 #include "minddata/mindrecord/include/shard_pk_sample.h"
26 #include "minddata/mindrecord/include/shard_sample.h"
27 #include "minddata/mindrecord/include/shard_sequential_sample.h"
28 #include "minddata/mindrecord/include/shard_shuffle.h"
29 #endif
30 
31 namespace mindspore {
32 namespace dataset {
33 // Constructor
SequentialSamplerObj(int64_t start_index,int64_t num_samples)34 SequentialSamplerObj::SequentialSamplerObj(int64_t start_index, int64_t num_samples)
35     : start_index_(start_index), num_samples_(num_samples) {}
36 
37 // Destructor
38 SequentialSamplerObj::~SequentialSamplerObj() = default;
39 
ValidateParams()40 Status SequentialSamplerObj::ValidateParams() {
41   if (num_samples_ < 0) {
42     RETURN_STATUS_UNEXPECTED("SequentialSampler: num_samples must be greater than or equal to 0, but got: " +
43                              std::to_string(num_samples_));
44   }
45 
46   if (start_index_ < 0) {
47     RETURN_STATUS_UNEXPECTED("SequentialSampler: start_index_ must be greater than or equal to 0, but got: " +
48                              std::to_string(start_index_));
49   }
50 
51   return Status::OK();
52 }
53 
to_json(nlohmann::json * const out_json)54 Status SequentialSamplerObj::to_json(nlohmann::json *const out_json) {
55   nlohmann::json args;
56   RETURN_IF_NOT_OK(SamplerObj::to_json(&args));
57   args["sampler_name"] = "SequentialSampler";
58   args["start_index"] = start_index_;
59   args["num_samples"] = num_samples_;
60   *out_json = args;
61   return Status::OK();
62 }
63 
64 #ifndef ENABLE_ANDROID
from_json(nlohmann::json json_obj,int64_t num_samples,std::shared_ptr<SamplerObj> * sampler)65 Status SequentialSamplerObj::from_json(nlohmann::json json_obj, int64_t num_samples,
66                                        std::shared_ptr<SamplerObj> *sampler) {
67   RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "start_index", "SequentialSampler"));
68   int64_t start_index = json_obj["start_index"];
69   *sampler = std::make_shared<SequentialSamplerObj>(start_index, num_samples);
70   // Run common code in super class to add children samplers
71   RETURN_IF_NOT_OK(SamplerObj::from_json(json_obj, sampler));
72   return Status::OK();
73 }
74 #endif
75 
SamplerBuild(std::shared_ptr<SamplerRT> * sampler)76 Status SequentialSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *sampler) {
77   // runtime sampler object
78   *sampler = std::make_shared<dataset::SequentialSamplerRT>(start_index_, num_samples_);
79   Status s = BuildChildren(sampler);
80   sampler = s.IsOk() ? sampler : nullptr;
81   return s;
82 }
83 
84 #ifndef ENABLE_ANDROID
BuildForMindDataset()85 std::shared_ptr<mindrecord::ShardOperator> SequentialSamplerObj::BuildForMindDataset() {
86   // runtime mindrecord sampler object
87   auto mind_sampler = std::make_shared<mindrecord::ShardSequentialSample>(num_samples_, start_index_);
88 
89   return mind_sampler;
90 }
91 #endif
92 
SamplerCopy()93 std::shared_ptr<SamplerObj> SequentialSamplerObj::SamplerCopy() {
94   auto sampler = std::make_shared<SequentialSamplerObj>(start_index_, num_samples_);
95   for (const auto &child : children_) {
96     Status rc = sampler->AddChildSampler(child);
97     if (rc.IsError()) {
98       MS_LOG(ERROR) << "[Internal ERROR] Error in copying the sampler. Message: " << rc;
99     }
100   }
101   return sampler;
102 }
103 }  // namespace dataset
104 }  // namespace mindspore
105