1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/tensor_slice.h"
17 #include <vector>
18 #include "tensorflow/core/lib/core/errors.h"
19 #include "tensorflow/core/lib/strings/numbers.h"
20 #include "tensorflow/core/lib/strings/str_util.h"
21 #include "tensorflow/core/lib/strings/strcat.h"
22 #include "tensorflow/core/platform/logging.h"
23
24 namespace tensorflow {
25
TensorSlice(int dim)26 TensorSlice::TensorSlice(int dim) { SetFullSlice(dim); }
27
TensorSlice(const TensorSliceProto & proto)28 TensorSlice::TensorSlice(const TensorSliceProto& proto) {
29 starts_.reserve(proto.extent_size());
30 lengths_.reserve(proto.extent_size());
31 for (const auto& e : proto.extent()) {
32 starts_.push_back(e.start());
33 lengths_.push_back(GetExtentLength(e));
34 }
35 }
36
TensorSlice(std::initializer_list<std::pair<int64,int64>> extents)37 TensorSlice::TensorSlice(
38 std::initializer_list<std::pair<int64, int64>> extents) {
39 starts_.reserve(extents.size());
40 lengths_.reserve(extents.size());
41 for (const auto& e : extents) {
42 starts_.push_back(e.first);
43 lengths_.push_back(e.second);
44 }
45 }
46
Parse(const string & str,TensorSlice * slice)47 Status TensorSlice::Parse(const string& str, TensorSlice* slice) {
48 std::vector<string> items = str_util::Split(str, ':', str_util::SkipEmpty());
49 slice->starts_.reserve(items.size());
50 slice->lengths_.reserve(items.size());
51 for (const string& x : items) {
52 int64 s, l;
53 if (x == "-") {
54 // "everything"
55 s = 0;
56 l = kFullExtent;
57 } else {
58 std::vector<string> sl = str_util::Split(x, ',', str_util::SkipEmpty());
59 if (sl.size() != 2 || !strings::safe_strto64(sl[0], &s) ||
60 !strings::safe_strto64(sl[1], &l)) {
61 return errors::InvalidArgument(
62 "Expected a pair of numbers or '-' "
63 "but got '",
64 x, "': string = ", str);
65 }
66 if (s < 0 || l <= 0) {
67 return errors::InvalidArgument(
68 "Expected non-negative start and "
69 "positive length but got start = ",
70 s, ", length = ", l, ": string = ", str);
71 }
72 }
73 slice->starts_.push_back(s);
74 slice->lengths_.push_back(l);
75 }
76
77 return Status::OK();
78 }
79
Clear()80 void TensorSlice::Clear() {
81 starts_.clear();
82 lengths_.clear();
83 }
84
IsFull() const85 bool TensorSlice::IsFull() const {
86 for (int d = 0; d < dims(); ++d) {
87 if (!IsFullAt(d)) return false;
88 }
89 return true;
90 }
91
SetFullSlice(int dim)92 void TensorSlice::SetFullSlice(int dim) {
93 Clear();
94 starts_.reserve(dim);
95 lengths_.reserve(dim);
96 for (int d = 0; d < dim; ++d) {
97 starts_.push_back(0);
98 lengths_.push_back(kFullExtent);
99 }
100 }
101
Extend(int dim)102 void TensorSlice::Extend(int dim) {
103 int old_dim = dims();
104 DCHECK_LE(old_dim, dim);
105 starts_.resize(dim);
106 lengths_.resize(dim);
107 for (int d = old_dim; d < dim; ++d) {
108 starts_[d] = 0;
109 lengths_[d] = kFullExtent;
110 }
111 }
112
AsProto(TensorSliceProto * proto) const113 void TensorSlice::AsProto(TensorSliceProto* proto) const {
114 for (int d = 0; d < dims(); ++d) {
115 TensorSliceProto::Extent* e = proto->add_extent();
116 // We only need to record the explicit slice for non-full slices
117 if (!IsFullAt(d)) {
118 e->set_start(starts_[d]);
119 e->set_length(lengths_[d]);
120 }
121 }
122 }
123
DebugString() const124 string TensorSlice::DebugString() const {
125 string buffer;
126 bool first = true;
127 for (int d = 0; d < dims(); ++d) {
128 if (!first) {
129 buffer.append(":");
130 }
131 if (IsFullAt(d)) {
132 buffer.append("-");
133 } else {
134 strings::StrAppend(&buffer, starts_[d], ",", lengths_[d]);
135 }
136 first = false;
137 }
138 return buffer;
139 }
140
Intersect(const TensorSlice & other,TensorSlice * result) const141 bool TensorSlice::Intersect(const TensorSlice& other,
142 TensorSlice* result) const {
143 // First, if two slices have different ranks, they obviously don't overlap
144 // -- in fact they are not compatible.
145 if (dims() != other.dims()) {
146 return false;
147 }
148
149 // Setting the result to the right dimension
150 if (result) {
151 result->SetFullSlice(dims());
152 }
153 // The two slices overlap if they overlap in all dimensions.
154 for (int d = 0; d < dims(); ++d) {
155 if (IsFullAt(d)) {
156 if (result) {
157 result->set_start(d, other.start(d));
158 result->set_length(d, other.length(d));
159 }
160 } else if (other.IsFullAt(d)) {
161 if (result) {
162 result->set_start(d, start(d));
163 result->set_length(d, length(d));
164 }
165 } else {
166 // If we have an intersection here, it should have a start that is the
167 // max of the two starts and an end that is the min of the two ends.
168 int64 s = std::max(start(d), other.start(d));
169 int64 l = std::min(end(d), other.end(d)) - s;
170 if (l > 0) {
171 // We have a real intersection
172 if (result) {
173 result->set_start(d, s);
174 result->set_length(d, l);
175 }
176 } else {
177 // We don't have an intersection for this dimension -- thus we don't
178 // have any intersection at all.
179 if (result) {
180 result->Clear();
181 }
182 return false;
183 }
184 }
185 }
186 // If we are here, we know there is overlap in every dimension.
187 return true;
188 }
189
operator ==(const TensorSlice & other) const190 bool TensorSlice::operator==(const TensorSlice& other) const {
191 return dims() == other.dims() && starts_ == other.starts_ &&
192 lengths_ == other.lengths_;
193 }
194
ComputeRelative(const TensorSlice & sub,TensorSlice * relative) const195 void TensorSlice::ComputeRelative(const TensorSlice& sub,
196 TensorSlice* relative) const {
197 DCHECK_EQ(dims(), sub.dims());
198 relative->SetFullSlice(dims());
199 for (int d = 0; d < dims(); ++d) {
200 if (IsFullAt(d)) {
201 relative->set_start(d, sub.start(d));
202 relative->set_length(d, sub.length(d));
203 } else {
204 // Otherwise the relative start is the difference between the start of
205 // sub and the start of base
206 relative->set_start(d, sub.start(d) - start(d));
207 relative->set_length(d, sub.length(d));
208 }
209 }
210 }
211
UpdateToCover(const TensorSlice & other)212 void TensorSlice::UpdateToCover(const TensorSlice& other) {
213 DCHECK_EQ(dims(), other.dims());
214 for (int d = 0; d < dims(); ++d) {
215 if (!IsFullAt(d)) {
216 if (other.IsFullAt(d)) {
217 starts_[d] = 0;
218 lengths_[d] = kFullExtent;
219 } else {
220 const auto new_end = std::max(end(d), other.end(d));
221 set_start(d, std::min(start(d), other.start(d)));
222 set_length(d, new_end - start(d));
223 }
224 }
225 }
226 }
227
228 // static
HasExtentLength(const TensorSliceProto::Extent & extent)229 bool TensorSlice::HasExtentLength(const TensorSliceProto::Extent& extent) {
230 return extent.has_length_case() == TensorSliceProto::Extent::kLength;
231 }
232
233 // static
GetExtentLength(const TensorSliceProto::Extent & extent)234 int64 TensorSlice::GetExtentLength(const TensorSliceProto::Extent& extent) {
235 if (!HasExtentLength(extent)) return -1;
236 return extent.length();
237 }
238
SliceTensorShape(const TensorShape & shape,TensorShape * result_shape) const239 Status TensorSlice::SliceTensorShape(const TensorShape& shape,
240 TensorShape* result_shape) const {
241 result_shape->Clear();
242 // Mismatching ranks: we can't apply the slice at all.
243 if (shape.dims() != dims()) {
244 return errors::Internal("Mismatching ranks: shape = ", shape.DebugString(),
245 ", slice = ", DebugString());
246 }
247 for (int d = 0; d < dims(); ++d) {
248 if (IsFullAt(d)) {
249 result_shape->AddDim(shape.dim_size(d));
250 } else {
251 // Check if the extent applies to the dimension
252 if (end(d) <= shape.dim_size(d)) {
253 // Yes: the end is within the range of the dim -- we adjust the result
254 // shape so that its size along this dimension is the length of the
255 // slice.
256 result_shape->AddDim(length(d));
257 } else {
258 // The extent doesn't apply to the dimension
259 result_shape->Clear();
260 return errors::Internal("Extent in dimension ", d,
261 " out of bounds: shape = ", shape.DebugString(),
262 ", slice = ", DebugString());
263 }
264 }
265 }
266 // If we are here, we have successfully applied the shape.
267 return Status::OK();
268 }
269
270 const int64 TensorSlice::kFullExtent = -1;
271
272 } // namespace tensorflow
273