1 /** 2 * Copyright 2020-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CONCAT_OP_H_ 17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CONCAT_OP_H_ 18 19 #include <memory> 20 #include <string> 21 #include <random> 22 #include <unordered_map> 23 #include <utility> 24 #include <vector> 25 #include "minddata/dataset/engine/dataset_iterator.h" 26 #include "minddata/dataset/engine/datasetops/pipeline_op.h" 27 #include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h" 28 #include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h" 29 30 namespace mindspore { 31 namespace dataset { 32 class ConcatOp : public PipelineOp { 33 public: 34 // Constructor of the ConcatOp. 35 // @note The builder class should be used to call it 36 // @param op_connector_size - connector size 37 ConcatOp(); 38 ConcatOp(const std::shared_ptr<SamplerRT> &sampler, const std::vector<std::pair<int, int>> &children_flag_and_nums, 39 const std::vector<std::pair<int, int>> &children_start_end_index, 40 const std::vector<int64_t> &children_sizes); 41 42 // Destructor 43 ~ConcatOp() = default; 44 45 // A print method typically used for debugging 46 // @param out - The output stream to write output to 47 // @param show_all - A bool to control if you want to show all info or just a summary 48 void Print(std::ostream &out, bool show_all) const override; 49 50 // << Stream output operator overload 51 // @notes This allows you to write the debug print info using stream operators 52 // @param out - reference to the output stream being overloaded 53 // @param ro - reference to the ConcatOp to display 54 // @return - the output stream must be returned 55 friend std::ostream &operator<<(std::ostream &out, const ConcatOp &ro) { 56 ro.Print(out, false); 57 return out; 58 } 59 60 // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will 61 // provide the master loop that drives the logic for performing the work 62 // @return Status The status code returned 63 Status operator()() override; 64 65 // Op name getter 66 // @return Name of the current Op Name()67 std::string Name() const override { return kConcatOp; } 68 69 // Private function for computing the assignment of the column name map. 70 // @return - Status 71 Status ComputeColMap() override; 72 73 /// \brief Gets the number of classes 74 /// \param[out] num_classes the number of classes 75 /// \return Status - The status code return 76 Status GetNumClasses(int64_t *num_classes) override; 77 78 Status GetNextRow(TensorRow *row) override; 79 80 Status GetNextRowPullMode(TensorRow *const row) override; 81 82 Status SampleInSequence(TensorRow *row, bool is_pipeline_mode = true); 83 84 Status SampleInGlobal(TensorRow *row, bool is_pipeline_mode = true); 85 86 /// Check if the current sample will be taken or dropped 87 /// \return bool 88 bool IgnoreSample(); 89 90 protected: 91 /// \brief Gets the implementation status for operator in pull mode 92 /// \return implementation status PullModeImplementationStatus()93 ImplementedPullMode PullModeImplementationStatus() const override { return ImplementedPullMode::Implemented; } 94 95 private: 96 Status Verify(int32_t id, const TensorRow &new_row); 97 98 std::unordered_map<std::string, int32_t> column_name_id_; // Mapping between col index and col name 99 std::vector<DataType> data_type_; 100 std::vector<dsize_t> data_rank_; 101 std::vector<std::pair<int, int>> children_flag_and_nums_; 102 std::vector<std::pair<int, int>> children_start_end_index_; 103 std::vector<int64_t> children_sizes_; 104 std::vector<int64_t> children_sizes_ori_; 105 std::vector<bool> children_exhausted_; 106 107 size_t cur_child_; 108 bool verified_; 109 int64_t sample_number_; 110 111 int32_t num_shard_; 112 int32_t shard_index_; 113 114 std::unique_ptr<std::discrete_distribution<>> discrete_random_; 115 bool global_shuffle_; 116 uint32_t seed_; 117 std::mt19937 rnd_; 118 }; 119 } // namespace dataset 120 } // namespace mindspore 121 122 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CONCAT_OP_H_ 123