1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_GRAPH_BUILDER_H_ 17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_GRAPH_BUILDER_H_ 18 19 #include <vector> 20 #include <memory> 21 #include <string> 22 23 #include "ir/dtype.h" 24 #include "ir/tensor.h" 25 #include "mindapi/base/type_id.h" 26 #include "backend/common/graph_kernel/model/lite_graph.h" 27 28 namespace mindspore::graphkernel::inner { 29 class GraphBuilder : public LiteGraph::GraphBuilderBase { 30 public: GraphBuilderBase(name)31 explicit GraphBuilder(const std::string &name = "") : GraphBuilderBase(name) {} 32 ~GraphBuilder() = default; Add(const NodePtr & lhs,const NodePtr & rhs)33 NodePtr Add(const NodePtr &lhs, const NodePtr &rhs) const { return Emit("Add", {lhs, rhs}); } Sub(const NodePtr & lhs,const NodePtr & rhs)34 NodePtr Sub(const NodePtr &lhs, const NodePtr &rhs) const { return Emit("Sub", {lhs, rhs}); } Mul(const NodePtr & lhs,const NodePtr & rhs)35 NodePtr Mul(const NodePtr &lhs, const NodePtr &rhs) const { return Emit("Mul", {lhs, rhs}); } Div(const NodePtr & lhs,const NodePtr & rhs)36 NodePtr Div(const NodePtr &lhs, const NodePtr &rhs) const { return Emit("RealDiv", {lhs, rhs}); } Greater(const NodePtr & lhs,const NodePtr & rhs)37 NodePtr Greater(const NodePtr &lhs, const NodePtr &rhs) const { return Emit("Greater", {lhs, rhs}); } Less(const NodePtr & lhs,const NodePtr & rhs)38 NodePtr Less(const NodePtr &lhs, const NodePtr &rhs) const { return Emit("Less", {lhs, rhs}); } GreaterEqual(const NodePtr & lhs,const NodePtr & rhs)39 NodePtr GreaterEqual(const NodePtr &lhs, const NodePtr &rhs) const { return Emit("GreaterEqual", {lhs, rhs}); } LessEqual(const NodePtr & lhs,const NodePtr & rhs)40 NodePtr LessEqual(const NodePtr &lhs, const NodePtr &rhs) const { return Emit("LessEqual", {lhs, rhs}); } Equal(const NodePtr & lhs,const NodePtr & rhs)41 NodePtr Equal(const NodePtr &lhs, const NodePtr &rhs) const { return Emit("Equal", {lhs, rhs}); } LogicalOr(const NodePtr & lhs,const NodePtr & rhs)42 NodePtr LogicalOr(const NodePtr &lhs, const NodePtr &rhs) const { return Emit("LogicalOr", {lhs, rhs}); } Assign(const NodePtr & lhs,const NodePtr & rhs)43 NodePtr Assign(const NodePtr &lhs, const NodePtr &rhs) const { return Emit("Assign", {lhs, rhs}); } Select(const NodePtr & cond,const NodePtr & lhs,const NodePtr & rhs)44 NodePtr Select(const NodePtr &cond, const NodePtr &lhs, const NodePtr &rhs) const { 45 return Emit("Select", {cond, lhs, rhs}); 46 } 47 NodePtr MatMul(const NodePtr &lhs, const NodePtr &rhs, const TypeId &type_id = kNumberTypeFloat16, 48 const bool &transpose_a = false, const bool &transpose_b = false) const { 49 return Emit("MatMul", {lhs, rhs}, 50 {{"transpose_a", MakeValue(transpose_a)}, 51 {"transpose_x1", MakeValue(transpose_a)}, 52 {"transpose_b", MakeValue(transpose_b)}, 53 {"transpose_x2", MakeValue(transpose_b)}, 54 {"dst_type", TypeIdToType(type_id)}}); 55 } Neg(const NodePtr & input)56 NodePtr Neg(const NodePtr &input) const { return Emit("Neg", {input}); } Exp(const NodePtr & input)57 NodePtr Exp(const NodePtr &input) const { return Emit("Exp", {input}); } Abs(const NodePtr & input)58 NodePtr Abs(const NodePtr &input) const { return Emit("Abs", {input}); } Log(const NodePtr & input)59 NodePtr Log(const NodePtr &input) const { return Emit("Log", {input}); } Sqrt(const NodePtr & input)60 NodePtr Sqrt(const NodePtr &input) const { return Emit("Sqrt", {input}); } IsInf(const NodePtr & input)61 NodePtr IsInf(const NodePtr &input) const { return Emit("IsInf", {input}); } IsNan(const NodePtr & input)62 NodePtr IsNan(const NodePtr &input) const { return Emit("IsNan", {input}); } Reciprocal(const NodePtr & input)63 NodePtr Reciprocal(const NodePtr &input) const { return Emit("Reciprocal", {input}); } 64 NodePtr StridedSlice(const NodePtr &input, const std::vector<int64_t> &begin, const std::vector<int64_t> &end, 65 const std::vector<int64_t> &strides) const; 66 NodePtr Tanh(const NodePtr &input) const; TensorScatterAdd(const NodePtr & input,const NodePtr & indices,const NodePtr & update)67 NodePtr TensorScatterAdd(const NodePtr &input, const NodePtr &indices, const NodePtr &update) const { 68 return Emit("TensorScatterAdd", {input, indices, update}); 69 } CReal(const NodePtr & input)70 NodePtr CReal(const NodePtr &input) const { return Emit("CReal", {input}); } CImag(const NodePtr & input)71 NodePtr CImag(const NodePtr &input) const { return Emit("CImag", {input}); } Complex(const NodePtr & lhs,const NodePtr & rhs)72 NodePtr Complex(const NodePtr &lhs, const NodePtr &rhs) const { return Emit("Complex", {lhs, rhs}); } Custom(const NodePtrList & inputs,const NodeBase & baseinfo,const std::string & func_name,const std::string & func_type,const std::string & func_source_str,const size_t & inplace_assign_output,const std::string & func_compile_attrs)73 NodePtr Custom(const NodePtrList &inputs, const NodeBase &baseinfo, const std::string &func_name, 74 const std::string &func_type, const std::string &func_source_str, const size_t &inplace_assign_output, 75 const std::string &func_compile_attrs) const { 76 std::string write_from_output_to_input = "0 " + std::to_string(inplace_assign_output); 77 return Op("Custom", baseinfo, inputs, 78 {{"func_name", MakeValue(func_name)}, 79 {"func_type", MakeValue(func_type)}, 80 {"func_source_str", MakeValue(func_source_str)}, 81 {"inplace_assign_output", MakeValue(write_from_output_to_input)}, 82 {"func_compile_attrs", MakeValue(func_compile_attrs)}}); 83 } Cast(const NodePtr & input,const TypeId & type_id)84 NodePtr Cast(const NodePtr &input, const TypeId &type_id) const { 85 return Emit("Cast", {input}, {{"dst_type", TypeIdToType(type_id)}}); 86 } Shape(const NodePtr & input)87 NodePtr Shape(const NodePtr &input) const { return Emit("Shape", {input}); } 88 NodePtr Reshape(const NodePtr &input, const ShapeVector &shape) const; 89 NodePtr BroadcastTo(const NodePtr &input, const ShapeVector &shape) const; 90 NodePtr Gather(const NodePtr ¶m, const NodePtr &indice, int64_t axis, int64_t batch_dims = 0) const; 91 NodePtr Concat(const NodePtrList &inputs, const int64_t &axis) const; 92 NodePtr Transpose(const NodePtr &input, const ShapeVector &perm) const; 93 94 NodePtr ReduceSum(const NodePtr &input, const std::vector<int64_t> &axis, const bool &keep_dims = false) const; 95 NodePtr ReduceMax(const NodePtr &input, const std::vector<int64_t> &axis, const bool &keep_dims = false) const; 96 NodePtr ReduceMin(const NodePtr &input, const std::vector<int64_t> &axis, const bool &keep_dims = false) const; 97 98 NodePtr TupleGetItem(const NodePtr &input, int64_t index) const; 99 100 template <typename T> Tensor(T input,const TypeId & type_id)101 NodePtr Tensor(T input, const TypeId &type_id) const { 102 tensor::TensorPtr const_tensor; 103 switch (type_id) { 104 case kNumberTypeBool: 105 const_tensor = std::make_shared<tensor::Tensor>(static_cast<bool>(input), TypeIdToType(type_id)); 106 break; 107 case kNumberTypeInt: 108 case kNumberTypeInt8: 109 case kNumberTypeInt16: 110 case kNumberTypeInt32: 111 case kNumberTypeInt64: 112 const_tensor = std::make_shared<tensor::Tensor>(static_cast<int64_t>(input), TypeIdToType(type_id)); 113 break; 114 case kNumberTypeUInt: 115 case kNumberTypeUInt8: 116 case kNumberTypeUInt16: 117 case kNumberTypeUInt32: 118 case kNumberTypeUInt64: 119 const_tensor = std::make_shared<tensor::Tensor>(static_cast<uint64_t>(input), TypeIdToType(type_id)); 120 break; 121 case kNumberTypeFloat: 122 case kNumberTypeFloat16: 123 case kNumberTypeFloat32: 124 case kNumberTypeFloat64: 125 case kNumberTypeBFloat16: 126 const_tensor = std::make_shared<tensor::Tensor>(static_cast<double>(input), TypeIdToType(type_id)); 127 break; 128 default: 129 MS_LOG(EXCEPTION) << "The input data type should be int, uint, float or bool, But Get :" 130 << TypeIdToString(type_id); 131 } 132 return Value(const_tensor); 133 } 134 Tensor(std::vector<int64_t> input)135 NodePtr Tensor(std::vector<int64_t> input) const { 136 auto const_tensor = std::make_shared<tensor::Tensor>(input); 137 return Value(const_tensor); 138 } 139 140 template <typename T> Scalar(const T & input)141 NodePtr Scalar(const T &input) const { 142 auto const_scalar = MakeValue(input); 143 return std::make_shared<ConstScalarNode>(const_scalar); 144 } 145 146 template <typename T> Tuple(const std::vector<T> & input)147 NodePtr Tuple(const std::vector<T> &input) const { 148 auto const_tuple = MakeValue(input); 149 return std::make_shared<ConstTupleNode>(const_tuple, input.size()); 150 } 151 }; 152 } // namespace mindspore::graphkernel::inner 153 #endif 154