1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CSV_NODE_H_ 18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CSV_NODE_H_ 19 20 #include <memory> 21 #include <string> 22 #include <vector> 23 24 #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" 25 26 namespace mindspore { 27 namespace dataset { 28 /// \brief Record type for CSV 29 enum CsvType : uint8_t { INT = 0, FLOAT, STRING }; 30 31 /// \brief Base class of CSV Record 32 class CsvBase { 33 public: 34 CsvBase() = default; CsvBase(CsvType t)35 explicit CsvBase(CsvType t) : type(t) {} ~CsvBase()36 virtual ~CsvBase() {} 37 CsvType type; 38 }; 39 40 /// \brief CSV Record that can represent integer, float and string. 41 template <typename T> 42 class CsvRecord : public CsvBase { 43 public: 44 CsvRecord() = default; CsvRecord(CsvType t,T v)45 CsvRecord(CsvType t, T v) : CsvBase(t), value(v) {} ~CsvRecord()46 ~CsvRecord() {} 47 T value; 48 }; 49 50 class CSVNode : public NonMappableSourceNode { 51 public: 52 /// \brief Constructor 53 CSVNode(const std::vector<std::string> &dataset_files, char field_delim, 54 const std::vector<std::shared_ptr<CsvBase>> &column_defaults, const std::vector<std::string> &column_names, 55 int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, 56 std::shared_ptr<DatasetCache> cache); 57 58 /// \brief Destructor 59 ~CSVNode() = default; 60 61 /// \brief Node name getter 62 /// \return Name of the current node Name()63 std::string Name() const override { return kCSVNode; } 64 65 /// \brief Print the description 66 /// \param out - The output stream to write output to 67 void Print(std::ostream &out) const override; 68 69 /// \brief Copy the node to a new object 70 /// \return A shared pointer to the new copy 71 std::shared_ptr<DatasetNode> Copy() override; 72 73 /// \brief a base class override function to create the required runtime dataset op objects for this class 74 /// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create 75 /// \return Status Status::OK() if build successfully 76 Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override; 77 78 /// \brief Parameters validation 79 /// \return Status Status::OK() if all the parameters are valid 80 Status ValidateParams() override; 81 82 /// \brief Get the shard id of node 83 /// \return Status Status::OK() if get shard id successfully 84 Status GetShardId(int32_t *shard_id) override; 85 86 /// \brief Base-class override for GetDatasetSize 87 /// \param[in] size_getter Shared pointer to DatasetSizeGetter 88 /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting 89 /// dataset size at the expense of accuracy. 90 /// \param[out] dataset_size the size of the dataset 91 /// \return Status of the function 92 Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, 93 int64_t *dataset_size) override; 94 95 /// \brief Getter functions DatasetFiles()96 const std::vector<std::string> &DatasetFiles() const { return dataset_files_; } FieldDelim()97 char FieldDelim() const { return field_delim_; } ColumnDefaults()98 const std::vector<std::shared_ptr<CsvBase>> &ColumnDefaults() const { return column_defaults_; } ColumnNames()99 const std::vector<std::string> &ColumnNames() const { return column_names_; } NumSamples()100 int64_t NumSamples() const { return num_samples_; } Shuffle()101 ShuffleMode Shuffle() const { return shuffle_; } NumShards()102 int32_t NumShards() const { return num_shards_; } ShardId()103 int32_t ShardId() const { return shard_id_; } 104 105 /// \brief Get the arguments of node 106 /// \param[out] out_json JSON string of all attributes 107 /// \return Status of the function 108 Status to_json(nlohmann::json *out_json) override; 109 110 /// \brief Function to read dataset in json 111 /// \param[in] json_obj The JSON object to be deserialized 112 /// \param[out] ds Deserialized dataset 113 /// \return Status The status code returned 114 static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds); 115 116 /// \brief CSV by itself is a non-mappable dataset that does not support sampling. 117 /// However, if a cache operator is injected at some other place higher in the tree, that cache can 118 /// inherit this sampler from the leaf, providing sampling support from the caching layer. 119 /// That is why we setup the sampler for a leaf node that does not use sampling. 120 /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. 121 /// \param[in] sampler The sampler to setup 122 /// \return Status of the function 123 Status SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) override; 124 125 /// \brief If a cache has been added into the ascendant tree over this CSV node, then the cache will be executing 126 /// a sampler for fetching the data. As such, any options in the CSV node need to be reset to its defaults so 127 /// that this CSV node will produce the full set of data into the cache. 128 /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. 129 /// \return Status of the function 130 Status MakeSimpleProducer() override; 131 132 private: 133 std::vector<std::string> dataset_files_; 134 char field_delim_; 135 std::vector<std::shared_ptr<CsvBase>> column_defaults_; 136 std::vector<std::string> column_names_; 137 int64_t num_samples_; 138 ShuffleMode shuffle_; 139 int32_t num_shards_; 140 int32_t shard_id_; 141 }; 142 143 } // namespace dataset 144 } // namespace mindspore 145 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CSV_NODE_H_ 146