• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_OPS_DEFORMABLE_CONV2D_H_
18 #define MINDSPORE_CORE_OPS_DEFORMABLE_CONV2D_H_
19 
20 #include <vector>
21 #include <string>
22 #include "ops/base_operator.h"
23 #include "mindapi/base/format.h"
24 #include "include/common/utils/utils.h"
25 
26 namespace mindspore {
27 namespace ops {
28 constexpr auto kNameDeformableConv2d = "DeformableConv2d";
29 /// \brief DeformableConv2D. Refer to Python API @ref mindspore.ops.deformable_conv2d for more details.
30 class MIND_API DeformableConv2d : public BaseOperator {
31  public:
32   MIND_API_BASE_MEMBER(DeformableConv2d);
33   /// \brief Constructor.
DeformableConv2d()34   DeformableConv2d() : BaseOperator(kNameDeformableConv2d) { InitIOName({"x", "filter", "offsets", "bias"}, {"y"}); }
35   /// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.deformable_conv2d for the inputs.
36   void Init(const std::vector<int64_t> &strides, const std::vector<int64_t> &pads,
37             const std::vector<int64_t> &dilations = {1, 1, 1, 1}, int64_t groups = 1,
38             const std::string &data_format = kOpFormat_NCHW, int64_t deformable_groups = 1, bool modulated = true);
39 
40   /// \brief Set strides.
41   void set_strides(const std::vector<int64_t> &strides);
42 
43   /// \brief Get strides.
44   ///
45   /// \return strides.
46   std::vector<int64_t> get_strides() const;
47 
48   /// \brief Set pads.
49   void set_pads(const std::vector<int64_t> &pads);
50 
51   /// \brief Get pads.
52   ///
53   /// \return pads.
54   std::vector<int64_t> get_pads() const;
55 
56   /// \brief Set dilations.
57   void set_dilations(const std::vector<int64_t> &dilations);
58 
59   /// \brief Get dilations.
60   ///
61   /// \return dilations.
62   std::vector<int64_t> get_dilations() const;
63 
64   /// \brief Set format.
65   void set_data_format(const std::string &data_format);
66 
67   /// \brief Get format.
68   ///
69   /// \return format.
70   std::string get_data_format() const;
71 
72   /// \brief Set number of blocked connection from input channels to output channels.
73   void set_groups(int64_t groups);
74 
75   /// \brief Get number of groups.
76   ///
77   /// \return groups.
78   int64_t get_groups() const;
79 
80   /// \brief Set deformable_groups.
81   void set_deformable_groups(int64_t deformable_groups);
82 
83   /// \brief Get deformable_groups.
84   ///
85   /// \return deformable_groups.
86   int64_t get_deformable_groups() const;
87 
88   /// \brief Set modulated.
89   void set_modulated(bool modulated);
90 
91   /// \brief Get modulated.
92   ///
93   /// \return modulated.
94   bool get_modulated() const;
95 };
96 MIND_API abstract::AbstractBasePtr DeformableConv2dInfer(const abstract::AnalysisEnginePtr &,
97                                                          const PrimitivePtr &primitive,
98                                                          const std::vector<abstract::AbstractBasePtr> &input_args);
99 }  // namespace ops
100 }  // namespace mindspore
101 #endif  // MINDSPORE_CORE_OPS_DEFORMABLE_CONV2D_H_
102