1 /**
2 * Copyright 2020-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ops/topk.h"
18 #include <set>
19 #include <utility>
20 #include "mindapi/src/helper.h"
21 #include "mindspore/core/ops/array_ops.h"
22 #include "ops/op_utils.h"
23 #include "utils/check_convert_utils.h"
24
25 namespace mindspore {
26 namespace ops {
27 MIND_API_OPERATOR_IMPL(TopK, BaseOperator);
Init(const bool sorted)28 void TopK::Init(const bool sorted) { this->set_sorted(sorted); }
set_sorted(const bool sorted)29 void TopK::set_sorted(const bool sorted) { (void)this->AddAttr(kSorted, api::MakeValue(sorted)); }
30
get_sorted() const31 bool TopK::get_sorted() const {
32 auto value_ptr = this->GetAttr(kSorted);
33 return GetValue<bool>(value_ptr);
34 }
35
36 namespace {
TopKInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)37 abstract::TupleShapePtr TopKInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
38 auto prim_name = primitive->name();
39 auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->GetShape());
40 auto x_shape = shape_map[kShape];
41 if (IsDynamicRank(x_shape)) {
42 abstract::BaseShapePtr out_shape_ptr =
43 std::make_shared<abstract::Shape>(ShapeVector{abstract::Shape::kShapeRankAny});
44 return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{out_shape_ptr, out_shape_ptr});
45 }
46 int64_t k_v = 0;
47 if ((IsDynamicRank(x_shape)) || !IsValueKnown(input_args[kInputIndex1])) {
48 auto unknown_shape_p = std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeRankAny}));
49 return std::make_shared<abstract::TupleShape>(
50 std::vector<abstract::BaseShapePtr>{unknown_shape_p, unknown_shape_p});
51 }
52
53 // 2rd input is a Tensor when TopK is a dynamic shape operator
54 if (CheckAndConvertUtils::IsTensor(input_args[kInputIndex1])) {
55 auto k_dim = input_args[kInputIndex1]->GetShape()->GetShapeVector().size();
56 if (k_dim > 1) {
57 MS_LOG(EXCEPTION) << "For '" << prim_name
58 << "', the dimension of 'k' should only be 0 or 1 when 'k' is a Tensor, but got: " << k_dim
59 << ".";
60 }
61 auto k_val = CheckAndConvertUtils::CheckTensorIntValue("k", input_args[kInputIndex1]->GetValue(), prim_name,
62 input_args[kInputIndex1]->GetType());
63 k_v = k_val[0];
64 } else if (CheckAndConvertUtils::IsScalar(input_args[kInputIndex1])) {
65 k_v = GetScalarValue<int64_t>(input_args[kInputIndex1]->GetValue()).value();
66 } else {
67 MS_LOG(EXCEPTION) << "Invalid abstract type:" << input_args[kInputIndex1]->type_name();
68 }
69 if (!x_shape.empty()) {
70 auto ndims = x_shape.size() - 1;
71 if (x_shape[ndims] != abstract::Shape::kShapeDimAny) {
72 std::pair<int64_t, int64_t> k_range(0, x_shape[ndims]);
73 CheckAndConvertUtils::CheckInRange<int64_t>("k", k_v, kIncludeRight, k_range, prim_name);
74 x_shape[ndims] = k_v;
75 }
76 }
77
78 auto out_shape_ptr = std::make_shared<abstract::Shape>(x_shape);
79 return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{out_shape_ptr, out_shape_ptr});
80 }
81
TopKInferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)82 TuplePtr TopKInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
83 auto prim_name = primitive->name();
84 auto output0_type = input_args[kInputIndex0]->GetType();
85 (void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", output0_type, common_valid_types, prim_name);
86 auto k_type = input_args[kInputIndex1]->GetType();
87 const std::set<TypePtr> int_types = {kInt32, kInt64};
88 (void)CheckAndConvertUtils::CheckTypeValid("k", k_type, int_types, prim_name);
89 auto output1_type = kInt32;
90 return std::make_shared<Tuple>(std::vector<TypePtr>{output0_type, output1_type});
91 }
92 } // namespace
93
TopKInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)94 AbstractBasePtr TopKInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
95 const std::vector<AbstractBasePtr> &input_args) {
96 MS_EXCEPTION_IF_NULL(primitive);
97 auto prim_name = primitive->name();
98 const int64_t input_num = 2;
99 CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
100 auto infer_type = TopKInferType(primitive, input_args);
101 auto infer_shape = TopKInferShape(primitive, input_args);
102 return abstract::MakeAbstract(infer_shape, infer_type);
103 }
104
get_attr(const char * attr) const105 bool TopK::get_attr(const char *attr) const {
106 auto attr_ptr = GetAttr(attr);
107 return GetValue<bool>(attr_ptr);
108 }
109
110 // AG means auto generated
111 class MIND_API AGTopKInfer : public abstract::OpInferBase {
112 public:
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const113 BaseShapePtr InferShape(const PrimitivePtr &primitive,
114 const std::vector<AbstractBasePtr> &input_args) const override {
115 return TopKInferShape(primitive, input_args);
116 }
117
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const118 TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
119 return TopKInferType(primitive, input_args);
120 }
InferShapeAndType(const abstract::AnalysisEnginePtr & engine,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const121 AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
122 const std::vector<AbstractBasePtr> &input_args) const override {
123 return TopKInfer(engine, primitive, input_args);
124 }
125
GetValueDependArgIndices() const126 std::set<int64_t> GetValueDependArgIndices() const override { return {1}; }
127 };
128
129 REGISTER_PRIMITIVE_OP_INFER_IMPL(TopK, prim::kPrimTopK, AGTopKInfer, false);
130 } // namespace ops
131 } // namespace mindspore
132