1/* 2Copyright 2019 The TensorFlow Authors. All Rights Reserved. 3 4Licensed under the Apache License, Version 2.0 (the "License"); 5you may not use this file except in compliance with the License. 6You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10Unless required by applicable law or agreed to in writing, software 11distributed under the License is distributed on an "AS IS" BASIS, 12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13See the License for the specific language governing permissions and 14limitations under the License. 15*/ 16 17package tensorflow 18 19import corepb "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto" 20 21// #include "tensorflow/c/c_api.h" 22import "C" 23 24// A Signature defines the signature of a computation supported by a TensorFlow 25// graph. 26// 27// For example, a model with two loss computations, sharing a single input, 28// might have the following signature_def map. 29// 30// Note that across the two Signatures "loss_A" and "loss_B", the input key, 31// output key, and method_name are identical, and will be used by system(s) that 32// implement or rely upon this particular loss method. The output tensor names 33// differ, demonstrating how different outputs can exist for the same method. 34// 35// signature_def { 36// key: "loss_A" 37// value { 38// inputs { 39// key: "input" 40// value { 41// name: "input:0" 42// dtype: DT_STRING 43// tensor_shape: ... 44// } 45// } 46// outputs { 47// key: "loss_output" 48// value { 49// name: "loss_output_A:0" 50// dtype: DT_FLOAT 51// tensor_shape: ... 52// } 53// } 54// } 55// ... 56// method_name: "some/package/compute_loss" 57// } 58// signature_def { 59// key: "loss_B" 60// value { 61// inputs { 62// key: "input" 63// value { 64// name: "input:0" 65// dtype: DT_STRING 66// tensor_shape: ... 67// } 68// } 69// outputs { 70// key: "loss_output" 71// value { 72// name: "loss_output_B:0" 73// dtype: DT_FLOAT 74// tensor_shape: ... 75// } 76// } 77// } 78// ... 79// method_name: "some/package/compute_loss" 80// } 81type Signature struct { 82 Inputs, Outputs map[string]TensorInfo 83 MethodName string 84} 85 86// A TensorInfo contains the information about a Tensor necessary for feeding or retrieval. 87type TensorInfo struct { 88 Name string 89 DType DataType 90 Shape Shape 91} 92 93func signatureDefFromProto(pb *corepb.SignatureDef) Signature { 94 inputs := make(map[string]TensorInfo) 95 for name, input := range pb.GetInputs() { 96 inputs[name] = tensorInfoFromProto(input) 97 } 98 outputs := make(map[string]TensorInfo) 99 for name, output := range pb.GetOutputs() { 100 outputs[name] = tensorInfoFromProto(output) 101 } 102 return Signature{ 103 Inputs: inputs, 104 Outputs: outputs, 105 MethodName: pb.GetMethodName(), 106 } 107} 108 109func tensorInfoFromProto(pb *corepb.TensorInfo) TensorInfo { 110 var dims []int64 111 for _, d := range pb.GetTensorShape().GetDim() { 112 dims = append(dims, d.GetSize()) 113 } 114 return TensorInfo{ 115 Name: pb.GetName(), 116 DType: DataType(C.TF_DataType(pb.GetDtype())), 117 Shape: MakeShape(dims...), 118 } 119} 120