• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
16 
17 #include "absl/types/span.h"
18 #include "tensorflow/c/eager/abstract_operation.h"
19 #include "tensorflow/c/eager/abstract_tensor_handle.h"
20 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
21 #include "tensorflow/c/tf_tensor_internal.h"
22 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
23 #include "tensorflow/core/common_runtime/eager/custom_device.h"
24 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
25 #include "tensorflow/core/platform/casts.h"
26 #include "tensorflow/core/platform/errors.h"
27 #include "tensorflow/core/platform/host_info.h"
28 
29 namespace tensorflow {
30 
31 // An EagerOperation object can be reused for a different op by calling
32 // Clear(), and then Reset(...) with the same arguments that would have
33 // been provided to the constructor.
Clear()34 void EagerOperation::Clear() {
35   for (ImmediateExecutionTensorHandle* h : inputs_) {
36     h->Unref();
37   }
38   inputs_.clear();
39   custom_device_tensor_handles_count_ = 0;
40   ClearInferenceState();
41 }
42 
SetAttrValue(const char * attr_name,const AttrValue & value)43 Status EagerOperation::SetAttrValue(const char* attr_name,
44                                     const AttrValue& value) {
45   MutableAttrs()->Set(attr_name, value);
46   return Status::OK();
47 }
48 
SetAttrString(const char * attr_name,const char * data,size_t length)49 Status EagerOperation::SetAttrString(const char* attr_name, const char* data,
50                                      size_t length) {
51   MutableAttrs()->Set(attr_name, StringPiece(data, length));
52   return Status::OK();
53 }
54 
SetAttrInt(const char * attr_name,int64_t value)55 Status EagerOperation::SetAttrInt(const char* attr_name, int64_t value) {
56   MutableAttrs()->Set(attr_name, static_cast<int64>(value));
57   return Status::OK();
58 }
59 
SetAttrFloat(const char * attr_name,float value)60 Status EagerOperation::SetAttrFloat(const char* attr_name, float value) {
61   MutableAttrs()->Set(attr_name, value);
62   return Status::OK();
63 }
64 
SetAttrBool(const char * attr_name,bool value)65 Status EagerOperation::SetAttrBool(const char* attr_name, bool value) {
66   MutableAttrs()->Set(attr_name, value);
67   return Status::OK();
68 }
69 
SetAttrType(const char * attr_name,DataType value)70 Status EagerOperation::SetAttrType(const char* attr_name, DataType value) {
71   MutableAttrs()->Set(attr_name, value);
72   return Status::OK();
73 }
74 
SetAttrShape(const char * attr_name,const int64_t * dims,const int num_dims)75 Status EagerOperation::SetAttrShape(const char* attr_name, const int64_t* dims,
76                                     const int num_dims) {
77   if (num_dims > TensorShape::MaxDimensions()) {
78     return errors::InvalidArgument("Value specified for `", attr_name, "` has ",
79                                    num_dims,
80                                    " dimensions which is over the limit of ",
81                                    TensorShape::MaxDimensions(), ".");
82   }
83 
84   TensorShapeProto proto;
85   if (num_dims < 0) {
86     proto.set_unknown_rank(true);
87   } else {
88     for (int d = 0; d < num_dims; ++d) {
89       proto.add_dim()->set_size(dims[d]);
90     }
91   }
92 
93   MutableAttrs()->Set(attr_name, proto);
94 
95   return Status::OK();
96 }
97 
SetAttrFunction(const char * attr_name,const AbstractOperation * value)98 Status EagerOperation::SetAttrFunction(const char* attr_name,
99                                        const AbstractOperation* value) {
100   AttrValue attr_value;
101   NameAttrList* func = attr_value.mutable_func();
102   func->set_name(value->Name());
103   auto* value_operation = down_cast<const EagerOperation*>(value);
104   value_operation->Attrs().FillAttrValueMap(func->mutable_attr());
105   MutableAttrs()->Set(attr_name, attr_value);
106   return Status::OK();
107 }
108 
SetAttrFunctionName(const char * attr_name,const char * data,size_t length)109 Status EagerOperation::SetAttrFunctionName(const char* attr_name,
110                                            const char* data, size_t length) {
111   AttrValue attr_value;
112   NameAttrList* func = attr_value.mutable_func();
113   func->set_name(data, length);
114   MutableAttrs()->Set(attr_name, attr_value);
115   return Status::OK();
116 }
117 
SetAttrTensor(const char * attr_name,AbstractTensorInterface * tensor)118 Status EagerOperation::SetAttrTensor(const char* attr_name,
119                                      AbstractTensorInterface* tensor) {
120   Tensor t = TensorFromInterface(tensor);
121   MutableAttrs()->Set(attr_name, t);
122   return Status::OK();
123 }
124 
SetAttrStringList(const char * attr_name,const void * const * values,const size_t * lengths,int num_values)125 Status EagerOperation::SetAttrStringList(const char* attr_name,
126                                          const void* const* values,
127                                          const size_t* lengths,
128                                          int num_values) {
129   std::vector<StringPiece> v(num_values);
130   for (int i = 0; i < num_values; ++i) {
131     v[i] = StringPiece(static_cast<const char*>(values[i]), lengths[i]);
132   }
133   MutableAttrs()->Set(attr_name, v);
134 
135   return Status::OK();
136 }
137 
SetAttrFloatList(const char * attr_name,const float * values,int num_values)138 Status EagerOperation::SetAttrFloatList(const char* attr_name,
139                                         const float* values, int num_values) {
140   MutableAttrs()->Set(attr_name,
141                       gtl::ArraySlice<const float>(values, num_values));
142   return Status::OK();
143 }
144 
SetAttrIntList(const char * attr_name,const int64_t * values,int num_values)145 Status EagerOperation::SetAttrIntList(const char* attr_name,
146                                       const int64_t* values, int num_values) {
147   MutableAttrs()->Set(attr_name,
148                       gtl::ArraySlice<const int64>(
149                           reinterpret_cast<const int64*>(values), num_values));
150   return Status::OK();
151 }
152 
SetAttrTypeList(const char * attr_name,const DataType * values,int num_values)153 Status EagerOperation::SetAttrTypeList(const char* attr_name,
154                                        const DataType* values, int num_values) {
155   MutableAttrs()->Set(attr_name,
156                       gtl::ArraySlice<const DataType>(values, num_values));
157   return Status::OK();
158 }
159 
SetAttrBoolList(const char * attr_name,const unsigned char * values,int num_values)160 Status EagerOperation::SetAttrBoolList(const char* attr_name,
161                                        const unsigned char* values,
162                                        int num_values) {
163   std::unique_ptr<bool[]> b(new bool[num_values]);
164   for (int i = 0; i < num_values; ++i) {
165     b[i] = values[i];
166   }
167   MutableAttrs()->Set(attr_name,
168                       gtl::ArraySlice<const bool>(b.get(), num_values));
169   return Status::OK();
170 }
171 
SetAttrShapeList(const char * attr_name,const int64_t ** dims,const int * num_dims,int num_values)172 Status EagerOperation::SetAttrShapeList(const char* attr_name,
173                                         const int64_t** dims,
174                                         const int* num_dims, int num_values) {
175   std::unique_ptr<TensorShapeProto[]> proto(new TensorShapeProto[num_values]);
176   for (int i = 0; i < num_values; ++i) {
177     const auto num_dims_i = num_dims[i];
178 
179     if (num_dims_i > TensorShape::MaxDimensions()) {
180       return errors::InvalidArgument(
181           strings::StrCat("Value specified for `", attr_name, "` has ",
182                           num_dims_i, " dimensions which is over the limit of ",
183                           TensorShape::MaxDimensions(), "."));
184     }
185     if (num_dims_i < 0) {
186       proto[i].set_unknown_rank(true);
187     } else {
188       const int64_t* dims_i = dims[i];
189       auto proto_i = &proto[i];
190       for (int d = 0; d < num_dims_i; ++d) {
191         proto_i->add_dim()->set_size(dims_i[d]);
192       }
193     }
194   }
195   MutableAttrs()->Set(
196       attr_name, gtl::ArraySlice<TensorShapeProto>(proto.get(), num_values));
197   return Status::OK();
198 }
199 
SetAttrFunctionList(const char * attr_name,absl::Span<const AbstractOperation * > values)200 Status EagerOperation::SetAttrFunctionList(
201     const char* attr_name, absl::Span<const AbstractOperation*> values) {
202   size_t num_values = values.size();
203   std::unique_ptr<NameAttrList[]> funcs(new NameAttrList[num_values]);
204   for (int i = 0; i < num_values; i++) {
205     auto* value_operation = down_cast<const EagerOperation*>(values[i]);
206     funcs[i].set_name(value_operation->Name());
207     value_operation->Attrs().FillAttrValueMap(funcs[i].mutable_attr());
208   }
209   MutableAttrs()->Set(
210       attr_name, gtl::ArraySlice<const NameAttrList>(funcs.get(), num_values));
211   return Status::OK();
212 }
213 
GetOpDef(Status * status)214 const OpDef* EagerOperation::GetOpDef(Status* status) {
215   const tensorflow::OpDef* op_def = OpDef();
216   if (op_def) return op_def;
217   *status = OpDefForOp(Name(), &op_def);
218   return op_def;
219 }
220 
InputLength(const char * input_name,int * length)221 Status EagerOperation::InputLength(const char* input_name, int* length) {
222   Status status;
223   const tensorflow::OpDef* op_def = GetOpDef(&status);
224   if (!status.ok()) {
225     return status;
226   }
227   AttrValueMap attrs;
228   Attrs().FillAttrValueMap(&attrs);
229   NameRangeMap name_ranges;
230   TF_RETURN_IF_ERROR(
231       NameRangesForNode(AttrSlice(&attrs), *op_def, &name_ranges, nullptr));
232   auto iter = name_ranges.find(input_name);
233   if (iter == name_ranges.end()) {
234     return errors::InvalidArgument("Input '", input_name, "' not found");
235   }
236   *length = iter->second.second - iter->second.first;
237   return Status::OK();
238 }
239 
GetInputs() const240 absl::Span<ImmediateExecutionTensorHandle* const> EagerOperation::GetInputs()
241     const {
242   // TODO(b/162536003): Remove reinterpret_cast.
243   return absl::MakeSpan(
244       reinterpret_cast<ImmediateExecutionTensorHandle* const*>(inputs_.data()),
245       inputs_.size());
246 }
247 
OutputLength(const char * output_name,int * length)248 Status EagerOperation::OutputLength(const char* output_name, int* length) {
249   Status status;
250   const tensorflow::OpDef* op_def = GetOpDef(&status);
251   if (!status.ok()) {
252     return status;
253   }
254   AttrValueMap attrs;
255   Attrs().FillAttrValueMap(&attrs);
256   NameRangeMap name_ranges;
257   TF_RETURN_IF_ERROR(
258       NameRangesForNode(AttrSlice(&attrs), *op_def, nullptr, &name_ranges));
259   auto iter = name_ranges.find(output_name);
260   if (iter == name_ranges.end()) {
261     return errors::InvalidArgument("Output '", output_name, "' not found");
262   }
263   *length = iter->second.second - iter->second.first;
264   return Status::OK();
265 }
266 
AddInput(AbstractTensorHandle * input)267 Status EagerOperation::AddInput(AbstractTensorHandle* input) {
268   ImmediateExecutionTensorHandle* h =
269       down_cast<ImmediateExecutionTensorHandle*>(input);
270   // TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
271   if (CustomDeviceTensorHandle::classof(h)) {
272     custom_device_tensor_handles_count_++;
273   }
274   AddTensorHandle(h);
275   return MaybeInferSingleInputAttrs(h);
276 }
277 
AddInputList(absl::Span<AbstractTensorHandle * const> inputs)278 Status EagerOperation::AddInputList(
279     absl::Span<AbstractTensorHandle* const> inputs) {
280   for (auto& input : inputs) {
281     // TODO(b/175427838): It would be nice to be able to use tensorflow::isa
282     // here.
283     if (CustomDeviceTensorHandle::classof(input)) {
284       custom_device_tensor_handles_count_++;
285     }
286     ImmediateExecutionTensorHandle* h =
287         down_cast<ImmediateExecutionTensorHandle*>(input);
288     AddTensorHandle(h);
289   }
290   return InferInputListAttrs(inputs.size());
291 }
292 
SetInput(size_t index,ImmediateExecutionTensorHandle * input)293 Status EagerOperation::SetInput(size_t index,
294                                 ImmediateExecutionTensorHandle* input) {
295   if (index >= inputs_.size()) {
296     return errors::InvalidArgument("Index >= inputs.size: %d >= %d", index,
297                                    inputs_.size());
298   }
299   auto* previous = inputs_[index];
300   if (CustomDeviceTensorHandle::classof(previous)) {
301     custom_device_tensor_handles_count_--;
302   }
303   if (CustomDeviceTensorHandle::classof(input)) {
304     custom_device_tensor_handles_count_++;
305   }
306   input->Ref();
307   inputs_[index] = input;
308   previous->Unref();
309   return Status::OK();
310 }
311 
Reset(const char * op,const char * device_name,bool remote,EagerExecutor * executor,const absl::optional<EagerRemoteFunctionParams> remote_func_params)312 Status EagerOperation::Reset(
313     const char* op, const char* device_name, bool remote,
314     EagerExecutor* executor,
315     const absl::optional<EagerRemoteFunctionParams> remote_func_params) {
316   DCHECK(inputs_.empty());
317   ClearInferenceState();
318   bool is_function = false;
319   TF_RETURN_IF_ERROR(AttrTypeMapForOp(op, &attr_types_, &is_function));
320 
321   // Don't update the device of direct function calls.
322   // Particularly, if the user did not explicitly request any device for this
323   // function, picking a device would result in this device being the default
324   // for nodes inside the function. This is undesirable for multi-device
325   // functions since the not-explicitly-placed nodes inside the body will all
326   // end up on this default device.
327   colocation_exempt_ = is_function;
328   if (!is_function) {
329     const auto& exempt_ops = InputColocationExemptionRegistry::Global()->Get();
330     colocation_exempt_ = exempt_ops.find(op) != exempt_ops.end();
331 
332     TF_RETURN_IF_ERROR(OpDefForOp(op, &op_def_));
333   } else if (!remote && !ctx_.FindFunctionByName(op)) {
334     return errors::NotFound(
335         "'", op,
336         "' is neither a type of a primitive operation nor a name "
337         "of a function registered in binary running on ",
338         port::Hostname(),
339         ". Make sure the operation or function is "
340         "registered in the binary running in this process.");
341   }
342   attrs_.Reset(op);
343   stack_trace_.reset();
344   is_function_ = is_function;
345   cancellation_manager_ = nullptr;
346   executor_ = executor ? executor : &ctx_.Executor();
347   remote_func_params_ = remote_func_params;
348   op_name_ = op;
349   return SetDeviceName(device_name);
350 }
351 
MaybeInferSingleInputAttrs(ImmediateExecutionTensorHandle * handle)352 Status EagerOperation::MaybeInferSingleInputAttrs(
353     ImmediateExecutionTensorHandle* handle) {
354   if (!op_def_) return Status::OK();
355 
356   const auto& input_def = op_def_->input_arg(inference_arg_idx_++);
357   if (!input_def.number_attr().empty() || !input_def.type_list_attr().empty()) {
358     // Some clients that are still setting their input attributes manually are
359     // adding input list to their op by calling `TFE_OpAddInput` for each of
360     // its elements instead of calling `TFE_OpAddInputList`. When this happens,
361     // we cannot detect the end of such list, thus lose track of the input
362     // arguments in the op definition. To guarantee backward compatibility with
363     // those clients, disable automatic inference in this case.
364     ClearInferenceState();
365     return Status::OK();
366   }
367   const std::string& type_attr = input_def.type_attr();
368   if (!type_attr.empty() &&
369       inference_attrs_.find(type_attr) == inference_attrs_.end()) {
370     MutableAttrs()->Set(type_attr, handle->DataType());
371     inference_attrs_.insert(type_attr);
372   }
373   return Status::OK();
374 }
375 
InferSingleTypeInputListAttrs(const OpDef::ArgDef & input_def,const DataType dtype,int num_inputs)376 void EagerOperation::InferSingleTypeInputListAttrs(
377     const OpDef::ArgDef& input_def, const DataType dtype, int num_inputs) {
378   if (inference_attrs_.find(input_def.number_attr()) ==
379       inference_attrs_.end()) {
380     MutableAttrs()->Set(input_def.number_attr(), num_inputs);
381     inference_attrs_.insert(input_def.number_attr());
382   }
383   if (inference_attrs_.find(input_def.type_attr()) == inference_attrs_.end()) {
384     MutableAttrs()->Set(input_def.type_attr(), dtype);
385     inference_attrs_.insert(input_def.type_attr());
386   }
387 }
388 
InferMixedTypeInputListAttrs(const OpDef::ArgDef & input_def,const std::vector<DataType> & dtypes)389 void EagerOperation::InferMixedTypeInputListAttrs(
390     const OpDef::ArgDef& input_def, const std::vector<DataType>& dtypes) {
391   if (inference_attrs_.find(input_def.type_list_attr()) ==
392       inference_attrs_.end()) {
393     MutableAttrs()->Set(
394         input_def.type_list_attr(),
395         gtl::ArraySlice<const DataType>(dtypes.data(), dtypes.size()));
396     inference_attrs_.insert(input_def.type_list_attr());
397   }
398 }
399 
InferInputListAttrs(int num_inputs)400 Status EagerOperation::InferInputListAttrs(int num_inputs) {
401   if (!op_def_) return Status::OK();
402 
403   int start = inference_arg_idx_;
404   const auto& input_def = op_def_->input_arg(inference_arg_idx_++);
405   if (!input_def.type_list_attr().empty()) {
406     std::vector<DataType> dtypes(num_inputs);
407     for (int i = 0; i < num_inputs; ++i) {
408       dtypes[i] = inputs_[start + i]->DataType();
409     }
410     InferMixedTypeInputListAttrs(input_def, dtypes);
411   } else if (!input_def.type_attr().empty() &&
412              !input_def.number_attr().empty()) {
413     InferSingleTypeInputListAttrs(input_def, inputs_[start]->DataType(),
414                                   num_inputs);
415   } else if (!input_def.number_attr().empty()) {
416     if (inference_attrs_.find(input_def.number_attr()) ==
417         inference_attrs_.end()) {
418       MutableAttrs()->Set(input_def.number_attr(), num_inputs);
419       inference_attrs_.insert(input_def.number_attr());
420     }
421   } else {
422     return errors::InvalidArgument("Invalid input list definition");
423   }
424   return Status::OK();
425 }
426 
TensorHandleInputs(const absl::InlinedVector<TensorHandle *,4> ** inputs) const427 Status EagerOperation::TensorHandleInputs(
428     const absl::InlinedVector<TensorHandle*, 4>** inputs) const {
429   if (TF_PREDICT_TRUE(!HasCustomDeviceInput())) {
430     *inputs = reinterpret_cast<const absl::InlinedVector<TensorHandle*, 4>*>(
431         &inputs_);
432     return Status::OK();
433   } else {
434     return errors::Internal("The operation unexpectedly had custom devices.");
435   }
436 }
437 
MutableTensorHandleInputs(absl::InlinedVector<TensorHandle *,4> ** inputs)438 Status EagerOperation::MutableTensorHandleInputs(
439     absl::InlinedVector<TensorHandle*, 4>** inputs) {
440   if (TF_PREDICT_TRUE(!HasCustomDeviceInput())) {
441     *inputs =
442         reinterpret_cast<absl::InlinedVector<TensorHandle*, 4>*>(&inputs_);
443     return Status::OK();
444   } else {
445     return errors::Internal("The operation unexpectedly had custom devices.");
446   }
447 }
448 
SetDeviceName(const char * c_name)449 Status EagerOperation::SetDeviceName(const char* c_name) {
450   string name(c_name != nullptr ? c_name : "");
451   if (name != last_set_device_name_) {
452     if (!DeviceNameUtils::ParseFullName(name, &device_parsed_name_)) {
453       return errors::InvalidArgument("Malformed device specification '", name,
454                                      "' in eager op: ", DebugString());
455     }
456     last_set_device_name_ = name;
457     device_name_ = DeviceNameUtils::ParsedNameToString(device_parsed_name_);
458     device_ = kVariantDeviceNull;
459   }
460   return Status::OK();
461 }
462 
IsLocal() const463 bool EagerOperation::IsLocal() const {
464   if (ctx_.remote_device_mgr() == nullptr) return true;
465 
466   if (!device_parsed_name_.has_job && !device_parsed_name_.has_replica &&
467       !device_parsed_name_.has_task)
468     return true;
469   auto& host_cpu_name = ctx_.HostCPU()->parsed_name();
470   return device_parsed_name_.job == host_cpu_name.job &&
471          device_parsed_name_.replica == host_cpu_name.replica &&
472          device_parsed_name_.task == host_cpu_name.task;
473 }
474 
VariantDeviceDebugString(VariantDevice device)475 string VariantDeviceDebugString(VariantDevice device) {
476   if (device == kVariantDeviceNull) {
477     return "[]";
478   } else if (absl::holds_alternative<CustomDevice*>(device)) {
479     return absl::get<CustomDevice*>(device)->name();
480   } else {
481     return absl::get<Device*>(device)->DebugString();
482   }
483 }
GetOpAttrs() const484 const AbstractOpAttrs* EagerOperation::GetOpAttrs() const { return &attrs_; }
485 
AddAttrs(const AbstractOpAttrs * op_attrs)486 void EagerOperation::AddAttrs(const AbstractOpAttrs* op_attrs) {
487   attrs_.CopyAttributes(*(down_cast<const AttrBuilder*>(op_attrs)));
488 }
489 
DebugString() const490 string EagerOperation::DebugString() const {
491   string out;
492   VLOG(1) << "EagerOperation::DebugString() over " << this;
493 
494   strings::StrAppend(&out, "Name: ", Name(), "\n");
495   strings::StrAppend(&out, "Device Name: [", device_name_, "]\n");
496   strings::StrAppend(&out, "Device: ", VariantDeviceDebugString(Device()),
497                      "\n");
498   for (const auto& input : inputs_) {
499     VLOG(1) << "Input ptr: " << input;
500     strings::StrAppend(&out, "Input: ", input->DebugString(), "\n");
501   }
502 
503   NodeDef ndef;
504   Attrs().FillAttrValueMap(ndef.mutable_attr());
505   strings::StrAppend(&out, "Attrs: ", ndef.DebugString(), "\n");
506   return out;
507 }
508 
AddTensorHandle(ImmediateExecutionTensorHandle * h)509 void EagerOperation::AddTensorHandle(ImmediateExecutionTensorHandle* h) {
510   h->Ref();
511   inputs_.push_back(h);
512   attrs_.NumInputs(static_cast<int>(inputs_.size()));
513 }
514 }  // namespace tensorflow
515