1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/kernel_compiler/tbe/tbe_json/tbe_json_utils.h"
18 #include "base/core_ops.h"
19 #include "backend/session/anf_runtime_algorithm.h"
20 #include "backend/kernel_compiler/tbe/tbe_convert_utils.h"
21 #include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
22 #include "runtime/dev.h"
23 #include "utils/json_operation_utils.h"
24
25 namespace mindspore::kernel {
GetInputsRealNum(const AnfNodePtr & anf_node,const std::vector<OpIOInfoPtr> & inputs_ptr,std::vector<size_t> * inputs_num)26 bool TbeJsonUtils::GetInputsRealNum(const AnfNodePtr &anf_node, const std::vector<OpIOInfoPtr> &inputs_ptr,
27 std::vector<size_t> *inputs_num) {
28 MS_EXCEPTION_IF_NULL(anf_node);
29 MS_EXCEPTION_IF_NULL(inputs_num);
30 auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
31 // for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input.
32 auto dyn_input_sizes_ptr = primitive->GetAttr(kAttrDynInputSizes);
33 std::vector<int64_t> dyn_input_sizes = (dyn_input_sizes_ptr != nullptr)
34 ? GetValue<const std::vector<int64_t>>(dyn_input_sizes_ptr)
35 : std::vector<int64_t>{};
36 size_t dyn_input_index = 0;
37 for (const auto &input_ptr : inputs_ptr) {
38 if (input_ptr->param_type() == kJParamDynamic) {
39 if (dyn_input_index >= dyn_input_sizes.size()) {
40 MS_LOG(ERROR) << "Dyn input index" << dyn_input_index << "is over dyn input num" << dyn_input_sizes.size();
41 return false;
42 } else {
43 (*inputs_num).emplace_back(LongToSize(dyn_input_sizes[dyn_input_index]));
44 dyn_input_index++;
45 }
46 } else {
47 (*inputs_num).emplace_back(1);
48 }
49 }
50 return true;
51 }
52
GetOutputsRealNum(const AnfNodePtr & anf_node,const std::vector<OpIOInfoPtr> & outputs_ptr,std::vector<size_t> * outputs_num)53 bool TbeJsonUtils::GetOutputsRealNum(const AnfNodePtr &anf_node, const std::vector<OpIOInfoPtr> &outputs_ptr,
54 std::vector<size_t> *outputs_num) {
55 MS_EXCEPTION_IF_NULL(anf_node);
56 size_t real_output_num = AnfAlgo::GetOutputTensorNum(anf_node);
57 for (const auto &output_ptr : outputs_ptr) {
58 if (output_ptr->param_type() == kJParamDynamic) {
59 if (outputs_ptr.size() > 1) {
60 MS_LOG(ERROR) << "Dynamic output is unsupported multi output, node [ " << AnfAlgo::GetCNodeName(anf_node)
61 << " ] has " << outputs_ptr.size() << "outputs, however one of the outputs param_type is "
62 << output_ptr->param_type();
63 return false;
64 }
65 outputs_num->emplace_back(real_output_num);
66 } else {
67 outputs_num->emplace_back(1);
68 }
69 }
70 return true;
71 }
72
IsNeedChangeDefaultFormat(const AnfNodePtr & anf_node)73 bool TbeJsonUtils::IsNeedChangeDefaultFormat(const AnfNodePtr &anf_node) {
74 MS_EXCEPTION_IF_NULL(anf_node);
75 return anf_node->isa<CNode>() && AnfAlgo::HasNodeAttr(kAttrFormat, anf_node->cast<CNodePtr>()) &&
76 AnfAlgo::GetNodeAttr<std::string>(anf_node, kAttrFormat) == kOpFormat_NCDHW;
77 }
78
GetInputOriShapeForTbeBuild(const AnfNodePtr & anf_node,size_t real_idx)79 std::vector<int64_t> TbeJsonUtils::GetInputOriShapeForTbeBuild(const AnfNodePtr &anf_node, size_t real_idx) {
80 MS_EXCEPTION_IF_NULL(anf_node);
81 session::KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, real_idx);
82 return GetOutputOriShapeForTbeBuild(kernel_with_index.first, kernel_with_index.second);
83 }
84
GetInputDeviceShapeForTbeBuild(const AnfNodePtr & anf_node,size_t real_idx)85 std::vector<int64_t> TbeJsonUtils::GetInputDeviceShapeForTbeBuild(const AnfNodePtr &anf_node, size_t real_idx) {
86 MS_EXCEPTION_IF_NULL(anf_node);
87 std::vector<int64_t> shape;
88 session::KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, real_idx);
89 auto format = AnfAlgo::GetInputFormat(anf_node, real_idx);
90 shape = AnfAlgo::GetOutputDeviceShapeForTbeBuild(kernel_with_index.first, kernel_with_index.second, format);
91 if (shape.empty()) {
92 shape.emplace_back(1);
93 }
94 return shape;
95 }
96
GetOutputOriShapeForTbeBuild(const AnfNodePtr & anf_node,size_t real_idx)97 std::vector<int64_t> TbeJsonUtils::GetOutputOriShapeForTbeBuild(const AnfNodePtr &anf_node, size_t real_idx) {
98 MS_EXCEPTION_IF_NULL(anf_node);
99 std::vector<int64_t> shape;
100 auto out_shape = AnfAlgo::GetOutputDetailShape(anf_node, real_idx);
101 MS_EXCEPTION_IF_NULL(out_shape);
102 if (out_shape->isa<abstract::Shape>()) {
103 auto shape_ptr = out_shape->cast<abstract::ShapePtr>();
104 MS_EXCEPTION_IF_NULL(shape_ptr);
105 shape = shape_ptr->shape();
106 }
107 if (shape.empty()) {
108 shape.emplace_back(1);
109 }
110 return shape;
111 }
112
GetOutputDeviceShapeForTbeBuild(const AnfNodePtr & anf_node,size_t real_idx)113 std::vector<int64_t> TbeJsonUtils::GetOutputDeviceShapeForTbeBuild(const AnfNodePtr &anf_node, size_t real_idx) {
114 MS_EXCEPTION_IF_NULL(anf_node);
115 std::vector<int64_t> shape;
116 auto format = AnfAlgo::GetOutputFormat(anf_node, real_idx);
117 shape = AnfAlgo::GetOutputDeviceShapeForTbeBuild(anf_node, real_idx, format);
118 if (shape.empty()) {
119 shape.emplace_back(1);
120 }
121 return shape;
122 }
123 } // namespace mindspore::kernel
124