1 /**
2 * Copyright 2021-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/dataset/engine/ir/datasetops/source/fake_image_node.h"
18
19 #include "minddata/dataset/engine/datasetops/source/fake_image_op.h"
20 #include "minddata/dataset/util/status.h"
21
22 namespace mindspore {
23 namespace dataset {
FakeImageNode(int32_t num_images,const std::vector<int32_t> & image_size,int32_t num_classes,int32_t base_seed,std::shared_ptr<SamplerObj> sampler,std::shared_ptr<DatasetCache> cache)24 FakeImageNode::FakeImageNode(int32_t num_images, const std::vector<int32_t> &image_size, int32_t num_classes,
25 int32_t base_seed, std::shared_ptr<SamplerObj> sampler,
26 std::shared_ptr<DatasetCache> cache)
27 : MappableSourceNode(std::move(cache)),
28 num_images_(num_images),
29 image_size_(image_size),
30 num_classes_(num_classes),
31 base_seed_(base_seed),
32 sampler_(sampler) {}
33
Copy()34 std::shared_ptr<DatasetNode> FakeImageNode::Copy() {
35 std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
36 auto node = std::make_shared<FakeImageNode>(num_images_, image_size_, num_classes_, base_seed_, sampler, cache_);
37 (void)node->SetNumWorkers(num_workers_);
38 (void)node->SetConnectorQueueSize(connector_que_size_);
39 return node;
40 }
41
Print(std::ostream & out) const42 void FakeImageNode::Print(std::ostream &out) const {
43 out << (Name() + "(cache: " + ((cache_ != nullptr) ? "true" : "false") + ")");
44 }
45
ValidateParams()46 Status FakeImageNode::ValidateParams() {
47 RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
48
49 RETURN_IF_NOT_OK(ValidateDatasetSampler("FakeImageDataset", sampler_));
50 RETURN_IF_NOT_OK(ValidateScalar("FakeImageDataset", "num_images", num_images_, {0}, true));
51
52 if (image_size_.size() != 3) {
53 std::string err_msg = "FakeImageDataset: 'image_size' expecting size 3, but got image_size.size(): " +
54 std::to_string(image_size_.size());
55 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
56 }
57
58 for (auto i = 0; i < 3; i++) {
59 RETURN_IF_NOT_OK(
60 ValidateScalar("FakeImageDataset", "image_size[" + std::to_string(i) + "]", image_size_[i], {0}, true));
61 }
62
63 RETURN_IF_NOT_OK(ValidateScalar("FakeImageDataset", "num_classes", num_classes_, {0}, true));
64 return Status::OK();
65 }
66
Build(std::vector<std::shared_ptr<DatasetOp>> * const node_ops)67 Status FakeImageNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
68 // Do internal Schema generation.
69 auto schema = std::make_unique<DataSchema>();
70 RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
71 TensorShape scalar = TensorShape::CreateScalar();
72 RETURN_IF_NOT_OK(
73 schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
74 std::shared_ptr<SamplerRT> sampler_rt = nullptr;
75 RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
76
77 auto op = std::make_shared<FakeImageOp>(num_images_, image_size_, num_classes_, base_seed_, num_workers_,
78 connector_que_size_, std::move(schema), std::move(sampler_rt));
79 op->SetTotalRepeats(GetTotalRepeats());
80 op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
81 node_ops->push_back(op);
82
83 return Status::OK();
84 }
85
86 // Get the shard id of node
GetShardId(int32_t * shard_id)87 Status FakeImageNode::GetShardId(int32_t *shard_id) {
88 *shard_id = sampler_->ShardId();
89
90 return Status::OK();
91 }
92
93 // Get Dataset size
GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> & size_getter,bool estimate,int64_t * dataset_size)94 Status FakeImageNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
95 int64_t *dataset_size) {
96 if (dataset_size_ > 0) {
97 *dataset_size = dataset_size_;
98 return Status::OK();
99 }
100
101 int64_t num_rows, sample_size;
102 num_rows = num_images_;
103 std::shared_ptr<SamplerRT> sampler_rt = nullptr;
104 RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
105 sample_size = sampler_rt->CalculateNumSamples(num_rows);
106 if (sample_size == -1) {
107 RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
108 }
109 *dataset_size = sample_size;
110 dataset_size_ = *dataset_size;
111 return Status::OK();
112 }
113
to_json(nlohmann::json * out_json)114 Status FakeImageNode::to_json(nlohmann::json *out_json) {
115 nlohmann::json args, sampler_args;
116 RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
117 args["sampler"] = sampler_args;
118 args["num_parallel_workers"] = num_workers_;
119 args["connector_queue_size"] = connector_que_size_;
120 args["num_images"] = num_images_;
121 args["image_size"] = image_size_;
122 args["num_classes"] = num_classes_;
123 args["base_seed"] = base_seed_;
124 if (cache_ != nullptr) {
125 nlohmann::json cache_args;
126 RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
127 args["cache"] = cache_args;
128 }
129 *out_json = args;
130 return Status::OK();
131 }
132 } // namespace dataset
133 } // namespace mindspore
134