1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/function_api_info.h"
17
18 #include <string>
19 #include <unordered_set>
20 #include <vector>
21
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/function.pb.h"
24 #include "tensorflow/core/framework/op_def.pb.h"
25 #include "tensorflow/core/lib/core/status_test_util.h"
26 #include "tensorflow/core/platform/test.h"
27
28 namespace tensorflow {
29 namespace grappler {
30 namespace {
SetArg(const string & name,const string & type_name,OpDef::ArgDef * arg_def)31 void SetArg(const string& name, const string& type_name,
32 OpDef::ArgDef* arg_def) {
33 arg_def->set_name(name);
34 arg_def->set_type_attr(type_name);
35 }
36
37 typedef std::pair<string, string> ArgSpec; // name, type.
38
SetArgs(const std::vector<ArgSpec> & input_args_spec,const std::vector<ArgSpec> & output_args_spec,OpDef * sig)39 void SetArgs(const std::vector<ArgSpec>& input_args_spec,
40 const std::vector<ArgSpec>& output_args_spec, OpDef* sig) {
41 for (const auto& arg_spec : input_args_spec)
42 SetArg(arg_spec.first, arg_spec.second, sig->add_input_arg());
43 for (const auto& arg_spec : output_args_spec)
44 SetArg(arg_spec.first, arg_spec.second, sig->add_output_arg());
45 }
46
PopulateFunction(const string & name,const string & api_interface_name,const string & preferred_device,const std::vector<ArgSpec> & input_args,const std::vector<ArgSpec> & output_args,const string & forward_function_name,const string & backward_function_name,FunctionDef * func_def)47 void PopulateFunction(const string& name, const string& api_interface_name,
48 const string& preferred_device,
49 const std::vector<ArgSpec>& input_args,
50 const std::vector<ArgSpec>& output_args,
51 const string& forward_function_name,
52 const string& backward_function_name,
53 FunctionDef* func_def) {
54 OpDef* sig = func_def->mutable_signature();
55 sig->set_name(name);
56
57 SetArgs(input_args, output_args, sig);
58
59 auto* func_attr = func_def->mutable_attr();
60 if (!api_interface_name.empty())
61 (*func_attr)["api_implements"].set_s(api_interface_name);
62 if (!preferred_device.empty())
63 (*func_attr)["api_preferred_device"].set_s(preferred_device);
64 if (!forward_function_name.empty())
65 (*func_attr)["forward_function_name"].set_s(forward_function_name);
66 if (!backward_function_name.empty())
67 (*func_attr)["backward_function_name"].set_s(backward_function_name);
68 }
69
PopulateSampleLibrary(const bool mismatch_args,FunctionDefLibrary * func_lib)70 void PopulateSampleLibrary(const bool mismatch_args,
71 FunctionDefLibrary* func_lib) {
72 const std::vector<ArgSpec> func_args{{"in1", "float32"}, {"in2", "int32"}};
73 const std::vector<ArgSpec> func_wrong_args{{"in1", "int32"},
74 {"in2", "int32"}};
75 const std::vector<ArgSpec> output_args{{"out", "float32"}};
76 PopulateFunction("DoStuffCpu", "DoStuff", "CPU", func_args, output_args, "",
77 "", func_lib->add_function());
78 PopulateFunction("DoStuffGpu", "DoStuff", "GPU",
79 mismatch_args ? func_wrong_args : func_args, output_args, "",
80 "", func_lib->add_function());
81 PopulateFunction("DoThings", "DoThings", "", func_args, output_args, "", "",
82 func_lib->add_function());
83 PopulateFunction("OneOff", "", "", func_args, output_args, "", "",
84 func_lib->add_function());
85 PopulateFunction("AnotherOneOff", "", "", func_args, output_args, "", "",
86 func_lib->add_function());
87 }
88
PopulateComplexLibrary(FunctionDefLibrary * func_lib)89 void PopulateComplexLibrary(FunctionDefLibrary* func_lib) {
90 const std::vector<ArgSpec> input_args{{"in1", "float32"}, {"in2", "int32"}};
91 const std::vector<ArgSpec> output_args{{"out", "float32"}};
92 const std::vector<ArgSpec> output_with_state{
93 {"out", "float32"}, {"state1", "int32"}, {"state2", "int32"}};
94
95 PopulateFunction("DoStuffCpu", "DoStuff", "CPU", input_args, output_args, "",
96 "DoStuffCpu_gradient", func_lib->add_function());
97 PopulateFunction("DoStuffCpu_gradient", "DoStuff", "CPU", output_args,
98 input_args, "DoStuffCpu", "", func_lib->add_function());
99 PopulateFunction("DoStuffGpu", "DoStuff", "GPU", input_args,
100 output_with_state, "", "DoStuffGpu_gradient",
101 func_lib->add_function());
102 PopulateFunction("DoStuffGpu_gradient", "DoStuff", "GPU", output_with_state,
103 input_args, "DoStuffGpu", "", func_lib->add_function());
104 }
105
CheckEquivImpl(const FunctionLibraryApiInfo & lib_api_info,const string & func_name,const std::vector<string> & expected_other)106 bool CheckEquivImpl(const FunctionLibraryApiInfo& lib_api_info,
107 const string& func_name,
108 const std::vector<string>& expected_other) {
109 std::vector<string> other_impl;
110 Status status =
111 lib_api_info.GetEquivalentImplementations(func_name, &other_impl);
112 EXPECT_EQ(status, Status::OK());
113 const std::unordered_set<string> actual(other_impl.begin(), other_impl.end());
114 const std::unordered_set<string> expected(expected_other.begin(),
115 expected_other.end());
116 return actual == expected;
117 }
118
GetInterfaceName(const FunctionLibraryApiInfo & lib_api_info,const string & func_name)119 string GetInterfaceName(const FunctionLibraryApiInfo& lib_api_info,
120 const string& func_name) {
121 auto* info = lib_api_info.GetApiInfo(func_name);
122 CHECK_NOTNULL(info);
123 return info->interface_name();
124 }
125
GetPreferredDevice(const FunctionLibraryApiInfo & lib_api_info,const string & func_name)126 string GetPreferredDevice(const FunctionLibraryApiInfo& lib_api_info,
127 const string& func_name) {
128 auto* info = lib_api_info.GetApiInfo(func_name);
129 CHECK_NOTNULL(info);
130 return info->preferred_device();
131 }
132
TEST(FunctionApiInfoTest,ParseTags)133 TEST(FunctionApiInfoTest, ParseTags) {
134 FunctionDefLibrary func_lib;
135 PopulateSampleLibrary(/* mismatch_args */ false, &func_lib);
136 FunctionLibraryApiInfo lib_api_info;
137 TF_ASSERT_OK(lib_api_info.Init(func_lib));
138
139 EXPECT_EQ("DoStuff", GetInterfaceName(lib_api_info, "DoStuffCpu"));
140 EXPECT_EQ("DoStuff", GetInterfaceName(lib_api_info, "DoStuffGpu"));
141 EXPECT_EQ("DoThings", GetInterfaceName(lib_api_info, "DoThings"));
142
143 EXPECT_EQ("CPU", GetPreferredDevice(lib_api_info, "DoStuffCpu"));
144 EXPECT_EQ("GPU", GetPreferredDevice(lib_api_info, "DoStuffGpu"));
145 EXPECT_EQ("", GetPreferredDevice(lib_api_info, "DoThings"));
146
147 EXPECT_TRUE(CheckEquivImpl(lib_api_info, "DoStuffCpu", {"DoStuffGpu"}));
148 EXPECT_TRUE(CheckEquivImpl(lib_api_info, "DoStuffGpu", {"DoStuffCpu"}));
149 EXPECT_TRUE(CheckEquivImpl(lib_api_info, "Undefined", {}));
150 EXPECT_TRUE(CheckEquivImpl(lib_api_info, "OneOff", {}));
151 EXPECT_TRUE(CheckEquivImpl(lib_api_info, "AnotherOneOff", {}));
152 EXPECT_TRUE(CheckEquivImpl(lib_api_info, "DoThings", {}));
153 }
154
TEST(FunctionApiInfoTest,ComplexFunctionLib)155 TEST(FunctionApiInfoTest, ComplexFunctionLib) {
156 FunctionDefLibrary func_lib;
157 PopulateComplexLibrary(&func_lib);
158 FunctionLibraryApiInfo lib_api_info;
159 TF_ASSERT_OK(lib_api_info.Init(func_lib));
160
161 EXPECT_EQ("DoStuff", GetInterfaceName(lib_api_info, "DoStuffCpu"));
162 EXPECT_EQ("DoStuff", GetInterfaceName(lib_api_info, "DoStuffCpu_gradient"));
163 EXPECT_EQ("DoStuff", GetInterfaceName(lib_api_info, "DoStuffGpu"));
164 EXPECT_EQ("DoStuff", GetInterfaceName(lib_api_info, "DoStuffGpu_gradient"));
165
166 EXPECT_EQ("CPU", GetPreferredDevice(lib_api_info, "DoStuffCpu"));
167 EXPECT_EQ("CPU", GetPreferredDevice(lib_api_info, "DoStuffCpu_gradient"));
168 EXPECT_EQ("GPU", GetPreferredDevice(lib_api_info, "DoStuffGpu"));
169 EXPECT_EQ("GPU", GetPreferredDevice(lib_api_info, "DoStuffGpu_gradient"));
170
171 EXPECT_TRUE(CheckEquivImpl(lib_api_info, "DoStuffCpu", {"DoStuffGpu"}));
172 EXPECT_TRUE(CheckEquivImpl(lib_api_info, "DoStuffGpu", {"DoStuffCpu"}));
173 EXPECT_TRUE(CheckEquivImpl(lib_api_info, "DoStuffCpu_gradient",
174 {"DoStuffGpu_gradient"}));
175 EXPECT_TRUE(CheckEquivImpl(lib_api_info, "DoStuffGpu_gradient",
176 {"DoStuffCpu_gradient"}));
177 EXPECT_TRUE(CheckEquivImpl(lib_api_info, "Undefined", {}));
178 }
179
TEST(FunctionApiInfoTest,MismatchedArguments)180 TEST(FunctionApiInfoTest, MismatchedArguments) {
181 FunctionDefLibrary func_lib;
182 PopulateSampleLibrary(/* mismatch_args */ true, &func_lib);
183 FunctionLibraryApiInfo lib_api_info;
184 const Status ret = lib_api_info.Init(func_lib);
185 EXPECT_FALSE(ret.ok());
186 }
187
188 } // namespace
189 } // namespace grappler
190 } // namespace tensorflow
191