• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/checkpoint_reader.h"
17 
18 #include <unordered_set>
19 #include <utility>
20 
21 #include "tensorflow/core/lib/core/status.h"
22 #include "tensorflow/core/lib/core/stringpiece.h"
23 #include "tensorflow/core/platform/env.h"
24 #include "tensorflow/core/platform/types.h"
25 #include "tensorflow/core/util/saved_tensor_slice_util.h"
26 
27 namespace tensorflow {
28 namespace checkpoint {
29 
30 class TensorSliceReader;
31 
CheckpointReader(const string & filename,TF_Status * out_status)32 CheckpointReader::CheckpointReader(const string& filename,
33                                    TF_Status* out_status)
34     : reader_(nullptr),
35       v2_reader_(nullptr),
36       var_to_shape_map_(nullptr),
37       var_to_data_type_map_(nullptr) {
38   // Depending on whether this is a V2 ckpt, initializes "reader_" or
39   // "v2_reader_".
40   std::vector<string> v2_path;
41   if (Env::Default()->GetMatchingPaths(MetaFilename(filename), &v2_path).ok() &&
42       !v2_path.empty()) {
43     v2_reader_.reset(
44         new BundleReader(Env::Default(), filename /* prefix to a V2 ckpt */));
45     if (!v2_reader_->status().ok()) {
46       Set_TF_Status_from_Status(out_status, v2_reader_->status());
47       return;
48     }
49     auto result = BuildV2VarMaps();
50     var_to_shape_map_.swap(result.first);
51     var_to_data_type_map_.swap(result.second);
52   } else {
53     reader_.reset(new TensorSliceReader(filename));
54     if (!reader_->status().ok()) {
55       Set_TF_Status_from_Status(out_status, reader_->status());
56       return;
57     }
58     var_to_shape_map_.reset(
59         new TensorSliceReader::VarToShapeMap(reader_->GetVariableToShapeMap()));
60     var_to_data_type_map_.reset(new TensorSliceReader::VarToDataTypeMap(
61         reader_->GetVariableToDataTypeMap()));
62   }
63 }
64 
HasTensor(const string & name) const65 bool CheckpointReader::HasTensor(const string& name) const {
66   if (reader_ != nullptr) {
67     return reader_->HasTensor(name, nullptr, nullptr);
68   }
69   return v2_reader_->Contains(name);
70 }
71 
72 const TensorSliceReader::VarToShapeMap&
GetVariableToShapeMap() const73 CheckpointReader::GetVariableToShapeMap() const {
74   CHECK(var_to_shape_map_);
75   return *var_to_shape_map_;
76 }
77 
78 const TensorSliceReader::VarToDataTypeMap&
GetVariableToDataTypeMap() const79 CheckpointReader::GetVariableToDataTypeMap() const {
80   CHECK(var_to_data_type_map_);
81   return *var_to_data_type_map_;
82 }
83 
DebugString() const84 const string CheckpointReader::DebugString() const {
85   if (reader_ != nullptr) return reader_->DebugString();
86   return v2_reader_->DebugString();
87 }
88 
GetTensor(const string & name,std::unique_ptr<tensorflow::Tensor> * out_tensor,TF_Status * out_status) const89 void CheckpointReader::GetTensor(
90     const string& name, std::unique_ptr<tensorflow::Tensor>* out_tensor,
91     TF_Status* out_status) const {
92   Status status;
93   if (reader_ != nullptr) {
94     status = reader_->GetTensor(name, out_tensor);
95   } else {
96     tensorflow::DataType dtype;
97     tensorflow::TensorShape shape;
98     status = v2_reader_->LookupDtypeAndShape(name, &dtype, &shape);
99     if (status.ok()) {
100       out_tensor->reset(new Tensor(dtype, shape));
101       status = v2_reader_->Lookup(name, out_tensor->get());
102       if (!status.ok()) out_tensor->reset();
103     }
104   }
105   if (!status.ok()) {
106     Set_TF_Status_from_Status(out_status, status);
107   }
108 }
109 
110 std::pair<std::unique_ptr<TensorSliceReader::VarToShapeMap>,
111           std::unique_ptr<TensorSliceReader::VarToDataTypeMap>>
BuildV2VarMaps()112 CheckpointReader::BuildV2VarMaps() {
113   CHECK(v2_reader_ != nullptr);
114   CHECK(v2_reader_->status().ok());
115 
116   // First pass: filters out the entries of the slices.
117   std::unordered_set<string> filtered_keys;
118   BundleEntryProto entry;
119   v2_reader_->Seek(kHeaderEntryKey);
120   for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) {
121     CHECK(entry.ParseFromArray(v2_reader_->value().data(),
122                                v2_reader_->value().size()))
123         << entry.InitializationErrorString();
124     for (int i = 0; i < entry.slices_size(); ++i) {
125       const auto& slice_proto = entry.slices(i);
126       CHECK(filtered_keys
127                 .insert(EncodeTensorNameSlice(
128                     string(v2_reader_->key()) /* full var's name */,
129                     TensorSlice(slice_proto)))
130                 .second);
131     }
132   }
133 
134   // Second pass: adds the entries, ignoring the filtered keys.
135   std::unique_ptr<TensorSliceReader::VarToShapeMap> var_to_shape_map(
136       new TensorSliceReader::VarToShapeMap);
137   std::unique_ptr<TensorSliceReader::VarToDataTypeMap> var_to_data_type_map(
138       new TensorSliceReader::VarToDataTypeMap);
139   v2_reader_->Seek(kHeaderEntryKey);
140   for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) {
141     if (filtered_keys.count(string(v2_reader_->key())) > 0) continue;
142     CHECK(entry.ParseFromArray(v2_reader_->value().data(),
143                                v2_reader_->value().size()))
144         << entry.InitializationErrorString();
145     string key(v2_reader_->key());
146     (*var_to_shape_map)[key] = TensorShape(entry.shape());
147     (*var_to_data_type_map)[key] = DataType(entry.dtype());
148   }
149   // The returned pointers are owned by the caller.
150   return std::make_pair(std::move(var_to_shape_map),
151                         std::move(var_to_data_type_map));
152 }
153 
154 }  // namespace checkpoint
155 }  // namespace tensorflow
156