1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/checkpoint_reader.h"
17
18 #include <unordered_set>
19 #include <utility>
20
21 #include "tensorflow/core/lib/core/status.h"
22 #include "tensorflow/core/lib/core/stringpiece.h"
23 #include "tensorflow/core/platform/env.h"
24 #include "tensorflow/core/platform/types.h"
25 #include "tensorflow/core/util/saved_tensor_slice_util.h"
26
27 namespace tensorflow {
28 namespace checkpoint {
29
30 class TensorSliceReader;
31
CheckpointReader(const string & filename,TF_Status * out_status)32 CheckpointReader::CheckpointReader(const string& filename,
33 TF_Status* out_status)
34 : reader_(nullptr),
35 v2_reader_(nullptr),
36 var_to_shape_map_(nullptr),
37 var_to_data_type_map_(nullptr) {
38 // Depending on whether this is a V2 ckpt, initializes "reader_" or
39 // "v2_reader_".
40 std::vector<string> v2_path;
41 if (Env::Default()->GetMatchingPaths(MetaFilename(filename), &v2_path).ok() &&
42 !v2_path.empty()) {
43 v2_reader_.reset(
44 new BundleReader(Env::Default(), filename /* prefix to a V2 ckpt */));
45 if (!v2_reader_->status().ok()) {
46 Set_TF_Status_from_Status(out_status, v2_reader_->status());
47 return;
48 }
49 auto result = BuildV2VarMaps();
50 var_to_shape_map_.swap(result.first);
51 var_to_data_type_map_.swap(result.second);
52 } else {
53 reader_.reset(new TensorSliceReader(filename));
54 if (!reader_->status().ok()) {
55 Set_TF_Status_from_Status(out_status, reader_->status());
56 return;
57 }
58 var_to_shape_map_.reset(
59 new TensorSliceReader::VarToShapeMap(reader_->GetVariableToShapeMap()));
60 var_to_data_type_map_.reset(new TensorSliceReader::VarToDataTypeMap(
61 reader_->GetVariableToDataTypeMap()));
62 }
63 }
64
HasTensor(const string & name) const65 bool CheckpointReader::HasTensor(const string& name) const {
66 if (reader_ != nullptr) {
67 return reader_->HasTensor(name, nullptr, nullptr);
68 }
69 return v2_reader_->Contains(name);
70 }
71
72 const TensorSliceReader::VarToShapeMap&
GetVariableToShapeMap() const73 CheckpointReader::GetVariableToShapeMap() const {
74 CHECK(var_to_shape_map_);
75 return *var_to_shape_map_;
76 }
77
78 const TensorSliceReader::VarToDataTypeMap&
GetVariableToDataTypeMap() const79 CheckpointReader::GetVariableToDataTypeMap() const {
80 CHECK(var_to_data_type_map_);
81 return *var_to_data_type_map_;
82 }
83
DebugString() const84 const string CheckpointReader::DebugString() const {
85 if (reader_ != nullptr) return reader_->DebugString();
86 return v2_reader_->DebugString();
87 }
88
GetTensor(const string & name,std::unique_ptr<tensorflow::Tensor> * out_tensor,TF_Status * out_status) const89 void CheckpointReader::GetTensor(
90 const string& name, std::unique_ptr<tensorflow::Tensor>* out_tensor,
91 TF_Status* out_status) const {
92 Status status;
93 if (reader_ != nullptr) {
94 status = reader_->GetTensor(name, out_tensor);
95 } else {
96 tensorflow::DataType dtype;
97 tensorflow::TensorShape shape;
98 status = v2_reader_->LookupDtypeAndShape(name, &dtype, &shape);
99 if (status.ok()) {
100 out_tensor->reset(new Tensor(dtype, shape));
101 status = v2_reader_->Lookup(name, out_tensor->get());
102 if (!status.ok()) out_tensor->reset();
103 }
104 }
105 if (!status.ok()) {
106 Set_TF_Status_from_Status(out_status, status);
107 }
108 }
109
110 std::pair<std::unique_ptr<TensorSliceReader::VarToShapeMap>,
111 std::unique_ptr<TensorSliceReader::VarToDataTypeMap>>
BuildV2VarMaps()112 CheckpointReader::BuildV2VarMaps() {
113 CHECK(v2_reader_ != nullptr);
114 CHECK(v2_reader_->status().ok());
115
116 // First pass: filters out the entries of the slices.
117 std::unordered_set<string> filtered_keys;
118 BundleEntryProto entry;
119 v2_reader_->Seek(kHeaderEntryKey);
120 for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) {
121 CHECK(entry.ParseFromArray(v2_reader_->value().data(),
122 v2_reader_->value().size()))
123 << entry.InitializationErrorString();
124 for (int i = 0; i < entry.slices_size(); ++i) {
125 const auto& slice_proto = entry.slices(i);
126 CHECK(filtered_keys
127 .insert(EncodeTensorNameSlice(
128 string(v2_reader_->key()) /* full var's name */,
129 TensorSlice(slice_proto)))
130 .second);
131 }
132 }
133
134 // Second pass: adds the entries, ignoring the filtered keys.
135 std::unique_ptr<TensorSliceReader::VarToShapeMap> var_to_shape_map(
136 new TensorSliceReader::VarToShapeMap);
137 std::unique_ptr<TensorSliceReader::VarToDataTypeMap> var_to_data_type_map(
138 new TensorSliceReader::VarToDataTypeMap);
139 v2_reader_->Seek(kHeaderEntryKey);
140 for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) {
141 if (filtered_keys.count(string(v2_reader_->key())) > 0) continue;
142 CHECK(entry.ParseFromArray(v2_reader_->value().data(),
143 v2_reader_->value().size()))
144 << entry.InitializationErrorString();
145 string key(v2_reader_->key());
146 (*var_to_shape_map)[key] = TensorShape(entry.shape());
147 (*var_to_data_type_map)[key] = DataType(entry.dtype());
148 }
149 // The returned pointers are owned by the caller.
150 return std::make_pair(std::move(var_to_shape_map),
151 std::move(var_to_data_type_map));
152 }
153
154 } // namespace checkpoint
155 } // namespace tensorflow
156