• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_BACKEND_COMMON_GRAPH_KERNEL_SYMBOL_ENGINE_JIT_SYNTAX_H_
17 #define MINDSPORE_CCSRC_BACKEND_COMMON_GRAPH_KERNEL_SYMBOL_ENGINE_JIT_SYNTAX_H_
18 #include <string>
19 #include <memory>
20 #include <vector>
21 
22 namespace mindspore::graphkernel::symshape::ast {
23 /*
24  Term ::= SingleTerm | Shape
25  Shape ::= List[SingleTerm]
26  SingleTerm ::= IntImm int | Symbol
27  Symbol ::= BinOp BinOpType Term Term   // binary shape function, BinOpType is tag indicating Operation type
28          |  Input int int      // from shape of input: input_i.shape[j]
29          | Var id               // a symbol represent
30 */
31 
32 enum class BinOpType { ScalarAdd, ScalarSub, ScalarMul, ScalarDiv, ScalarMin, ScalarMax };
33 
34 class Visitor;
35 
36 struct Term {
AcceptTerm37   virtual void Accept(Visitor *visitor) {}
ToStringTerm38   virtual std::string ToString() const { return "Term"; }
39 };
40 using TermPtr = std::shared_ptr<Term>;
41 
42 struct SingleTerm : public Term {
ToStringSingleTerm43   std::string ToString() const override { return "SingleTerm"; }
44 };
45 using SingleTermPtr = std::shared_ptr<SingleTerm>;
46 
47 struct IntImm : public SingleTerm {
48   int64_t shape_int;
IntImmIntImm49   explicit IntImm(int64_t i) : shape_int(i) {}
50 
51   void Accept(Visitor *visitor);
52   std::string ToString() const override;
53 };
54 
55 struct Symbol : public SingleTerm {
56   virtual void Accept(Visitor *visitor) = 0;
ToStringSymbol57   std::string ToString() const override { return "Var"; }
58 };
59 
60 struct Var : public Symbol {
61   size_t id_;
62   std::string name_;
63 
VarVar64   Var(size_t id, const std::string &name) : id_(id), name_(name) {}
65   void Accept(Visitor *visitor) override;
66   std::string ToString() const override;
67 };
68 using VarPtr = std::shared_ptr<Var>;
69 
70 struct BinOp : public Symbol {
71   BinOpType optype_;
72   TermPtr a_;
73   TermPtr b_;
74 
75   void Accept(Visitor *visitor) override;
76   std::string ToString() const override;
77 };
78 
79 struct Input : public Symbol {
80   int64_t i_, j_;
81 
InputInput82   Input(int64_t i, int64_t j) : i_(i), j_(j) {}
83   void Accept(Visitor *visitor) override;
84   std::string ToString() const override;
85 };
86 
87 struct Shape : public Term {
88   std::vector<SingleTermPtr> smbls_;
89   void Accept(Visitor *visitor) override;
90   std::string ToString() const override;
91 };
92 using ShapePtr = std::shared_ptr<Shape>;
93 
94 class Visitor {
95  public:
96   virtual void Visit(const IntImm &intimm) = 0;
97   virtual void Visit(const BinOp &op) = 0;
98   virtual void Visit(const Input &input) = 0;
99   virtual void Visit(const Var &val) = 0;
Visit(const Shape & shape)100   virtual void Visit(const Shape &shape) {}
101 };
102 
103 using SymbolTable = std::vector<TermPtr>;
104 }  // namespace mindspore::graphkernel::symshape::ast
105 
106 #endif  // MINDSPORE_CCSRC_BACKEND_COMMON_GRAPH_KERNEL_SYMBOL_ENGINE_JIT_SYNTAX_H_
107