• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "mindspore/core/ops/symbol_ops_impl/common.h"
17 #include "mindspore/core/ops/symbol_ops_impl/scalar_add.h"
18 #include "mindspore/core/ops/symbol_ops_impl/scalar_mul.h"
19 
20 namespace mindspore {
21 namespace symshape {
22 namespace ops {
SetPositive(const ListSymbol * list)23 void InferShapeOp::SetPositive(const ListSymbol *list) {
24   for (auto &s : list->symbols()) {
25     auto list_s = s->as_noexcept<ListSymbol>();
26     if (list_s != nullptr) {
27       SetPositive(list_s);
28     } else {
29       auto int_s = s->as_noexcept<IntSymbol>();
30       MS_EXCEPTION_IF_NULL(int_s);
31       if (!int_s->is_positive()) {
32         int_s->SetRangeMin(1);
33       }
34     }
35   }
36 }
37 
TransValueToShape(OperationBuilder * b)38 SymbolPtr TransValueToShape(OperationBuilder *b) {
39   auto ret = TransparentInput(b);
40   if (ret == nullptr) {
41     return nullptr;
42   }
43   auto ret_shape = ret->as_noexcept<ListSymbol>();
44   MS_EXCEPTION_IF_NULL(ret_shape);
45   InferShapeOp::SetPositive(ret_shape);
46   return ret;
47 }
48 
49 template <typename OP>
Accumulate(const SymbolPtrList & symbols,const OperationEmitter & e)50 SymbolPtr Accumulate(const SymbolPtrList &symbols, const OperationEmitter &e) {
51   SymbolPtr vars = nullptr;
52   int64_t constv = std::is_same_v<OP, ScalarAdd> ? 0 : 1;
53   for (size_t i = 0; i < symbols.size(); i++) {
54     auto s = symbols[i]->as_sptr<IntSymbol>();
55     MS_EXCEPTION_IF_NULL(s);
56     if (s->HasData()) {
57       if (std::is_same_v<OP, ScalarAdd>) {
58         constv += s->value();
59       } else {
60         constv *= s->value();
61       }
62     } else if (vars == nullptr) {
63       vars = s;
64     } else {
65       vars = e.Emit(std::make_shared<OP>(vars, s));
66     }
67   }
68   if (vars == nullptr) {
69     return IntSymbol::Make(constv);
70   }
71   return e.Emit(std::make_shared<OP>(vars, IntSymbol::Make(constv)));
72 }
73 template SymbolPtr Accumulate<ScalarAdd>(const SymbolPtrList &, const OperationEmitter &);
74 template SymbolPtr Accumulate<ScalarMul>(const SymbolPtrList &, const OperationEmitter &);
75 }  // namespace ops
76 }  // namespace symshape
77 }  // namespace mindspore
78