• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_ACTOR_KERNEL_INFER_ACTOR_H_
18 #define MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_ACTOR_KERNEL_INFER_ACTOR_H_
19 
20 #include <string>
21 #include <memory>
22 #include <vector>
23 #include "runtime/graph_scheduler/actor/kernel_actor.h"
24 #include "runtime/hardware/device_context.h"
25 
26 namespace mindspore {
27 namespace runtime {
28 // KernelInferActor is used to Infer the shape output scenario from the dynamic shape asynchronous operator, improving
29 // the concurrency between dynamic shape operators and improving the performance of the dynamic shape network.
30 class KernelInferActor : public KernelActor {
31  public:
32   KernelInferActor(const std::string &name, const CNodePtr &kernel, const DeviceContext *device_context,
33                    const AID &memory_manager_aid,
34                    const KernelTransformType &type = KernelTransformType::kKernelInferActor)
35       : KernelActor(name, kernel, device_context, memory_manager_aid, nullptr, nullptr,
36                     GraphExecutionStrategy::kPipeline, {}, {}, type) {}
37   ~KernelInferActor() override = default;
38 
39   // The actor run when receive the input data.
40   void RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) override;
41 
42   // The memory related operation interface.
43   void SendMemoryFreeReq(OpContext<DeviceTensor> *const context) override;
44 
45  protected:
46   void Run(OpContext<DeviceTensor> *const context) override;
47   void Init() override;
SendRecorderInfo(OpContext<DeviceTensor> * const context)48   void SendRecorderInfo(OpContext<DeviceTensor> *const context) const override {}
49 };
50 
51 using KernelInferActorPtr = std::shared_ptr<KernelInferActor>;
52 }  // namespace runtime
53 }  // namespace mindspore
54 #endif  // MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_ACTOR_KERNEL_INFER_ACTOR_H_
55