• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <string>
18 #include <unordered_map>
19 #include <vector>
20 
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/core/framework/kernel_def_builder.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/tensor_types.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/gtl/map_util.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
31 
32 namespace tensorflow {
33 
34 namespace {
35 // Returning a Status instead of using OP_REQUIRES directly since that doesn't
36 // seem to work outside the main OpKernel functions.
RemapVectorToMap(const TTypes<const int64>::Vec & remapping,std::vector<bool> * id_present,std::unordered_map<int64,int64> * old_id_to_new_id)37 Status RemapVectorToMap(const TTypes<const int64>::Vec& remapping,
38                         std::vector<bool>* id_present,
39                         std::unordered_map<int64, int64>* old_id_to_new_id) {
40   id_present->clear();
41   id_present->resize(remapping.size(), false);
42   for (int i = 0; i < remapping.size(); ++i) {
43     const int64 old_id = remapping(i);
44     if (old_id < 0) continue;
45     (*id_present)[i] = true;
46     if (!gtl::InsertIfNotPresent(old_id_to_new_id, old_id, i)) {
47       return errors::Unimplemented(
48           strings::StrCat("Old ID ", old_id, " is mapped to both new ID ",
49                           old_id_to_new_id->at(old_id), " and ", i,
50                           ", which is not supported."));
51     }
52   }
53   return Status::OK();
54 }
55 }  // anonymous namespace
56 
57 // This op loads a rank-2 Tensor (matrix) from a TensorFlow checkpoint (V2) and
58 // swaps around the rows/columns according to row_remapping/col_remapping.
59 // "Missing" cells are initialized with values from initializing_values.
60 class LoadAndRemapMatrixOp : public OpKernel {
61  public:
LoadAndRemapMatrixOp(OpKernelConstruction * context)62   explicit LoadAndRemapMatrixOp(OpKernelConstruction* context)
63       : OpKernel(context) {
64     OP_REQUIRES_OK(context, context->GetAttr("num_rows", &num_rows_));
65     OP_REQUIRES_OK(context, context->GetAttr("num_cols", &num_cols_));
66     OP_REQUIRES_OK(
67         context, context->GetAttr("max_rows_in_memory", &max_rows_in_memory_));
68   }
69 
Compute(OpKernelContext * context)70   void Compute(OpKernelContext* context) override {
71     // Checks what we're remapping and inverts the relevant remapping Tensors to
72     // be maps with key = old ID, value = new ID.
73     std::unordered_map<int64, int64> old_row_to_new_row_map;
74     std::vector<bool> row_id_present;
75     const Tensor* row_remapping_t;
76     OP_REQUIRES_OK(context, context->input("row_remapping", &row_remapping_t));
77     const auto row_remapping = row_remapping_t->vec<int64>();
78     OP_REQUIRES(context, row_remapping.size() == num_rows_,
79                 errors::InvalidArgument(strings::StrCat(
80                     "Size of row_remapping is ", row_remapping.size(),
81                     " instead of being equal to num_rows=", num_rows_)));
82     OP_REQUIRES_OK(context, RemapVectorToMap(row_remapping, &row_id_present,
83                                              &old_row_to_new_row_map));
84 
85     // Calculates the min/max old row ID that we need to read, to save us from
86     // reading some unnecessary slices of the old tensor.
87     int64 min_old_row = -1;
88     int64 max_old_row = -1;
89     for (int i = 0; i < row_remapping.size(); ++i) {
90       if (min_old_row < 0 ||
91           (row_remapping(i) >= 0 && row_remapping(i) < min_old_row)) {
92         min_old_row = row_remapping(i);
93       }
94       if (max_old_row < 0 ||
95           (row_remapping(i) >= 0 && row_remapping(i) > max_old_row)) {
96         max_old_row = row_remapping(i);
97       }
98     }
99 
100     // Processes the remapping for columns.
101     std::unordered_map<int64, int64> old_col_to_new_col_map;
102     std::vector<bool> col_id_present;
103     const Tensor* col_remapping_t;
104     OP_REQUIRES_OK(context, context->input("col_remapping", &col_remapping_t));
105     const auto col_remapping = col_remapping_t->vec<int64>();
106     // Note that we always "remap rows", even when the row vocabulary does
107     // not change, because partitioning requires a mapping from partitioned
108     // Variables to the full checkpoints we load.
109     const bool remap_cols = col_remapping.size() > 0;
110     if (remap_cols) {
111       OP_REQUIRES(
112           context, col_remapping.size() == num_cols_,
113           errors::InvalidArgument(strings::StrCat(
114               "Provided col_remapping, but its size is ", col_remapping.size(),
115               " instead of being equal to num_cols=", num_cols_)));
116       OP_REQUIRES_OK(context, RemapVectorToMap(col_remapping, &col_id_present,
117                                                &old_col_to_new_col_map));
118     } else {
119       col_id_present.clear();
120       col_id_present.resize(num_cols_, true);
121     }
122 
123     // Processes the checkpoint source and the provided Tensor name.
124     const Tensor* ckpt_path_t;
125     OP_REQUIRES_OK(context, context->input("ckpt_path", &ckpt_path_t));
126     const string& ckpt_path = ckpt_path_t->scalar<tstring>()();
127     const Tensor* old_tensor_name_t;
128     OP_REQUIRES_OK(context,
129                    context->input("old_tensor_name", &old_tensor_name_t));
130     const string& old_tensor_name = old_tensor_name_t->scalar<tstring>()();
131 
132     LOG(INFO) << "Processing checkpoint : " << ckpt_path;
133     BundleReader reader(context->env(), ckpt_path);
134     OP_REQUIRES_OK(context, reader.status());
135 
136     DataType tensor_type;
137     TensorShape tensor_shape;
138     OP_REQUIRES_OK(context, reader.LookupDtypeAndShape(
139                                 old_tensor_name, &tensor_type, &tensor_shape));
140     OP_REQUIRES(context, tensor_type == DT_FLOAT,
141                 errors::InvalidArgument(strings::StrCat(
142                     "Tensor ", old_tensor_name, " has invalid type ",
143                     DataTypeString(tensor_type), " instead of expected type ",
144                     DataTypeString(DT_FLOAT))));
145     // This op is limited to loading Tensors of rank 2 (matrices).
146     OP_REQUIRES(
147         context, tensor_shape.dims() == 2,
148         errors::InvalidArgument(strings::StrCat(
149             "Tensor ", old_tensor_name, " has shape ",
150             tensor_shape.DebugString(), " of invalid rank ",
151             tensor_shape.dims(), " instead of expected shape of rank 2.")));
152 
153     if (!remap_cols) {
154       // TODO(weiho): Consider relaxing this restriction to allow partial column
155       // loading (even when no column remapping is specified) if there turns out
156       // to be a use case for it.
157       OP_REQUIRES(context, num_cols_ == tensor_shape.dim_size(1),
158                   errors::InvalidArgument(strings::StrCat(
159                       "Tensor ", old_tensor_name, " has shape ",
160                       tensor_shape.DebugString(),
161                       ", where the size of its 2nd dimension is ",
162                       tensor_shape.dim_size(1),
163                       " instead of being equal to num_cols=", num_cols_)));
164     }
165 
166     // Uses TensorSlice to potentially load the old tensor in chunks in case
167     // memory usage is a concern.
168     std::vector<TensorSlice> tensor_slices;
169     TensorSlice slice(tensor_shape.dims());
170     if (min_old_row >= 0 && max_old_row >= 0) {
171       int64 row_start = min_old_row;
172       // TODO(weiho): Given the list of old row IDs of interest (the keys of
173       // old_row_to_new_row_map), we could also try something smarter to
174       // find some minimal set of covering ranges for the list of old row IDs
175       // such that the size of each range is less than max_rows_in_memory_.
176       while (row_start <= max_old_row) {
177         const int64 slice_length =
178             max_rows_in_memory_ <= 0
179                 // If max_rows_in_memory_ <= 0, we just load the entire chunk.
180                 ? max_old_row - row_start + 1
181                 : std::min(max_rows_in_memory_, max_old_row - row_start + 1);
182         slice.set_start(0, row_start);
183         slice.set_length(0, slice_length);
184         tensor_slices.push_back(slice);
185         row_start += slice_length;
186       }
187     }
188 
189     // Allocates the output matrix.
190     Tensor* output_matrix_t = nullptr;
191     OP_REQUIRES_OK(context,
192                    context->allocate_output("output_matrix",
193                                             TensorShape({num_rows_, num_cols_}),
194                                             &output_matrix_t));
195     auto output_matrix = output_matrix_t->matrix<float>();
196 
197     // Iterates through tensor slices and copies over values from the old tensor
198     // to the output matrix.
199     int64 row_index = min_old_row;
200     int64 rows_copied = 0;
201     Tensor loaded_tensor_t;
202     for (const TensorSlice& tensor_slice : tensor_slices) {
203       LOG(INFO) << "Loading slice " << tensor_slice.DebugString();
204       TensorShape slice_shape;
205       OP_REQUIRES_OK(context,
206                      tensor_slice.SliceTensorShape(tensor_shape, &slice_shape));
207       // Potentially re-allocates the tensor buffer since the last slice may
208       // have fewer rows than the other slices.
209       if (loaded_tensor_t.shape() != slice_shape) {
210         loaded_tensor_t = Tensor(DT_FLOAT, slice_shape);
211       }
212       OP_REQUIRES_OK(context, reader.LookupSlice(old_tensor_name, tensor_slice,
213                                                  &loaded_tensor_t));
214 
215       // Iterates through the old loaded tensor slice row-by-row.
216       for (int row = 0; row < loaded_tensor_t.dim_size(0); ++row, ++row_index) {
217         if (row_index % 500000 == min_old_row) {
218           LOG(INFO) << "Processing old row " << row_index;
219         }
220 
221         // If the old row ID is not found in old_row_to_new_row_map, continue
222         // to the next row; otherwise, copy it to the output matrix.
223         const int64* new_row_ptr =
224             gtl::FindOrNull(old_row_to_new_row_map, row_index);
225         if (new_row_ptr == nullptr) {
226           continue;
227         }
228         ++rows_copied;
229         const int64 new_row = *new_row_ptr;
230 
231         // Copies over the row element-by-element, in case remapping is needed
232         // along the column axis.
233         const auto& loaded_tensor = loaded_tensor_t.matrix<float>();
234         for (int old_col = 0; old_col < loaded_tensor_t.dim_size(1);
235              ++old_col) {
236           int64 new_col = old_col;
237           if (remap_cols) {
238             const int64* new_col_ptr =
239                 gtl::FindOrNull(old_col_to_new_col_map, old_col);
240             if (new_col_ptr == nullptr) {
241               // Column remapping is specified, but this column is not found in
242               // old_col_to_new_col_map, so we leave it uninitialized, to be
243               // filled in with initializing_values later.
244               continue;
245             }
246             new_col = *new_col_ptr;
247           }
248 
249           OP_REQUIRES(context,
250                       new_row < num_rows_ && new_col < num_cols_ &&
251                           new_row >= 0 && new_col >= 0,
252                       errors::Internal(strings::StrCat(
253                           "new_row=", new_row, " and new_col=", new_col,
254                           " should have been less than num_rows_=", num_rows_,
255                           " and num_cols_=", num_cols_,
256                           " and non-negative. This should never have happened "
257                           "if the code were correct. Please file a bug.")));
258           output_matrix(new_row, new_col) = loaded_tensor(row, old_col);
259         }
260       }
261     }
262     LOG(INFO) << "Copied " << rows_copied << " rows from old matrix (with "
263               << tensor_shape.dim_size(0) << " rows) to new matrix (with "
264               << num_rows_ << " rows).";
265 
266     // At this point, there are potentially whole rows/columns uninitialized
267     // (corresponding to the indices where row_id_present/col_id_present are
268     // false). We fill this in cell-by-cell using row_id_present and
269     // col_id_present while dequeuing from the initializing_values vector.
270     const Tensor* initializing_values_t;
271     OP_REQUIRES_OK(
272         context, context->input("initializing_values", &initializing_values_t));
273     const auto initializing_values = initializing_values_t->flat<float>();
274     int64 initializing_values_index = 0;
275     for (int i = 0; i < num_rows_; ++i) {
276       for (int j = 0; j < num_cols_; ++j) {
277         if (row_id_present[i] && col_id_present[j]) continue;
278         OP_REQUIRES(
279             context, initializing_values_index < initializing_values.size(),
280             errors::InvalidArgument(
281                 "initializing_values contained ", initializing_values.size(),
282                 " elements, but more missing values remain."));
283         output_matrix(i, j) = initializing_values(initializing_values_index);
284         ++initializing_values_index;
285       }
286     }
287 
288     // Checks that we used all the given initializing values.
289     OP_REQUIRES(
290         context, initializing_values_index == initializing_values.size(),
291         errors::InvalidArgument(
292             "initializing_values contained ", initializing_values.size(),
293             " elements, but only ", initializing_values_index,
294             " elements were used to fill in missing values."));
295   }
296 
297  private:
298   int64 num_rows_;
299   int64 num_cols_;
300   int64 max_rows_in_memory_;
301 };
302 
303 REGISTER_KERNEL_BUILDER(Name("LoadAndRemapMatrix").Device(DEVICE_CPU),
304                         LoadAndRemapMatrixOp);
305 
306 }  // namespace tensorflow
307