• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h"
18 #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
19 #ifndef ENABLE_ANDROID
20 #include "minddata/dataset/engine/serdes.h"
21 #endif
22 
23 #include "minddata/dataset/core/config_manager.h"
24 
25 namespace mindspore {
26 namespace dataset {
27 
28 // Constructor
SamplerObj()29 SamplerObj::SamplerObj() {}
30 
31 // Destructor
32 SamplerObj::~SamplerObj() = default;
33 
BuildChildren(std::shared_ptr<SamplerRT> * const sampler)34 Status SamplerObj::BuildChildren(std::shared_ptr<SamplerRT> *const sampler) {
35   for (auto child : children_) {
36     std::shared_ptr<SamplerRT> sampler_rt = nullptr;
37     RETURN_IF_NOT_OK(child->SamplerBuild(&sampler_rt));
38     RETURN_IF_NOT_OK((*sampler)->AddChild(sampler_rt));
39   }
40   return Status::OK();
41 }
42 
AddChildSampler(std::shared_ptr<SamplerObj> child)43 Status SamplerObj::AddChildSampler(std::shared_ptr<SamplerObj> child) {
44   if (child == nullptr) {
45     return Status::OK();
46   }
47 
48   // Only samplers can be added, not any other DatasetOp.
49   std::shared_ptr<SamplerObj> sampler = std::dynamic_pointer_cast<SamplerObj>(child);
50   if (!sampler) {
51     RETURN_STATUS_UNEXPECTED("Cannot add child, child is not a sampler object.");
52   }
53 
54   // Samplers can have at most 1 child.
55   if (!children_.empty()) {
56     RETURN_STATUS_UNEXPECTED("Cannot add child sampler, this sampler already has a child.");
57   }
58 
59   children_.push_back(child);
60 
61   return Status::OK();
62 }
63 
to_json(nlohmann::json * const out_json)64 Status SamplerObj::to_json(nlohmann::json *const out_json) {
65   nlohmann::json args;
66   if (!children_.empty()) {
67     std::vector<nlohmann::json> children_args;
68     for (auto child : children_) {
69       nlohmann::json child_arg;
70       RETURN_IF_NOT_OK(child->to_json(&child_arg));
71       children_args.push_back(child_arg);
72     }
73     args["child_sampler"] = children_args;
74   }
75   *out_json = args;
76   return Status::OK();
77 }
78 
79 #ifndef ENABLE_ANDROID
from_json(nlohmann::json json_obj,std::shared_ptr<SamplerObj> * parent_sampler)80 Status SamplerObj::from_json(nlohmann::json json_obj, std::shared_ptr<SamplerObj> *parent_sampler) {
81   for (nlohmann::json child : json_obj["child_sampler"]) {
82     std::shared_ptr<SamplerObj> child_sampler;
83     RETURN_IF_NOT_OK(Serdes::ConstructSampler(child, &child_sampler));
84     (*parent_sampler)->AddChildSampler(child_sampler);
85   }
86   return Status::OK();
87 }
88 #endif
89 }  // namespace dataset
90 }  // namespace mindspore
91