1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <string>
18 #include <set>
19 #include <vector>
20 #include <algorithm>
21 #include <memory>
22 #include "ops/op_utils.h"
23 #include "abstract/primitive_infer_map.h"
24 #include "utils/check_convert_utils.h"
25
26 namespace mindspore {
27 namespace ops {
CalBroadCastShape(std::vector<int64_t> x_shape,std::vector<int64_t> y_shape,const std::string & op_name,const std::string & op_x_name,const std::string & op_y_name)28 std::vector<int64_t> CalBroadCastShape(std::vector<int64_t> x_shape, std::vector<int64_t> y_shape,
29 const std::string &op_name, const std::string &op_x_name,
30 const std::string &op_y_name) {
31 if (x_shape == y_shape) {
32 return x_shape;
33 }
34 auto x_length = static_cast<int64_t>(x_shape.size());
35 auto y_length = static_cast<int64_t>(y_shape.size());
36 auto length = x_length < y_length ? x_length : y_length;
37 std::vector<int64_t> broadcast_shape;
38 if (x_length == length) {
39 (void)std::copy(y_shape.begin(), y_shape.end() - length, std::back_inserter(broadcast_shape));
40 } else {
41 (void)std::copy(x_shape.begin(), x_shape.end() - length, std::back_inserter(broadcast_shape));
42 }
43 for (int64_t i = -length; i < 0; i++) {
44 if (x_shape[LongToSize(x_length + i)] == 1) {
45 broadcast_shape.push_back(y_shape[LongToSize(y_length + i)]);
46 } else if (y_shape[LongToSize(y_length + i)] == 1) {
47 broadcast_shape.push_back(x_shape[LongToSize(x_length + i)]);
48 } else if (x_shape[LongToSize(x_length + i)] == y_shape[LongToSize(y_length + i)]) {
49 broadcast_shape.push_back(x_shape[LongToSize(x_length + i)]);
50 } else {
51 MS_EXCEPTION(ValueError) << "For op " << op_name << ", the two input '" << op_x_name << "' and '" << op_y_name
52 << "' can not broadcast";
53 }
54 }
55 return broadcast_shape;
56 }
BroadCastInferShape(const std::string & op_name,const std::vector<AbstractBasePtr> & input_args)57 abstract::ShapePtr BroadCastInferShape(const std::string &op_name, const std::vector<AbstractBasePtr> &input_args) {
58 MS_LOG(INFO) << "Do infer shape for op " << op_name;
59 const int64_t input_num = 2;
60 CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, op_name);
61 auto x_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack());
62 auto y_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->GetShapeTrack());
63 auto x_shape = x_shape_map[kShape];
64 auto y_shape = y_shape_map[kShape];
65 auto x_min_shape = x_shape_map[kMinShape];
66 auto x_max_shape = x_shape_map[kMaxShape];
67 auto y_min_shape = y_shape_map[kMinShape];
68 auto y_max_shape = y_shape_map[kMaxShape];
69 if (x_shape == y_shape) {
70 return std::make_shared<abstract::Shape>(x_shape, x_min_shape, x_max_shape);
71 }
72 auto broadcast_shape = CalBroadCastShape(x_shape, y_shape, op_name);
73 auto min_broadcast_shape = CalBroadCastShape(x_min_shape, y_min_shape, op_name);
74 auto max_broadcast_shape = CalBroadCastShape(x_max_shape, y_max_shape, op_name);
75 return std::make_shared<abstract::Shape>(broadcast_shape, min_broadcast_shape, max_broadcast_shape);
76 }
77 } // namespace ops
78 } // namespace mindspore
79