1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "include/common/fallback.h"
18
19 #include <queue>
20
21 #include "include/common/utils/python_adapter.h"
22 #include "utils/log_adapter.h"
23 #include "utils/ms_context.h"
24 #include "utils/phase.h"
25
26 namespace mindspore {
27 namespace fallback {
28 static std::queue<py::object> py_execute_output_queue = std::queue<py::object>();
29
HasPyExecuteOutput()30 bool HasPyExecuteOutput() { return !py_execute_output_queue.empty(); }
31
PopPyExecuteOutput()32 py::object PopPyExecuteOutput() {
33 auto output = py_execute_output_queue.front();
34 MS_LOG(DEBUG) << "output: " << output;
35 py_execute_output_queue.pop();
36 return output;
37 }
38
PushPyExecuteOutput(const py::object & output)39 void PushPyExecuteOutput(const py::object &output) {
40 MS_LOG(DEBUG) << "output: " << output;
41 py_execute_output_queue.push(output);
42 }
43
GetJitSyntaxLevel()44 int GetJitSyntaxLevel() {
45 // Get jit_syntax_level from environment variable 'MS_DEV_JIT_SYNTAX_LEVEL'.
46 std::string env_level_str = common::GetEnv("MS_DEV_JIT_SYNTAX_LEVEL");
47 if (env_level_str.size() == 1) {
48 int env_level = -1;
49 try {
50 env_level = std::stoi(env_level_str);
51 } catch (const std::invalid_argument &ia) {
52 MS_LOG(EXCEPTION) << "Invalid argument: " << ia.what() << " when parse " << env_level_str;
53 }
54 if (env_level >= kStrict && env_level <= kLax) {
55 return env_level;
56 }
57 }
58 if (!env_level_str.empty()) {
59 MS_LOG(EXCEPTION) << "JIT syntax level should be a number and from 0 to 2, but got " << env_level_str;
60 }
61
62 // Get jit_syntax_level from jit_config, default to an empty string.
63 const auto &jit_config = PhaseManager::GetInstance().jit_config();
64 auto iter = jit_config.find("jit_syntax_level");
65 if (iter != jit_config.end()) {
66 auto level = iter->second;
67 if (level == "STRICT") {
68 return kStrict;
69 } else if (level == "COMPATIBLE") {
70 return kCompatible;
71 } else if (level == "LAX") {
72 return kLax;
73 }
74 }
75 // Get jit_syntax_level from context.
76 return MsContext::GetInstance()->get_param<int>(MS_CTX_JIT_SYNTAX_LEVEL);
77 }
78
GetTypeElements(const TypePtr & type)79 TypePtrList GetTypeElements(const TypePtr &type) {
80 MS_EXCEPTION_IF_NULL(type);
81 if (type->isa<List>()) {
82 auto type_list = type->cast_ptr<List>();
83 return type_list->elements();
84 }
85 auto type_tuple = type->cast_ptr<Tuple>();
86 MS_EXCEPTION_IF_NULL(type_tuple);
87 return type_tuple->elements();
88 }
89
GenerateAbstractSequence(const BaseShapePtr & base_shape,const TypePtr & type,bool is_frontend)90 abstract::AbstractSequencePtr GenerateAbstractSequence(const BaseShapePtr &base_shape, const TypePtr &type,
91 bool is_frontend) {
92 // Generate AbstractSequence for PyExecute node.
93 MS_EXCEPTION_IF_NULL(base_shape);
94 MS_EXCEPTION_IF_NULL(type);
95 bool is_list = base_shape->isa<abstract::ListShape>() && type->isa<List>();
96 bool is_tuple = base_shape->isa<abstract::TupleShape>() && type->isa<Tuple>();
97 if (!is_list && !is_tuple) {
98 MS_INTERNAL_EXCEPTION(TypeError) << "For GenerateAbstractSequence, the input shape and type should be both "
99 << "list or tuple, but got shape: " << base_shape->ToString()
100 << " and type: " << type->ToString();
101 }
102 auto shape_seq = base_shape->cast_ptr<abstract::SequenceShape>();
103 MS_EXCEPTION_IF_NULL(shape_seq);
104 const auto &type_elements = GetTypeElements(type);
105 if (shape_seq->size() != type_elements.size()) {
106 MS_INTERNAL_EXCEPTION(ValueError) << "For GenerateAbstractSequence, the shape and type size should be the same, "
107 << "but got shape size: " << shape_seq->size()
108 << " and type size: " << type_elements.size();
109 }
110 AbstractBasePtrList ptr_list;
111 for (size_t it = 0; it < shape_seq->size(); ++it) {
112 auto element_shape = (*shape_seq)[it];
113 auto element_type = type_elements[it];
114 bool is_external = element_type->isa<External>();
115 bool is_tensor_or_scalar = element_type->isa<Number>() || element_type->isa<TensorType>();
116 if (!is_external && is_tensor_or_scalar) {
117 (void)ptr_list.emplace_back(abstract::MakeAbstract(element_shape, element_type));
118 } else {
119 if (is_frontend) {
120 (void)ptr_list.emplace_back(std::make_shared<abstract::AbstractAny>());
121 } else {
122 // In backend, the type is correctly fixed and the shape should be fixed.
123 const auto &infer_shape = std::make_shared<abstract::Shape>(ShapeVector({1}));
124 (void)ptr_list.emplace_back(abstract::MakeAbstract(infer_shape, kFloat64));
125 }
126 }
127 }
128 if (!is_frontend || is_tuple) {
129 return std::make_shared<abstract::AbstractTuple>(ptr_list);
130 }
131 return std::make_shared<abstract::AbstractList>(ptr_list);
132 }
133 } // namespace fallback
134 } // namespace mindspore
135