1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_MULTI30K_NODE_H_ 18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_MULTI30K_NODE_H_ 19 20 #include <memory> 21 #include <string> 22 #include <vector> 23 24 #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" 25 26 namespace mindspore { 27 namespace dataset { 28 class Multi30kNode : public NonMappableSourceNode { 29 public: 30 /// \brief Constructor of Multi30kNode. 31 /// \param[in] dataset_dir Path to the root directory that contains the dataset. 32 /// \param[in] usage Part of dataset of MULTI30K, can be "train", "test", "valid" or "all". 33 /// \param[in] language_pair List containing text and translation language. 34 /// \param[in] num_samples The number of samples to be included in the dataset 35 /// \param[in] shuffle The mode for shuffling data every epoch. 36 /// Can be any of: 37 /// ShuffleMode::kFalse - No shuffling is performed. 38 /// ShuffleMode::kFiles - Shuffle files only. 39 /// ShuffleMode::kGlobal - Shuffle both the files and samples. 40 /// \param[in] num_shards Number of shards that the dataset should be divided into. 41 /// \param[in] shared_id The shard ID within num_shards. This argument should be 42 /// specified only when num_shards is also specified. 43 /// \param[in] cache Tensor cache to use. 44 Multi30kNode(const std::string &dataset_dir, const std::string &usage, const std::vector<std::string> &language_pair, 45 int32_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shared_id, 46 std::shared_ptr<DatasetCache> cache); 47 48 /// \brief Destructor of Multi30kNode. 49 ~Multi30kNode() = default; 50 51 /// \brief Node name getter. 52 /// \return Name of the current node. Name()53 std::string Name() const override { return kMulti30kNode; } 54 55 /// \brief Print the description. 56 /// \param[in] out The output stream to write output to. 57 void Print(std::ostream &out) const override; 58 59 /// \brief Copy the node to a new object. 60 /// \return A shared pointer to the new copy. 61 std::shared_ptr<DatasetNode> Copy() override; 62 63 /// \brief a base class override function to create the required runtime dataset op objects for this class. 64 /// \param[in] node_ops A vector containing shared pointer to the Dataset Ops that this object will create. 65 /// \return Status Status::OK() if build successfully. 66 Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops); 67 68 /// \brief Parameters validation. 69 /// \return Status Status::OK() if all the parameters are valid. 70 Status ValidateParams() override; 71 72 /// \brief Get the shard id of node. 73 /// \param[in] shard_id The shard id. 74 /// \return Status Status::OK() if get shard id successfully. 75 Status GetShardId(int32_t *shard_id) override; 76 77 /// \brief Base-class override for GetDatasetSize. 78 /// \param[in] size_getter Shared pointer to DatasetSizeGetter. 79 /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting. 80 /// dataset size at the expense of accuracy. 81 /// \param[out] dataset_size the size of the dataset. 82 /// \return Status of the function. 83 Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, 84 int64_t *dataset_size) override; 85 86 /// \brief Get the arguments of node. 87 /// \param[out] out_json JSON string of all attributes. 88 /// \return Status of the function. 89 Status to_json(nlohmann::json *out_json) override; 90 91 /// \brief Multi30k by itself is a non-mappable dataset that does not support sampling. 92 /// However, if a cache operator is injected at some other place higher in the tree, that cache can 93 /// inherit this sampler from the leaf, providing sampling support from the caching layer. 94 /// That is why we setup the sampler for a leaf node that does not use sampling. 95 /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. 96 /// \param[in] sampler The sampler to setup. 97 /// \return Status of the function. 98 Status SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) override; 99 100 /// \brief If a cache has been added into the ascendant tree over this Multi30k node, then the cache will be executing 101 /// a sampler for fetching the data. As such, any options in the Multi30k node need to be reset to its defaults 102 /// so that this Multi30k node will produce the full set of data into the cache. 103 /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. 104 /// \return Status of the function 105 Status MakeSimpleProducer() override; 106 107 /// \brief Getter functions NumSamples()108 int32_t NumSamples() const { return num_samples_; } NumShards()109 int32_t NumShards() const { return num_shards_; } ShardId()110 int32_t ShardId() const { return shard_id_; } Shuffle()111 ShuffleMode Shuffle() const { return shuffle_; } DatasetDir()112 const std::string &DatasetDir() const { return dataset_dir_; } Usage()113 const std::string &Usage() const { return usage_; } LanguagePair()114 const std::vector<std::string> &LanguagePair() const { return language_pair_; } 115 116 /// \brief Generate a list of read file names according to usage. 117 /// \param[in] usage Part of dataset of Multi30k. 118 /// \param[in] dataset_dir Path to the root directory that contains the dataset. 119 /// \return std::vector<std::string> A list of read file names. 120 std::vector<std::string> WalkAllFiles(const std::string &usage, const std::string &dataset_dir); 121 122 private: 123 std::string dataset_dir_; 124 std::string usage_; 125 std::vector<std::string> language_pair_; 126 int32_t num_samples_; 127 ShuffleMode shuffle_; 128 int32_t num_shards_; 129 int32_t shard_id_; 130 std::vector<std::string> multi30k_files_list_; 131 }; 132 } // namespace dataset 133 } // namespace mindspore 134 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_MULTI30K_NODE_H_ 135