• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_GRAPH_BUILDER_H_
17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_GRAPH_BUILDER_H_
18 
19 #include <vector>
20 #include <memory>
21 #include <string>
22 
23 #include "ir/dtype.h"
24 #include "ir/tensor.h"
25 #include "mindapi/base/type_id.h"
26 #include "backend/common/graph_kernel/model/lite_graph.h"
27 
28 namespace mindspore::graphkernel::inner {
29 class GraphBuilder : public LiteGraph::GraphBuilderBase {
30  public:
GraphBuilderBase(name)31   explicit GraphBuilder(const std::string &name = "") : GraphBuilderBase(name) {}
32   ~GraphBuilder() = default;
Add(const NodePtr & lhs,const NodePtr & rhs)33   NodePtr Add(const NodePtr &lhs, const NodePtr &rhs) const { return Emit("Add", {lhs, rhs}); }
Sub(const NodePtr & lhs,const NodePtr & rhs)34   NodePtr Sub(const NodePtr &lhs, const NodePtr &rhs) const { return Emit("Sub", {lhs, rhs}); }
Mul(const NodePtr & lhs,const NodePtr & rhs)35   NodePtr Mul(const NodePtr &lhs, const NodePtr &rhs) const { return Emit("Mul", {lhs, rhs}); }
Div(const NodePtr & lhs,const NodePtr & rhs)36   NodePtr Div(const NodePtr &lhs, const NodePtr &rhs) const { return Emit("RealDiv", {lhs, rhs}); }
Greater(const NodePtr & lhs,const NodePtr & rhs)37   NodePtr Greater(const NodePtr &lhs, const NodePtr &rhs) const { return Emit("Greater", {lhs, rhs}); }
Less(const NodePtr & lhs,const NodePtr & rhs)38   NodePtr Less(const NodePtr &lhs, const NodePtr &rhs) const { return Emit("Less", {lhs, rhs}); }
GreaterEqual(const NodePtr & lhs,const NodePtr & rhs)39   NodePtr GreaterEqual(const NodePtr &lhs, const NodePtr &rhs) const { return Emit("GreaterEqual", {lhs, rhs}); }
LessEqual(const NodePtr & lhs,const NodePtr & rhs)40   NodePtr LessEqual(const NodePtr &lhs, const NodePtr &rhs) const { return Emit("LessEqual", {lhs, rhs}); }
Equal(const NodePtr & lhs,const NodePtr & rhs)41   NodePtr Equal(const NodePtr &lhs, const NodePtr &rhs) const { return Emit("Equal", {lhs, rhs}); }
LogicalOr(const NodePtr & lhs,const NodePtr & rhs)42   NodePtr LogicalOr(const NodePtr &lhs, const NodePtr &rhs) const { return Emit("LogicalOr", {lhs, rhs}); }
Assign(const NodePtr & lhs,const NodePtr & rhs)43   NodePtr Assign(const NodePtr &lhs, const NodePtr &rhs) const { return Emit("Assign", {lhs, rhs}); }
Select(const NodePtr & cond,const NodePtr & lhs,const NodePtr & rhs)44   NodePtr Select(const NodePtr &cond, const NodePtr &lhs, const NodePtr &rhs) const {
45     return Emit("Select", {cond, lhs, rhs});
46   }
47   NodePtr MatMul(const NodePtr &lhs, const NodePtr &rhs, const TypeId &type_id = kNumberTypeFloat16,
48                  const bool &transpose_a = false, const bool &transpose_b = false) const {
49     return Emit("MatMul", {lhs, rhs},
50                 {{"transpose_a", MakeValue(transpose_a)},
51                  {"transpose_x1", MakeValue(transpose_a)},
52                  {"transpose_b", MakeValue(transpose_b)},
53                  {"transpose_x2", MakeValue(transpose_b)},
54                  {"dst_type", TypeIdToType(type_id)}});
55   }
Neg(const NodePtr & input)56   NodePtr Neg(const NodePtr &input) const { return Emit("Neg", {input}); }
Exp(const NodePtr & input)57   NodePtr Exp(const NodePtr &input) const { return Emit("Exp", {input}); }
Abs(const NodePtr & input)58   NodePtr Abs(const NodePtr &input) const { return Emit("Abs", {input}); }
Log(const NodePtr & input)59   NodePtr Log(const NodePtr &input) const { return Emit("Log", {input}); }
Sqrt(const NodePtr & input)60   NodePtr Sqrt(const NodePtr &input) const { return Emit("Sqrt", {input}); }
IsInf(const NodePtr & input)61   NodePtr IsInf(const NodePtr &input) const { return Emit("IsInf", {input}); }
IsNan(const NodePtr & input)62   NodePtr IsNan(const NodePtr &input) const { return Emit("IsNan", {input}); }
Reciprocal(const NodePtr & input)63   NodePtr Reciprocal(const NodePtr &input) const { return Emit("Reciprocal", {input}); }
64   NodePtr StridedSlice(const NodePtr &input, const std::vector<int64_t> &begin, const std::vector<int64_t> &end,
65                        const std::vector<int64_t> &strides) const;
66   NodePtr Tanh(const NodePtr &input) const;
TensorScatterAdd(const NodePtr & input,const NodePtr & indices,const NodePtr & update)67   NodePtr TensorScatterAdd(const NodePtr &input, const NodePtr &indices, const NodePtr &update) const {
68     return Emit("TensorScatterAdd", {input, indices, update});
69   }
CReal(const NodePtr & input)70   NodePtr CReal(const NodePtr &input) const { return Emit("CReal", {input}); }
CImag(const NodePtr & input)71   NodePtr CImag(const NodePtr &input) const { return Emit("CImag", {input}); }
Complex(const NodePtr & lhs,const NodePtr & rhs)72   NodePtr Complex(const NodePtr &lhs, const NodePtr &rhs) const { return Emit("Complex", {lhs, rhs}); }
Custom(const NodePtrList & inputs,const NodeBase & baseinfo,const std::string & func_name,const std::string & func_type,const std::string & func_source_str,const size_t & inplace_assign_output,const std::string & func_compile_attrs)73   NodePtr Custom(const NodePtrList &inputs, const NodeBase &baseinfo, const std::string &func_name,
74                  const std::string &func_type, const std::string &func_source_str, const size_t &inplace_assign_output,
75                  const std::string &func_compile_attrs) const {
76     std::string write_from_output_to_input = "0 " + std::to_string(inplace_assign_output);
77     return Op("Custom", baseinfo, inputs,
78               {{"func_name", MakeValue(func_name)},
79                {"func_type", MakeValue(func_type)},
80                {"func_source_str", MakeValue(func_source_str)},
81                {"inplace_assign_output", MakeValue(write_from_output_to_input)},
82                {"func_compile_attrs", MakeValue(func_compile_attrs)}});
83   }
Cast(const NodePtr & input,const TypeId & type_id)84   NodePtr Cast(const NodePtr &input, const TypeId &type_id) const {
85     return Emit("Cast", {input}, {{"dst_type", TypeIdToType(type_id)}});
86   }
Shape(const NodePtr & input)87   NodePtr Shape(const NodePtr &input) const { return Emit("Shape", {input}); }
88   NodePtr Reshape(const NodePtr &input, const ShapeVector &shape) const;
89   NodePtr BroadcastTo(const NodePtr &input, const ShapeVector &shape) const;
90   NodePtr Gather(const NodePtr &param, const NodePtr &indice, int64_t axis, int64_t batch_dims = 0) const;
91   NodePtr Concat(const NodePtrList &inputs, const int64_t &axis) const;
92   NodePtr Transpose(const NodePtr &input, const ShapeVector &perm) const;
93 
94   NodePtr ReduceSum(const NodePtr &input, const std::vector<int64_t> &axis, const bool &keep_dims = false) const;
95   NodePtr ReduceMax(const NodePtr &input, const std::vector<int64_t> &axis, const bool &keep_dims = false) const;
96   NodePtr ReduceMin(const NodePtr &input, const std::vector<int64_t> &axis, const bool &keep_dims = false) const;
97 
98   NodePtr TupleGetItem(const NodePtr &input, int64_t index) const;
99 
100   template <typename T>
Tensor(T input,const TypeId & type_id)101   NodePtr Tensor(T input, const TypeId &type_id) const {
102     tensor::TensorPtr const_tensor;
103     switch (type_id) {
104       case kNumberTypeBool:
105         const_tensor = std::make_shared<tensor::Tensor>(static_cast<bool>(input), TypeIdToType(type_id));
106         break;
107       case kNumberTypeInt:
108       case kNumberTypeInt8:
109       case kNumberTypeInt16:
110       case kNumberTypeInt32:
111       case kNumberTypeInt64:
112         const_tensor = std::make_shared<tensor::Tensor>(static_cast<int64_t>(input), TypeIdToType(type_id));
113         break;
114       case kNumberTypeUInt:
115       case kNumberTypeUInt8:
116       case kNumberTypeUInt16:
117       case kNumberTypeUInt32:
118       case kNumberTypeUInt64:
119         const_tensor = std::make_shared<tensor::Tensor>(static_cast<uint64_t>(input), TypeIdToType(type_id));
120         break;
121       case kNumberTypeFloat:
122       case kNumberTypeFloat16:
123       case kNumberTypeFloat32:
124       case kNumberTypeFloat64:
125       case kNumberTypeBFloat16:
126         const_tensor = std::make_shared<tensor::Tensor>(static_cast<double>(input), TypeIdToType(type_id));
127         break;
128       default:
129         MS_LOG(EXCEPTION) << "The input data type should be int, uint, float or bool, But Get :"
130                           << TypeIdToString(type_id);
131     }
132     return Value(const_tensor);
133   }
134 
Tensor(std::vector<int64_t> input)135   NodePtr Tensor(std::vector<int64_t> input) const {
136     auto const_tensor = std::make_shared<tensor::Tensor>(input);
137     return Value(const_tensor);
138   }
139 
140   template <typename T>
Scalar(const T & input)141   NodePtr Scalar(const T &input) const {
142     auto const_scalar = MakeValue(input);
143     return std::make_shared<ConstScalarNode>(const_scalar);
144   }
145 
146   template <typename T>
Tuple(const std::vector<T> & input)147   NodePtr Tuple(const std::vector<T> &input) const {
148     auto const_tuple = MakeValue(input);
149     return std::make_shared<ConstTupleNode>(const_tuple, input.size());
150   }
151 };
152 }  // namespace mindspore::graphkernel::inner
153 #endif
154