• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/shape_inference.h"
19 
20 namespace tensorflow {
21 
22 using shape_inference::InferenceContext;
23 using shape_inference::ShapeHandle;
24 
25 // --------------------------------------------------------------------------
26 namespace {
SwitchShape(InferenceContext * c)27 Status SwitchShape(InferenceContext* c) {
28   ShapeHandle unused;
29   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
30   ShapeHandle out = c->input(0);
31   c->set_output(0, out);
32   c->set_output(1, out);
33 
34   // Handle resource shape / dtype.
35   auto* handle_data = c->input_handle_shapes_and_types(0);
36   if (handle_data != nullptr) {
37     c->set_output_handle_shapes_and_types(0, *handle_data);
38     c->set_output_handle_shapes_and_types(1, *handle_data);
39   }
40   return Status::OK();
41 }
42 }  // namespace
43 
44 REGISTER_OP("Switch")
45     .Input("data: T")
46     .Input("pred: bool")
47     .Output("output_false: T")
48     .Output("output_true: T")
49     .Attr("T: type")
50     .SetShapeFn(SwitchShape);
51 
52 REGISTER_OP("RefSwitch")
53     .Input("data: Ref(T)")
54     .Input("pred: bool")
55     .Output("output_false: Ref(T)")
56     .Output("output_true: Ref(T)")
57     .Attr("T: type")
58     .SetAllowsUninitializedInput()
59     .SetShapeFn(SwitchShape);
60 
61 // --------------------------------------------------------------------------
62 REGISTER_OP("RefSelect")
63     .Input("index: int32")
64     .Input("inputs: Ref(N * T)")
65     .Output("output: Ref(T)")
66     .Attr("T: type")
67     .Attr("N: int >= 1")
__anonad90f1540202(InferenceContext* c) 68     .SetShapeFn([](InferenceContext* c) {
69       ShapeHandle unused;
70       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
71       ShapeHandle first_input = c->input(1);
72       if (!c->FullyDefined(first_input)) {
73         c->set_output(0, c->UnknownShape());
74         return Status::OK();
75       }
76       // If any inputs aren't fully defined or don't match, we return unknown.
77       for (int i = 2; i < c->num_inputs(); ++i) {
78         ShapeHandle input = c->input(i);
79         if (!c->FullyDefined(input) ||
80             !c->Merge(first_input, input, &unused).ok()) {
81           c->set_output(0, c->UnknownShape());
82           return Status::OK();
83         }
84       }
85       c->set_output(0, first_input);
86       return Status::OK();
87     });
88 
89 // --------------------------------------------------------------------------
90 namespace {
MergeShape(InferenceContext * c)91 Status MergeShape(InferenceContext* c) {
92   ShapeHandle out = c->input(0);
93   if (!c->RankKnown(out)) {
94     out = c->UnknownShape();
95   } else {
96     int32 rank = c->Rank(out);
97     for (int i = 1; i < c->num_inputs(); ++i) {
98       ShapeHandle input = c->input(i);
99       if (!c->RankKnown(input) || c->Rank(input) != rank) {
100         out = c->UnknownShape();
101         break;
102       }
103 
104       for (int d = 0; d < rank; ++d) {
105         if (c->Value(c->Dim(input, d)) != c->Value(c->Dim(out, d))) {
106           TF_RETURN_IF_ERROR(c->ReplaceDim(out, d, c->UnknownDim(), &out));
107         }
108       }
109     }
110   }
111   c->set_output(0, out);
112   c->set_output(1, c->Scalar());
113   return Status::OK();
114 }
115 }  // namespace
116 
117 REGISTER_OP("Merge")
118     .Input("inputs: N * T")
119     .Output("output: T")
120     .Output("value_index: int32")
121     .Attr("T: type")
122     .Attr("N: int >= 1")
123     .SetShapeFn(MergeShape);
124 
125 REGISTER_OP("RefMerge")
126     .Input("inputs: Ref(N * T)")
127     .Output("output: Ref(T)")
128     .Output("value_index: int32")
129     .Attr("T: type")
130     .Attr("N: int >= 1")
131     .SetShapeFn(MergeShape);
132 
133 // --------------------------------------------------------------------------
134 REGISTER_OP("Enter")
135     .Input("data: T")
136     .Output("output: T")
137     .Attr("T: type")
138     .Attr("frame_name: string")
139     .Attr("is_constant: bool = false")
140     .Attr("parallel_iterations: int = 10")
__anonad90f1540402(InferenceContext* c) 141     .SetShapeFn([](InferenceContext* c) {
142       c->set_output(0, c->UnknownShape());
143 
144       // Handle resource shape / dtype, if present.
145       auto* handle_data = c->input_handle_shapes_and_types(0);
146       if (handle_data != nullptr) {
147         c->set_output_handle_shapes_and_types(0, *handle_data);
148       }
149       // Propagate shape if output is a constant.
150       bool is_constant;
151       TF_RETURN_IF_ERROR(c->GetAttr("is_constant", &is_constant));
152       if (is_constant) {
153         c->set_output(0, c->input(0));
154       }
155 
156       return Status::OK();
157     });
158 
159 // --------------------------------------------------------------------------
160 REGISTER_OP("RefEnter")
161     .Input("data: Ref(T)")
162     .Output("output: Ref(T)")
163     .Attr("T: type")
164     .Attr("frame_name: string")
165     .Attr("is_constant: bool = false")
166     .Attr("parallel_iterations: int = 10")
167     .SetShapeFn(shape_inference::UnchangedShape);
168 
169 // --------------------------------------------------------------------------
170 REGISTER_OP("Exit")
171     .Input("data: T")
172     .Output("output: T")
173     .Attr("T: type")
174     .SetShapeFn(shape_inference::UnchangedShape);
175 
176 REGISTER_OP("RefExit")
177     .Input("data: Ref(T)")
178     .Output("output: Ref(T)")
179     .Attr("T: type")
180     .SetShapeFn(shape_inference::UnchangedShape);
181 
182 // --------------------------------------------------------------------------
183 REGISTER_OP("NextIteration")
184     .Input("data: T")
185     .Output("output: T")
186     .Attr("T: type")
187     .SetShapeFn(shape_inference::UnchangedShape);
188 
189 REGISTER_OP("RefNextIteration")
190     .Input("data: Ref(T)")
191     .Output("output: Ref(T)")
192     .Attr("T: type")
193     .SetShapeFn(shape_inference::UnchangedShape);
194 
195 // --------------------------------------------------------------------------
196 REGISTER_OP("LoopCond")
197     .Input("input: bool")
198     .Output("output: bool")
__anonad90f1540502(InferenceContext* c) 199     .SetShapeFn([](InferenceContext* c) {
200       return shape_inference::UnchangedShapeWithRank(c, 0);
201     });
202 
203 // --------------------------------------------------------------------------
204 REGISTER_OP("ControlTrigger").SetShapeFn(shape_inference::NoOutputs);
205 
206 // --------------------------------------------------------------------------
207 REGISTER_OP("Abort")
208     .Attr("error_msg: string = ''")
209     .Attr("exit_without_error: bool = false")
210     .SetShapeFn(shape_inference::NoOutputs);
211 
212 }  // namespace tensorflow
213