1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/python/framework/cpp_shape_inference.h"
17
18 #include "tensorflow/core/framework/node_def.pb.h"
19 #include "tensorflow/core/framework/op.h"
20 #include "tensorflow/core/framework/shape_inference.h"
21 #include "tensorflow/core/framework/tensor_shape.pb.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/lib/strings/strcat.h"
24 #include "tensorflow/python/framework/cpp_shape_inference.pb.h"
25 #include "tensorflow/python/lib/core/py_func.h"
26
27 namespace tensorflow {
28
29 namespace swig {
30 namespace {
31
ProtoFromShapeHandle(tensorflow::shape_inference::ShapeHandle s,tensorflow::shape_inference::InferenceContext * c,TensorShapeProto * out)32 void ProtoFromShapeHandle(tensorflow::shape_inference::ShapeHandle s,
33 tensorflow::shape_inference::InferenceContext* c,
34 TensorShapeProto* out) {
35 if (c->RankKnown(s)) {
36 const int32 rank = c->Rank(s);
37 for (int i = 0; i < rank; ++i) {
38 shape_inference::DimensionHandle d = c->Dim(s, i);
39 auto* out_dim = out->add_dim();
40 if (c->ValueKnown(d)) {
41 out_dim->set_size(c->Value(d));
42 } else {
43 out_dim->set_size(-1);
44 }
45 }
46 } else {
47 out->set_unknown_rank(true);
48 }
49 }
50
RunCppShapeInferenceImpl(int graph_def_version,const string & serialized_node_def,const std::vector<string> & input_serialized_shapes,const std::vector<PyObject * > & input_constant_tensor_values,const std::vector<string> & input_constant_tensor_as_shape_values,std::vector<string> * output_tensor_shape_protos,string * input_tensors_needed_out)51 Status RunCppShapeInferenceImpl(
52 int graph_def_version, const string& serialized_node_def,
53 const std::vector<string>& input_serialized_shapes,
54 const std::vector<PyObject*>& input_constant_tensor_values,
55 const std::vector<string>& input_constant_tensor_as_shape_values,
56 std::vector<string>* output_tensor_shape_protos,
57 string* input_tensors_needed_out) {
58 tensorflow::NodeDef node;
59 if (!node.ParseFromString(serialized_node_def)) {
60 return errors::InvalidArgument(
61 "Error parsing node_def during cpp shape inference");
62 }
63 DCHECK_EQ(output_tensor_shape_protos->size(), 0);
64
65 const OpRegistrationData* op_reg_data;
66 TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUp(node.op(), &op_reg_data));
67
68 if (op_reg_data->shape_inference_fn == nullptr) {
69 return errors::InvalidArgument(
70 "No shape inference function exists for op '", node.op(),
71 "', did you forget to define it?");
72 }
73
74 // Convert input shapes.
75 std::vector<TensorShapeProto> input_shapes;
76 std::vector<
77 std::unique_ptr<std::vector<std::pair<TensorShapeProto, DataType>>>>
78 input_handle_shapes_and_types;
79 input_shapes.resize(input_serialized_shapes.size());
80 input_handle_shapes_and_types.resize(input_serialized_shapes.size());
81 CppShapeInferenceResult tmp;
82 for (int i = 0; i < input_serialized_shapes.size(); ++i) {
83 tmp.Clear();
84 if (!tmp.ParseFromString(input_serialized_shapes[i])) {
85 return errors::InvalidArgument(
86 "Error parsing shape proto during cpp shape inference");
87 }
88
89 input_shapes[i].Swap(tmp.mutable_shape());
90
91 if (tmp.handle_data().is_set()) {
92 input_handle_shapes_and_types[i].reset(
93 new std::vector<std::pair<TensorShapeProto, DataType>>);
94 auto& v = *input_handle_shapes_and_types[i];
95 for (const auto& x : tmp.handle_data().shape_and_type()) {
96 v.emplace_back(x.shape(), x.dtype());
97 }
98 }
99 }
100
101 // Convert input tensor values;
102 std::vector<Tensor> input_tensor_values(input_constant_tensor_values.size());
103 std::vector<const Tensor*> input_tensors;
104 for (int i = 0; i < input_constant_tensor_values.size(); ++i) {
105 auto* py_val = input_constant_tensor_values[i];
106 if (py_val == Py_None) {
107 input_tensors.push_back(nullptr);
108 } else {
109 TF_RETURN_IF_ERROR(
110 ConvertNdarrayToTensor(py_val, &input_tensor_values[i]));
111 input_tensors.push_back(&input_tensor_values[i]);
112 }
113 }
114
115 // Convert input tensor-as-shape values;
116 std::vector<TensorShapeProto> input_tensor_as_shapes_protos(
117 input_constant_tensor_as_shape_values.size());
118 for (int i = 0; i < input_constant_tensor_as_shape_values.size(); ++i) {
119 if (!input_tensor_as_shapes_protos[i].ParseFromString(
120 input_constant_tensor_as_shape_values[i])) {
121 return errors::InvalidArgument(
122 "Error parsing shape proto during cpp shape inference");
123 }
124 }
125
126 // Run shape inference.
127 tensorflow::shape_inference::InferenceContext c(
128 graph_def_version, &node, op_reg_data->op_def, input_shapes,
129 input_tensors, input_tensor_as_shapes_protos,
130 input_handle_shapes_and_types);
131 TF_RETURN_IF_ERROR(c.construction_status());
132
133 TF_RETURN_IF_ERROR(c.Run(op_reg_data->shape_inference_fn));
134
135 // Convert output shapes.
136 output_tensor_shape_protos->resize(c.num_outputs());
137 CppShapeInferenceResult out;
138 for (int i = 0; i < c.num_outputs(); ++i) {
139 out.Clear();
140 ProtoFromShapeHandle(c.output(i), &c, out.mutable_shape());
141
142 const auto* shapes_and_types = c.output_handle_shapes_and_types(i);
143 if (shapes_and_types != nullptr) {
144 auto* out_handle_data = out.mutable_handle_data();
145 out_handle_data->set_is_set(true);
146 for (const auto& p : *shapes_and_types) {
147 auto* out_shape_and_type = out_handle_data->add_shape_and_type();
148 ProtoFromShapeHandle(p.shape, &c, out_shape_and_type->mutable_shape());
149 out_shape_and_type->set_dtype(p.dtype);
150 }
151 }
152
153 CHECK(out.AppendToString(&(*output_tensor_shape_protos)[i]));
154 }
155
156 // Add info about requested inputs.
157 CppShapeInferenceInputsNeeded needed;
158 for (int i = 0; i < c.num_inputs(); ++i) {
159 if (c.requested_input_tensor(i)) {
160 needed.add_input_tensors_needed(i);
161 }
162 if (c.requested_input_tensor_as_partial_shape(i)) {
163 needed.add_input_tensors_as_shapes_needed(i);
164 }
165 }
166 *input_tensors_needed_out = needed.SerializeAsString();
167
168 return Status::OK();
169 }
170
171 } // namespace
172
RunCppShapeInference(int graph_def_version,const string & serialized_node_def,const std::vector<string> & input_serialized_shapes,PyObject * input_constant_tensor_values,const std::vector<string> & input_constant_tensor_as_shape_values,TF_Status * out_status)173 std::vector<string> RunCppShapeInference(
174 int graph_def_version, const string& serialized_node_def,
175 const std::vector<string>& input_serialized_shapes,
176 PyObject* input_constant_tensor_values,
177 const std::vector<string>& input_constant_tensor_as_shape_values,
178 TF_Status* out_status) {
179 if (!PyList_Check(input_constant_tensor_values)) {
180 TF_SetStatus(out_status, TF_INVALID_ARGUMENT, "Invalid python value");
181 return std::vector<string>();
182 }
183
184 std::vector<PyObject*> input_constant_tensor_values_v;
185 int cnt = PyList_Size(input_constant_tensor_values);
186 input_constant_tensor_values_v.reserve(cnt);
187 for (int i = 0; i < cnt; ++i) {
188 input_constant_tensor_values_v.push_back(
189 PyList_GetItem(input_constant_tensor_values, i));
190 }
191
192 std::vector<string> output;
193 string input_tensors_needed_out;
194 tensorflow::Status status = RunCppShapeInferenceImpl(
195 graph_def_version, serialized_node_def, input_serialized_shapes,
196 input_constant_tensor_values_v, input_constant_tensor_as_shape_values,
197 &output, &input_tensors_needed_out);
198
199 Set_TF_Status_from_Status(out_status, status);
200 if (!status.ok()) {
201 return std::vector<string>();
202 }
203 output.push_back(input_tensors_needed_out);
204 return output;
205 }
206
207 } // namespace swig
208 } // namespace tensorflow
209