• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "abstract/dshape.h"
20 
21 #include <exception>
22 #include <iostream>
23 
24 #include "utils/log_adapter.h"
25 
26 namespace mindspore {
27 namespace abstract {
28 namespace {
ShapeVectorToStr(const std::vector<int64_t> & shp)29 std::string ShapeVectorToStr(const std::vector<int64_t> &shp) {
30   std::ostringstream buffer;
31   bool f_begin = true;
32   buffer << "(";
33   for (auto &x : shp) {
34     if (!f_begin) {
35       buffer << ", ";
36     } else {
37       f_begin = false;
38     }
39     buffer << x;
40   }
41   buffer << ")";
42   return buffer.str();
43 }
44 }  // namespace
45 // used for print BaseShape content
operator <<(std::ostream & os,const BaseShape & bs)46 std::ostream &operator<<(std::ostream &os, const BaseShape &bs) {
47   os << bs.ToString();
48   return os;
49 }
50 
operator <<(std::ostream & os,const std::shared_ptr<BaseShape> bs)51 std::ostream &operator<<(std::ostream &os, const std::shared_ptr<BaseShape> bs) {
52   MS_EXCEPTION_IF_NULL(bs);
53   os << bs->ToString();
54   return os;
55 }
56 
operator ==(const BaseShape & other) const57 bool BaseShape::operator==(const BaseShape &other) const {
58   if (tid() != other.tid()) {
59     return false;
60   }
61   return true;
62 }
63 
operator !=(const BaseShape & other) const64 bool BaseShape::operator!=(const BaseShape &other) const { return !(*this == other); }
65 
ToString() const66 std::string Shape::ToString() const {
67   std::ostringstream buffer;
68   bool has_dyn_shape = IsDynamic();
69   if (has_dyn_shape) {
70     buffer << "{shape:";
71   }
72   buffer << ShapeVectorToStr(shape_);
73   if (has_dyn_shape) {
74     buffer << "|min shape:";
75     buffer << ShapeVectorToStr(min_shape_);
76     buffer << "|max shape:";
77     buffer << ShapeVectorToStr(max_shape_);
78     buffer << "}";
79   }
80   return buffer.str();
81 }
82 
DumpText() const83 std::string Shape::DumpText() const {
84   std::ostringstream buffer;
85   buffer << "[";
86   for (size_t i = 0; i < shape_.size(); i++) {
87     buffer << (i > 0 ? ", " : "") << shape_[i];
88     if (shape_[i] == SHP_ANY && min_shape_.size() == shape_.size() && max_shape_.size() == shape_.size()) {
89       buffer << "_" << min_shape_[i] << "^" << max_shape_[i];
90     }
91   }
92   buffer << "]";
93   return buffer.str();
94 }
95 
operator ==(const BaseShape & other) const96 bool Shape::operator==(const BaseShape &other) const {
97   if (tid() != other.tid()) {
98     return false;
99   }
100   return shape_ == static_cast<const Shape &>(other).shape_;
101 }
102 
103 const int64_t Shape::SHP_ANY;
Broaden()104 void Shape::Broaden() {
105   for (size_t i = 0; i < shape_.size(); i++) {
106     shape_[i] = SHP_ANY;
107   }
108 }
109 
ToString() const110 std::string SequeueShape::ToString() const {
111   std::ostringstream buffer;
112   bool f_begin = true;
113   for (const auto &p_shp : p_shapes_) {
114     if (!f_begin) {
115       buffer << ", ";
116     } else {
117       f_begin = false;
118     }
119     MS_EXCEPTION_IF_NULL(p_shp);
120     buffer << p_shp->ToString();
121   }
122   return buffer.str();
123 }
124 
ElementsClone() const125 BaseShapePtrList SequeueShape::ElementsClone() const {
126   BaseShapePtrList ele_list;
127   for (auto p_shp : p_shapes_) {
128     MS_EXCEPTION_IF_NULL(p_shp);
129     ele_list.push_back(p_shp->Clone());
130   }
131   return ele_list;
132 }
133 
134 template <typename T>
SequeueEqual(const BaseShape & other) const135 bool SequeueShape::SequeueEqual(const BaseShape &other) const {
136   if (tid() != other.tid()) {
137     return false;
138   }
139   auto other_shapes = static_cast<const T &>(other).p_shapes_;
140   if (other_shapes.size() != p_shapes_.size()) {
141     return false;
142   }
143   for (uint64_t i = 0; i < p_shapes_.size(); ++i) {
144     MS_EXCEPTION_IF_NULL(p_shapes_[i]);
145     MS_EXCEPTION_IF_NULL(other_shapes[i]);
146     if (!(*p_shapes_[i] == *other_shapes[i])) {
147       return false;
148     }
149   }
150   return true;
151 }
152 template bool SequeueShape::SequeueEqual<TupleShape>(const BaseShape &) const;
153 template bool SequeueShape::SequeueEqual<ListShape>(const BaseShape &) const;
154 }  // namespace abstract
155 }  // namespace mindspore
156