• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_CORE_ABSTRACT_UTILS_H_
20 #define MINDSPORE_CORE_ABSTRACT_UTILS_H_
21 
22 #include <vector>
23 #include <utility>
24 #include <memory>
25 #include <string>
26 #include <functional>
27 #include "abstract/abstract_value.h"
28 #include "utils/any.h"
29 #include "utils/misc.h"
30 #include "utils/shape_utils.h"
31 #include "mindapi/base/macros.h"
32 
33 namespace mindspore {
34 namespace abstract {
35 MS_CORE_API ValuePtr ValueJoin(const ValuePtr &value1, const ValuePtr &value2);
36 MS_CORE_API TypePtr TypeJoin(const TypePtr &type1, const TypePtr &type2);
37 MS_CORE_API ShapePtr ShapeJoin(const ShapePtr &shape1, const ShapePtr &shape2);
38 
39 MS_CORE_API AbstractBasePtr AbstractJoin(const AbstractBasePtrList &args_abs_list);
40 MS_CORE_API AbstractBasePtrList AbstractJoin(const AbstractBasePtrList &spec1, const AbstractBasePtrList &spec2);
41 MS_CORE_API AbstractBasePtr AbstractBroaden(const AbstractBasePtr &abs);
42 
43 // Return an abstract value for the sensitivity of x.
44 // The sensitivity of a function is an Env
45 // The sensitivity of J(x) is x
46 // else self.Clone;
47 MS_CORE_API AbstractBasePtr SensitivityTransform(const AbstractBasePtr &spec);
48 
49 ShapeVector BroadcastShape(ShapeVector shpx, ShapeVector shpy);
50 MS_CORE_API size_t TypeIdSize(const TypeId data_type);
51 template <typename T>
ShapeSize(const std::vector<T> & shape)52 T ShapeSize(const std::vector<T> &shape) {
53   return std::accumulate(shape.begin(), shape.end(), static_cast<T>(1), std::multiplies<T>());
54 }
55 
56 MS_CORE_API AbstractBasePtr MakeAbstract(const BaseShapePtr &base_shape, const TypePtr &type);
57 MS_CORE_API AbstractBasePtr MakeMonadAbstract(const MonadTypePtr &type);
58 MS_CORE_API AbstractBasePtr MakeAbstractTensor(const ShapePtr &shape, const TypePtr &type);
59 MS_CORE_API std::vector<FuncGraphPtr> GetFuncGraphsFromCallNode(const CNodePtr &call_node);
60 
61 MS_CORE_API void SetVariableFlag(const AbstractBasePtr &abs);
62 
63 class MS_CORE_API EnvSetSparseResultMgr {
64  public:
GetInstance()65   static EnvSetSparseResultMgr &GetInstance() noexcept {
66     static EnvSetSparseResultMgr instance;
67     return instance;
68   }
69   EnvSetSparseResultMgr(const EnvSetSparseResultMgr &) = delete;
70   EnvSetSparseResultMgr &operator=(const EnvSetSparseResultMgr &) = delete;
71   ~EnvSetSparseResultMgr() = default;
72 
Get()73   bool Get() const { return env_set_sparse_result_; }
Set(bool env_set_sparse_result)74   void Set(bool env_set_sparse_result) { env_set_sparse_result_ = env_set_sparse_result; }
75 
76  private:
77   EnvSetSparseResultMgr() = default;
78   bool env_set_sparse_result_{false};
79 };
80 }  // namespace abstract
81 }  // namespace mindspore
82 #endif  // MINDSPORE_CORE_ABSTRACT_UTILS_H_
83