• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/dataset/engine/ir/datasetops/source/lsun_node.h"
18 
19 #include "minddata/dataset/engine/datasetops/source/lsun_op.h"
20 #include "minddata/dataset/util/status.h"
21 
22 namespace mindspore {
23 namespace dataset {
LSUNNode(const std::string & dataset_dir,const std::string & usage,const std::vector<std::string> & classes,bool decode,std::shared_ptr<SamplerObj> sampler,std::shared_ptr<DatasetCache> cache=nullptr)24 LSUNNode::LSUNNode(const std::string &dataset_dir, const std::string &usage, const std::vector<std::string> &classes,
25                    bool decode, std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetCache> cache = nullptr)
26     : MappableSourceNode(std::move(cache)),
27       dataset_dir_(dataset_dir),
28       usage_(usage),
29       classes_(classes),
30       decode_(decode),
31       sampler_(sampler) {}
32 
Copy()33 std::shared_ptr<DatasetNode> LSUNNode::Copy() {
34   std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
35   auto node = std::make_shared<LSUNNode>(dataset_dir_, usage_, classes_, decode_, sampler, cache_);
36   return node;
37 }
38 
Print(std::ostream & out) const39 void LSUNNode::Print(std::ostream &out) const {
40   out << (Name() + "(path: " + dataset_dir_ + ", decode: " + (decode_ ? "true" : "false") + ")");
41 }
42 
ValidateParams()43 Status LSUNNode::ValidateParams() {
44   RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
45   RETURN_IF_NOT_OK(ValidateDatasetDirParam("LSUNDataset", dataset_dir_));
46   RETURN_IF_NOT_OK(ValidateDatasetSampler("LSUNDataset", sampler_));
47   RETURN_IF_NOT_OK(ValidateStringValue("LSUNDataset", usage_, {"train", "test", "valid", "all"}));
48   for (auto class_name : classes_) {
49     RETURN_IF_NOT_OK(ValidateStringValue("LSUNDataset", class_name,
50                                          {"bedroom", "bridge", "church_outdoor", "classroom", "conference_room",
51                                           "dining_room", "kitchen", "living_room", "restaurant", "tower"}));
52   }
53 
54   return Status::OK();
55 }
56 
Build(std::vector<std::shared_ptr<DatasetOp>> * const node_ops)57 Status LSUNNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
58   // Do internal Schema generation.
59   // This arg is exist in LSUNOp, but not externalized (in Python API).
60   std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
61   TensorShape scalar = TensorShape::CreateScalar();
62   RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
63   RETURN_IF_NOT_OK(
64     schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
65   std::shared_ptr<SamplerRT> sampler_rt = nullptr;
66   RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
67 
68   auto op = std::make_shared<LSUNOp>(num_workers_, dataset_dir_, connector_que_size_, usage_, classes_, decode_,
69                                      std::move(schema), std::move(sampler_rt));
70   op->SetTotalRepeats(GetTotalRepeats());
71   op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
72   RETURN_UNEXPECTED_IF_NULL(node_ops);
73   node_ops->push_back(op);
74   return Status::OK();
75 }
76 
77 // Get the shard id of node
GetShardId(int32_t * shard_id)78 Status LSUNNode::GetShardId(int32_t *shard_id) {
79   RETURN_UNEXPECTED_IF_NULL(shard_id);
80   *shard_id = sampler_->ShardId();
81 
82   return Status::OK();
83 }
84 
85 // Get Dataset size
GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> & size_getter,bool estimate,int64_t * dataset_size)86 Status LSUNNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
87                                 int64_t *dataset_size) {
88   if (dataset_size_ > 0) {
89     *dataset_size = dataset_size_;
90     return Status::OK();
91   }
92   int64_t sample_size, num_rows;
93   RETURN_IF_NOT_OK(LSUNOp::CountRowsAndClasses(dataset_dir_, usage_, classes_, &num_rows, nullptr));
94   std::shared_ptr<SamplerRT> sampler_rt = nullptr;
95   RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
96   sample_size = sampler_rt->CalculateNumSamples(num_rows);
97   if (sample_size == -1) {
98     RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
99   }
100   *dataset_size = sample_size;
101   dataset_size_ = *dataset_size;
102   return Status::OK();
103 }
104 
to_json(nlohmann::json * out_json)105 Status LSUNNode::to_json(nlohmann::json *out_json) {
106   nlohmann::json args, sampler_args;
107   RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
108   args["sampler"] = sampler_args;
109   args["num_parallel_workers"] = num_workers_;
110   args["dataset_dir"] = dataset_dir_;
111   args["usage"] = usage_;
112   args["classes"] = classes_;
113   args["decode"] = decode_;
114   if (cache_ != nullptr) {
115     nlohmann::json cache_args;
116     RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
117     args["cache"] = cache_args;
118   }
119   *out_json = args;
120   return Status::OK();
121 }
122 }  // namespace dataset
123 }  // namespace mindspore
124