• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/tensor_slice.h"
17 #include <vector>
18 #include "tensorflow/core/lib/core/errors.h"
19 #include "tensorflow/core/lib/strings/numbers.h"
20 #include "tensorflow/core/lib/strings/str_util.h"
21 #include "tensorflow/core/lib/strings/strcat.h"
22 #include "tensorflow/core/platform/logging.h"
23 
24 namespace tensorflow {
25 
TensorSlice(int dim)26 TensorSlice::TensorSlice(int dim) { SetFullSlice(dim); }
27 
TensorSlice(const TensorSliceProto & proto)28 TensorSlice::TensorSlice(const TensorSliceProto& proto) {
29   starts_.reserve(proto.extent_size());
30   lengths_.reserve(proto.extent_size());
31   for (const auto& e : proto.extent()) {
32     starts_.push_back(e.start());
33     lengths_.push_back(GetExtentLength(e));
34   }
35 }
36 
TensorSlice(std::initializer_list<std::pair<int64,int64>> extents)37 TensorSlice::TensorSlice(
38     std::initializer_list<std::pair<int64, int64>> extents) {
39   starts_.reserve(extents.size());
40   lengths_.reserve(extents.size());
41   for (const auto& e : extents) {
42     starts_.push_back(e.first);
43     lengths_.push_back(e.second);
44   }
45 }
46 
Parse(const string & str,TensorSlice * slice)47 Status TensorSlice::Parse(const string& str, TensorSlice* slice) {
48   std::vector<string> items = str_util::Split(str, ':', str_util::SkipEmpty());
49   slice->starts_.reserve(items.size());
50   slice->lengths_.reserve(items.size());
51   for (const string& x : items) {
52     int64 s, l;
53     if (x == "-") {
54       // "everything"
55       s = 0;
56       l = kFullExtent;
57     } else {
58       std::vector<string> sl = str_util::Split(x, ',', str_util::SkipEmpty());
59       if (sl.size() != 2 || !strings::safe_strto64(sl[0], &s) ||
60           !strings::safe_strto64(sl[1], &l)) {
61         return errors::InvalidArgument(
62             "Expected a pair of numbers or '-' "
63             "but got '",
64             x, "': string = ", str);
65       }
66       if (s < 0 || l <= 0) {
67         return errors::InvalidArgument(
68             "Expected non-negative start and "
69             "positive length but got start = ",
70             s, ", length = ", l, ": string = ", str);
71       }
72     }
73     slice->starts_.push_back(s);
74     slice->lengths_.push_back(l);
75   }
76 
77   return Status::OK();
78 }
79 
Clear()80 void TensorSlice::Clear() {
81   starts_.clear();
82   lengths_.clear();
83 }
84 
IsFull() const85 bool TensorSlice::IsFull() const {
86   for (int d = 0; d < dims(); ++d) {
87     if (!IsFullAt(d)) return false;
88   }
89   return true;
90 }
91 
SetFullSlice(int dim)92 void TensorSlice::SetFullSlice(int dim) {
93   Clear();
94   starts_.reserve(dim);
95   lengths_.reserve(dim);
96   for (int d = 0; d < dim; ++d) {
97     starts_.push_back(0);
98     lengths_.push_back(kFullExtent);
99   }
100 }
101 
Extend(int dim)102 void TensorSlice::Extend(int dim) {
103   int old_dim = dims();
104   DCHECK_LE(old_dim, dim);
105   starts_.resize(dim);
106   lengths_.resize(dim);
107   for (int d = old_dim; d < dim; ++d) {
108     starts_[d] = 0;
109     lengths_[d] = kFullExtent;
110   }
111 }
112 
AsProto(TensorSliceProto * proto) const113 void TensorSlice::AsProto(TensorSliceProto* proto) const {
114   for (int d = 0; d < dims(); ++d) {
115     TensorSliceProto::Extent* e = proto->add_extent();
116     // We only need to record the explicit slice for non-full slices
117     if (!IsFullAt(d)) {
118       e->set_start(starts_[d]);
119       e->set_length(lengths_[d]);
120     }
121   }
122 }
123 
DebugString() const124 string TensorSlice::DebugString() const {
125   string buffer;
126   bool first = true;
127   for (int d = 0; d < dims(); ++d) {
128     if (!first) {
129       buffer.append(":");
130     }
131     if (IsFullAt(d)) {
132       buffer.append("-");
133     } else {
134       strings::StrAppend(&buffer, starts_[d], ",", lengths_[d]);
135     }
136     first = false;
137   }
138   return buffer;
139 }
140 
Intersect(const TensorSlice & other,TensorSlice * result) const141 bool TensorSlice::Intersect(const TensorSlice& other,
142                             TensorSlice* result) const {
143   // First, if two slices have different ranks, they obviously don't overlap
144   // -- in fact they are not compatible.
145   if (dims() != other.dims()) {
146     return false;
147   }
148 
149   // Setting the result to the right dimension
150   if (result) {
151     result->SetFullSlice(dims());
152   }
153   // The two slices overlap if they overlap in all dimensions.
154   for (int d = 0; d < dims(); ++d) {
155     if (IsFullAt(d)) {
156       if (result) {
157         result->set_start(d, other.start(d));
158         result->set_length(d, other.length(d));
159       }
160     } else if (other.IsFullAt(d)) {
161       if (result) {
162         result->set_start(d, start(d));
163         result->set_length(d, length(d));
164       }
165     } else {
166       // If we have an intersection here, it should have a start that is the
167       // max of the two starts and an end that is the min of the two ends.
168       int64 s = std::max(start(d), other.start(d));
169       int64 l = std::min(end(d), other.end(d)) - s;
170       if (l > 0) {
171         // We have a real intersection
172         if (result) {
173           result->set_start(d, s);
174           result->set_length(d, l);
175         }
176       } else {
177         // We don't have an intersection for this dimension -- thus we don't
178         // have any intersection at all.
179         if (result) {
180           result->Clear();
181         }
182         return false;
183       }
184     }
185   }
186   // If we are here, we know there is overlap in every dimension.
187   return true;
188 }
189 
operator ==(const TensorSlice & other) const190 bool TensorSlice::operator==(const TensorSlice& other) const {
191   return dims() == other.dims() && starts_ == other.starts_ &&
192          lengths_ == other.lengths_;
193 }
194 
ComputeRelative(const TensorSlice & sub,TensorSlice * relative) const195 void TensorSlice::ComputeRelative(const TensorSlice& sub,
196                                   TensorSlice* relative) const {
197   DCHECK_EQ(dims(), sub.dims());
198   relative->SetFullSlice(dims());
199   for (int d = 0; d < dims(); ++d) {
200     if (IsFullAt(d)) {
201       relative->set_start(d, sub.start(d));
202       relative->set_length(d, sub.length(d));
203     } else {
204       // Otherwise the relative start is the difference between the start of
205       // sub and the start of base
206       relative->set_start(d, sub.start(d) - start(d));
207       relative->set_length(d, sub.length(d));
208     }
209   }
210 }
211 
UpdateToCover(const TensorSlice & other)212 void TensorSlice::UpdateToCover(const TensorSlice& other) {
213   DCHECK_EQ(dims(), other.dims());
214   for (int d = 0; d < dims(); ++d) {
215     if (!IsFullAt(d)) {
216       if (other.IsFullAt(d)) {
217         starts_[d] = 0;
218         lengths_[d] = kFullExtent;
219       } else {
220         const auto new_end = std::max(end(d), other.end(d));
221         set_start(d, std::min(start(d), other.start(d)));
222         set_length(d, new_end - start(d));
223       }
224     }
225   }
226 }
227 
228 // static
HasExtentLength(const TensorSliceProto::Extent & extent)229 bool TensorSlice::HasExtentLength(const TensorSliceProto::Extent& extent) {
230   return extent.has_length_case() == TensorSliceProto::Extent::kLength;
231 }
232 
233 // static
GetExtentLength(const TensorSliceProto::Extent & extent)234 int64 TensorSlice::GetExtentLength(const TensorSliceProto::Extent& extent) {
235   if (!HasExtentLength(extent)) return -1;
236   return extent.length();
237 }
238 
SliceTensorShape(const TensorShape & shape,TensorShape * result_shape) const239 Status TensorSlice::SliceTensorShape(const TensorShape& shape,
240                                      TensorShape* result_shape) const {
241   result_shape->Clear();
242   // Mismatching ranks: we can't apply the slice at all.
243   if (shape.dims() != dims()) {
244     return errors::Internal("Mismatching ranks: shape = ", shape.DebugString(),
245                             ", slice = ", DebugString());
246   }
247   for (int d = 0; d < dims(); ++d) {
248     if (IsFullAt(d)) {
249       result_shape->AddDim(shape.dim_size(d));
250     } else {
251       // Check if the extent applies to the dimension
252       if (end(d) <= shape.dim_size(d)) {
253         // Yes: the end is within the range of the dim -- we adjust the result
254         // shape so that its size along this dimension is the length of the
255         // slice.
256         result_shape->AddDim(length(d));
257       } else {
258         // The extent doesn't apply to the dimension
259         result_shape->Clear();
260         return errors::Internal("Extent in dimension ", d,
261                                 " out of bounds: shape = ", shape.DebugString(),
262                                 ", slice = ", DebugString());
263       }
264     }
265   }
266   // If we are here, we have successfully applied the shape.
267   return Status::OK();
268 }
269 
270 const int64 TensorSlice::kFullExtent = -1;
271 
272 }  // namespace tensorflow
273