• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_IMPLEMENTATION_SELECTOR_H_
17 #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_IMPLEMENTATION_SELECTOR_H_
18 
19 #include <string>
20 
21 #include "tensorflow/core/framework/op.h"
22 #include "tensorflow/core/grappler/costs/graph_properties.h"
23 #include "tensorflow/core/grappler/grappler_item.h"
24 #include "tensorflow/core/grappler/op_types.h"
25 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
26 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
27 #include "tensorflow/core/grappler/optimizers/function_api_info.h"
28 #include "tensorflow/core/grappler/utils/graph_view.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/core/stringpiece.h"
31 #include "tensorflow/core/lib/strings/strcat.h"
32 #include "tensorflow/core/util/device_name_utils.h"
33 
34 namespace tensorflow {
35 namespace grappler {
36 
37 // Motivation: To achieve the same high level functionality, the underlying
38 // implementations sometimes are different for various devices where the
39 // function runs. In order to achieve the correct result and best performance,
40 // the proper implementation needs to be picked dynamically.
41 //
42 // Currently there are two approaches to do this.
43 // (1) Utilize case op and dynamacically change the branch index.
44 // (2) Swap function implementation, it will be deprecated.
45 //
46 // Idea for approach 1.
47 // This transformation rewrites the DeviceIndex op with a Const op with value
48 // of the index of the device the associcated Case op runs.
49 // Example:
50 // def plus_one_gpu(x): return x + 1.0
51 // def plus_one_reference_implementation(x): return x + 1.0
52 // input = tf.constant(2.0, dtype=tf.float32)
53 // cpu_fn = lambda:plus_one_reference_implementation(input)
54 // gpu_fn = lambda:plus_one_gpu(input)
55 // control_flow_ops.execute_fn_for_device(
56 //  {"CPU": cpu_fn, "GPU":gpu_fn)}, default_fn=cpu_fn)
57 //
58 // Idea for approach 2.
59 // This transformation replaces function calls by the appropriate function
60 // definition based on properties of the runtime system. For instance,
61 // we may choose one implementation over another if we have a GPU with
62 // enough memory available.
63 //
64 // It is a way for the programmer to specify alternative implementations
65 // of the same functionality in the graph, and let TensorFlow pick the
66 // most appropriate one at runtime.
67 //
68 // For instance, the python code might specify:
69 // @Defun(tf.float32,
70 //        api_implements='plus_one',
71 //        api_preferred_device='GPU')
72 // def plus_one_gpu(x): return x + 1.0
73 //
74 // @Defun(tf.float32,
75 //        api_implements='plus_one')
76 // def plus_one_reference_implementation(x): return x + 1.0
77 // input = tf.constant(2.0, dtype=tf.float32)
78 //
79 // z = plus_one_reference_implementation(input)
80 // z = plus_one_gpu(input)
81 // print(sess.run(z))
82 //
83 
84 // At runtime, we will select either `plus_one_gpu` or
85 // `plus_one_reference_implementation` based on the availability of the GPU.
86 //
87 // Available annotations:
88 //  - api_implements(string): all functions mapping to the same
89 //    string can be interchanged. For now, all functions must have the same
90 //    signature and overloads are not allowed. Defuns within defuns are
91 //    allowed.
92 //  - api_preferred_device(string): sets which device is preferred.
93 class ImplementationSelector : public CustomGraphOptimizer {
94  public:
95   ImplementationSelector() = default;
96   ~ImplementationSelector() override = default;
Init(const tensorflow::RewriterConfig_CustomGraphOptimizer * config)97   Status Init(
98       const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
99     return Status::OK();
100   }
name()101   string name() const override {
102     return "implementation_selector";
103   }
104 
UsesFunctionLibrary()105   bool UsesFunctionLibrary() const override { return false; }
106 
107   // This call is not thread-safe.
108   Status Optimize(Cluster* cluster, const GrapplerItem& item,
109                   GraphDef* optimized_graph) override;
110 
111   // Does not take any feedback.
Feedback(Cluster * cluster,const GrapplerItem & item,const GraphDef & optimized_graph,double result)112   void Feedback(Cluster* cluster, const GrapplerItem& item,
113                 const GraphDef& optimized_graph, double result) override {}
114 
115  private:
116   Status LoadFunctions(const GraphDef& graph);
117   Status MaybeOptimizeFunctionCall(utils::MutableNodeView* node_view) const;
118 
119   // Finds all call sites for functions, then replace with the appropriate
120   // implementation.
121   // There are two ways of calling functions:
122   //  1. By specifying an op name as a function name, and
123   //  2. Via the functional interface, where the function name appears as an
124   //  Attr.
125   //
126   // There may be multiple call sites for a given function. The function body
127   // may call into another function, so a function might have to be duplicated.
128   // For simplicity, we do not change function bodies. Also, we do not change
129   // gradients.
130   Status SelectImplementation(GraphDef* graph) const;
131 
132   // Rewrites the DeviceIndex op with a Const op with value of the index of the
133   // device the associcated Case op runs.
134 
135   // This function first looks up all the DeviceIndex ops.
136   // Then for each of these ops, it finds the device of the
137   // associated Case op that takes the DeviceIndex op as the input, and
138   // caculates the index of the device in the device list of DeviceIndex op.
139   // Lastly, it rewrites the DeviceIndex op with a Const op and sets the value
140   // to be the index.
141   //
142   // Example input nodes:
143   // node {
144   //   name: "x"
145   //   op: "DeviceIndex"
146   //   device: "/device:CPU:0"
147   //   attr {
148   //     key: "device_names"
149   //     value {
150   //       list {
151   //         s: "CPU"
152   //         s: "TPU_REPLICATED_CORE"
153   //         s: "GPU"
154   //       }
155   //     }
156   //   }
157   // }
158   // node {
159   //   name: "case"
160   //   op: "Case"
161   //   input: "x"
162   //   device: "/device:GPU:0"
163   //   ...
164   // }
165   // Example output nodes:
166   //
167   //  name: "x"
168   //  op: "Const"
169   //  device: "/device:CPU:0"
170   //  attr {
171   //    key: "dtype"
172   //    value {
173   //      type: DT_INT32
174   //    }
175   //  }
176   //  attr {
177   //    key: "value"
178   //    value {
179   //      tensor {
180   //        dtype: DT_INT32
181   //        int_val: 2
182   //      }
183   //    }
184   //  }
185   // node {
186   //   name: "case"
187   //   op: "Case"
188   //   input: "x"
189   //   device: "/device:GPU:0"
190   //   ...
191   // }
192   Status SelectDeviceIndex(GraphDef* graph) const;
193 
194   std::unique_ptr<FunctionLibraryApiInfo> lib_info_;
195 
196   TF_DISALLOW_COPY_AND_ASSIGN(ImplementationSelector);
197 };
198 
199 }  // namespace grappler
200 }  // namespace tensorflow
201 
202 #endif  // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_IMPLEMENTATION_SELECTOR_H_
203