• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/dataset/engine/ir/datasetops/source/imdb_node.h"
18 
19 #include <map>
20 #include <memory>
21 #include <set>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 
26 #include "minddata/dataset/engine/datasetops/source/imdb_op.h"
27 #ifndef ENABLE_ANDROID
28 #include "minddata/dataset/engine/serdes.h"
29 #endif
30 #include "minddata/dataset/util/status.h"
31 
32 namespace mindspore {
33 namespace dataset {
IMDBNode(const std::string & dataset_dir,const std::string & usage,std::shared_ptr<SamplerObj> sampler,std::shared_ptr<DatasetCache> cache=nullptr)34 IMDBNode::IMDBNode(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler,
35                    std::shared_ptr<DatasetCache> cache = nullptr)
36     : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), sampler_(sampler), usage_(usage) {}
37 
Copy()38 std::shared_ptr<DatasetNode> IMDBNode::Copy() {
39   std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
40   auto node = std::make_shared<IMDBNode>(dataset_dir_, usage_, sampler, cache_);
41   (void)node->SetNumWorkers(num_workers_);
42   (void)node->SetConnectorQueueSize(connector_que_size_);
43   return node;
44 }
45 
Print(std::ostream & out) const46 void IMDBNode::Print(std::ostream &out) const {
47   out << (Name() + "(path: " + dataset_dir_ + ", usage: " + usage_ + ")");
48 }
49 
ValidateParams()50 Status IMDBNode::ValidateParams() {
51   RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
52   RETURN_IF_NOT_OK(ValidateDatasetDirParam("IMDBDataset", dataset_dir_));
53   RETURN_IF_NOT_OK(ValidateStringValue("IMDBDataset", usage_, {"train", "test", "all"}));
54   RETURN_IF_NOT_OK(ValidateDatasetSampler("IMDBDataset", sampler_));
55   return Status::OK();
56 }
57 
Build(std::vector<std::shared_ptr<DatasetOp>> * const node_ops)58 Status IMDBNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
59   RETURN_UNEXPECTED_IF_NULL(node_ops);
60   // Do internal Schema generation.
61   // This arg is exist in IMDBOp, but not externalized (in Python API).
62   std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
63   TensorShape scalar = TensorShape::CreateScalar();
64   RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("text", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
65   RETURN_IF_NOT_OK(
66     schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
67   std::shared_ptr<SamplerRT> sampler_rt = nullptr;
68   RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
69 
70   auto op = std::make_shared<IMDBOp>(num_workers_, dataset_dir_, connector_que_size_, usage_, std::move(schema),
71                                      std::move(sampler_rt));
72   op->SetTotalRepeats(GetTotalRepeats());
73   op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
74   node_ops->push_back(op);
75   return Status::OK();
76 }
77 
78 // Get the shard id of node
GetShardId(int32_t * shard_id)79 Status IMDBNode::GetShardId(int32_t *shard_id) {
80   RETURN_UNEXPECTED_IF_NULL(shard_id);
81   *shard_id = sampler_->ShardId();
82   return Status::OK();
83 }
84 
85 // Get Dataset size
GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> & size_getter,bool estimate,int64_t * dataset_size)86 Status IMDBNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
87                                 int64_t *dataset_size) {
88   RETURN_UNEXPECTED_IF_NULL(dataset_size);
89   if (dataset_size_ > 0) {
90     *dataset_size = dataset_size_;
91     return Status::OK();
92   }
93   int64_t sample_size, num_rows;
94   RETURN_IF_NOT_OK(IMDBOp::CountRows(dataset_dir_, usage_, &num_rows));
95   std::shared_ptr<SamplerRT> sampler_rt = nullptr;
96   RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
97   sample_size = sampler_rt->CalculateNumSamples(num_rows);
98   if (sample_size == -1) {
99     RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
100   }
101   *dataset_size = sample_size;
102   dataset_size_ = *dataset_size;
103   return Status::OK();
104 }
105 
to_json(nlohmann::json * out_json)106 Status IMDBNode::to_json(nlohmann::json *out_json) {
107   nlohmann::json args, sampler_args;
108   RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
109   args["sampler"] = sampler_args;
110   args["num_parallel_workers"] = num_workers_;
111   args["connector_queue_size"] = connector_que_size_;
112   args["dataset_dir"] = dataset_dir_;
113   args["usage"] = usage_;
114   if (cache_ != nullptr) {
115     nlohmann::json cache_args;
116     RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
117     args["cache"] = cache_args;
118   }
119   *out_json = args;
120   return Status::OK();
121 }
122 
123 #ifndef ENABLE_ANDROID
from_json(nlohmann::json json_obj,std::shared_ptr<DatasetNode> * ds)124 Status IMDBNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
125   RETURN_UNEXPECTED_IF_NULL(ds);
126   RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kIMDBNode));
127   RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "connector_queue_size", kIMDBNode));
128   RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kIMDBNode));
129   RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "usage", kIMDBNode));
130   RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "sampler", kIMDBNode));
131   std::string dataset_dir = json_obj["dataset_dir"];
132   std::string usage = json_obj["usage"];
133   std::shared_ptr<SamplerObj> sampler;
134   RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler));
135   std::shared_ptr<DatasetCache> cache = nullptr;
136   RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
137   *ds = std::make_shared<IMDBNode>(dataset_dir, usage, sampler, cache);
138   (*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
139   return Status::OK();
140 }
141 #endif
142 }  // namespace dataset
143 }  // namespace mindspore
144