• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_DEVICE_KERNEL_INFO_H_
18 #define MINDSPORE_DEVICE_KERNEL_INFO_H_
19 
20 #include <vector>
21 #include <memory>
22 #include "ir/kernel_info_dev.h"
23 #include "backend/kernel_compiler/kernel_build_info.h"
24 #include "runtime/device/ascend/ascend_device_address.h"
25 #include "backend/kernel_compiler/kernel.h"
26 
27 namespace mindspore {
28 const uint32_t kInvalidGraphId = UINT32_MAX;
29 const uint32_t kInvalidDistincLabel = UINT32_MAX;
30 namespace device {
31 class KernelInfo : public KernelInfoDevice {
32  public:
KernelInfo()33   KernelInfo() {
34     kernel_mod_ = nullptr;
35     is_feature_map_ = false;
36     select_kernel_build_info_ = nullptr;
37     output_address_list_ = {};
38     workspace_address_list_ = {};
39     stream_id_ = UINT32_MAX;
40     stream_distinction_label_ = kInvalidDistincLabel;
41     graph_id_ = kInvalidGraphId;
42   }
43   virtual ~KernelInfo() = default;
44 
has_build_info()45   bool has_build_info() const override { return select_kernel_build_info() != nullptr; }
46   const kernel::KernelBuildInfo *select_kernel_build_info() const;
47   kernel::KernelBuildInfoPtr GetMutableSelectKernelBuildInfo() const;
set_select_kernel_build_info(const kernel::KernelBuildInfoPtr & select_kernel_build_info)48   void set_select_kernel_build_info(const kernel::KernelBuildInfoPtr &select_kernel_build_info) {
49     select_kernel_build_info_ = select_kernel_build_info;
50   }
set_feature_map_flag(bool flag)51   void set_feature_map_flag(bool flag) { is_feature_map_ = flag; }
52   const DeviceAddress *GetOutputAddr(size_t index) const;
53   DeviceAddressPtr GetMutableOutputAddr(size_t index) const;
54   bool OutputAddrExist(size_t index) const;
55   bool SetOutputAddr(const DeviceAddressPtr &output_address, size_t index);
56   DeviceAddress *GetWorkspaceAddr(size_t index) const;
57   DeviceAddressPtr GetMutableWorkspaceAddr(size_t index) const;
58   bool WorkspaceAddrExist(size_t index) const;
59   bool SetWorkspaceAddr(const DeviceAddressPtr &output_address, size_t index);
60   void set_kernel_mod(const kernel::KernelModPtr &kernel_mod);
61   kernel::KernelMod *MutableKernelMod() const;
62   const kernel::KernelMod *kernel_mod() const;
stream_id()63   uint32_t stream_id() const { return stream_id_; }
set_stream_id(uint32_t stream_id)64   void set_stream_id(uint32_t stream_id) { stream_id_ = stream_id; }
stream_distinction_label()65   uint32_t stream_distinction_label() const { return stream_distinction_label_; }
set_stream_distinction_label(uint32_t stream_distinction_label)66   void set_stream_distinction_label(uint32_t stream_distinction_label) {
67     stream_distinction_label_ = stream_distinction_label;
68   }
set_graph_id(uint32_t graph_id)69   void set_graph_id(uint32_t graph_id) { graph_id_ = graph_id; }
graph_id()70   uint32_t graph_id() const { return graph_id_; }
71   bool operator==(const KernelInfo &other) const;
is_feature_map()72   bool is_feature_map() const { return is_feature_map_; }
output_address_list()73   const std::vector<std::shared_ptr<DeviceAddress>> &output_address_list() const { return output_address_list_; }
workspace_address_list()74   const std::vector<std::shared_ptr<DeviceAddress>> &workspace_address_list() const { return workspace_address_list_; }
75 
76  private:
77   bool is_feature_map_;
78   kernel::KernelBuildInfoPtr select_kernel_build_info_;
79   std::vector<std::shared_ptr<DeviceAddress>> output_address_list_;
80   std::vector<std::shared_ptr<DeviceAddress>> workspace_address_list_;
81   kernel::KernelModPtr kernel_mod_;
82   // stream_id_ is the index of stream object vector
83   uint32_t stream_id_;
84   // stream_distinction_label_ is used mark different op in different stream
85   uint32_t stream_distinction_label_;
86   // record which graph the node belong to
87   uint32_t graph_id_;
88 };
89 }  // namespace device
90 }  // namespace mindspore
91 #endif  // MINDSPORE_DEVICE_KERNEL_INFO_H_
92