1/* 2Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 4Licensed under the Apache License, Version 2.0 (the "License"); 5you may not use this file except in compliance with the License. 6You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10Unless required by applicable law or agreed to in writing, software 11distributed under the License is distributed on an "AS IS" BASIS, 12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13See the License for the specific language governing permissions and 14limitations under the License. 15*/ 16 17package tensorflow 18 19// #include <stdlib.h> 20// #include "tensorflow/c/c_api.h" 21import "C" 22 23import "unsafe" 24 25// Operation that has been added to the graph. 26type Operation struct { 27 c *C.TF_Operation 28 // A reference to the Graph to prevent it from 29 // being GCed while the Operation is still alive. 30 g *Graph 31} 32 33// Name returns the name of the operation. 34func (op *Operation) Name() string { 35 return C.GoString(C.TF_OperationName(op.c)) 36} 37 38// Type returns the name of the operator used by this operation. 39func (op *Operation) Type() string { 40 return C.GoString(C.TF_OperationOpType(op.c)) 41} 42 43// NumOutputs returns the number of outputs of op. 44func (op *Operation) NumOutputs() int { 45 return int(C.TF_OperationNumOutputs(op.c)) 46} 47 48// Device returns a specification of the device on which this operation 49// will be executed, or the empty string if there is no such specification. 50func (op *Operation) Device() string { 51 return C.GoString(C.TF_OperationDevice(op.c)) 52} 53 54// OutputListSize returns the size of the list of Outputs that is produced by a 55// named output of op. 56// 57// An Operation has multiple named outputs, each of which produces either 58// a single tensor or a list of tensors. This method returns the size of 59// the list of tensors for a specific output of the operation, identified 60// by its name. 61func (op *Operation) OutputListSize(output string) (int, error) { 62 cname := C.CString(output) 63 defer C.free(unsafe.Pointer(cname)) 64 status := newStatus() 65 n := C.TF_OperationOutputListLength(op.c, cname, status.c) 66 return int(n), status.Err() 67} 68 69// Output returns the i-th output of op. 70func (op *Operation) Output(i int) Output { 71 return Output{op, i} 72} 73 74// NumInputs returns the number of inputs of op. 75func (op *Operation) NumInputs() int { 76 return int(C.TF_OperationNumInputs(op.c)) 77} 78 79// Output represents one of the outputs of an operation in the graph. Has a 80// DataType (and eventually a Shape). May be passed as an input argument to a 81// function for adding operations to a graph, or to a Session's Run() method to 82// fetch that output as a tensor. 83type Output struct { 84 // Op is the Operation that produces this Output. 85 Op *Operation 86 87 // Index specifies the index of the output within the Operation. 88 Index int 89} 90 91// DataType returns the type of elements in the tensor produced by p. 92func (p Output) DataType() DataType { 93 return DataType(C.TF_OperationOutputType(p.c())) 94} 95 96// Shape returns the (possibly incomplete) shape of the tensor produced p. 97func (p Output) Shape() Shape { 98 status := newStatus() 99 port := p.c() 100 ndims := C.TF_GraphGetTensorNumDims(p.Op.g.c, port, status.c) 101 if err := status.Err(); err != nil { 102 // This should not be possible since an error only occurs if 103 // the operation does not belong to the graph. It should not 104 // be possible to construct such an Operation object. 105 return Shape{} 106 } 107 if ndims < 0 { 108 return Shape{} 109 } 110 if ndims == 0 { 111 return ScalarShape() 112 } 113 dims := make([]C.int64_t, ndims) 114 C.TF_GraphGetTensorShape(p.Op.g.c, port, &dims[0], ndims, status.c) 115 if err := status.Err(); err != nil { 116 // Same as above, should not be possible. 117 return Shape{} 118 } 119 ret := Shape{dims: make([]int64, ndims)} 120 for i := 0; i < int(ndims); i++ { 121 ret.dims[i] = int64(dims[i]) 122 } 123 return ret 124} 125 126func (p Output) c() C.TF_Output { 127 if p.Op == nil { 128 // Attempt to provide a more useful panic message than "nil 129 // pointer dereference". 130 panic("nil-Operation. If the Output was created with a Scope object, see Scope.Err() for details.") 131 } 132 return C.TF_Output{oper: p.Op.c, index: C.int(p.Index)} 133} 134 135func (p Output) canBeAnInput() {} 136 137// Consumers returns the inputs that consume this output. 138func (p Output) Consumers() []Consumer { 139 max := int(C.TF_OperationOutputNumConsumers(p.c())) 140 if max == 0 { 141 return nil 142 } 143 inputs := make([]C.TF_Input, max) 144 n := C.TF_OperationOutputConsumers(p.c(), (*C.TF_Input)(unsafe.Pointer(&inputs[0])), C.int(max)) 145 inputs = inputs[:int(n)] 146 147 var consumers []Consumer 148 for _, consumer := range inputs { 149 consumers = append(consumers, Consumer{ 150 Index: int(consumer.index), 151 Op: &Operation{ 152 c: consumer.oper, 153 g: p.Op.g, 154 }, 155 }) 156 } 157 158 return consumers 159} 160 161// Consumer identifies a specific input of an operation that consumes the output 162// of another operation. 163type Consumer struct { 164 // Op is the Operation that is consuming the output of another operation. 165 Op *Operation 166 167 // Index is the index of the input within Op that the output of another 168 // operation is connected to. 169 Index int 170} 171 172func (p Consumer) c() C.TF_Input { 173 if p.Op == nil { 174 // Attempt to provide a more useful panic message than "nil 175 // pointer dereference". 176 panic("nil-Operation. Consumer objects should only be created by a call to Output.Consumers") 177 } 178 return C.TF_Input{oper: p.Op.c, index: C.int(p.Index)} 179} 180 181// DataType returns the type of the input. 182func (p Consumer) DataType() DataType { 183 return DataType(C.TF_OperationInputType(p.c())) 184} 185 186// Producer returns the Output that is connected to this Consumer. 187func (p Consumer) Producer() Output { 188 output := C.TF_OperationInput(p.c()) 189 return Output{ 190 Op: &Operation{ 191 c: output.oper, 192 g: p.Op.g, 193 }, 194 Index: int(output.index), 195 } 196} 197 198// Input is the interface for specifying inputs to an operation being added to 199// a Graph. 200// 201// Operations can have multiple inputs, each of which could be either a tensor 202// produced by another operation (an Output object), or a list of tensors 203// produced by other operations (an OutputList). Thus, this interface is 204// implemented by both Output and OutputList. 205// 206// See OpSpec.Input for more information. 207type Input interface { 208 // Unexported to preclude implementations outside this package. 209 canBeAnInput() 210} 211 212// OutputList represents a list of Outputs that can be provided as input to 213// another operation. 214type OutputList []Output 215 216func (l OutputList) canBeAnInput() {} 217