1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CSV_NODE_H_ 18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CSV_NODE_H_ 19 20 #include <memory> 21 #include <string> 22 #include <vector> 23 24 #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" 25 26 namespace mindspore { 27 namespace dataset { 28 /// \brief Record type for CSV 29 enum CsvType : uint8_t { INT = 0, FLOAT, STRING }; 30 31 /// \brief Base class of CSV Record 32 class CsvBase { 33 public: 34 CsvBase() = default; 35 CsvBase(CsvType t)36 explicit CsvBase(CsvType t) : type(t) {} 37 38 virtual ~CsvBase() = default; 39 40 CsvType type{INT}; 41 }; 42 43 /// \brief CSV Record that can represent integer, float and string. 44 template <typename T> 45 class CsvRecord : public CsvBase { 46 public: 47 CsvRecord() = default; 48 CsvRecord(CsvType t,T v)49 CsvRecord(CsvType t, T v) : CsvBase(t), value(v) {} 50 51 ~CsvRecord() override = default; 52 53 T value; 54 }; 55 56 class CSVNode : public NonMappableSourceNode { 57 public: 58 /// \brief Constructor 59 CSVNode(const std::vector<std::string> &dataset_files, char field_delim, 60 const std::vector<std::shared_ptr<CsvBase>> &column_defaults, const std::vector<std::string> &column_names, 61 int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, 62 std::shared_ptr<DatasetCache> cache); 63 64 /// \brief Destructor 65 ~CSVNode() override = default; 66 67 /// \brief Node name getter 68 /// \return Name of the current node Name()69 std::string Name() const override { return kCSVNode; } 70 71 /// \brief Print the description 72 /// \param out - The output stream to write output to 73 void Print(std::ostream &out) const override; 74 75 /// \brief Copy the node to a new object 76 /// \return A shared pointer to the new copy 77 std::shared_ptr<DatasetNode> Copy() override; 78 79 /// \brief a base class override function to create the required runtime dataset op objects for this class 80 /// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create 81 /// \return Status Status::OK() if build successfully 82 Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override; 83 84 /// \brief Parameters validation 85 /// \return Status Status::OK() if all the parameters are valid 86 Status ValidateParams() override; 87 88 /// \brief Get the shard id of node 89 /// \return Status Status::OK() if get shard id successfully 90 Status GetShardId(int32_t *shard_id) override; 91 92 /// \brief Base-class override for GetDatasetSize 93 /// \param[in] size_getter Shared pointer to DatasetSizeGetter 94 /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting 95 /// dataset size at the expense of accuracy. 96 /// \param[out] dataset_size the size of the dataset 97 /// \return Status of the function 98 Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, 99 int64_t *dataset_size) override; 100 101 /// \brief Getter functions DatasetFiles()102 const std::vector<std::string> &DatasetFiles() const { return dataset_files_; } FieldDelim()103 char FieldDelim() const { return field_delim_; } ColumnDefaults()104 const std::vector<std::shared_ptr<CsvBase>> &ColumnDefaults() const { return column_defaults_; } ColumnNames()105 const std::vector<std::string> &ColumnNames() const { return column_names_; } NumSamples()106 int64_t NumSamples() const { return num_samples_; } Shuffle()107 ShuffleMode Shuffle() const { return shuffle_; } NumShards()108 int32_t NumShards() const { return num_shards_; } ShardId()109 int32_t ShardId() const { return shard_id_; } 110 111 /// \brief Get the arguments of node 112 /// \param[out] out_json JSON string of all attributes 113 /// \return Status of the function 114 Status to_json(nlohmann::json *out_json) override; 115 116 /// \brief Function to read dataset in json 117 /// \param[in] json_obj The JSON object to be deserialized 118 /// \param[out] ds Deserialized dataset 119 /// \return Status The status code returned 120 static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds); 121 122 /// \brief CSV by itself is a non-mappable dataset that does not support sampling. 123 /// However, if a cache operator is injected at some other place higher in the tree, that cache can 124 /// inherit this sampler from the leaf, providing sampling support from the caching layer. 125 /// That is why we setup the sampler for a leaf node that does not use sampling. 126 /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. 127 /// \param[in] sampler The sampler to setup 128 /// \return Status of the function 129 Status SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) override; 130 131 /// \brief If a cache has been added into the ascendant tree over this CSV node, then the cache will be executing 132 /// a sampler for fetching the data. As such, any options in the CSV node need to be reset to its defaults so 133 /// that this CSV node will produce the full set of data into the cache. 134 /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. 135 /// \return Status of the function 136 Status MakeSimpleProducer() override; 137 138 private: 139 std::vector<std::string> dataset_files_; 140 char field_delim_; 141 std::vector<std::shared_ptr<CsvBase>> column_defaults_; 142 std::vector<std::string> column_names_; 143 int64_t num_samples_; 144 ShuffleMode shuffle_; 145 int32_t num_shards_; 146 int32_t shard_id_; 147 }; 148 } // namespace dataset 149 } // namespace mindspore 150 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CSV_NODE_H_ 151