• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/*
2Copyright 2019 The TensorFlow Authors. All Rights Reserved.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package tensorflow
18
19import corepb "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto"
20
21// #include "tensorflow/c/c_api.h"
22import "C"
23
24// A Signature defines the signature of a computation supported by a TensorFlow
25// graph.
26//
27// For example, a model with two loss computations, sharing a single input,
28// might have the following signature_def map.
29//
30// Note that across the two Signatures "loss_A" and "loss_B", the input key,
31// output key, and method_name are identical, and will be used by system(s) that
32// implement or rely upon this particular loss method. The output tensor names
33// differ, demonstrating how different outputs can exist for the same method.
34//
35// signature_def {
36//   key: "loss_A"
37//   value {
38//     inputs {
39//       key: "input"
40//       value {
41//         name: "input:0"
42//         dtype: DT_STRING
43//         tensor_shape: ...
44//       }
45//     }
46//     outputs {
47//       key: "loss_output"
48//       value {
49//         name: "loss_output_A:0"
50//         dtype: DT_FLOAT
51//         tensor_shape: ...
52//       }
53//     }
54//   }
55//   ...
56//   method_name: "some/package/compute_loss"
57// }
58// signature_def {
59//   key: "loss_B"
60//   value {
61//     inputs {
62//       key: "input"
63//       value {
64//         name: "input:0"
65//         dtype: DT_STRING
66//         tensor_shape: ...
67//       }
68//     }
69//     outputs {
70//       key: "loss_output"
71//       value {
72//         name: "loss_output_B:0"
73//         dtype: DT_FLOAT
74//         tensor_shape: ...
75//       }
76//     }
77//   }
78//   ...
79//   method_name: "some/package/compute_loss"
80// }
81type Signature struct {
82	Inputs, Outputs map[string]TensorInfo
83	MethodName      string
84}
85
86// A TensorInfo contains the information about a Tensor necessary for feeding or retrieval.
87type TensorInfo struct {
88	Name  string
89	DType DataType
90	Shape Shape
91}
92
93func signatureDefFromProto(pb *corepb.SignatureDef) Signature {
94	inputs := make(map[string]TensorInfo)
95	for name, input := range pb.GetInputs() {
96		inputs[name] = tensorInfoFromProto(input)
97	}
98	outputs := make(map[string]TensorInfo)
99	for name, output := range pb.GetOutputs() {
100		outputs[name] = tensorInfoFromProto(output)
101	}
102	return Signature{
103		Inputs:     inputs,
104		Outputs:    outputs,
105		MethodName: pb.GetMethodName(),
106	}
107}
108
109func tensorInfoFromProto(pb *corepb.TensorInfo) TensorInfo {
110	var dims []int64
111	for _, d := range pb.GetTensorShape().GetDim() {
112		dims = append(dims, d.GetSize())
113	}
114	return TensorInfo{
115		Name:  pb.GetName(),
116		DType: DataType(C.TF_DataType(pb.GetDtype())),
117		Shape: MakeShape(dims...),
118	}
119}
120