1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17 #include <string>
18 #include <unordered_map>
19 #include <vector>
20
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/core/framework/kernel_def_builder.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/tensor_types.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/gtl/map_util.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
31
32 namespace tensorflow {
33
34 namespace {
35 // Returning a Status instead of using OP_REQUIRES directly since that doesn't
36 // seem to work outside the main OpKernel functions.
RemapVectorToMap(const TTypes<const int64>::Vec & remapping,std::vector<bool> * id_present,std::unordered_map<int64,int64> * old_id_to_new_id)37 Status RemapVectorToMap(const TTypes<const int64>::Vec& remapping,
38 std::vector<bool>* id_present,
39 std::unordered_map<int64, int64>* old_id_to_new_id) {
40 id_present->clear();
41 id_present->resize(remapping.size(), false);
42 for (int i = 0; i < remapping.size(); ++i) {
43 const int64 old_id = remapping(i);
44 if (old_id < 0) continue;
45 (*id_present)[i] = true;
46 if (!gtl::InsertIfNotPresent(old_id_to_new_id, old_id, i)) {
47 return errors::Unimplemented(
48 strings::StrCat("Old ID ", old_id, " is mapped to both new ID ",
49 old_id_to_new_id->at(old_id), " and ", i,
50 ", which is not supported."));
51 }
52 }
53 return Status::OK();
54 }
55 } // anonymous namespace
56
57 // This op loads a rank-2 Tensor (matrix) from a TensorFlow checkpoint (V2) and
58 // swaps around the rows/columns according to row_remapping/col_remapping.
59 // "Missing" cells are initialized with values from initializing_values.
60 class LoadAndRemapMatrixOp : public OpKernel {
61 public:
LoadAndRemapMatrixOp(OpKernelConstruction * context)62 explicit LoadAndRemapMatrixOp(OpKernelConstruction* context)
63 : OpKernel(context) {
64 OP_REQUIRES_OK(context, context->GetAttr("num_rows", &num_rows_));
65 OP_REQUIRES_OK(context, context->GetAttr("num_cols", &num_cols_));
66 OP_REQUIRES_OK(
67 context, context->GetAttr("max_rows_in_memory", &max_rows_in_memory_));
68 }
69
Compute(OpKernelContext * context)70 void Compute(OpKernelContext* context) override {
71 // Checks what we're remapping and inverts the relevant remapping Tensors to
72 // be maps with key = old ID, value = new ID.
73 std::unordered_map<int64, int64> old_row_to_new_row_map;
74 std::vector<bool> row_id_present;
75 const Tensor* row_remapping_t;
76 OP_REQUIRES_OK(context, context->input("row_remapping", &row_remapping_t));
77 const auto row_remapping = row_remapping_t->vec<int64>();
78 OP_REQUIRES(context, row_remapping.size() == num_rows_,
79 errors::InvalidArgument(strings::StrCat(
80 "Size of row_remapping is ", row_remapping.size(),
81 " instead of being equal to num_rows=", num_rows_)));
82 OP_REQUIRES_OK(context, RemapVectorToMap(row_remapping, &row_id_present,
83 &old_row_to_new_row_map));
84
85 // Calculates the min/max old row ID that we need to read, to save us from
86 // reading some unnecessary slices of the old tensor.
87 int64 min_old_row = -1;
88 int64 max_old_row = -1;
89 for (int i = 0; i < row_remapping.size(); ++i) {
90 if (min_old_row < 0 ||
91 (row_remapping(i) >= 0 && row_remapping(i) < min_old_row)) {
92 min_old_row = row_remapping(i);
93 }
94 if (max_old_row < 0 ||
95 (row_remapping(i) >= 0 && row_remapping(i) > max_old_row)) {
96 max_old_row = row_remapping(i);
97 }
98 }
99
100 // Processes the remapping for columns.
101 std::unordered_map<int64, int64> old_col_to_new_col_map;
102 std::vector<bool> col_id_present;
103 const Tensor* col_remapping_t;
104 OP_REQUIRES_OK(context, context->input("col_remapping", &col_remapping_t));
105 const auto col_remapping = col_remapping_t->vec<int64>();
106 // Note that we always "remap rows", even when the row vocabulary does
107 // not change, because partitioning requires a mapping from partitioned
108 // Variables to the full checkpoints we load.
109 const bool remap_cols = col_remapping.size() > 0;
110 if (remap_cols) {
111 OP_REQUIRES(
112 context, col_remapping.size() == num_cols_,
113 errors::InvalidArgument(strings::StrCat(
114 "Provided col_remapping, but its size is ", col_remapping.size(),
115 " instead of being equal to num_cols=", num_cols_)));
116 OP_REQUIRES_OK(context, RemapVectorToMap(col_remapping, &col_id_present,
117 &old_col_to_new_col_map));
118 } else {
119 col_id_present.clear();
120 col_id_present.resize(num_cols_, true);
121 }
122
123 // Processes the checkpoint source and the provided Tensor name.
124 const Tensor* ckpt_path_t;
125 OP_REQUIRES_OK(context, context->input("ckpt_path", &ckpt_path_t));
126 const string& ckpt_path = ckpt_path_t->scalar<tstring>()();
127 const Tensor* old_tensor_name_t;
128 OP_REQUIRES_OK(context,
129 context->input("old_tensor_name", &old_tensor_name_t));
130 const string& old_tensor_name = old_tensor_name_t->scalar<tstring>()();
131
132 LOG(INFO) << "Processing checkpoint : " << ckpt_path;
133 BundleReader reader(context->env(), ckpt_path);
134 OP_REQUIRES_OK(context, reader.status());
135
136 DataType tensor_type;
137 TensorShape tensor_shape;
138 OP_REQUIRES_OK(context, reader.LookupDtypeAndShape(
139 old_tensor_name, &tensor_type, &tensor_shape));
140 OP_REQUIRES(context, tensor_type == DT_FLOAT,
141 errors::InvalidArgument(strings::StrCat(
142 "Tensor ", old_tensor_name, " has invalid type ",
143 DataTypeString(tensor_type), " instead of expected type ",
144 DataTypeString(DT_FLOAT))));
145 // This op is limited to loading Tensors of rank 2 (matrices).
146 OP_REQUIRES(
147 context, tensor_shape.dims() == 2,
148 errors::InvalidArgument(strings::StrCat(
149 "Tensor ", old_tensor_name, " has shape ",
150 tensor_shape.DebugString(), " of invalid rank ",
151 tensor_shape.dims(), " instead of expected shape of rank 2.")));
152
153 if (!remap_cols) {
154 // TODO(weiho): Consider relaxing this restriction to allow partial column
155 // loading (even when no column remapping is specified) if there turns out
156 // to be a use case for it.
157 OP_REQUIRES(context, num_cols_ == tensor_shape.dim_size(1),
158 errors::InvalidArgument(strings::StrCat(
159 "Tensor ", old_tensor_name, " has shape ",
160 tensor_shape.DebugString(),
161 ", where the size of its 2nd dimension is ",
162 tensor_shape.dim_size(1),
163 " instead of being equal to num_cols=", num_cols_)));
164 }
165
166 // Uses TensorSlice to potentially load the old tensor in chunks in case
167 // memory usage is a concern.
168 std::vector<TensorSlice> tensor_slices;
169 TensorSlice slice(tensor_shape.dims());
170 if (min_old_row >= 0 && max_old_row >= 0) {
171 int64 row_start = min_old_row;
172 // TODO(weiho): Given the list of old row IDs of interest (the keys of
173 // old_row_to_new_row_map), we could also try something smarter to
174 // find some minimal set of covering ranges for the list of old row IDs
175 // such that the size of each range is less than max_rows_in_memory_.
176 while (row_start <= max_old_row) {
177 const int64 slice_length =
178 max_rows_in_memory_ <= 0
179 // If max_rows_in_memory_ <= 0, we just load the entire chunk.
180 ? max_old_row - row_start + 1
181 : std::min(max_rows_in_memory_, max_old_row - row_start + 1);
182 slice.set_start(0, row_start);
183 slice.set_length(0, slice_length);
184 tensor_slices.push_back(slice);
185 row_start += slice_length;
186 }
187 }
188
189 // Allocates the output matrix.
190 Tensor* output_matrix_t = nullptr;
191 OP_REQUIRES_OK(context,
192 context->allocate_output("output_matrix",
193 TensorShape({num_rows_, num_cols_}),
194 &output_matrix_t));
195 auto output_matrix = output_matrix_t->matrix<float>();
196
197 // Iterates through tensor slices and copies over values from the old tensor
198 // to the output matrix.
199 int64 row_index = min_old_row;
200 int64 rows_copied = 0;
201 Tensor loaded_tensor_t;
202 for (const TensorSlice& tensor_slice : tensor_slices) {
203 LOG(INFO) << "Loading slice " << tensor_slice.DebugString();
204 TensorShape slice_shape;
205 OP_REQUIRES_OK(context,
206 tensor_slice.SliceTensorShape(tensor_shape, &slice_shape));
207 // Potentially re-allocates the tensor buffer since the last slice may
208 // have fewer rows than the other slices.
209 if (loaded_tensor_t.shape() != slice_shape) {
210 loaded_tensor_t = Tensor(DT_FLOAT, slice_shape);
211 }
212 OP_REQUIRES_OK(context, reader.LookupSlice(old_tensor_name, tensor_slice,
213 &loaded_tensor_t));
214
215 // Iterates through the old loaded tensor slice row-by-row.
216 for (int row = 0; row < loaded_tensor_t.dim_size(0); ++row, ++row_index) {
217 if (row_index % 500000 == min_old_row) {
218 LOG(INFO) << "Processing old row " << row_index;
219 }
220
221 // If the old row ID is not found in old_row_to_new_row_map, continue
222 // to the next row; otherwise, copy it to the output matrix.
223 const int64* new_row_ptr =
224 gtl::FindOrNull(old_row_to_new_row_map, row_index);
225 if (new_row_ptr == nullptr) {
226 continue;
227 }
228 ++rows_copied;
229 const int64 new_row = *new_row_ptr;
230
231 // Copies over the row element-by-element, in case remapping is needed
232 // along the column axis.
233 const auto& loaded_tensor = loaded_tensor_t.matrix<float>();
234 for (int old_col = 0; old_col < loaded_tensor_t.dim_size(1);
235 ++old_col) {
236 int64 new_col = old_col;
237 if (remap_cols) {
238 const int64* new_col_ptr =
239 gtl::FindOrNull(old_col_to_new_col_map, old_col);
240 if (new_col_ptr == nullptr) {
241 // Column remapping is specified, but this column is not found in
242 // old_col_to_new_col_map, so we leave it uninitialized, to be
243 // filled in with initializing_values later.
244 continue;
245 }
246 new_col = *new_col_ptr;
247 }
248
249 OP_REQUIRES(context,
250 new_row < num_rows_ && new_col < num_cols_ &&
251 new_row >= 0 && new_col >= 0,
252 errors::Internal(strings::StrCat(
253 "new_row=", new_row, " and new_col=", new_col,
254 " should have been less than num_rows_=", num_rows_,
255 " and num_cols_=", num_cols_,
256 " and non-negative. This should never have happened "
257 "if the code were correct. Please file a bug.")));
258 output_matrix(new_row, new_col) = loaded_tensor(row, old_col);
259 }
260 }
261 }
262 LOG(INFO) << "Copied " << rows_copied << " rows from old matrix (with "
263 << tensor_shape.dim_size(0) << " rows) to new matrix (with "
264 << num_rows_ << " rows).";
265
266 // At this point, there are potentially whole rows/columns uninitialized
267 // (corresponding to the indices where row_id_present/col_id_present are
268 // false). We fill this in cell-by-cell using row_id_present and
269 // col_id_present while dequeuing from the initializing_values vector.
270 const Tensor* initializing_values_t;
271 OP_REQUIRES_OK(
272 context, context->input("initializing_values", &initializing_values_t));
273 const auto initializing_values = initializing_values_t->flat<float>();
274 int64 initializing_values_index = 0;
275 for (int i = 0; i < num_rows_; ++i) {
276 for (int j = 0; j < num_cols_; ++j) {
277 if (row_id_present[i] && col_id_present[j]) continue;
278 OP_REQUIRES(
279 context, initializing_values_index < initializing_values.size(),
280 errors::InvalidArgument(
281 "initializing_values contained ", initializing_values.size(),
282 " elements, but more missing values remain."));
283 output_matrix(i, j) = initializing_values(initializing_values_index);
284 ++initializing_values_index;
285 }
286 }
287
288 // Checks that we used all the given initializing values.
289 OP_REQUIRES(
290 context, initializing_values_index == initializing_values.size(),
291 errors::InvalidArgument(
292 "initializing_values contained ", initializing_values.size(),
293 " elements, but only ", initializing_values_index,
294 " elements were used to fill in missing values."));
295 }
296
297 private:
298 int64 num_rows_;
299 int64 num_cols_;
300 int64 max_rows_in_memory_;
301 };
302
303 REGISTER_KERNEL_BUILDER(Name("LoadAndRemapMatrix").Device(DEVICE_CPU),
304 LoadAndRemapMatrixOp);
305
306 } // namespace tensorflow
307