• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/util/tensor_slice_set.h"
17 
18 #include <vector>
19 #include "tensorflow/core/lib/core/errors.h"
20 #include "tensorflow/core/lib/gtl/map_util.h"
21 #include "tensorflow/core/platform/logging.h"
22 #include "tensorflow/core/util/tensor_slice_util.h"
23 
24 namespace tensorflow {
25 
26 namespace checkpoint {
27 
TensorSliceSet(const TensorShape & shape,DataType type)28 TensorSliceSet::TensorSliceSet(const TensorShape& shape, DataType type)
29     : shape_(shape), type_(type) {}
30 
~TensorSliceSet()31 TensorSliceSet::~TensorSliceSet() {}
32 
Register(const TensorSlice & slice,const string & tag)33 Status TensorSliceSet::Register(const TensorSlice& slice, const string& tag) {
34   TensorShape result_shape;
35   TF_RETURN_IF_ERROR(slice.SliceTensorShape(shape_, &result_shape));
36   string str = slice.DebugString();
37 
38   if (slices_.empty()) {
39     slices_hull_ = slice;
40   } else {
41     // We check if there is any intersection between this slice and any of the
42     // registered slices.
43     if (slices_hull_.Overlaps(slice)) {
44       for (const auto& x : slices_) {
45         if (slice.Overlaps(x.second.slice)) {
46           return errors::Internal("Overlapping slices: existing slice = ",
47                                   x.first, ", new slice = ", str);
48         }
49       }
50     }
51     // No overlap: we can now insert the slice
52     slices_hull_.UpdateToCover(slice);
53   }
54 
55   TensorSliceSet::SliceInfo info = {slice, tag, result_shape.num_elements()};
56   slices_.insert(std::make_pair(str, info));
57   return OkStatus();
58 }
59 
QueryMeta(const TensorSlice & slice,std::vector<std::pair<TensorSlice,string>> * results) const60 bool TensorSliceSet::QueryMeta(
61     const TensorSlice& slice,
62     std::vector<std::pair<TensorSlice, string>>* results) const {
63   results->clear();
64   Status s;
65   string str = slice.DebugString();
66   // First we check if there is an exactly match (this is the dominant case).
67   const TensorSliceSet::SliceInfo* info = gtl::FindOrNull(slices_, str);
68   if (info) {
69     results->emplace_back(std::make_pair(info->slice, info->tag));
70     return true;
71   } else {
72     // We didn't find any exact match but there is still a possibility that
73     // multiple existing slices can be patched together to output the slice.
74     // We figure this out by computing the intersection of each of the existing
75     // slices with the query slice, and check if the union of all these
76     // intersections cover the entire slice. We rely on the fact that the
77     // existing slices don't have any intersection among themselves.
78     TensorShape target_shape;
79     Status s;
80     s = slice.SliceTensorShape(shape_, &target_shape);
81     if (!s.ok()) {
82       LOG(WARNING) << s;
83       return false;
84     }
85     int64_t total_size = target_shape.num_elements();
86 
87     int64_t overlap_size = 0;
88     TensorSlice intersection;
89     TensorShape inter_shape;
90     for (const auto& x : slices_) {
91       if (slice.Intersect(x.second.slice, &intersection)) {
92         s = intersection.SliceTensorShape(shape_, &inter_shape);
93         if (!s.ok()) {
94           LOG(WARNING) << s;
95           return false;
96         }
97         overlap_size += inter_shape.num_elements();
98         results->emplace_back(std::make_pair(x.second.slice, x.second.tag));
99       }
100     }
101     if (total_size == overlap_size) {
102       // We have it!
103       return true;
104     } else {
105       // We don't have all the data for the asked tensor slice
106       results->clear();
107       return false;
108     }
109   }
110 }
111 
RegisterTensorSlice(const string & name,const TensorShape & shape,DataType type,const string & tag,const TensorSlice & slice,std::unordered_map<string,TensorSliceSet * > * tensor_slices)112 Status RegisterTensorSlice(
113     const string& name, const TensorShape& shape, DataType type,
114     const string& tag, const TensorSlice& slice,
115     std::unordered_map<string, TensorSliceSet*>* tensor_slices) {
116   DCHECK_NE(tensor_slices, nullptr);
117   TensorSliceSet* tss = gtl::FindPtrOrNull(*tensor_slices, name);
118   // Create a tensor slice set if needed
119   if (!tss) {
120     tss = new TensorSliceSet(shape, type);
121     tensor_slices->insert(std::make_pair(name, tss));
122   } else {
123     // Check if the shapes match
124     const TensorShape& tss_shape(tss->shape());
125     if (!shape.IsSameSize(tss_shape)) {
126       return errors::Internal("Incompatible tensor shapes detected for tensor ",
127                               name, ": existing = ", tss_shape.DebugString(),
128                               ", new = ", shape.DebugString());
129     }
130     if (type != tss->type()) {
131       return errors::Internal("Incompatible tensor types detected for tensor ",
132                               name,
133                               ": existing = ", DataTypeString(tss->type()),
134                               ", new = ", DataTypeString(type));
135     }
136   }
137   // Register the tensor slices without the actual data.
138   return tss->Register(slice, tag);
139 }
140 
141 }  // namespace checkpoint
142 
143 }  // namespace tensorflow
144