• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_COMMON_EXPANDER_CORE_NODE_H_
18 #define MINDSPORE_CCSRC_COMMON_EXPANDER_CORE_NODE_H_
19 #include <vector>
20 #include <string>
21 #include <memory>
22 #include <utility>
23 #include "ir/anf.h"
24 #include "include/common/visible.h"
25 #include "include/common/utils/utils.h"
26 
27 namespace mindspore {
28 namespace expander {
29 class Emitter;
30 using DAttr = std::vector<std::pair<std::string, ValuePtr>>;
31 
32 class COMMON_EXPORT Node : public std::enable_shared_from_this<Node> {
33  public:
34   explicit Node(Emitter *emitter);
Node(Emitter * emitter,const ValuePtr & value)35   Node(Emitter *emitter, const ValuePtr &value) : emitter_(emitter), value_(value) {}
36   virtual ~Node() = default;
37 
get()38   virtual const AnfNodePtr &get() const { MS_EXCEPTION(NotImplementedError) << "Base Node not implement get() method"; }
39 
40   virtual InputType input_type();
41   virtual AbstractBasePtr abstract();
42 
SetValue(const ValuePtr & val)43   void SetValue(const ValuePtr &val) { value_ = val; }
Value()44   ValuePtr Value() { return value_; }
45   virtual ValuePtr BuildValue();
46   virtual bool HasAbstractValue();
47   virtual BaseShapePtr GetShape();
48   virtual TypePtr GetType();
49 
50   std::vector<int64_t> shape();
51   std::vector<std::vector<int64_t>> shapes();
52   TypePtr dtype();
53   std::vector<TypePtr> dtypes();
emitter()54   Emitter *emitter() { return emitter_; }
55   virtual std::string ToString() const;
set_debug_info(const std::string & debug_info)56   virtual void set_debug_info(const std::string &debug_info) {}
debug_info()57   virtual std::string debug_info() const { return ""; }
is_used_value()58   virtual bool is_used_value() const {
59     MS_EXCEPTION(NotImplementedError) << "Base Node not implement is_used_value() method";
60   }
need_compute_grad_out()61   virtual bool need_compute_grad_out() const { return true; }
62 
63  protected:
64   // hold the emitter who created this node.
65   Emitter *emitter_{nullptr};
66   // cache the output shape after first query
67   BaseShapePtr shape_{nullptr};
68   // cache the output dtype after first query
69   TypePtr type_{nullptr};
70   // cache the value of node
71   ValuePtr value_{nullptr};
72 };
73 using NodePtr = std::shared_ptr<Node>;
74 using NodePtrList = std::vector<NodePtr>;
75 
76 class COMMON_EXPORT IrNode : public Node {
77  public:
IrNode(const AnfNodePtr anfnode,Emitter * emitter)78   IrNode(const AnfNodePtr anfnode, Emitter *emitter) : Node(emitter), anf_node_(anfnode) {}
get()79   const AnfNodePtr &get() const override { return anf_node_; }
80   InputType input_type() override;
81   AbstractBasePtr abstract() override;
82 
83   ValuePtr BuildValue() override;
84   bool HasAbstractValue() override;
85   BaseShapePtr GetShape() override;
86   TypePtr GetType() override;
87 
88   std::string ToString() const override;
89   void set_debug_info(const std::string &debug_info) override;
90   std::string debug_info() const override;
is_used_value()91   bool is_used_value() const override { return is_used_value_; }
92 
93  private:
94   // the wrapped anfnode.
95   AnfNodePtr anf_node_{nullptr};
96   // whether use value
97   bool is_used_value_{false};
98 };
99 using IrNodePtr = std::shared_ptr<IrNode>;
100 
101 class COMMON_EXPORT FuncNode : public Node {
102  public:
FuncNode(const ValuePtr & value,const abstract::AbstractBasePtr & abs,InputType input_type,Emitter * emitter)103   FuncNode(const ValuePtr &value, const abstract::AbstractBasePtr &abs, InputType input_type, Emitter *emitter)
104       : Node(emitter, value), abstract_(abs), input_type_(input_type) {}
105   ValuePtr BuildValue() override;
106   InputType input_type() override;
set_node_type(InputType input_type)107   void set_node_type(InputType input_type) { input_type_ = input_type; }
108   AbstractBasePtr abstract() override;
set_abstract(const AbstractBasePtr & abs)109   void set_abstract(const AbstractBasePtr &abs) { abstract_ = abs; }
110   BaseShapePtr GetShape() override;
111   TypePtr GetType() override;
ToString()112   std::string ToString() const override { return value_->ToString(); }
set_debug_info(const std::string & debug_info)113   void set_debug_info(const std::string &debug_info) override {}
debug_info()114   std::string debug_info() const override { return ""; }
need_compute_grad_out()115   bool need_compute_grad_out() const override { return need_compute_grad_out_; }
set_need_compute_grad_out(bool need_compute_grad_out)116   void set_need_compute_grad_out(bool need_compute_grad_out) { need_compute_grad_out_ = need_compute_grad_out; }
117 
118  private:
119   AbstractBasePtr abstract_;
120   InputType input_type_;
121   bool need_compute_grad_out_{true};
122 };
123 using FuncNodePtr = std::shared_ptr<FuncNode>;
124 }  // namespace expander
125 }  // namespace mindspore
126 #endif  // MINDSPORE_CCSRC_COMMON_EXPANDER_CORE_NODE_H_
127