1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/checkpoint_reader.h"
17
18 #include <unordered_set>
19 #include <utility>
20
21 #include "tensorflow/core/platform/env.h"
22 #include "tensorflow/core/platform/status.h"
23 #include "tensorflow/core/platform/stringpiece.h"
24 #include "tensorflow/core/platform/types.h"
25 #include "tensorflow/core/util/saved_tensor_slice_util.h"
26
27 namespace tensorflow {
28 namespace checkpoint {
29
30 class TensorSliceReader;
31
CheckpointReader(const string & filename,TF_Status * status)32 CheckpointReader::CheckpointReader(const string& filename, TF_Status* status)
33 : reader_(nullptr),
34 v2_reader_(nullptr),
35 var_to_shape_map_(nullptr),
36 var_to_data_type_map_(nullptr) {
37 // Depending on whether this is a V2 ckpt, initializes "reader_" or
38 // "v2_reader_".
39 std::vector<string> v2_path;
40 if (Env::Default()->GetMatchingPaths(MetaFilename(filename), &v2_path).ok() &&
41 !v2_path.empty()) {
42 v2_reader_.reset(
43 new BundleReader(Env::Default(), filename /* prefix to a V2 ckpt */));
44 if (!v2_reader_->status().ok()) {
45 Set_TF_Status_from_Status(status, v2_reader_->status());
46 return;
47 }
48 auto result = BuildV2VarMaps();
49 var_to_shape_map_.swap(result.first);
50 var_to_data_type_map_.swap(result.second);
51 } else {
52 reader_.reset(new TensorSliceReader(filename));
53 if (!reader_->status().ok()) {
54 Set_TF_Status_from_Status(status, reader_->status());
55 return;
56 }
57 var_to_shape_map_.reset(
58 new TensorSliceReader::VarToShapeMap(reader_->GetVariableToShapeMap()));
59 var_to_data_type_map_.reset(new TensorSliceReader::VarToDataTypeMap(
60 reader_->GetVariableToDataTypeMap()));
61 }
62 }
63
HasTensor(const string & name) const64 bool CheckpointReader::HasTensor(const string& name) const {
65 if (reader_ != nullptr) {
66 return reader_->HasTensor(name, nullptr, nullptr);
67 }
68 return v2_reader_->Contains(name);
69 }
70
71 const TensorSliceReader::VarToShapeMap&
GetVariableToShapeMap() const72 CheckpointReader::GetVariableToShapeMap() const {
73 CHECK(var_to_shape_map_);
74 return *var_to_shape_map_;
75 }
76
77 const TensorSliceReader::VarToDataTypeMap&
GetVariableToDataTypeMap() const78 CheckpointReader::GetVariableToDataTypeMap() const {
79 CHECK(var_to_data_type_map_);
80 return *var_to_data_type_map_;
81 }
82
DebugString() const83 const string CheckpointReader::DebugString() const {
84 if (reader_ != nullptr) return reader_->DebugString();
85 return v2_reader_->DebugString();
86 }
87
GetTensor(const string & name,std::unique_ptr<tensorflow::Tensor> * out_tensor,TF_Status * out_status) const88 void CheckpointReader::GetTensor(
89 const string& name, std::unique_ptr<tensorflow::Tensor>* out_tensor,
90 TF_Status* out_status) const {
91 Status status;
92 if (reader_ != nullptr) {
93 status = reader_->GetTensor(name, out_tensor);
94 } else {
95 tensorflow::DataType dtype;
96 tensorflow::TensorShape shape;
97 status = v2_reader_->LookupDtypeAndShape(name, &dtype, &shape);
98 if (status.ok()) {
99 out_tensor->reset(new Tensor(dtype, shape));
100 status = v2_reader_->Lookup(name, out_tensor->get());
101 if (!status.ok()) out_tensor->reset();
102 }
103 }
104 if (!status.ok()) {
105 Set_TF_Status_from_Status(out_status, status);
106 }
107 }
108
109 std::pair<std::unique_ptr<TensorSliceReader::VarToShapeMap>,
110 std::unique_ptr<TensorSliceReader::VarToDataTypeMap>>
BuildV2VarMaps()111 CheckpointReader::BuildV2VarMaps() {
112 CHECK(v2_reader_ != nullptr);
113 CHECK(v2_reader_->status().ok());
114
115 // First pass: filters out the entries of the slices.
116 std::unordered_set<string> filtered_keys;
117 BundleEntryProto entry;
118 v2_reader_->Seek(kHeaderEntryKey);
119 for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) {
120 CHECK(entry.ParseFromArray(v2_reader_->value().data(),
121 v2_reader_->value().size()))
122 << entry.InitializationErrorString();
123 for (int i = 0; i < entry.slices_size(); ++i) {
124 const auto& slice_proto = entry.slices(i);
125 CHECK(filtered_keys
126 .insert(EncodeTensorNameSlice(
127 string(v2_reader_->key()) /* full var's name */,
128 TensorSlice(slice_proto)))
129 .second);
130 }
131 }
132
133 // Second pass: adds the entries, ignoring the filtered keys.
134 std::unique_ptr<TensorSliceReader::VarToShapeMap> var_to_shape_map(
135 new TensorSliceReader::VarToShapeMap);
136 std::unique_ptr<TensorSliceReader::VarToDataTypeMap> var_to_data_type_map(
137 new TensorSliceReader::VarToDataTypeMap);
138 v2_reader_->Seek(kHeaderEntryKey);
139 for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) {
140 if (filtered_keys.count(string(v2_reader_->key())) > 0) continue;
141 CHECK(entry.ParseFromArray(v2_reader_->value().data(),
142 v2_reader_->value().size()))
143 << entry.InitializationErrorString();
144 string key(v2_reader_->key());
145 (*var_to_shape_map)[key] = TensorShape(entry.shape());
146 (*var_to_data_type_map)[key] = DataType(entry.dtype());
147 }
148 // The returned pointers are owned by the caller.
149 return std::make_pair(std::move(var_to_shape_map),
150 std::move(var_to_data_type_map));
151 }
152
153 } // namespace checkpoint
154 } // namespace tensorflow
155