1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <vector>
18
19 #include "absl/strings/str_cat.h"
20 #include "tensorflow/c/c_api.h"
21 #include "tensorflow/c/eager/abstract_context.h"
22 #include "tensorflow/c/eager/c_api_internal.h"
23 #include "tensorflow/c/eager/c_api_unified_experimental.h"
24 #include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
25 #include "tensorflow/c/tf_datatype.h"
26 #include "tensorflow/c/tf_status.h"
27 #include "tensorflow/c/tf_status_helper.h"
28 #include "tensorflow/core/framework/shape_inference.h"
29 #include "tensorflow/core/framework/tensor_shape.h"
30 #include "tensorflow/core/framework/types.pb.h"
31 #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
32 #include "tensorflow/core/platform/errors.h"
33 #include "tensorflow/core/platform/strcat.h"
34 #include "tensorflow/core/platform/types.h"
35
36 using tensorflow::dyn_cast;
37 using tensorflow::string;
38 using tensorflow::gtl::ArraySlice;
39
40 namespace tensorflow {
41 namespace tracing {
42 namespace graph {
43
44 class GraphContext;
45 class GraphOperation;
46 class GraphTensor;
47
48 auto& kUnknownDim = shape_inference::InferenceContext::kUnknownDim;
49 auto& kUnknownRank = shape_inference::InferenceContext::kUnknownRank;
50
51 // GraphTensor wraps a `TF_Output`, i.e. a pointer to TF_Operation and the index
52 // into the list of outputs for the operation.
53 class GraphTensor : public TracingTensorHandle {
54 public:
GraphTensor(TF_Output output,TF_Graph * graph)55 explicit GraphTensor(TF_Output output, TF_Graph* graph)
56 : TracingTensorHandle(kGraph), output_(output), graph_(graph) {}
57
DataType() const58 tensorflow::DataType DataType() const override {
59 return static_cast<tensorflow::DataType>(TF_OperationOutputType(output_));
60 }
61
Shape(tensorflow::PartialTensorShape * shape) const62 tensorflow::Status Shape(
63 tensorflow::PartialTensorShape* shape) const override {
64 DCHECK(shape != nullptr);
65 TF_Status status;
66 int num_dims = TF_GraphGetTensorNumDims(graph_, output_, &status);
67 DCHECK_GE(num_dims, -1);
68 TF_RETURN_IF_ERROR(StatusFromTF_Status(&status));
69 if (num_dims == kUnknownRank) {
70 return Status::OK();
71 }
72
73 std::vector<int64> dims(num_dims, kUnknownDim);
74 TF_GraphGetTensorShape(graph_, output_,
75 reinterpret_cast<int64_t*>(dims.data()), num_dims,
76 &status);
77 TF_RETURN_IF_ERROR(StatusFromTF_Status(&status));
78 TF_RETURN_IF_ERROR(tensorflow::TensorShapeUtils::MakeShape(dims, shape));
79
80 return Status::OK();
81 }
82
83 TF_Output output_;
84
85 // For LLVM style RTTI.
classof(const AbstractTensorHandle * ptr)86 static bool classof(const AbstractTensorHandle* ptr) {
87 return ptr->getKind() == kGraph;
88 }
89
90 private:
91 TF_Graph* graph_; // For shape inference.
92 };
93
94 // GraphOperation wraps and populates a TF_OperationDescription.
95 class GraphOperation : public TracingOperation {
96 public:
GraphOperation(TF_Graph * g)97 explicit GraphOperation(TF_Graph* g) : TracingOperation(kGraph), g_(g) {}
Release()98 void Release() override { delete this; }
Reset(const char * op,const char * raw_device_name)99 Status Reset(const char* op, const char* raw_device_name) override {
100 if (op_) {
101 return errors::FailedPrecondition("Reset called on already built op.");
102 }
103 if (raw_device_name) {
104 device_name_ = raw_device_name;
105 }
106 op_type_ = op;
107 return Status::OK();
108 }
SetOpName(const char * const op_name)109 Status SetOpName(const char* const op_name) override {
110 if (op_) {
111 return errors::FailedPrecondition(
112 "SetOpName called on already built op.");
113 }
114 if (op_type_.empty()) {
115 return errors::FailedPrecondition(
116 "GraphOperation::Reset must be called before calling SetOpName.");
117 }
118 // TODO(b/145674566): We use Graph::NewName to get a unique name here but
119 // this may not be consistent with python's naming policy.
120 mutex_lock l(g_->mu);
121 op_.reset(new TF_OperationDescription(g_, op_type_.c_str(),
122 g_->graph.NewName(op_name).c_str()));
123 return Status::OK();
124 }
Name() const125 const string& Name() const override { return op_type_; }
DeviceName() const126 const string& DeviceName() const override { return device_name_; }
127
SetDeviceName(const char * name)128 Status SetDeviceName(const char* name) override {
129 // TODO(srbs): Implement this.
130 device_name_ = name;
131 return Status::OK();
132 }
133
AddInput(AbstractTensorHandle * input)134 Status AddInput(AbstractTensorHandle* input) override {
135 GraphTensor* t = dyn_cast<GraphTensor>(input);
136 if (!t) {
137 return tensorflow::errors::InvalidArgument(
138 "Unable to cast input to GraphTensor");
139 }
140 TF_AddInput(op_.get(), t->output_);
141 return Status::OK();
142 }
AddInputList(absl::Span<AbstractTensorHandle * const> inputs)143 Status AddInputList(absl::Span<AbstractTensorHandle* const> inputs) override {
144 std::vector<TF_Output> tf_outputs(inputs.size());
145 for (int i = 0; i < inputs.size(); i++) {
146 GraphTensor* t = dyn_cast<GraphTensor>(inputs[i]);
147 if (!t) {
148 return tensorflow::errors::InvalidArgument(
149 "Unable to cast input to GraphTensor");
150 }
151 tf_outputs[i] = t->output_;
152 }
153 TF_AddInputList(op_.get(), tf_outputs.data(), tf_outputs.size());
154 return Status::OK();
155 }
Execute(absl::Span<AbstractTensorHandle * > retvals,int * num_retvals)156 Status Execute(absl::Span<AbstractTensorHandle*> retvals,
157 int* num_retvals) override {
158 auto* tf_opdesc = op_.release();
159 if (tf_opdesc == nullptr) {
160 return errors::InvalidArgument("AbstractOp is incomplete.");
161 }
162 TF_Status* s = TF_NewStatus();
163 auto* operation = TF_FinishOperation(tf_opdesc, s);
164 TF_RETURN_IF_ERROR(StatusFromTF_Status(s));
165 TF_DeleteStatus(s);
166 *num_retvals = TF_OperationNumOutputs(operation);
167 for (int i = 0; i < *num_retvals; ++i) {
168 retvals[i] = new GraphTensor({operation, i}, g_);
169 }
170 return Status::OK();
171 }
172
SetAttrString(const char * attr_name,const char * data,size_t length)173 Status SetAttrString(const char* attr_name, const char* data,
174 size_t length) override {
175 tensorflow::StringPiece s(data, length);
176 op_->node_builder.Attr(attr_name, s);
177 return Status::OK();
178 }
SetAttrInt(const char * attr_name,int64_t value)179 Status SetAttrInt(const char* attr_name, int64_t value) override {
180 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
181 "64-bit int types should match in size");
182 op_->node_builder.Attr(attr_name, static_cast<tensorflow::int64>(value));
183 return Status::OK();
184 }
SetAttrFloat(const char * attr_name,float value)185 Status SetAttrFloat(const char* attr_name, float value) override {
186 op_->node_builder.Attr(attr_name, value);
187 return Status::OK();
188 }
SetAttrBool(const char * attr_name,bool value)189 Status SetAttrBool(const char* attr_name, bool value) override {
190 op_->node_builder.Attr(attr_name, value);
191 return Status::OK();
192 }
SetAttrType(const char * const attr_name,DataType value)193 Status SetAttrType(const char* const attr_name, DataType value) override {
194 if (!op_) {
195 return Status(
196 error::Code::FAILED_PRECONDITION,
197 "op_type and op_name must be specified before specifying attrs.");
198 }
199 op_->node_builder.Attr(attr_name, value);
200 return Status::OK();
201 }
SetAttrShape(const char * attr_name,const int64_t * dims,const int num_dims)202 Status SetAttrShape(const char* attr_name, const int64_t* dims,
203 const int num_dims) override {
204 PartialTensorShape shape;
205 if (num_dims >= 0) {
206 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
207 "64-bit int types should match in size");
208 shape = PartialTensorShape(ArraySlice<tensorflow::int64>(
209 reinterpret_cast<const tensorflow::int64*>(dims), num_dims));
210 }
211 op_->node_builder.Attr(attr_name, shape);
212 return Status::OK();
213 }
SetAttrFunction(const char * attr_name,const AbstractOperation * value)214 Status SetAttrFunction(const char* attr_name,
215 const AbstractOperation* value) override {
216 return tensorflow::errors::Unimplemented(
217 "SetAttrFunction has not been implemented yet.");
218 }
SetAttrFunctionName(const char * attr_name,const char * value,size_t length)219 Status SetAttrFunctionName(const char* attr_name, const char* value,
220 size_t length) override {
221 tensorflow::NameAttrList func_name;
222 func_name.set_name(string(value, value + length));
223 op_->node_builder.Attr(attr_name, func_name);
224 return Status::OK();
225 }
SetAttrTensor(const char * attr_name,AbstractTensorInterface * tensor)226 Status SetAttrTensor(const char* attr_name,
227 AbstractTensorInterface* tensor) override {
228 return tensorflow::errors::Unimplemented(
229 "SetAttrTensor has not been implemented yet.");
230 }
SetAttrStringList(const char * attr_name,const void * const * values,const size_t * lengths,int num_values)231 Status SetAttrStringList(const char* attr_name, const void* const* values,
232 const size_t* lengths, int num_values) override {
233 if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
234 op_->colocation_constraints.clear();
235 for (int i = 0; i < num_values; ++i) {
236 op_->colocation_constraints.emplace(static_cast<const char*>(values[i]),
237 lengths[i]);
238 }
239 } else {
240 std::vector<tensorflow::StringPiece> v;
241 v.reserve(num_values);
242 for (int i = 0; i < num_values; ++i) {
243 v.emplace_back(static_cast<const char*>(values[i]), lengths[i]);
244 }
245 op_->node_builder.Attr(attr_name, v);
246 }
247 return Status::OK();
248 }
SetAttrFloatList(const char * attr_name,const float * values,int num_values)249 Status SetAttrFloatList(const char* attr_name, const float* values,
250 int num_values) override {
251 op_->node_builder.Attr(attr_name,
252 ArraySlice<const float>(values, num_values));
253 return Status::OK();
254 }
SetAttrIntList(const char * attr_name,const int64_t * values,int num_values)255 Status SetAttrIntList(const char* attr_name, const int64_t* values,
256 int num_values) override {
257 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
258 "64-bit int types should match in size");
259 op_->node_builder.Attr(
260 attr_name,
261 ArraySlice<const tensorflow::int64>(
262 reinterpret_cast<const tensorflow::int64*>(values), num_values));
263 return Status::OK();
264 }
SetAttrTypeList(const char * attr_name,const DataType * values,int num_values)265 Status SetAttrTypeList(const char* attr_name, const DataType* values,
266 int num_values) override {
267 op_->node_builder.Attr(attr_name,
268 ArraySlice<const DataType>(values, num_values));
269 return Status::OK();
270 }
SetAttrBoolList(const char * attr_name,const unsigned char * values,int num_values)271 Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
272 int num_values) override {
273 std::unique_ptr<bool[]> b(new bool[num_values]);
274 for (int i = 0; i < num_values; ++i) {
275 b[i] = values[i];
276 }
277 op_->node_builder.Attr(attr_name,
278 ArraySlice<const bool>(b.get(), num_values));
279
280 return Status::OK();
281 }
SetAttrShapeList(const char * attr_name,const int64_t ** dims,const int * num_dims,int num_values)282 Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
283 const int* num_dims, int num_values) override {
284 std::vector<PartialTensorShape> shapes;
285 shapes.reserve(num_values);
286 for (int i = 0; i < num_values; ++i) {
287 if (num_dims[i] < 0) {
288 shapes.emplace_back();
289 } else {
290 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
291 "64-bit int types should match in size");
292 shapes.emplace_back(ArraySlice<tensorflow::int64>(
293 reinterpret_cast<const tensorflow::int64*>(dims[i]), num_dims[i]));
294 }
295 }
296 op_->node_builder.Attr(attr_name, shapes);
297 return Status::OK();
298 }
SetAttrFunctionList(const char * attr_name,absl::Span<const AbstractOperation * > values)299 Status SetAttrFunctionList(
300 const char* attr_name,
301 absl::Span<const AbstractOperation*> values) override {
302 return tensorflow::errors::Unimplemented(
303 "SetAttrFunctionList has not been implemented yet.");
304 }
305 // For LLVM style RTTI.
classof(const AbstractOperation * ptr)306 static bool classof(const AbstractOperation* ptr) {
307 return ptr->getKind() == kGraph;
308 }
~GraphOperation()309 ~GraphOperation() override {}
310
311 private:
312 friend class GraphContext; // For access to op_.
313 TF_Graph* g_;
314 std::unique_ptr<TF_OperationDescription> op_;
315 // Hold `op_type` and `op_name` till both are available since we need both
316 // to build a graph operation.
317 string op_type_;
318 const char* op_name_ = nullptr;
319 // TODO(srbs): Use this.
320 string device_name_;
321 };
322
323 // GraphFunction is a thin wrapper over a TF_Function.
324 struct GraphFunction : public AbstractFunction {
325 TF_Function* func = nullptr;
GraphFunctiontensorflow::tracing::graph::GraphFunction326 GraphFunction() : AbstractFunction(kGraph) {}
GraphFunctiontensorflow::tracing::graph::GraphFunction327 explicit GraphFunction(TF_Function* func)
328 : AbstractFunction(kGraph), func(func) {}
~GraphFunctiontensorflow::tracing::graph::GraphFunction329 ~GraphFunction() override {
330 if (func) TF_DeleteFunction(func);
331 }
332
GetFunctionDeftensorflow::tracing::graph::GraphFunction333 Status GetFunctionDef(FunctionDef** fdef) override {
334 *fdef = &func->fdef;
335 return Status::OK();
336 }
337
338 // For LLVM style RTTI.
classoftensorflow::tracing::graph::GraphFunction339 static bool classof(const AbstractFunction* ptr) {
340 return ptr->getKind() == kGraph;
341 }
342 };
343
344 // GraphContext wraps a TF_Graph modeling a single function and manages the
345 // "execution" of operation, i.e. adding them to the function.
346 class GraphContext : public TracingContext {
347 public:
GraphContext(const char * name)348 explicit GraphContext(const char* name)
349 : TracingContext(kGraph),
350 graph_(new TF_Graph(), TF_DeleteGraph),
351 name_(name) {}
352
Release()353 void Release() override { delete this; }
354
CreateOperation()355 TracingOperation* CreateOperation() override {
356 return new GraphOperation(graph_.get());
357 }
358
AddParameter(DataType dtype,const PartialTensorShape & shape,TracingTensorHandle ** output)359 Status AddParameter(DataType dtype, const PartialTensorShape& shape,
360 TracingTensorHandle** output) override {
361 TracingOperationPtr operation(CreateOperation());
362 TF_RETURN_IF_ERROR(operation->Reset("Placeholder", nullptr));
363 TF_RETURN_IF_ERROR(
364 operation->SetOpName(absl::StrCat("_input_", inputs_.size()).c_str()));
365 TF_RETURN_IF_ERROR(operation->SetAttrType("dtype", dtype));
366 if (!shape.unknown_rank()) {
367 TF_RETURN_IF_ERROR(operation->SetAttrShape(
368 "shape", reinterpret_cast<int64_t*>(shape.dim_sizes().data()),
369 shape.dims()));
370 }
371 int num_outputs = 1;
372 std::vector<AbstractTensorHandle*> outputs(num_outputs);
373 TF_RETURN_IF_ERROR(operation->Execute(
374 absl::Span<AbstractTensorHandle*>(outputs), &num_outputs));
375
376 if (num_outputs != 1) {
377 return errors::Internal("Expected 1 output but found ", num_outputs);
378 }
379 auto* t = dyn_cast<GraphTensor>(outputs[0]);
380 if (!t) {
381 return tensorflow::errors::InvalidArgument(
382 "Unable to cast input to GraphTensor");
383 }
384 inputs_.push_back(t->output_);
385 *output = tensorflow::down_cast<TracingTensorHandle*>(outputs[0]);
386 return Status::OK();
387 }
388
Finalize(OutputList * outputs,AbstractFunction ** f)389 Status Finalize(OutputList* outputs, AbstractFunction** f) override {
390 std::unique_ptr<GraphFunction> func(new GraphFunction);
391 std::vector<TF_Output> graph_outputs;
392 graph_outputs.reserve(outputs->outputs.size());
393 for (auto* abstract_output : outputs->outputs) {
394 GraphTensor* output = dyn_cast<GraphTensor>(abstract_output);
395 if (!output) {
396 return errors::Unimplemented(
397 "Returning a non-graph tensor from a function has not "
398 "been implemented yet.");
399 }
400 graph_outputs.push_back(output->output_);
401 }
402
403 auto s = TF_NewStatus();
404 func->func = TF_GraphToFunction(graph_.get(), name_.data(), 0, -1, nullptr,
405 inputs_.size(), inputs_.data(),
406 graph_outputs.size(), graph_outputs.data(),
407 nullptr, nullptr, name_.data(), s);
408 TF_RETURN_IF_ERROR(StatusFromTF_Status(s));
409 TF_DeleteStatus(s);
410 *f = func.release();
411 return Status::OK();
412 }
413
RegisterFunction(AbstractFunction * func)414 Status RegisterFunction(AbstractFunction* func) override {
415 return errors::Unimplemented(
416 "Registering graph functions has not been implemented yet.");
417 }
418
RemoveFunction(const string & func)419 Status RemoveFunction(const string& func) override {
420 return errors::Unimplemented(
421 "GraphContext::RemoveFunction has not been implemented yet.");
422 }
423 // For LLVM style RTTI.
classof(const AbstractContext * ptr)424 static bool classof(const AbstractContext* ptr) {
425 return ptr->getKind() == kGraph;
426 }
427
428 private:
429 std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
430 std::vector<TF_Output> inputs_;
431 string name_;
432 };
433
GraphTracingFactory(const char * name,TF_Status * s)434 static TracingContext* GraphTracingFactory(const char* name, TF_Status* s) {
435 return new GraphContext(name);
436 }
437
438 // Register the tracing implemented in this file as the default tracing engine.
__anonb98a64720102null439 static bool register_tracing = [] {
440 RegisterTracingEngineFactory("graphdef", GraphTracingFactory);
441 SetDefaultTracingEngine("graphdef").IgnoreError();
442 return true;
443 }();
444
445 } // namespace graph
446 } // namespace tracing
447 } // namespace tensorflow
448