1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/util/tensor_slice_set.h"
17
18 #include <vector>
19 #include "tensorflow/core/lib/core/errors.h"
20 #include "tensorflow/core/lib/gtl/map_util.h"
21 #include "tensorflow/core/platform/logging.h"
22 #include "tensorflow/core/util/tensor_slice_util.h"
23
24 namespace tensorflow {
25
26 namespace checkpoint {
27
TensorSliceSet(const TensorShape & shape,DataType type)28 TensorSliceSet::TensorSliceSet(const TensorShape& shape, DataType type)
29 : shape_(shape), type_(type) {}
30
~TensorSliceSet()31 TensorSliceSet::~TensorSliceSet() {}
32
Register(const TensorSlice & slice,const string & tag)33 Status TensorSliceSet::Register(const TensorSlice& slice, const string& tag) {
34 TensorShape result_shape;
35 TF_RETURN_IF_ERROR(slice.SliceTensorShape(shape_, &result_shape));
36 string str = slice.DebugString();
37
38 if (slices_.empty()) {
39 slices_hull_ = slice;
40 } else {
41 // We check if there is any intersection between this slice and any of the
42 // registered slices.
43 if (slices_hull_.Overlaps(slice)) {
44 for (const auto& x : slices_) {
45 if (slice.Overlaps(x.second.slice)) {
46 return errors::Internal("Overlapping slices: existing slice = ",
47 x.first, ", new slice = ", str);
48 }
49 }
50 }
51 // No overlap: we can now insert the slice
52 slices_hull_.UpdateToCover(slice);
53 }
54
55 TensorSliceSet::SliceInfo info = {slice, tag, result_shape.num_elements()};
56 slices_.insert(std::make_pair(str, info));
57 return OkStatus();
58 }
59
QueryMeta(const TensorSlice & slice,std::vector<std::pair<TensorSlice,string>> * results) const60 bool TensorSliceSet::QueryMeta(
61 const TensorSlice& slice,
62 std::vector<std::pair<TensorSlice, string>>* results) const {
63 results->clear();
64 Status s;
65 string str = slice.DebugString();
66 // First we check if there is an exactly match (this is the dominant case).
67 const TensorSliceSet::SliceInfo* info = gtl::FindOrNull(slices_, str);
68 if (info) {
69 results->emplace_back(std::make_pair(info->slice, info->tag));
70 return true;
71 } else {
72 // We didn't find any exact match but there is still a possibility that
73 // multiple existing slices can be patched together to output the slice.
74 // We figure this out by computing the intersection of each of the existing
75 // slices with the query slice, and check if the union of all these
76 // intersections cover the entire slice. We rely on the fact that the
77 // existing slices don't have any intersection among themselves.
78 TensorShape target_shape;
79 Status s;
80 s = slice.SliceTensorShape(shape_, &target_shape);
81 if (!s.ok()) {
82 LOG(WARNING) << s;
83 return false;
84 }
85 int64_t total_size = target_shape.num_elements();
86
87 int64_t overlap_size = 0;
88 TensorSlice intersection;
89 TensorShape inter_shape;
90 for (const auto& x : slices_) {
91 if (slice.Intersect(x.second.slice, &intersection)) {
92 s = intersection.SliceTensorShape(shape_, &inter_shape);
93 if (!s.ok()) {
94 LOG(WARNING) << s;
95 return false;
96 }
97 overlap_size += inter_shape.num_elements();
98 results->emplace_back(std::make_pair(x.second.slice, x.second.tag));
99 }
100 }
101 if (total_size == overlap_size) {
102 // We have it!
103 return true;
104 } else {
105 // We don't have all the data for the asked tensor slice
106 results->clear();
107 return false;
108 }
109 }
110 }
111
RegisterTensorSlice(const string & name,const TensorShape & shape,DataType type,const string & tag,const TensorSlice & slice,std::unordered_map<string,TensorSliceSet * > * tensor_slices)112 Status RegisterTensorSlice(
113 const string& name, const TensorShape& shape, DataType type,
114 const string& tag, const TensorSlice& slice,
115 std::unordered_map<string, TensorSliceSet*>* tensor_slices) {
116 DCHECK_NE(tensor_slices, nullptr);
117 TensorSliceSet* tss = gtl::FindPtrOrNull(*tensor_slices, name);
118 // Create a tensor slice set if needed
119 if (!tss) {
120 tss = new TensorSliceSet(shape, type);
121 tensor_slices->insert(std::make_pair(name, tss));
122 } else {
123 // Check if the shapes match
124 const TensorShape& tss_shape(tss->shape());
125 if (!shape.IsSameSize(tss_shape)) {
126 return errors::Internal("Incompatible tensor shapes detected for tensor ",
127 name, ": existing = ", tss_shape.DebugString(),
128 ", new = ", shape.DebugString());
129 }
130 if (type != tss->type()) {
131 return errors::Internal("Incompatible tensor types detected for tensor ",
132 name,
133 ": existing = ", DataTypeString(tss->type()),
134 ", new = ", DataTypeString(type));
135 }
136 }
137 // Register the tensor slices without the actual data.
138 return tss->Register(slice, tag);
139 }
140
141 } // namespace checkpoint
142
143 } // namespace tensorflow
144