• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow;
17 
18 /**
19  * A Graph node that performs computation on Tensors.
20  *
21  * <p>An Operation is a node in a {@link Graph} that takes zero or more {@link Tensor}s (produced by
22  * other Operations in the Graph) as input, and produces zero or more {@link Tensor}s as output.
23  *
24  * <p>Operation instances are valid only as long as the Graph they are a part of is valid. Thus, if
25  * {@link Graph#close()} has been invoked, then methods on the Operation instance may fail with an
26  * {@code IllegalStateException}.
27  *
28  * <p>Operation instances are immutable and thread-safe.
29  */
30 public final class Operation {
31 
32   // Create an Operation instance referring to an operation in g, with the given handle to the C
33   // TF_Operation object.  The handle is valid only as long as g has not been closed, hence it is
34   // called unsafeHandle.  Graph.ref() is used to safely use the unsafeHandle.
Operation(Graph g, long unsafeNativeHandle)35   Operation(Graph g, long unsafeNativeHandle) {
36     this.graph = g;
37     this.unsafeNativeHandle = unsafeNativeHandle;
38   }
39 
40   /** Returns the full name of the Operation. */
name()41   public String name() {
42     Graph.Reference r = graph.ref();
43     try {
44       return name(unsafeNativeHandle);
45     } finally {
46       r.close();
47     }
48   }
49 
50   /**
51    * Returns the type of the operation, i.e., the name of the computation performed by the
52    * operation.
53    */
type()54   public String type() {
55     Graph.Reference r = graph.ref();
56     try {
57       return type(unsafeNativeHandle);
58     } finally {
59       r.close();
60     }
61   }
62 
63   /** Returns the number of tensors produced by this operation. */
numOutputs()64   public int numOutputs() {
65     Graph.Reference r = graph.ref();
66     try {
67       return numOutputs(unsafeNativeHandle);
68     } finally {
69       r.close();
70     }
71   }
72 
73   /**
74    * Returns the size of the list of Tensors produced by this operation.
75    *
76    * <p>An Operation has multiple named outputs, each of which produces either a single tensor or a
77    * list of tensors. This method returns the size of the list of tensors for a specific named
78    * output of the operation.
79    *
80    * @param name identifier of the list of tensors (of which there may be many) produced by this
81    *     operation.
82    * @return the size of the list of Tensors produced by this named output.
83    * @throws IllegalArgumentException if this operation has no output with the provided name.
84    */
outputListLength(final String name)85   public int outputListLength(final String name) {
86     Graph.Reference r = graph.ref();
87     try {
88       return outputListLength(unsafeNativeHandle, name);
89     } finally {
90       r.close();
91     }
92   }
93 
94   /**
95    * Returns symbolic handles to a list of tensors produced by this operation.
96    *
97    * @param idx index of the first tensor of the list
98    * @param length number of tensors in the list
99    * @return array of {@code Output}
100    */
outputList(int idx, int length)101   public Output<?>[] outputList(int idx, int length) {
102     Output<?>[] outputs = new Output<?>[length];
103     for (int i = 0; i < length; ++i) {
104       outputs[i] = output(idx + i);
105     }
106     return outputs;
107   }
108 
109   /**
110    * Returns a symbolic handle to one of the tensors produced by this operation.
111    *
112    * <p>Warning: Does not check that the type of the tensor matches T. It is recommended to call
113    * this method with an explicit type parameter rather than letting it be inferred, e.g. {@code
114    * operation.<Integer>output(0)}
115    *
116    * @param <T> The expected element type of the tensors produced by this output.
117    * @param idx The index of the output among the outputs produced by this operation.
118    */
119   @SuppressWarnings({"rawtypes", "unchecked"})
output(int idx)120   public <T> Output<T> output(int idx) {
121     return new Output(this, idx);
122   }
123 
124   @Override
hashCode()125   public int hashCode() {
126     return Long.valueOf(unsafeNativeHandle).hashCode();
127   }
128 
129   @Override
equals(Object o)130   public boolean equals(Object o) {
131     if (o == this) {
132       return true;
133     }
134     if (!(o instanceof Operation)) {
135       return false;
136     }
137     Operation that = (Operation) o;
138     if (graph != that.graph) {
139       return false;
140     }
141 
142     // The graph object is known to be identical here, so this one
143     // reference is sufficient to validate the use of native pointers
144     // in both objects.
145     Graph.Reference r = graph.ref();
146     try {
147       return unsafeNativeHandle == that.unsafeNativeHandle;
148     } finally {
149       r.close();
150     }
151   }
152 
153   @Override
toString()154   public String toString() {
155     return String.format("<%s '%s'>", type(), name());
156   }
157 
158   /**
159    * Returns the size of the given inputs list of Tensors for this operation.
160    *
161    * <p>An Operation has multiple named inputs, each of which contains either a single tensor or a
162    * list of tensors. This method returns the size of the list of tensors for a specific named input
163    * of the operation.
164    *
165    * @param name identifier of the list of tensors (of which there may be many) inputs to this
166    *     operation.
167    * @return the size of the list of Tensors produced by this named input.
168    * @throws IllegalArgumentException if this operation has no input with the provided name.
169    */
inputListLength(final String name)170   public int inputListLength(final String name) {
171     Graph.Reference r = graph.ref();
172     try {
173       return inputListLength(unsafeNativeHandle, name);
174     } finally {
175       r.close();
176     }
177   }
178 
getUnsafeNativeHandle()179   long getUnsafeNativeHandle() {
180     return unsafeNativeHandle;
181   }
182 
183   // Package private, meant primarily for the public Output.shape() method.
shape(int output)184   long[] shape(int output) {
185     Graph.Reference r = graph.ref();
186     try {
187       return shape(r.nativeHandle(), unsafeNativeHandle, output);
188     } finally {
189       r.close();
190     }
191   }
192 
193   // Package private, meant primarily for the public Output.dataType() method.
dtype(int output)194   DataType dtype(int output) {
195     Graph.Reference r = graph.ref();
196     try {
197       return DataType.fromC(dtype(r.nativeHandle(), unsafeNativeHandle, output));
198     } finally {
199       r.close();
200     }
201   }
202 
203   private final long unsafeNativeHandle;
204 
205   private final Graph graph;
206 
name(long handle)207   private static native String name(long handle);
208 
type(long handle)209   private static native String type(long handle);
210 
numOutputs(long handle)211   private static native int numOutputs(long handle);
212 
outputListLength(long handle, String name)213   private static native int outputListLength(long handle, String name);
214 
inputListLength(long handle, String name)215   private static native int inputListLength(long handle, String name);
216 
shape(long graphHandle, long opHandle, int output)217   private static native long[] shape(long graphHandle, long opHandle, int output);
218 
dtype(long graphHandle, long opHandle, int output)219   private static native int dtype(long graphHandle, long opHandle, int output);
220 }
221