• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_TREE_MODIFIER_H_
18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_TREE_MODIFIER_H_
19 
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <unordered_map>
24 #include <utility>
25 #include <vector>
26 
27 #include "minddata/dataset/engine/execution_tree.h"
28 #include "minddata/dataset/engine/tree_adapter.h"
29 
30 constexpr int64_t queue_size = 10;
31 
32 namespace mindspore {
33 namespace dataset {
34 class DatasetNode;
35 
36 /// A pure virtual class to be used as a base for all pipeline modification requests.
37 class ChangeRequest {
38  public:
39   /// Default constructor
40   ChangeRequest() = default;
41   virtual ~ChangeRequest() = default;
42 
43   /// Pure virtual method. Subclasses should override this function and implement the actual change to the give
44   /// operator.
45   /// \param op pointer to the operator that the change will be applied on
46   /// \return Status return Status code
47   virtual Status ApplyChange(DatasetOp *op) = 0;
48 };
49 
50 using ChangeRequestPtr = std::shared_ptr<ChangeRequest>;
51 
52 /// ChangeRequest to add n workers to an operator.
53 class ChangeNumWorkersRequest : public ChangeRequest {
54  public:
55   /// Constructor
56   /// \param num_workers number of workeres to be added to the operator. Default to 1.
num_workers_(num_workers)57   explicit ChangeNumWorkersRequest(int32_t num_workers = 1) : num_workers_(num_workers) {}
58   virtual ~ChangeNumWorkersRequest() = default;
59 
60   /// Actual change to add n workers
61   /// \param op pointer to the operator that the change will be applied on
62   /// \return Status return Status code
63   Status ApplyChange(DatasetOp *op) override;
64 
65  private:
66   int32_t num_workers_;
67 };
68 
69 /// ChangeRequest to change the size of the oupout connector of an operators.
70 class ResizeConnectorRequest : public ChangeRequest {
71  public:
72   /// Constructor
73   /// \param new_size new queue size.
ResizeConnectorRequest(int32_t new_size)74   explicit ResizeConnectorRequest(int32_t new_size) : new_size_(new_size) {}
75   virtual ~ResizeConnectorRequest() = default;
76 
77   /// Actual change to resize the output connector of the given operator
78   /// \param op pointer to the operator that the change will be applied on
79   /// \return Status return Status code
ApplyChange(DatasetOp * op)80   Status ApplyChange(DatasetOp *op) override {
81     RETURN_IF_NOT_OK(op->OutputConnector()->Resize(new_size_));
82     return Status::OK();
83   }
84 
85  private:
86   int32_t new_size_;
87 };
88 
89 /// A callback class used by Aututune to queue changes for operators
90 class AutotuneCallback : public DSCallback {
91  public:
AutotuneCallback(int32_t step_size,DatasetOp * op)92   AutotuneCallback(int32_t step_size, DatasetOp *op)
93       : DSCallback(step_size), op_(op), change_request_queue_(std::make_unique<Queue<ChangeRequestPtr>>(queue_size)) {}
94   virtual ~AutotuneCallback() = default;
95 
96   Status DSNStepBegin(const CallbackParam &cb_param) override;
97   Status DSBegin(const CallbackParam &cb_param) override;
98   Status DSEpochBegin(const CallbackParam &cb_param) override;
99   Status DSEnd(const CallbackParam &cb_param) override;
100   Status DSEpochEnd(const CallbackParam &cb_param) override;
101   Status DSNStepEnd(const CallbackParam &cb_param) override;
102 
103   bool IsBeginNeeded() override;
104   bool IsEpochBeginNeeded() override;
105   bool IsNStepBeginNeeded() override;
106   bool IsEndNeeded() override;
107   bool IsEpochEndNeeded() override;
108   bool IsNStepEndNeeded() override;
109 
110   ///  Push a change request to the queue of the callback.
111   /// \param change_request Shared pointer to the change request to be pushed to the queue.
112   /// \return Status return Status code
113   Status PushChangeRequest(ChangeRequestPtr change_request);
114 
115  private:
116   DatasetOp *op_;
117   std::unique_ptr<Queue<ChangeRequestPtr>> change_request_queue_;
118 };
119 
120 /// Main class to handle modification of the ExecutionTree used by AutoTune
121 class TreeModifier {
122   // friend with TreeAdapter to access the ExecutionTree
123   friend TreeAdapter;
124 
125  public:
126   /// Constructor to create a TreeModifier given a TreeAdapter
127   /// \param adapter TreeAdapter
128   explicit TreeModifier(const TreeAdapter *adapter);
129 
130   /// Constructor to create a TreeModifier given an ExecutionTree
131   /// \param tree ExecutionTree
TreeModifier(ExecutionTree * tree)132   explicit TreeModifier(ExecutionTree *tree) : tree_(tree) {
133     // loop over all ops to create AutotuneCallback and register it.
134     for (auto itr = tree_->begin(); itr != tree_->end(); ++itr) {
135       auto cb = std::make_shared<AutotuneCallback>(1, itr.get().get());
136       itr->AddCallbacks({cb});
137       (void)callbacks.insert(std::make_pair(itr->id(), cb));
138     }
139   }
140 
141   /// Add changeRequest to the callback associated with the op.
142   /// \param op_id Operator ID
143   /// \param change_request Pointer to the change request
144   /// \return Status return Status code
AddChangeRequest(int32_t op_id,const ChangeRequestPtr & change_request)145   Status AddChangeRequest(int32_t op_id, const ChangeRequestPtr &change_request) {
146     num_requests_++;
147     RETURN_IF_NOT_OK(callbacks[op_id]->PushChangeRequest(change_request));
148     return Status::OK();
149   }
150 
151   /// \brief Get the number of change requests received
152   /// \return Number of change requests received
GetRequestsCount()153   uint64_t GetRequestsCount() const { return num_requests_; }
154 
155  private:
156   ExecutionTree *tree_;
157   std::map<int32_t, std::shared_ptr<AutotuneCallback>> callbacks;
158   uint64_t num_requests_ = 0;  // counter for number of requests received
159 };
160 }  // namespace dataset
161 }  // namespace mindspore
162 
163 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_TREE_MODIFIER_H_
164