• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/op.h"
18 
19 namespace tensorflow {
20 
21 using shape_inference::InferenceContext;
22 using shape_inference::ShapeHandle;
23 
24 REGISTER_OP("VariableV2")
25     .Output("ref: Ref(dtype)")
26     .Attr("shape: shape")
27     .Attr("dtype: type")
28     .Attr("container: string = ''")
29     .Attr("shared_name: string = ''")
30     .SetIsStateful()
31     .SetShapeFn(shape_inference::ExplicitShape);
32 
33 REGISTER_OP("Variable")
34     .Output("ref: Ref(dtype)")
35     .Attr("shape: shape")
36     .Attr("dtype: type")
37     .Attr("container: string = ''")
38     .Attr("shared_name: string = ''")
39     .SetIsStateful()
__anone3ce2a1d0102(InferenceContext* c) 40     .SetShapeFn([](InferenceContext* c) {
41       PartialTensorShape shape;
42       TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
43 
44       // Variable has legacy behavior where we cannot tell the difference
45       // between a scalar shape attribute and 'unknown shape'.  So if the shape
46       // is a scalar, we return an unknown shape.
47       if (shape.dims() <= 0) {
48         return shape_inference::UnknownShape(c);
49       }
50 
51       ShapeHandle out;
52       TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &out));
53       c->set_output(0, out);
54       return Status::OK();
55     });
56 
57 REGISTER_OP("IsVariableInitialized")
58     .Input("ref: Ref(dtype)")
59     .Output("is_initialized: bool")
60     .Attr("dtype: type")
61     .SetAllowsUninitializedInput()
62     .SetShapeFn(shape_inference::ScalarShape);
63 
64 REGISTER_OP("TemporaryVariable")
65     .Output("ref: Ref(dtype)")
66     .Attr("shape: shape")
67     .Attr("dtype: type")
68     .Attr("var_name: string = ''")
69     .SetIsStateful()
70     .SetShapeFn(shape_inference::ExplicitShape);
71 
72 REGISTER_OP("DestroyTemporaryVariable")
73     .Input("ref: Ref(T)")
74     .Output("value: T")
75     .Attr("T: type")
76     .Attr("var_name: string")
77     .SetShapeFn(shape_inference::UnchangedShape);
78 
79 REGISTER_OP("Assign")
80     .Input("ref: Ref(T)")
81     .Input("value: T")
82     .Output("output_ref: Ref(T)")
83     .Attr("T: type")
84     .Attr("validate_shape: bool = true")
85     .Attr("use_locking: bool = true")
86     .SetAllowsUninitializedInput()
__anone3ce2a1d0202(InferenceContext* c) 87     .SetShapeFn([](InferenceContext* c) {
88       bool validate_shape;
89       TF_RETURN_IF_ERROR(c->GetAttr("validate_shape", &validate_shape));
90       if (validate_shape) {
91         return shape_inference::MergeBothInputsShapeFn(c);
92       }
93 
94       c->set_output(0, c->input(1));
95       return Status::OK();
96     });
97 
98 REGISTER_OP("AssignAdd")
99     .Input("ref: Ref(T)")
100     .Input("value: T")
101     .Output("output_ref: Ref(T)")
102     .Attr("T: numbertype")
103     .Attr("use_locking: bool = false")
104     .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
105 
106 REGISTER_OP("AssignSub")
107     .Input("ref: Ref(T)")
108     .Input("value: T")
109     .Output("output_ref: Ref(T)")
110     .Attr("T: numbertype")
111     .Attr("use_locking: bool = false")
112     .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
113 
114 namespace {
115 
ScatterUpdateShape(InferenceContext * c)116 Status ScatterUpdateShape(InferenceContext* c) {
117   ShapeHandle var_shape = c->input(0);
118   ShapeHandle indices_shape = c->input(1);
119 
120   ShapeHandle unused_updates_shape;
121   ShapeHandle concat;
122   ShapeHandle var_subshape;
123   TF_RETURN_IF_ERROR(c->Subshape(var_shape, 1, &var_subshape));
124   TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, var_subshape, &concat));
125   TF_RETURN_IF_ERROR(
126       InferenceContext::Rank(c->input(2)) == 0
127           ? Status::OK()
128           : c->Merge(c->input(2), concat, &unused_updates_shape));
129 
130   c->set_output(0, var_shape);
131   return Status::OK();
132 }
133 
134 }  // namespace
135 
136 REGISTER_OP("ScatterUpdate")
137     .Input("ref: Ref(T)")
138     .Input("indices: Tindices")
139     .Input("updates: T")
140     .Output("output_ref: Ref(T)")
141     .Attr("T: type")
142     .Attr("Tindices: {int32, int64}")
143     .Attr("use_locking: bool = true")
144     .SetShapeFn(ScatterUpdateShape);
145 
146 REGISTER_OP("ScatterAdd")
147     .Input("ref: Ref(T)")
148     .Input("indices: Tindices")
149     .Input("updates: T")
150     .Output("output_ref: Ref(T)")
151     .Attr("T: numbertype")
152     .Attr("Tindices: {int32, int64}")
153     .Attr("use_locking: bool = false")
154     .SetShapeFn(ScatterUpdateShape);
155 
156 REGISTER_OP("ScatterSub")
157     .Input("ref: Ref(T)")
158     .Input("indices: Tindices")
159     .Input("updates: T")
160     .Output("output_ref: Ref(T)")
161     .Attr("T: numbertype")
162     .Attr("Tindices: {int32, int64}")
163     .Attr("use_locking: bool = false")
164     .SetShapeFn(ScatterUpdateShape);
165 
166 REGISTER_OP("ScatterMul")
167     .Input("ref: Ref(T)")
168     .Input("indices: Tindices")
169     .Input("updates: T")
170     .Output("output_ref: Ref(T)")
171     .Attr("T: numbertype")
172     .Attr("Tindices: {int32, int64}")
173     .Attr("use_locking: bool = false")
174     .SetShapeFn(ScatterUpdateShape);
175 
176 REGISTER_OP("ScatterDiv")
177     .Input("ref: Ref(T)")
178     .Input("indices: Tindices")
179     .Input("updates: T")
180     .Output("output_ref: Ref(T)")
181     .Attr("T: numbertype")
182     .Attr("Tindices: {int32, int64}")
183     .Attr("use_locking: bool = false")
184     .SetShapeFn(ScatterUpdateShape);
185 
186 REGISTER_OP("ScatterMin")
187     .Input("ref: Ref(T)")
188     .Input("indices: Tindices")
189     .Input("updates: T")
190     .Output("output_ref: Ref(T)")
191     .Attr("T: {half, bfloat16, float, double, int32, int64}")
192     .Attr("Tindices: {int32, int64}")
193     .Attr("use_locking: bool = false")
194     .SetShapeFn(ScatterUpdateShape);
195 
196 REGISTER_OP("ScatterMax")
197     .Input("ref: Ref(T)")
198     .Input("indices: Tindices")
199     .Input("updates: T")
200     .Output("output_ref: Ref(T)")
201     .Attr("T: {half, bfloat16, float, double, int32, int64}")
202     .Attr("Tindices: {int32, int64}")
203     .Attr("use_locking: bool = false")
204     .SetShapeFn(ScatterUpdateShape);
205 
206 REGISTER_OP("ScatterNdUpdate")
207     .Input("ref: Ref(T)")
208     .Input("indices: Tindices")
209     .Input("updates: T")
210     .Output("output_ref: Ref(T)")
211     .Attr("T: type")
212     .Attr("Tindices: {int32, int64}")
213     .Attr("use_locking: bool = true")
214     .SetShapeFn(shape_inference::ScatterNdUpdateShape);
215 
216 REGISTER_OP("ResourceScatterNdUpdate")
217     .Input("ref: resource")
218     .Input("indices: Tindices")
219     .Input("updates: T")
220     .Attr("T: type")
221     .Attr("Tindices: {int32, int64}")
222     .Attr("use_locking: bool = true")
223     .SetShapeFn(shape_inference::ScatterNdUpdateShape);
224 
225 REGISTER_OP("ResourceScatterNdAdd")
226     .Input("ref: resource")
227     .Input("indices: Tindices")
228     .Input("updates: T")
229     .Attr("T: type")
230     .Attr("Tindices: {int32, int64}")
231     .Attr("use_locking: bool = true")
232     .SetShapeFn(shape_inference::ScatterNdUpdateShape);
233 
234 REGISTER_OP("ResourceScatterNdSub")
235     .Input("ref: resource")
236     .Input("indices: Tindices")
237     .Input("updates: T")
238     .Attr("T: type")
239     .Attr("Tindices: {int32, int64}")
240     .Attr("use_locking: bool = true")
241     .SetShapeFn(shape_inference::ScatterNdUpdateShape);
242 
243 REGISTER_OP("ScatterNdAdd")
244     .Input("ref: Ref(T)")
245     .Input("indices: Tindices")
246     .Input("updates: T")
247     .Output("output_ref: Ref(T)")
248     .Attr("T: numbertype")
249     .Attr("Tindices: {int32, int64}")
250     .Attr("use_locking: bool = false")
251     .SetShapeFn(shape_inference::ScatterNdUpdateShape);
252 
253 REGISTER_OP("ScatterNdSub")
254     .Input("ref: Ref(T)")
255     .Input("indices: Tindices")
256     .Input("updates: T")
257     .Output("output_ref: Ref(T)")
258     .Attr("T: numbertype")
259     .Attr("Tindices: {int32, int64}")
260     .Attr("use_locking: bool = false")
261     .SetShapeFn(shape_inference::ScatterNdUpdateShape);
262 
263 REGISTER_OP("CountUpTo")
264     .Input("ref: Ref(T)")
265     .Output("output: T")
266     .Attr("limit: int")
267     .Attr("T: {int32, int64}")
__anone3ce2a1d0402(InferenceContext* c) 268     .SetShapeFn([](InferenceContext* c) {
269       ShapeHandle output;
270       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &output));
271       c->set_output(0, output);
272       return Status::OK();
273     });
274 
275 REGISTER_OP("ResourceCountUpTo")
276     .Input("resource: resource")
277     .Output("output: T")
278     .Attr("limit: int")
279     .Attr("T: {int32, int64}")
__anone3ce2a1d0502(InferenceContext* c) 280     .SetShapeFn([](InferenceContext* c) {
281       auto* handle_data = c->input_handle_shapes_and_types(0);
282       if (handle_data == nullptr || handle_data->empty()) {
283         return errors::InvalidArgument("Handle has no shape/type information.");
284       }
285       shape_inference::ShapeAndType shape_and_type = (*handle_data)[0];
286       DataType value_dtype;
287       TF_RETURN_IF_ERROR(c->GetAttr("T", &value_dtype));
288       if (value_dtype != shape_and_type.dtype) {
289         return errors::InvalidArgument(
290             "Data types do not match: ", DataTypeString(value_dtype), " and ",
291             DataTypeString(shape_and_type.dtype));
292       }
293       ShapeHandle output;
294       TF_RETURN_IF_ERROR(c->WithRank(shape_and_type.shape, 0, &output));
295       c->set_output(0, output);
296       return Status::OK();
297     });
298 
299 }  // namespace tensorflow
300