1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_INSPECTING_PLACER_H_ 17 #define TENSORFLOW_CORE_COMMON_RUNTIME_INSPECTING_PLACER_H_ 18 19 #include <vector> 20 21 #include "absl/strings/str_join.h" 22 #include "tensorflow/core/common_runtime/device.h" 23 #include "tensorflow/core/common_runtime/placer_inspection_required_ops_utils.h" 24 #include "tensorflow/core/framework/function.h" 25 #include "tensorflow/core/framework/types.h" 26 #include "tensorflow/core/lib/core/stringpiece.h" 27 #include "tensorflow/core/util/device_name_utils.h" 28 #include "tensorflow/core/util/port.h" 29 30 namespace tensorflow { 31 32 // TODO(iga): Convert this struct into a class to ensure invariants between 33 // device names, i.e. 34 // DeviceNameUtils::IsSpecification(resource_device_name, 35 // requested_device_name) 36 // PossibleDevices does not contain assigned_device_name because we don't 37 // assign devices to nested functions. 38 struct PossibleDevices { 39 // The same as Member::requested_device_name_ in colocation_graph.cc. 40 DeviceNameUtils::ParsedName requested_device_name; 41 42 // The same as Member::resource_device_name_ in colocation_graph.cc. 43 DeviceNameUtils::ParsedName resource_device_name; 44 45 // A device type outside of this set will not be supported by some 46 // internal op. 47 PrioritizedDeviceTypeVector device_types; 48 }; 49 50 // A struct for communicating constraints on devices that can 51 // be chosen for inputs and outputs of an op requiring deep placer inspection. 52 struct IOColocationGroups { 53 // input_groups[i] contains the group id that i'th input belongs to. 54 // List inputs are not supported. 55 std::vector<int> input_groups; 56 // output_groups[i] contains the group id that i'th output belongs to. 57 // List inputs are not supported. 58 std::vector<int> output_groups; 59 // group_devices[i] contains possible devices for group with id i. 60 std::vector<PossibleDevices> group_devices; 61 62 string DebugString() const; 63 }; 64 65 class InspectingPlacer { 66 public: 67 // graph and device_set must not be null and must outlive this 68 // InspectingPlacer. default_device can be null. If not, must outlive this. 69 // TODO(iga): Add a "stack trace" to detect recursion and improve log 70 // messages. Currently, we will enter an infinite loop for recursive 71 // functions. 72 InspectingPlacer(const FunctionStack& stack, 73 const FunctionLibraryDefinition* flib_def, 74 const DeviceSet* device_set, const Device* default_device, 75 bool allow_soft_placement, bool log_device_placement); 76 77 // `node` must be 78 // PlacerInspectionRequiredOpsChecker::IsPlacerInspectionRequired. 79 Status ComputeIOColocationGroups(const Node& node, 80 IOColocationGroups* groups); 81 82 private: 83 const FunctionStack stack_; 84 const FunctionLibraryDefinition& flib_def_; 85 const DeviceSet& device_set_; 86 const Device* default_device_; 87 const bool allow_soft_placement_; 88 const bool log_device_placement_; 89 90 TF_DISALLOW_COPY_AND_ASSIGN(InspectingPlacer); 91 }; 92 93 } // namespace tensorflow 94 95 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_INSPECTING_PLACER_H_ 96