• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/shape_inference.h"
19 
20 namespace tensorflow {
21 
22 using shape_inference::InferenceContext;
23 
24 REGISTER_OP("SymbolicGradient")
25     .Input("input: Tin")
26     .Output("output: Tout")
27     .Attr("Tin: list(type)")
28     .Attr("Tout: list(type)")
29     .Attr("f: func")
__anon22ceab0f0102(InferenceContext* c) 30     .SetShapeFn([](InferenceContext* c) {
31       if (c->num_inputs() < c->num_outputs()) {
32         return errors::InvalidArgument("len(inputs) < len(outputs)");
33       }
34       std::vector<DataType> types;
35       TF_RETURN_IF_ERROR(c->GetAttr("Tin", &types));
36       // Say, (u, v) = f(x, y, z), _symbolic_gradient(f) is a function of
37       // (x, y, z, du, dv) -> (dx, dy, dz). Therefore, shapes of its
38       // outputs (dx, dy, dz) are the same as (x, y, z).
39       for (int i = 0; i < c->num_outputs(); ++i) {
40         if (types[i] == DT_RESOURCE) {
41           const std::vector<shape_inference::ShapeAndType>* handle_type =
42               c->input_handle_shapes_and_types(i);
43           if (handle_type != nullptr) {
44             c->set_output(i, handle_type->at(0).shape);
45           } else {
46             c->set_output(i, c->UnknownShape());
47           }
48         } else {
49           c->set_output(i, c->input(i));
50         }
51       }
52       return Status::OK();
53     });
54 
55 REGISTER_OP("RemoteCall")
56     .Input("target: string")
57     .Input("args: Tin")
58     .Output("output: Tout")
59     .Attr("Tin: list(type)")
60     .Attr("Tout: list(type)")
61     .Attr("f: func")
62     .SetIsStateful()
63     .SetShapeFn(shape_inference::UnknownShape);
64 
65 // TODO(drpng): remove this.
66 REGISTER_OP("_If")
67     .Input("cond: Tcond")
68     .Input("input: Tin")
69     .Output("output: Tout")
70     .Attr("Tcond: type")
71     .Attr("Tin: list(type)")
72     .Attr("Tout: list(type)")
73     .Attr("then_branch: func")
74     .Attr("else_branch: func")
75     .SetIsStateful()
76     .SetShapeFn(shape_inference::UnknownShape)
77     .Doc(R"doc(
78 output = cond ? then_branch(input) : else_branch(input)
79 
80 cond: A Tensor. If the tensor is a scalar of non-boolean type, the
81     scalar is converted to a boolean according to the
82     following rule: if the scalar is a numerical value, non-zero means
83     True and zero means False; if the scalar is a string, non-empty
84     means True and empty means False. If the tensor is not a scalar,
85     being empty means False and being non-empty means True.
86 input: A list of input tensors.
87 then_branch: A function that takes 'inputs' and returns a list of
88     tensors, whose types are the same as what else_branch returns.
89 else_branch: A function that takes 'inputs' and returns a list of
90     tensors.  whose types are the same as what then_branch returns.
91 )doc");
92 
93 REGISTER_OP("StatelessIf")
94     .Input("cond: Tcond")
95     .Input("input: Tin")
96     .Output("output: Tout")
97     .Attr("Tcond: type")
98     .Attr("Tin: list(type) >= 0")
99     .Attr("Tout: list(type) >= 0")
100     .Attr("then_branch: func")
101     .Attr("else_branch: func")
102     .SetShapeFn(shape_inference::UnknownShape);
103 
104 REGISTER_OP("If")
105     .Input("cond: Tcond")
106     .Input("input: Tin")
107     .Output("output: Tout")
108     .Attr("Tcond: type")
109     .Attr("Tin: list(type) >= 0")
110     .Attr("Tout: list(type) >= 0")
111     .Attr("then_branch: func")
112     .Attr("else_branch: func")
113     .Attr("output_shapes: list(shape) = []")
114     .SetIsStateful()
__anon22ceab0f0202(shape_inference::InferenceContext* c) 115     .SetShapeFn([](shape_inference::InferenceContext* c) {
116       std::vector<PartialTensorShape> output_shapes;
117       TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
118       // If `output_shapes` attr is set use that as the shapes of the outputs
119       // else return unknown shapes.
120       if (output_shapes.empty()) return shape_inference::UnknownShape(c);
121       if (output_shapes.size() != c->num_outputs()) {
122         return errors::InvalidArgument(
123             "`output_shapes` must be the same length as num outputs (",
124             output_shapes.size(), " vs. ", c->num_outputs());
125       }
126       for (size_t i = 0; i < output_shapes.size(); ++i) {
127         shape_inference::ShapeHandle output_shape_handle;
128         TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
129             output_shapes[i], &output_shape_handle));
130         c->set_output(static_cast<int>(i), output_shape_handle);
131       }
132       return Status::OK();
133     });
134 
135 REGISTER_OP("Case")
136     .Input("branch_index: int32")
137     .Input("input: Tin")
138     .Output("output: Tout")
139     .Attr("Tin: list(type) >= 0")
140     .Attr("Tout: list(type) >= 0")
141     .Attr("branches: list(func) >= 1")
142     .Attr("output_shapes: list(shape) = []")
143     .SetIsStateful()
__anon22ceab0f0302(shape_inference::InferenceContext* c) 144     .SetShapeFn([](shape_inference::InferenceContext* c) {
145       std::vector<PartialTensorShape> output_shapes;
146       TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
147       // If `output_shapes` attr is set use that as the shapes of the outputs
148       // else return unknown shapes.
149       if (output_shapes.empty()) return shape_inference::UnknownShape(c);
150       if (output_shapes.size() != c->num_outputs()) {
151         return errors::InvalidArgument(
152             "`output_shapes` must be the same length as num outputs (",
153             output_shapes.size(), " vs. ", c->num_outputs());
154       }
155       for (size_t i = 0; i < output_shapes.size(); ++i) {
156         shape_inference::ShapeHandle output_shape_handle;
157         TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
158             output_shapes[i], &output_shape_handle));
159         c->set_output(static_cast<int>(i), output_shape_handle);
160       }
161       return Status::OK();
162     });
163 
164 // TODO(drpng): remove this.
165 REGISTER_OP("_While")
166     .Input("input: T")
167     .Output("output: T")
168     .Attr("T: list(type) >= 0")
169     .Attr("cond: func")
170     .Attr("body: func")
171     .SetIsStateful()
__anon22ceab0f0402(shape_inference::InferenceContext* c) 172     .SetShapeFn([](shape_inference::InferenceContext* c) {
173       for (int i = 0; i < c->num_outputs(); ++i) {
174         c->set_output(i, c->input(i));
175       }
176       return Status::OK();
177     })
178     .Doc(R"doc(
179 output = input; While (Cond(output)) { output = Body(output) }
180 
181 input: A list of input tensors whose types are T.
182 output: A list of output tensors whose types are T.
183 cond: A function takes 'input' and returns a tensor.  If the tensor is
184     a scalar of non-boolean, the scalar is converted to a boolean
185     according to the following rule: if the scalar is a numerical
186     value, non-zero means True and zero means False; if the scalar is
187     a string, non-empty means True and empty means False. If the
188     tensor is not a scalar, non-emptiness means True and False
189     otherwise.
190 body: A function that takes a list of tensors and returns another
191       list of tensors. Both lists have the same types as specified
192       by T.
193 )doc");
194 
195 REGISTER_OP("While")
196     .Input("input: T")
197     .Output("output: T")
198     .Attr("T: list(type) >= 0")
199     .Attr("cond: func")
200     .Attr("body: func")
201     .Attr("output_shapes: list(shape) = []")
202     .Attr("parallel_iterations: int = 10")
203     .SetIsStateful()
__anon22ceab0f0502(shape_inference::InferenceContext* c) 204     .SetShapeFn([](shape_inference::InferenceContext* c) {
205       std::vector<PartialTensorShape> output_shapes;
206       TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
207       // If `output_shapes` attr is set use that as the shapes of the outputs
208       // else use the input shapes.
209       if (!output_shapes.empty()) {
210         if (output_shapes.size() != c->num_outputs()) {
211           return errors::InvalidArgument(
212               "`output_shapes` must be the same length as num outputs (",
213               output_shapes.size(), " vs. ", c->num_outputs());
214         }
215         for (size_t i = 0; i < output_shapes.size(); ++i) {
216           shape_inference::ShapeHandle output_shape_handle;
217           TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
218               output_shapes[i], &output_shape_handle));
219           c->set_output(static_cast<int>(i), output_shape_handle);
220         }
221       } else {
222         for (int i = 0; i < c->num_outputs(); ++i) {
223           c->set_output(i, c->input(i));
224         }
225       }
226       return Status::OK();
227     });
228 
229 REGISTER_OP("StatelessWhile")
230     .Input("input: T")
231     .Output("output: T")
232     .Attr("T: list(type) >= 0")
233     .Attr("cond: func")
234     .Attr("body: func")
__anon22ceab0f0602(shape_inference::InferenceContext* c) 235     .SetShapeFn([](shape_inference::InferenceContext* c) {
236       for (int i = 0; i < c->num_outputs(); ++i) {
237         c->set_output(i, c->input(i));
238       }
239       return Status::OK();
240     });
241 
242 REGISTER_OP("For")
243     .Input("start: int32")
244     .Input("limit: int32")
245     .Input("delta: int32")
246     .Input("input: T")
247     .Output("output: T")
248     .Attr("T: list(type) >= 0")
249     .Attr("body: func")
250     .SetShapeFn(shape_inference::UnknownShape);
251 
252 REGISTER_OP("PartitionedCall")
253     .Input("args: Tin")
254     .Output("output: Tout")
255     .Attr("Tin: list(type) >= 0")
256     .Attr("Tout: list(type) >= 0")
257     .Attr("f: func")
258     .Attr("config: string = ''")
259     .Attr("config_proto: string = ''")
260     .Attr("executor_type: string = ''")
261     .SetShapeFn(shape_inference::UnknownShape);
262 
263 REGISTER_OP("StatefulPartitionedCall")
264     .Input("args: Tin")
265     .Output("output: Tout")
266     .Attr("Tin: list(type) >= 0")
267     .Attr("Tout: list(type) >= 0")
268     .Attr("f: func")
269     .Attr("config: string = ''")  // Deprecated in favor of config_proto
270     .Attr("config_proto: string = ''")
271     .Attr("executor_type: string = ''")
272     .SetIsStateful()
273     .SetShapeFn(shape_inference::UnknownShape);
274 
275 // This op is used as a placeholder in If branch functions. It doesn't provide a
276 // valid output when run, so must either be removed (e.g. replaced with a
277 // function input) or guaranteed not to be used (e.g. if mirroring an
278 // intermediate output needed for the gradient computation of the other branch).
279 REGISTER_OP("FakeParam")
280     .Output("output: dtype")
281     .Attr("dtype: type")
282     .Attr("shape: shape")
__anon22ceab0f0702(InferenceContext* c) 283     .SetShapeFn([](InferenceContext* c) {
284       PartialTensorShape shape;
285       TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
286       shape_inference::ShapeHandle out;
287       TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &out));
288       c->set_output(0, out);
289       return Status::OK();
290     });
291 
292 }  // end namespace tensorflow
293