• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <vector>
18 
19 #include "absl/strings/str_cat.h"
20 #include "tensorflow/c/c_api.h"
21 #include "tensorflow/c/eager/abstract_context.h"
22 #include "tensorflow/c/eager/c_api_internal.h"
23 #include "tensorflow/c/eager/c_api_unified_experimental.h"
24 #include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
25 #include "tensorflow/c/tf_datatype.h"
26 #include "tensorflow/c/tf_status.h"
27 #include "tensorflow/c/tf_status_helper.h"
28 #include "tensorflow/core/framework/shape_inference.h"
29 #include "tensorflow/core/framework/tensor_shape.h"
30 #include "tensorflow/core/framework/types.pb.h"
31 #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
32 #include "tensorflow/core/platform/errors.h"
33 #include "tensorflow/core/platform/strcat.h"
34 #include "tensorflow/core/platform/types.h"
35 
36 using tensorflow::dyn_cast;
37 using tensorflow::string;
38 using tensorflow::gtl::ArraySlice;
39 
40 namespace tensorflow {
41 namespace tracing {
42 namespace graph {
43 
44 class GraphContext;
45 class GraphOperation;
46 class GraphTensor;
47 
48 auto& kUnknownDim = shape_inference::InferenceContext::kUnknownDim;
49 auto& kUnknownRank = shape_inference::InferenceContext::kUnknownRank;
50 
51 // GraphTensor wraps a `TF_Output`, i.e. a pointer to TF_Operation and the index
52 // into the list of outputs for the operation.
53 class GraphTensor : public TracingTensorHandle {
54  public:
GraphTensor(TF_Output output,TF_Graph * graph)55   explicit GraphTensor(TF_Output output, TF_Graph* graph)
56       : TracingTensorHandle(kGraph), output_(output), graph_(graph) {}
57 
DataType() const58   tensorflow::DataType DataType() const override {
59     return static_cast<tensorflow::DataType>(TF_OperationOutputType(output_));
60   }
61 
Shape(tensorflow::PartialTensorShape * shape) const62   tensorflow::Status Shape(
63       tensorflow::PartialTensorShape* shape) const override {
64     DCHECK(shape != nullptr);
65     TF_Status status;
66     int num_dims = TF_GraphGetTensorNumDims(graph_, output_, &status);
67     DCHECK_GE(num_dims, -1);
68     TF_RETURN_IF_ERROR(StatusFromTF_Status(&status));
69     if (num_dims == kUnknownRank) {
70       return Status::OK();
71     }
72 
73     std::vector<int64> dims(num_dims, kUnknownDim);
74     TF_GraphGetTensorShape(graph_, output_,
75                            reinterpret_cast<int64_t*>(dims.data()), num_dims,
76                            &status);
77     TF_RETURN_IF_ERROR(StatusFromTF_Status(&status));
78     TF_RETURN_IF_ERROR(tensorflow::TensorShapeUtils::MakeShape(dims, shape));
79 
80     return Status::OK();
81   }
82 
83   TF_Output output_;
84 
85   // For LLVM style RTTI.
classof(const AbstractTensorHandle * ptr)86   static bool classof(const AbstractTensorHandle* ptr) {
87     return ptr->getKind() == kGraph;
88   }
89 
90  private:
91   TF_Graph* graph_;  // For shape inference.
92 };
93 
94 // GraphOperation wraps and populates a TF_OperationDescription.
95 class GraphOperation : public TracingOperation {
96  public:
GraphOperation(TF_Graph * g)97   explicit GraphOperation(TF_Graph* g) : TracingOperation(kGraph), g_(g) {}
Release()98   void Release() override { delete this; }
Reset(const char * op,const char * raw_device_name)99   Status Reset(const char* op, const char* raw_device_name) override {
100     if (op_) {
101       return errors::FailedPrecondition("Reset called on already built op.");
102     }
103     if (raw_device_name) {
104       device_name_ = raw_device_name;
105     }
106     op_type_ = op;
107     return Status::OK();
108   }
SetOpName(const char * const op_name)109   Status SetOpName(const char* const op_name) override {
110     if (op_) {
111       return errors::FailedPrecondition(
112           "SetOpName called on already built op.");
113     }
114     if (op_type_.empty()) {
115       return errors::FailedPrecondition(
116           "GraphOperation::Reset must be called before calling SetOpName.");
117     }
118     // TODO(b/145674566): We use Graph::NewName to get a unique name here but
119     // this may not be consistent with python's naming policy.
120     mutex_lock l(g_->mu);
121     op_.reset(new TF_OperationDescription(g_, op_type_.c_str(),
122                                           g_->graph.NewName(op_name).c_str()));
123     return Status::OK();
124   }
Name() const125   const string& Name() const override { return op_type_; }
DeviceName() const126   const string& DeviceName() const override { return device_name_; }
127 
SetDeviceName(const char * name)128   Status SetDeviceName(const char* name) override {
129     // TODO(srbs): Implement this.
130     device_name_ = name;
131     return Status::OK();
132   }
133 
AddInput(AbstractTensorHandle * input)134   Status AddInput(AbstractTensorHandle* input) override {
135     GraphTensor* t = dyn_cast<GraphTensor>(input);
136     if (!t) {
137       return tensorflow::errors::InvalidArgument(
138           "Unable to cast input to GraphTensor");
139     }
140     TF_AddInput(op_.get(), t->output_);
141     return Status::OK();
142   }
AddInputList(absl::Span<AbstractTensorHandle * const> inputs)143   Status AddInputList(absl::Span<AbstractTensorHandle* const> inputs) override {
144     std::vector<TF_Output> tf_outputs(inputs.size());
145     for (int i = 0; i < inputs.size(); i++) {
146       GraphTensor* t = dyn_cast<GraphTensor>(inputs[i]);
147       if (!t) {
148         return tensorflow::errors::InvalidArgument(
149             "Unable to cast input to GraphTensor");
150       }
151       tf_outputs[i] = t->output_;
152     }
153     TF_AddInputList(op_.get(), tf_outputs.data(), tf_outputs.size());
154     return Status::OK();
155   }
Execute(absl::Span<AbstractTensorHandle * > retvals,int * num_retvals)156   Status Execute(absl::Span<AbstractTensorHandle*> retvals,
157                  int* num_retvals) override {
158     auto* tf_opdesc = op_.release();
159     if (tf_opdesc == nullptr) {
160       return errors::InvalidArgument("AbstractOp is incomplete.");
161     }
162     TF_Status* s = TF_NewStatus();
163     auto* operation = TF_FinishOperation(tf_opdesc, s);
164     TF_RETURN_IF_ERROR(StatusFromTF_Status(s));
165     TF_DeleteStatus(s);
166     *num_retvals = TF_OperationNumOutputs(operation);
167     for (int i = 0; i < *num_retvals; ++i) {
168       retvals[i] = new GraphTensor({operation, i}, g_);
169     }
170     return Status::OK();
171   }
172 
SetAttrString(const char * attr_name,const char * data,size_t length)173   Status SetAttrString(const char* attr_name, const char* data,
174                        size_t length) override {
175     tensorflow::StringPiece s(data, length);
176     op_->node_builder.Attr(attr_name, s);
177     return Status::OK();
178   }
SetAttrInt(const char * attr_name,int64_t value)179   Status SetAttrInt(const char* attr_name, int64_t value) override {
180     static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
181                   "64-bit int types should match in size");
182     op_->node_builder.Attr(attr_name, static_cast<tensorflow::int64>(value));
183     return Status::OK();
184   }
SetAttrFloat(const char * attr_name,float value)185   Status SetAttrFloat(const char* attr_name, float value) override {
186     op_->node_builder.Attr(attr_name, value);
187     return Status::OK();
188   }
SetAttrBool(const char * attr_name,bool value)189   Status SetAttrBool(const char* attr_name, bool value) override {
190     op_->node_builder.Attr(attr_name, value);
191     return Status::OK();
192   }
SetAttrType(const char * const attr_name,DataType value)193   Status SetAttrType(const char* const attr_name, DataType value) override {
194     if (!op_) {
195       return Status(
196           error::Code::FAILED_PRECONDITION,
197           "op_type and op_name must be specified before specifying attrs.");
198     }
199     op_->node_builder.Attr(attr_name, value);
200     return Status::OK();
201   }
SetAttrShape(const char * attr_name,const int64_t * dims,const int num_dims)202   Status SetAttrShape(const char* attr_name, const int64_t* dims,
203                       const int num_dims) override {
204     PartialTensorShape shape;
205     if (num_dims >= 0) {
206       static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
207                     "64-bit int types should match in size");
208       shape = PartialTensorShape(ArraySlice<tensorflow::int64>(
209           reinterpret_cast<const tensorflow::int64*>(dims), num_dims));
210     }
211     op_->node_builder.Attr(attr_name, shape);
212     return Status::OK();
213   }
SetAttrFunction(const char * attr_name,const AbstractOperation * value)214   Status SetAttrFunction(const char* attr_name,
215                          const AbstractOperation* value) override {
216     return tensorflow::errors::Unimplemented(
217         "SetAttrFunction has not been implemented yet.");
218   }
SetAttrFunctionName(const char * attr_name,const char * value,size_t length)219   Status SetAttrFunctionName(const char* attr_name, const char* value,
220                              size_t length) override {
221     tensorflow::NameAttrList func_name;
222     func_name.set_name(string(value, value + length));
223     op_->node_builder.Attr(attr_name, func_name);
224     return Status::OK();
225   }
SetAttrTensor(const char * attr_name,AbstractTensorInterface * tensor)226   Status SetAttrTensor(const char* attr_name,
227                        AbstractTensorInterface* tensor) override {
228     return tensorflow::errors::Unimplemented(
229         "SetAttrTensor has not been implemented yet.");
230   }
SetAttrStringList(const char * attr_name,const void * const * values,const size_t * lengths,int num_values)231   Status SetAttrStringList(const char* attr_name, const void* const* values,
232                            const size_t* lengths, int num_values) override {
233     if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
234       op_->colocation_constraints.clear();
235       for (int i = 0; i < num_values; ++i) {
236         op_->colocation_constraints.emplace(static_cast<const char*>(values[i]),
237                                             lengths[i]);
238       }
239     } else {
240       std::vector<tensorflow::StringPiece> v;
241       v.reserve(num_values);
242       for (int i = 0; i < num_values; ++i) {
243         v.emplace_back(static_cast<const char*>(values[i]), lengths[i]);
244       }
245       op_->node_builder.Attr(attr_name, v);
246     }
247     return Status::OK();
248   }
SetAttrFloatList(const char * attr_name,const float * values,int num_values)249   Status SetAttrFloatList(const char* attr_name, const float* values,
250                           int num_values) override {
251     op_->node_builder.Attr(attr_name,
252                            ArraySlice<const float>(values, num_values));
253     return Status::OK();
254   }
SetAttrIntList(const char * attr_name,const int64_t * values,int num_values)255   Status SetAttrIntList(const char* attr_name, const int64_t* values,
256                         int num_values) override {
257     static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
258                   "64-bit int types should match in size");
259     op_->node_builder.Attr(
260         attr_name,
261         ArraySlice<const tensorflow::int64>(
262             reinterpret_cast<const tensorflow::int64*>(values), num_values));
263     return Status::OK();
264   }
SetAttrTypeList(const char * attr_name,const DataType * values,int num_values)265   Status SetAttrTypeList(const char* attr_name, const DataType* values,
266                          int num_values) override {
267     op_->node_builder.Attr(attr_name,
268                            ArraySlice<const DataType>(values, num_values));
269     return Status::OK();
270   }
SetAttrBoolList(const char * attr_name,const unsigned char * values,int num_values)271   Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
272                          int num_values) override {
273     std::unique_ptr<bool[]> b(new bool[num_values]);
274     for (int i = 0; i < num_values; ++i) {
275       b[i] = values[i];
276     }
277     op_->node_builder.Attr(attr_name,
278                            ArraySlice<const bool>(b.get(), num_values));
279 
280     return Status::OK();
281   }
SetAttrShapeList(const char * attr_name,const int64_t ** dims,const int * num_dims,int num_values)282   Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
283                           const int* num_dims, int num_values) override {
284     std::vector<PartialTensorShape> shapes;
285     shapes.reserve(num_values);
286     for (int i = 0; i < num_values; ++i) {
287       if (num_dims[i] < 0) {
288         shapes.emplace_back();
289       } else {
290         static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
291                       "64-bit int types should match in size");
292         shapes.emplace_back(ArraySlice<tensorflow::int64>(
293             reinterpret_cast<const tensorflow::int64*>(dims[i]), num_dims[i]));
294       }
295     }
296     op_->node_builder.Attr(attr_name, shapes);
297     return Status::OK();
298   }
SetAttrFunctionList(const char * attr_name,absl::Span<const AbstractOperation * > values)299   Status SetAttrFunctionList(
300       const char* attr_name,
301       absl::Span<const AbstractOperation*> values) override {
302     return tensorflow::errors::Unimplemented(
303         "SetAttrFunctionList has not been implemented yet.");
304   }
305   // For LLVM style RTTI.
classof(const AbstractOperation * ptr)306   static bool classof(const AbstractOperation* ptr) {
307     return ptr->getKind() == kGraph;
308   }
~GraphOperation()309   ~GraphOperation() override {}
310 
311  private:
312   friend class GraphContext;  // For access to op_.
313   TF_Graph* g_;
314   std::unique_ptr<TF_OperationDescription> op_;
315   // Hold `op_type` and `op_name` till both are available since we need both
316   // to build a graph operation.
317   string op_type_;
318   const char* op_name_ = nullptr;
319   // TODO(srbs): Use this.
320   string device_name_;
321 };
322 
323 // GraphFunction is a thin wrapper over a TF_Function.
324 struct GraphFunction : public AbstractFunction {
325   TF_Function* func = nullptr;
GraphFunctiontensorflow::tracing::graph::GraphFunction326   GraphFunction() : AbstractFunction(kGraph) {}
GraphFunctiontensorflow::tracing::graph::GraphFunction327   explicit GraphFunction(TF_Function* func)
328       : AbstractFunction(kGraph), func(func) {}
~GraphFunctiontensorflow::tracing::graph::GraphFunction329   ~GraphFunction() override {
330     if (func) TF_DeleteFunction(func);
331   }
332 
GetFunctionDeftensorflow::tracing::graph::GraphFunction333   Status GetFunctionDef(FunctionDef** fdef) override {
334     *fdef = &func->fdef;
335     return Status::OK();
336   }
337 
338   // For LLVM style RTTI.
classoftensorflow::tracing::graph::GraphFunction339   static bool classof(const AbstractFunction* ptr) {
340     return ptr->getKind() == kGraph;
341   }
342 };
343 
344 // GraphContext wraps a TF_Graph modeling a single function and manages the
345 // "execution" of operation, i.e. adding them to the function.
346 class GraphContext : public TracingContext {
347  public:
GraphContext(const char * name)348   explicit GraphContext(const char* name)
349       : TracingContext(kGraph),
350         graph_(new TF_Graph(), TF_DeleteGraph),
351         name_(name) {}
352 
Release()353   void Release() override { delete this; }
354 
CreateOperation()355   TracingOperation* CreateOperation() override {
356     return new GraphOperation(graph_.get());
357   }
358 
AddParameter(DataType dtype,const PartialTensorShape & shape,TracingTensorHandle ** output)359   Status AddParameter(DataType dtype, const PartialTensorShape& shape,
360                       TracingTensorHandle** output) override {
361     TracingOperationPtr operation(CreateOperation());
362     TF_RETURN_IF_ERROR(operation->Reset("Placeholder", nullptr));
363     TF_RETURN_IF_ERROR(
364         operation->SetOpName(absl::StrCat("_input_", inputs_.size()).c_str()));
365     TF_RETURN_IF_ERROR(operation->SetAttrType("dtype", dtype));
366     if (!shape.unknown_rank()) {
367       TF_RETURN_IF_ERROR(operation->SetAttrShape(
368           "shape", reinterpret_cast<int64_t*>(shape.dim_sizes().data()),
369           shape.dims()));
370     }
371     int num_outputs = 1;
372     std::vector<AbstractTensorHandle*> outputs(num_outputs);
373     TF_RETURN_IF_ERROR(operation->Execute(
374         absl::Span<AbstractTensorHandle*>(outputs), &num_outputs));
375 
376     if (num_outputs != 1) {
377       return errors::Internal("Expected 1 output but found ", num_outputs);
378     }
379     auto* t = dyn_cast<GraphTensor>(outputs[0]);
380     if (!t) {
381       return tensorflow::errors::InvalidArgument(
382           "Unable to cast input to GraphTensor");
383     }
384     inputs_.push_back(t->output_);
385     *output = tensorflow::down_cast<TracingTensorHandle*>(outputs[0]);
386     return Status::OK();
387   }
388 
Finalize(OutputList * outputs,AbstractFunction ** f)389   Status Finalize(OutputList* outputs, AbstractFunction** f) override {
390     std::unique_ptr<GraphFunction> func(new GraphFunction);
391     std::vector<TF_Output> graph_outputs;
392     graph_outputs.reserve(outputs->outputs.size());
393     for (auto* abstract_output : outputs->outputs) {
394       GraphTensor* output = dyn_cast<GraphTensor>(abstract_output);
395       if (!output) {
396         return errors::Unimplemented(
397             "Returning a non-graph tensor from a function has not "
398             "been implemented yet.");
399       }
400       graph_outputs.push_back(output->output_);
401     }
402 
403     auto s = TF_NewStatus();
404     func->func = TF_GraphToFunction(graph_.get(), name_.data(), 0, -1, nullptr,
405                                     inputs_.size(), inputs_.data(),
406                                     graph_outputs.size(), graph_outputs.data(),
407                                     nullptr, nullptr, name_.data(), s);
408     TF_RETURN_IF_ERROR(StatusFromTF_Status(s));
409     TF_DeleteStatus(s);
410     *f = func.release();
411     return Status::OK();
412   }
413 
RegisterFunction(AbstractFunction * func)414   Status RegisterFunction(AbstractFunction* func) override {
415     return errors::Unimplemented(
416         "Registering graph functions has not been implemented yet.");
417   }
418 
RemoveFunction(const string & func)419   Status RemoveFunction(const string& func) override {
420     return errors::Unimplemented(
421         "GraphContext::RemoveFunction has not been implemented yet.");
422   }
423   // For LLVM style RTTI.
classof(const AbstractContext * ptr)424   static bool classof(const AbstractContext* ptr) {
425     return ptr->getKind() == kGraph;
426   }
427 
428  private:
429   std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
430   std::vector<TF_Output> inputs_;
431   string name_;
432 };
433 
GraphTracingFactory(const char * name,TF_Status * s)434 static TracingContext* GraphTracingFactory(const char* name, TF_Status* s) {
435   return new GraphContext(name);
436 }
437 
438 // Register the tracing implemented in this file as the default tracing engine.
__anonb98a64720102null439 static bool register_tracing = [] {
440   RegisterTracingEngineFactory("graphdef", GraphTracingFactory);
441   SetDefaultTracingEngine("graphdef").IgnoreError();
442   return true;
443 }();
444 
445 }  // namespace graph
446 }  // namespace tracing
447 }  // namespace tensorflow
448