• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ops/sparse_to_dense.h"
18 
19 #include <map>
20 #include <memory>
21 #include <set>
22 #include <string>
23 
24 #include "abstract/ops/primitive_infer_map.h"
25 #include "mindapi/src/helper.h"
26 #include "mindspore/core/ops/sparse_ops.h"
27 #include "ops/op_utils.h"
28 #include "utils/check_convert_utils.h"
29 
30 namespace mindspore {
31 namespace ops {
32 namespace {
33 constexpr int64_t kSparseToDenseInputMaxDim = 2;
34 constexpr int64_t kSparseToDenseInputMinDim = 1;
35 constexpr int64_t kSparseToDenseInputsNum = 3;
36 constexpr int64_t kNumZero = 0;
37 
SparseToDenseInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)38 abstract::ShapePtr SparseToDenseInferShape(const PrimitivePtr &primitive,
39                                            const std::vector<AbstractBasePtr> &input_args) {
40   MS_EXCEPTION_IF_NULL(primitive);
41   auto op_name = primitive->name();
42   auto indice_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->GetShape())[kShape];
43   auto values_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShape())[kShape];
44 
45   std::vector<ShapeVector> all_shapes = {indice_shape, values_shape};
46   auto is_dynamic = std::any_of(all_shapes.begin(), all_shapes.end(), IsDynamic);
47 
48   (void)CheckAndConvertUtils::CheckInteger("dimension of 'values'", SizeToLong(values_shape.size()), kEqual,
49                                            kSparseToDenseInputMinDim, op_name);
50   if (!is_dynamic) {
51     (void)CheckAndConvertUtils::CheckInteger("dimension of 'indices'", SizeToLong(indice_shape.size()), kEqual,
52                                              kSparseToDenseInputMaxDim, op_name);
53     (void)CheckAndConvertUtils::CheckInteger("batch of 'indices'", indice_shape[kInputIndex0], kEqual,
54                                              values_shape[kInputIndex0], op_name);
55   }
56   auto shape_arg = input_args[kInputIndex2];
57   MS_EXCEPTION_IF_NULL(shape_arg);
58   auto output_shape = GetShapeValue(primitive, shape_arg);
59   return std::make_shared<abstract::Shape>(output_shape);
60 }
61 
SparseToDenseInferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)62 TypePtr SparseToDenseInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
63   MS_EXCEPTION_IF_NULL(primitive);
64   auto op_name = primitive->name();
65   auto indice_type = input_args[kInputIndex0]->GetType();
66   auto values_type = input_args[kInputIndex1]->GetType();
67 
68   const std::set<TypePtr> valid_types = {kInt64, kInt32};
69   (void)CheckAndConvertUtils::CheckTensorTypeSame({{"indices", indice_type}}, valid_types, op_name);
70   const std::set<TypePtr> valid_types_value = {kInt64,  kInt32, kInt16,   kInt8,    kUInt64,  kUInt32,
71                                                kUInt16, kUInt8, kFloat16, kFloat32, kFloat64, kBool};
72   std::map<std::string, TypePtr> types_value;
73   (void)types_value.insert({"values", values_type});
74   (void)CheckAndConvertUtils::CheckTensorTypeSame(types_value, valid_types_value, op_name);
75   return values_type;
76 }
77 }  // namespace
78 
SparseToDenseInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<abstract::AbstractBasePtr> & input_args)79 abstract::AbstractBasePtr SparseToDenseInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
80                                              const std::vector<abstract::AbstractBasePtr> &input_args) {
81   MS_EXCEPTION_IF_NULL(primitive);
82   for (auto input : input_args) {
83     MS_EXCEPTION_IF_NULL(input);
84   }
85   (void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual,
86                                            kSparseToDenseInputsNum, primitive->name());
87   auto infer_type = SparseToDenseInferType(primitive, input_args);
88   auto infer_shape = SparseToDenseInferShape(primitive, input_args);
89   return abstract::MakeAbstract(infer_shape, infer_type);
90 }
91 
92 MIND_API_OPERATOR_IMPL(SparseToDense, BaseOperator);
93 
94 // AG means auto generated
95 class MIND_API AGSparseToDenseInfer : public abstract::OpInferBase {
96  public:
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const97   BaseShapePtr InferShape(const PrimitivePtr &primitive,
98                           const std::vector<AbstractBasePtr> &input_args) const override {
99     return SparseToDenseInferShape(primitive, input_args);
100   }
101 
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const102   TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
103     return SparseToDenseInferType(primitive, input_args);
104   }
InferShapeAndType(const abstract::AnalysisEnginePtr & engine,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const105   AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
106                                     const std::vector<AbstractBasePtr> &input_args) const override {
107     return SparseToDenseInfer(engine, primitive, input_args);
108   }
109 
GetValueDependArgIndices() const110   std::set<int64_t> GetValueDependArgIndices() const override { return {2}; }
111 };
112 
113 REGISTER_PRIMITIVE_OP_INFER_IMPL(SparseToDense, prim::kPrimSparseToDense, AGSparseToDenseInfer, false);
114 }  // namespace ops
115 }  // namespace mindspore
116