• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "include/common/fallback.h"
18 
19 #include <queue>
20 
21 #include "include/common/utils/python_adapter.h"
22 #include "utils/log_adapter.h"
23 #include "utils/ms_context.h"
24 #include "utils/phase.h"
25 
26 namespace mindspore {
27 namespace fallback {
28 static std::queue<py::object> py_execute_output_queue = std::queue<py::object>();
29 
HasPyExecuteOutput()30 bool HasPyExecuteOutput() { return !py_execute_output_queue.empty(); }
31 
PopPyExecuteOutput()32 py::object PopPyExecuteOutput() {
33   auto output = py_execute_output_queue.front();
34   MS_LOG(DEBUG) << "output: " << output;
35   py_execute_output_queue.pop();
36   return output;
37 }
38 
PushPyExecuteOutput(const py::object & output)39 void PushPyExecuteOutput(const py::object &output) {
40   MS_LOG(DEBUG) << "output: " << output;
41   py_execute_output_queue.push(output);
42 }
43 
GetJitSyntaxLevel()44 int GetJitSyntaxLevel() {
45   // Get jit_syntax_level from environment variable 'MS_DEV_JIT_SYNTAX_LEVEL'.
46   std::string env_level_str = common::GetEnv("MS_DEV_JIT_SYNTAX_LEVEL");
47   if (env_level_str.size() == 1) {
48     int env_level = -1;
49     try {
50       env_level = std::stoi(env_level_str);
51     } catch (const std::invalid_argument &ia) {
52       MS_LOG(EXCEPTION) << "Invalid argument: " << ia.what() << " when parse " << env_level_str;
53     }
54     if (env_level >= kStrict && env_level <= kLax) {
55       return env_level;
56     }
57   }
58   if (!env_level_str.empty()) {
59     MS_LOG(EXCEPTION) << "JIT syntax level should be a number and from 0 to 2, but got " << env_level_str;
60   }
61 
62   // Get jit_syntax_level from jit_config, default to an empty string.
63   const auto &jit_config = PhaseManager::GetInstance().jit_config();
64   auto iter = jit_config.find("jit_syntax_level");
65   if (iter != jit_config.end()) {
66     auto level = iter->second;
67     if (level == "STRICT") {
68       return kStrict;
69     } else if (level == "COMPATIBLE") {
70       return kCompatible;
71     } else if (level == "LAX") {
72       return kLax;
73     }
74   }
75   // Get jit_syntax_level from context.
76   return MsContext::GetInstance()->get_param<int>(MS_CTX_JIT_SYNTAX_LEVEL);
77 }
78 
GetTypeElements(const TypePtr & type)79 TypePtrList GetTypeElements(const TypePtr &type) {
80   MS_EXCEPTION_IF_NULL(type);
81   if (type->isa<List>()) {
82     auto type_list = type->cast_ptr<List>();
83     return type_list->elements();
84   }
85   auto type_tuple = type->cast_ptr<Tuple>();
86   MS_EXCEPTION_IF_NULL(type_tuple);
87   return type_tuple->elements();
88 }
89 
GenerateAbstractSequence(const BaseShapePtr & base_shape,const TypePtr & type,bool is_frontend)90 abstract::AbstractSequencePtr GenerateAbstractSequence(const BaseShapePtr &base_shape, const TypePtr &type,
91                                                        bool is_frontend) {
92   // Generate AbstractSequence for PyExecute node.
93   MS_EXCEPTION_IF_NULL(base_shape);
94   MS_EXCEPTION_IF_NULL(type);
95   bool is_list = base_shape->isa<abstract::ListShape>() && type->isa<List>();
96   bool is_tuple = base_shape->isa<abstract::TupleShape>() && type->isa<Tuple>();
97   if (!is_list && !is_tuple) {
98     MS_INTERNAL_EXCEPTION(TypeError) << "For GenerateAbstractSequence, the input shape and type should be both "
99                                      << "list or tuple, but got shape: " << base_shape->ToString()
100                                      << " and type: " << type->ToString();
101   }
102   auto shape_seq = base_shape->cast_ptr<abstract::SequenceShape>();
103   MS_EXCEPTION_IF_NULL(shape_seq);
104   const auto &type_elements = GetTypeElements(type);
105   if (shape_seq->size() != type_elements.size()) {
106     MS_INTERNAL_EXCEPTION(ValueError) << "For GenerateAbstractSequence, the shape and type size should be the same, "
107                                       << "but got shape size: " << shape_seq->size()
108                                       << " and type size: " << type_elements.size();
109   }
110   AbstractBasePtrList ptr_list;
111   for (size_t it = 0; it < shape_seq->size(); ++it) {
112     auto element_shape = (*shape_seq)[it];
113     auto element_type = type_elements[it];
114     bool is_external = element_type->isa<External>();
115     bool is_tensor_or_scalar = element_type->isa<Number>() || element_type->isa<TensorType>();
116     if (!is_external && is_tensor_or_scalar) {
117       (void)ptr_list.emplace_back(abstract::MakeAbstract(element_shape, element_type));
118     } else {
119       if (is_frontend) {
120         (void)ptr_list.emplace_back(std::make_shared<abstract::AbstractAny>());
121       } else {
122         // In backend, the type is correctly fixed and the shape should be fixed.
123         const auto &infer_shape = std::make_shared<abstract::Shape>(ShapeVector({1}));
124         (void)ptr_list.emplace_back(abstract::MakeAbstract(infer_shape, kFloat64));
125       }
126     }
127   }
128   if (!is_frontend || is_tuple) {
129     return std::make_shared<abstract::AbstractTuple>(ptr_list);
130   }
131   return std::make_shared<abstract::AbstractList>(ptr_list);
132 }
133 }  // namespace fallback
134 }  // namespace mindspore
135