1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
16
17 #include "absl/types/span.h"
18 #include "tensorflow/c/eager/abstract_operation.h"
19 #include "tensorflow/c/eager/abstract_tensor_handle.h"
20 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
21 #include "tensorflow/c/tf_tensor_internal.h"
22 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
23 #include "tensorflow/core/common_runtime/eager/custom_device.h"
24 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
25 #include "tensorflow/core/platform/casts.h"
26 #include "tensorflow/core/platform/errors.h"
27 #include "tensorflow/core/platform/host_info.h"
28
29 namespace tensorflow {
30
31 // An EagerOperation object can be reused for a different op by calling
32 // Clear(), and then Reset(...) with the same arguments that would have
33 // been provided to the constructor.
Clear()34 void EagerOperation::Clear() {
35 for (ImmediateExecutionTensorHandle* h : inputs_) {
36 h->Unref();
37 }
38 inputs_.clear();
39 custom_device_tensor_handles_count_ = 0;
40 ClearInferenceState();
41 }
42
SetAttrValue(const char * attr_name,const AttrValue & value)43 Status EagerOperation::SetAttrValue(const char* attr_name,
44 const AttrValue& value) {
45 MutableAttrs()->Set(attr_name, value);
46 return Status::OK();
47 }
48
SetAttrString(const char * attr_name,const char * data,size_t length)49 Status EagerOperation::SetAttrString(const char* attr_name, const char* data,
50 size_t length) {
51 MutableAttrs()->Set(attr_name, StringPiece(data, length));
52 return Status::OK();
53 }
54
SetAttrInt(const char * attr_name,int64_t value)55 Status EagerOperation::SetAttrInt(const char* attr_name, int64_t value) {
56 MutableAttrs()->Set(attr_name, static_cast<int64>(value));
57 return Status::OK();
58 }
59
SetAttrFloat(const char * attr_name,float value)60 Status EagerOperation::SetAttrFloat(const char* attr_name, float value) {
61 MutableAttrs()->Set(attr_name, value);
62 return Status::OK();
63 }
64
SetAttrBool(const char * attr_name,bool value)65 Status EagerOperation::SetAttrBool(const char* attr_name, bool value) {
66 MutableAttrs()->Set(attr_name, value);
67 return Status::OK();
68 }
69
SetAttrType(const char * attr_name,DataType value)70 Status EagerOperation::SetAttrType(const char* attr_name, DataType value) {
71 MutableAttrs()->Set(attr_name, value);
72 return Status::OK();
73 }
74
SetAttrShape(const char * attr_name,const int64_t * dims,const int num_dims)75 Status EagerOperation::SetAttrShape(const char* attr_name, const int64_t* dims,
76 const int num_dims) {
77 if (num_dims > TensorShape::MaxDimensions()) {
78 return errors::InvalidArgument("Value specified for `", attr_name, "` has ",
79 num_dims,
80 " dimensions which is over the limit of ",
81 TensorShape::MaxDimensions(), ".");
82 }
83
84 TensorShapeProto proto;
85 if (num_dims < 0) {
86 proto.set_unknown_rank(true);
87 } else {
88 for (int d = 0; d < num_dims; ++d) {
89 proto.add_dim()->set_size(dims[d]);
90 }
91 }
92
93 MutableAttrs()->Set(attr_name, proto);
94
95 return Status::OK();
96 }
97
SetAttrFunction(const char * attr_name,const AbstractOperation * value)98 Status EagerOperation::SetAttrFunction(const char* attr_name,
99 const AbstractOperation* value) {
100 AttrValue attr_value;
101 NameAttrList* func = attr_value.mutable_func();
102 func->set_name(value->Name());
103 auto* value_operation = down_cast<const EagerOperation*>(value);
104 value_operation->Attrs().FillAttrValueMap(func->mutable_attr());
105 MutableAttrs()->Set(attr_name, attr_value);
106 return Status::OK();
107 }
108
SetAttrFunctionName(const char * attr_name,const char * data,size_t length)109 Status EagerOperation::SetAttrFunctionName(const char* attr_name,
110 const char* data, size_t length) {
111 AttrValue attr_value;
112 NameAttrList* func = attr_value.mutable_func();
113 func->set_name(data, length);
114 MutableAttrs()->Set(attr_name, attr_value);
115 return Status::OK();
116 }
117
SetAttrTensor(const char * attr_name,AbstractTensorInterface * tensor)118 Status EagerOperation::SetAttrTensor(const char* attr_name,
119 AbstractTensorInterface* tensor) {
120 Tensor t = TensorFromInterface(tensor);
121 MutableAttrs()->Set(attr_name, t);
122 return Status::OK();
123 }
124
SetAttrStringList(const char * attr_name,const void * const * values,const size_t * lengths,int num_values)125 Status EagerOperation::SetAttrStringList(const char* attr_name,
126 const void* const* values,
127 const size_t* lengths,
128 int num_values) {
129 std::vector<StringPiece> v(num_values);
130 for (int i = 0; i < num_values; ++i) {
131 v[i] = StringPiece(static_cast<const char*>(values[i]), lengths[i]);
132 }
133 MutableAttrs()->Set(attr_name, v);
134
135 return Status::OK();
136 }
137
SetAttrFloatList(const char * attr_name,const float * values,int num_values)138 Status EagerOperation::SetAttrFloatList(const char* attr_name,
139 const float* values, int num_values) {
140 MutableAttrs()->Set(attr_name,
141 gtl::ArraySlice<const float>(values, num_values));
142 return Status::OK();
143 }
144
SetAttrIntList(const char * attr_name,const int64_t * values,int num_values)145 Status EagerOperation::SetAttrIntList(const char* attr_name,
146 const int64_t* values, int num_values) {
147 MutableAttrs()->Set(attr_name,
148 gtl::ArraySlice<const int64>(
149 reinterpret_cast<const int64*>(values), num_values));
150 return Status::OK();
151 }
152
SetAttrTypeList(const char * attr_name,const DataType * values,int num_values)153 Status EagerOperation::SetAttrTypeList(const char* attr_name,
154 const DataType* values, int num_values) {
155 MutableAttrs()->Set(attr_name,
156 gtl::ArraySlice<const DataType>(values, num_values));
157 return Status::OK();
158 }
159
SetAttrBoolList(const char * attr_name,const unsigned char * values,int num_values)160 Status EagerOperation::SetAttrBoolList(const char* attr_name,
161 const unsigned char* values,
162 int num_values) {
163 std::unique_ptr<bool[]> b(new bool[num_values]);
164 for (int i = 0; i < num_values; ++i) {
165 b[i] = values[i];
166 }
167 MutableAttrs()->Set(attr_name,
168 gtl::ArraySlice<const bool>(b.get(), num_values));
169 return Status::OK();
170 }
171
SetAttrShapeList(const char * attr_name,const int64_t ** dims,const int * num_dims,int num_values)172 Status EagerOperation::SetAttrShapeList(const char* attr_name,
173 const int64_t** dims,
174 const int* num_dims, int num_values) {
175 std::unique_ptr<TensorShapeProto[]> proto(new TensorShapeProto[num_values]);
176 for (int i = 0; i < num_values; ++i) {
177 const auto num_dims_i = num_dims[i];
178
179 if (num_dims_i > TensorShape::MaxDimensions()) {
180 return errors::InvalidArgument(
181 strings::StrCat("Value specified for `", attr_name, "` has ",
182 num_dims_i, " dimensions which is over the limit of ",
183 TensorShape::MaxDimensions(), "."));
184 }
185 if (num_dims_i < 0) {
186 proto[i].set_unknown_rank(true);
187 } else {
188 const int64_t* dims_i = dims[i];
189 auto proto_i = &proto[i];
190 for (int d = 0; d < num_dims_i; ++d) {
191 proto_i->add_dim()->set_size(dims_i[d]);
192 }
193 }
194 }
195 MutableAttrs()->Set(
196 attr_name, gtl::ArraySlice<TensorShapeProto>(proto.get(), num_values));
197 return Status::OK();
198 }
199
SetAttrFunctionList(const char * attr_name,absl::Span<const AbstractOperation * > values)200 Status EagerOperation::SetAttrFunctionList(
201 const char* attr_name, absl::Span<const AbstractOperation*> values) {
202 size_t num_values = values.size();
203 std::unique_ptr<NameAttrList[]> funcs(new NameAttrList[num_values]);
204 for (int i = 0; i < num_values; i++) {
205 auto* value_operation = down_cast<const EagerOperation*>(values[i]);
206 funcs[i].set_name(value_operation->Name());
207 value_operation->Attrs().FillAttrValueMap(funcs[i].mutable_attr());
208 }
209 MutableAttrs()->Set(
210 attr_name, gtl::ArraySlice<const NameAttrList>(funcs.get(), num_values));
211 return Status::OK();
212 }
213
GetOpDef(Status * status)214 const OpDef* EagerOperation::GetOpDef(Status* status) {
215 const tensorflow::OpDef* op_def = OpDef();
216 if (op_def) return op_def;
217 *status = OpDefForOp(Name(), &op_def);
218 return op_def;
219 }
220
InputLength(const char * input_name,int * length)221 Status EagerOperation::InputLength(const char* input_name, int* length) {
222 Status status;
223 const tensorflow::OpDef* op_def = GetOpDef(&status);
224 if (!status.ok()) {
225 return status;
226 }
227 AttrValueMap attrs;
228 Attrs().FillAttrValueMap(&attrs);
229 NameRangeMap name_ranges;
230 TF_RETURN_IF_ERROR(
231 NameRangesForNode(AttrSlice(&attrs), *op_def, &name_ranges, nullptr));
232 auto iter = name_ranges.find(input_name);
233 if (iter == name_ranges.end()) {
234 return errors::InvalidArgument("Input '", input_name, "' not found");
235 }
236 *length = iter->second.second - iter->second.first;
237 return Status::OK();
238 }
239
GetInputs() const240 absl::Span<ImmediateExecutionTensorHandle* const> EagerOperation::GetInputs()
241 const {
242 // TODO(b/162536003): Remove reinterpret_cast.
243 return absl::MakeSpan(
244 reinterpret_cast<ImmediateExecutionTensorHandle* const*>(inputs_.data()),
245 inputs_.size());
246 }
247
OutputLength(const char * output_name,int * length)248 Status EagerOperation::OutputLength(const char* output_name, int* length) {
249 Status status;
250 const tensorflow::OpDef* op_def = GetOpDef(&status);
251 if (!status.ok()) {
252 return status;
253 }
254 AttrValueMap attrs;
255 Attrs().FillAttrValueMap(&attrs);
256 NameRangeMap name_ranges;
257 TF_RETURN_IF_ERROR(
258 NameRangesForNode(AttrSlice(&attrs), *op_def, nullptr, &name_ranges));
259 auto iter = name_ranges.find(output_name);
260 if (iter == name_ranges.end()) {
261 return errors::InvalidArgument("Output '", output_name, "' not found");
262 }
263 *length = iter->second.second - iter->second.first;
264 return Status::OK();
265 }
266
AddInput(AbstractTensorHandle * input)267 Status EagerOperation::AddInput(AbstractTensorHandle* input) {
268 ImmediateExecutionTensorHandle* h =
269 down_cast<ImmediateExecutionTensorHandle*>(input);
270 // TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
271 if (CustomDeviceTensorHandle::classof(h)) {
272 custom_device_tensor_handles_count_++;
273 }
274 AddTensorHandle(h);
275 return MaybeInferSingleInputAttrs(h);
276 }
277
AddInputList(absl::Span<AbstractTensorHandle * const> inputs)278 Status EagerOperation::AddInputList(
279 absl::Span<AbstractTensorHandle* const> inputs) {
280 for (auto& input : inputs) {
281 // TODO(b/175427838): It would be nice to be able to use tensorflow::isa
282 // here.
283 if (CustomDeviceTensorHandle::classof(input)) {
284 custom_device_tensor_handles_count_++;
285 }
286 ImmediateExecutionTensorHandle* h =
287 down_cast<ImmediateExecutionTensorHandle*>(input);
288 AddTensorHandle(h);
289 }
290 return InferInputListAttrs(inputs.size());
291 }
292
SetInput(size_t index,ImmediateExecutionTensorHandle * input)293 Status EagerOperation::SetInput(size_t index,
294 ImmediateExecutionTensorHandle* input) {
295 if (index >= inputs_.size()) {
296 return errors::InvalidArgument("Index >= inputs.size: %d >= %d", index,
297 inputs_.size());
298 }
299 auto* previous = inputs_[index];
300 if (CustomDeviceTensorHandle::classof(previous)) {
301 custom_device_tensor_handles_count_--;
302 }
303 if (CustomDeviceTensorHandle::classof(input)) {
304 custom_device_tensor_handles_count_++;
305 }
306 input->Ref();
307 inputs_[index] = input;
308 previous->Unref();
309 return Status::OK();
310 }
311
Reset(const char * op,const char * device_name,bool remote,EagerExecutor * executor,const absl::optional<EagerRemoteFunctionParams> remote_func_params)312 Status EagerOperation::Reset(
313 const char* op, const char* device_name, bool remote,
314 EagerExecutor* executor,
315 const absl::optional<EagerRemoteFunctionParams> remote_func_params) {
316 DCHECK(inputs_.empty());
317 ClearInferenceState();
318 bool is_function = false;
319 TF_RETURN_IF_ERROR(AttrTypeMapForOp(op, &attr_types_, &is_function));
320
321 // Don't update the device of direct function calls.
322 // Particularly, if the user did not explicitly request any device for this
323 // function, picking a device would result in this device being the default
324 // for nodes inside the function. This is undesirable for multi-device
325 // functions since the not-explicitly-placed nodes inside the body will all
326 // end up on this default device.
327 colocation_exempt_ = is_function;
328 if (!is_function) {
329 const auto& exempt_ops = InputColocationExemptionRegistry::Global()->Get();
330 colocation_exempt_ = exempt_ops.find(op) != exempt_ops.end();
331
332 TF_RETURN_IF_ERROR(OpDefForOp(op, &op_def_));
333 } else if (!remote && !ctx_.FindFunctionByName(op)) {
334 return errors::NotFound(
335 "'", op,
336 "' is neither a type of a primitive operation nor a name "
337 "of a function registered in binary running on ",
338 port::Hostname(),
339 ". Make sure the operation or function is "
340 "registered in the binary running in this process.");
341 }
342 attrs_.Reset(op);
343 stack_trace_.reset();
344 is_function_ = is_function;
345 cancellation_manager_ = nullptr;
346 executor_ = executor ? executor : &ctx_.Executor();
347 remote_func_params_ = remote_func_params;
348 op_name_ = op;
349 return SetDeviceName(device_name);
350 }
351
MaybeInferSingleInputAttrs(ImmediateExecutionTensorHandle * handle)352 Status EagerOperation::MaybeInferSingleInputAttrs(
353 ImmediateExecutionTensorHandle* handle) {
354 if (!op_def_) return Status::OK();
355
356 const auto& input_def = op_def_->input_arg(inference_arg_idx_++);
357 if (!input_def.number_attr().empty() || !input_def.type_list_attr().empty()) {
358 // Some clients that are still setting their input attributes manually are
359 // adding input list to their op by calling `TFE_OpAddInput` for each of
360 // its elements instead of calling `TFE_OpAddInputList`. When this happens,
361 // we cannot detect the end of such list, thus lose track of the input
362 // arguments in the op definition. To guarantee backward compatibility with
363 // those clients, disable automatic inference in this case.
364 ClearInferenceState();
365 return Status::OK();
366 }
367 const std::string& type_attr = input_def.type_attr();
368 if (!type_attr.empty() &&
369 inference_attrs_.find(type_attr) == inference_attrs_.end()) {
370 MutableAttrs()->Set(type_attr, handle->DataType());
371 inference_attrs_.insert(type_attr);
372 }
373 return Status::OK();
374 }
375
InferSingleTypeInputListAttrs(const OpDef::ArgDef & input_def,const DataType dtype,int num_inputs)376 void EagerOperation::InferSingleTypeInputListAttrs(
377 const OpDef::ArgDef& input_def, const DataType dtype, int num_inputs) {
378 if (inference_attrs_.find(input_def.number_attr()) ==
379 inference_attrs_.end()) {
380 MutableAttrs()->Set(input_def.number_attr(), num_inputs);
381 inference_attrs_.insert(input_def.number_attr());
382 }
383 if (inference_attrs_.find(input_def.type_attr()) == inference_attrs_.end()) {
384 MutableAttrs()->Set(input_def.type_attr(), dtype);
385 inference_attrs_.insert(input_def.type_attr());
386 }
387 }
388
InferMixedTypeInputListAttrs(const OpDef::ArgDef & input_def,const std::vector<DataType> & dtypes)389 void EagerOperation::InferMixedTypeInputListAttrs(
390 const OpDef::ArgDef& input_def, const std::vector<DataType>& dtypes) {
391 if (inference_attrs_.find(input_def.type_list_attr()) ==
392 inference_attrs_.end()) {
393 MutableAttrs()->Set(
394 input_def.type_list_attr(),
395 gtl::ArraySlice<const DataType>(dtypes.data(), dtypes.size()));
396 inference_attrs_.insert(input_def.type_list_attr());
397 }
398 }
399
InferInputListAttrs(int num_inputs)400 Status EagerOperation::InferInputListAttrs(int num_inputs) {
401 if (!op_def_) return Status::OK();
402
403 int start = inference_arg_idx_;
404 const auto& input_def = op_def_->input_arg(inference_arg_idx_++);
405 if (!input_def.type_list_attr().empty()) {
406 std::vector<DataType> dtypes(num_inputs);
407 for (int i = 0; i < num_inputs; ++i) {
408 dtypes[i] = inputs_[start + i]->DataType();
409 }
410 InferMixedTypeInputListAttrs(input_def, dtypes);
411 } else if (!input_def.type_attr().empty() &&
412 !input_def.number_attr().empty()) {
413 InferSingleTypeInputListAttrs(input_def, inputs_[start]->DataType(),
414 num_inputs);
415 } else if (!input_def.number_attr().empty()) {
416 if (inference_attrs_.find(input_def.number_attr()) ==
417 inference_attrs_.end()) {
418 MutableAttrs()->Set(input_def.number_attr(), num_inputs);
419 inference_attrs_.insert(input_def.number_attr());
420 }
421 } else {
422 return errors::InvalidArgument("Invalid input list definition");
423 }
424 return Status::OK();
425 }
426
TensorHandleInputs(const absl::InlinedVector<TensorHandle *,4> ** inputs) const427 Status EagerOperation::TensorHandleInputs(
428 const absl::InlinedVector<TensorHandle*, 4>** inputs) const {
429 if (TF_PREDICT_TRUE(!HasCustomDeviceInput())) {
430 *inputs = reinterpret_cast<const absl::InlinedVector<TensorHandle*, 4>*>(
431 &inputs_);
432 return Status::OK();
433 } else {
434 return errors::Internal("The operation unexpectedly had custom devices.");
435 }
436 }
437
MutableTensorHandleInputs(absl::InlinedVector<TensorHandle *,4> ** inputs)438 Status EagerOperation::MutableTensorHandleInputs(
439 absl::InlinedVector<TensorHandle*, 4>** inputs) {
440 if (TF_PREDICT_TRUE(!HasCustomDeviceInput())) {
441 *inputs =
442 reinterpret_cast<absl::InlinedVector<TensorHandle*, 4>*>(&inputs_);
443 return Status::OK();
444 } else {
445 return errors::Internal("The operation unexpectedly had custom devices.");
446 }
447 }
448
SetDeviceName(const char * c_name)449 Status EagerOperation::SetDeviceName(const char* c_name) {
450 string name(c_name != nullptr ? c_name : "");
451 if (name != last_set_device_name_) {
452 if (!DeviceNameUtils::ParseFullName(name, &device_parsed_name_)) {
453 return errors::InvalidArgument("Malformed device specification '", name,
454 "' in eager op: ", DebugString());
455 }
456 last_set_device_name_ = name;
457 device_name_ = DeviceNameUtils::ParsedNameToString(device_parsed_name_);
458 device_ = kVariantDeviceNull;
459 }
460 return Status::OK();
461 }
462
IsLocal() const463 bool EagerOperation::IsLocal() const {
464 if (ctx_.remote_device_mgr() == nullptr) return true;
465
466 if (!device_parsed_name_.has_job && !device_parsed_name_.has_replica &&
467 !device_parsed_name_.has_task)
468 return true;
469 auto& host_cpu_name = ctx_.HostCPU()->parsed_name();
470 return device_parsed_name_.job == host_cpu_name.job &&
471 device_parsed_name_.replica == host_cpu_name.replica &&
472 device_parsed_name_.task == host_cpu_name.task;
473 }
474
VariantDeviceDebugString(VariantDevice device)475 string VariantDeviceDebugString(VariantDevice device) {
476 if (device == kVariantDeviceNull) {
477 return "[]";
478 } else if (absl::holds_alternative<CustomDevice*>(device)) {
479 return absl::get<CustomDevice*>(device)->name();
480 } else {
481 return absl::get<Device*>(device)->DebugString();
482 }
483 }
GetOpAttrs() const484 const AbstractOpAttrs* EagerOperation::GetOpAttrs() const { return &attrs_; }
485
AddAttrs(const AbstractOpAttrs * op_attrs)486 void EagerOperation::AddAttrs(const AbstractOpAttrs* op_attrs) {
487 attrs_.CopyAttributes(*(down_cast<const AttrBuilder*>(op_attrs)));
488 }
489
DebugString() const490 string EagerOperation::DebugString() const {
491 string out;
492 VLOG(1) << "EagerOperation::DebugString() over " << this;
493
494 strings::StrAppend(&out, "Name: ", Name(), "\n");
495 strings::StrAppend(&out, "Device Name: [", device_name_, "]\n");
496 strings::StrAppend(&out, "Device: ", VariantDeviceDebugString(Device()),
497 "\n");
498 for (const auto& input : inputs_) {
499 VLOG(1) << "Input ptr: " << input;
500 strings::StrAppend(&out, "Input: ", input->DebugString(), "\n");
501 }
502
503 NodeDef ndef;
504 Attrs().FillAttrValueMap(ndef.mutable_attr());
505 strings::StrAppend(&out, "Attrs: ", ndef.DebugString(), "\n");
506 return out;
507 }
508
AddTensorHandle(ImmediateExecutionTensorHandle * h)509 void EagerOperation::AddTensorHandle(ImmediateExecutionTensorHandle* h) {
510 h->Ref();
511 inputs_.push_back(h);
512 attrs_.NumInputs(static_cast<int>(inputs_.size()));
513 }
514 } // namespace tensorflow
515