• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CONCAT_OP_H_
17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CONCAT_OP_H_
18 
19 #include <memory>
20 #include <string>
21 #include <random>
22 #include <unordered_map>
23 #include <utility>
24 #include <vector>
25 #include "minddata/dataset/engine/dataset_iterator.h"
26 #include "minddata/dataset/engine/datasetops/pipeline_op.h"
27 #include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h"
28 #include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h"
29 
30 namespace mindspore {
31 namespace dataset {
32 class ConcatOp : public PipelineOp {
33  public:
34   // Constructor of the ConcatOp.
35   // @note The builder class should be used to call it
36   // @param op_connector_size - connector size
37   ConcatOp();
38   ConcatOp(const std::shared_ptr<SamplerRT> &sampler, const std::vector<std::pair<int, int>> &children_flag_and_nums,
39            const std::vector<std::pair<int, int>> &children_start_end_index,
40            const std::vector<int64_t> &children_sizes);
41 
42   // Destructor
43   ~ConcatOp() = default;
44 
45   // A print method typically used for debugging
46   // @param out - The output stream to write output to
47   // @param show_all - A bool to control if you want to show all info or just a summary
48   void Print(std::ostream &out, bool show_all) const override;
49 
50   // << Stream output operator overload
51   // @notes This allows you to write the debug print info using stream operators
52   // @param out - reference to the output stream being overloaded
53   // @param ro - reference to the ConcatOp to display
54   // @return - the output stream must be returned
55   friend std::ostream &operator<<(std::ostream &out, const ConcatOp &ro) {
56     ro.Print(out, false);
57     return out;
58   }
59 
60   // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will
61   // provide the master loop that drives the logic for performing the work
62   // @return Status The status code returned
63   Status operator()() override;
64 
65   // Op name getter
66   // @return Name of the current Op
Name()67   std::string Name() const override { return kConcatOp; }
68 
69   // Private function for computing the assignment of the column name map.
70   // @return - Status
71   Status ComputeColMap() override;
72 
73   /// \brief Gets the number of classes
74   /// \param[out] num_classes the number of classes
75   /// \return Status - The status code return
76   Status GetNumClasses(int64_t *num_classes) override;
77 
78   Status GetNextRow(TensorRow *row) override;
79 
80   Status GetNextRowPullMode(TensorRow *const row) override;
81 
82   Status SampleInSequence(TensorRow *row, bool is_pipeline_mode = true);
83 
84   Status SampleInGlobal(TensorRow *row, bool is_pipeline_mode = true);
85 
86   /// Check if the current sample will be taken or dropped
87   /// \return bool
88   bool IgnoreSample();
89 
90  protected:
91   /// \brief Gets the implementation status for operator in pull mode
92   /// \return implementation status
PullModeImplementationStatus()93   ImplementedPullMode PullModeImplementationStatus() const override { return ImplementedPullMode::Implemented; }
94 
95  private:
96   Status Verify(int32_t id, const TensorRow &new_row);
97 
98   std::unordered_map<std::string, int32_t> column_name_id_;  // Mapping between col index and col name
99   std::vector<DataType> data_type_;
100   std::vector<dsize_t> data_rank_;
101   std::vector<std::pair<int, int>> children_flag_and_nums_;
102   std::vector<std::pair<int, int>> children_start_end_index_;
103   std::vector<int64_t> children_sizes_;
104   std::vector<int64_t> children_sizes_ori_;
105   std::vector<bool> children_exhausted_;
106 
107   size_t cur_child_;
108   bool verified_;
109   int64_t sample_number_;
110 
111   int32_t num_shard_;
112   int32_t shard_index_;
113 
114   std::unique_ptr<std::discrete_distribution<>> discrete_random_;
115   bool global_shuffle_;
116   uint32_t seed_;
117   std::mt19937 rnd_;
118 };
119 }  // namespace dataset
120 }  // namespace mindspore
121 
122 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CONCAT_OP_H_
123