1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/shape_inference.h"
19
20 namespace tensorflow {
21
22 using shape_inference::InferenceContext;
23 using shape_inference::ShapeHandle;
24
25 // --------------------------------------------------------------------------
26 namespace {
SwitchShape(InferenceContext * c)27 Status SwitchShape(InferenceContext* c) {
28 ShapeHandle unused;
29 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
30 ShapeHandle out = c->input(0);
31 c->set_output(0, out);
32 c->set_output(1, out);
33
34 // Handle resource shape / dtype.
35 auto* handle_data = c->input_handle_shapes_and_types(0);
36 if (handle_data != nullptr) {
37 c->set_output_handle_shapes_and_types(0, *handle_data);
38 c->set_output_handle_shapes_and_types(1, *handle_data);
39 }
40 return Status::OK();
41 }
42 } // namespace
43
44 REGISTER_OP("Switch")
45 .Input("data: T")
46 .Input("pred: bool")
47 .Output("output_false: T")
48 .Output("output_true: T")
49 .Attr("T: type")
50 .SetShapeFn(SwitchShape);
51
52 REGISTER_OP("RefSwitch")
53 .Input("data: Ref(T)")
54 .Input("pred: bool")
55 .Output("output_false: Ref(T)")
56 .Output("output_true: Ref(T)")
57 .Attr("T: type")
58 .SetAllowsUninitializedInput()
59 .SetShapeFn(SwitchShape);
60
61 // --------------------------------------------------------------------------
62 REGISTER_OP("RefSelect")
63 .Input("index: int32")
64 .Input("inputs: Ref(N * T)")
65 .Output("output: Ref(T)")
66 .Attr("T: type")
67 .Attr("N: int >= 1")
__anonad90f1540202(InferenceContext* c) 68 .SetShapeFn([](InferenceContext* c) {
69 ShapeHandle unused;
70 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
71 ShapeHandle first_input = c->input(1);
72 if (!c->FullyDefined(first_input)) {
73 c->set_output(0, c->UnknownShape());
74 return Status::OK();
75 }
76 // If any inputs aren't fully defined or don't match, we return unknown.
77 for (int i = 2; i < c->num_inputs(); ++i) {
78 ShapeHandle input = c->input(i);
79 if (!c->FullyDefined(input) ||
80 !c->Merge(first_input, input, &unused).ok()) {
81 c->set_output(0, c->UnknownShape());
82 return Status::OK();
83 }
84 }
85 c->set_output(0, first_input);
86 return Status::OK();
87 });
88
89 // --------------------------------------------------------------------------
90 namespace {
MergeShape(InferenceContext * c)91 Status MergeShape(InferenceContext* c) {
92 ShapeHandle out = c->input(0);
93 if (!c->RankKnown(out)) {
94 out = c->UnknownShape();
95 } else {
96 int32 rank = c->Rank(out);
97 for (int i = 1; i < c->num_inputs(); ++i) {
98 ShapeHandle input = c->input(i);
99 if (!c->RankKnown(input) || c->Rank(input) != rank) {
100 out = c->UnknownShape();
101 break;
102 }
103
104 for (int d = 0; d < rank; ++d) {
105 if (c->Value(c->Dim(input, d)) != c->Value(c->Dim(out, d))) {
106 TF_RETURN_IF_ERROR(c->ReplaceDim(out, d, c->UnknownDim(), &out));
107 }
108 }
109 }
110 }
111 c->set_output(0, out);
112 c->set_output(1, c->Scalar());
113 return Status::OK();
114 }
115 } // namespace
116
117 REGISTER_OP("Merge")
118 .Input("inputs: N * T")
119 .Output("output: T")
120 .Output("value_index: int32")
121 .Attr("T: type")
122 .Attr("N: int >= 1")
123 .SetShapeFn(MergeShape);
124
125 REGISTER_OP("RefMerge")
126 .Input("inputs: Ref(N * T)")
127 .Output("output: Ref(T)")
128 .Output("value_index: int32")
129 .Attr("T: type")
130 .Attr("N: int >= 1")
131 .SetShapeFn(MergeShape);
132
133 // --------------------------------------------------------------------------
134 REGISTER_OP("Enter")
135 .Input("data: T")
136 .Output("output: T")
137 .Attr("T: type")
138 .Attr("frame_name: string")
139 .Attr("is_constant: bool = false")
140 .Attr("parallel_iterations: int = 10")
__anonad90f1540402(InferenceContext* c) 141 .SetShapeFn([](InferenceContext* c) {
142 c->set_output(0, c->UnknownShape());
143
144 // Handle resource shape / dtype, if present.
145 auto* handle_data = c->input_handle_shapes_and_types(0);
146 if (handle_data != nullptr) {
147 c->set_output_handle_shapes_and_types(0, *handle_data);
148 }
149 // Propagate shape if output is a constant.
150 bool is_constant;
151 TF_RETURN_IF_ERROR(c->GetAttr("is_constant", &is_constant));
152 if (is_constant) {
153 c->set_output(0, c->input(0));
154 }
155
156 return Status::OK();
157 });
158
159 // --------------------------------------------------------------------------
160 REGISTER_OP("RefEnter")
161 .Input("data: Ref(T)")
162 .Output("output: Ref(T)")
163 .Attr("T: type")
164 .Attr("frame_name: string")
165 .Attr("is_constant: bool = false")
166 .Attr("parallel_iterations: int = 10")
167 .SetShapeFn(shape_inference::UnchangedShape);
168
169 // --------------------------------------------------------------------------
170 REGISTER_OP("Exit")
171 .Input("data: T")
172 .Output("output: T")
173 .Attr("T: type")
174 .SetShapeFn(shape_inference::UnchangedShape);
175
176 REGISTER_OP("RefExit")
177 .Input("data: Ref(T)")
178 .Output("output: Ref(T)")
179 .Attr("T: type")
180 .SetShapeFn(shape_inference::UnchangedShape);
181
182 // --------------------------------------------------------------------------
183 REGISTER_OP("NextIteration")
184 .Input("data: T")
185 .Output("output: T")
186 .Attr("T: type")
187 .SetShapeFn(shape_inference::UnchangedShape);
188
189 REGISTER_OP("RefNextIteration")
190 .Input("data: Ref(T)")
191 .Output("output: Ref(T)")
192 .Attr("T: type")
193 .SetShapeFn(shape_inference::UnchangedShape);
194
195 // --------------------------------------------------------------------------
196 REGISTER_OP("LoopCond")
197 .Input("input: bool")
198 .Output("output: bool")
__anonad90f1540502(InferenceContext* c) 199 .SetShapeFn([](InferenceContext* c) {
200 return shape_inference::UnchangedShapeWithRank(c, 0);
201 });
202
203 // --------------------------------------------------------------------------
204 REGISTER_OP("ControlTrigger").SetShapeFn(shape_inference::NoOutputs);
205
206 // --------------------------------------------------------------------------
207 REGISTER_OP("Abort")
208 .Attr("error_msg: string = ''")
209 .Attr("exit_without_error: bool = false")
210 .SetShapeFn(shape_inference::NoOutputs);
211
212 } // namespace tensorflow
213