• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Helpers for working with TensorFlow exports and their signatures.
17 
18 #ifndef TENSORFLOW_CONTRIB_SESSION_BUNDLE_SIGNATURE_H_
19 #define TENSORFLOW_CONTRIB_SESSION_BUNDLE_SIGNATURE_H_
20 
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "tensorflow/contrib/session_bundle/manifest.pb.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/platform/types.h"
29 #include "tensorflow/core/protobuf/meta_graph.pb.h"
30 #include "tensorflow/core/protobuf/saver.pb.h"
31 #include "tensorflow/core/public/session.h"
32 
33 namespace tensorflow {
34 namespace serving {
35 
36 const char kSignaturesKey[] = "serving_signatures";
37 
38 // Get Signatures from a MetaGraphDef.
39 Status GetSignatures(const tensorflow::MetaGraphDef& meta_graph_def,
40                      Signatures* signatures);
41 
42 // (Re)set Signatures in a MetaGraphDef.
43 Status SetSignatures(const Signatures& signatures,
44                      tensorflow::MetaGraphDef* meta_graph_def);
45 
46 // Gets a ClassificationSignature from a MetaGraphDef's default signature.
47 // Returns an error if the default signature is not a ClassificationSignature,
48 // or does not exist.
49 Status GetClassificationSignature(
50     const tensorflow::MetaGraphDef& meta_graph_def,
51     ClassificationSignature* signature);
52 
53 // Gets a named ClassificationSignature from a MetaGraphDef.
54 // Returns an error if a ClassificationSignature with the given name does
55 // not exist.
56 Status GetNamedClassificationSignature(
57     const string& name, const tensorflow::MetaGraphDef& meta_graph_def,
58     ClassificationSignature* signature);
59 
60 // Gets a RegressionSignature from a MetaGraphDef's default signature.
61 // Returns an error if the default signature is not a RegressionSignature,
62 // or does not exist.
63 Status GetRegressionSignature(const tensorflow::MetaGraphDef& meta_graph_def,
64                               RegressionSignature* signature);
65 
66 // Runs a classification using the provided signature and initialized Session.
67 //   input: input batch of items to classify
68 //   classes: output batch of classes; may be null if not needed
69 //   scores: output batch of scores; may be null if not needed
70 // Validates sizes of the inputs and outputs are consistent (e.g., input
71 // batch size equals output batch sizes).
72 // Does not do any type validation.
73 Status RunClassification(const ClassificationSignature& signature,
74                          const Tensor& input, Session* session, Tensor* classes,
75                          Tensor* scores);
76 
77 // Runs regression using the provided signature and initialized Session.
78 //   input: input batch of items to run the regression model against
79 //   output: output targets
80 // Validates sizes of the inputs and outputs are consistent (e.g., input
81 // batch size equals output batch sizes).
82 // Does not do any type validation.
83 Status RunRegression(const RegressionSignature& signature, const Tensor& input,
84                      Session* session, Tensor* output);
85 
86 // Gets the named GenericSignature from a MetaGraphDef.
87 // Returns an error if a GenericSignature with the given name does not exist.
88 Status GetGenericSignature(const string& name,
89                            const tensorflow::MetaGraphDef& meta_graph_def,
90                            GenericSignature* signature);
91 
92 // Gets the default signature from a MetaGraphDef.
93 Status GetDefaultSignature(const tensorflow::MetaGraphDef& meta_graph_def,
94                            Signature* default_signature);
95 
96 // Gets a named Signature from a MetaGraphDef.
97 // Returns an error if a Signature with the given name does not exist.
98 Status GetNamedSignature(const string& name,
99                          const tensorflow::MetaGraphDef& meta_graph_def,
100                          Signature* default_signature);
101 
102 // Binds TensorFlow inputs specified by the caller using the logical names
103 // specified at Graph export time, to the actual Graph names.
104 // Returns an error if any of the inputs do not have a binding in the export's
105 // MetaGraphDef.
106 Status BindGenericInputs(const GenericSignature& signature,
107                          const std::vector<std::pair<string, Tensor>>& inputs,
108                          std::vector<std::pair<string, Tensor>>* bound_inputs);
109 
110 // Binds the input names specified by the caller using the logical names
111 // specified at Graph export time, to the actual Graph names. This is useful
112 // for binding names of both the TensorFlow output tensors and target nodes,
113 // with the latter (target nodes) being optional and rarely used (if ever) at
114 // serving time.
115 // Returns an error if any of the input names do not have a binding in the
116 // export's MetaGraphDef.
117 Status BindGenericNames(const GenericSignature& signature,
118                         const std::vector<string>& input_names,
119                         std::vector<string>* bound_names);
120 
121 }  // namespace serving
122 }  // namespace tensorflow
123 
124 #endif  // TENSORFLOW_CONTRIB_SESSION_BUNDLE_SIGNATURE_H_
125