1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/dataset/engine/ir/datasetops/source/flickr_node.h"
18
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <vector>
23
24 #include "minddata/dataset/engine/datasetops/source/flickr_op.h"
25 #ifndef ENABLE_ANDROID
26 #include "minddata/dataset/engine/serdes.h"
27 #endif
28
29 #include "minddata/dataset/util/status.h"
30
31 namespace mindspore {
32 namespace dataset {
33 // Constructor for FlickrNode
FlickrNode(const std::string & dataset_dir,const std::string & annotation_file,bool decode,std::shared_ptr<SamplerObj> sampler,std::shared_ptr<DatasetCache> cache)34 FlickrNode::FlickrNode(const std::string &dataset_dir, const std::string &annotation_file, bool decode,
35 std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetCache> cache)
36 : MappableSourceNode(std::move(cache)),
37 dataset_dir_(dataset_dir),
38 annotation_file_(annotation_file),
39 decode_(decode),
40 sampler_(sampler) {}
41
Copy()42 std::shared_ptr<DatasetNode> FlickrNode::Copy() {
43 std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
44 auto node = std::make_shared<FlickrNode>(dataset_dir_, annotation_file_, decode_, sampler, cache_);
45 return node;
46 }
47
Print(std::ostream & out) const48 void FlickrNode::Print(std::ostream &out) const {
49 out << Name() + "(dataset dir:" + dataset_dir_;
50 out << ", annotation file:" + annotation_file_;
51 if (sampler_ != nullptr) {
52 out << ", sampler";
53 }
54 if (cache_ != nullptr) {
55 out << ", cache";
56 }
57 out << ")";
58 }
59
ValidateParams()60 Status FlickrNode::ValidateParams() {
61 RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
62 RETURN_IF_NOT_OK(ValidateDatasetDirParam("FlickrNode", dataset_dir_));
63
64 if (annotation_file_.empty()) {
65 std::string err_msg = "FlickrNode: annotation_file is not specified.";
66 MS_LOG(ERROR) << err_msg;
67 RETURN_STATUS_SYNTAX_ERROR(err_msg);
68 }
69
70 std::vector<char> forbidden_symbols = {':', '*', '?', '"', '<', '>', '|', '`', '&', '\'', ';'};
71 for (char c : annotation_file_) {
72 auto p = std::find(forbidden_symbols.begin(), forbidden_symbols.end(), c);
73 if (p != forbidden_symbols.end()) {
74 std::string err_msg = "FlickrNode: annotation_file: [" + annotation_file_ + "] should not contain :*?\"<>|`&;\'.";
75 MS_LOG(ERROR) << err_msg;
76 RETURN_STATUS_SYNTAX_ERROR(err_msg);
77 }
78 }
79 Path annotation_file(annotation_file_);
80 if (!annotation_file.Exists()) {
81 std::string err_msg = "FlickrNode: annotation_file: [" + annotation_file_ + "] is invalid or not exist.";
82 MS_LOG(ERROR) << err_msg;
83 RETURN_STATUS_SYNTAX_ERROR(err_msg);
84 }
85
86 RETURN_IF_NOT_OK(ValidateDatasetSampler("FlickrNode", sampler_));
87 return Status::OK();
88 }
89
90 // Function to build FlickrOp for Flickr
Build(std::vector<std::shared_ptr<DatasetOp>> * const node_ops)91 Status FlickrNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
92 // Do internal Schema generation.
93 auto schema = std::make_unique<DataSchema>();
94 RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
95 TensorShape scalar = TensorShape::CreateScalar();
96 RETURN_IF_NOT_OK(
97 schema->AddColumn(ColDescriptor("annotation", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &scalar)));
98 std::shared_ptr<SamplerRT> sampler_rt = nullptr;
99 RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
100
101 auto flickr_op = std::make_shared<FlickrOp>(num_workers_, dataset_dir_, annotation_file_, decode_,
102 connector_que_size_, std::move(schema), std::move(sampler_rt));
103 flickr_op->SetTotalRepeats(GetTotalRepeats());
104 flickr_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
105 node_ops->push_back(flickr_op);
106 return Status::OK();
107 }
108
109 // Get the shard id of node
GetShardId(int32_t * shard_id)110 Status FlickrNode::GetShardId(int32_t *shard_id) {
111 *shard_id = sampler_->ShardId();
112 return Status::OK();
113 }
114
115 // Get Dataset size
GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> & size_getter,bool estimate,int64_t * dataset_size)116 Status FlickrNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
117 int64_t *dataset_size) {
118 if (dataset_size_ > 0) {
119 *dataset_size = dataset_size_;
120 return Status::OK();
121 }
122
123 int64_t num_rows, sample_size;
124 RETURN_IF_NOT_OK(FlickrOp::CountTotalRows(dataset_dir_, annotation_file_, &num_rows));
125 std::shared_ptr<SamplerRT> sampler_rt = nullptr;
126 RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
127 sample_size = sampler_rt->CalculateNumSamples(num_rows);
128 if (sample_size == -1) {
129 RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
130 }
131
132 *dataset_size = sample_size;
133 dataset_size_ = *dataset_size;
134 return Status::OK();
135 }
136
to_json(nlohmann::json * out_json)137 Status FlickrNode::to_json(nlohmann::json *out_json) {
138 nlohmann::json args, sampler_args;
139 RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
140 args["sampler"] = sampler_args;
141 args["num_parallel_workers"] = num_workers_;
142 args["dataset_dir"] = dataset_dir_;
143 args["annotation_file"] = annotation_file_;
144 args["decode"] = decode_;
145 if (cache_ != nullptr) {
146 nlohmann::json cache_args;
147 RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
148 args["cache"] = cache_args;
149 }
150 *out_json = args;
151 return Status::OK();
152 }
153
from_json(nlohmann::json json_obj,std::shared_ptr<DatasetNode> * ds)154 Status FlickrNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
155 CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
156 "Failed to find num_parallel_workers");
157 CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
158 CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("annotation_file") != json_obj.end(), "Failed to find annotation_file");
159 CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode");
160 std::string dataset_dir = json_obj["dataset_dir"];
161 std::string annotation_file = json_obj["annotation_file"];
162 bool decode = json_obj["decode"];
163 std::shared_ptr<SamplerObj> sampler;
164 RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler));
165 std::shared_ptr<DatasetCache> cache = nullptr;
166 RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
167 *ds = std::make_shared<FlickrNode>(dataset_dir, annotation_file, decode, sampler, cache);
168 (*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
169 return Status::OK();
170 }
171 } // namespace dataset
172 } // namespace mindspore
173