• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_MINDRECORD_OP_H_
17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_MINDRECORD_OP_H_
18 #pragma once
19 
20 #include <cstdint>
21 #include <map>
22 #include <memory>
23 #include <queue>
24 #include <string>
25 #include <tuple>
26 #include <unordered_map>
27 #include <unordered_set>
28 #include <utility>
29 #include <vector>
30 
31 #include "minddata/dataset/engine/data_schema.h"
32 #include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h"
33 #include "minddata/dataset/util/queue.h"
34 #include "minddata/dataset/util/status.h"
35 #include "minddata/mindrecord/include/shard_column.h"
36 #include "minddata/mindrecord/include/shard_error.h"
37 #include "minddata/mindrecord/include/shard_reader.h"
38 #include "minddata/mindrecord/include/common/shard_utils.h"
39 #include "minddata/dataset/util/wait_post.h"
40 
41 namespace mindspore {
42 namespace dataset {
43 // Forward declares
44 template <typename T>
45 class Queue;
46 
47 using mindrecord::ShardOperator;
48 using mindrecord::ShardReader;
49 using ShardTuple = std::vector<std::tuple<std::vector<uint8_t>, mindrecord::json>>;  /// Row of data from ShardReader
50 
51 const int32_t LOG_INTERVAL = 19;
52 
53 class MindRecordOp : public MappableLeafOp {
54  public:
55   // Constructor of the MindRecordOp.
56   // @note The builder class should be used to call it
57   // @param num_mind_record_workers - The number of workers for the op (run by ShardReader)
58   // @param dataset_file - dataset files
59   // @param op_connector_queue_size - The output connector queue size
60   // @param columns_to_load - The list of columns to use (column name)
61   // @param operators - ShardOperators for Shuffle, Category, Sample
62   // @param sampler - sampler tells MindRecordOp what to read
63   MindRecordOp(int32_t num_mind_record_workers, std::vector<std::string> dataset_file, bool load_dataset,
64                int32_t op_connector_queue_size, const std::vector<std::string> &columns_to_load,
65                const std::vector<std::shared_ptr<ShardOperator>> &operators, int64_t num_padded_,
66                const mindrecord::json &sample_json, const std::map<std::string, std::string> &sample_bytes_,
67                const ShuffleMode shuffle_mode_, std::unique_ptr<ShardReader> shard_reader,
68                std::shared_ptr<SamplerRT> sampler);
69 
70   /// Destructor
71   ~MindRecordOp() override;
72 
73   /// A print method typically used for debugging
74   /// @param out - The output stream to write output to
75   /// @param show_all - A bool to control if you want to show all info or just a summary
76   void Print(std::ostream &out, bool show_all) const override;
77 
78   /// << Stream output operator overload
79   /// @notes This allows you to write the debug print info using stream operators
80   /// @param out - reference to the output stream being overloaded
81   /// @param op - reference to the MindRecordOp to display
82   /// @return - the output stream must be returned
83   friend std::ostream &operator<<(std::ostream &out, const MindRecordOp &op) {
84     op.Print(out, false);
85     return out;
86   }
87 
88   // Worker thread pulls a number of IOBlock from IOBlock Queue, make a TensorRow and push it to Connector
89   // @param int32_t workerId - id of each worker
90   // @return Status The status code returned
91   Status WorkerEntry(int32_t worker_id) override;
92 
93   // Called first when function is called
94   // @return
95   Status RegisterAndLaunchThreads() override;
96 
97   /// Overrides base class reset method.  When an operator does a reset, it cleans up any state
98   /// info from it's previous execution and then initializes itself so that it can be executed
99   /// again.
100   /// @return Status The status code returned
101   Status Reset() override;
102 
103   static Status CountTotalRows(const std::vector<std::string> dataset_path, bool load_dataset,
104                                const std::shared_ptr<ShardOperator> &op, int64_t *count, int64_t num_padded);
105 
106   // Getter method
dataset_file()107   std::vector<std::string> dataset_file() const { return dataset_file_; }
108 
109   /// Getter method
columns_to_load()110   std::vector<std::string> columns_to_load() const { return columns_to_load_; }
111 
load_dataset()112   bool load_dataset() const { return load_dataset_; }
113 
114   Status Init();
115 
116   /// Op name getter
117   /// @return Name of the current Op
Name()118   std::string Name() const override { return "MindRecordOp"; }
119 
120  private:
121   Status GetRowFromReader(TensorRow *fetched_row, uint64_t row_id, int32_t worker_id);
122 
123   /// Parses a single cell and puts the data into a tensor
124   /// @param tensor_row - the tensor row to put the parsed data in
125   /// @param columns_blob - the blob data received from the reader
126   /// @param columns_json - the data for fields received from the reader
127   Status LoadTensorRow(TensorRow *tensor_row, const std::vector<uint8_t> &columns_blob,
128                        const mindrecord::json &columns_json, const mindrecord::TaskType task_type);
129 
LoadTensorRow(row_id_type row_id,TensorRow * row)130   Status LoadTensorRow(row_id_type row_id, TensorRow *row) override {
131     return Status(StatusCode::kMDSyntaxError, "[Internal ERROR] Cannot call this method.");
132   }
133   // Private function for computing the assignment of the column name map.
134   // @return - Status
135   Status ComputeColMap() override;
136 
137  protected:
138   Status PrepareData() override;
139 
140   /// Add a new worker to the MindRecordOp. The function will have to wait for all workers to process current rows.
141   /// It will then update the shard reader. Finally, it adds a new thread to the list.
142   /// \note The caller of this function has to be the main thread of the Op, since it's the only entity responsible to
143   /// push rows to workers_in_queue
144   /// \return Status The status code returned
145   Status AddNewWorkers(int32_t num_new_workers = 1) override;
146 
147   /// Remove a worker from MindRecordOp. The function will have to wait for all workers to process current rows.
148   /// It will then update the shard reader. Finally, it removes a thread from the list.
149   /// \note The caller of this function has to be the main thread of the Op, since it's the only entity responsible to
150   /// push rows to workers_in_queue
151   /// \return Status The status code returned
152   Status RemoveWorkers(int32_t num_workers = 1) override;
153 
154   /// Initialize pull mode, calls PrepareData() within
155   /// @return Status The status code returned
156   Status InitPullMode() override;
157 
158   /// Load a tensor row at location row_id for pull mode
159   /// \param row_id_type row_id - id for this tensor row
160   /// \param TensorRow row - loaded row
161   /// \return Status The status code returned
162   Status LoadTensorRowPullMode(row_id_type row_id, TensorRow *row) override;
163 
164  private:
165   std::vector<std::string> dataset_file_;                  // dataset files
166   bool load_dataset_;                                      // load dataset from single file or not
167   std::vector<std::string> columns_to_load_;               // Columns to load from dataset
168   std::vector<std::shared_ptr<ShardOperator>> operators_;  // ShardOperators to use
169   int32_t num_mind_record_workers_;                        // number of workers to be spawned by ShardReader
170   std::atomic<int32_t> ended_worker_;
171 
172   int64_t num_padded_;
173   mindrecord::json sample_json_;
174   std::map<std::string, std::string> sample_bytes_;
175 
176   std::unique_ptr<DataSchema> data_schema_;  // Data schema for column typing
177 
178   std::unique_ptr<ShardReader> shard_reader_;
179 
180   std::mutex ended_worker_mutex_;
181 
182   ShuffleMode shuffle_mode_;
183 };
184 }  // namespace dataset
185 }  // namespace mindspore
186 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_MINDRECORD_OP_H_
187