• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <string>
18 #include <set>
19 #include <vector>
20 #include <algorithm>
21 #include <memory>
22 #include "ops/op_utils.h"
23 #include "abstract/primitive_infer_map.h"
24 #include "utils/check_convert_utils.h"
25 
26 namespace mindspore {
27 namespace ops {
CalBroadCastShape(std::vector<int64_t> x_shape,std::vector<int64_t> y_shape,const std::string & op_name,const std::string & op_x_name,const std::string & op_y_name)28 std::vector<int64_t> CalBroadCastShape(std::vector<int64_t> x_shape, std::vector<int64_t> y_shape,
29                                        const std::string &op_name, const std::string &op_x_name,
30                                        const std::string &op_y_name) {
31   if (x_shape == y_shape) {
32     return x_shape;
33   }
34   auto x_length = static_cast<int64_t>(x_shape.size());
35   auto y_length = static_cast<int64_t>(y_shape.size());
36   auto length = x_length < y_length ? x_length : y_length;
37   std::vector<int64_t> broadcast_shape;
38   if (x_length == length) {
39     (void)std::copy(y_shape.begin(), y_shape.end() - length, std::back_inserter(broadcast_shape));
40   } else {
41     (void)std::copy(x_shape.begin(), x_shape.end() - length, std::back_inserter(broadcast_shape));
42   }
43   for (int64_t i = -length; i < 0; i++) {
44     if (x_shape[LongToSize(x_length + i)] == 1) {
45       broadcast_shape.push_back(y_shape[LongToSize(y_length + i)]);
46     } else if (y_shape[LongToSize(y_length + i)] == 1) {
47       broadcast_shape.push_back(x_shape[LongToSize(x_length + i)]);
48     } else if (x_shape[LongToSize(x_length + i)] == y_shape[LongToSize(y_length + i)]) {
49       broadcast_shape.push_back(x_shape[LongToSize(x_length + i)]);
50     } else {
51       MS_EXCEPTION(ValueError) << "For op " << op_name << ", the two input '" << op_x_name << "' and '" << op_y_name
52                                << "' can not broadcast";
53     }
54   }
55   return broadcast_shape;
56 }
BroadCastInferShape(const std::string & op_name,const std::vector<AbstractBasePtr> & input_args)57 abstract::ShapePtr BroadCastInferShape(const std::string &op_name, const std::vector<AbstractBasePtr> &input_args) {
58   MS_LOG(INFO) << "Do infer shape for op " << op_name;
59   const int64_t input_num = 2;
60   CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, op_name);
61   auto x_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack());
62   auto y_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->GetShapeTrack());
63   auto x_shape = x_shape_map[kShape];
64   auto y_shape = y_shape_map[kShape];
65   auto x_min_shape = x_shape_map[kMinShape];
66   auto x_max_shape = x_shape_map[kMaxShape];
67   auto y_min_shape = y_shape_map[kMinShape];
68   auto y_max_shape = y_shape_map[kMaxShape];
69   if (x_shape == y_shape) {
70     return std::make_shared<abstract::Shape>(x_shape, x_min_shape, x_max_shape);
71   }
72   auto broadcast_shape = CalBroadCastShape(x_shape, y_shape, op_name);
73   auto min_broadcast_shape = CalBroadCastShape(x_min_shape, y_min_shape, op_name);
74   auto max_broadcast_shape = CalBroadCastShape(x_max_shape, y_max_shape, op_name);
75   return std::make_shared<abstract::Shape>(broadcast_shape, min_broadcast_shape, max_broadcast_shape);
76 }
77 }  // namespace ops
78 }  // namespace mindspore
79