• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CC_EXPERIMENTAL_LIBEXPORT_LOAD_H_
16 #define TENSORFLOW_CC_EXPERIMENTAL_LIBEXPORT_LOAD_H_
17 
18 #include "tensorflow/core/framework/function.pb.h"
19 #include "tensorflow/core/framework/types.pb.h"
20 #include "tensorflow/core/platform/protobuf.h"
21 #include "tensorflow/core/platform/statusor.h"
22 #include "tensorflow/core/protobuf/saved_model.pb.h"
23 #include "tensorflow/core/protobuf/saved_object_graph.pb.h"
24 
25 namespace tensorflow {
26 namespace libexport {
27 
28 // A low-level representation of a SavedModel.
29 //
30 // This class should only ever be a thin wrapper around disk (or other storage)
31 // access for a SavedModel.  Higher level functionality should be layered on top
32 // by other functions and classes.
33 //
34 // In the future, this class can also provide a mechanism for automatic version
35 // migration.  This will allow the calling code to always work against the most
36 // recent version of SavedModel.
37 class TFPackage {
38  public:
39   // Load a SavedModel, parsing the associated protobuf for later access.
40   static tensorflow::StatusOr<TFPackage> Load(const std::string& path);
41 
42   // Reads and returns a list of variable checkpoint keys found in the
43   // SavedModel.
44   //
45   // RestoreV2 is the operation that will ultimately be responsible for reading
46   // and restoring the variable(s)' values.  Variable values are indexed in the
47   // checkpoint files by "checkpoint keys".  These keys along with dtype and
48   // shape / slice information allow RestoreV2 to look up a variable's value in
49   // the SavedModel and restore it into a tensor.
50   //
51   // In an ideal world, we wouldn't need this extra layer of indirection; this
52   // class would be responsible for reading the values and providing them to the
53   // caller for registration in the runtime.  We should explore whether that is
54   // feasible and migrate to it if possible.
55   //
56   // Regardless of what we decide to do, we should eventually split this out
57   // into its own checkpoint abstraction.
58   struct CheckpointKey {
59     std::string key;
60     DataType dtype;
61     // Use an empty string for a non-partitioned variable.
62     //
63     // TODO(danielellis): Create a better description around what valid values
64     // look like for this.
65     std::string shape_and_slice;
66   };
67   tensorflow::StatusOr<std::vector<CheckpointKey>> GetVariableCheckpointKeys();
68 
69   // Retrieves the object graph from the SavedModel.
70   //
71   // For now, we're returning the object graph directly (i.e. the parsed proto)
72   // rather than adding abstraction on top.  We may later find we would like an
73   // intermediate abstraction layer to make traversal easier, but for now the
74   // extra complexity doesn't seem justified.  Regardless of what we choose,
75   // that logic should live outside this class; this class should continue to
76   // have the clearly-defined, singular responsibility of reading and parsing
77   // the low-level, serialized format.
78   const SavedObjectGraph& GetObjectGraph();
79 
80   // Returns a list of function defs in the SavedModel.
81   const protobuf::RepeatedPtrField<FunctionDef>& GetFunctionDefs();
82 
83  private:
84   SavedModel saved_model_proto_;
85 };
86 
87 }  // namespace libexport
88 }  // namespace tensorflow
89 
90 #endif  // TENSORFLOW_CC_EXPERIMENTAL_LIBEXPORT_LOAD_H_
91