1 /**
2 * Copyright 2021-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/dataset/engine/ir/datasetops/source/imdb_node.h"
18
19 #include <map>
20 #include <memory>
21 #include <set>
22 #include <string>
23 #include <utility>
24 #include <vector>
25
26 #include "minddata/dataset/engine/datasetops/source/imdb_op.h"
27 #ifndef ENABLE_ANDROID
28 #include "minddata/dataset/engine/serdes.h"
29 #endif
30 #include "minddata/dataset/util/status.h"
31
32 namespace mindspore {
33 namespace dataset {
IMDBNode(const std::string & dataset_dir,const std::string & usage,std::shared_ptr<SamplerObj> sampler,std::shared_ptr<DatasetCache> cache=nullptr)34 IMDBNode::IMDBNode(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler,
35 std::shared_ptr<DatasetCache> cache = nullptr)
36 : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), sampler_(sampler), usage_(usage) {}
37
Copy()38 std::shared_ptr<DatasetNode> IMDBNode::Copy() {
39 std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
40 auto node = std::make_shared<IMDBNode>(dataset_dir_, usage_, sampler, cache_);
41 (void)node->SetNumWorkers(num_workers_);
42 (void)node->SetConnectorQueueSize(connector_que_size_);
43 return node;
44 }
45
Print(std::ostream & out) const46 void IMDBNode::Print(std::ostream &out) const {
47 out << (Name() + "(path: " + dataset_dir_ + ", usage: " + usage_ + ")");
48 }
49
ValidateParams()50 Status IMDBNode::ValidateParams() {
51 RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
52 RETURN_IF_NOT_OK(ValidateDatasetDirParam("IMDBDataset", dataset_dir_));
53 RETURN_IF_NOT_OK(ValidateStringValue("IMDBDataset", usage_, {"train", "test", "all"}));
54 RETURN_IF_NOT_OK(ValidateDatasetSampler("IMDBDataset", sampler_));
55 return Status::OK();
56 }
57
Build(std::vector<std::shared_ptr<DatasetOp>> * const node_ops)58 Status IMDBNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
59 RETURN_UNEXPECTED_IF_NULL(node_ops);
60 // Do internal Schema generation.
61 // This arg is exist in IMDBOp, but not externalized (in Python API).
62 std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
63 TensorShape scalar = TensorShape::CreateScalar();
64 RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("text", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
65 RETURN_IF_NOT_OK(
66 schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
67 std::shared_ptr<SamplerRT> sampler_rt = nullptr;
68 RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
69
70 auto op = std::make_shared<IMDBOp>(num_workers_, dataset_dir_, connector_que_size_, usage_, std::move(schema),
71 std::move(sampler_rt));
72 op->SetTotalRepeats(GetTotalRepeats());
73 op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
74 node_ops->push_back(op);
75 return Status::OK();
76 }
77
78 // Get the shard id of node
GetShardId(int32_t * shard_id)79 Status IMDBNode::GetShardId(int32_t *shard_id) {
80 RETURN_UNEXPECTED_IF_NULL(shard_id);
81 *shard_id = sampler_->ShardId();
82 return Status::OK();
83 }
84
85 // Get Dataset size
GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> & size_getter,bool estimate,int64_t * dataset_size)86 Status IMDBNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
87 int64_t *dataset_size) {
88 RETURN_UNEXPECTED_IF_NULL(dataset_size);
89 if (dataset_size_ > 0) {
90 *dataset_size = dataset_size_;
91 return Status::OK();
92 }
93 int64_t sample_size, num_rows;
94 RETURN_IF_NOT_OK(IMDBOp::CountRows(dataset_dir_, usage_, &num_rows));
95 std::shared_ptr<SamplerRT> sampler_rt = nullptr;
96 RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
97 sample_size = sampler_rt->CalculateNumSamples(num_rows);
98 if (sample_size == -1) {
99 RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
100 }
101 *dataset_size = sample_size;
102 dataset_size_ = *dataset_size;
103 return Status::OK();
104 }
105
to_json(nlohmann::json * out_json)106 Status IMDBNode::to_json(nlohmann::json *out_json) {
107 nlohmann::json args, sampler_args;
108 RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
109 args["sampler"] = sampler_args;
110 args["num_parallel_workers"] = num_workers_;
111 args["connector_queue_size"] = connector_que_size_;
112 args["dataset_dir"] = dataset_dir_;
113 args["usage"] = usage_;
114 if (cache_ != nullptr) {
115 nlohmann::json cache_args;
116 RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
117 args["cache"] = cache_args;
118 }
119 *out_json = args;
120 return Status::OK();
121 }
122
123 #ifndef ENABLE_ANDROID
from_json(nlohmann::json json_obj,std::shared_ptr<DatasetNode> * ds)124 Status IMDBNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
125 RETURN_UNEXPECTED_IF_NULL(ds);
126 RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kIMDBNode));
127 RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "connector_queue_size", kIMDBNode));
128 RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kIMDBNode));
129 RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "usage", kIMDBNode));
130 RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "sampler", kIMDBNode));
131 std::string dataset_dir = json_obj["dataset_dir"];
132 std::string usage = json_obj["usage"];
133 std::shared_ptr<SamplerObj> sampler;
134 RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler));
135 std::shared_ptr<DatasetCache> cache = nullptr;
136 RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
137 *ds = std::make_shared<IMDBNode>(dataset_dir, usage, sampler, cache);
138 (*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
139 return Status::OK();
140 }
141 #endif
142 } // namespace dataset
143 } // namespace mindspore
144