1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17 #include <string>
18 #include <unordered_map>
19 #include <vector>
20
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/core/framework/kernel_def_builder.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/tensor_types.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/gtl/map_util.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
31
32 namespace tensorflow {
33
34 namespace {
35 // Returning a Status instead of using OP_REQUIRES directly since that doesn't
36 // seem to work outside the main OpKernel functions.
RemapVectorToMap(const TTypes<const int64>::Vec & remapping,std::vector<bool> * id_present,std::unordered_map<int64,int64> * old_id_to_new_id)37 Status RemapVectorToMap(const TTypes<const int64>::Vec& remapping,
38 std::vector<bool>* id_present,
39 std::unordered_map<int64, int64>* old_id_to_new_id) {
40 id_present->clear();
41 id_present->resize(remapping.size(), false);
42 for (int i = 0; i < remapping.size(); ++i) {
43 const int64 old_id = remapping(i);
44 if (old_id < 0) continue;
45 (*id_present)[i] = true;
46 if (!gtl::InsertIfNotPresent(old_id_to_new_id, old_id, i)) {
47 return errors::Unimplemented(
48 strings::StrCat("Old ID ", old_id, " is mapped to both new ID ",
49 old_id_to_new_id->at(old_id), " and ", i,
50 ", which is not supported."));
51 }
52 }
53 return Status::OK();
54 }
55 } // anonymous namespace
56
57 // This op loads a rank-2 Tensor (matrix) from a TensorFlow checkpoint (V2) and
58 // swaps around the rows/columns according to row_remapping/col_remapping.
59 // "Missing" cells are initialized with values from initializing_values.
60 class LoadAndRemapMatrixOp : public OpKernel {
61 public:
LoadAndRemapMatrixOp(OpKernelConstruction * context)62 explicit LoadAndRemapMatrixOp(OpKernelConstruction* context)
63 : OpKernel(context) {
64 OP_REQUIRES_OK(context, context->GetAttr("num_rows", &num_rows_));
65 OP_REQUIRES_OK(context, context->GetAttr("num_cols", &num_cols_));
66 OP_REQUIRES_OK(
67 context, context->GetAttr("max_rows_in_memory", &max_rows_in_memory_));
68 }
69
Compute(OpKernelContext * context)70 void Compute(OpKernelContext* context) override {
71 // Checks what we're remapping and inverts the relevant remapping Tensors to
72 // be maps with key = old ID, value = new ID.
73 std::unordered_map<int64, int64> old_row_to_new_row_map;
74 std::vector<bool> row_id_present;
75 const Tensor* row_remapping_t;
76 OP_REQUIRES_OK(context, context->input("row_remapping", &row_remapping_t));
77 const auto row_remapping = row_remapping_t->vec<int64>();
78 OP_REQUIRES(context, row_remapping.size() == num_rows_,
79 errors::InvalidArgument(strings::StrCat(
80 "Size of row_remapping is ", row_remapping.size(),
81 " instead of being equal to num_rows=", num_rows_)));
82 OP_REQUIRES_OK(context, RemapVectorToMap(row_remapping, &row_id_present,
83 &old_row_to_new_row_map));
84
85 // Calculates the min/max old row ID that we need to read, to save us from
86 // reading some unnecessary slices of the old tensor.
87 int64 min_old_row = -1;
88 int64 max_old_row = -1;
89 for (int i = 0; i < row_remapping.size(); ++i) {
90 if (min_old_row < 0 ||
91 (row_remapping(i) >= 0 && row_remapping(i) < min_old_row)) {
92 min_old_row = row_remapping(i);
93 }
94 if (max_old_row < 0 ||
95 (row_remapping(i) >= 0 && row_remapping(i) > max_old_row)) {
96 max_old_row = row_remapping(i);
97 }
98 }
99
100 // Processes the remapping for columns.
101 std::unordered_map<int64, int64> old_col_to_new_col_map;
102 std::vector<bool> col_id_present;
103 const Tensor* col_remapping_t;
104 OP_REQUIRES_OK(context, context->input("col_remapping", &col_remapping_t));
105 const auto col_remapping = col_remapping_t->vec<int64>();
106 // Note that we always "remap rows", even when the row vocabulary does
107 // not change, because partitioning requires a mapping from partitioned
108 // Variables to the full checkpoints we load.
109 const bool remap_cols = col_remapping.size() > 0;
110 if (remap_cols) {
111 OP_REQUIRES(
112 context, col_remapping.size() == num_cols_,
113 errors::InvalidArgument(strings::StrCat(
114 "Provided col_remapping, but its size is ", col_remapping.size(),
115 " instead of being equal to num_cols=", num_cols_)));
116 OP_REQUIRES_OK(context, RemapVectorToMap(col_remapping, &col_id_present,
117 &old_col_to_new_col_map));
118 } else {
119 col_id_present.clear();
120 col_id_present.resize(num_cols_, true);
121 }
122
123 // Processes the checkpoint source and the provided Tensor name.
124 const Tensor* ckpt_path_t;
125 OP_REQUIRES_OK(context, context->input("ckpt_path", &ckpt_path_t));
126 const string ckpt_path = *(ckpt_path_t->scalar<string>().data());
127 const Tensor* old_tensor_name_t;
128 OP_REQUIRES_OK(context,
129 context->input("old_tensor_name", &old_tensor_name_t));
130 const string old_tensor_name =
131 *(old_tensor_name_t->scalar<string>().data());
132
133 LOG(INFO) << "Processing checkpoint : " << ckpt_path;
134 BundleReader reader(context->env(), ckpt_path);
135 OP_REQUIRES_OK(context, reader.status());
136
137 DataType tensor_type;
138 TensorShape tensor_shape;
139 OP_REQUIRES_OK(context, reader.LookupDtypeAndShape(
140 old_tensor_name, &tensor_type, &tensor_shape));
141 OP_REQUIRES(context, tensor_type == DT_FLOAT,
142 errors::InvalidArgument(strings::StrCat(
143 "Tensor ", old_tensor_name, " has invalid type ",
144 DataTypeString(tensor_type), " instead of expected type ",
145 DataTypeString(DT_FLOAT))));
146 // This op is limited to loading Tensors of rank 2 (matrices).
147 OP_REQUIRES(
148 context, tensor_shape.dims() == 2,
149 errors::InvalidArgument(strings::StrCat(
150 "Tensor ", old_tensor_name, " has shape ",
151 tensor_shape.DebugString(), " of invalid rank ",
152 tensor_shape.dims(), " instead of expected shape of rank 2.")));
153
154 if (!remap_cols) {
155 // TODO(weiho): Consider relaxing this restriction to allow partial column
156 // loading (even when no column remapping is specified) if there turns out
157 // to be a use case for it.
158 OP_REQUIRES(context, num_cols_ == tensor_shape.dim_size(1),
159 errors::InvalidArgument(strings::StrCat(
160 "Tensor ", old_tensor_name, " has shape ",
161 tensor_shape.DebugString(),
162 ", where the size of its 2nd dimension is ",
163 tensor_shape.dim_size(1),
164 " instead of being equal to num_cols=", num_cols_)));
165 }
166
167 // Uses TensorSlice to potentially load the old tensor in chunks in case
168 // memory usage is a concern.
169 std::vector<TensorSlice> tensor_slices;
170 TensorSlice slice(tensor_shape.dims());
171 if (min_old_row >= 0 && max_old_row >= 0) {
172 int64 row_start = min_old_row;
173 // TODO(weiho): Given the list of old row IDs of interest (the keys of
174 // old_row_to_new_row_map), we could also try something smarter to
175 // find some minimal set of covering ranges for the list of old row IDs
176 // such that the size of each range is less than max_rows_in_memory_.
177 while (row_start <= max_old_row) {
178 const int64 slice_length =
179 max_rows_in_memory_ <= 0
180 // If max_rows_in_memory_ <= 0, we just load the entire chunk.
181 ? max_old_row - row_start + 1
182 : std::min(max_rows_in_memory_, max_old_row - row_start + 1);
183 slice.set_start(0, row_start);
184 slice.set_length(0, slice_length);
185 tensor_slices.push_back(slice);
186 row_start += slice_length;
187 }
188 }
189
190 // Allocates the output matrix.
191 Tensor* output_matrix_t = nullptr;
192 OP_REQUIRES_OK(context,
193 context->allocate_output("output_matrix",
194 TensorShape({num_rows_, num_cols_}),
195 &output_matrix_t));
196 auto output_matrix = output_matrix_t->matrix<float>();
197
198 // Iterates through tensor slices and copies over values from the old tensor
199 // to the output matrix.
200 int64 row_index = min_old_row;
201 int64 rows_copied = 0;
202 Tensor loaded_tensor_t;
203 for (const TensorSlice& tensor_slice : tensor_slices) {
204 LOG(INFO) << "Loading slice " << tensor_slice.DebugString();
205 TensorShape slice_shape;
206 OP_REQUIRES_OK(context,
207 tensor_slice.SliceTensorShape(tensor_shape, &slice_shape));
208 // Potentially re-allocates the tensor buffer since the last slice may
209 // have fewer rows than the other slices.
210 if (loaded_tensor_t.shape() != slice_shape) {
211 loaded_tensor_t = Tensor(DT_FLOAT, slice_shape);
212 }
213 OP_REQUIRES_OK(context, reader.LookupSlice(old_tensor_name, tensor_slice,
214 &loaded_tensor_t));
215
216 // Iterates through the old loaded tensor slice row-by-row.
217 for (int row = 0; row < loaded_tensor_t.dim_size(0); ++row, ++row_index) {
218 if (row_index % 500000 == min_old_row) {
219 LOG(INFO) << "Processing old row " << row_index;
220 }
221
222 // If the old row ID is not found in old_row_to_new_row_map, continue
223 // to the next row; otherwise, copy it to the output matrix.
224 const int64* new_row_ptr =
225 gtl::FindOrNull(old_row_to_new_row_map, row_index);
226 if (new_row_ptr == nullptr) {
227 continue;
228 }
229 ++rows_copied;
230 const int64 new_row = *new_row_ptr;
231
232 // Copies over the row element-by-element, in case remapping is needed
233 // along the column axis.
234 const auto& loaded_tensor = loaded_tensor_t.matrix<float>();
235 for (int old_col = 0; old_col < loaded_tensor_t.dim_size(1);
236 ++old_col) {
237 int64 new_col = old_col;
238 if (remap_cols) {
239 const int64* new_col_ptr =
240 gtl::FindOrNull(old_col_to_new_col_map, old_col);
241 if (new_col_ptr == nullptr) {
242 // Column remapping is specified, but this column is not found in
243 // old_col_to_new_col_map, so we leave it uninitialized, to be
244 // filled in with initializing_values later.
245 continue;
246 }
247 new_col = *new_col_ptr;
248 }
249
250 OP_REQUIRES(context,
251 new_row < num_rows_ && new_col < num_cols_ &&
252 new_row >= 0 && new_col >= 0,
253 errors::Internal(strings::StrCat(
254 "new_row=", new_row, " and new_col=", new_col,
255 " should have been less than num_rows_=", num_rows_,
256 " and num_cols_=", num_cols_,
257 " and non-negative. This should never have happened "
258 "if the code were correct. Please file a bug.")));
259 output_matrix(new_row, new_col) = loaded_tensor(row, old_col);
260 }
261 }
262 }
263 LOG(INFO) << "Copied " << rows_copied << " rows from old matrix (with "
264 << tensor_shape.dim_size(0) << " rows) to new matrix (with "
265 << num_rows_ << " rows).";
266
267 // At this point, there are potentially whole rows/columns uninitialized
268 // (corresponding to the indices where row_id_present/col_id_present are
269 // false). We fill this in cell-by-cell using row_id_present and
270 // col_id_present while dequeuing from the initializing_values vector.
271 const Tensor* initializing_values_t;
272 OP_REQUIRES_OK(
273 context, context->input("initializing_values", &initializing_values_t));
274 const auto initializing_values = initializing_values_t->flat<float>();
275 int64 initializing_values_index = 0;
276 for (int i = 0; i < num_rows_; ++i) {
277 for (int j = 0; j < num_cols_; ++j) {
278 if (row_id_present[i] && col_id_present[j]) continue;
279 OP_REQUIRES(
280 context, initializing_values_index < initializing_values.size(),
281 errors::InvalidArgument(
282 "initializing_values contained ", initializing_values.size(),
283 " elements, but more missing values remain."));
284 output_matrix(i, j) = initializing_values(initializing_values_index);
285 ++initializing_values_index;
286 }
287 }
288
289 // Checks that we used all the given initializing values.
290 OP_REQUIRES(
291 context, initializing_values_index == initializing_values.size(),
292 errors::InvalidArgument(
293 "initializing_values contained ", initializing_values.size(),
294 " elements, but only ", initializing_values_index,
295 " elements were used to fill in missing values."));
296 }
297
298 private:
299 int64 num_rows_;
300 int64 num_cols_;
301 int64 max_rows_in_memory_;
302 };
303
304 REGISTER_KERNEL_BUILDER(Name("LoadAndRemapMatrix").Device(DEVICE_CPU),
305 LoadAndRemapMatrixOp);
306
307 } // namespace tensorflow
308