• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/dataset/engine/ir/datasetops/source/fake_image_node.h"
18 
19 #include "minddata/dataset/engine/datasetops/source/fake_image_op.h"
20 #include "minddata/dataset/util/status.h"
21 
22 namespace mindspore {
23 namespace dataset {
FakeImageNode(int32_t num_images,const std::vector<int32_t> & image_size,int32_t num_classes,int32_t base_seed,std::shared_ptr<SamplerObj> sampler,std::shared_ptr<DatasetCache> cache)24 FakeImageNode::FakeImageNode(int32_t num_images, const std::vector<int32_t> &image_size, int32_t num_classes,
25                              int32_t base_seed, std::shared_ptr<SamplerObj> sampler,
26                              std::shared_ptr<DatasetCache> cache)
27     : MappableSourceNode(std::move(cache)),
28       num_images_(num_images),
29       image_size_(image_size),
30       num_classes_(num_classes),
31       base_seed_(base_seed),
32       sampler_(sampler) {}
33 
Copy()34 std::shared_ptr<DatasetNode> FakeImageNode::Copy() {
35   std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
36   auto node = std::make_shared<FakeImageNode>(num_images_, image_size_, num_classes_, base_seed_, sampler, cache_);
37   (void)node->SetNumWorkers(num_workers_);
38   (void)node->SetConnectorQueueSize(connector_que_size_);
39   return node;
40 }
41 
Print(std::ostream & out) const42 void FakeImageNode::Print(std::ostream &out) const {
43   out << (Name() + "(cache: " + ((cache_ != nullptr) ? "true" : "false") + ")");
44 }
45 
ValidateParams()46 Status FakeImageNode::ValidateParams() {
47   RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
48 
49   RETURN_IF_NOT_OK(ValidateDatasetSampler("FakeImageDataset", sampler_));
50   RETURN_IF_NOT_OK(ValidateScalar("FakeImageDataset", "num_images", num_images_, {0}, true));
51 
52   if (image_size_.size() != 3) {
53     std::string err_msg = "FakeImageDataset: 'image_size' expecting size 3, but got image_size.size(): " +
54                           std::to_string(image_size_.size());
55     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
56   }
57 
58   for (auto i = 0; i < 3; i++) {
59     RETURN_IF_NOT_OK(
60       ValidateScalar("FakeImageDataset", "image_size[" + std::to_string(i) + "]", image_size_[i], {0}, true));
61   }
62 
63   RETURN_IF_NOT_OK(ValidateScalar("FakeImageDataset", "num_classes", num_classes_, {0}, true));
64   return Status::OK();
65 }
66 
Build(std::vector<std::shared_ptr<DatasetOp>> * const node_ops)67 Status FakeImageNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
68   // Do internal Schema generation.
69   auto schema = std::make_unique<DataSchema>();
70   RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
71   TensorShape scalar = TensorShape::CreateScalar();
72   RETURN_IF_NOT_OK(
73     schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
74   std::shared_ptr<SamplerRT> sampler_rt = nullptr;
75   RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
76 
77   auto op = std::make_shared<FakeImageOp>(num_images_, image_size_, num_classes_, base_seed_, num_workers_,
78                                           connector_que_size_, std::move(schema), std::move(sampler_rt));
79   op->SetTotalRepeats(GetTotalRepeats());
80   op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
81   node_ops->push_back(op);
82 
83   return Status::OK();
84 }
85 
86 // Get the shard id of node
GetShardId(int32_t * shard_id)87 Status FakeImageNode::GetShardId(int32_t *shard_id) {
88   *shard_id = sampler_->ShardId();
89 
90   return Status::OK();
91 }
92 
93 // Get Dataset size
GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> & size_getter,bool estimate,int64_t * dataset_size)94 Status FakeImageNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
95                                      int64_t *dataset_size) {
96   if (dataset_size_ > 0) {
97     *dataset_size = dataset_size_;
98     return Status::OK();
99   }
100 
101   int64_t num_rows, sample_size;
102   num_rows = num_images_;
103   std::shared_ptr<SamplerRT> sampler_rt = nullptr;
104   RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
105   sample_size = sampler_rt->CalculateNumSamples(num_rows);
106   if (sample_size == -1) {
107     RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
108   }
109   *dataset_size = sample_size;
110   dataset_size_ = *dataset_size;
111   return Status::OK();
112 }
113 
to_json(nlohmann::json * out_json)114 Status FakeImageNode::to_json(nlohmann::json *out_json) {
115   nlohmann::json args, sampler_args;
116   RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
117   args["sampler"] = sampler_args;
118   args["num_parallel_workers"] = num_workers_;
119   args["connector_queue_size"] = connector_que_size_;
120   args["num_images"] = num_images_;
121   args["image_size"] = image_size_;
122   args["num_classes"] = num_classes_;
123   args["base_seed"] = base_seed_;
124   if (cache_ != nullptr) {
125     nlohmann::json cache_args;
126     RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
127     args["cache"] = cache_args;
128   }
129   *out_json = args;
130   return Status::OK();
131 }
132 }  // namespace dataset
133 }  // namespace mindspore
134