1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/op.h"
18
19 namespace tensorflow {
20
21 using shape_inference::InferenceContext;
22 using shape_inference::ShapeHandle;
23
24 REGISTER_OP("VariableV2")
25 .Output("ref: Ref(dtype)")
26 .Attr("shape: shape")
27 .Attr("dtype: type")
28 .Attr("container: string = ''")
29 .Attr("shared_name: string = ''")
30 .SetIsStateful()
31 .SetShapeFn(shape_inference::ExplicitShape);
32
33 REGISTER_OP("Variable")
34 .Output("ref: Ref(dtype)")
35 .Attr("shape: shape")
36 .Attr("dtype: type")
37 .Attr("container: string = ''")
38 .Attr("shared_name: string = ''")
39 .SetIsStateful()
__anone3ce2a1d0102(InferenceContext* c) 40 .SetShapeFn([](InferenceContext* c) {
41 PartialTensorShape shape;
42 TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
43
44 // Variable has legacy behavior where we cannot tell the difference
45 // between a scalar shape attribute and 'unknown shape'. So if the shape
46 // is a scalar, we return an unknown shape.
47 if (shape.dims() <= 0) {
48 return shape_inference::UnknownShape(c);
49 }
50
51 ShapeHandle out;
52 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &out));
53 c->set_output(0, out);
54 return Status::OK();
55 });
56
57 REGISTER_OP("IsVariableInitialized")
58 .Input("ref: Ref(dtype)")
59 .Output("is_initialized: bool")
60 .Attr("dtype: type")
61 .SetAllowsUninitializedInput()
62 .SetShapeFn(shape_inference::ScalarShape);
63
64 REGISTER_OP("TemporaryVariable")
65 .Output("ref: Ref(dtype)")
66 .Attr("shape: shape")
67 .Attr("dtype: type")
68 .Attr("var_name: string = ''")
69 .SetIsStateful()
70 .SetShapeFn(shape_inference::ExplicitShape);
71
72 REGISTER_OP("DestroyTemporaryVariable")
73 .Input("ref: Ref(T)")
74 .Output("value: T")
75 .Attr("T: type")
76 .Attr("var_name: string")
77 .SetShapeFn(shape_inference::UnchangedShape);
78
79 REGISTER_OP("Assign")
80 .Input("ref: Ref(T)")
81 .Input("value: T")
82 .Output("output_ref: Ref(T)")
83 .Attr("T: type")
84 .Attr("validate_shape: bool = true")
85 .Attr("use_locking: bool = true")
86 .SetAllowsUninitializedInput()
__anone3ce2a1d0202(InferenceContext* c) 87 .SetShapeFn([](InferenceContext* c) {
88 bool validate_shape;
89 TF_RETURN_IF_ERROR(c->GetAttr("validate_shape", &validate_shape));
90 if (validate_shape) {
91 return shape_inference::MergeBothInputsShapeFn(c);
92 }
93
94 c->set_output(0, c->input(1));
95 return Status::OK();
96 });
97
98 REGISTER_OP("AssignAdd")
99 .Input("ref: Ref(T)")
100 .Input("value: T")
101 .Output("output_ref: Ref(T)")
102 .Attr("T: numbertype")
103 .Attr("use_locking: bool = false")
104 .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
105
106 REGISTER_OP("AssignSub")
107 .Input("ref: Ref(T)")
108 .Input("value: T")
109 .Output("output_ref: Ref(T)")
110 .Attr("T: numbertype")
111 .Attr("use_locking: bool = false")
112 .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
113
114 namespace {
115
ScatterUpdateShape(InferenceContext * c)116 Status ScatterUpdateShape(InferenceContext* c) {
117 ShapeHandle var_shape = c->input(0);
118 ShapeHandle indices_shape = c->input(1);
119
120 ShapeHandle unused_updates_shape;
121 ShapeHandle concat;
122 ShapeHandle var_subshape;
123 TF_RETURN_IF_ERROR(c->Subshape(var_shape, 1, &var_subshape));
124 TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, var_subshape, &concat));
125 TF_RETURN_IF_ERROR(
126 InferenceContext::Rank(c->input(2)) == 0
127 ? Status::OK()
128 : c->Merge(c->input(2), concat, &unused_updates_shape));
129
130 c->set_output(0, var_shape);
131 return Status::OK();
132 }
133
134 } // namespace
135
136 REGISTER_OP("ScatterUpdate")
137 .Input("ref: Ref(T)")
138 .Input("indices: Tindices")
139 .Input("updates: T")
140 .Output("output_ref: Ref(T)")
141 .Attr("T: type")
142 .Attr("Tindices: {int32, int64}")
143 .Attr("use_locking: bool = true")
144 .SetShapeFn(ScatterUpdateShape);
145
146 REGISTER_OP("ScatterAdd")
147 .Input("ref: Ref(T)")
148 .Input("indices: Tindices")
149 .Input("updates: T")
150 .Output("output_ref: Ref(T)")
151 .Attr("T: numbertype")
152 .Attr("Tindices: {int32, int64}")
153 .Attr("use_locking: bool = false")
154 .SetShapeFn(ScatterUpdateShape);
155
156 REGISTER_OP("ScatterSub")
157 .Input("ref: Ref(T)")
158 .Input("indices: Tindices")
159 .Input("updates: T")
160 .Output("output_ref: Ref(T)")
161 .Attr("T: numbertype")
162 .Attr("Tindices: {int32, int64}")
163 .Attr("use_locking: bool = false")
164 .SetShapeFn(ScatterUpdateShape);
165
166 REGISTER_OP("ScatterMul")
167 .Input("ref: Ref(T)")
168 .Input("indices: Tindices")
169 .Input("updates: T")
170 .Output("output_ref: Ref(T)")
171 .Attr("T: numbertype")
172 .Attr("Tindices: {int32, int64}")
173 .Attr("use_locking: bool = false")
174 .SetShapeFn(ScatterUpdateShape);
175
176 REGISTER_OP("ScatterDiv")
177 .Input("ref: Ref(T)")
178 .Input("indices: Tindices")
179 .Input("updates: T")
180 .Output("output_ref: Ref(T)")
181 .Attr("T: numbertype")
182 .Attr("Tindices: {int32, int64}")
183 .Attr("use_locking: bool = false")
184 .SetShapeFn(ScatterUpdateShape);
185
186 REGISTER_OP("ScatterMin")
187 .Input("ref: Ref(T)")
188 .Input("indices: Tindices")
189 .Input("updates: T")
190 .Output("output_ref: Ref(T)")
191 .Attr("T: {half, bfloat16, float, double, int32, int64}")
192 .Attr("Tindices: {int32, int64}")
193 .Attr("use_locking: bool = false")
194 .SetShapeFn(ScatterUpdateShape);
195
196 REGISTER_OP("ScatterMax")
197 .Input("ref: Ref(T)")
198 .Input("indices: Tindices")
199 .Input("updates: T")
200 .Output("output_ref: Ref(T)")
201 .Attr("T: {half, bfloat16, float, double, int32, int64}")
202 .Attr("Tindices: {int32, int64}")
203 .Attr("use_locking: bool = false")
204 .SetShapeFn(ScatterUpdateShape);
205
206 REGISTER_OP("ScatterNdUpdate")
207 .Input("ref: Ref(T)")
208 .Input("indices: Tindices")
209 .Input("updates: T")
210 .Output("output_ref: Ref(T)")
211 .Attr("T: type")
212 .Attr("Tindices: {int32, int64}")
213 .Attr("use_locking: bool = true")
214 .SetShapeFn(shape_inference::ScatterNdUpdateShape);
215
216 REGISTER_OP("ResourceScatterNdUpdate")
217 .Input("ref: resource")
218 .Input("indices: Tindices")
219 .Input("updates: T")
220 .Attr("T: type")
221 .Attr("Tindices: {int32, int64}")
222 .Attr("use_locking: bool = true")
223 .SetShapeFn(shape_inference::ScatterNdUpdateShape);
224
225 REGISTER_OP("ResourceScatterNdAdd")
226 .Input("ref: resource")
227 .Input("indices: Tindices")
228 .Input("updates: T")
229 .Attr("T: type")
230 .Attr("Tindices: {int32, int64}")
231 .Attr("use_locking: bool = true")
232 .SetShapeFn(shape_inference::ScatterNdUpdateShape);
233
234 REGISTER_OP("ResourceScatterNdSub")
235 .Input("ref: resource")
236 .Input("indices: Tindices")
237 .Input("updates: T")
238 .Attr("T: type")
239 .Attr("Tindices: {int32, int64}")
240 .Attr("use_locking: bool = true")
241 .SetShapeFn(shape_inference::ScatterNdUpdateShape);
242
243 REGISTER_OP("ScatterNdAdd")
244 .Input("ref: Ref(T)")
245 .Input("indices: Tindices")
246 .Input("updates: T")
247 .Output("output_ref: Ref(T)")
248 .Attr("T: numbertype")
249 .Attr("Tindices: {int32, int64}")
250 .Attr("use_locking: bool = false")
251 .SetShapeFn(shape_inference::ScatterNdUpdateShape);
252
253 REGISTER_OP("ScatterNdSub")
254 .Input("ref: Ref(T)")
255 .Input("indices: Tindices")
256 .Input("updates: T")
257 .Output("output_ref: Ref(T)")
258 .Attr("T: numbertype")
259 .Attr("Tindices: {int32, int64}")
260 .Attr("use_locking: bool = false")
261 .SetShapeFn(shape_inference::ScatterNdUpdateShape);
262
263 REGISTER_OP("CountUpTo")
264 .Input("ref: Ref(T)")
265 .Output("output: T")
266 .Attr("limit: int")
267 .Attr("T: {int32, int64}")
__anone3ce2a1d0402(InferenceContext* c) 268 .SetShapeFn([](InferenceContext* c) {
269 ShapeHandle output;
270 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &output));
271 c->set_output(0, output);
272 return Status::OK();
273 });
274
275 REGISTER_OP("ResourceCountUpTo")
276 .Input("resource: resource")
277 .Output("output: T")
278 .Attr("limit: int")
279 .Attr("T: {int32, int64}")
__anone3ce2a1d0502(InferenceContext* c) 280 .SetShapeFn([](InferenceContext* c) {
281 auto* handle_data = c->input_handle_shapes_and_types(0);
282 if (handle_data == nullptr || handle_data->empty()) {
283 return errors::InvalidArgument("Handle has no shape/type information.");
284 }
285 shape_inference::ShapeAndType shape_and_type = (*handle_data)[0];
286 DataType value_dtype;
287 TF_RETURN_IF_ERROR(c->GetAttr("T", &value_dtype));
288 if (value_dtype != shape_and_type.dtype) {
289 return errors::InvalidArgument(
290 "Data types do not match: ", DataTypeString(value_dtype), " and ",
291 DataTypeString(shape_and_type.dtype));
292 }
293 ShapeHandle output;
294 TF_RETURN_IF_ERROR(c->WithRank(shape_and_type.shape, 0, &output));
295 c->set_output(0, output);
296 return Status::OK();
297 });
298
299 } // namespace tensorflow
300