1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17 #include <string>
18 #include <unordered_map>
19 #include <vector>
20
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/core/framework/kernel_def_builder.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/tensor_types.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/gtl/map_util.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
31
32 namespace tensorflow {
33
34 namespace {
35 // Returning a Status instead of using OP_REQUIRES directly since that doesn't
36 // seem to work outside the main OpKernel functions.
RemapVectorToMap(const TTypes<const int64>::Vec & remapping,std::vector<bool> * id_present,std::unordered_map<int64,int64> * old_id_to_new_id)37 Status RemapVectorToMap(const TTypes<const int64>::Vec& remapping,
38 std::vector<bool>* id_present,
39 std::unordered_map<int64, int64>* old_id_to_new_id) {
40 id_present->clear();
41 id_present->resize(remapping.size(), false);
42 for (int i = 0; i < remapping.size(); ++i) {
43 const int64_t old_id = remapping(i);
44 if (old_id < 0) continue;
45 (*id_present)[i] = true;
46 if (!gtl::InsertIfNotPresent(old_id_to_new_id, old_id, i)) {
47 return errors::Unimplemented(
48 strings::StrCat("Old ID ", old_id, " is mapped to both new ID ",
49 old_id_to_new_id->at(old_id), " and ", i,
50 ", which is not supported."));
51 }
52 }
53 return Status::OK();
54 }
55 } // anonymous namespace
56
57 // This op loads a rank-2 Tensor (matrix) from a TensorFlow checkpoint (V2) and
58 // swaps around the rows/columns according to row_remapping/col_remapping.
59 // "Missing" cells are initialized with values from initializing_values.
60 class LoadAndRemapMatrixOp : public OpKernel {
61 public:
LoadAndRemapMatrixOp(OpKernelConstruction * context)62 explicit LoadAndRemapMatrixOp(OpKernelConstruction* context)
63 : OpKernel(context) {
64 OP_REQUIRES_OK(context, context->GetAttr("num_rows", &num_rows_));
65 OP_REQUIRES_OK(context, context->GetAttr("num_cols", &num_cols_));
66 OP_REQUIRES_OK(
67 context, context->GetAttr("max_rows_in_memory", &max_rows_in_memory_));
68 }
69
Compute(OpKernelContext * context)70 void Compute(OpKernelContext* context) override {
71 // Checks what we're remapping and inverts the relevant remapping Tensors to
72 // be maps with key = old ID, value = new ID.
73 std::unordered_map<int64, int64> old_row_to_new_row_map;
74 std::vector<bool> row_id_present;
75 const Tensor* row_remapping_t;
76 OP_REQUIRES_OK(context, context->input("row_remapping", &row_remapping_t));
77 const auto row_remapping = row_remapping_t->vec<int64>();
78 OP_REQUIRES(context, row_remapping.size() == num_rows_,
79 errors::InvalidArgument(strings::StrCat(
80 "Size of row_remapping is ", row_remapping.size(),
81 " instead of being equal to num_rows=", num_rows_)));
82 OP_REQUIRES_OK(context, RemapVectorToMap(row_remapping, &row_id_present,
83 &old_row_to_new_row_map));
84
85 // Calculates the min/max old row ID that we need to read, to save us from
86 // reading some unnecessary slices of the old tensor.
87 int64_t min_old_row = -1;
88 int64_t max_old_row = -1;
89 for (int i = 0; i < row_remapping.size(); ++i) {
90 if (min_old_row < 0 ||
91 (row_remapping(i) >= 0 && row_remapping(i) < min_old_row)) {
92 min_old_row = row_remapping(i);
93 }
94 if (max_old_row < 0 ||
95 (row_remapping(i) >= 0 && row_remapping(i) > max_old_row)) {
96 max_old_row = row_remapping(i);
97 }
98 }
99
100 // Processes the remapping for columns.
101 std::unordered_map<int64, int64> old_col_to_new_col_map;
102 std::vector<bool> col_id_present;
103 const Tensor* col_remapping_t;
104 OP_REQUIRES_OK(context, context->input("col_remapping", &col_remapping_t));
105 const auto col_remapping = col_remapping_t->vec<int64>();
106 // Note that we always "remap rows", even when the row vocabulary does
107 // not change, because partitioning requires a mapping from partitioned
108 // Variables to the full checkpoints we load.
109 const bool remap_cols = col_remapping.size() > 0;
110 if (remap_cols) {
111 OP_REQUIRES(
112 context, col_remapping.size() == num_cols_,
113 errors::InvalidArgument(strings::StrCat(
114 "Provided col_remapping, but its size is ", col_remapping.size(),
115 " instead of being equal to num_cols=", num_cols_)));
116 OP_REQUIRES_OK(context, RemapVectorToMap(col_remapping, &col_id_present,
117 &old_col_to_new_col_map));
118 } else {
119 col_id_present.clear();
120 col_id_present.resize(num_cols_, true);
121 }
122
123 // Processes the checkpoint source and the provided Tensor name.
124 const Tensor* ckpt_path_t;
125 OP_REQUIRES_OK(context, context->input("ckpt_path", &ckpt_path_t));
126 OP_REQUIRES(
127 context, ckpt_path_t->NumElements() == 1,
128 errors::InvalidArgument("The `ckpt_path` tensor must have exactly one "
129 "element, got tensor of shape ",
130 ckpt_path_t->shape().DebugString()));
131 const string& ckpt_path = ckpt_path_t->scalar<tstring>()();
132 const Tensor* old_tensor_name_t;
133 OP_REQUIRES_OK(context,
134 context->input("old_tensor_name", &old_tensor_name_t));
135 const string& old_tensor_name = old_tensor_name_t->scalar<tstring>()();
136
137 LOG(INFO) << "Processing checkpoint : " << ckpt_path;
138 BundleReader reader(context->env(), ckpt_path);
139 OP_REQUIRES_OK(context, reader.status());
140
141 DataType tensor_type;
142 TensorShape tensor_shape;
143 OP_REQUIRES_OK(context, reader.LookupDtypeAndShape(
144 old_tensor_name, &tensor_type, &tensor_shape));
145 OP_REQUIRES(context, tensor_type == DT_FLOAT,
146 errors::InvalidArgument(strings::StrCat(
147 "Tensor ", old_tensor_name, " has invalid type ",
148 DataTypeString(tensor_type), " instead of expected type ",
149 DataTypeString(DT_FLOAT))));
150 // This op is limited to loading Tensors of rank 2 (matrices).
151 OP_REQUIRES(
152 context, tensor_shape.dims() == 2,
153 errors::InvalidArgument(strings::StrCat(
154 "Tensor ", old_tensor_name, " has shape ",
155 tensor_shape.DebugString(), " of invalid rank ",
156 tensor_shape.dims(), " instead of expected shape of rank 2.")));
157
158 if (!remap_cols) {
159 // TODO(weiho): Consider relaxing this restriction to allow partial column
160 // loading (even when no column remapping is specified) if there turns out
161 // to be a use case for it.
162 OP_REQUIRES(context, num_cols_ == tensor_shape.dim_size(1),
163 errors::InvalidArgument(strings::StrCat(
164 "Tensor ", old_tensor_name, " has shape ",
165 tensor_shape.DebugString(),
166 ", where the size of its 2nd dimension is ",
167 tensor_shape.dim_size(1),
168 " instead of being equal to num_cols=", num_cols_)));
169 }
170
171 // Uses TensorSlice to potentially load the old tensor in chunks in case
172 // memory usage is a concern.
173 std::vector<TensorSlice> tensor_slices;
174 TensorSlice slice(tensor_shape.dims());
175 if (min_old_row >= 0 && max_old_row >= 0) {
176 int64_t row_start = min_old_row;
177 // TODO(weiho): Given the list of old row IDs of interest (the keys of
178 // old_row_to_new_row_map), we could also try something smarter to
179 // find some minimal set of covering ranges for the list of old row IDs
180 // such that the size of each range is less than max_rows_in_memory_.
181 while (row_start <= max_old_row) {
182 const int64_t slice_length =
183 max_rows_in_memory_ <= 0
184 // If max_rows_in_memory_ <= 0, we just load the entire chunk.
185 ? max_old_row - row_start + 1
186 : std::min(max_rows_in_memory_, max_old_row - row_start + 1);
187 slice.set_start(0, row_start);
188 slice.set_length(0, slice_length);
189 tensor_slices.push_back(slice);
190 row_start += slice_length;
191 }
192 }
193
194 // Allocates the output matrix.
195 Tensor* output_matrix_t = nullptr;
196 OP_REQUIRES_OK(context,
197 context->allocate_output("output_matrix",
198 TensorShape({num_rows_, num_cols_}),
199 &output_matrix_t));
200 auto output_matrix = output_matrix_t->matrix<float>();
201
202 // Iterates through tensor slices and copies over values from the old tensor
203 // to the output matrix.
204 int64_t row_index = min_old_row;
205 int64_t rows_copied = 0;
206 Tensor loaded_tensor_t;
207 for (const TensorSlice& tensor_slice : tensor_slices) {
208 LOG(INFO) << "Loading slice " << tensor_slice.DebugString();
209 TensorShape slice_shape;
210 OP_REQUIRES_OK(context,
211 tensor_slice.SliceTensorShape(tensor_shape, &slice_shape));
212 // Potentially re-allocates the tensor buffer since the last slice may
213 // have fewer rows than the other slices.
214 if (loaded_tensor_t.shape() != slice_shape) {
215 loaded_tensor_t = Tensor(DT_FLOAT, slice_shape);
216 }
217 OP_REQUIRES_OK(context, reader.LookupSlice(old_tensor_name, tensor_slice,
218 &loaded_tensor_t));
219
220 // Iterates through the old loaded tensor slice row-by-row.
221 for (int row = 0; row < loaded_tensor_t.dim_size(0); ++row, ++row_index) {
222 if (row_index % 500000 == min_old_row) {
223 LOG(INFO) << "Processing old row " << row_index;
224 }
225
226 // If the old row ID is not found in old_row_to_new_row_map, continue
227 // to the next row; otherwise, copy it to the output matrix.
228 const int64* new_row_ptr =
229 gtl::FindOrNull(old_row_to_new_row_map, row_index);
230 if (new_row_ptr == nullptr) {
231 continue;
232 }
233 ++rows_copied;
234 const int64_t new_row = *new_row_ptr;
235
236 // Copies over the row element-by-element, in case remapping is needed
237 // along the column axis.
238 const auto& loaded_tensor = loaded_tensor_t.matrix<float>();
239 for (int old_col = 0; old_col < loaded_tensor_t.dim_size(1);
240 ++old_col) {
241 int64_t new_col = old_col;
242 if (remap_cols) {
243 const int64* new_col_ptr =
244 gtl::FindOrNull(old_col_to_new_col_map, old_col);
245 if (new_col_ptr == nullptr) {
246 // Column remapping is specified, but this column is not found in
247 // old_col_to_new_col_map, so we leave it uninitialized, to be
248 // filled in with initializing_values later.
249 continue;
250 }
251 new_col = *new_col_ptr;
252 }
253
254 OP_REQUIRES(context,
255 new_row < num_rows_ && new_col < num_cols_ &&
256 new_row >= 0 && new_col >= 0,
257 errors::Internal(strings::StrCat(
258 "new_row=", new_row, " and new_col=", new_col,
259 " should have been less than num_rows_=", num_rows_,
260 " and num_cols_=", num_cols_,
261 " and non-negative. This should never have happened "
262 "if the code were correct. Please file a bug.")));
263 output_matrix(new_row, new_col) = loaded_tensor(row, old_col);
264 }
265 }
266 }
267 LOG(INFO) << "Copied " << rows_copied << " rows from old matrix (with "
268 << tensor_shape.dim_size(0) << " rows) to new matrix (with "
269 << num_rows_ << " rows).";
270
271 // At this point, there are potentially whole rows/columns uninitialized
272 // (corresponding to the indices where row_id_present/col_id_present are
273 // false). We fill this in cell-by-cell using row_id_present and
274 // col_id_present while dequeuing from the initializing_values vector.
275 const Tensor* initializing_values_t;
276 OP_REQUIRES_OK(
277 context, context->input("initializing_values", &initializing_values_t));
278 const auto initializing_values = initializing_values_t->flat<float>();
279 int64_t initializing_values_index = 0;
280 for (int i = 0; i < num_rows_; ++i) {
281 for (int j = 0; j < num_cols_; ++j) {
282 if (row_id_present[i] && col_id_present[j]) continue;
283 OP_REQUIRES(
284 context, initializing_values_index < initializing_values.size(),
285 errors::InvalidArgument(
286 "initializing_values contained ", initializing_values.size(),
287 " elements, but more missing values remain."));
288 output_matrix(i, j) = initializing_values(initializing_values_index);
289 ++initializing_values_index;
290 }
291 }
292
293 // Checks that we used all the given initializing values.
294 OP_REQUIRES(
295 context, initializing_values_index == initializing_values.size(),
296 errors::InvalidArgument(
297 "initializing_values contained ", initializing_values.size(),
298 " elements, but only ", initializing_values_index,
299 " elements were used to fill in missing values."));
300 }
301
302 private:
303 int64 num_rows_;
304 int64 num_cols_;
305 int64 max_rows_in_memory_;
306 };
307
308 REGISTER_KERNEL_BUILDER(Name("LoadAndRemapMatrix").Device(DEVICE_CPU),
309 LoadAndRemapMatrixOp);
310
311 } // namespace tensorflow
312