1 /**
2 * Copyright 2021-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/dataset/engine/ir/datasetops/source/lsun_node.h"
18
19 #include "minddata/dataset/engine/datasetops/source/lsun_op.h"
20 #include "minddata/dataset/util/status.h"
21
22 namespace mindspore {
23 namespace dataset {
LSUNNode(const std::string & dataset_dir,const std::string & usage,const std::vector<std::string> & classes,bool decode,std::shared_ptr<SamplerObj> sampler,std::shared_ptr<DatasetCache> cache=nullptr)24 LSUNNode::LSUNNode(const std::string &dataset_dir, const std::string &usage, const std::vector<std::string> &classes,
25 bool decode, std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetCache> cache = nullptr)
26 : MappableSourceNode(std::move(cache)),
27 dataset_dir_(dataset_dir),
28 usage_(usage),
29 classes_(classes),
30 decode_(decode),
31 sampler_(sampler) {}
32
Copy()33 std::shared_ptr<DatasetNode> LSUNNode::Copy() {
34 std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
35 auto node = std::make_shared<LSUNNode>(dataset_dir_, usage_, classes_, decode_, sampler, cache_);
36 return node;
37 }
38
Print(std::ostream & out) const39 void LSUNNode::Print(std::ostream &out) const {
40 out << (Name() + "(path: " + dataset_dir_ + ", decode: " + (decode_ ? "true" : "false") + ")");
41 }
42
ValidateParams()43 Status LSUNNode::ValidateParams() {
44 RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
45 RETURN_IF_NOT_OK(ValidateDatasetDirParam("LSUNDataset", dataset_dir_));
46 RETURN_IF_NOT_OK(ValidateDatasetSampler("LSUNDataset", sampler_));
47 RETURN_IF_NOT_OK(ValidateStringValue("LSUNDataset", usage_, {"train", "test", "valid", "all"}));
48 for (auto class_name : classes_) {
49 RETURN_IF_NOT_OK(ValidateStringValue("LSUNDataset", class_name,
50 {"bedroom", "bridge", "church_outdoor", "classroom", "conference_room",
51 "dining_room", "kitchen", "living_room", "restaurant", "tower"}));
52 }
53
54 return Status::OK();
55 }
56
Build(std::vector<std::shared_ptr<DatasetOp>> * const node_ops)57 Status LSUNNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
58 // Do internal Schema generation.
59 // This arg is exist in LSUNOp, but not externalized (in Python API).
60 std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
61 TensorShape scalar = TensorShape::CreateScalar();
62 RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
63 RETURN_IF_NOT_OK(
64 schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
65 std::shared_ptr<SamplerRT> sampler_rt = nullptr;
66 RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
67
68 auto op = std::make_shared<LSUNOp>(num_workers_, dataset_dir_, connector_que_size_, usage_, classes_, decode_,
69 std::move(schema), std::move(sampler_rt));
70 op->SetTotalRepeats(GetTotalRepeats());
71 op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
72 RETURN_UNEXPECTED_IF_NULL(node_ops);
73 node_ops->push_back(op);
74 return Status::OK();
75 }
76
77 // Get the shard id of node
GetShardId(int32_t * shard_id)78 Status LSUNNode::GetShardId(int32_t *shard_id) {
79 RETURN_UNEXPECTED_IF_NULL(shard_id);
80 *shard_id = sampler_->ShardId();
81
82 return Status::OK();
83 }
84
85 // Get Dataset size
GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> & size_getter,bool estimate,int64_t * dataset_size)86 Status LSUNNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
87 int64_t *dataset_size) {
88 if (dataset_size_ > 0) {
89 *dataset_size = dataset_size_;
90 return Status::OK();
91 }
92 int64_t sample_size, num_rows;
93 RETURN_IF_NOT_OK(LSUNOp::CountRowsAndClasses(dataset_dir_, usage_, classes_, &num_rows, nullptr));
94 std::shared_ptr<SamplerRT> sampler_rt = nullptr;
95 RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
96 sample_size = sampler_rt->CalculateNumSamples(num_rows);
97 if (sample_size == -1) {
98 RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
99 }
100 *dataset_size = sample_size;
101 dataset_size_ = *dataset_size;
102 return Status::OK();
103 }
104
to_json(nlohmann::json * out_json)105 Status LSUNNode::to_json(nlohmann::json *out_json) {
106 nlohmann::json args, sampler_args;
107 RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
108 args["sampler"] = sampler_args;
109 args["num_parallel_workers"] = num_workers_;
110 args["dataset_dir"] = dataset_dir_;
111 args["usage"] = usage_;
112 args["classes"] = classes_;
113 args["decode"] = decode_;
114 if (cache_ != nullptr) {
115 nlohmann::json cache_args;
116 RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
117 args["cache"] = cache_args;
118 }
119 *out_json = args;
120 return Status::OK();
121 }
122 } // namespace dataset
123 } // namespace mindspore
124