• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_COMPARISON_FUNCTION_INFO_H_
18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_COMPARISON_FUNCTION_INFO_H_
19 
20 #include <string>
21 #include <memory>
22 #include <unordered_map>
23 #include <vector>
24 #include "ir/value.h"
25 #include "frontend/parallel/auto_parallel/operator_costmodel.h"
26 #include "frontend/parallel/ops_info/arithmetic_info.h"
27 #include "frontend/parallel/strategy.h"
28 
29 namespace mindspore {
30 namespace parallel {
31 class EqualInfo : public ArithmeticBase {
32  public:
EqualInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)33   EqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
34             const PrimitiveAttrs &attrs)
35       : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<EqualCost>()) {}
36   ~EqualInfo() override = default;
37 };
38 
39 class ApproximateEqualInfo : public ArithmeticBase {
40  public:
ApproximateEqualInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)41   ApproximateEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
42                        const PrimitiveAttrs &attrs)
43       : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ApproximateEqualCost>()) {}
44   ~ApproximateEqualInfo() override = default;
45 };
46 
47 class NotEqualInfo : public ArithmeticBase {
48  public:
NotEqualInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)49   NotEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
50                const PrimitiveAttrs &attrs)
51       : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<NotEqualCost>()) {}
52   ~NotEqualInfo() override = default;
53 };
54 
55 class MaximumInfo : public ArithmeticBase {
56  public:
MaximumInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)57   MaximumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
58               const PrimitiveAttrs &attrs)
59       : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<MaximumCost>()) {}
60   ~MaximumInfo() override = default;
61 };
62 
63 class MinimumInfo : public ArithmeticBase {
64  public:
MinimumInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)65   MinimumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
66               const PrimitiveAttrs &attrs)
67       : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<MinimumCost>()) {}
68   ~MinimumInfo() override = default;
69 };
70 
71 class GreaterInfo : public ArithmeticBase {
72  public:
GreaterInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)73   GreaterInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
74               const PrimitiveAttrs &attrs)
75       : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<GreaterCost>()) {}
76   ~GreaterInfo() override = default;
77 };
78 
79 class GreaterEqualInfo : public ArithmeticBase {
80  public:
GreaterEqualInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)81   GreaterEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
82                    const PrimitiveAttrs &attrs)
83       : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<GreaterEqualCost>()) {}
84   ~GreaterEqualInfo() override = default;
85 };
86 
87 class LessInfo : public ArithmeticBase {
88  public:
LessInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)89   LessInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
90            const PrimitiveAttrs &attrs)
91       : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<LessCost>()) {}
92   ~LessInfo() override = default;
93 };
94 
95 class LessEqualInfo : public ArithmeticBase {
96  public:
LessEqualInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)97   LessEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
98                 const PrimitiveAttrs &attrs)
99       : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<LessEqualCost>()) {}
100   ~LessEqualInfo() override = default;
101 };
102 }  // namespace parallel
103 }  // namespace mindspore
104 
105 #endif  // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_COMPARISON_FUNCTION_INFO_H_
106