1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_DEVICE_KERNEL_INFO_H_ 18 #define MINDSPORE_DEVICE_KERNEL_INFO_H_ 19 20 #include <vector> 21 #include <memory> 22 #include "ir/kernel_info_dev.h" 23 #include "backend/kernel_compiler/kernel_build_info.h" 24 #include "runtime/device/ascend/ascend_device_address.h" 25 #include "backend/kernel_compiler/kernel.h" 26 27 namespace mindspore { 28 const uint32_t kInvalidGraphId = UINT32_MAX; 29 const uint32_t kInvalidDistincLabel = UINT32_MAX; 30 namespace device { 31 class KernelInfo : public KernelInfoDevice { 32 public: KernelInfo()33 KernelInfo() { 34 kernel_mod_ = nullptr; 35 is_feature_map_ = false; 36 select_kernel_build_info_ = nullptr; 37 output_address_list_ = {}; 38 workspace_address_list_ = {}; 39 stream_id_ = UINT32_MAX; 40 stream_distinction_label_ = kInvalidDistincLabel; 41 graph_id_ = kInvalidGraphId; 42 } 43 virtual ~KernelInfo() = default; 44 has_build_info()45 bool has_build_info() const override { return select_kernel_build_info() != nullptr; } 46 const kernel::KernelBuildInfo *select_kernel_build_info() const; 47 kernel::KernelBuildInfoPtr GetMutableSelectKernelBuildInfo() const; set_select_kernel_build_info(const kernel::KernelBuildInfoPtr & select_kernel_build_info)48 void set_select_kernel_build_info(const kernel::KernelBuildInfoPtr &select_kernel_build_info) { 49 select_kernel_build_info_ = select_kernel_build_info; 50 } set_feature_map_flag(bool flag)51 void set_feature_map_flag(bool flag) { is_feature_map_ = flag; } 52 const DeviceAddress *GetOutputAddr(size_t index) const; 53 DeviceAddressPtr GetMutableOutputAddr(size_t index) const; 54 bool OutputAddrExist(size_t index) const; 55 bool SetOutputAddr(const DeviceAddressPtr &output_address, size_t index); 56 DeviceAddress *GetWorkspaceAddr(size_t index) const; 57 DeviceAddressPtr GetMutableWorkspaceAddr(size_t index) const; 58 bool WorkspaceAddrExist(size_t index) const; 59 bool SetWorkspaceAddr(const DeviceAddressPtr &output_address, size_t index); 60 void set_kernel_mod(const kernel::KernelModPtr &kernel_mod); 61 kernel::KernelMod *MutableKernelMod() const; 62 const kernel::KernelMod *kernel_mod() const; stream_id()63 uint32_t stream_id() const { return stream_id_; } set_stream_id(uint32_t stream_id)64 void set_stream_id(uint32_t stream_id) { stream_id_ = stream_id; } stream_distinction_label()65 uint32_t stream_distinction_label() const { return stream_distinction_label_; } set_stream_distinction_label(uint32_t stream_distinction_label)66 void set_stream_distinction_label(uint32_t stream_distinction_label) { 67 stream_distinction_label_ = stream_distinction_label; 68 } set_graph_id(uint32_t graph_id)69 void set_graph_id(uint32_t graph_id) { graph_id_ = graph_id; } graph_id()70 uint32_t graph_id() const { return graph_id_; } 71 bool operator==(const KernelInfo &other) const; is_feature_map()72 bool is_feature_map() const { return is_feature_map_; } output_address_list()73 const std::vector<std::shared_ptr<DeviceAddress>> &output_address_list() const { return output_address_list_; } workspace_address_list()74 const std::vector<std::shared_ptr<DeviceAddress>> &workspace_address_list() const { return workspace_address_list_; } 75 76 private: 77 bool is_feature_map_; 78 kernel::KernelBuildInfoPtr select_kernel_build_info_; 79 std::vector<std::shared_ptr<DeviceAddress>> output_address_list_; 80 std::vector<std::shared_ptr<DeviceAddress>> workspace_address_list_; 81 kernel::KernelModPtr kernel_mod_; 82 // stream_id_ is the index of stream object vector 83 uint32_t stream_id_; 84 // stream_distinction_label_ is used mark different op in different stream 85 uint32_t stream_distinction_label_; 86 // record which graph the node belong to 87 uint32_t graph_id_; 88 }; 89 } // namespace device 90 } // namespace mindspore 91 #endif // MINDSPORE_DEVICE_KERNEL_INFO_H_ 92