• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow;
17 
18 /**
19  * SavedModelBundle represents a model loaded from storage.
20  *
21  * <p>The model consists of a description of the computation (a {@link Graph}), a {@link Session}
22  * with tensors (e.g., parameters or variables in the graph) initialized to values saved in storage,
23  * and a description of the model (a serialized representation of a <a
24  * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/meta_graph.proto">MetaGraphDef
25  * protocol buffer</a>).
26  */
27 public class SavedModelBundle implements AutoCloseable {
28   /** Options for loading a SavedModel. */
29   public static final class Loader {
30     /** Load a <code>SavedModelBundle</code> with the configured options. */
load()31     public SavedModelBundle load() {
32       return SavedModelBundle.load(exportDir, tags, configProto, runOptions);
33     }
34 
35     /**
36      * Sets options to use when executing model initialization operations.
37      *
38      * @param options Serialized <a
39      *     href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunOptions
40      *     protocol buffer</a>.
41      */
withRunOptions(byte[] options)42     public Loader withRunOptions(byte[] options) {
43       this.runOptions = options;
44       return this;
45     }
46 
47     /**
48      * Set configuration of the <code>Session</code> object created when loading the model.
49      *
50      * @param configProto Serialized <a
51      *     href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">ConfigProto
52      *     protocol buffer</a>.
53      */
withConfigProto(byte[] configProto)54     public Loader withConfigProto(byte[] configProto) {
55       this.configProto = configProto;
56       return this;
57     }
58 
59     /**
60      * Sets the set of tags that identify the specific graph in the saved model to load.
61      *
62      * @param tags the tags identifying the specific MetaGraphDef to load.
63      */
withTags(String... tags)64     public Loader withTags(String... tags) {
65       this.tags = tags;
66       return this;
67     }
68 
Loader(String exportDir)69     private Loader(String exportDir) {
70       this.exportDir = exportDir;
71     }
72 
73     private String exportDir = null;
74     private String[] tags = null;
75     private byte[] configProto = null;
76     private byte[] runOptions = null;
77   }
78 
79   /**
80    * Load a saved model from an export directory. The model that is being loaded should be created
81    * using the <a href="https://www.tensorflow.org/api_docs/python/tf/saved_model">Saved Model
82    * API</a>.
83    *
84    * <p>This method is a shorthand for:
85    *
86    * <pre>{@code
87    * SavedModelBundle.loader().withTags(tags).load();
88    * }</pre>
89    *
90    * @param exportDir the directory path containing a saved model.
91    * @param tags the tags identifying the specific metagraphdef to load.
92    * @return a bundle containing the graph and associated session.
93    */
load(String exportDir, String... tags)94   public static SavedModelBundle load(String exportDir, String... tags) {
95     return loader(exportDir).withTags(tags).load();
96   }
97 
98   /**
99    * Load a saved model.
100    *
101    * <p/>Returns a <code>Loader</code> object that can set configuration options before actually
102    * loading the model,
103    *
104    * @param exportDir the directory path containing a saved model.
105    */
loader(String exportDir)106   public static Loader loader(String exportDir) {
107     return new Loader(exportDir);
108   }
109 
110   /**
111    * Returns the serialized <a
112    * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/meta_graph.proto">MetaGraphDef
113    * protocol buffer</a> associated with the saved model.
114    */
metaGraphDef()115   public byte[] metaGraphDef() {
116     return metaGraphDef;
117   }
118 
119   /** Returns the graph that describes the computation performed by the model. */
graph()120   public Graph graph() {
121     return graph;
122   }
123 
124   /**
125    * Returns the {@link Session} with which to perform computation using the model.
126    *
127    * @return the initialized session
128    */
session()129   public Session session() {
130     return session;
131   }
132 
133   /**
134    * Releases resources (the {@link Graph} and {@link Session}) associated with the saved model
135    * bundle.
136    */
137   @Override
close()138   public void close() {
139     session.close();
140     graph.close();
141   }
142 
143   private final Graph graph;
144   private final Session session;
145   private final byte[] metaGraphDef;
146 
SavedModelBundle(Graph graph, Session session, byte[] metaGraphDef)147   private SavedModelBundle(Graph graph, Session session, byte[] metaGraphDef) {
148     this.graph = graph;
149     this.session = session;
150     this.metaGraphDef = metaGraphDef;
151   }
152 
153   /**
154    * Create a SavedModelBundle object from a handle to the C TF_Graph object and to the C TF_Session
155    * object, plus the serialized MetaGraphDef.
156    *
157    * <p>Invoked from the native load method. Takes ownership of the handles.
158    */
fromHandle( long graphHandle, long sessionHandle, byte[] metaGraphDef)159   private static SavedModelBundle fromHandle(
160       long graphHandle, long sessionHandle, byte[] metaGraphDef) {
161     Graph graph = new Graph(graphHandle);
162     Session session = new Session(graph, sessionHandle);
163     return new SavedModelBundle(graph, session, metaGraphDef);
164   }
165 
load( String exportDir, String[] tags, byte[] config, byte[] runOptions)166   private static native SavedModelBundle load(
167       String exportDir, String[] tags, byte[] config, byte[] runOptions);
168 
169   static {
TensorFlow.init()170     TensorFlow.init();
171   }
172 }
173