1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow; 17 18 /** 19 * SavedModelBundle represents a model loaded from storage. 20 * 21 * <p>The model consists of a description of the computation (a {@link Graph}), a {@link Session} 22 * with tensors (e.g., parameters or variables in the graph) initialized to values saved in storage, 23 * and a description of the model (a serialized representation of a <a 24 * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/meta_graph.proto">MetaGraphDef 25 * protocol buffer</a>). 26 */ 27 public class SavedModelBundle implements AutoCloseable { 28 /** Options for loading a SavedModel. */ 29 public static final class Loader { 30 /** Load a <code>SavedModelBundle</code> with the configured options. */ load()31 public SavedModelBundle load() { 32 return SavedModelBundle.load(exportDir, tags, configProto, runOptions); 33 } 34 35 /** 36 * Sets options to use when executing model initialization operations. 37 * 38 * @param options Serialized <a 39 * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunOptions 40 * protocol buffer</a>. 41 */ withRunOptions(byte[] options)42 public Loader withRunOptions(byte[] options) { 43 this.runOptions = options; 44 return this; 45 } 46 47 /** 48 * Set configuration of the <code>Session</code> object created when loading the model. 49 * 50 * @param configProto Serialized <a 51 * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">ConfigProto 52 * protocol buffer</a>. 53 */ withConfigProto(byte[] configProto)54 public Loader withConfigProto(byte[] configProto) { 55 this.configProto = configProto; 56 return this; 57 } 58 59 /** 60 * Sets the set of tags that identify the specific graph in the saved model to load. 61 * 62 * @param tags the tags identifying the specific MetaGraphDef to load. 63 */ withTags(String... tags)64 public Loader withTags(String... tags) { 65 this.tags = tags; 66 return this; 67 } 68 Loader(String exportDir)69 private Loader(String exportDir) { 70 this.exportDir = exportDir; 71 } 72 73 private String exportDir = null; 74 private String[] tags = null; 75 private byte[] configProto = null; 76 private byte[] runOptions = null; 77 } 78 79 /** 80 * Load a saved model from an export directory. The model that is being loaded should be created 81 * using the <a href="https://www.tensorflow.org/api_docs/python/tf/saved_model">Saved Model 82 * API</a>. 83 * 84 * <p>This method is a shorthand for: 85 * 86 * <pre>{@code 87 * SavedModelBundle.loader().withTags(tags).load(); 88 * }</pre> 89 * 90 * @param exportDir the directory path containing a saved model. 91 * @param tags the tags identifying the specific metagraphdef to load. 92 * @return a bundle containing the graph and associated session. 93 */ load(String exportDir, String... tags)94 public static SavedModelBundle load(String exportDir, String... tags) { 95 return loader(exportDir).withTags(tags).load(); 96 } 97 98 /** 99 * Load a saved model. 100 * 101 * <p/>Returns a <code>Loader</code> object that can set configuration options before actually 102 * loading the model, 103 * 104 * @param exportDir the directory path containing a saved model. 105 */ loader(String exportDir)106 public static Loader loader(String exportDir) { 107 return new Loader(exportDir); 108 } 109 110 /** 111 * Returns the serialized <a 112 * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/meta_graph.proto">MetaGraphDef 113 * protocol buffer</a> associated with the saved model. 114 */ metaGraphDef()115 public byte[] metaGraphDef() { 116 return metaGraphDef; 117 } 118 119 /** Returns the graph that describes the computation performed by the model. */ graph()120 public Graph graph() { 121 return graph; 122 } 123 124 /** 125 * Returns the {@link Session} with which to perform computation using the model. 126 * 127 * @return the initialized session 128 */ session()129 public Session session() { 130 return session; 131 } 132 133 /** 134 * Releases resources (the {@link Graph} and {@link Session}) associated with the saved model 135 * bundle. 136 */ 137 @Override close()138 public void close() { 139 session.close(); 140 graph.close(); 141 } 142 143 private final Graph graph; 144 private final Session session; 145 private final byte[] metaGraphDef; 146 SavedModelBundle(Graph graph, Session session, byte[] metaGraphDef)147 private SavedModelBundle(Graph graph, Session session, byte[] metaGraphDef) { 148 this.graph = graph; 149 this.session = session; 150 this.metaGraphDef = metaGraphDef; 151 } 152 153 /** 154 * Create a SavedModelBundle object from a handle to the C TF_Graph object and to the C TF_Session 155 * object, plus the serialized MetaGraphDef. 156 * 157 * <p>Invoked from the native load method. Takes ownership of the handles. 158 */ fromHandle( long graphHandle, long sessionHandle, byte[] metaGraphDef)159 private static SavedModelBundle fromHandle( 160 long graphHandle, long sessionHandle, byte[] metaGraphDef) { 161 Graph graph = new Graph(graphHandle); 162 Session session = new Session(graph, sessionHandle); 163 return new SavedModelBundle(graph, session, metaGraphDef); 164 } 165 load( String exportDir, String[] tags, byte[] config, byte[] runOptions)166 private static native SavedModelBundle load( 167 String exportDir, String[] tags, byte[] config, byte[] runOptions); 168 169 static { TensorFlow.init()170 TensorFlow.init(); 171 } 172 } 173