• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADAPTER_CALLBACK_IMPL_H_
18 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADAPTER_CALLBACK_IMPL_H_
19 #include <string>
20 #include <vector>
21 #include "backend/common/graph_kernel/core/graph_kernel_callback.h"
22 
23 namespace mindspore::graphkernel {
24 class BACKEND_EXPORT CallbackImpl : public Callback {
25  public:
26   virtual ~CallbackImpl() = default;
27   ShapeVector GetInputShape(const AnfNodePtr &node, size_t i) override;
28   ShapeVector GetOutputShape(const AnfNodePtr &node, size_t i) override;
29   ShapeVector GetInputInferShape(const AnfNodePtr &node, size_t i) override;
30   ShapeVector GetOutputInferShape(const AnfNodePtr &node, size_t i) override;
31   std::string GetInputFormat(const AnfNodePtr &node, size_t i) override;
32   std::string GetOutputFormat(const AnfNodePtr &node, size_t i) override;
33   std::string GetProcessor(const AnfNodePtr &node) override;
34   TypeId GetInputType(const AnfNodePtr &node, size_t i) override;
35   TypeId GetOutputType(const AnfNodePtr &node, size_t i) override;
36   TypeId GetInputInferType(const AnfNodePtr &node, size_t i) override;
37   TypeId GetOutputInferType(const AnfNodePtr &node, size_t i) override;
38   void SetGraphKernelNodeKernelInfo(const AnfNodePtr &node) override;
39   void SetBasicNodeKernelInfo(const AnfNodePtr &node, const std::vector<inner::NodeBase> &outputs_info) override;
40   void SetEmptyKernelInfo(const AnfNodePtr &node) override;
41   void ResetKernelInfo(const AnfNodePtr &node) override;
42   void ResetKernelInfoInputs(const AnfNodePtr &node, const std::vector<size_t> &indices) override;
43 
44  private:
45   std::string GetTargetFromContextImpl(bool detail) override;
46   void CollectInputTypesAndFormats(const AnfNodePtr &node, std::vector<TypeId> *input_types,
47                                    std::vector<std::string> *input_formats, bool is_basic_node = false);
48 };
49 class BACKEND_EXPORT CallbackImplWithInferShape : public CallbackImpl {
50  public:
51   ShapeVector GetInputShape(const AnfNodePtr &node, size_t i) override;
52   ShapeVector GetOutputShape(const AnfNodePtr &node, size_t i) override;
53   TypeId GetInputType(const AnfNodePtr &node, size_t i) override;
54   TypeId GetOutputType(const AnfNodePtr &node, size_t i) override;
55   std::string GetInputFormat(const AnfNodePtr &, size_t) override;
56   std::string GetOutputFormat(const AnfNodePtr &, size_t) override;
57   std::string GetProcessor(const AnfNodePtr &) override;
SetGraphKernelNodeKernelInfo(const AnfNodePtr &)58   void SetGraphKernelNodeKernelInfo(const AnfNodePtr &) override {}
59   void SetBasicNodeKernelInfo(const AnfNodePtr &node, const std::vector<inner::NodeBase> &outputs_info) override;
ResetKernelInfo(const AnfNodePtr & node)60   void ResetKernelInfo(const AnfNodePtr &node) override {}
IsUseDeviceInfo()61   bool IsUseDeviceInfo() override { return false; }
ResetKernelInfoInputs(const AnfNodePtr &,const std::vector<size_t> &)62   void ResetKernelInfoInputs(const AnfNodePtr &, const std::vector<size_t> &) override {}
63 };
64 }  // namespace mindspore::graphkernel
65 #endif  // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADAPTER_CALLBACK_IMPL_H_
66