• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <string>
18 #include <unordered_map>
19 #include <vector>
20 
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/core/framework/kernel_def_builder.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/tensor_types.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/gtl/map_util.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
31 
32 namespace tensorflow {
33 
34 namespace {
35 // Returning a Status instead of using OP_REQUIRES directly since that doesn't
36 // seem to work outside the main OpKernel functions.
RemapVectorToMap(const TTypes<const int64>::Vec & remapping,std::vector<bool> * id_present,std::unordered_map<int64,int64> * old_id_to_new_id)37 Status RemapVectorToMap(const TTypes<const int64>::Vec& remapping,
38                         std::vector<bool>* id_present,
39                         std::unordered_map<int64, int64>* old_id_to_new_id) {
40   id_present->clear();
41   id_present->resize(remapping.size(), false);
42   for (int i = 0; i < remapping.size(); ++i) {
43     const int64_t old_id = remapping(i);
44     if (old_id < 0) continue;
45     (*id_present)[i] = true;
46     if (!gtl::InsertIfNotPresent(old_id_to_new_id, old_id, i)) {
47       return errors::Unimplemented(
48           strings::StrCat("Old ID ", old_id, " is mapped to both new ID ",
49                           old_id_to_new_id->at(old_id), " and ", i,
50                           ", which is not supported."));
51     }
52   }
53   return Status::OK();
54 }
55 }  // anonymous namespace
56 
57 // This op loads a rank-2 Tensor (matrix) from a TensorFlow checkpoint (V2) and
58 // swaps around the rows/columns according to row_remapping/col_remapping.
59 // "Missing" cells are initialized with values from initializing_values.
60 class LoadAndRemapMatrixOp : public OpKernel {
61  public:
LoadAndRemapMatrixOp(OpKernelConstruction * context)62   explicit LoadAndRemapMatrixOp(OpKernelConstruction* context)
63       : OpKernel(context) {
64     OP_REQUIRES_OK(context, context->GetAttr("num_rows", &num_rows_));
65     OP_REQUIRES_OK(context, context->GetAttr("num_cols", &num_cols_));
66     OP_REQUIRES_OK(
67         context, context->GetAttr("max_rows_in_memory", &max_rows_in_memory_));
68   }
69 
Compute(OpKernelContext * context)70   void Compute(OpKernelContext* context) override {
71     // Checks what we're remapping and inverts the relevant remapping Tensors to
72     // be maps with key = old ID, value = new ID.
73     std::unordered_map<int64, int64> old_row_to_new_row_map;
74     std::vector<bool> row_id_present;
75     const Tensor* row_remapping_t;
76     OP_REQUIRES_OK(context, context->input("row_remapping", &row_remapping_t));
77     const auto row_remapping = row_remapping_t->vec<int64>();
78     OP_REQUIRES(context, row_remapping.size() == num_rows_,
79                 errors::InvalidArgument(strings::StrCat(
80                     "Size of row_remapping is ", row_remapping.size(),
81                     " instead of being equal to num_rows=", num_rows_)));
82     OP_REQUIRES_OK(context, RemapVectorToMap(row_remapping, &row_id_present,
83                                              &old_row_to_new_row_map));
84 
85     // Calculates the min/max old row ID that we need to read, to save us from
86     // reading some unnecessary slices of the old tensor.
87     int64_t min_old_row = -1;
88     int64_t max_old_row = -1;
89     for (int i = 0; i < row_remapping.size(); ++i) {
90       if (min_old_row < 0 ||
91           (row_remapping(i) >= 0 && row_remapping(i) < min_old_row)) {
92         min_old_row = row_remapping(i);
93       }
94       if (max_old_row < 0 ||
95           (row_remapping(i) >= 0 && row_remapping(i) > max_old_row)) {
96         max_old_row = row_remapping(i);
97       }
98     }
99 
100     // Processes the remapping for columns.
101     std::unordered_map<int64, int64> old_col_to_new_col_map;
102     std::vector<bool> col_id_present;
103     const Tensor* col_remapping_t;
104     OP_REQUIRES_OK(context, context->input("col_remapping", &col_remapping_t));
105     const auto col_remapping = col_remapping_t->vec<int64>();
106     // Note that we always "remap rows", even when the row vocabulary does
107     // not change, because partitioning requires a mapping from partitioned
108     // Variables to the full checkpoints we load.
109     const bool remap_cols = col_remapping.size() > 0;
110     if (remap_cols) {
111       OP_REQUIRES(
112           context, col_remapping.size() == num_cols_,
113           errors::InvalidArgument(strings::StrCat(
114               "Provided col_remapping, but its size is ", col_remapping.size(),
115               " instead of being equal to num_cols=", num_cols_)));
116       OP_REQUIRES_OK(context, RemapVectorToMap(col_remapping, &col_id_present,
117                                                &old_col_to_new_col_map));
118     } else {
119       col_id_present.clear();
120       col_id_present.resize(num_cols_, true);
121     }
122 
123     // Processes the checkpoint source and the provided Tensor name.
124     const Tensor* ckpt_path_t;
125     OP_REQUIRES_OK(context, context->input("ckpt_path", &ckpt_path_t));
126     OP_REQUIRES(
127         context, ckpt_path_t->NumElements() == 1,
128         errors::InvalidArgument("The `ckpt_path` tensor must have exactly one "
129                                 "element, got tensor of shape ",
130                                 ckpt_path_t->shape().DebugString()));
131     const string& ckpt_path = ckpt_path_t->scalar<tstring>()();
132     const Tensor* old_tensor_name_t;
133     OP_REQUIRES_OK(context,
134                    context->input("old_tensor_name", &old_tensor_name_t));
135     const string& old_tensor_name = old_tensor_name_t->scalar<tstring>()();
136 
137     LOG(INFO) << "Processing checkpoint : " << ckpt_path;
138     BundleReader reader(context->env(), ckpt_path);
139     OP_REQUIRES_OK(context, reader.status());
140 
141     DataType tensor_type;
142     TensorShape tensor_shape;
143     OP_REQUIRES_OK(context, reader.LookupDtypeAndShape(
144                                 old_tensor_name, &tensor_type, &tensor_shape));
145     OP_REQUIRES(context, tensor_type == DT_FLOAT,
146                 errors::InvalidArgument(strings::StrCat(
147                     "Tensor ", old_tensor_name, " has invalid type ",
148                     DataTypeString(tensor_type), " instead of expected type ",
149                     DataTypeString(DT_FLOAT))));
150     // This op is limited to loading Tensors of rank 2 (matrices).
151     OP_REQUIRES(
152         context, tensor_shape.dims() == 2,
153         errors::InvalidArgument(strings::StrCat(
154             "Tensor ", old_tensor_name, " has shape ",
155             tensor_shape.DebugString(), " of invalid rank ",
156             tensor_shape.dims(), " instead of expected shape of rank 2.")));
157 
158     if (!remap_cols) {
159       // TODO(weiho): Consider relaxing this restriction to allow partial column
160       // loading (even when no column remapping is specified) if there turns out
161       // to be a use case for it.
162       OP_REQUIRES(context, num_cols_ == tensor_shape.dim_size(1),
163                   errors::InvalidArgument(strings::StrCat(
164                       "Tensor ", old_tensor_name, " has shape ",
165                       tensor_shape.DebugString(),
166                       ", where the size of its 2nd dimension is ",
167                       tensor_shape.dim_size(1),
168                       " instead of being equal to num_cols=", num_cols_)));
169     }
170 
171     // Uses TensorSlice to potentially load the old tensor in chunks in case
172     // memory usage is a concern.
173     std::vector<TensorSlice> tensor_slices;
174     TensorSlice slice(tensor_shape.dims());
175     if (min_old_row >= 0 && max_old_row >= 0) {
176       int64_t row_start = min_old_row;
177       // TODO(weiho): Given the list of old row IDs of interest (the keys of
178       // old_row_to_new_row_map), we could also try something smarter to
179       // find some minimal set of covering ranges for the list of old row IDs
180       // such that the size of each range is less than max_rows_in_memory_.
181       while (row_start <= max_old_row) {
182         const int64_t slice_length =
183             max_rows_in_memory_ <= 0
184                 // If max_rows_in_memory_ <= 0, we just load the entire chunk.
185                 ? max_old_row - row_start + 1
186                 : std::min(max_rows_in_memory_, max_old_row - row_start + 1);
187         slice.set_start(0, row_start);
188         slice.set_length(0, slice_length);
189         tensor_slices.push_back(slice);
190         row_start += slice_length;
191       }
192     }
193 
194     // Allocates the output matrix.
195     Tensor* output_matrix_t = nullptr;
196     OP_REQUIRES_OK(context,
197                    context->allocate_output("output_matrix",
198                                             TensorShape({num_rows_, num_cols_}),
199                                             &output_matrix_t));
200     auto output_matrix = output_matrix_t->matrix<float>();
201 
202     // Iterates through tensor slices and copies over values from the old tensor
203     // to the output matrix.
204     int64_t row_index = min_old_row;
205     int64_t rows_copied = 0;
206     Tensor loaded_tensor_t;
207     for (const TensorSlice& tensor_slice : tensor_slices) {
208       LOG(INFO) << "Loading slice " << tensor_slice.DebugString();
209       TensorShape slice_shape;
210       OP_REQUIRES_OK(context,
211                      tensor_slice.SliceTensorShape(tensor_shape, &slice_shape));
212       // Potentially re-allocates the tensor buffer since the last slice may
213       // have fewer rows than the other slices.
214       if (loaded_tensor_t.shape() != slice_shape) {
215         loaded_tensor_t = Tensor(DT_FLOAT, slice_shape);
216       }
217       OP_REQUIRES_OK(context, reader.LookupSlice(old_tensor_name, tensor_slice,
218                                                  &loaded_tensor_t));
219 
220       // Iterates through the old loaded tensor slice row-by-row.
221       for (int row = 0; row < loaded_tensor_t.dim_size(0); ++row, ++row_index) {
222         if (row_index % 500000 == min_old_row) {
223           LOG(INFO) << "Processing old row " << row_index;
224         }
225 
226         // If the old row ID is not found in old_row_to_new_row_map, continue
227         // to the next row; otherwise, copy it to the output matrix.
228         const int64* new_row_ptr =
229             gtl::FindOrNull(old_row_to_new_row_map, row_index);
230         if (new_row_ptr == nullptr) {
231           continue;
232         }
233         ++rows_copied;
234         const int64_t new_row = *new_row_ptr;
235 
236         // Copies over the row element-by-element, in case remapping is needed
237         // along the column axis.
238         const auto& loaded_tensor = loaded_tensor_t.matrix<float>();
239         for (int old_col = 0; old_col < loaded_tensor_t.dim_size(1);
240              ++old_col) {
241           int64_t new_col = old_col;
242           if (remap_cols) {
243             const int64* new_col_ptr =
244                 gtl::FindOrNull(old_col_to_new_col_map, old_col);
245             if (new_col_ptr == nullptr) {
246               // Column remapping is specified, but this column is not found in
247               // old_col_to_new_col_map, so we leave it uninitialized, to be
248               // filled in with initializing_values later.
249               continue;
250             }
251             new_col = *new_col_ptr;
252           }
253 
254           OP_REQUIRES(context,
255                       new_row < num_rows_ && new_col < num_cols_ &&
256                           new_row >= 0 && new_col >= 0,
257                       errors::Internal(strings::StrCat(
258                           "new_row=", new_row, " and new_col=", new_col,
259                           " should have been less than num_rows_=", num_rows_,
260                           " and num_cols_=", num_cols_,
261                           " and non-negative. This should never have happened "
262                           "if the code were correct. Please file a bug.")));
263           output_matrix(new_row, new_col) = loaded_tensor(row, old_col);
264         }
265       }
266     }
267     LOG(INFO) << "Copied " << rows_copied << " rows from old matrix (with "
268               << tensor_shape.dim_size(0) << " rows) to new matrix (with "
269               << num_rows_ << " rows).";
270 
271     // At this point, there are potentially whole rows/columns uninitialized
272     // (corresponding to the indices where row_id_present/col_id_present are
273     // false). We fill this in cell-by-cell using row_id_present and
274     // col_id_present while dequeuing from the initializing_values vector.
275     const Tensor* initializing_values_t;
276     OP_REQUIRES_OK(
277         context, context->input("initializing_values", &initializing_values_t));
278     const auto initializing_values = initializing_values_t->flat<float>();
279     int64_t initializing_values_index = 0;
280     for (int i = 0; i < num_rows_; ++i) {
281       for (int j = 0; j < num_cols_; ++j) {
282         if (row_id_present[i] && col_id_present[j]) continue;
283         OP_REQUIRES(
284             context, initializing_values_index < initializing_values.size(),
285             errors::InvalidArgument(
286                 "initializing_values contained ", initializing_values.size(),
287                 " elements, but more missing values remain."));
288         output_matrix(i, j) = initializing_values(initializing_values_index);
289         ++initializing_values_index;
290       }
291     }
292 
293     // Checks that we used all the given initializing values.
294     OP_REQUIRES(
295         context, initializing_values_index == initializing_values.size(),
296         errors::InvalidArgument(
297             "initializing_values contained ", initializing_values.size(),
298             " elements, but only ", initializing_values_index,
299             " elements were used to fill in missing values."));
300   }
301 
302  private:
303   int64 num_rows_;
304   int64 num_cols_;
305   int64 max_rows_in_memory_;
306 };
307 
308 REGISTER_KERNEL_BUILDER(Name("LoadAndRemapMatrix").Device(DEVICE_CPU),
309                         LoadAndRemapMatrixOp);
310 
311 }  // namespace tensorflow
312