• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/dataset/engine/ir/datasetops/source/flickr_node.h"
18 
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "minddata/dataset/engine/datasetops/source/flickr_op.h"
25 #ifndef ENABLE_ANDROID
26 #include "minddata/dataset/engine/serdes.h"
27 #endif
28 
29 #include "minddata/dataset/util/status.h"
30 
31 namespace mindspore {
32 namespace dataset {
33 // Constructor for FlickrNode
FlickrNode(const std::string & dataset_dir,const std::string & annotation_file,bool decode,std::shared_ptr<SamplerObj> sampler,std::shared_ptr<DatasetCache> cache)34 FlickrNode::FlickrNode(const std::string &dataset_dir, const std::string &annotation_file, bool decode,
35                        std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetCache> cache)
36     : MappableSourceNode(std::move(cache)),
37       dataset_dir_(dataset_dir),
38       annotation_file_(annotation_file),
39       decode_(decode),
40       sampler_(sampler) {}
41 
Copy()42 std::shared_ptr<DatasetNode> FlickrNode::Copy() {
43   std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
44   auto node = std::make_shared<FlickrNode>(dataset_dir_, annotation_file_, decode_, sampler, cache_);
45   (void)node->SetNumWorkers(num_workers_);
46   (void)node->SetConnectorQueueSize(connector_que_size_);
47   return node;
48 }
49 
Print(std::ostream & out) const50 void FlickrNode::Print(std::ostream &out) const {
51   out << Name() + "(dataset dir:" + dataset_dir_;
52   out << ", annotation file:" + annotation_file_;
53   if (sampler_ != nullptr) {
54     out << ", sampler";
55   }
56   if (cache_ != nullptr) {
57     out << ", cache";
58   }
59   out << ")";
60 }
61 
ValidateParams()62 Status FlickrNode::ValidateParams() {
63   RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
64   RETURN_IF_NOT_OK(ValidateDatasetDirParam("FlickrDataset", dataset_dir_));
65 
66   if (annotation_file_.empty()) {
67     std::string err_msg = "FlickrDataset: 'annotation_file' is not specified.";
68     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
69   }
70 
71   std::vector<char> forbidden_symbols = {':', '*', '?', '"', '<', '>', '|', '`', '&', '\'', ';'};
72   for (char c : annotation_file_) {
73     auto p = std::find(forbidden_symbols.begin(), forbidden_symbols.end(), c);
74     if (p != forbidden_symbols.end()) {
75       std::string err_msg =
76         "FlickrDataset: 'annotation_file': [" + annotation_file_ + "] should not contain :*?\"<>|`&;\'.";
77       LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
78     }
79   }
80 
81   RETURN_IF_NOT_OK(ValidateDatasetFilesParam("FlickrDataset", {annotation_file_}, "annotation file"));
82   RETURN_IF_NOT_OK(ValidateDatasetSampler("FlickrDataset", sampler_));
83   return Status::OK();
84 }
85 
86 // Function to build FlickrOp for Flickr
Build(std::vector<std::shared_ptr<DatasetOp>> * const node_ops)87 Status FlickrNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
88   // Do internal Schema generation.
89   auto schema = std::make_unique<DataSchema>();
90   RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
91   TensorShape scalar = TensorShape::CreateScalar();
92   RETURN_IF_NOT_OK(
93     schema->AddColumn(ColDescriptor("annotation", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &scalar)));
94   std::shared_ptr<SamplerRT> sampler_rt = nullptr;
95   RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
96 
97   auto flickr_op = std::make_shared<FlickrOp>(num_workers_, dataset_dir_, annotation_file_, decode_,
98                                               connector_que_size_, std::move(schema), std::move(sampler_rt));
99   flickr_op->SetTotalRepeats(GetTotalRepeats());
100   flickr_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
101   node_ops->push_back(flickr_op);
102   return Status::OK();
103 }
104 
105 // Get the shard id of node
GetShardId(int32_t * shard_id)106 Status FlickrNode::GetShardId(int32_t *shard_id) {
107   *shard_id = sampler_->ShardId();
108   return Status::OK();
109 }
110 
111 // Get Dataset size
GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> & size_getter,bool estimate,int64_t * dataset_size)112 Status FlickrNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
113                                   int64_t *dataset_size) {
114   if (dataset_size_ > 0) {
115     *dataset_size = dataset_size_;
116     return Status::OK();
117   }
118 
119   int64_t num_rows, sample_size;
120   RETURN_IF_NOT_OK(FlickrOp::CountTotalRows(dataset_dir_, annotation_file_, &num_rows));
121   std::shared_ptr<SamplerRT> sampler_rt = nullptr;
122   RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
123   sample_size = sampler_rt->CalculateNumSamples(num_rows);
124   if (sample_size == -1) {
125     RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
126   }
127 
128   *dataset_size = sample_size;
129   dataset_size_ = *dataset_size;
130   return Status::OK();
131 }
132 
to_json(nlohmann::json * out_json)133 Status FlickrNode::to_json(nlohmann::json *out_json) {
134   nlohmann::json args, sampler_args;
135   RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
136   args["sampler"] = sampler_args;
137   args["num_parallel_workers"] = num_workers_;
138   args["connector_queue_size"] = connector_que_size_;
139   args["dataset_dir"] = dataset_dir_;
140   args["annotation_file"] = annotation_file_;
141   args["decode"] = decode_;
142   if (cache_ != nullptr) {
143     nlohmann::json cache_args;
144     RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
145     args["cache"] = cache_args;
146   }
147   *out_json = args;
148   return Status::OK();
149 }
150 
from_json(nlohmann::json json_obj,std::shared_ptr<DatasetNode> * ds)151 Status FlickrNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
152   RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kFlickrNode));
153   RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "connector_queue_size", kFlickrNode));
154   CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
155   CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("annotation_file") != json_obj.end(), "Failed to find annotation_file");
156   CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode");
157   std::string dataset_dir = json_obj["dataset_dir"];
158   std::string annotation_file = json_obj["annotation_file"];
159   bool decode = json_obj["decode"];
160   std::shared_ptr<SamplerObj> sampler;
161   RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler));
162   std::shared_ptr<DatasetCache> cache = nullptr;
163   RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
164   *ds = std::make_shared<FlickrNode>(dataset_dir, annotation_file, decode, sampler, cache);
165   (void)(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
166   (void)(*ds)->SetConnectorQueueSize(json_obj["connector_queue_size"]);
167   return Status::OK();
168 }
169 }  // namespace dataset
170 }  // namespace mindspore
171