1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #ifndef MINDSPORE_CORE_ABSTRACT_UTILS_H_
20 #define MINDSPORE_CORE_ABSTRACT_UTILS_H_
21
22 #include <vector>
23 #include <utility>
24 #include <memory>
25 #include <string>
26 #include <functional>
27 #include "abstract/abstract_value.h"
28 #include "utils/any.h"
29 #include "utils/misc.h"
30 #include "utils/shape_utils.h"
31 #include "mindapi/base/macros.h"
32
33 namespace mindspore {
34 namespace abstract {
35 MS_CORE_API ValuePtr ValueJoin(const ValuePtr &value1, const ValuePtr &value2);
36 MS_CORE_API TypePtr TypeJoin(const TypePtr &type1, const TypePtr &type2);
37 MS_CORE_API ShapePtr ShapeJoin(const ShapePtr &shape1, const ShapePtr &shape2);
38
39 MS_CORE_API AbstractBasePtr AbstractJoin(const AbstractBasePtrList &args_abs_list);
40 MS_CORE_API AbstractBasePtrList AbstractJoin(const AbstractBasePtrList &spec1, const AbstractBasePtrList &spec2);
41 MS_CORE_API AbstractBasePtr AbstractBroaden(const AbstractBasePtr &abs);
42
43 // Return an abstract value for the sensitivity of x.
44 // The sensitivity of a function is an Env
45 // The sensitivity of J(x) is x
46 // else self.Clone;
47 MS_CORE_API AbstractBasePtr SensitivityTransform(const AbstractBasePtr &spec);
48
49 ShapeVector BroadcastShape(ShapeVector shpx, ShapeVector shpy);
50 MS_CORE_API size_t TypeIdSize(const TypeId data_type);
51 template <typename T>
ShapeSize(const std::vector<T> & shape)52 T ShapeSize(const std::vector<T> &shape) {
53 return std::accumulate(shape.begin(), shape.end(), static_cast<T>(1), std::multiplies<T>());
54 }
55
56 MS_CORE_API AbstractBasePtr MakeAbstract(const BaseShapePtr &base_shape, const TypePtr &type);
57 MS_CORE_API AbstractBasePtr MakeMonadAbstract(const MonadTypePtr &type);
58 MS_CORE_API AbstractBasePtr MakeAbstractTensor(const ShapePtr &shape, const TypePtr &type);
59 MS_CORE_API std::vector<FuncGraphPtr> GetFuncGraphsFromCallNode(const CNodePtr &call_node);
60
61 MS_CORE_API void SetVariableFlag(const AbstractBasePtr &abs);
62
63 class MS_CORE_API EnvSetSparseResultMgr {
64 public:
GetInstance()65 static EnvSetSparseResultMgr &GetInstance() noexcept {
66 static EnvSetSparseResultMgr instance;
67 return instance;
68 }
69 EnvSetSparseResultMgr(const EnvSetSparseResultMgr &) = delete;
70 EnvSetSparseResultMgr &operator=(const EnvSetSparseResultMgr &) = delete;
71 ~EnvSetSparseResultMgr() = default;
72
Get()73 bool Get() const { return env_set_sparse_result_; }
Set(bool env_set_sparse_result)74 void Set(bool env_set_sparse_result) { env_set_sparse_result_ = env_set_sparse_result; }
75
76 private:
77 EnvSetSparseResultMgr() = default;
78 bool env_set_sparse_result_{false};
79 };
80 } // namespace abstract
81 } // namespace mindspore
82 #endif // MINDSPORE_CORE_ABSTRACT_UTILS_H_
83