• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_RUN_OP_OP_RUNTIME_INFO_H_
17 #define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_RUN_OP_OP_RUNTIME_INFO_H_
18 
19 #include <memory>
20 #include <utility>
21 #include <vector>
22 #include <string>
23 #include "include/backend/device_address.h"
24 #include "include/backend/kernel_info.h"
25 #include "include/backend/kernel_graph.h"
26 
27 namespace mindspore::runtime {
28 class AclRuntimeInfo {
29  public:
AclRuntimeInfo()30   AclRuntimeInfo() : is_dynamic_input_size_(true), is_dynamic_output_size_(true), use_(false) {}
SetUse(bool flag)31   void SetUse(bool flag) { use_ = flag; }
SetIsDynamicInputSize(bool flag)32   void SetIsDynamicInputSize(bool flag) {
33     CheckInUse();
34     is_dynamic_input_size_ = flag;
35   }
SetIsDynamicOutputSize(bool flag)36   void SetIsDynamicOutputSize(bool flag) {
37     CheckInUse();
38     is_dynamic_output_size_ = flag;
39   }
SetInputNames(std::vector<std::string> input_names)40   void SetInputNames(std::vector<std::string> input_names) {
41     CheckInUse();
42     input_names_ = std::move(input_names);
43   }
SetOutputNames(std::vector<std::string> output_names)44   void SetOutputNames(std::vector<std::string> output_names) {
45     CheckInUse();
46     output_names_ = std::move(output_names);
47   }
48 
use()49   bool use() const { return use_; }
is_dynamic_input_size()50   bool is_dynamic_input_size() const {
51     CheckInUse();
52     return is_dynamic_input_size_;
53   }
is_dynamic_output_size()54   bool is_dynamic_output_size() const {
55     CheckInUse();
56     return is_dynamic_output_size_;
57   }
input_names()58   const std::vector<std::string> &input_names() {
59     if (is_dynamic_input_size()) {
60       MS_LOG(EXCEPTION) << "This node has dynamic_input_size, should not get AclRuntimeInfo.";
61     }
62     return input_names_;
63   }
output_names()64   const std::vector<std::string> &output_names() {
65     if (is_dynamic_output_size()) {
66       MS_LOG(EXCEPTION) << "This node has dynamic_output_size, should not get AclRuntimeInfo.";
67     }
68     return output_names_;
69   }
70 
71  private:
CheckInUse()72   void CheckInUse() const {
73     if (!use()) {
74       MS_LOG(EXCEPTION) << "AclRuntimeInfo is not in use.";
75     }
76   }
77   std::vector<std::string> input_names_;
78   std::vector<std::string> output_names_;
79   bool is_dynamic_input_size_;
80   bool is_dynamic_output_size_;
81   bool use_;
82 };
83 using AclRuntimeInfoPtr = std::shared_ptr<AclRuntimeInfo>;
84 
85 class BACKEND_EXPORT OpRuntimeInfo {
86  public:
OpRuntimeInfo(std::vector<std::string> output_format,std::vector<TypeId> output_type,std::vector<size_t> output_tensor_size,std::vector<ShapeVector> output_infer_shape,std::vector<ShapeVector> output_device_shape,device::KernelInfo * kernel_info,std::vector<std::pair<device::KernelInfo *,size_t>> input_kernel_infos)87   OpRuntimeInfo(std::vector<std::string> output_format, std::vector<TypeId> output_type,
88                 std::vector<size_t> output_tensor_size, std::vector<ShapeVector> output_infer_shape,
89                 std::vector<ShapeVector> output_device_shape, device::KernelInfo *kernel_info,
90                 std::vector<std::pair<device::KernelInfo *, size_t>> input_kernel_infos)
91       : acl_runtime_info_(std::make_shared<AclRuntimeInfo>()),
92         output_format_(std::move(output_format)),
93         output_type_(std::move(output_type)),
94         output_tensor_size_(std::move(output_tensor_size)),
95         output_infer_shape_(std::move(output_infer_shape)),
96         output_device_shape_(std::move(output_device_shape)),
97         kernel_info_(kernel_info),
98         input_kernel_infos_(std::move(input_kernel_infos)) {}
99   ~OpRuntimeInfo() = default;
100 
101   // Key for user data.
102   constexpr static char key[] = "OpRuntimeInfo";
103 
104   std::string output_format(size_t index) const;
105   TypeId output_type(size_t index) const;
106   size_t output_tensor_size(size_t index) const;
107   const ShapeVector &output_infer_shape(size_t index) const;
108   const ShapeVector &output_device_shape(size_t index) const;
109   void SetOutputTensorSize(size_t index, size_t tensor_size);
110   void SetOutputInferShape(size_t index, const ShapeVector &shape);
111   void SetOutputDeviceShape(size_t index, const ShapeVector &shape);
112   device::DeviceAddressPtr GetOutputDeviceAddress(size_t index) const;
113   device::DeviceAddressPtr GetWorkspaceDeviceAddress(size_t index) const;
114   device::DeviceAddressPtr GetInputDeviceAddress(size_t index) const;
115   size_t GetInputSize() const;
116   size_t GetOutputSize() const;
117   size_t GetWorkspaceSize() const;
118   kernel::KernelMod *GetKernelMod() const;
119   void Resize(const AnfNodePtr &node);
120 
121   static void CacheGraphOpRuntimeInfo(const KernelGraphPtr &graph);
122   // for acl
123   AclRuntimeInfoPtr acl_runtime_info_;
124 
125  private:
126   std::vector<std::string> output_format_;
127   std::vector<TypeId> output_type_;
128   std::vector<size_t> output_tensor_size_;
129   std::vector<ShapeVector> output_infer_shape_;
130   std::vector<ShapeVector> output_device_shape_;
131   device::KernelInfo *kernel_info_;
132   std::vector<std::pair<device::KernelInfo *, size_t>> input_kernel_infos_;
133 };
134 using OpRuntimeInfoPtr = std::shared_ptr<OpRuntimeInfo>;
135 }  // namespace mindspore::runtime
136 #endif  // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_RUN_OP_OP_RUNTIME_INFO_H_
137