• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/kernel_compiler/tbe/tbe_json/tbe_json_utils.h"
18 #include "base/core_ops.h"
19 #include "backend/session/anf_runtime_algorithm.h"
20 #include "backend/kernel_compiler/tbe/tbe_convert_utils.h"
21 #include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
22 #include "runtime/dev.h"
23 #include "utils/json_operation_utils.h"
24 
25 namespace mindspore::kernel {
GetInputsRealNum(const AnfNodePtr & anf_node,const std::vector<OpIOInfoPtr> & inputs_ptr,std::vector<size_t> * inputs_num)26 bool TbeJsonUtils::GetInputsRealNum(const AnfNodePtr &anf_node, const std::vector<OpIOInfoPtr> &inputs_ptr,
27                                     std::vector<size_t> *inputs_num) {
28   MS_EXCEPTION_IF_NULL(anf_node);
29   MS_EXCEPTION_IF_NULL(inputs_num);
30   auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
31   // for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input.
32   auto dyn_input_sizes_ptr = primitive->GetAttr(kAttrDynInputSizes);
33   std::vector<int64_t> dyn_input_sizes = (dyn_input_sizes_ptr != nullptr)
34                                            ? GetValue<const std::vector<int64_t>>(dyn_input_sizes_ptr)
35                                            : std::vector<int64_t>{};
36   size_t dyn_input_index = 0;
37   for (const auto &input_ptr : inputs_ptr) {
38     if (input_ptr->param_type() == kJParamDynamic) {
39       if (dyn_input_index >= dyn_input_sizes.size()) {
40         MS_LOG(ERROR) << "Dyn input index" << dyn_input_index << "is over dyn input num" << dyn_input_sizes.size();
41         return false;
42       } else {
43         (*inputs_num).emplace_back(LongToSize(dyn_input_sizes[dyn_input_index]));
44         dyn_input_index++;
45       }
46     } else {
47       (*inputs_num).emplace_back(1);
48     }
49   }
50   return true;
51 }
52 
GetOutputsRealNum(const AnfNodePtr & anf_node,const std::vector<OpIOInfoPtr> & outputs_ptr,std::vector<size_t> * outputs_num)53 bool TbeJsonUtils::GetOutputsRealNum(const AnfNodePtr &anf_node, const std::vector<OpIOInfoPtr> &outputs_ptr,
54                                      std::vector<size_t> *outputs_num) {
55   MS_EXCEPTION_IF_NULL(anf_node);
56   size_t real_output_num = AnfAlgo::GetOutputTensorNum(anf_node);
57   for (const auto &output_ptr : outputs_ptr) {
58     if (output_ptr->param_type() == kJParamDynamic) {
59       if (outputs_ptr.size() > 1) {
60         MS_LOG(ERROR) << "Dynamic output is unsupported multi output, node [ " << AnfAlgo::GetCNodeName(anf_node)
61                       << " ] has " << outputs_ptr.size() << "outputs, however one of the outputs param_type is "
62                       << output_ptr->param_type();
63         return false;
64       }
65       outputs_num->emplace_back(real_output_num);
66     } else {
67       outputs_num->emplace_back(1);
68     }
69   }
70   return true;
71 }
72 
IsNeedChangeDefaultFormat(const AnfNodePtr & anf_node)73 bool TbeJsonUtils::IsNeedChangeDefaultFormat(const AnfNodePtr &anf_node) {
74   MS_EXCEPTION_IF_NULL(anf_node);
75   return anf_node->isa<CNode>() && AnfAlgo::HasNodeAttr(kAttrFormat, anf_node->cast<CNodePtr>()) &&
76          AnfAlgo::GetNodeAttr<std::string>(anf_node, kAttrFormat) == kOpFormat_NCDHW;
77 }
78 
GetInputOriShapeForTbeBuild(const AnfNodePtr & anf_node,size_t real_idx)79 std::vector<int64_t> TbeJsonUtils::GetInputOriShapeForTbeBuild(const AnfNodePtr &anf_node, size_t real_idx) {
80   MS_EXCEPTION_IF_NULL(anf_node);
81   session::KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, real_idx);
82   return GetOutputOriShapeForTbeBuild(kernel_with_index.first, kernel_with_index.second);
83 }
84 
GetInputDeviceShapeForTbeBuild(const AnfNodePtr & anf_node,size_t real_idx)85 std::vector<int64_t> TbeJsonUtils::GetInputDeviceShapeForTbeBuild(const AnfNodePtr &anf_node, size_t real_idx) {
86   MS_EXCEPTION_IF_NULL(anf_node);
87   std::vector<int64_t> shape;
88   session::KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, real_idx);
89   auto format = AnfAlgo::GetInputFormat(anf_node, real_idx);
90   shape = AnfAlgo::GetOutputDeviceShapeForTbeBuild(kernel_with_index.first, kernel_with_index.second, format);
91   if (shape.empty()) {
92     shape.emplace_back(1);
93   }
94   return shape;
95 }
96 
GetOutputOriShapeForTbeBuild(const AnfNodePtr & anf_node,size_t real_idx)97 std::vector<int64_t> TbeJsonUtils::GetOutputOriShapeForTbeBuild(const AnfNodePtr &anf_node, size_t real_idx) {
98   MS_EXCEPTION_IF_NULL(anf_node);
99   std::vector<int64_t> shape;
100   auto out_shape = AnfAlgo::GetOutputDetailShape(anf_node, real_idx);
101   MS_EXCEPTION_IF_NULL(out_shape);
102   if (out_shape->isa<abstract::Shape>()) {
103     auto shape_ptr = out_shape->cast<abstract::ShapePtr>();
104     MS_EXCEPTION_IF_NULL(shape_ptr);
105     shape = shape_ptr->shape();
106   }
107   if (shape.empty()) {
108     shape.emplace_back(1);
109   }
110   return shape;
111 }
112 
GetOutputDeviceShapeForTbeBuild(const AnfNodePtr & anf_node,size_t real_idx)113 std::vector<int64_t> TbeJsonUtils::GetOutputDeviceShapeForTbeBuild(const AnfNodePtr &anf_node, size_t real_idx) {
114   MS_EXCEPTION_IF_NULL(anf_node);
115   std::vector<int64_t> shape;
116   auto format = AnfAlgo::GetOutputFormat(anf_node, real_idx);
117   shape = AnfAlgo::GetOutputDeviceShapeForTbeBuild(anf_node, real_idx, format);
118   if (shape.empty()) {
119     shape.emplace_back(1);
120   }
121   return shape;
122 }
123 }  // namespace mindspore::kernel
124