• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/dataset/engine/ir/datasetops/source/flickr_node.h"
18 
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "minddata/dataset/engine/datasetops/source/flickr_op.h"
25 #ifndef ENABLE_ANDROID
26 #include "minddata/dataset/engine/serdes.h"
27 #endif
28 
29 #include "minddata/dataset/util/status.h"
30 
31 namespace mindspore {
32 namespace dataset {
33 // Constructor for FlickrNode
FlickrNode(const std::string & dataset_dir,const std::string & annotation_file,bool decode,std::shared_ptr<SamplerObj> sampler,std::shared_ptr<DatasetCache> cache)34 FlickrNode::FlickrNode(const std::string &dataset_dir, const std::string &annotation_file, bool decode,
35                        std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetCache> cache)
36     : MappableSourceNode(std::move(cache)),
37       dataset_dir_(dataset_dir),
38       annotation_file_(annotation_file),
39       decode_(decode),
40       sampler_(sampler) {}
41 
Copy()42 std::shared_ptr<DatasetNode> FlickrNode::Copy() {
43   std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
44   auto node = std::make_shared<FlickrNode>(dataset_dir_, annotation_file_, decode_, sampler, cache_);
45   return node;
46 }
47 
Print(std::ostream & out) const48 void FlickrNode::Print(std::ostream &out) const {
49   out << Name() + "(dataset dir:" + dataset_dir_;
50   out << ", annotation file:" + annotation_file_;
51   if (sampler_ != nullptr) {
52     out << ", sampler";
53   }
54   if (cache_ != nullptr) {
55     out << ", cache";
56   }
57   out << ")";
58 }
59 
ValidateParams()60 Status FlickrNode::ValidateParams() {
61   RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
62   RETURN_IF_NOT_OK(ValidateDatasetDirParam("FlickrNode", dataset_dir_));
63 
64   if (annotation_file_.empty()) {
65     std::string err_msg = "FlickrNode: annotation_file is not specified.";
66     MS_LOG(ERROR) << err_msg;
67     RETURN_STATUS_SYNTAX_ERROR(err_msg);
68   }
69 
70   std::vector<char> forbidden_symbols = {':', '*', '?', '"', '<', '>', '|', '`', '&', '\'', ';'};
71   for (char c : annotation_file_) {
72     auto p = std::find(forbidden_symbols.begin(), forbidden_symbols.end(), c);
73     if (p != forbidden_symbols.end()) {
74       std::string err_msg = "FlickrNode: annotation_file: [" + annotation_file_ + "] should not contain :*?\"<>|`&;\'.";
75       MS_LOG(ERROR) << err_msg;
76       RETURN_STATUS_SYNTAX_ERROR(err_msg);
77     }
78   }
79   Path annotation_file(annotation_file_);
80   if (!annotation_file.Exists()) {
81     std::string err_msg = "FlickrNode: annotation_file: [" + annotation_file_ + "] is invalid or not exist.";
82     MS_LOG(ERROR) << err_msg;
83     RETURN_STATUS_SYNTAX_ERROR(err_msg);
84   }
85 
86   RETURN_IF_NOT_OK(ValidateDatasetSampler("FlickrNode", sampler_));
87   return Status::OK();
88 }
89 
90 // Function to build FlickrOp for Flickr
Build(std::vector<std::shared_ptr<DatasetOp>> * const node_ops)91 Status FlickrNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
92   // Do internal Schema generation.
93   auto schema = std::make_unique<DataSchema>();
94   RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
95   TensorShape scalar = TensorShape::CreateScalar();
96   RETURN_IF_NOT_OK(
97     schema->AddColumn(ColDescriptor("annotation", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &scalar)));
98   std::shared_ptr<SamplerRT> sampler_rt = nullptr;
99   RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
100 
101   auto flickr_op = std::make_shared<FlickrOp>(num_workers_, dataset_dir_, annotation_file_, decode_,
102                                               connector_que_size_, std::move(schema), std::move(sampler_rt));
103   flickr_op->SetTotalRepeats(GetTotalRepeats());
104   flickr_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
105   node_ops->push_back(flickr_op);
106   return Status::OK();
107 }
108 
109 // Get the shard id of node
GetShardId(int32_t * shard_id)110 Status FlickrNode::GetShardId(int32_t *shard_id) {
111   *shard_id = sampler_->ShardId();
112   return Status::OK();
113 }
114 
115 // Get Dataset size
GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> & size_getter,bool estimate,int64_t * dataset_size)116 Status FlickrNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
117                                   int64_t *dataset_size) {
118   if (dataset_size_ > 0) {
119     *dataset_size = dataset_size_;
120     return Status::OK();
121   }
122 
123   int64_t num_rows, sample_size;
124   RETURN_IF_NOT_OK(FlickrOp::CountTotalRows(dataset_dir_, annotation_file_, &num_rows));
125   std::shared_ptr<SamplerRT> sampler_rt = nullptr;
126   RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
127   sample_size = sampler_rt->CalculateNumSamples(num_rows);
128   if (sample_size == -1) {
129     RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
130   }
131 
132   *dataset_size = sample_size;
133   dataset_size_ = *dataset_size;
134   return Status::OK();
135 }
136 
to_json(nlohmann::json * out_json)137 Status FlickrNode::to_json(nlohmann::json *out_json) {
138   nlohmann::json args, sampler_args;
139   RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
140   args["sampler"] = sampler_args;
141   args["num_parallel_workers"] = num_workers_;
142   args["dataset_dir"] = dataset_dir_;
143   args["annotation_file"] = annotation_file_;
144   args["decode"] = decode_;
145   if (cache_ != nullptr) {
146     nlohmann::json cache_args;
147     RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
148     args["cache"] = cache_args;
149   }
150   *out_json = args;
151   return Status::OK();
152 }
153 
from_json(nlohmann::json json_obj,std::shared_ptr<DatasetNode> * ds)154 Status FlickrNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
155   CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
156                                "Failed to find num_parallel_workers");
157   CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
158   CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("annotation_file") != json_obj.end(), "Failed to find annotation_file");
159   CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode");
160   std::string dataset_dir = json_obj["dataset_dir"];
161   std::string annotation_file = json_obj["annotation_file"];
162   bool decode = json_obj["decode"];
163   std::shared_ptr<SamplerObj> sampler;
164   RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler));
165   std::shared_ptr<DatasetCache> cache = nullptr;
166   RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
167   *ds = std::make_shared<FlickrNode>(dataset_dir, annotation_file, decode, sampler, cache);
168   (*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
169   return Status::OK();
170 }
171 }  // namespace dataset
172 }  // namespace mindspore
173