• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ops/topk.h"
18 #include <set>
19 #include <utility>
20 #include "mindapi/src/helper.h"
21 #include "mindspore/core/ops/array_ops.h"
22 #include "ops/op_utils.h"
23 #include "utils/check_convert_utils.h"
24 
25 namespace mindspore {
26 namespace ops {
27 MIND_API_OPERATOR_IMPL(TopK, BaseOperator);
Init(const bool sorted)28 void TopK::Init(const bool sorted) { this->set_sorted(sorted); }
set_sorted(const bool sorted)29 void TopK::set_sorted(const bool sorted) { (void)this->AddAttr(kSorted, api::MakeValue(sorted)); }
30 
get_sorted() const31 bool TopK::get_sorted() const {
32   auto value_ptr = this->GetAttr(kSorted);
33   return GetValue<bool>(value_ptr);
34 }
35 
36 namespace {
TopKInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)37 abstract::TupleShapePtr TopKInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
38   auto prim_name = primitive->name();
39   auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->GetShape());
40   auto x_shape = shape_map[kShape];
41   if (IsDynamicRank(x_shape)) {
42     abstract::BaseShapePtr out_shape_ptr =
43       std::make_shared<abstract::Shape>(ShapeVector{abstract::Shape::kShapeRankAny});
44     return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{out_shape_ptr, out_shape_ptr});
45   }
46   int64_t k_v = 0;
47   if ((IsDynamicRank(x_shape)) || !IsValueKnown(input_args[kInputIndex1])) {
48     auto unknown_shape_p = std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeRankAny}));
49     return std::make_shared<abstract::TupleShape>(
50       std::vector<abstract::BaseShapePtr>{unknown_shape_p, unknown_shape_p});
51   }
52 
53   // 2rd input is a Tensor when TopK is a dynamic shape operator
54   if (CheckAndConvertUtils::IsTensor(input_args[kInputIndex1])) {
55     auto k_dim = input_args[kInputIndex1]->GetShape()->GetShapeVector().size();
56     if (k_dim > 1) {
57       MS_LOG(EXCEPTION) << "For '" << prim_name
58                         << "', the dimension of 'k' should only be 0 or 1 when 'k' is a Tensor, but got: " << k_dim
59                         << ".";
60     }
61     auto k_val = CheckAndConvertUtils::CheckTensorIntValue("k", input_args[kInputIndex1]->GetValue(), prim_name,
62                                                            input_args[kInputIndex1]->GetType());
63     k_v = k_val[0];
64   } else if (CheckAndConvertUtils::IsScalar(input_args[kInputIndex1])) {
65     k_v = GetScalarValue<int64_t>(input_args[kInputIndex1]->GetValue()).value();
66   } else {
67     MS_LOG(EXCEPTION) << "Invalid abstract type:" << input_args[kInputIndex1]->type_name();
68   }
69   if (!x_shape.empty()) {
70     auto ndims = x_shape.size() - 1;
71     if (x_shape[ndims] != abstract::Shape::kShapeDimAny) {
72       std::pair<int64_t, int64_t> k_range(0, x_shape[ndims]);
73       CheckAndConvertUtils::CheckInRange<int64_t>("k", k_v, kIncludeRight, k_range, prim_name);
74       x_shape[ndims] = k_v;
75     }
76   }
77 
78   auto out_shape_ptr = std::make_shared<abstract::Shape>(x_shape);
79   return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{out_shape_ptr, out_shape_ptr});
80 }
81 
TopKInferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)82 TuplePtr TopKInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
83   auto prim_name = primitive->name();
84   auto output0_type = input_args[kInputIndex0]->GetType();
85   (void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", output0_type, common_valid_types, prim_name);
86   auto k_type = input_args[kInputIndex1]->GetType();
87   const std::set<TypePtr> int_types = {kInt32, kInt64};
88   (void)CheckAndConvertUtils::CheckTypeValid("k", k_type, int_types, prim_name);
89   auto output1_type = kInt32;
90   return std::make_shared<Tuple>(std::vector<TypePtr>{output0_type, output1_type});
91 }
92 }  // namespace
93 
TopKInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)94 AbstractBasePtr TopKInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
95                           const std::vector<AbstractBasePtr> &input_args) {
96   MS_EXCEPTION_IF_NULL(primitive);
97   auto prim_name = primitive->name();
98   const int64_t input_num = 2;
99   CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
100   auto infer_type = TopKInferType(primitive, input_args);
101   auto infer_shape = TopKInferShape(primitive, input_args);
102   return abstract::MakeAbstract(infer_shape, infer_type);
103 }
104 
get_attr(const char * attr) const105 bool TopK::get_attr(const char *attr) const {
106   auto attr_ptr = GetAttr(attr);
107   return GetValue<bool>(attr_ptr);
108 }
109 
110 // AG means auto generated
111 class MIND_API AGTopKInfer : public abstract::OpInferBase {
112  public:
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const113   BaseShapePtr InferShape(const PrimitivePtr &primitive,
114                           const std::vector<AbstractBasePtr> &input_args) const override {
115     return TopKInferShape(primitive, input_args);
116   }
117 
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const118   TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
119     return TopKInferType(primitive, input_args);
120   }
InferShapeAndType(const abstract::AnalysisEnginePtr & engine,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const121   AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
122                                     const std::vector<AbstractBasePtr> &input_args) const override {
123     return TopKInfer(engine, primitive, input_args);
124   }
125 
GetValueDependArgIndices() const126   std::set<int64_t> GetValueDependArgIndices() const override { return {1}; }
127 };
128 
129 REGISTER_PRIMITIVE_OP_INFER_IMPL(TopK, prim::kPrimTopK, AGTopKInfer, false);
130 }  // namespace ops
131 }  // namespace mindspore
132