1 /** 2 * Copyright 2020-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_FILTER_OP_H_ 17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_FILTER_OP_H_ 18 19 #include <memory> 20 #include <queue> 21 #include <string> 22 #include <utility> 23 #include <vector> 24 #include "minddata/dataset/engine/dataset_iterator.h" 25 #include "minddata/dataset/engine/datasetops/parallel_op.h" 26 #include "minddata/dataset/kernels/tensor_op.h" 27 #include "minddata/dataset/util/queue.h" 28 29 namespace mindspore { 30 namespace dataset { 31 32 class FilterOp : public ParallelOp { 33 public: 34 enum filterCtrl : int8_t { kFilterEmpty = 0, kFilterPartial = 1, kFilterFull = 2, kFilterEoe = 3, kFilterEof = 4 }; 35 36 // Constructor of FilterOp 37 // @note The builder class should be used to call it. 38 // @param in_col_names A list of input column names,when it is empty the predicate will be 39 // applied all columns in the dataset. 40 // @param num_workers The number of worker threads. 41 // @param op_connector_size The size of each queue in the connector. 42 // @param predicate_func python callable which returns a boolean value. 43 FilterOp(const std::vector<std::string> &in_col_names, int32_t num_workers, int32_t op_queue_size, 44 std::shared_ptr<TensorOp> predicate_func); 45 46 // Destructor 47 ~FilterOp() = default; 48 49 // Class functor operator () override. 50 // All dataset ops operate by launching a thread (see ExecutionTree),This class functor will 51 // provide the master loop that drives the logic for performing the work. 52 // @return Status The status code returned 53 Status operator()() override; 54 55 // @param int32_t workerId. 56 // @return Status The status code returned. 57 Status EofReceived(int32_t) override; 58 59 // @param int32_t workerId. 60 // @return Status The status code returned. 61 Status EoeReceived(int32_t) override; 62 63 // A print method typically used for debugging. 64 // @param out The output stream to write output to. 65 // @param show_all A bool to control if you want to show all info or just a summary. 66 void Print(std::ostream &out, bool show_all) const override; 67 68 // Op name getter 69 // @return Name of the current Op Name()70 std::string Name() const override { return kFilterOp; } 71 72 int32_t NumConsumers() const override; 73 74 private: 75 // predicate_func python callable which returns a boolean value. 76 std::shared_ptr<TensorOp> predicate_func_; 77 78 // Variable to store the column name that will feed to predicate function. 79 std::vector<std::string> in_columns_; 80 81 // Internal queue for filter. 82 QueueList<std::pair<TensorRow, filterCtrl>> filter_queues_; 83 84 QueueList<TensorRow> worker_queues_; // internal queue for syncing worker 85 86 std::unique_ptr<ChildIterator> child_iterator_; 87 88 // Private function for worker/thread to loop continuously. It comprises the main 89 // logic of FilterOp, getting the data from previous Op, validating user specified column names, 90 // applying predicate to each of the data, filter the data when predicate result is false. 91 // @param worker_id The id assigned to this thread/worker upon creation. 92 // @return Status The status code returned 93 Status WorkerEntry(int32_t worker_id) override; // In: workerId assigned by tree_ 94 95 // Filter the data by predicate function . 96 // @param in_row input row. 97 // @param out_predicate result boolean to filter or not. 98 // @return Status The status code returned 99 Status WorkerCompute(const TensorRow &in_row, bool *out_predicate); 100 101 // Collector TensorRows. 102 // @return Status The status code returned 103 Status Collector(); 104 105 // @param input tensor vector. 106 // @return Status The status code returned. 107 Status CheckInput(const TensorRow &input) const; 108 109 // Invoke python func. 110 // @param input tensor vector. 111 // @param the result of predicate. 112 // @return Status The status code returned. 113 Status InvokePredicateFunc(const TensorRow &input, bool *out_predicate); 114 115 // Private function for validating if each of the user specified input column names 116 // exist in column_name_id_map_. 117 // @param input_columns The vector of input column names used in the current thread. 118 // @return Status The status code returned 119 Status ValidateInColumns(const std::vector<std::string> &input_columns); 120 121 // Do the initialization of all queues then start all worker threads 122 // @return Status The status code returned 123 Status LaunchThreadsAndInitOp(); 124 }; 125 126 } // namespace dataset 127 } // namespace mindspore 128 #endif 129