1 /**
2 * Copyright 2021-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/dataset/engine/ir/datasetops/source/flickr_node.h"
18
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <vector>
23
24 #include "minddata/dataset/engine/datasetops/source/flickr_op.h"
25 #ifndef ENABLE_ANDROID
26 #include "minddata/dataset/engine/serdes.h"
27 #endif
28
29 #include "minddata/dataset/util/status.h"
30
31 namespace mindspore {
32 namespace dataset {
33 // Constructor for FlickrNode
FlickrNode(const std::string & dataset_dir,const std::string & annotation_file,bool decode,std::shared_ptr<SamplerObj> sampler,std::shared_ptr<DatasetCache> cache)34 FlickrNode::FlickrNode(const std::string &dataset_dir, const std::string &annotation_file, bool decode,
35 std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetCache> cache)
36 : MappableSourceNode(std::move(cache)),
37 dataset_dir_(dataset_dir),
38 annotation_file_(annotation_file),
39 decode_(decode),
40 sampler_(sampler) {}
41
Copy()42 std::shared_ptr<DatasetNode> FlickrNode::Copy() {
43 std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
44 auto node = std::make_shared<FlickrNode>(dataset_dir_, annotation_file_, decode_, sampler, cache_);
45 (void)node->SetNumWorkers(num_workers_);
46 (void)node->SetConnectorQueueSize(connector_que_size_);
47 return node;
48 }
49
Print(std::ostream & out) const50 void FlickrNode::Print(std::ostream &out) const {
51 out << Name() + "(dataset dir:" + dataset_dir_;
52 out << ", annotation file:" + annotation_file_;
53 if (sampler_ != nullptr) {
54 out << ", sampler";
55 }
56 if (cache_ != nullptr) {
57 out << ", cache";
58 }
59 out << ")";
60 }
61
ValidateParams()62 Status FlickrNode::ValidateParams() {
63 RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
64 RETURN_IF_NOT_OK(ValidateDatasetDirParam("FlickrDataset", dataset_dir_));
65
66 if (annotation_file_.empty()) {
67 std::string err_msg = "FlickrDataset: 'annotation_file' is not specified.";
68 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
69 }
70
71 std::vector<char> forbidden_symbols = {':', '*', '?', '"', '<', '>', '|', '`', '&', '\'', ';'};
72 for (char c : annotation_file_) {
73 auto p = std::find(forbidden_symbols.begin(), forbidden_symbols.end(), c);
74 if (p != forbidden_symbols.end()) {
75 std::string err_msg =
76 "FlickrDataset: 'annotation_file': [" + annotation_file_ + "] should not contain :*?\"<>|`&;\'.";
77 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
78 }
79 }
80
81 RETURN_IF_NOT_OK(ValidateDatasetFilesParam("FlickrDataset", {annotation_file_}, "annotation file"));
82 RETURN_IF_NOT_OK(ValidateDatasetSampler("FlickrDataset", sampler_));
83 return Status::OK();
84 }
85
86 // Function to build FlickrOp for Flickr
Build(std::vector<std::shared_ptr<DatasetOp>> * const node_ops)87 Status FlickrNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
88 // Do internal Schema generation.
89 auto schema = std::make_unique<DataSchema>();
90 RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
91 TensorShape scalar = TensorShape::CreateScalar();
92 RETURN_IF_NOT_OK(
93 schema->AddColumn(ColDescriptor("annotation", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &scalar)));
94 std::shared_ptr<SamplerRT> sampler_rt = nullptr;
95 RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
96
97 auto flickr_op = std::make_shared<FlickrOp>(num_workers_, dataset_dir_, annotation_file_, decode_,
98 connector_que_size_, std::move(schema), std::move(sampler_rt));
99 flickr_op->SetTotalRepeats(GetTotalRepeats());
100 flickr_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
101 node_ops->push_back(flickr_op);
102 return Status::OK();
103 }
104
105 // Get the shard id of node
GetShardId(int32_t * shard_id)106 Status FlickrNode::GetShardId(int32_t *shard_id) {
107 *shard_id = sampler_->ShardId();
108 return Status::OK();
109 }
110
111 // Get Dataset size
GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> & size_getter,bool estimate,int64_t * dataset_size)112 Status FlickrNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
113 int64_t *dataset_size) {
114 if (dataset_size_ > 0) {
115 *dataset_size = dataset_size_;
116 return Status::OK();
117 }
118
119 int64_t num_rows, sample_size;
120 RETURN_IF_NOT_OK(FlickrOp::CountTotalRows(dataset_dir_, annotation_file_, &num_rows));
121 std::shared_ptr<SamplerRT> sampler_rt = nullptr;
122 RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
123 sample_size = sampler_rt->CalculateNumSamples(num_rows);
124 if (sample_size == -1) {
125 RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
126 }
127
128 *dataset_size = sample_size;
129 dataset_size_ = *dataset_size;
130 return Status::OK();
131 }
132
to_json(nlohmann::json * out_json)133 Status FlickrNode::to_json(nlohmann::json *out_json) {
134 nlohmann::json args, sampler_args;
135 RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
136 args["sampler"] = sampler_args;
137 args["num_parallel_workers"] = num_workers_;
138 args["connector_queue_size"] = connector_que_size_;
139 args["dataset_dir"] = dataset_dir_;
140 args["annotation_file"] = annotation_file_;
141 args["decode"] = decode_;
142 if (cache_ != nullptr) {
143 nlohmann::json cache_args;
144 RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
145 args["cache"] = cache_args;
146 }
147 *out_json = args;
148 return Status::OK();
149 }
150
from_json(nlohmann::json json_obj,std::shared_ptr<DatasetNode> * ds)151 Status FlickrNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
152 RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kFlickrNode));
153 RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "connector_queue_size", kFlickrNode));
154 CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
155 CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("annotation_file") != json_obj.end(), "Failed to find annotation_file");
156 CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode");
157 std::string dataset_dir = json_obj["dataset_dir"];
158 std::string annotation_file = json_obj["annotation_file"];
159 bool decode = json_obj["decode"];
160 std::shared_ptr<SamplerObj> sampler;
161 RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler));
162 std::shared_ptr<DatasetCache> cache = nullptr;
163 RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
164 *ds = std::make_shared<FlickrNode>(dataset_dir, annotation_file, decode, sampler, cache);
165 (void)(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
166 (void)(*ds)->SetConnectorQueueSize(json_obj["connector_queue_size"]);
167 return Status::OK();
168 }
169 } // namespace dataset
170 } // namespace mindspore
171