1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_RUN_OP_OP_RUNTIME_INFO_H_ 17 #define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_RUN_OP_OP_RUNTIME_INFO_H_ 18 19 #include <memory> 20 #include <utility> 21 #include <vector> 22 #include <string> 23 #include "include/backend/device_address.h" 24 #include "include/backend/kernel_info.h" 25 #include "include/backend/kernel_graph.h" 26 27 namespace mindspore::runtime { 28 class AclRuntimeInfo { 29 public: AclRuntimeInfo()30 AclRuntimeInfo() : is_dynamic_input_size_(true), is_dynamic_output_size_(true), use_(false) {} SetUse(bool flag)31 void SetUse(bool flag) { use_ = flag; } SetIsDynamicInputSize(bool flag)32 void SetIsDynamicInputSize(bool flag) { 33 CheckInUse(); 34 is_dynamic_input_size_ = flag; 35 } SetIsDynamicOutputSize(bool flag)36 void SetIsDynamicOutputSize(bool flag) { 37 CheckInUse(); 38 is_dynamic_output_size_ = flag; 39 } SetInputNames(std::vector<std::string> input_names)40 void SetInputNames(std::vector<std::string> input_names) { 41 CheckInUse(); 42 input_names_ = std::move(input_names); 43 } SetOutputNames(std::vector<std::string> output_names)44 void SetOutputNames(std::vector<std::string> output_names) { 45 CheckInUse(); 46 output_names_ = std::move(output_names); 47 } 48 use()49 bool use() const { return use_; } is_dynamic_input_size()50 bool is_dynamic_input_size() const { 51 CheckInUse(); 52 return is_dynamic_input_size_; 53 } is_dynamic_output_size()54 bool is_dynamic_output_size() const { 55 CheckInUse(); 56 return is_dynamic_output_size_; 57 } input_names()58 const std::vector<std::string> &input_names() { 59 if (is_dynamic_input_size()) { 60 MS_LOG(EXCEPTION) << "This node has dynamic_input_size, should not get AclRuntimeInfo."; 61 } 62 return input_names_; 63 } output_names()64 const std::vector<std::string> &output_names() { 65 if (is_dynamic_output_size()) { 66 MS_LOG(EXCEPTION) << "This node has dynamic_output_size, should not get AclRuntimeInfo."; 67 } 68 return output_names_; 69 } 70 71 private: CheckInUse()72 void CheckInUse() const { 73 if (!use()) { 74 MS_LOG(EXCEPTION) << "AclRuntimeInfo is not in use."; 75 } 76 } 77 std::vector<std::string> input_names_; 78 std::vector<std::string> output_names_; 79 bool is_dynamic_input_size_; 80 bool is_dynamic_output_size_; 81 bool use_; 82 }; 83 using AclRuntimeInfoPtr = std::shared_ptr<AclRuntimeInfo>; 84 85 class BACKEND_EXPORT OpRuntimeInfo { 86 public: OpRuntimeInfo(std::vector<std::string> output_format,std::vector<TypeId> output_type,std::vector<size_t> output_tensor_size,std::vector<ShapeVector> output_infer_shape,std::vector<ShapeVector> output_device_shape,device::KernelInfo * kernel_info,std::vector<std::pair<device::KernelInfo *,size_t>> input_kernel_infos)87 OpRuntimeInfo(std::vector<std::string> output_format, std::vector<TypeId> output_type, 88 std::vector<size_t> output_tensor_size, std::vector<ShapeVector> output_infer_shape, 89 std::vector<ShapeVector> output_device_shape, device::KernelInfo *kernel_info, 90 std::vector<std::pair<device::KernelInfo *, size_t>> input_kernel_infos) 91 : acl_runtime_info_(std::make_shared<AclRuntimeInfo>()), 92 output_format_(std::move(output_format)), 93 output_type_(std::move(output_type)), 94 output_tensor_size_(std::move(output_tensor_size)), 95 output_infer_shape_(std::move(output_infer_shape)), 96 output_device_shape_(std::move(output_device_shape)), 97 kernel_info_(kernel_info), 98 input_kernel_infos_(std::move(input_kernel_infos)) {} 99 ~OpRuntimeInfo() = default; 100 101 // Key for user data. 102 constexpr static char key[] = "OpRuntimeInfo"; 103 104 std::string output_format(size_t index) const; 105 TypeId output_type(size_t index) const; 106 size_t output_tensor_size(size_t index) const; 107 const ShapeVector &output_infer_shape(size_t index) const; 108 const ShapeVector &output_device_shape(size_t index) const; 109 void SetOutputTensorSize(size_t index, size_t tensor_size); 110 void SetOutputInferShape(size_t index, const ShapeVector &shape); 111 void SetOutputDeviceShape(size_t index, const ShapeVector &shape); 112 device::DeviceAddressPtr GetOutputDeviceAddress(size_t index) const; 113 device::DeviceAddressPtr GetWorkspaceDeviceAddress(size_t index) const; 114 device::DeviceAddressPtr GetInputDeviceAddress(size_t index) const; 115 size_t GetInputSize() const; 116 size_t GetOutputSize() const; 117 size_t GetWorkspaceSize() const; 118 kernel::KernelMod *GetKernelMod() const; 119 void Resize(const AnfNodePtr &node); 120 121 static void CacheGraphOpRuntimeInfo(const KernelGraphPtr &graph); 122 // for acl 123 AclRuntimeInfoPtr acl_runtime_info_; 124 125 private: 126 std::vector<std::string> output_format_; 127 std::vector<TypeId> output_type_; 128 std::vector<size_t> output_tensor_size_; 129 std::vector<ShapeVector> output_infer_shape_; 130 std::vector<ShapeVector> output_device_shape_; 131 device::KernelInfo *kernel_info_; 132 std::vector<std::pair<device::KernelInfo *, size_t>> input_kernel_infos_; 133 }; 134 using OpRuntimeInfoPtr = std::shared_ptr<OpRuntimeInfo>; 135 } // namespace mindspore::runtime 136 #endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_RUN_OP_OP_RUNTIME_INFO_H_ 137