• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATA_SCHEMA_H_
17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATA_SCHEMA_H_
18 
19 #include <iostream>
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <unordered_map>
24 #include <vector>
25 #include <nlohmann/json.hpp>
26 #include "minddata/dataset/include/dataset/constants.h"
27 #include "minddata/dataset/core/data_type.h"
28 #include "minddata/dataset/core/tensor_shape.h"
29 #include "minddata/dataset/util/status.h"
30 
31 namespace mindspore {
32 namespace dataset {
33 /// \class ColDescriptor data_schema.h
34 /// \brief A simple class to provide meta info about a column.
35 class ColDescriptor {
36  public:
37   /// \brief Constructor 1: Simple constructor that leaves things uninitialized.
38   ColDescriptor();
39 
40   /// \brief Constructor 2: Main constructor
41   /// \param[in] col_name - The name of the column
42   /// \param[in] col_type - The DE Datatype of the column
43   /// \param[in] tensor_impl - The (initial) type of tensor implementation for the column
44   /// \param[in] rank - The number of dimension of the data
45   /// \param[in] in_shape - option argument for input shape
46   ColDescriptor(const std::string &col_name, DataType col_type, TensorImpl tensor_impl, int32_t rank,
47                 const TensorShape *in_shape = nullptr);
48 
49   /// \brief Explicit copy constructor is required
50   /// \param[in] in_cd - the source ColDescriptor
51   ColDescriptor(const ColDescriptor &in_cd);
52 
53   /// \brief Assignment overload
54   /// \param in_cd - the source ColDescriptor
55   ColDescriptor &operator=(const ColDescriptor &in_cd);
56 
57   /// \brief Destructor
58   ~ColDescriptor();
59 
60   /// \brief A print method typically used for debugging
61   /// \param out - The output stream to write output to
62   void Print(std::ostream &out) const;
63 
64   /// \brief Given a number of elements, this function will compute what the actual Tensor shape would be.
65   ///     If there is no starting TensorShape in this column, or if there is a shape but it contains
66   ///     an unknown dimension, then the output shape returned shall resolve dimensions as needed.
67   /// \param[in] num_elements - The number of elements in the data for a Tensor
68   /// \param[in/out] out_shape - The materialized output Tensor shape
69   /// \return Status The status code returned
70   Status MaterializeTensorShape(int32_t num_elements, TensorShape *out_shape) const;
71 
72   /// \brief << Stream output operator overload
73   ///     This allows you to write the debug print info using stream operators
74   /// \param[in] out - reference to the output stream being overloaded
75   /// \param[in] cd - reference to the ColDescriptor to display
76   /// \return - the output stream must be returned
77   friend std::ostream &operator<<(std::ostream &out, const ColDescriptor &cd) {
78     cd.Print(out);
79     return out;
80   }
81 
82   /// \brief getter function
83   /// \return The column's DataType
Type()84   DataType Type() const { return type_; }
85 
86   /// \brief getter function
87   /// \return The column's rank
Rank()88   int32_t Rank() const { return rank_; }
89 
90   /// \brief getter function
91   /// \return The column's name
Name()92   std::string Name() const { return col_name_; }
93 
94   /// \brief getter function
95   /// \return The column's shape
96   TensorShape Shape() const;
97 
98   /// \brief Check if the column has a shape.
99   /// \return Whether the column has a shape.
HasShape()100   bool HasShape() const { return tensor_shape_ != nullptr; }
101 
102   /// \brief Check if the column has a known shape.
103   /// \return Whether the column has a known shape.
HasKnownShape()104   bool HasKnownShape() const { return HasShape() && Shape().known(); }
105 
106   /// \brief getter function
107   /// \return The column's tensor implementation type
GetTensorImpl()108   TensorImpl GetTensorImpl() const { return tensor_impl_; }
109 
110  private:
111   DataType type_;                              // The columns type
112   int32_t rank_;                               // The rank for this column (number of dimensions)
113   TensorImpl tensor_impl_;                     // The initial flavour of the tensor for this column
114   std::unique_ptr<TensorShape> tensor_shape_;  // The fixed shape (if given by user)
115   std::string col_name_;                       // The name of the column
116 };
117 
118 /// \class DataSchema data_schema.h
119 /// \brief A list of the columns.
120 class DataSchema {
121  public:
122   /// \brief Constructor
123   DataSchema();
124 
125   /// \brief Destructor
126   ~DataSchema();
127 
128   /// \brief Parses a schema json file and populates the columns and meta info.
129   /// \param[in] schema_file_path - the schema file that has the column's info to load
130   /// \param[in] columns_to_load - list of strings for columns to load. if empty, assumes all columns.
131   /// \return Status The status code returned
132   Status LoadSchemaFile(const std::string &schema_file_path, const std::vector<std::string> &columns_to_load);
133 
134   /// \brief Parses a schema JSON string and populates the columns and meta info.
135   /// \param[in] schema_json_string - the schema file that has the column's info to load
136   /// \param[in] columns_to_load - list of strings for columns to load. if empty, assumes all columns.
137   /// \return Status The status code returned
138   Status LoadSchemaString(const std::string &schema_json_string, const std::vector<std::string> &columns_to_load);
139 
140   /// \brief A print method typically used for debugging
141   /// \param[in] out - The output stream to write output to
142   void Print(std::ostream &out) const;
143 
144   /// \brief << Stream output operator overload. This allows you to write the debug print info using stream operators
145   /// \param[in] out - reference to the output stream being overloaded
146   /// \param[in] ds - reference to the DataSchema to display
147   /// \return - the output stream must be returned
148   friend std::ostream &operator<<(std::ostream &out, const DataSchema &ds) {
149     ds.Print(out);
150     return out;
151   }
152 
153   /// \brief Adds a column descriptor to the schema
154   /// \param[in] cd - The ColDescriptor to add
155   /// \return Status The status code returned
156   Status AddColumn(const ColDescriptor &cd);
157 
158   /// \brief getter
159   /// \return The reference to a ColDescriptor to get (const version)
160   const ColDescriptor &Column(int32_t idx) const;
161 
162   /// \brief getter
163   /// \return The number of columns in the schema
NumColumns()164   size_t NumColumns() const { return col_descs_.size(); }
165 
Empty()166   bool Empty() const { return NumColumns() == 0; }
167 
168   /// \brief getter
169   /// \return The number of rows read from schema
NumRows()170   int64_t NumRows() const { return num_rows_; }
171 
172   static const char DEFAULT_DATA_SCHEMA_FILENAME[];
173 
174   /// \brief Loops through all columns in the schema and returns a map with the column name to column index number.
175   /// \param[in/out] out_column_name_map - The output map of columns names to column index
176   /// \return Status The status code returned
177   Status GetColumnNameMap(std::unordered_map<std::string, int32_t> *out_column_name_map);
178 
179   /// \brief Get the column name list of the schema.
180   /// \param[out] column_names The column names in the schema.
181   /// \return The status code.
182   Status GetColumnName(std::vector<std::string> *column_names) const;
183 
184  private:
185   /// \brief Internal helper function. Parses the json schema file in any order and produces a schema that
186   ///     does not follow any particular order (json standard does not enforce any ordering protocol).
187   ///     This one produces a schema that contains all of the columns from the schema file.
188   /// \param[in] column_tree - The nlohmann tree from the json file to parse
189   /// \return Status The status code returned
190   Status AnyOrderLoad(nlohmann::json column_tree);
191 
192   /// \brief Internal helper function. For each input column name, perform a lookup to the json document to
193   ///     find the matching column.  When the match is found, process that column to build the column
194   ///     descriptor and add to the schema in the order in which the input column names are given.
195   /// \param[in] column_tree - The nlohmann tree from the json file to parse
196   /// \param[in] columns_to_load - list of strings for the columns to add to the schema
197   /// \return Status The status code returned
198   Status ColumnOrderLoad(nlohmann::json column_tree, const std::vector<std::string> &columns_to_load);
199 
200   /// \brief Internal helper function. Given the json tree for a given column, load it into our schema.
201   /// \param[in] columnTree - The nlohmann child tree for a given column to load.
202   /// \param[in] col_name - The string name of the column for that subtree.
203   /// \return Status The status code returned
204   Status ColumnLoad(nlohmann::json column_child_tree, const std::string &col_name);
205 
206   /// \brief Internal helper function. Performs sanity checks on the json file setup.
207   /// \param[in] js - The nlohmann tree for the schema file
208   /// \return Status The status code returned
209   Status PreLoadExceptionCheck(const nlohmann::json &js);
210 
211   std::vector<ColDescriptor> col_descs_;  // Vector of column descriptors
212   int64_t num_rows_;
213 };
214 }  // namespace dataset
215 }  // namespace mindspore
216 
217 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATA_SCHEMA_H_
218