1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/common_shape_fns.h" 17 #include "tensorflow/core/framework/op.h" 18 #include "tensorflow/core/framework/shape_inference.h" 19 20 namespace tensorflow { 21 22 using shape_inference::InferenceContext; 23 24 REGISTER_OP("SymbolicGradient") 25 .Input("input: Tin") 26 .Output("output: Tout") 27 .Attr("Tin: list(type)") 28 .Attr("Tout: list(type)") 29 .Attr("f: func") __anon22ceab0f0102(InferenceContext* c) 30 .SetShapeFn([](InferenceContext* c) { 31 if (c->num_inputs() < c->num_outputs()) { 32 return errors::InvalidArgument("len(inputs) < len(outputs)"); 33 } 34 std::vector<DataType> types; 35 TF_RETURN_IF_ERROR(c->GetAttr("Tin", &types)); 36 // Say, (u, v) = f(x, y, z), _symbolic_gradient(f) is a function of 37 // (x, y, z, du, dv) -> (dx, dy, dz). Therefore, shapes of its 38 // outputs (dx, dy, dz) are the same as (x, y, z). 39 for (int i = 0; i < c->num_outputs(); ++i) { 40 if (types[i] == DT_RESOURCE) { 41 const std::vector<shape_inference::ShapeAndType>* handle_type = 42 c->input_handle_shapes_and_types(i); 43 if (handle_type != nullptr) { 44 c->set_output(i, handle_type->at(0).shape); 45 } else { 46 c->set_output(i, c->UnknownShape()); 47 } 48 } else { 49 c->set_output(i, c->input(i)); 50 } 51 } 52 return Status::OK(); 53 }); 54 55 REGISTER_OP("RemoteCall") 56 .Input("target: string") 57 .Input("args: Tin") 58 .Output("output: Tout") 59 .Attr("Tin: list(type)") 60 .Attr("Tout: list(type)") 61 .Attr("f: func") 62 .SetIsStateful() 63 .SetShapeFn(shape_inference::UnknownShape); 64 65 // TODO(drpng): remove this. 66 REGISTER_OP("_If") 67 .Input("cond: Tcond") 68 .Input("input: Tin") 69 .Output("output: Tout") 70 .Attr("Tcond: type") 71 .Attr("Tin: list(type)") 72 .Attr("Tout: list(type)") 73 .Attr("then_branch: func") 74 .Attr("else_branch: func") 75 .SetIsStateful() 76 .SetShapeFn(shape_inference::UnknownShape) 77 .Doc(R"doc( 78 output = cond ? then_branch(input) : else_branch(input) 79 80 cond: A Tensor. If the tensor is a scalar of non-boolean type, the 81 scalar is converted to a boolean according to the 82 following rule: if the scalar is a numerical value, non-zero means 83 True and zero means False; if the scalar is a string, non-empty 84 means True and empty means False. If the tensor is not a scalar, 85 being empty means False and being non-empty means True. 86 input: A list of input tensors. 87 then_branch: A function that takes 'inputs' and returns a list of 88 tensors, whose types are the same as what else_branch returns. 89 else_branch: A function that takes 'inputs' and returns a list of 90 tensors. whose types are the same as what then_branch returns. 91 )doc"); 92 93 REGISTER_OP("StatelessIf") 94 .Input("cond: Tcond") 95 .Input("input: Tin") 96 .Output("output: Tout") 97 .Attr("Tcond: type") 98 .Attr("Tin: list(type) >= 0") 99 .Attr("Tout: list(type) >= 0") 100 .Attr("then_branch: func") 101 .Attr("else_branch: func") 102 .SetShapeFn(shape_inference::UnknownShape); 103 104 REGISTER_OP("If") 105 .Input("cond: Tcond") 106 .Input("input: Tin") 107 .Output("output: Tout") 108 .Attr("Tcond: type") 109 .Attr("Tin: list(type) >= 0") 110 .Attr("Tout: list(type) >= 0") 111 .Attr("then_branch: func") 112 .Attr("else_branch: func") 113 .Attr("output_shapes: list(shape) = []") 114 .SetIsStateful() __anon22ceab0f0202(shape_inference::InferenceContext* c) 115 .SetShapeFn([](shape_inference::InferenceContext* c) { 116 std::vector<PartialTensorShape> output_shapes; 117 TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); 118 // If `output_shapes` attr is set use that as the shapes of the outputs 119 // else return unknown shapes. 120 if (output_shapes.empty()) return shape_inference::UnknownShape(c); 121 if (output_shapes.size() != c->num_outputs()) { 122 return errors::InvalidArgument( 123 "`output_shapes` must be the same length as num outputs (", 124 output_shapes.size(), " vs. ", c->num_outputs()); 125 } 126 for (size_t i = 0; i < output_shapes.size(); ++i) { 127 shape_inference::ShapeHandle output_shape_handle; 128 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( 129 output_shapes[i], &output_shape_handle)); 130 c->set_output(static_cast<int>(i), output_shape_handle); 131 } 132 return Status::OK(); 133 }); 134 135 REGISTER_OP("Case") 136 .Input("branch_index: int32") 137 .Input("input: Tin") 138 .Output("output: Tout") 139 .Attr("Tin: list(type) >= 0") 140 .Attr("Tout: list(type) >= 0") 141 .Attr("branches: list(func) >= 1") 142 .Attr("output_shapes: list(shape) = []") 143 .SetIsStateful() __anon22ceab0f0302(shape_inference::InferenceContext* c) 144 .SetShapeFn([](shape_inference::InferenceContext* c) { 145 std::vector<PartialTensorShape> output_shapes; 146 TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); 147 // If `output_shapes` attr is set use that as the shapes of the outputs 148 // else return unknown shapes. 149 if (output_shapes.empty()) return shape_inference::UnknownShape(c); 150 if (output_shapes.size() != c->num_outputs()) { 151 return errors::InvalidArgument( 152 "`output_shapes` must be the same length as num outputs (", 153 output_shapes.size(), " vs. ", c->num_outputs()); 154 } 155 for (size_t i = 0; i < output_shapes.size(); ++i) { 156 shape_inference::ShapeHandle output_shape_handle; 157 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( 158 output_shapes[i], &output_shape_handle)); 159 c->set_output(static_cast<int>(i), output_shape_handle); 160 } 161 return Status::OK(); 162 }); 163 164 // TODO(drpng): remove this. 165 REGISTER_OP("_While") 166 .Input("input: T") 167 .Output("output: T") 168 .Attr("T: list(type) >= 0") 169 .Attr("cond: func") 170 .Attr("body: func") 171 .SetIsStateful() __anon22ceab0f0402(shape_inference::InferenceContext* c) 172 .SetShapeFn([](shape_inference::InferenceContext* c) { 173 for (int i = 0; i < c->num_outputs(); ++i) { 174 c->set_output(i, c->input(i)); 175 } 176 return Status::OK(); 177 }) 178 .Doc(R"doc( 179 output = input; While (Cond(output)) { output = Body(output) } 180 181 input: A list of input tensors whose types are T. 182 output: A list of output tensors whose types are T. 183 cond: A function takes 'input' and returns a tensor. If the tensor is 184 a scalar of non-boolean, the scalar is converted to a boolean 185 according to the following rule: if the scalar is a numerical 186 value, non-zero means True and zero means False; if the scalar is 187 a string, non-empty means True and empty means False. If the 188 tensor is not a scalar, non-emptiness means True and False 189 otherwise. 190 body: A function that takes a list of tensors and returns another 191 list of tensors. Both lists have the same types as specified 192 by T. 193 )doc"); 194 195 REGISTER_OP("While") 196 .Input("input: T") 197 .Output("output: T") 198 .Attr("T: list(type) >= 0") 199 .Attr("cond: func") 200 .Attr("body: func") 201 .Attr("output_shapes: list(shape) = []") 202 .Attr("parallel_iterations: int = 10") 203 .SetIsStateful() __anon22ceab0f0502(shape_inference::InferenceContext* c) 204 .SetShapeFn([](shape_inference::InferenceContext* c) { 205 std::vector<PartialTensorShape> output_shapes; 206 TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); 207 // If `output_shapes` attr is set use that as the shapes of the outputs 208 // else use the input shapes. 209 if (!output_shapes.empty()) { 210 if (output_shapes.size() != c->num_outputs()) { 211 return errors::InvalidArgument( 212 "`output_shapes` must be the same length as num outputs (", 213 output_shapes.size(), " vs. ", c->num_outputs()); 214 } 215 for (size_t i = 0; i < output_shapes.size(); ++i) { 216 shape_inference::ShapeHandle output_shape_handle; 217 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( 218 output_shapes[i], &output_shape_handle)); 219 c->set_output(static_cast<int>(i), output_shape_handle); 220 } 221 } else { 222 for (int i = 0; i < c->num_outputs(); ++i) { 223 c->set_output(i, c->input(i)); 224 } 225 } 226 return Status::OK(); 227 }); 228 229 REGISTER_OP("StatelessWhile") 230 .Input("input: T") 231 .Output("output: T") 232 .Attr("T: list(type) >= 0") 233 .Attr("cond: func") 234 .Attr("body: func") __anon22ceab0f0602(shape_inference::InferenceContext* c) 235 .SetShapeFn([](shape_inference::InferenceContext* c) { 236 for (int i = 0; i < c->num_outputs(); ++i) { 237 c->set_output(i, c->input(i)); 238 } 239 return Status::OK(); 240 }); 241 242 REGISTER_OP("For") 243 .Input("start: int32") 244 .Input("limit: int32") 245 .Input("delta: int32") 246 .Input("input: T") 247 .Output("output: T") 248 .Attr("T: list(type) >= 0") 249 .Attr("body: func") 250 .SetShapeFn(shape_inference::UnknownShape); 251 252 REGISTER_OP("PartitionedCall") 253 .Input("args: Tin") 254 .Output("output: Tout") 255 .Attr("Tin: list(type) >= 0") 256 .Attr("Tout: list(type) >= 0") 257 .Attr("f: func") 258 .Attr("config: string = ''") 259 .Attr("config_proto: string = ''") 260 .Attr("executor_type: string = ''") 261 .SetShapeFn(shape_inference::UnknownShape); 262 263 REGISTER_OP("StatefulPartitionedCall") 264 .Input("args: Tin") 265 .Output("output: Tout") 266 .Attr("Tin: list(type) >= 0") 267 .Attr("Tout: list(type) >= 0") 268 .Attr("f: func") 269 .Attr("config: string = ''") // Deprecated in favor of config_proto 270 .Attr("config_proto: string = ''") 271 .Attr("executor_type: string = ''") 272 .SetIsStateful() 273 .SetShapeFn(shape_inference::UnknownShape); 274 275 // This op is used as a placeholder in If branch functions. It doesn't provide a 276 // valid output when run, so must either be removed (e.g. replaced with a 277 // function input) or guaranteed not to be used (e.g. if mirroring an 278 // intermediate output needed for the gradient computation of the other branch). 279 REGISTER_OP("FakeParam") 280 .Output("output: dtype") 281 .Attr("dtype: type") 282 .Attr("shape: shape") __anon22ceab0f0702(InferenceContext* c) 283 .SetShapeFn([](InferenceContext* c) { 284 PartialTensorShape shape; 285 TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); 286 shape_inference::ShapeHandle out; 287 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &out)); 288 c->set_output(0, out); 289 return Status::OK(); 290 }); 291 292 } // end namespace tensorflow 293