1 /** 2 * Copyright 2021-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_TREE_MODIFIER_H_ 18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_TREE_MODIFIER_H_ 19 20 #include <map> 21 #include <memory> 22 #include <string> 23 #include <unordered_map> 24 #include <utility> 25 #include <vector> 26 27 #include "minddata/dataset/engine/execution_tree.h" 28 #include "minddata/dataset/engine/tree_adapter.h" 29 30 constexpr int64_t queue_size = 10; 31 32 namespace mindspore { 33 namespace dataset { 34 class DatasetNode; 35 36 /// A pure virtual class to be used as a base for all pipeline modification requests. 37 class ChangeRequest { 38 public: 39 /// Default constructor 40 ChangeRequest() = default; 41 virtual ~ChangeRequest() = default; 42 43 /// Pure virtual method. Subclasses should override this function and implement the actual change to the give 44 /// operator. 45 /// \param op pointer to the operator that the change will be applied on 46 /// \return Status return Status code 47 virtual Status ApplyChange(DatasetOp *op) = 0; 48 }; 49 50 using ChangeRequestPtr = std::shared_ptr<ChangeRequest>; 51 52 /// ChangeRequest to add n workers to an operator. 53 class ChangeNumWorkersRequest : public ChangeRequest { 54 public: 55 /// Constructor 56 /// \param num_workers number of workeres to be added to the operator. Default to 1. num_workers_(num_workers)57 explicit ChangeNumWorkersRequest(int32_t num_workers = 1) : num_workers_(num_workers) {} 58 virtual ~ChangeNumWorkersRequest() = default; 59 60 /// Actual change to add n workers 61 /// \param op pointer to the operator that the change will be applied on 62 /// \return Status return Status code 63 Status ApplyChange(DatasetOp *op) override; 64 65 private: 66 int32_t num_workers_; 67 }; 68 69 /// ChangeRequest to change the size of the oupout connector of an operators. 70 class ResizeConnectorRequest : public ChangeRequest { 71 public: 72 /// Constructor 73 /// \param new_size new queue size. ResizeConnectorRequest(int32_t new_size)74 explicit ResizeConnectorRequest(int32_t new_size) : new_size_(new_size) {} 75 virtual ~ResizeConnectorRequest() = default; 76 77 /// Actual change to resize the output connector of the given operator 78 /// \param op pointer to the operator that the change will be applied on 79 /// \return Status return Status code ApplyChange(DatasetOp * op)80 Status ApplyChange(DatasetOp *op) override { 81 RETURN_IF_NOT_OK(op->OutputConnector()->Resize(new_size_)); 82 return Status::OK(); 83 } 84 85 private: 86 int32_t new_size_; 87 }; 88 89 /// A callback class used by Aututune to queue changes for operators 90 class AutotuneCallback : public DSCallback { 91 public: AutotuneCallback(int32_t step_size,DatasetOp * op)92 AutotuneCallback(int32_t step_size, DatasetOp *op) 93 : DSCallback(step_size), op_(op), change_request_queue_(std::make_unique<Queue<ChangeRequestPtr>>(queue_size)) {} 94 virtual ~AutotuneCallback() = default; 95 96 Status DSNStepBegin(const CallbackParam &cb_param) override; 97 Status DSBegin(const CallbackParam &cb_param) override; 98 Status DSEpochBegin(const CallbackParam &cb_param) override; 99 Status DSEnd(const CallbackParam &cb_param) override; 100 Status DSEpochEnd(const CallbackParam &cb_param) override; 101 Status DSNStepEnd(const CallbackParam &cb_param) override; 102 103 bool IsBeginNeeded() override; 104 bool IsEpochBeginNeeded() override; 105 bool IsNStepBeginNeeded() override; 106 bool IsEndNeeded() override; 107 bool IsEpochEndNeeded() override; 108 bool IsNStepEndNeeded() override; 109 110 /// Push a change request to the queue of the callback. 111 /// \param change_request Shared pointer to the change request to be pushed to the queue. 112 /// \return Status return Status code 113 Status PushChangeRequest(ChangeRequestPtr change_request); 114 115 private: 116 DatasetOp *op_; 117 std::unique_ptr<Queue<ChangeRequestPtr>> change_request_queue_; 118 }; 119 120 /// Main class to handle modification of the ExecutionTree used by AutoTune 121 class TreeModifier { 122 // friend with TreeAdapter to access the ExecutionTree 123 friend TreeAdapter; 124 125 public: 126 /// Constructor to create a TreeModifier given a TreeAdapter 127 /// \param adapter TreeAdapter 128 explicit TreeModifier(const TreeAdapter *adapter); 129 130 /// Constructor to create a TreeModifier given an ExecutionTree 131 /// \param tree ExecutionTree TreeModifier(ExecutionTree * tree)132 explicit TreeModifier(ExecutionTree *tree) : tree_(tree) { 133 // loop over all ops to create AutotuneCallback and register it. 134 for (auto itr = tree_->begin(); itr != tree_->end(); ++itr) { 135 auto cb = std::make_shared<AutotuneCallback>(1, itr.get().get()); 136 itr->AddCallbacks({cb}); 137 (void)callbacks.insert(std::make_pair(itr->id(), cb)); 138 } 139 } 140 141 /// Add changeRequest to the callback associated with the op. 142 /// \param op_id Operator ID 143 /// \param change_request Pointer to the change request 144 /// \return Status return Status code AddChangeRequest(int32_t op_id,const ChangeRequestPtr & change_request)145 Status AddChangeRequest(int32_t op_id, const ChangeRequestPtr &change_request) { 146 num_requests_++; 147 RETURN_IF_NOT_OK(callbacks[op_id]->PushChangeRequest(change_request)); 148 return Status::OK(); 149 } 150 151 /// \brief Get the number of change requests received 152 /// \return Number of change requests received GetRequestsCount()153 uint64_t GetRequestsCount() const { return num_requests_; } 154 155 private: 156 ExecutionTree *tree_; 157 std::map<int32_t, std::shared_ptr<AutotuneCallback>> callbacks; 158 uint64_t num_requests_ = 0; // counter for number of requests received 159 }; 160 } // namespace dataset 161 } // namespace mindspore 162 163 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_TREE_MODIFIER_H_ 164