• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2023 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_DATA_CONVERTER_H_
20 #define MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_DATA_CONVERTER_H_
21 
22 #include <deque>
23 #include <memory>
24 #include <utility>
25 #include <vector>
26 #include <string>
27 #include "utils/ordered_map.h"
28 #include "utils/hash_map.h"
29 #include "pipeline/jit/ps/parse/parse_base.h"
30 #include "include/common/utils/python_adapter.h"
31 #include "utils/log_adapter.h"
32 #include "ops/op_def.h"
33 
34 namespace mindspore {
35 namespace parse {
36 // data convert for parse
37 namespace data_converter {
38 void CacheObjectValue(const std::string &obj_key, const ValuePtr &data);
39 bool GetObjectValue(const std::string &obj_key, ValuePtr *const data);
40 
41 void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data);
42 
43 const mindspore::OrderedMap<std::string, std::vector<FuncGraphPtr>> &GetObjGraphs();
44 
45 std::vector<std::string> GetObjKey(const py::object &obj);
46 ResolveType GetObjType(const py::object &obj);
47 ClassInstanceType GetClassInstanceType(const py::object &obj);
48 
49 bool IsCellInstance(const py::object &obj);
50 bool IsNumpyArrayInstance(const py::object &obj);
51 bool IsMsClassInstance(const py::object &obj);
52 bool IsJITForbiddenAPI(const py::object &obj);
53 bool IsClassType(const py::object &obj);
54 py::object CreatePythonObject(const py::object &type, const py::tuple &args_kwargs);
55 py::object CallPythonScript(const py::object &script, const py::tuple &args_kwargs);
56 py::set GetPythonScriptIdAttrs(const py::object &script);
57 void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name);
58 ValuePtr PyDataToValue(const py::object &obj);
59 ValuePtr PyDataToStubNode(const py::object &obj);
60 void ClearObjectCache();
61 }  // namespace data_converter
62 
63 class DataConverter {
64  public:
DataConverter(ValuePtrList args_value_list,bool use_signature)65   DataConverter(ValuePtrList args_value_list, bool use_signature)
66       : args_value_list_(std::move(args_value_list)),
67         use_signature_(use_signature),
68         dtype_(nullptr),
69         forbid_reuse_(false) {}
70 
71   virtual ~DataConverter() = default;
72 
73   ValuePtr ConvertData(const py::object &obj);
74 
75  private:
76   ValuePtrList args_value_list_;
77   bool use_signature_;
78   TypePtr dtype_;
79   bool forbid_reuse_;
80 };
81 
82 FuncGraphPtr ConvertToBpropCut(const py::object &obj);
83 constexpr int32_t kTypeShiftBits = 16;
84 constexpr auto kDstMask = (1 << kTypeShiftBits) - 1;
CombineTypesForTypeCast(const mindspore::ops::OP_DTYPE & src,const mindspore::ops::OP_DTYPE & dst)85 inline int32_t CombineTypesForTypeCast(const mindspore::ops::OP_DTYPE &src, const mindspore::ops::OP_DTYPE &dst) {
86   return (static_cast<int32_t>(src) << kTypeShiftBits) | static_cast<int32_t>(dst);
87 }
88 // using OpDefConvertFunc = std::function<ValuePtr(const py::object &obj)>;
89 typedef ValuePtr (*OpDefConvertFunc)(const py::object &);
90 OpDefConvertFunc GetConverterByType(int32_t dtype);
91 ValuePtr ConvertTensor(const py::object &obj);
92 template <typename TS, typename TD, OpDefConvertFunc func>
93 ValuePtr ConvertSequence(const py::object &obj);
94 tensor::TensorPtr ConvertTensorValue(const py::object &obj);
95 }  // namespace parse
96 }  // namespace mindspore
97 
98 #endif  // MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_DATA_CONVERTER_H_
99