1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CONVERT_TENSOR_ELIMINATE_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CONVERT_TENSOR_ELIMINATE_H_ 19 20 #include "frontend/optimizer/anf_visitor.h" 21 #include "mindspore/core/ops/structure_ops.h" 22 #include "frontend/optimizer/irpass.h" 23 #include "frontend/optimizer/optimizer.h" 24 #include "pipeline/jit/ps/static_analysis/prim.h" 25 26 namespace mindspore { 27 namespace opt { 28 namespace irpass { 29 class ConvertTensorEliminate : public AnfVisitor { 30 public: operator()31 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 32 auto fg = node->func_graph(); 33 MS_EXCEPTION_IF_NULL(fg); 34 auto cnode = node->cast<CNodePtr>(); 35 MS_EXCEPTION_IF_NULL(cnode); 36 constexpr size_t tensor_index = 1; 37 auto x = cnode->input(tensor_index); 38 if (IsPrimitiveCNode(node, prim::kPrimConvertToAdapterTensor)) { 39 // {prim::kPrimConvertToAdapterTensor, {prim::kPrimConvertToMsTensor, inp}} -> 40 // {prim::kPrimConvertToAdapterTensor, inp} 41 if (IsPrimitiveCNode(x, prim::kPrimConvertToMsTensor)) { 42 auto x_cnode = x->cast<CNodePtr>(); 43 auto inp = x_cnode->input(tensor_index); 44 auto new_node = fg->NewCNode({NewValueNode(prim::kPrimConvertToAdapterTensor), inp}); 45 new_node->set_abstract(node->abstract()); 46 return new_node; 47 } 48 } 49 if (IsPrimitiveCNode(x, prim::kPrimConvertToMsTensor)) { 50 // {prim::kPrimConvertToMsTensor, {prim::kPrimConvertToAdapterTensor, inp}} -> 51 // {prim::kPrimConvertToMsTensor, inp} 52 if (IsPrimitiveCNode(x, prim::kPrimConvertToAdapterTensor)) { 53 auto x_cnode = x->cast<CNodePtr>(); 54 auto inp = x_cnode->input(tensor_index); 55 auto new_node = fg->NewCNode({NewValueNode(prim::kPrimConvertToMsTensor), inp}); 56 new_node->set_abstract(node->abstract()); 57 return new_node; 58 } 59 } 60 return nullptr; 61 } 62 }; 63 64 class ConvertTensorAllEliminate : public AnfVisitor { 65 public: 66 // {prim::kPrimConvertToAdapterTensor, x} -> x 67 // {prim::kPrimConvertToMsTensor, x} -> x operator()68 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 69 if (!IsPrimitiveCNode(node, prim::kPrimConvertToAdapterTensor) && 70 !IsPrimitiveCNode(node, prim::kPrimConvertToMsTensor)) { 71 return nullptr; 72 } 73 auto cnode = node->cast<CNodePtr>(); 74 MS_EXCEPTION_IF_NULL(cnode); 75 constexpr size_t tensor_index = 1; 76 auto tensor = cnode->input(tensor_index); 77 tensor->set_abstract(node->abstract()); 78 return tensor; 79 } 80 }; 81 } // namespace irpass 82 } // namespace opt 83 } // namespace mindspore 84 #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CONVERT_TENSOR_ELIMINATE_H_ 85