1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cmath>
17 #include <memory>
18 #include <unordered_map>
19
20 #include "tensorflow/c/checkpoint_reader.h"
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/graph/graph_constructor.h"
23 #include "tensorflow/core/graph/node_builder.h"
24 #include "tensorflow/core/graph/subgraph.h"
25 #include "tensorflow/core/lib/strings/str_util.h"
26 #include "tensorflow/core/platform/init_main.h"
27 #include "tensorflow/core/public/session.h"
28 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
29 #include "tensorflow/tools/graph_transforms/transform_utils.h"
30
31 namespace tensorflow {
32 using str_util::Join;
33 using str_util::Split;
34 using str_util::StringReplace;
35 using strings::StrCat;
36
37 namespace graph_transforms {
38
39 // Sparsify Tensor of shape [N, 1]. Return the indices and values vectors for
40 // non-zero tensor content.
SparsifyWeights(const Tensor & tensor,Tensor * indices_tensor,Tensor * values_tensor)41 Status SparsifyWeights(const Tensor& tensor, Tensor* indices_tensor,
42 Tensor* values_tensor) {
43 if (tensor.dims() != 2 || tensor.dim_size(1) != 1) {
44 return tensorflow::errors::FailedPrecondition(
45 "Transform only applicable to subgraph with 'Const' with "
46 "tensor of shape [N, 1]. But instead get shape ",
47 tensor.shape().DebugString(), ".");
48 }
49
50 auto flat = tensor.flat<float>();
51 std::vector<int64> indices;
52 std::vector<float> values;
53
54 for (int64 i = 0; i < flat.size(); i++) {
55 float val = flat(i);
56 if (std::abs(val) >= 1.0e-5) {
57 indices.push_back(i);
58 values.push_back(val);
59 }
60 }
61
62 // During model initialization, InitializeTableOp makes use of
63 // KeyValueTensorIterator, which does not accept empty keys or values.
64 // Consequently, adding a dummy pair of indices and values as a walkaround.
65 if (indices.empty() || values.empty()) {
66 indices.push_back(0);
67 values.push_back(0);
68 }
69 *indices_tensor = Tensor(DataTypeToEnum<int64>::value,
70 {static_cast<int64>(indices.size())});
71 std::copy_n(indices.begin(), indices.size(),
72 indices_tensor->flat<int64>().data());
73
74 *values_tensor =
75 Tensor(DataTypeToEnum<float>::value, {static_cast<int64>(values.size())});
76 std::copy_n(values.begin(), values.size(),
77 values_tensor->flat<float>().data());
78
79 return Status::OK();
80 }
81
CreateConstNode(const Tensor & tensor,const string & name,NodeDef * node_def)82 void CreateConstNode(const Tensor& tensor, const string& name,
83 NodeDef* node_def) {
84 node_def->set_op("Const");
85 node_def->set_name(name);
86 SetNodeTensorAttr<float>("value", tensor, node_def);
87 }
88
GetMonolithicTensorKey(const string & tensor_slice_name)89 string GetMonolithicTensorKey(const string& tensor_slice_name) {
90 std::vector<string> names = Split(tensor_slice_name, "/");
91 if (str_util::StartsWith(names[names.size() - 1], "part_")) {
92 CHECK_GE(names.size(), 2);
93 names.pop_back();
94 }
95 return Join(names, "/");
96 }
97
ObtainTensorSlice(const GraphDef & input_graph_def,const string & target_name,string * shape_slice_string)98 Status ObtainTensorSlice(const GraphDef& input_graph_def,
99 const string& target_name,
100 string* shape_slice_string) {
101 string restore_node_name;
102 for (const auto& node : input_graph_def.node()) {
103 std::vector<string> node_name_parts = Split(node.name(), "/");
104 if (node_name_parts.size() == 2 &&
105 str_util::StartsWith(node_name_parts[0], "save") &&
106 str_util::StartsWith(node_name_parts[1], "Assign") &&
107 node.input(0) == target_name) {
108 restore_node_name = node.input(1);
109 break;
110 }
111 }
112
113 std::vector<string> restore_node_parts = Split(restore_node_name, ":");
114 CHECK_LE(restore_node_parts.size(), 2);
115 string tensor_names_node;
116 string shape_and_slices_node;
117 for (const auto& node : input_graph_def.node()) {
118 if ((node.name() == restore_node_parts[0]) && (node.op() == "RestoreV2")) {
119 tensor_names_node = node.input(1);
120 shape_and_slices_node = node.input(2);
121 break;
122 }
123 }
124
125 int offset = -1;
126 for (const auto& node : input_graph_def.node()) {
127 if (node.name() == tensor_names_node) {
128 Tensor tensor_names_tensor;
129 TF_RETURN_IF_ERROR(GetNodeAttr(node, "value", &tensor_names_tensor));
130 const auto& tensor_names_value = tensor_names_tensor.flat<string>();
131 for (int i = 0; i < tensor_names_value.size(); i++) {
132 if (tensor_names_value(i) == GetMonolithicTensorKey(target_name)) {
133 offset = i;
134 break;
135 }
136 }
137 }
138 }
139 if (offset == -1) {
140 return errors::Internal("Unable to find RestoreV2 entry for variable: ",
141 target_name);
142 }
143 for (const auto& node : input_graph_def.node()) {
144 if (node.name() == shape_and_slices_node) {
145 Tensor shape_and_slices_tensor;
146 TF_RETURN_IF_ERROR(GetNodeAttr(node, "value", &shape_and_slices_tensor));
147 const auto& shape_and_slices_value =
148 shape_and_slices_tensor.flat<string>();
149 *shape_slice_string = shape_and_slices_value(offset);
150 return Status::OK();
151 }
152 }
153 return errors::Internal("Unable to find slice for variable: ", target_name);
154 }
155
ReadTensorFromCheckpoint(const string & tensor_name,const std::unique_ptr<BundleReader> & ckpt_reader,const string & shape_and_slice,Tensor * tensor)156 Status ReadTensorFromCheckpoint(
157 const string& tensor_name, const std::unique_ptr<BundleReader>& ckpt_reader,
158 const string& shape_and_slice, Tensor* tensor) {
159 if (ckpt_reader) {
160 TensorShape parsed_full_shape;
161 TensorSlice parsed_slice;
162 TensorShape parsed_slice_shape;
163
164 bool get_slice = false;
165 if (!shape_and_slice.empty()) {
166 TF_RETURN_IF_ERROR(
167 checkpoint::ParseShapeAndSlice(shape_and_slice, &parsed_full_shape,
168 &parsed_slice, &parsed_slice_shape));
169 get_slice = (parsed_full_shape != parsed_slice_shape);
170 }
171 if (get_slice) {
172 TF_RETURN_IF_ERROR(ckpt_reader->LookupSlice(
173 GetMonolithicTensorKey(tensor_name), parsed_slice, tensor));
174 } else {
175 TF_RETURN_IF_ERROR(
176 ckpt_reader->Lookup(GetMonolithicTensorKey(tensor_name), tensor));
177 }
178 return Status::OK();
179 }
180 return errors::Internal("Checkpoint reader was not initialized. ");
181 }
182
InitializeCheckpointReader(const TransformFuncContext & context,std::unique_ptr<BundleReader> * ckpt_reader)183 Status InitializeCheckpointReader(const TransformFuncContext& context,
184 std::unique_ptr<BundleReader>* ckpt_reader) {
185 if (context.params.count("input_checkpoint")) {
186 const string input_checkpoint = context.params.at("input_checkpoint")[0];
187 ckpt_reader->reset(new BundleReader(Env::Default(), input_checkpoint));
188 TF_RETURN_IF_ERROR((*ckpt_reader)->status());
189 }
190 return Status::OK();
191 }
192
ObtainVariableInfo(const GraphDef & input_graph_def,std::unique_ptr<std::unordered_map<string,string>> * shapes_and_slices)193 Status ObtainVariableInfo(
194 const GraphDef& input_graph_def,
195 std::unique_ptr<std::unordered_map<string, string> >* shapes_and_slices) {
196 shapes_and_slices->reset(new std::unordered_map<string, string>());
197 for (const auto& node : input_graph_def.node()) {
198 if ((node.op() == "Variable") || (node.op() == "VariableV2")) {
199 string s;
200 TF_RETURN_IF_ERROR(ObtainTensorSlice(input_graph_def, node.name(), &s));
201 (**shapes_and_slices)[node.name()] = s;
202 }
203 }
204 return Status::OK();
205 }
206
RemoveInputAtIndex(NodeDef * n,int index)207 Status RemoveInputAtIndex(NodeDef* n, int index) {
208 for (int i = index; i < n->input_size() - 1; i++) {
209 n->mutable_input()->SwapElements(i, i + 1);
210 }
211 n->mutable_input()->RemoveLast();
212 return Status::OK();
213 }
214
RemoveNodeAtIndex(GraphDef * g,int index)215 Status RemoveNodeAtIndex(GraphDef* g, int index) {
216 for (int i = index; i < g->node_size() - 1; i++) {
217 g->mutable_node()->SwapElements(i, i + 1);
218 }
219 g->mutable_node()->RemoveLast();
220 return Status::OK();
221 }
222
SparsifyGatherInternal(const GraphDef & input_graph_def,const std::unique_ptr<std::unordered_map<string,string>> & shapes_and_slices,const TransformFuncContext & context,const OpTypePattern & pattern,const std::unique_ptr<BundleReader> & ckpt_reader,GraphDef * output_graph_def)223 Status SparsifyGatherInternal(
224 const GraphDef& input_graph_def,
225 const std::unique_ptr<std::unordered_map<string, string> >&
226 shapes_and_slices,
227 const TransformFuncContext& context, const OpTypePattern& pattern,
228 const std::unique_ptr<BundleReader>& ckpt_reader,
229 GraphDef* output_graph_def) {
230 string group_init_node = "group_deps";
231 if (context.params.count("group_init_node")) {
232 group_init_node = context.params.at("group_init_node")[0];
233 }
234 GraphDef current_graph_def = input_graph_def;
235 bool any_match_found = false;
236
237 // Populate references.
238 std::unordered_map<string, int> refs;
239 for (const auto& node : current_graph_def.node()) {
240 for (const auto& input : node.input()) {
241 auto parsed_input = StringReplace(input, "^", "", true);
242 refs[parsed_input] += 1;
243 }
244 }
245
246 // The subgraphs may have overlapping components, therefore GraphMatcher
247 // doesn't return all subgraphs in one round -- this has to be multi-round
248 // update.
249 do {
250 any_match_found = false;
251 GraphDef replaced_graph_def = current_graph_def;
252 std::vector<string> init_table_node_names;
253 std::vector<string> removed_node_names;
254
255 TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
256 current_graph_def, pattern,
257 [&ckpt_reader, &any_match_found, &init_table_node_names,
258 &shapes_and_slices, &removed_node_names,
259 &refs](const NodeMatch& match, const std::set<string>& input_nodes,
260 const std::set<string>& output_nodes,
261 std::vector<NodeDef>* new_nodes) {
262 any_match_found = true;
263
264 // The captured subgraph should be of the following pattern:
265 // Const --> Identity --> Gather --> ...
266 // ^
267 // |
268 // (ids)
269 //
270 // After transform, it becomes:
271 // --> NoOp(group_deps)
272 // |
273 // Const --> InitializeTable --> HashTable
274 // ^ |
275 // | |
276 // Const ------------- |
277 // v
278 // (ids) ---> LookupTableFind <--- Const(default)
279 // |
280 // v
281 // ...
282
283 // clang-format off
284 // For each subgraph, do the following
285 // 1. Sparsify the `Const`, creating two `Const`, for hashtable
286 // key/val.
287 // 2. Create a `InitializeTable` op connecting to the above 2 `Const`.
288 // 3. Create a `HashTable` op connecting to `InitializeTable` op.
289 // 4. Replace the `Gather` with a `LookupTableFind` op.
290 // 5. Connect the `LookupTableFind` with
291 // a. `HashTable`
292 // b. `Gather`'s ids input
293 // c. a `default_val` arg, valued at 0
294 // clang-format on
295 const NodeDef& gather_node = match.node;
296
297 // GatherV2 adds an "axis" parameter. sparsify_gather only supports
298 // axis 0 gathers.
299 if (gather_node.op() == "GatherV2") {
300 // Per the OpTypePattern, the 3rd input to Gather must be a Const.
301 const NodeDef& axis_node = match.inputs[2].node;
302
303 Tensor axis_t;
304 TF_RETURN_IF_ERROR(GetNodeAttr(axis_node, "value", &axis_t));
305 int64 axis = 0;
306 if (axis_t.dtype() == DT_INT32) {
307 axis = axis_t.scalar<int32>()();
308 } else if (axis_t.dtype() == DT_INT64) {
309 axis = axis_t.scalar<int64>()();
310 } else {
311 return tensorflow::errors::FailedPrecondition(
312 "Gather axis was not int32 or int64.");
313 }
314
315 if (axis != 0) {
316 return tensorflow::errors::FailedPrecondition(
317 "Transform only applicable to subgraph with GatherV2 over "
318 "axis 0. Found axis ",
319 axis, ".");
320 }
321 }
322
323 const NodeDef& weights_node = match.inputs[0].inputs[0].node;
324
325 DataType data_type;
326 TF_RETURN_IF_ERROR(GetNodeAttr(weights_node, "dtype", &data_type));
327 if (data_type != DT_FLOAT) {
328 return tensorflow::errors::FailedPrecondition(
329 "Transform only applicable to subgraph with 'Const',"
330 "'Variable', or 'VariableV2' of dtype "
331 "'DT_FLOAT'. Found '" +
332 weights_node.op() + "' with name '",
333 weights_node.name(), "' and dtype '", data_type, "'.");
334 }
335
336 Tensor weight;
337 if (weights_node.op() == "Const") {
338 weight = GetNodeTensorAttr(weights_node, "value");
339 } else {
340 TF_RETURN_IF_ERROR(ReadTensorFromCheckpoint(
341 weights_node.name(), ckpt_reader,
342 (*shapes_and_slices)[weights_node.name()], &weight));
343 }
344 // Add both both weight and identity node names.
345 removed_node_names.push_back(weights_node.name());
346 removed_node_names.push_back(match.inputs[0].node.name());
347 for (auto input_node : match.inputs[0].node.input()) {
348 auto parsed_input = StringReplace(input_node, "^", "", true);
349 refs[parsed_input]--;
350 }
351 Tensor indices_tensor;
352 Tensor values_tensor;
353 TF_RETURN_IF_ERROR(
354 SparsifyWeights(weight, &indices_tensor, &values_tensor));
355
356 // indices and values of sparsified `Const`
357 DataType key_dtype = DT_INT64;
358 NodeDef indices_node;
359 CreateConstNode(indices_tensor,
360 StrCat(weights_node.name(), "/indices"),
361 &indices_node);
362 SetNodeAttr("dtype", key_dtype, &indices_node);
363
364 NodeDef values_node;
365 CreateConstNode(values_tensor, StrCat(weights_node.name(), "/values"),
366 &values_node);
367 SetNodeAttr("dtype", data_type, &values_node);
368
369 // HashTable node
370 NodeDef hashtable_node;
371 hashtable_node.set_op("HashTable");
372 hashtable_node.set_name(StrCat(weights_node.name(), "/HashTable"));
373 SetNodeAttr("key_dtype", key_dtype, &hashtable_node);
374 SetNodeAttr("value_dtype", data_type, &hashtable_node);
375
376 // InitializeTable node
377 NodeDef init_table_node;
378 init_table_node.set_op("InitializeTable");
379 init_table_node.set_name(
380 StrCat(weights_node.name(), "/InitializeTable"));
381 SetNodeAttr("Tkey", key_dtype, &init_table_node);
382 SetNodeAttr("Tval", data_type, &init_table_node);
383 init_table_node_names.push_back(init_table_node.name());
384
385 // LookupTableFind node
386 NodeDef lookup_node;
387 lookup_node.set_op("LookupTableFind");
388 lookup_node.set_name(StrCat(gather_node.name(), "/LookupTableFind"));
389 SetNodeAttr("Tin", key_dtype, &lookup_node);
390 SetNodeAttr("Tout", data_type, &lookup_node);
391
392 // Default return value of hashtable lookup
393 Tensor zero_tensor(data_type, TensorShape({}));
394 zero_tensor.flat<float>()(0) = 0.0;
395 NodeDef default_value_node;
396 CreateConstNode(zero_tensor, StrCat(gather_node.name(), "/Const"),
397 &default_value_node);
398 SetNodeAttr("dtype", data_type, &default_value_node);
399
400 // ExpandDims argument
401 Tensor dim_idx(DT_INT32, TensorShape({}));
402 dim_idx.flat<int32>()(0) = -1;
403 NodeDef dim_idx_node;
404 dim_idx_node.set_op("Const");
405 dim_idx_node.set_name(
406 StrCat(gather_node.name(), "/ExpandDims/Const"));
407 SetNodeAttr("value", dim_idx, &dim_idx_node);
408 SetNodeAttr("dtype", DT_INT32, &dim_idx_node);
409
410 // ExpandDims node
411 NodeDef expand_dims_node;
412 expand_dims_node.set_op("ExpandDims");
413 // Reuse gather_node's name so not to change dependent's inputs
414 expand_dims_node.set_name(gather_node.name());
415 SetNodeAttr("T", data_type, &expand_dims_node);
416
417 // Connect nodes
418 AddNodeInput(hashtable_node.name(), &init_table_node);
419 refs[hashtable_node.name()]++;
420 AddNodeInput(indices_node.name(), &init_table_node);
421 refs[indices_node.name()]++;
422 AddNodeInput(values_node.name(), &init_table_node);
423 refs[values_node.name()]++;
424
425 AddNodeInput(hashtable_node.name(), &lookup_node);
426 refs[hashtable_node.name()]++;
427 AddNodeInput(gather_node.input(1), &lookup_node);
428 refs[gather_node.input(1)]++;
429 AddNodeInput(default_value_node.name(), &lookup_node);
430 refs[default_value_node.name()]++;
431
432 AddNodeInput(lookup_node.name(), &expand_dims_node);
433 refs[lookup_node.name()]++;
434 AddNodeInput(dim_idx_node.name(), &expand_dims_node);
435 refs[dim_idx_node.name()]++;
436
437 // Copy 'ids' input of original 'Gather'
438 new_nodes->push_back(match.inputs[1].node);
439 new_nodes->push_back(indices_node);
440 new_nodes->push_back(values_node);
441 new_nodes->push_back(hashtable_node);
442 new_nodes->push_back(init_table_node);
443 new_nodes->push_back(lookup_node);
444 new_nodes->push_back(default_value_node);
445 new_nodes->push_back(dim_idx_node);
446 new_nodes->push_back(expand_dims_node);
447
448 return Status::OK();
449 },
450 {true}, &replaced_graph_def));
451
452 NodeDef* init_op = nullptr;
453 for (int i = 0; i < replaced_graph_def.node_size(); i++) {
454 if (replaced_graph_def.node(i).name() == group_init_node &&
455 replaced_graph_def.node(i).op() == "NoOp") {
456 init_op = replaced_graph_def.mutable_node(i);
457 break;
458 }
459 }
460 if (!init_op) {
461 // Init node
462 init_op = replaced_graph_def.mutable_node()->Add();
463 init_op->set_op("NoOp");
464 init_op->set_name(group_init_node);
465 }
466 for (const string& name : init_table_node_names) {
467 // Add control dependence from init_table_node to group_deps_node
468 AddNodeInput(StrCat("^", name), init_op);
469 refs[name]++;
470 }
471
472 // Erase inputs and outputs as they are not considered for deletion.
473 for (const auto& output : context.output_names) {
474 refs.erase(output);
475 }
476
477 for (const auto& input : context.input_names) {
478 refs.erase(input);
479 }
480
481 // Add nodes with a reference count of 0 for deletion.
482 for (auto entry : refs) {
483 if (entry.second == 0) {
484 removed_node_names.push_back(entry.first);
485 }
486 }
487
488 while (!removed_node_names.empty()) {
489 auto name = removed_node_names.back();
490 removed_node_names.pop_back();
491
492 int i = 0;
493 while (i < replaced_graph_def.node_size()) {
494 // Revisit this to see if we can safely remove RestoreV2 nodes.
495 if ((replaced_graph_def.node(i).name() == name) &&
496 (replaced_graph_def.node(i).op() != "RestoreV2")) {
497 for (const auto& input : replaced_graph_def.node(i).input()) {
498 auto parsed_input = StringReplace(input, "^", "", true);
499 refs[parsed_input] -= 1;
500 if (refs[parsed_input] == 0) {
501 removed_node_names.push_back(parsed_input);
502 }
503 }
504 TF_RETURN_IF_ERROR(RemoveNodeAtIndex(&replaced_graph_def, i));
505 continue;
506 }
507 int j = 0;
508 bool deleted_inputs = false;
509 while (j < replaced_graph_def.node(i).input_size()) {
510 if (replaced_graph_def.node(i).input(j) == name ||
511 replaced_graph_def.node(i).input(j) == ("^" + name)) {
512 TF_RETURN_IF_ERROR(
513 RemoveInputAtIndex(replaced_graph_def.mutable_node(i), j));
514 deleted_inputs = true;
515 continue;
516 }
517 j++;
518 }
519 if (deleted_inputs) {
520 if (replaced_graph_def.node(i).op() == "ConcatV2") {
521 if (replaced_graph_def.node(i).input_size() > 2) {
522 SetNodeAttr("N", replaced_graph_def.node(i).input_size() - 1,
523 replaced_graph_def.mutable_node(i));
524 } else if (replaced_graph_def.node(i).input_size() == 2) {
525 if (refs[replaced_graph_def.node(i).input(1)] != 1) {
526 return errors::Internal(
527 "Expect axis tensor of ConcatV2 node to only be referenced "
528 "once.");
529 }
530 refs[replaced_graph_def.node(i).input(1)] -= 1;
531 removed_node_names.push_back(replaced_graph_def.node(i).input(1));
532 replaced_graph_def.mutable_node(i)->mutable_input()->RemoveLast();
533 replaced_graph_def.mutable_node(i)->mutable_attr()->erase("N");
534 replaced_graph_def.mutable_node(i)->set_op("Identity");
535 } else {
536 return errors::Internal(
537 "ConcatV2 should have at least two elements");
538 }
539 }
540 if ((replaced_graph_def.node(i).op() == "Assign" ||
541 replaced_graph_def.node(i).op() == "Reshape" ||
542 replaced_graph_def.node(i).op() == "Equal" ||
543 replaced_graph_def.node(i).op() == "Mean" ||
544 replaced_graph_def.node(i).op() == "ScalarSummary") &&
545 replaced_graph_def.node(i).input_size() == 1) {
546 removed_node_names.push_back(replaced_graph_def.node(i).name());
547 }
548 if (!replaced_graph_def.node(i).input_size()) {
549 removed_node_names.push_back(replaced_graph_def.node(i).name());
550 }
551 }
552 i++;
553 }
554 }
555 current_graph_def = replaced_graph_def;
556 } while (any_match_found);
557 *output_graph_def = current_graph_def;
558 return Status::OK();
559 }
560
SparsifyGather(const GraphDef & input_graph_def,const TransformFuncContext & context,GraphDef * output_graph_def)561 Status SparsifyGather(const GraphDef& input_graph_def,
562 const TransformFuncContext& context,
563 GraphDef* output_graph_def) {
564 // clang-format off
565 const OpTypePattern gather_pattern =
566 {"Gather",
567 {
568 {"Identity",
569 {
570 {"Const|Variable|VariableV2"}
571 }
572 },
573 {"*"},
574 }
575 };
576 const OpTypePattern gather_v2_pattern =
577 {"GatherV2",
578 {
579 {"Identity",
580 {
581 {"Const|Variable|VariableV2"}
582 }
583 },
584 {"*"},
585 // GatherV2's axis must be constant.
586 {"Const"},
587 }
588 };
589 // clang-format on
590
591 GraphDef cleaned_input_graph_def;
592 RemoveAttributes(input_graph_def, {"_output_shapes"},
593 &cleaned_input_graph_def);
594
595 GraphDef temp_output;
596
597 std::unique_ptr<BundleReader> ckpt_reader;
598 TF_RETURN_IF_ERROR(InitializeCheckpointReader(context, &ckpt_reader));
599
600 std::unique_ptr<std::unordered_map<string, string> > shapes_and_slices;
601 TF_RETURN_IF_ERROR(
602 ObtainVariableInfo(cleaned_input_graph_def, &shapes_and_slices));
603
604 TF_RETURN_IF_ERROR(SparsifyGatherInternal(
605 cleaned_input_graph_def, shapes_and_slices, context, gather_pattern,
606 ckpt_reader, &temp_output));
607
608 TF_RETURN_IF_ERROR(SparsifyGatherInternal(temp_output, shapes_and_slices,
609 context, gather_v2_pattern,
610 ckpt_reader, output_graph_def));
611
612 return Status::OK();
613 }
614
615 REGISTER_GRAPH_TRANSFORM("sparsify_gather", SparsifyGather);
616
617 } // namespace graph_transforms
618 } // namespace tensorflow
619