1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
16
17 #include "absl/types/span.h"
18 #include "tensorflow/c/eager/abstract_operation.h"
19 #include "tensorflow/c/eager/abstract_tensor_handle.h"
20 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
21 #include "tensorflow/c/tf_tensor_internal.h"
22 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
23 #include "tensorflow/core/common_runtime/eager/custom_device.h"
24 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
25 #include "tensorflow/core/platform/casts.h"
26 #include "tensorflow/core/platform/errors.h"
27 #include "tensorflow/core/platform/host_info.h"
28
29 namespace tensorflow {
30
31 // An EagerOperation object can be reused for a different op by calling
32 // Clear(), and then Reset(...) with the same arguments that would have
33 // been provided to the constructor.
Clear()34 void EagerOperation::Clear() {
35 for (ImmediateExecutionTensorHandle* h : inputs_) {
36 h->Unref();
37 }
38 inputs_.clear();
39 custom_device_tensor_handles_count_ = 0;
40 ClearInferenceState();
41 }
42
SetAttrValue(const char * attr_name,const AttrValue & value)43 Status EagerOperation::SetAttrValue(const char* attr_name,
44 const AttrValue& value) {
45 MutableAttrs()->Set(attr_name, value);
46 return OkStatus();
47 }
48
SetAttrString(const char * attr_name,const char * data,size_t length)49 Status EagerOperation::SetAttrString(const char* attr_name, const char* data,
50 size_t length) {
51 MutableAttrs()->Set(attr_name, StringPiece(data, length));
52 return OkStatus();
53 }
54
SetAttrInt(const char * attr_name,int64_t value)55 Status EagerOperation::SetAttrInt(const char* attr_name, int64_t value) {
56 MutableAttrs()->Set(attr_name, static_cast<int64_t>(value));
57 return OkStatus();
58 }
59
SetAttrFloat(const char * attr_name,float value)60 Status EagerOperation::SetAttrFloat(const char* attr_name, float value) {
61 MutableAttrs()->Set(attr_name, value);
62 return OkStatus();
63 }
64
SetAttrBool(const char * attr_name,bool value)65 Status EagerOperation::SetAttrBool(const char* attr_name, bool value) {
66 MutableAttrs()->Set(attr_name, value);
67 return OkStatus();
68 }
69
SetAttrType(const char * attr_name,DataType value)70 Status EagerOperation::SetAttrType(const char* attr_name, DataType value) {
71 MutableAttrs()->Set(attr_name, value);
72 return OkStatus();
73 }
74
SetAttrShape(const char * attr_name,const int64_t * dims,const int num_dims)75 Status EagerOperation::SetAttrShape(const char* attr_name, const int64_t* dims,
76 const int num_dims) {
77 if (num_dims > TensorShape::MaxDimensions()) {
78 return errors::InvalidArgument("Value specified for `", attr_name, "` has ",
79 num_dims,
80 " dimensions which is over the limit of ",
81 TensorShape::MaxDimensions(), ".");
82 }
83
84 TensorShapeProto proto;
85 if (num_dims < 0) {
86 proto.set_unknown_rank(true);
87 } else {
88 for (int d = 0; d < num_dims; ++d) {
89 proto.add_dim()->set_size(dims[d]);
90 }
91 }
92
93 MutableAttrs()->Set(attr_name, proto);
94
95 return OkStatus();
96 }
97
SetAttrFunction(const char * attr_name,const AbstractOperation * value)98 Status EagerOperation::SetAttrFunction(const char* attr_name,
99 const AbstractOperation* value) {
100 AttrValue attr_value;
101 NameAttrList* func = attr_value.mutable_func();
102 func->set_name(value->Name());
103 auto* value_operation = down_cast<const EagerOperation*>(value);
104 value_operation->Attrs().FillAttrValueMap(func->mutable_attr());
105 MutableAttrs()->Set(attr_name, attr_value);
106 return OkStatus();
107 }
108
SetAttrFunctionName(const char * attr_name,const char * data,size_t length)109 Status EagerOperation::SetAttrFunctionName(const char* attr_name,
110 const char* data, size_t length) {
111 AttrValue attr_value;
112 NameAttrList* func = attr_value.mutable_func();
113 func->set_name(data, length);
114 MutableAttrs()->Set(attr_name, attr_value);
115 return OkStatus();
116 }
117
SetAttrTensor(const char * attr_name,AbstractTensorInterface * tensor)118 Status EagerOperation::SetAttrTensor(const char* attr_name,
119 AbstractTensorInterface* tensor) {
120 Tensor t = TensorFromInterface(tensor);
121 MutableAttrs()->Set(attr_name, t);
122 return OkStatus();
123 }
124
SetAttrStringList(const char * attr_name,const void * const * values,const size_t * lengths,int num_values)125 Status EagerOperation::SetAttrStringList(const char* attr_name,
126 const void* const* values,
127 const size_t* lengths,
128 int num_values) {
129 std::vector<StringPiece> v(num_values);
130 for (int i = 0; i < num_values; ++i) {
131 v[i] = StringPiece(static_cast<const char*>(values[i]), lengths[i]);
132 }
133 MutableAttrs()->Set(attr_name, v);
134
135 return OkStatus();
136 }
137
SetAttrFloatList(const char * attr_name,const float * values,int num_values)138 Status EagerOperation::SetAttrFloatList(const char* attr_name,
139 const float* values, int num_values) {
140 MutableAttrs()->Set(attr_name,
141 gtl::ArraySlice<const float>(values, num_values));
142 return OkStatus();
143 }
144
SetAttrIntList(const char * attr_name,const int64_t * values,int num_values)145 Status EagerOperation::SetAttrIntList(const char* attr_name,
146 const int64_t* values, int num_values) {
147 MutableAttrs()->Set(
148 attr_name, gtl::ArraySlice<const int64_t>(
149 reinterpret_cast<const int64_t*>(values), num_values));
150 return OkStatus();
151 }
152
SetAttrTypeList(const char * attr_name,const DataType * values,int num_values)153 Status EagerOperation::SetAttrTypeList(const char* attr_name,
154 const DataType* values, int num_values) {
155 MutableAttrs()->Set(attr_name,
156 gtl::ArraySlice<const DataType>(values, num_values));
157 return OkStatus();
158 }
159
SetAttrBoolList(const char * attr_name,const unsigned char * values,int num_values)160 Status EagerOperation::SetAttrBoolList(const char* attr_name,
161 const unsigned char* values,
162 int num_values) {
163 std::unique_ptr<bool[]> b(new bool[num_values]);
164 for (int i = 0; i < num_values; ++i) {
165 b[i] = values[i];
166 }
167 MutableAttrs()->Set(attr_name,
168 gtl::ArraySlice<const bool>(b.get(), num_values));
169 return OkStatus();
170 }
171
SetAttrShapeList(const char * attr_name,const int64_t ** dims,const int * num_dims,int num_values)172 Status EagerOperation::SetAttrShapeList(const char* attr_name,
173 const int64_t** dims,
174 const int* num_dims, int num_values) {
175 std::unique_ptr<TensorShapeProto[]> proto(new TensorShapeProto[num_values]);
176 for (int i = 0; i < num_values; ++i) {
177 const auto num_dims_i = num_dims[i];
178
179 if (num_dims_i > TensorShape::MaxDimensions()) {
180 return errors::InvalidArgument(
181 strings::StrCat("Value specified for `", attr_name, "` has ",
182 num_dims_i, " dimensions which is over the limit of ",
183 TensorShape::MaxDimensions(), "."));
184 }
185 if (num_dims_i < 0) {
186 proto[i].set_unknown_rank(true);
187 } else {
188 const int64_t* dims_i = dims[i];
189 auto proto_i = &proto[i];
190 for (int d = 0; d < num_dims_i; ++d) {
191 proto_i->add_dim()->set_size(dims_i[d]);
192 }
193 }
194 }
195 MutableAttrs()->Set(
196 attr_name, gtl::ArraySlice<TensorShapeProto>(proto.get(), num_values));
197 return OkStatus();
198 }
199
SetAttrFunctionList(const char * attr_name,absl::Span<const AbstractOperation * > values)200 Status EagerOperation::SetAttrFunctionList(
201 const char* attr_name, absl::Span<const AbstractOperation*> values) {
202 size_t num_values = values.size();
203 std::unique_ptr<NameAttrList[]> funcs(new NameAttrList[num_values]);
204 for (int i = 0; i < num_values; i++) {
205 auto* value_operation = down_cast<const EagerOperation*>(values[i]);
206 funcs[i].set_name(value_operation->Name());
207 value_operation->Attrs().FillAttrValueMap(funcs[i].mutable_attr());
208 }
209 MutableAttrs()->Set(
210 attr_name, gtl::ArraySlice<const NameAttrList>(funcs.get(), num_values));
211 return OkStatus();
212 }
213
GetOpDef(Status * status)214 const OpDef* EagerOperation::GetOpDef(Status* status) {
215 const tensorflow::OpDef* op_def = OpDef();
216 if (op_def) return op_def;
217 *status = OpDefForOp(Name(), &op_def);
218 return op_def;
219 }
220
InputLength(const char * input_name,int * length)221 Status EagerOperation::InputLength(const char* input_name, int* length) {
222 Status status;
223 const tensorflow::OpDef* op_def = GetOpDef(&status);
224 if (!status.ok()) {
225 return status;
226 }
227 AttrValueMap attrs;
228 Attrs().FillAttrValueMap(&attrs);
229 NameRangeMap name_ranges;
230 TF_RETURN_IF_ERROR(
231 NameRangesForNode(AttrSlice(&attrs), *op_def, &name_ranges, nullptr));
232 auto iter = name_ranges.find(input_name);
233 if (iter == name_ranges.end()) {
234 return errors::InvalidArgument("Input '", input_name, "' not found");
235 }
236 *length = iter->second.second - iter->second.first;
237 return OkStatus();
238 }
239
GetInputs() const240 absl::Span<ImmediateExecutionTensorHandle* const> EagerOperation::GetInputs()
241 const {
242 // TODO(b/162536003): Remove reinterpret_cast.
243 return absl::MakeSpan(
244 reinterpret_cast<ImmediateExecutionTensorHandle* const*>(inputs_.data()),
245 inputs_.size());
246 }
247
OutputLength(const char * output_name,int * length)248 Status EagerOperation::OutputLength(const char* output_name, int* length) {
249 Status status;
250 const tensorflow::OpDef* op_def = GetOpDef(&status);
251 if (!status.ok()) {
252 return status;
253 }
254 AttrValueMap attrs;
255 Attrs().FillAttrValueMap(&attrs);
256 NameRangeMap name_ranges;
257 TF_RETURN_IF_ERROR(
258 NameRangesForNode(AttrSlice(&attrs), *op_def, nullptr, &name_ranges));
259 auto iter = name_ranges.find(output_name);
260 if (iter == name_ranges.end()) {
261 return errors::InvalidArgument("Output '", output_name, "' not found");
262 }
263 *length = iter->second.second - iter->second.first;
264 return OkStatus();
265 }
266
AddInput(AbstractTensorHandle * input)267 Status EagerOperation::AddInput(AbstractTensorHandle* input) {
268 ImmediateExecutionTensorHandle* h =
269 down_cast<ImmediateExecutionTensorHandle*>(input);
270 // TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
271 if (CustomDeviceTensorHandle::classof(h)) {
272 custom_device_tensor_handles_count_++;
273 }
274 AddTensorHandle(h);
275 return MaybeInferSingleInputAttrs(h);
276 }
277
AddInputList(absl::Span<AbstractTensorHandle * const> inputs)278 Status EagerOperation::AddInputList(
279 absl::Span<AbstractTensorHandle* const> inputs) {
280 for (auto& input : inputs) {
281 // TODO(b/175427838): It would be nice to be able to use tensorflow::isa
282 // here.
283 if (CustomDeviceTensorHandle::classof(input)) {
284 custom_device_tensor_handles_count_++;
285 }
286 ImmediateExecutionTensorHandle* h =
287 down_cast<ImmediateExecutionTensorHandle*>(input);
288 AddTensorHandle(h);
289 }
290 return InferInputListAttrs(inputs.size());
291 }
292
SetInput(size_t index,ImmediateExecutionTensorHandle * input)293 Status EagerOperation::SetInput(size_t index,
294 ImmediateExecutionTensorHandle* input) {
295 if (index >= inputs_.size()) {
296 return errors::InvalidArgument("Index >= inputs.size: %d >= %d", index,
297 inputs_.size());
298 }
299 auto* previous = inputs_[index];
300 if (CustomDeviceTensorHandle::classof(previous)) {
301 custom_device_tensor_handles_count_--;
302 }
303 if (CustomDeviceTensorHandle::classof(input)) {
304 custom_device_tensor_handles_count_++;
305 }
306 input->Ref();
307 inputs_[index] = input;
308 previous->Unref();
309 return OkStatus();
310 }
311
Reset(const char * op,const char * device_name,bool remote,EagerExecutor * executor,const absl::optional<EagerFunctionParams> eager_func_params)312 Status EagerOperation::Reset(
313 const char* op, const char* device_name, bool remote,
314 EagerExecutor* executor,
315 const absl::optional<EagerFunctionParams> eager_func_params) {
316 DCHECK(inputs_.empty());
317 ClearInferenceState();
318 bool is_function = false;
319 TF_RETURN_IF_ERROR(AttrTypeMapForOp(op, &attr_types_, &is_function));
320
321 // Don't update the device of direct function calls.
322 // Particularly, if the user did not explicitly request any device for this
323 // function, picking a device would result in this device being the default
324 // for nodes inside the function. This is undesirable for multi-device
325 // functions since the not-explicitly-placed nodes inside the body will all
326 // end up on this default device.
327 colocation_exempt_ = is_function;
328 if (!is_function) {
329 const auto& exempt_ops = InputColocationExemptionRegistry::Global()->Get();
330 colocation_exempt_ = exempt_ops.find(op) != exempt_ops.end();
331
332 TF_RETURN_IF_ERROR(OpDefForOp(op, &op_def_));
333 } else if (!remote && !ctx_.FindFunctionByName(op)) {
334 return errors::NotFound(
335 "'", op,
336 "' is neither a type of a primitive operation nor a name "
337 "of a function registered in binary running on ",
338 port::Hostname(),
339 ". Make sure the operation or function is "
340 "registered in the binary running in this process.");
341 }
342 attrs_.Reset(op);
343 stack_trace_.reset();
344 is_function_ = is_function;
345 cancellation_manager_ = nullptr;
346 executor_ = executor ? executor : &ctx_.Executor();
347 if (eager_func_params.has_value()) {
348 eager_func_params_ = eager_func_params;
349 }
350 op_name_ = op;
351 return SetDeviceName(device_name);
352 }
353
MaybeInferSingleInputAttrs(ImmediateExecutionTensorHandle * handle)354 Status EagerOperation::MaybeInferSingleInputAttrs(
355 ImmediateExecutionTensorHandle* handle) {
356 if (!op_def_) return OkStatus();
357
358 const auto& input_def = op_def_->input_arg(inference_arg_idx_++);
359 if (!input_def.number_attr().empty() || !input_def.type_list_attr().empty()) {
360 // Some clients that are still setting their input attributes manually are
361 // adding input list to their op by calling `TFE_OpAddInput` for each of
362 // its elements instead of calling `TFE_OpAddInputList`. When this happens,
363 // we cannot detect the end of such list, thus lose track of the input
364 // arguments in the op definition. To guarantee backward compatibility with
365 // those clients, disable automatic inference in this case.
366 ClearInferenceState();
367 return OkStatus();
368 }
369 const std::string& type_attr = input_def.type_attr();
370 if (!type_attr.empty() &&
371 inference_attrs_.find(type_attr) == inference_attrs_.end()) {
372 MutableAttrs()->Set(type_attr, handle->DataType());
373 inference_attrs_.insert(type_attr);
374 }
375 return OkStatus();
376 }
377
InferSingleTypeInputListAttrs(const OpDef::ArgDef & input_def,const DataType dtype,int num_inputs)378 void EagerOperation::InferSingleTypeInputListAttrs(
379 const OpDef::ArgDef& input_def, const DataType dtype, int num_inputs) {
380 if (inference_attrs_.find(input_def.number_attr()) ==
381 inference_attrs_.end()) {
382 MutableAttrs()->Set(input_def.number_attr(), num_inputs);
383 inference_attrs_.insert(input_def.number_attr());
384 }
385 if (inference_attrs_.find(input_def.type_attr()) == inference_attrs_.end()) {
386 MutableAttrs()->Set(input_def.type_attr(), dtype);
387 inference_attrs_.insert(input_def.type_attr());
388 }
389 }
390
InferMixedTypeInputListAttrs(const OpDef::ArgDef & input_def,const std::vector<DataType> & dtypes)391 void EagerOperation::InferMixedTypeInputListAttrs(
392 const OpDef::ArgDef& input_def, const std::vector<DataType>& dtypes) {
393 if (inference_attrs_.find(input_def.type_list_attr()) ==
394 inference_attrs_.end()) {
395 MutableAttrs()->Set(
396 input_def.type_list_attr(),
397 gtl::ArraySlice<const DataType>(dtypes.data(), dtypes.size()));
398 inference_attrs_.insert(input_def.type_list_attr());
399 }
400 }
401
InferInputListAttrs(int num_inputs)402 Status EagerOperation::InferInputListAttrs(int num_inputs) {
403 if (!op_def_) return OkStatus();
404
405 int start = inference_arg_idx_;
406 const auto& input_def = op_def_->input_arg(inference_arg_idx_++);
407 if (!input_def.type_list_attr().empty()) {
408 std::vector<DataType> dtypes(num_inputs);
409 for (int i = 0; i < num_inputs; ++i) {
410 dtypes[i] = inputs_[start + i]->DataType();
411 }
412 InferMixedTypeInputListAttrs(input_def, dtypes);
413 } else if (!input_def.type_attr().empty() &&
414 !input_def.number_attr().empty()) {
415 InferSingleTypeInputListAttrs(input_def, inputs_[start]->DataType(),
416 num_inputs);
417 } else if (!input_def.number_attr().empty()) {
418 if (inference_attrs_.find(input_def.number_attr()) ==
419 inference_attrs_.end()) {
420 MutableAttrs()->Set(input_def.number_attr(), num_inputs);
421 inference_attrs_.insert(input_def.number_attr());
422 }
423 } else {
424 return errors::InvalidArgument("Invalid input list definition");
425 }
426 return OkStatus();
427 }
428
TensorHandleInputs(const absl::InlinedVector<TensorHandle *,4> ** inputs) const429 Status EagerOperation::TensorHandleInputs(
430 const absl::InlinedVector<TensorHandle*, 4>** inputs) const {
431 if (TF_PREDICT_TRUE(!HasCustomDeviceInput())) {
432 *inputs = reinterpret_cast<const absl::InlinedVector<TensorHandle*, 4>*>(
433 &inputs_);
434 return OkStatus();
435 } else {
436 return errors::Internal("The operation unexpectedly had custom devices.");
437 }
438 }
439
MutableTensorHandleInputs(absl::InlinedVector<TensorHandle *,4> ** inputs)440 Status EagerOperation::MutableTensorHandleInputs(
441 absl::InlinedVector<TensorHandle*, 4>** inputs) {
442 if (TF_PREDICT_TRUE(!HasCustomDeviceInput())) {
443 *inputs =
444 reinterpret_cast<absl::InlinedVector<TensorHandle*, 4>*>(&inputs_);
445 return OkStatus();
446 } else {
447 return errors::Internal("The operation unexpectedly had custom devices.");
448 }
449 }
450
SetDeviceName(const char * c_name)451 Status EagerOperation::SetDeviceName(const char* c_name) {
452 string name(c_name != nullptr ? c_name : "");
453 if (name != last_set_device_name_) {
454 if (!DeviceNameUtils::ParseFullName(name, &device_parsed_name_)) {
455 return errors::InvalidArgument("Malformed device specification '", name,
456 "' in eager op: ", DebugString());
457 }
458 last_set_device_name_ = name;
459 device_name_ = DeviceNameUtils::ParsedNameToString(device_parsed_name_);
460 device_ = kVariantDeviceNull;
461 }
462 return OkStatus();
463 }
464
IsLocal() const465 bool EagerOperation::IsLocal() const {
466 if (ctx_.remote_device_mgr() == nullptr) return true;
467
468 if (!device_parsed_name_.has_job && !device_parsed_name_.has_replica &&
469 !device_parsed_name_.has_task)
470 return true;
471 auto& host_cpu_name = ctx_.HostCPU()->parsed_name();
472 return device_parsed_name_.job == host_cpu_name.job &&
473 device_parsed_name_.replica == host_cpu_name.replica &&
474 device_parsed_name_.task == host_cpu_name.task;
475 }
476
VariantDeviceDebugString(VariantDevice device)477 string VariantDeviceDebugString(VariantDevice device) {
478 if (device == kVariantDeviceNull) {
479 return "[]";
480 } else if (absl::holds_alternative<CustomDevice*>(device)) {
481 return absl::get<CustomDevice*>(device)->name();
482 } else {
483 return absl::get<Device*>(device)->DebugString();
484 }
485 }
GetOpAttrs() const486 const AbstractOpAttrs* EagerOperation::GetOpAttrs() const { return &attrs_; }
487
AddAttrs(const AbstractOpAttrs * op_attrs)488 void EagerOperation::AddAttrs(const AbstractOpAttrs* op_attrs) {
489 attrs_.CopyAttributes(*(down_cast<const AttrBuilder*>(op_attrs)));
490 }
491
DebugString() const492 string EagerOperation::DebugString() const {
493 string out;
494 VLOG(1) << "EagerOperation::DebugString() over " << this;
495
496 strings::StrAppend(&out, "Name: ", Name(), "\n");
497 strings::StrAppend(&out, "Device Name: [", device_name_, "]\n");
498 strings::StrAppend(&out, "Device: ", VariantDeviceDebugString(Device()),
499 "\n");
500 for (const auto& input : inputs_) {
501 VLOG(1) << "Input ptr: " << input;
502 strings::StrAppend(&out, "Input: ", input->DebugString(), "\n");
503 }
504
505 NodeDef ndef;
506 Attrs().FillAttrValueMap(ndef.mutable_attr());
507 strings::StrAppend(&out, "Attrs: ", ndef.DebugString(), "\n");
508 return out;
509 }
510
AddTensorHandle(ImmediateExecutionTensorHandle * h)511 void EagerOperation::AddTensorHandle(ImmediateExecutionTensorHandle* h) {
512 h->Ref();
513 inputs_.push_back(h);
514 attrs_.NumInputs(static_cast<int>(inputs_.size()));
515 }
516 } // namespace tensorflow
517