• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "minddata/dataset/engine/ir/datasetops/source/caltech256_node.h"
17 
18 #include <map>
19 #include <memory>
20 #include <set>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "minddata/dataset/engine/datasetops/source/caltech_op.h"
26 #ifndef ENABLE_ANDROID
27 #include "minddata/dataset/engine/serdes.h"
28 #endif
29 #include "minddata/dataset/util/status.h"
30 
31 namespace mindspore {
32 namespace dataset {
33 const std::set<std::string> kExts = {".jpg", ".JPEG"};
34 
Caltech256Node(const std::string & dataset_dir,bool decode,std::shared_ptr<SamplerObj> sampler,std::shared_ptr<DatasetCache> cache=nullptr)35 Caltech256Node::Caltech256Node(const std::string &dataset_dir, bool decode, std::shared_ptr<SamplerObj> sampler,
36                                std::shared_ptr<DatasetCache> cache = nullptr)
37     : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), decode_(decode), sampler_(sampler) {}
38 
Copy()39 std::shared_ptr<DatasetNode> Caltech256Node::Copy() {
40   std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
41   auto node = std::make_shared<Caltech256Node>(dataset_dir_, decode_, sampler, cache_);
42   return node;
43 }
44 
Print(std::ostream & out) const45 void Caltech256Node::Print(std::ostream &out) const {
46   out << (Name() + "(path: " + dataset_dir_ + ", decode: " + (decode_ ? "true" : "false") + ")");
47 }
48 
ValidateParams()49 Status Caltech256Node::ValidateParams() {
50   RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
51   RETURN_IF_NOT_OK(ValidateDatasetDirParam("Caltech256Node", dataset_dir_));
52   RETURN_IF_NOT_OK(ValidateDatasetSampler("Caltech256Node", sampler_));
53   return Status::OK();
54 }
55 
Build(std::vector<std::shared_ptr<DatasetOp>> * const node_ops)56 Status Caltech256Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
57   // Do internal Schema generation.
58   // This arg exists in CaltechOp, but is not externalized (in Python API).
59   std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
60   TensorShape scalar = TensorShape::CreateScalar();
61   RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
62   RETURN_IF_NOT_OK(
63     schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
64   std::shared_ptr<SamplerRT> sampler_rt = nullptr;
65   RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
66 
67   auto op = std::make_shared<CaltechOp>(num_workers_, dataset_dir_, connector_que_size_, decode_, std::move(schema),
68                                         std::move(sampler_rt));
69   op->SetTotalRepeats(GetTotalRepeats());
70   op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
71   node_ops->push_back(op);
72   return Status::OK();
73 }
74 
75 // Get the shard id of node.
GetShardId(int32_t * shard_id)76 Status Caltech256Node::GetShardId(int32_t *shard_id) {
77   *shard_id = sampler_->ShardId();
78   return Status::OK();
79 }
80 
81 // Get Dataset size.
GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> & size_getter,bool estimate,int64_t * dataset_size)82 Status Caltech256Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
83                                       int64_t *dataset_size) {
84   if (dataset_size_ > 0) {
85     *dataset_size = dataset_size_;
86     return Status::OK();
87   }
88   int64_t sample_size, num_rows;
89   RETURN_IF_NOT_OK(CaltechOp::CountRowsAndClasses(dataset_dir_, kExts, &num_rows, nullptr, {}));
90   std::shared_ptr<SamplerRT> sampler_rt = nullptr;
91   RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
92   sample_size = sampler_rt->CalculateNumSamples(num_rows);
93   if (sample_size == -1) {
94     RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
95   }
96   *dataset_size = sample_size;
97   dataset_size_ = *dataset_size;
98   return Status::OK();
99 }
100 
to_json(nlohmann::json * out_json)101 Status Caltech256Node::to_json(nlohmann::json *out_json) {
102   nlohmann::json args, sampler_args;
103   RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
104   args["sampler"] = sampler_args;
105   args["num_parallel_workers"] = num_workers_;
106   args["connector_queue_size"] = connector_que_size_;
107   args["dataset_dir"] = dataset_dir_;
108   args["decode"] = decode_;
109   if (cache_ != nullptr) {
110     nlohmann::json cache_args;
111     RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
112     args["cache"] = cache_args;
113   }
114   *out_json = args;
115   return Status::OK();
116 }
117 
118 #ifndef ENABLE_ANDROID
from_json(nlohmann::json json_obj,std::shared_ptr<DatasetNode> * ds)119 Status Caltech256Node::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
120   RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kCaltech256Node));
121   RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "connector_queue_size", kCaltech256Node));
122   RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kCaltech256Node));
123   RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "decode", kCaltech256Node));
124   RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "sampler", kCaltech256Node));
125   std::string dataset_dir = json_obj["dataset_dir"];
126   bool decode = json_obj["decode"];
127   std::shared_ptr<SamplerObj> sampler;
128   RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler));
129   std::shared_ptr<DatasetCache> cache = nullptr;
130   RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
131   *ds = std::make_shared<Caltech256Node>(dataset_dir, decode, sampler, cache);
132   (void)(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
133   (void)(*ds)->SetConnectorQueueSize(json_obj["connector_queue_size"]);
134   return Status::OK();
135 }
136 #endif
137 }  // namespace dataset
138 }  // namespace mindspore
139