1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "mindspore/core/ops/symbol_ops_impl/common.h"
17 #include "mindspore/core/ops/symbol_ops_impl/scalar_add.h"
18 #include "mindspore/core/ops/symbol_ops_impl/scalar_mul.h"
19
20 namespace mindspore {
21 namespace symshape {
22 namespace ops {
SetPositive(const ListSymbol * list)23 void InferShapeOp::SetPositive(const ListSymbol *list) {
24 for (auto &s : list->symbols()) {
25 auto list_s = s->as_noexcept<ListSymbol>();
26 if (list_s != nullptr) {
27 SetPositive(list_s);
28 } else {
29 auto int_s = s->as_noexcept<IntSymbol>();
30 MS_EXCEPTION_IF_NULL(int_s);
31 if (!int_s->is_positive()) {
32 int_s->SetRangeMin(1);
33 }
34 }
35 }
36 }
37
TransValueToShape(OperationBuilder * b)38 SymbolPtr TransValueToShape(OperationBuilder *b) {
39 auto ret = TransparentInput(b);
40 if (ret == nullptr) {
41 return nullptr;
42 }
43 auto ret_shape = ret->as_noexcept<ListSymbol>();
44 MS_EXCEPTION_IF_NULL(ret_shape);
45 InferShapeOp::SetPositive(ret_shape);
46 return ret;
47 }
48
49 template <typename OP>
Accumulate(const SymbolPtrList & symbols,const OperationEmitter & e)50 SymbolPtr Accumulate(const SymbolPtrList &symbols, const OperationEmitter &e) {
51 SymbolPtr vars = nullptr;
52 int64_t constv = std::is_same_v<OP, ScalarAdd> ? 0 : 1;
53 for (size_t i = 0; i < symbols.size(); i++) {
54 auto s = symbols[i]->as_sptr<IntSymbol>();
55 MS_EXCEPTION_IF_NULL(s);
56 if (s->HasData()) {
57 if (std::is_same_v<OP, ScalarAdd>) {
58 constv += s->value();
59 } else {
60 constv *= s->value();
61 }
62 } else if (vars == nullptr) {
63 vars = s;
64 } else {
65 vars = e.Emit(std::make_shared<OP>(vars, s));
66 }
67 }
68 if (vars == nullptr) {
69 return IntSymbol::Make(constv);
70 }
71 return e.Emit(std::make_shared<OP>(vars, IntSymbol::Make(constv)));
72 }
73 template SymbolPtr Accumulate<ScalarAdd>(const SymbolPtrList &, const OperationEmitter &);
74 template SymbolPtr Accumulate<ScalarMul>(const SymbolPtrList &, const OperationEmitter &);
75 } // namespace ops
76 } // namespace symshape
77 } // namespace mindspore
78