• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_FILTER_OP_H_
17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_FILTER_OP_H_
18 
19 #include <memory>
20 #include <queue>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 #include "minddata/dataset/engine/dataset_iterator.h"
25 #include "minddata/dataset/engine/datasetops/parallel_op.h"
26 #include "minddata/dataset/kernels/tensor_op.h"
27 #include "minddata/dataset/util/queue.h"
28 
29 namespace mindspore {
30 namespace dataset {
31 
32 class FilterOp : public ParallelOp {
33  public:
34   enum filterCtrl : int8_t { kFilterEmpty = 0, kFilterPartial = 1, kFilterFull = 2, kFilterEoe = 3, kFilterEof = 4 };
35 
36   // Constructor of FilterOp
37   // @note The builder class should be used to call it.
38   // @param in_col_names A list of input column names,when it is empty the predicate will be
39   //     applied all columns in the dataset.
40   // @param num_workers The number of worker threads.
41   // @param op_connector_size The size of each queue in the connector.
42   // @param predicate_func python callable which returns a boolean value.
43   FilterOp(const std::vector<std::string> &in_col_names, int32_t num_workers, int32_t op_queue_size,
44            std::shared_ptr<TensorOp> predicate_func);
45 
46   // Destructor
47   ~FilterOp() = default;
48 
49   // Class functor operator () override.
50   // All dataset ops operate by launching a thread (see ExecutionTree),This class functor will
51   // provide the master loop that drives the logic for performing the work.
52   // @return Status The status code returned
53   Status operator()() override;
54 
55   // @param int32_t workerId.
56   // @return Status The status code returned.
57   Status EofReceived(int32_t) override;
58 
59   // @param int32_t workerId.
60   // @return Status The status code returned.
61   Status EoeReceived(int32_t) override;
62 
63   // A print method typically used for debugging.
64   // @param out The output stream to write output to.
65   // @param show_all A bool to control if you want to show all info or just a summary.
66   void Print(std::ostream &out, bool show_all) const override;
67 
68   // Op name getter
69   // @return Name of the current Op
Name()70   std::string Name() const override { return kFilterOp; }
71 
72   int32_t NumConsumers() const override;
73 
74  private:
75   // predicate_func python callable which returns a boolean value.
76   std::shared_ptr<TensorOp> predicate_func_;
77 
78   // Variable to store the column name that will feed to predicate function.
79   std::vector<std::string> in_columns_;
80 
81   // Internal queue for filter.
82   QueueList<std::pair<TensorRow, filterCtrl>> filter_queues_;
83 
84   QueueList<TensorRow> worker_queues_;  // internal queue for syncing worker
85 
86   std::unique_ptr<ChildIterator> child_iterator_;
87 
88   // Private function for worker/thread to loop continuously. It comprises the main
89   // logic of FilterOp, getting the data from previous Op, validating user specified column names,
90   // applying predicate to each of the data, filter the data when predicate result is false.
91   // @param worker_id The id assigned to this thread/worker upon creation.
92   // @return Status The status code returned
93   Status WorkerEntry(int32_t worker_id) override;  //  In: workerId assigned by tree_
94 
95   // Filter the data by  predicate function .
96   // @param in_row input row.
97   // @param out_predicate result boolean to filter or not.
98   // @return Status The status code returned
99   Status WorkerCompute(const TensorRow &in_row, bool *out_predicate);
100 
101   // Collector TensorRows.
102   // @return Status The status code returned
103   Status Collector();
104 
105   // @param input tensor vector.
106   // @return Status The status code returned.
107   Status CheckInput(const TensorRow &input) const;
108 
109   // Invoke python func.
110   // @param input tensor vector.
111   // @param the result of predicate.
112   // @return Status The status code returned.
113   Status InvokePredicateFunc(const TensorRow &input, bool *out_predicate);
114 
115   // Private function for validating if each of the user specified input column names
116   // exist in column_name_id_map_.
117   // @param input_columns The vector of input column names used in the current thread.
118   // @return Status The status code returned
119   Status ValidateInColumns(const std::vector<std::string> &input_columns);
120 
121   // Do the initialization of all queues then start all worker threads
122   // @return Status The status code returned
123   Status LaunchThreadsAndInitOp();
124 };
125 
126 }  // namespace dataset
127 }  // namespace mindspore
128 #endif
129