• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CONVERT_TENSOR_ELIMINATE_H_
18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CONVERT_TENSOR_ELIMINATE_H_
19 
20 #include "frontend/optimizer/anf_visitor.h"
21 #include "mindspore/core/ops/structure_ops.h"
22 #include "frontend/optimizer/irpass.h"
23 #include "frontend/optimizer/optimizer.h"
24 #include "pipeline/jit/ps/static_analysis/prim.h"
25 
26 namespace mindspore {
27 namespace opt {
28 namespace irpass {
29 class ConvertTensorEliminate : public AnfVisitor {
30  public:
operator()31   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
32     auto fg = node->func_graph();
33     MS_EXCEPTION_IF_NULL(fg);
34     auto cnode = node->cast<CNodePtr>();
35     MS_EXCEPTION_IF_NULL(cnode);
36     constexpr size_t tensor_index = 1;
37     auto x = cnode->input(tensor_index);
38     if (IsPrimitiveCNode(node, prim::kPrimConvertToAdapterTensor)) {
39       // {prim::kPrimConvertToAdapterTensor, {prim::kPrimConvertToMsTensor, inp}} ->
40       // {prim::kPrimConvertToAdapterTensor, inp}
41       if (IsPrimitiveCNode(x, prim::kPrimConvertToMsTensor)) {
42         auto x_cnode = x->cast<CNodePtr>();
43         auto inp = x_cnode->input(tensor_index);
44         auto new_node = fg->NewCNode({NewValueNode(prim::kPrimConvertToAdapterTensor), inp});
45         new_node->set_abstract(node->abstract());
46         return new_node;
47       }
48     }
49     if (IsPrimitiveCNode(x, prim::kPrimConvertToMsTensor)) {
50       // {prim::kPrimConvertToMsTensor, {prim::kPrimConvertToAdapterTensor, inp}} ->
51       // {prim::kPrimConvertToMsTensor, inp}
52       if (IsPrimitiveCNode(x, prim::kPrimConvertToAdapterTensor)) {
53         auto x_cnode = x->cast<CNodePtr>();
54         auto inp = x_cnode->input(tensor_index);
55         auto new_node = fg->NewCNode({NewValueNode(prim::kPrimConvertToMsTensor), inp});
56         new_node->set_abstract(node->abstract());
57         return new_node;
58       }
59     }
60     return nullptr;
61   }
62 };
63 
64 class ConvertTensorAllEliminate : public AnfVisitor {
65  public:
66   // {prim::kPrimConvertToAdapterTensor, x} -> x
67   // {prim::kPrimConvertToMsTensor, x} -> x
operator()68   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
69     if (!IsPrimitiveCNode(node, prim::kPrimConvertToAdapterTensor) &&
70         !IsPrimitiveCNode(node, prim::kPrimConvertToMsTensor)) {
71       return nullptr;
72     }
73     auto cnode = node->cast<CNodePtr>();
74     MS_EXCEPTION_IF_NULL(cnode);
75     constexpr size_t tensor_index = 1;
76     auto tensor = cnode->input(tensor_index);
77     tensor->set_abstract(node->abstract());
78     return tensor;
79   }
80 };
81 }  // namespace irpass
82 }  // namespace opt
83 }  // namespace mindspore
84 #endif  // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CONVERT_TENSOR_ELIMINATE_H_
85