1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TF_RECORD_NODE_H_ 18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TF_RECORD_NODE_H_ 19 20 #include <memory> 21 #include <string> 22 #include <utility> 23 #include <vector> 24 25 #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" 26 27 namespace mindspore { 28 namespace dataset { 29 30 /// \class TFRecordNode 31 /// \brief A Dataset derived class to represent TFRecord dataset 32 class TFRecordNode : public NonMappableSourceNode { 33 friend class CacheValidationPass; 34 35 public: 36 /// \brief Constructor 37 /// \note Parameter 'schema' is the path to the schema file TFRecordNode(const std::vector<std::string> & dataset_files,std::string schema,const std::vector<std::string> & columns_list,int64_t num_samples,ShuffleMode shuffle,int32_t num_shards,int32_t shard_id,bool shard_equal_rows,std::shared_ptr<DatasetCache> cache)38 TFRecordNode(const std::vector<std::string> &dataset_files, std::string schema, 39 const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle, 40 int32_t num_shards, int32_t shard_id, bool shard_equal_rows, std::shared_ptr<DatasetCache> cache) 41 : NonMappableSourceNode(std::move(cache)), 42 dataset_files_(dataset_files), 43 schema_path_(schema), 44 columns_list_(columns_list), 45 num_samples_(num_samples), 46 shuffle_(shuffle), 47 num_shards_(num_shards), 48 shard_id_(shard_id), 49 shard_equal_rows_(shard_equal_rows) { 50 // Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. User 51 // discretion is advised. Auto_num_worker_pass is currently an experimental feature which can still work if the 52 // num_shards_ isn't 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to return 53 // num_shards. Once PreBuildSampler is phased out, this can be cleaned up. 54 GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_); 55 } 56 57 /// \brief Constructor 58 /// \note Parameter 'schema' is shared pointer to Schema object TFRecordNode(const std::vector<std::string> & dataset_files,std::shared_ptr<SchemaObj> schema,const std::vector<std::string> & columns_list,int64_t num_samples,ShuffleMode shuffle,int32_t num_shards,int32_t shard_id,bool shard_equal_rows,std::shared_ptr<DatasetCache> cache)59 TFRecordNode(const std::vector<std::string> &dataset_files, std::shared_ptr<SchemaObj> schema, 60 const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle, 61 int32_t num_shards, int32_t shard_id, bool shard_equal_rows, std::shared_ptr<DatasetCache> cache) 62 : NonMappableSourceNode(std::move(cache)), 63 dataset_files_(dataset_files), 64 schema_obj_(schema), 65 columns_list_(columns_list), 66 num_samples_(num_samples), 67 shuffle_(shuffle), 68 num_shards_(num_shards), 69 shard_id_(shard_id), 70 shard_equal_rows_(shard_equal_rows) {} 71 72 /// \brief Destructor 73 ~TFRecordNode() = default; 74 75 /// \brief Node name getter 76 /// \return Name of the current node Name()77 std::string Name() const override { return kTFRecordNode; } 78 79 /// \brief Print the description 80 /// \param out - The output stream to write output to 81 void Print(std::ostream &out) const override; 82 83 /// \brief Copy the node to a new object 84 /// \return A shared pointer to the new copy 85 std::shared_ptr<DatasetNode> Copy() override; 86 87 /// \brief a base class override function to create the required runtime dataset op objects for this class 88 /// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create 89 /// \return Status Status::OK() if build successfully 90 Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override; 91 92 /// \brief Parameters validation 93 /// \return Status Status::OK() if all the parameters are valid 94 Status ValidateParams() override; 95 96 /// \brief Get the shard id of node 97 /// \return Status Status::OK() if get shard id successfully 98 Status GetShardId(int32_t *const shard_id) override; 99 100 /// \brief Base-class override for GetDatasetSize 101 /// \param[in] size_getter Shared pointer to DatasetSizeGetter 102 /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting 103 /// dataset size at the expense of accuracy. 104 /// \param[out] dataset_size the size of the dataset 105 /// \return Status of the function 106 Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, 107 int64_t *dataset_size) override; 108 109 /// \brief Get the file list of the specific shard ID 110 /// \param[out] shard_filenames the list of filenames for that specific shard ID 111 /// \return Status of the function 112 Status GetShardFileList(std::vector<std::string> *shard_filenames); 113 114 /// \brief Getter functions DatasetFiles()115 const std::vector<std::string> &DatasetFiles() const { return dataset_files_; } SchemaPath()116 const std::string &SchemaPath() const { return schema_path_; } GetSchemaObj()117 const std::shared_ptr<SchemaObj> &GetSchemaObj() const { return schema_obj_; } ColumnsList()118 const std::vector<std::string> &ColumnsList() const { return columns_list_; } NumSamples()119 int64_t NumSamples() const { return num_samples_; } Shuffle()120 ShuffleMode Shuffle() const { return shuffle_; } NumShards()121 int32_t NumShards() const { return num_shards_; } ShardEqualRows()122 bool ShardEqualRows() const { return shard_equal_rows_; } 123 124 /// \brief Get the arguments of node 125 /// \param[out] out_json JSON string of all attributes 126 /// \return Status of the function 127 Status to_json(nlohmann::json *out_json) override; 128 129 /// \brief Function to read dataset in json 130 /// \param[in] json_obj The JSON object to be deserialized 131 /// \param[out] ds Deserialized dataset 132 /// \return Status The status code returned 133 static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds); 134 135 /// \brief TFRecord by itself is a non-mappable dataset that does not support sampling. 136 /// However, if a cache operator is injected at some other place higher in the tree, that cache can 137 /// inherit this sampler from the leaf, providing sampling support from the caching layer. 138 /// That is why we setup the sampler for a leaf node that does not use sampling. 139 /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. 140 /// \param[in] sampler The sampler to setup 141 /// \return Status of the function 142 Status SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) override; 143 144 /// \brief If a cache has been added into the ascendant tree over this TFRecord node, then the cache will be executing 145 /// a sampler for fetching the data. As such, any options in the TFRecord node need to be reset to its defaults 146 /// so that this TFRecord node will produce the full set of data into the cache. 147 /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. 148 /// \return Status of the function 149 Status MakeSimpleProducer() override; 150 151 /// \brief Base-class override for accepting IRNodePass visitor 152 /// \param[in] p The node to visit 153 /// \param[out] modified Indicator if the node was modified 154 /// \return Status of the node visit 155 Status Accept(IRNodePass *p, bool *const modified) override; 156 157 /// \brief Base-class override for accepting IRNodePass visitor 158 /// \param[in] p The node to visit 159 /// \param[out] modified Indicator if the node was modified 160 /// \return Status of the node visit 161 Status AcceptAfter(IRNodePass *const p, bool *const modified) override; 162 163 private: 164 std::vector<std::string> dataset_files_; 165 std::string schema_path_; // schema_path_ path to schema file. It is set when type of schema parameter is string 166 std::shared_ptr<SchemaObj> schema_obj_; // schema_obj_ schema object. 167 std::vector<std::string> columns_list_; 168 int64_t num_samples_; 169 ShuffleMode shuffle_; 170 int32_t num_shards_; 171 int32_t shard_id_; 172 bool shard_equal_rows_; 173 }; 174 175 } // namespace dataset 176 } // namespace mindspore 177 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TF_RECORD_NODE_H_ 178