• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
16 
17 #include "absl/types/span.h"
18 #include "tensorflow/c/eager/abstract_operation.h"
19 #include "tensorflow/c/eager/abstract_tensor_handle.h"
20 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
21 #include "tensorflow/c/tf_tensor_internal.h"
22 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
23 #include "tensorflow/core/common_runtime/eager/custom_device.h"
24 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
25 #include "tensorflow/core/platform/casts.h"
26 #include "tensorflow/core/platform/errors.h"
27 #include "tensorflow/core/platform/host_info.h"
28 
29 namespace tensorflow {
30 
31 // An EagerOperation object can be reused for a different op by calling
32 // Clear(), and then Reset(...) with the same arguments that would have
33 // been provided to the constructor.
Clear()34 void EagerOperation::Clear() {
35   for (ImmediateExecutionTensorHandle* h : inputs_) {
36     h->Unref();
37   }
38   inputs_.clear();
39   custom_device_tensor_handles_count_ = 0;
40   ClearInferenceState();
41 }
42 
SetAttrValue(const char * attr_name,const AttrValue & value)43 Status EagerOperation::SetAttrValue(const char* attr_name,
44                                     const AttrValue& value) {
45   MutableAttrs()->Set(attr_name, value);
46   return OkStatus();
47 }
48 
SetAttrString(const char * attr_name,const char * data,size_t length)49 Status EagerOperation::SetAttrString(const char* attr_name, const char* data,
50                                      size_t length) {
51   MutableAttrs()->Set(attr_name, StringPiece(data, length));
52   return OkStatus();
53 }
54 
SetAttrInt(const char * attr_name,int64_t value)55 Status EagerOperation::SetAttrInt(const char* attr_name, int64_t value) {
56   MutableAttrs()->Set(attr_name, static_cast<int64_t>(value));
57   return OkStatus();
58 }
59 
SetAttrFloat(const char * attr_name,float value)60 Status EagerOperation::SetAttrFloat(const char* attr_name, float value) {
61   MutableAttrs()->Set(attr_name, value);
62   return OkStatus();
63 }
64 
SetAttrBool(const char * attr_name,bool value)65 Status EagerOperation::SetAttrBool(const char* attr_name, bool value) {
66   MutableAttrs()->Set(attr_name, value);
67   return OkStatus();
68 }
69 
SetAttrType(const char * attr_name,DataType value)70 Status EagerOperation::SetAttrType(const char* attr_name, DataType value) {
71   MutableAttrs()->Set(attr_name, value);
72   return OkStatus();
73 }
74 
SetAttrShape(const char * attr_name,const int64_t * dims,const int num_dims)75 Status EagerOperation::SetAttrShape(const char* attr_name, const int64_t* dims,
76                                     const int num_dims) {
77   if (num_dims > TensorShape::MaxDimensions()) {
78     return errors::InvalidArgument("Value specified for `", attr_name, "` has ",
79                                    num_dims,
80                                    " dimensions which is over the limit of ",
81                                    TensorShape::MaxDimensions(), ".");
82   }
83 
84   TensorShapeProto proto;
85   if (num_dims < 0) {
86     proto.set_unknown_rank(true);
87   } else {
88     for (int d = 0; d < num_dims; ++d) {
89       proto.add_dim()->set_size(dims[d]);
90     }
91   }
92 
93   MutableAttrs()->Set(attr_name, proto);
94 
95   return OkStatus();
96 }
97 
SetAttrFunction(const char * attr_name,const AbstractOperation * value)98 Status EagerOperation::SetAttrFunction(const char* attr_name,
99                                        const AbstractOperation* value) {
100   AttrValue attr_value;
101   NameAttrList* func = attr_value.mutable_func();
102   func->set_name(value->Name());
103   auto* value_operation = down_cast<const EagerOperation*>(value);
104   value_operation->Attrs().FillAttrValueMap(func->mutable_attr());
105   MutableAttrs()->Set(attr_name, attr_value);
106   return OkStatus();
107 }
108 
SetAttrFunctionName(const char * attr_name,const char * data,size_t length)109 Status EagerOperation::SetAttrFunctionName(const char* attr_name,
110                                            const char* data, size_t length) {
111   AttrValue attr_value;
112   NameAttrList* func = attr_value.mutable_func();
113   func->set_name(data, length);
114   MutableAttrs()->Set(attr_name, attr_value);
115   return OkStatus();
116 }
117 
SetAttrTensor(const char * attr_name,AbstractTensorInterface * tensor)118 Status EagerOperation::SetAttrTensor(const char* attr_name,
119                                      AbstractTensorInterface* tensor) {
120   Tensor t = TensorFromInterface(tensor);
121   MutableAttrs()->Set(attr_name, t);
122   return OkStatus();
123 }
124 
SetAttrStringList(const char * attr_name,const void * const * values,const size_t * lengths,int num_values)125 Status EagerOperation::SetAttrStringList(const char* attr_name,
126                                          const void* const* values,
127                                          const size_t* lengths,
128                                          int num_values) {
129   std::vector<StringPiece> v(num_values);
130   for (int i = 0; i < num_values; ++i) {
131     v[i] = StringPiece(static_cast<const char*>(values[i]), lengths[i]);
132   }
133   MutableAttrs()->Set(attr_name, v);
134 
135   return OkStatus();
136 }
137 
SetAttrFloatList(const char * attr_name,const float * values,int num_values)138 Status EagerOperation::SetAttrFloatList(const char* attr_name,
139                                         const float* values, int num_values) {
140   MutableAttrs()->Set(attr_name,
141                       gtl::ArraySlice<const float>(values, num_values));
142   return OkStatus();
143 }
144 
SetAttrIntList(const char * attr_name,const int64_t * values,int num_values)145 Status EagerOperation::SetAttrIntList(const char* attr_name,
146                                       const int64_t* values, int num_values) {
147   MutableAttrs()->Set(
148       attr_name, gtl::ArraySlice<const int64_t>(
149                      reinterpret_cast<const int64_t*>(values), num_values));
150   return OkStatus();
151 }
152 
SetAttrTypeList(const char * attr_name,const DataType * values,int num_values)153 Status EagerOperation::SetAttrTypeList(const char* attr_name,
154                                        const DataType* values, int num_values) {
155   MutableAttrs()->Set(attr_name,
156                       gtl::ArraySlice<const DataType>(values, num_values));
157   return OkStatus();
158 }
159 
SetAttrBoolList(const char * attr_name,const unsigned char * values,int num_values)160 Status EagerOperation::SetAttrBoolList(const char* attr_name,
161                                        const unsigned char* values,
162                                        int num_values) {
163   std::unique_ptr<bool[]> b(new bool[num_values]);
164   for (int i = 0; i < num_values; ++i) {
165     b[i] = values[i];
166   }
167   MutableAttrs()->Set(attr_name,
168                       gtl::ArraySlice<const bool>(b.get(), num_values));
169   return OkStatus();
170 }
171 
SetAttrShapeList(const char * attr_name,const int64_t ** dims,const int * num_dims,int num_values)172 Status EagerOperation::SetAttrShapeList(const char* attr_name,
173                                         const int64_t** dims,
174                                         const int* num_dims, int num_values) {
175   std::unique_ptr<TensorShapeProto[]> proto(new TensorShapeProto[num_values]);
176   for (int i = 0; i < num_values; ++i) {
177     const auto num_dims_i = num_dims[i];
178 
179     if (num_dims_i > TensorShape::MaxDimensions()) {
180       return errors::InvalidArgument(
181           strings::StrCat("Value specified for `", attr_name, "` has ",
182                           num_dims_i, " dimensions which is over the limit of ",
183                           TensorShape::MaxDimensions(), "."));
184     }
185     if (num_dims_i < 0) {
186       proto[i].set_unknown_rank(true);
187     } else {
188       const int64_t* dims_i = dims[i];
189       auto proto_i = &proto[i];
190       for (int d = 0; d < num_dims_i; ++d) {
191         proto_i->add_dim()->set_size(dims_i[d]);
192       }
193     }
194   }
195   MutableAttrs()->Set(
196       attr_name, gtl::ArraySlice<TensorShapeProto>(proto.get(), num_values));
197   return OkStatus();
198 }
199 
SetAttrFunctionList(const char * attr_name,absl::Span<const AbstractOperation * > values)200 Status EagerOperation::SetAttrFunctionList(
201     const char* attr_name, absl::Span<const AbstractOperation*> values) {
202   size_t num_values = values.size();
203   std::unique_ptr<NameAttrList[]> funcs(new NameAttrList[num_values]);
204   for (int i = 0; i < num_values; i++) {
205     auto* value_operation = down_cast<const EagerOperation*>(values[i]);
206     funcs[i].set_name(value_operation->Name());
207     value_operation->Attrs().FillAttrValueMap(funcs[i].mutable_attr());
208   }
209   MutableAttrs()->Set(
210       attr_name, gtl::ArraySlice<const NameAttrList>(funcs.get(), num_values));
211   return OkStatus();
212 }
213 
GetOpDef(Status * status)214 const OpDef* EagerOperation::GetOpDef(Status* status) {
215   const tensorflow::OpDef* op_def = OpDef();
216   if (op_def) return op_def;
217   *status = OpDefForOp(Name(), &op_def);
218   return op_def;
219 }
220 
InputLength(const char * input_name,int * length)221 Status EagerOperation::InputLength(const char* input_name, int* length) {
222   Status status;
223   const tensorflow::OpDef* op_def = GetOpDef(&status);
224   if (!status.ok()) {
225     return status;
226   }
227   AttrValueMap attrs;
228   Attrs().FillAttrValueMap(&attrs);
229   NameRangeMap name_ranges;
230   TF_RETURN_IF_ERROR(
231       NameRangesForNode(AttrSlice(&attrs), *op_def, &name_ranges, nullptr));
232   auto iter = name_ranges.find(input_name);
233   if (iter == name_ranges.end()) {
234     return errors::InvalidArgument("Input '", input_name, "' not found");
235   }
236   *length = iter->second.second - iter->second.first;
237   return OkStatus();
238 }
239 
GetInputs() const240 absl::Span<ImmediateExecutionTensorHandle* const> EagerOperation::GetInputs()
241     const {
242   // TODO(b/162536003): Remove reinterpret_cast.
243   return absl::MakeSpan(
244       reinterpret_cast<ImmediateExecutionTensorHandle* const*>(inputs_.data()),
245       inputs_.size());
246 }
247 
OutputLength(const char * output_name,int * length)248 Status EagerOperation::OutputLength(const char* output_name, int* length) {
249   Status status;
250   const tensorflow::OpDef* op_def = GetOpDef(&status);
251   if (!status.ok()) {
252     return status;
253   }
254   AttrValueMap attrs;
255   Attrs().FillAttrValueMap(&attrs);
256   NameRangeMap name_ranges;
257   TF_RETURN_IF_ERROR(
258       NameRangesForNode(AttrSlice(&attrs), *op_def, nullptr, &name_ranges));
259   auto iter = name_ranges.find(output_name);
260   if (iter == name_ranges.end()) {
261     return errors::InvalidArgument("Output '", output_name, "' not found");
262   }
263   *length = iter->second.second - iter->second.first;
264   return OkStatus();
265 }
266 
AddInput(AbstractTensorHandle * input)267 Status EagerOperation::AddInput(AbstractTensorHandle* input) {
268   ImmediateExecutionTensorHandle* h =
269       down_cast<ImmediateExecutionTensorHandle*>(input);
270   // TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
271   if (CustomDeviceTensorHandle::classof(h)) {
272     custom_device_tensor_handles_count_++;
273   }
274   AddTensorHandle(h);
275   return MaybeInferSingleInputAttrs(h);
276 }
277 
AddInputList(absl::Span<AbstractTensorHandle * const> inputs)278 Status EagerOperation::AddInputList(
279     absl::Span<AbstractTensorHandle* const> inputs) {
280   for (auto& input : inputs) {
281     // TODO(b/175427838): It would be nice to be able to use tensorflow::isa
282     // here.
283     if (CustomDeviceTensorHandle::classof(input)) {
284       custom_device_tensor_handles_count_++;
285     }
286     ImmediateExecutionTensorHandle* h =
287         down_cast<ImmediateExecutionTensorHandle*>(input);
288     AddTensorHandle(h);
289   }
290   return InferInputListAttrs(inputs.size());
291 }
292 
SetInput(size_t index,ImmediateExecutionTensorHandle * input)293 Status EagerOperation::SetInput(size_t index,
294                                 ImmediateExecutionTensorHandle* input) {
295   if (index >= inputs_.size()) {
296     return errors::InvalidArgument("Index >= inputs.size: %d >= %d", index,
297                                    inputs_.size());
298   }
299   auto* previous = inputs_[index];
300   if (CustomDeviceTensorHandle::classof(previous)) {
301     custom_device_tensor_handles_count_--;
302   }
303   if (CustomDeviceTensorHandle::classof(input)) {
304     custom_device_tensor_handles_count_++;
305   }
306   input->Ref();
307   inputs_[index] = input;
308   previous->Unref();
309   return OkStatus();
310 }
311 
Reset(const char * op,const char * device_name,bool remote,EagerExecutor * executor,const absl::optional<EagerFunctionParams> eager_func_params)312 Status EagerOperation::Reset(
313     const char* op, const char* device_name, bool remote,
314     EagerExecutor* executor,
315     const absl::optional<EagerFunctionParams> eager_func_params) {
316   DCHECK(inputs_.empty());
317   ClearInferenceState();
318   bool is_function = false;
319   TF_RETURN_IF_ERROR(AttrTypeMapForOp(op, &attr_types_, &is_function));
320 
321   // Don't update the device of direct function calls.
322   // Particularly, if the user did not explicitly request any device for this
323   // function, picking a device would result in this device being the default
324   // for nodes inside the function. This is undesirable for multi-device
325   // functions since the not-explicitly-placed nodes inside the body will all
326   // end up on this default device.
327   colocation_exempt_ = is_function;
328   if (!is_function) {
329     const auto& exempt_ops = InputColocationExemptionRegistry::Global()->Get();
330     colocation_exempt_ = exempt_ops.find(op) != exempt_ops.end();
331 
332     TF_RETURN_IF_ERROR(OpDefForOp(op, &op_def_));
333   } else if (!remote && !ctx_.FindFunctionByName(op)) {
334     return errors::NotFound(
335         "'", op,
336         "' is neither a type of a primitive operation nor a name "
337         "of a function registered in binary running on ",
338         port::Hostname(),
339         ". Make sure the operation or function is "
340         "registered in the binary running in this process.");
341   }
342   attrs_.Reset(op);
343   stack_trace_.reset();
344   is_function_ = is_function;
345   cancellation_manager_ = nullptr;
346   executor_ = executor ? executor : &ctx_.Executor();
347   if (eager_func_params.has_value()) {
348     eager_func_params_ = eager_func_params;
349   }
350   op_name_ = op;
351   return SetDeviceName(device_name);
352 }
353 
MaybeInferSingleInputAttrs(ImmediateExecutionTensorHandle * handle)354 Status EagerOperation::MaybeInferSingleInputAttrs(
355     ImmediateExecutionTensorHandle* handle) {
356   if (!op_def_) return OkStatus();
357 
358   const auto& input_def = op_def_->input_arg(inference_arg_idx_++);
359   if (!input_def.number_attr().empty() || !input_def.type_list_attr().empty()) {
360     // Some clients that are still setting their input attributes manually are
361     // adding input list to their op by calling `TFE_OpAddInput` for each of
362     // its elements instead of calling `TFE_OpAddInputList`. When this happens,
363     // we cannot detect the end of such list, thus lose track of the input
364     // arguments in the op definition. To guarantee backward compatibility with
365     // those clients, disable automatic inference in this case.
366     ClearInferenceState();
367     return OkStatus();
368   }
369   const std::string& type_attr = input_def.type_attr();
370   if (!type_attr.empty() &&
371       inference_attrs_.find(type_attr) == inference_attrs_.end()) {
372     MutableAttrs()->Set(type_attr, handle->DataType());
373     inference_attrs_.insert(type_attr);
374   }
375   return OkStatus();
376 }
377 
InferSingleTypeInputListAttrs(const OpDef::ArgDef & input_def,const DataType dtype,int num_inputs)378 void EagerOperation::InferSingleTypeInputListAttrs(
379     const OpDef::ArgDef& input_def, const DataType dtype, int num_inputs) {
380   if (inference_attrs_.find(input_def.number_attr()) ==
381       inference_attrs_.end()) {
382     MutableAttrs()->Set(input_def.number_attr(), num_inputs);
383     inference_attrs_.insert(input_def.number_attr());
384   }
385   if (inference_attrs_.find(input_def.type_attr()) == inference_attrs_.end()) {
386     MutableAttrs()->Set(input_def.type_attr(), dtype);
387     inference_attrs_.insert(input_def.type_attr());
388   }
389 }
390 
InferMixedTypeInputListAttrs(const OpDef::ArgDef & input_def,const std::vector<DataType> & dtypes)391 void EagerOperation::InferMixedTypeInputListAttrs(
392     const OpDef::ArgDef& input_def, const std::vector<DataType>& dtypes) {
393   if (inference_attrs_.find(input_def.type_list_attr()) ==
394       inference_attrs_.end()) {
395     MutableAttrs()->Set(
396         input_def.type_list_attr(),
397         gtl::ArraySlice<const DataType>(dtypes.data(), dtypes.size()));
398     inference_attrs_.insert(input_def.type_list_attr());
399   }
400 }
401 
InferInputListAttrs(int num_inputs)402 Status EagerOperation::InferInputListAttrs(int num_inputs) {
403   if (!op_def_) return OkStatus();
404 
405   int start = inference_arg_idx_;
406   const auto& input_def = op_def_->input_arg(inference_arg_idx_++);
407   if (!input_def.type_list_attr().empty()) {
408     std::vector<DataType> dtypes(num_inputs);
409     for (int i = 0; i < num_inputs; ++i) {
410       dtypes[i] = inputs_[start + i]->DataType();
411     }
412     InferMixedTypeInputListAttrs(input_def, dtypes);
413   } else if (!input_def.type_attr().empty() &&
414              !input_def.number_attr().empty()) {
415     InferSingleTypeInputListAttrs(input_def, inputs_[start]->DataType(),
416                                   num_inputs);
417   } else if (!input_def.number_attr().empty()) {
418     if (inference_attrs_.find(input_def.number_attr()) ==
419         inference_attrs_.end()) {
420       MutableAttrs()->Set(input_def.number_attr(), num_inputs);
421       inference_attrs_.insert(input_def.number_attr());
422     }
423   } else {
424     return errors::InvalidArgument("Invalid input list definition");
425   }
426   return OkStatus();
427 }
428 
TensorHandleInputs(const absl::InlinedVector<TensorHandle *,4> ** inputs) const429 Status EagerOperation::TensorHandleInputs(
430     const absl::InlinedVector<TensorHandle*, 4>** inputs) const {
431   if (TF_PREDICT_TRUE(!HasCustomDeviceInput())) {
432     *inputs = reinterpret_cast<const absl::InlinedVector<TensorHandle*, 4>*>(
433         &inputs_);
434     return OkStatus();
435   } else {
436     return errors::Internal("The operation unexpectedly had custom devices.");
437   }
438 }
439 
MutableTensorHandleInputs(absl::InlinedVector<TensorHandle *,4> ** inputs)440 Status EagerOperation::MutableTensorHandleInputs(
441     absl::InlinedVector<TensorHandle*, 4>** inputs) {
442   if (TF_PREDICT_TRUE(!HasCustomDeviceInput())) {
443     *inputs =
444         reinterpret_cast<absl::InlinedVector<TensorHandle*, 4>*>(&inputs_);
445     return OkStatus();
446   } else {
447     return errors::Internal("The operation unexpectedly had custom devices.");
448   }
449 }
450 
SetDeviceName(const char * c_name)451 Status EagerOperation::SetDeviceName(const char* c_name) {
452   string name(c_name != nullptr ? c_name : "");
453   if (name != last_set_device_name_) {
454     if (!DeviceNameUtils::ParseFullName(name, &device_parsed_name_)) {
455       return errors::InvalidArgument("Malformed device specification '", name,
456                                      "' in eager op: ", DebugString());
457     }
458     last_set_device_name_ = name;
459     device_name_ = DeviceNameUtils::ParsedNameToString(device_parsed_name_);
460     device_ = kVariantDeviceNull;
461   }
462   return OkStatus();
463 }
464 
IsLocal() const465 bool EagerOperation::IsLocal() const {
466   if (ctx_.remote_device_mgr() == nullptr) return true;
467 
468   if (!device_parsed_name_.has_job && !device_parsed_name_.has_replica &&
469       !device_parsed_name_.has_task)
470     return true;
471   auto& host_cpu_name = ctx_.HostCPU()->parsed_name();
472   return device_parsed_name_.job == host_cpu_name.job &&
473          device_parsed_name_.replica == host_cpu_name.replica &&
474          device_parsed_name_.task == host_cpu_name.task;
475 }
476 
VariantDeviceDebugString(VariantDevice device)477 string VariantDeviceDebugString(VariantDevice device) {
478   if (device == kVariantDeviceNull) {
479     return "[]";
480   } else if (absl::holds_alternative<CustomDevice*>(device)) {
481     return absl::get<CustomDevice*>(device)->name();
482   } else {
483     return absl::get<Device*>(device)->DebugString();
484   }
485 }
GetOpAttrs() const486 const AbstractOpAttrs* EagerOperation::GetOpAttrs() const { return &attrs_; }
487 
AddAttrs(const AbstractOpAttrs * op_attrs)488 void EagerOperation::AddAttrs(const AbstractOpAttrs* op_attrs) {
489   attrs_.CopyAttributes(*(down_cast<const AttrBuilder*>(op_attrs)));
490 }
491 
DebugString() const492 string EagerOperation::DebugString() const {
493   string out;
494   VLOG(1) << "EagerOperation::DebugString() over " << this;
495 
496   strings::StrAppend(&out, "Name: ", Name(), "\n");
497   strings::StrAppend(&out, "Device Name: [", device_name_, "]\n");
498   strings::StrAppend(&out, "Device: ", VariantDeviceDebugString(Device()),
499                      "\n");
500   for (const auto& input : inputs_) {
501     VLOG(1) << "Input ptr: " << input;
502     strings::StrAppend(&out, "Input: ", input->DebugString(), "\n");
503   }
504 
505   NodeDef ndef;
506   Attrs().FillAttrValueMap(ndef.mutable_attr());
507   strings::StrAppend(&out, "Attrs: ", ndef.DebugString(), "\n");
508   return out;
509 }
510 
AddTensorHandle(ImmediateExecutionTensorHandle * h)511 void EagerOperation::AddTensorHandle(ImmediateExecutionTensorHandle* h) {
512   h->Ref();
513   inputs_.push_back(h);
514   attrs_.NumInputs(static_cast<int>(inputs_.size()));
515 }
516 }  // namespace tensorflow
517