1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CC_EXPERIMENTAL_LIBEXPORT_LOAD_H_ 16 #define TENSORFLOW_CC_EXPERIMENTAL_LIBEXPORT_LOAD_H_ 17 18 #include "tensorflow/core/framework/function.pb.h" 19 #include "tensorflow/core/framework/types.pb.h" 20 #include "tensorflow/core/platform/protobuf.h" 21 #include "tensorflow/core/platform/statusor.h" 22 #include "tensorflow/core/protobuf/saved_model.pb.h" 23 #include "tensorflow/core/protobuf/saved_object_graph.pb.h" 24 25 namespace tensorflow { 26 namespace libexport { 27 28 // A low-level representation of a SavedModel. 29 // 30 // This class should only ever be a thin wrapper around disk (or other storage) 31 // access for a SavedModel. Higher level functionality should be layered on top 32 // by other functions and classes. 33 // 34 // In the future, this class can also provide a mechanism for automatic version 35 // migration. This will allow the calling code to always work against the most 36 // recent version of SavedModel. 37 class TFPackage { 38 public: 39 // Load a SavedModel, parsing the associated protobuf for later access. 40 static tensorflow::StatusOr<TFPackage> Load(const std::string& path); 41 42 // Reads and returns a list of variable checkpoint keys found in the 43 // SavedModel. 44 // 45 // RestoreV2 is the operation that will ultimately be responsible for reading 46 // and restoring the variable(s)' values. Variable values are indexed in the 47 // checkpoint files by "checkpoint keys". These keys along with dtype and 48 // shape / slice information allow RestoreV2 to look up a variable's value in 49 // the SavedModel and restore it into a tensor. 50 // 51 // In an ideal world, we wouldn't need this extra layer of indirection; this 52 // class would be responsible for reading the values and providing them to the 53 // caller for registration in the runtime. We should explore whether that is 54 // feasible and migrate to it if possible. 55 // 56 // Regardless of what we decide to do, we should eventually split this out 57 // into its own checkpoint abstraction. 58 struct CheckpointKey { 59 std::string key; 60 DataType dtype; 61 // Use an empty string for a non-partitioned variable. 62 // 63 // TODO(danielellis): Create a better description around what valid values 64 // look like for this. 65 std::string shape_and_slice; 66 }; 67 tensorflow::StatusOr<std::vector<CheckpointKey>> GetVariableCheckpointKeys(); 68 69 // Retrieves the object graph from the SavedModel. 70 // 71 // For now, we're returning the object graph directly (i.e. the parsed proto) 72 // rather than adding abstraction on top. We may later find we would like an 73 // intermediate abstraction layer to make traversal easier, but for now the 74 // extra complexity doesn't seem justified. Regardless of what we choose, 75 // that logic should live outside this class; this class should continue to 76 // have the clearly-defined, singular responsibility of reading and parsing 77 // the low-level, serialized format. 78 const SavedObjectGraph& GetObjectGraph(); 79 80 // Returns a list of function defs in the SavedModel. 81 const protobuf::RepeatedPtrField<FunctionDef>& GetFunctionDefs(); 82 83 private: 84 SavedModel saved_model_proto_; 85 }; 86 87 } // namespace libexport 88 } // namespace tensorflow 89 90 #endif // TENSORFLOW_CC_EXPERIMENTAL_LIBEXPORT_LOAD_H_ 91