1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "abstract/dshape.h"
20
21 #include <exception>
22 #include <iostream>
23
24 #include "utils/log_adapter.h"
25
26 namespace mindspore {
27 namespace abstract {
28 namespace {
ShapeVectorToStr(const std::vector<int64_t> & shp)29 std::string ShapeVectorToStr(const std::vector<int64_t> &shp) {
30 std::ostringstream buffer;
31 bool f_begin = true;
32 buffer << "(";
33 for (auto &x : shp) {
34 if (!f_begin) {
35 buffer << ", ";
36 } else {
37 f_begin = false;
38 }
39 buffer << x;
40 }
41 buffer << ")";
42 return buffer.str();
43 }
44 } // namespace
45 // used for print BaseShape content
operator <<(std::ostream & os,const BaseShape & bs)46 std::ostream &operator<<(std::ostream &os, const BaseShape &bs) {
47 os << bs.ToString();
48 return os;
49 }
50
operator <<(std::ostream & os,const std::shared_ptr<BaseShape> bs)51 std::ostream &operator<<(std::ostream &os, const std::shared_ptr<BaseShape> bs) {
52 MS_EXCEPTION_IF_NULL(bs);
53 os << bs->ToString();
54 return os;
55 }
56
operator ==(const BaseShape & other) const57 bool BaseShape::operator==(const BaseShape &other) const {
58 if (tid() != other.tid()) {
59 return false;
60 }
61 return true;
62 }
63
operator !=(const BaseShape & other) const64 bool BaseShape::operator!=(const BaseShape &other) const { return !(*this == other); }
65
ToString() const66 std::string Shape::ToString() const {
67 std::ostringstream buffer;
68 bool has_dyn_shape = IsDynamic();
69 if (has_dyn_shape) {
70 buffer << "{shape:";
71 }
72 buffer << ShapeVectorToStr(shape_);
73 if (has_dyn_shape) {
74 buffer << "|min shape:";
75 buffer << ShapeVectorToStr(min_shape_);
76 buffer << "|max shape:";
77 buffer << ShapeVectorToStr(max_shape_);
78 buffer << "}";
79 }
80 return buffer.str();
81 }
82
DumpText() const83 std::string Shape::DumpText() const {
84 std::ostringstream buffer;
85 buffer << "[";
86 for (size_t i = 0; i < shape_.size(); i++) {
87 buffer << (i > 0 ? ", " : "") << shape_[i];
88 if (shape_[i] == SHP_ANY && min_shape_.size() == shape_.size() && max_shape_.size() == shape_.size()) {
89 buffer << "_" << min_shape_[i] << "^" << max_shape_[i];
90 }
91 }
92 buffer << "]";
93 return buffer.str();
94 }
95
operator ==(const BaseShape & other) const96 bool Shape::operator==(const BaseShape &other) const {
97 if (tid() != other.tid()) {
98 return false;
99 }
100 return shape_ == static_cast<const Shape &>(other).shape_;
101 }
102
103 const int64_t Shape::SHP_ANY;
Broaden()104 void Shape::Broaden() {
105 for (size_t i = 0; i < shape_.size(); i++) {
106 shape_[i] = SHP_ANY;
107 }
108 }
109
ToString() const110 std::string SequeueShape::ToString() const {
111 std::ostringstream buffer;
112 bool f_begin = true;
113 for (const auto &p_shp : p_shapes_) {
114 if (!f_begin) {
115 buffer << ", ";
116 } else {
117 f_begin = false;
118 }
119 MS_EXCEPTION_IF_NULL(p_shp);
120 buffer << p_shp->ToString();
121 }
122 return buffer.str();
123 }
124
ElementsClone() const125 BaseShapePtrList SequeueShape::ElementsClone() const {
126 BaseShapePtrList ele_list;
127 for (auto p_shp : p_shapes_) {
128 MS_EXCEPTION_IF_NULL(p_shp);
129 ele_list.push_back(p_shp->Clone());
130 }
131 return ele_list;
132 }
133
134 template <typename T>
SequeueEqual(const BaseShape & other) const135 bool SequeueShape::SequeueEqual(const BaseShape &other) const {
136 if (tid() != other.tid()) {
137 return false;
138 }
139 auto other_shapes = static_cast<const T &>(other).p_shapes_;
140 if (other_shapes.size() != p_shapes_.size()) {
141 return false;
142 }
143 for (uint64_t i = 0; i < p_shapes_.size(); ++i) {
144 MS_EXCEPTION_IF_NULL(p_shapes_[i]);
145 MS_EXCEPTION_IF_NULL(other_shapes[i]);
146 if (!(*p_shapes_[i] == *other_shapes[i])) {
147 return false;
148 }
149 }
150 return true;
151 }
152 template bool SequeueShape::SequeueEqual<TupleShape>(const BaseShape &) const;
153 template bool SequeueShape::SequeueEqual<ListShape>(const BaseShape &) const;
154 } // namespace abstract
155 } // namespace mindspore
156