1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "minddata/dataset/engine/ir/datasetops/source/caltech256_node.h"
17
18 #include <map>
19 #include <memory>
20 #include <set>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "minddata/dataset/engine/datasetops/source/caltech_op.h"
26 #ifndef ENABLE_ANDROID
27 #include "minddata/dataset/engine/serdes.h"
28 #endif
29 #include "minddata/dataset/util/status.h"
30
31 namespace mindspore {
32 namespace dataset {
33 const std::set<std::string> kExts = {".jpg", ".JPEG"};
34
Caltech256Node(const std::string & dataset_dir,bool decode,std::shared_ptr<SamplerObj> sampler,std::shared_ptr<DatasetCache> cache=nullptr)35 Caltech256Node::Caltech256Node(const std::string &dataset_dir, bool decode, std::shared_ptr<SamplerObj> sampler,
36 std::shared_ptr<DatasetCache> cache = nullptr)
37 : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), decode_(decode), sampler_(sampler) {}
38
Copy()39 std::shared_ptr<DatasetNode> Caltech256Node::Copy() {
40 std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
41 auto node = std::make_shared<Caltech256Node>(dataset_dir_, decode_, sampler, cache_);
42 return node;
43 }
44
Print(std::ostream & out) const45 void Caltech256Node::Print(std::ostream &out) const {
46 out << (Name() + "(path: " + dataset_dir_ + ", decode: " + (decode_ ? "true" : "false") + ")");
47 }
48
ValidateParams()49 Status Caltech256Node::ValidateParams() {
50 RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
51 RETURN_IF_NOT_OK(ValidateDatasetDirParam("Caltech256Node", dataset_dir_));
52 RETURN_IF_NOT_OK(ValidateDatasetSampler("Caltech256Node", sampler_));
53 return Status::OK();
54 }
55
Build(std::vector<std::shared_ptr<DatasetOp>> * const node_ops)56 Status Caltech256Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
57 // Do internal Schema generation.
58 // This arg exists in CaltechOp, but is not externalized (in Python API).
59 std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
60 TensorShape scalar = TensorShape::CreateScalar();
61 RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
62 RETURN_IF_NOT_OK(
63 schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
64 std::shared_ptr<SamplerRT> sampler_rt = nullptr;
65 RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
66
67 auto op = std::make_shared<CaltechOp>(num_workers_, dataset_dir_, connector_que_size_, decode_, std::move(schema),
68 std::move(sampler_rt));
69 op->SetTotalRepeats(GetTotalRepeats());
70 op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
71 node_ops->push_back(op);
72 return Status::OK();
73 }
74
75 // Get the shard id of node.
GetShardId(int32_t * shard_id)76 Status Caltech256Node::GetShardId(int32_t *shard_id) {
77 *shard_id = sampler_->ShardId();
78 return Status::OK();
79 }
80
81 // Get Dataset size.
GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> & size_getter,bool estimate,int64_t * dataset_size)82 Status Caltech256Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
83 int64_t *dataset_size) {
84 if (dataset_size_ > 0) {
85 *dataset_size = dataset_size_;
86 return Status::OK();
87 }
88 int64_t sample_size, num_rows;
89 RETURN_IF_NOT_OK(CaltechOp::CountRowsAndClasses(dataset_dir_, kExts, &num_rows, nullptr, {}));
90 std::shared_ptr<SamplerRT> sampler_rt = nullptr;
91 RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
92 sample_size = sampler_rt->CalculateNumSamples(num_rows);
93 if (sample_size == -1) {
94 RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
95 }
96 *dataset_size = sample_size;
97 dataset_size_ = *dataset_size;
98 return Status::OK();
99 }
100
to_json(nlohmann::json * out_json)101 Status Caltech256Node::to_json(nlohmann::json *out_json) {
102 nlohmann::json args, sampler_args;
103 RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
104 args["sampler"] = sampler_args;
105 args["num_parallel_workers"] = num_workers_;
106 args["connector_queue_size"] = connector_que_size_;
107 args["dataset_dir"] = dataset_dir_;
108 args["decode"] = decode_;
109 if (cache_ != nullptr) {
110 nlohmann::json cache_args;
111 RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
112 args["cache"] = cache_args;
113 }
114 *out_json = args;
115 return Status::OK();
116 }
117
118 #ifndef ENABLE_ANDROID
from_json(nlohmann::json json_obj,std::shared_ptr<DatasetNode> * ds)119 Status Caltech256Node::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
120 RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kCaltech256Node));
121 RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "connector_queue_size", kCaltech256Node));
122 RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kCaltech256Node));
123 RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "decode", kCaltech256Node));
124 RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "sampler", kCaltech256Node));
125 std::string dataset_dir = json_obj["dataset_dir"];
126 bool decode = json_obj["decode"];
127 std::shared_ptr<SamplerObj> sampler;
128 RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler));
129 std::shared_ptr<DatasetCache> cache = nullptr;
130 RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
131 *ds = std::make_shared<Caltech256Node>(dataset_dir, decode, sampler, cache);
132 (void)(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
133 (void)(*ds)->SetConnectorQueueSize(json_obj["connector_queue_size"]);
134 return Status::OK();
135 }
136 #endif
137 } // namespace dataset
138 } // namespace mindspore
139