• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "include/common/pynative/abstract_converter.h"
18 #include <vector>
19 #include "mindspore/core/abstract/abstract_value.h"
20 
21 namespace mindspore {
22 namespace pynative {
CacheAbstract(const AbstractBasePtr & abstract)23 void AbstractConverter::CacheAbstract(const AbstractBasePtr &abstract) { abstract_cache_.Push(abstract); }
24 
ConvertAbstract(const ValuePtr & t)25 AbstractBasePtr AbstractConverter::ConvertAbstract(const ValuePtr &t) {
26   if (t->isa<BaseTensor>()) {
27     auto tensor = t->cast<BaseTensorPtr>();
28     return ConvertAbstract(tensor);
29   } else if (t->isa<ValueTuple>()) {
30     auto tuple = t->cast<ValueTuplePtr>();
31     return ConvertAbstract(tuple);
32   } else {
33     return t->ToAbstract();
34   }
35 }
36 
37 // Tensor is held by Abstract, may lead to memory leak.
ConvertAbstract(const BaseTensorPtr & t)38 AbstractBasePtr AbstractConverter::ConvertAbstract(const BaseTensorPtr &t) {
39   auto abs = t->ToAbstract();
40   abs->set_value(kValueAny);
41   t->set_abstract(abs);
42   abstract_cache_.Push(abs);
43   return abs;
44 }
45 
ConvertAbstract(const ValueTuplePtr & t)46 AbstractBasePtr AbstractConverter::ConvertAbstract(const ValueTuplePtr &t) {
47   AbstractBasePtrList abs_list(t->value().size());
48   for (size_t i = 0; i < t->value().size(); ++i) {
49     auto &val = t->value()[i];
50     auto abs = val->ToAbstract();
51     if (val->isa<tensor::BaseTensor>()) {
52       abs->set_value(kValueAny);
53       auto tensor = val->cast<tensor::BaseTensorPtr>();
54       tensor->set_abstract(abs);
55       abstract_cache_.Push(abs);
56     }
57     abs_list[i] = abs;
58   }
59   return std::make_shared<abstract::AbstractTuple>(abs_list);
60 }
61 }  // namespace pynative
62 }  // namespace mindspore
63