1 /** 2 * Copyright 2020-2024 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATA_SCHEMA_H_ 17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATA_SCHEMA_H_ 18 19 #include <iostream> 20 #include <map> 21 #include <memory> 22 #include <string> 23 #include <unordered_map> 24 #include <vector> 25 #include <nlohmann/json.hpp> 26 #include "minddata/dataset/include/dataset/constants.h" 27 #include "minddata/dataset/core/data_type.h" 28 #include "minddata/dataset/core/tensor_shape.h" 29 #include "minddata/dataset/util/status.h" 30 31 namespace mindspore { 32 namespace dataset { 33 /// \class ColDescriptor data_schema.h 34 /// \brief A simple class to provide meta info about a column. 35 class ColDescriptor { 36 public: 37 /// \brief Constructor 1: Simple constructor that leaves things uninitialized. 38 ColDescriptor(); 39 40 /// \brief Constructor 2: Main constructor 41 /// \param[in] col_name - The name of the column 42 /// \param[in] col_type - The DE Datatype of the column 43 /// \param[in] tensor_impl - The (initial) type of tensor implementation for the column 44 /// \param[in] rank - The number of dimension of the data 45 /// \param[in] in_shape - option argument for input shape 46 ColDescriptor(const std::string &col_name, DataType col_type, TensorImpl tensor_impl, int32_t rank, 47 const TensorShape *in_shape = nullptr); 48 49 /// \brief Explicit copy constructor is required 50 /// \param[in] in_cd - the source ColDescriptor 51 ColDescriptor(const ColDescriptor &in_cd); 52 53 /// \brief Assignment overload 54 /// \param in_cd - the source ColDescriptor 55 ColDescriptor &operator=(const ColDescriptor &in_cd); 56 57 /// \brief Destructor 58 ~ColDescriptor(); 59 60 /// \brief A print method typically used for debugging 61 /// \param out - The output stream to write output to 62 void Print(std::ostream &out) const; 63 64 /// \brief Given a number of elements, this function will compute what the actual Tensor shape would be. 65 /// If there is no starting TensorShape in this column, or if there is a shape but it contains 66 /// an unknown dimension, then the output shape returned shall resolve dimensions as needed. 67 /// \param[in] num_elements - The number of elements in the data for a Tensor 68 /// \param[in/out] out_shape - The materialized output Tensor shape 69 /// \return Status The status code returned 70 Status MaterializeTensorShape(int32_t num_elements, TensorShape *out_shape) const; 71 72 /// \brief << Stream output operator overload 73 /// This allows you to write the debug print info using stream operators 74 /// \param[in] out - reference to the output stream being overloaded 75 /// \param[in] cd - reference to the ColDescriptor to display 76 /// \return - the output stream must be returned 77 friend std::ostream &operator<<(std::ostream &out, const ColDescriptor &cd) { 78 cd.Print(out); 79 return out; 80 } 81 82 /// \brief getter function 83 /// \return The column's DataType Type()84 DataType Type() const { return type_; } 85 86 /// \brief getter function 87 /// \return The column's rank Rank()88 int32_t Rank() const { return rank_; } 89 90 /// \brief getter function 91 /// \return The column's name Name()92 std::string Name() const { return col_name_; } 93 94 /// \brief getter function 95 /// \return The column's shape 96 TensorShape Shape() const; 97 98 /// \brief Check if the column has a shape. 99 /// \return Whether the column has a shape. HasShape()100 bool HasShape() const { return tensor_shape_ != nullptr; } 101 102 /// \brief Check if the column has a known shape. 103 /// \return Whether the column has a known shape. HasKnownShape()104 bool HasKnownShape() const { return HasShape() && Shape().known(); } 105 106 /// \brief getter function 107 /// \return The column's tensor implementation type GetTensorImpl()108 TensorImpl GetTensorImpl() const { return tensor_impl_; } 109 110 private: 111 DataType type_; // The columns type 112 int32_t rank_; // The rank for this column (number of dimensions) 113 TensorImpl tensor_impl_; // The initial flavour of the tensor for this column 114 std::unique_ptr<TensorShape> tensor_shape_; // The fixed shape (if given by user) 115 std::string col_name_; // The name of the column 116 }; 117 118 /// \class DataSchema data_schema.h 119 /// \brief A list of the columns. 120 class DataSchema { 121 public: 122 /// \brief Constructor 123 DataSchema(); 124 125 /// \brief Destructor 126 ~DataSchema(); 127 128 /// \brief Parses a schema json file and populates the columns and meta info. 129 /// \param[in] schema_file_path - the schema file that has the column's info to load 130 /// \param[in] columns_to_load - list of strings for columns to load. if empty, assumes all columns. 131 /// \return Status The status code returned 132 Status LoadSchemaFile(const std::string &schema_file_path, const std::vector<std::string> &columns_to_load); 133 134 /// \brief Parses a schema JSON string and populates the columns and meta info. 135 /// \param[in] schema_json_string - the schema file that has the column's info to load 136 /// \param[in] columns_to_load - list of strings for columns to load. if empty, assumes all columns. 137 /// \return Status The status code returned 138 Status LoadSchemaString(const std::string &schema_json_string, const std::vector<std::string> &columns_to_load); 139 140 /// \brief A print method typically used for debugging 141 /// \param[in] out - The output stream to write output to 142 void Print(std::ostream &out) const; 143 144 /// \brief << Stream output operator overload. This allows you to write the debug print info using stream operators 145 /// \param[in] out - reference to the output stream being overloaded 146 /// \param[in] ds - reference to the DataSchema to display 147 /// \return - the output stream must be returned 148 friend std::ostream &operator<<(std::ostream &out, const DataSchema &ds) { 149 ds.Print(out); 150 return out; 151 } 152 153 /// \brief Adds a column descriptor to the schema 154 /// \param[in] cd - The ColDescriptor to add 155 /// \return Status The status code returned 156 Status AddColumn(const ColDescriptor &cd); 157 158 /// \brief getter 159 /// \return The reference to a ColDescriptor to get (const version) 160 const ColDescriptor &Column(int32_t idx) const; 161 162 /// \brief getter 163 /// \return The number of columns in the schema NumColumns()164 size_t NumColumns() const { return col_descs_.size(); } 165 Empty()166 bool Empty() const { return NumColumns() == 0; } 167 168 /// \brief getter 169 /// \return The number of rows read from schema NumRows()170 int64_t NumRows() const { return num_rows_; } 171 172 static const char DEFAULT_DATA_SCHEMA_FILENAME[]; 173 174 /// \brief Loops through all columns in the schema and returns a map with the column name to column index number. 175 /// \param[in/out] out_column_name_map - The output map of columns names to column index 176 /// \return Status The status code returned 177 Status GetColumnNameMap(std::unordered_map<std::string, int32_t> *out_column_name_map); 178 179 /// \brief Get the column name list of the schema. 180 /// \param[out] column_names The column names in the schema. 181 /// \return The status code. 182 Status GetColumnName(std::vector<std::string> *column_names) const; 183 184 private: 185 /// \brief Internal helper function. Parses the json schema file in any order and produces a schema that 186 /// does not follow any particular order (json standard does not enforce any ordering protocol). 187 /// This one produces a schema that contains all of the columns from the schema file. 188 /// \param[in] column_tree - The nlohmann tree from the json file to parse 189 /// \return Status The status code returned 190 Status AnyOrderLoad(nlohmann::json column_tree); 191 192 /// \brief Internal helper function. For each input column name, perform a lookup to the json document to 193 /// find the matching column. When the match is found, process that column to build the column 194 /// descriptor and add to the schema in the order in which the input column names are given. 195 /// \param[in] column_tree - The nlohmann tree from the json file to parse 196 /// \param[in] columns_to_load - list of strings for the columns to add to the schema 197 /// \return Status The status code returned 198 Status ColumnOrderLoad(nlohmann::json column_tree, const std::vector<std::string> &columns_to_load); 199 200 /// \brief Internal helper function. Given the json tree for a given column, load it into our schema. 201 /// \param[in] columnTree - The nlohmann child tree for a given column to load. 202 /// \param[in] col_name - The string name of the column for that subtree. 203 /// \return Status The status code returned 204 Status ColumnLoad(nlohmann::json column_child_tree, const std::string &col_name); 205 206 /// \brief Internal helper function. Performs sanity checks on the json file setup. 207 /// \param[in] js - The nlohmann tree for the schema file 208 /// \return Status The status code returned 209 Status PreLoadExceptionCheck(const nlohmann::json &js); 210 211 std::vector<ColDescriptor> col_descs_; // Vector of column descriptors 212 int64_t num_rows_; 213 }; 214 } // namespace dataset 215 } // namespace mindspore 216 217 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATA_SCHEMA_H_ 218