• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/function_api_info.h"
17 
18 #include <string>
19 #include <unordered_set>
20 #include <vector>
21 
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/function.pb.h"
24 #include "tensorflow/core/framework/op_def.pb.h"
25 #include "tensorflow/core/lib/core/status_test_util.h"
26 #include "tensorflow/core/platform/test.h"
27 
28 namespace tensorflow {
29 namespace grappler {
30 namespace {
SetArg(const string & name,const string & type_name,OpDef::ArgDef * arg_def)31 void SetArg(const string& name, const string& type_name,
32             OpDef::ArgDef* arg_def) {
33   arg_def->set_name(name);
34   arg_def->set_type_attr(type_name);
35 }
36 
37 typedef std::pair<string, string> ArgSpec;  // name, type.
38 
SetArgs(const std::vector<ArgSpec> & input_args_spec,const std::vector<ArgSpec> & output_args_spec,OpDef * sig)39 void SetArgs(const std::vector<ArgSpec>& input_args_spec,
40              const std::vector<ArgSpec>& output_args_spec, OpDef* sig) {
41   for (const auto& arg_spec : input_args_spec)
42     SetArg(arg_spec.first, arg_spec.second, sig->add_input_arg());
43   for (const auto& arg_spec : output_args_spec)
44     SetArg(arg_spec.first, arg_spec.second, sig->add_output_arg());
45 }
46 
PopulateFunction(const string & name,const string & api_interface_name,const string & preferred_device,const std::vector<ArgSpec> & input_args,const std::vector<ArgSpec> & output_args,const string & forward_function_name,const string & backward_function_name,FunctionDef * func_def)47 void PopulateFunction(const string& name, const string& api_interface_name,
48                       const string& preferred_device,
49                       const std::vector<ArgSpec>& input_args,
50                       const std::vector<ArgSpec>& output_args,
51                       const string& forward_function_name,
52                       const string& backward_function_name,
53                       FunctionDef* func_def) {
54   OpDef* sig = func_def->mutable_signature();
55   sig->set_name(name);
56 
57   SetArgs(input_args, output_args, sig);
58 
59   auto* func_attr = func_def->mutable_attr();
60   if (!api_interface_name.empty())
61     (*func_attr)["api_implements"].set_s(api_interface_name);
62   if (!preferred_device.empty())
63     (*func_attr)["api_preferred_device"].set_s(preferred_device);
64   if (!forward_function_name.empty())
65     (*func_attr)["forward_function_name"].set_s(forward_function_name);
66   if (!backward_function_name.empty())
67     (*func_attr)["backward_function_name"].set_s(backward_function_name);
68 }
69 
PopulateSampleLibrary(const bool mismatch_args,FunctionDefLibrary * func_lib)70 void PopulateSampleLibrary(const bool mismatch_args,
71                            FunctionDefLibrary* func_lib) {
72   const std::vector<ArgSpec> func_args{{"in1", "float32"}, {"in2", "int32"}};
73   const std::vector<ArgSpec> func_wrong_args{{"in1", "int32"},
74                                              {"in2", "int32"}};
75   const std::vector<ArgSpec> output_args{{"out", "float32"}};
76   PopulateFunction("DoStuffCpu", "DoStuff", "CPU", func_args, output_args, "",
77                    "", func_lib->add_function());
78   PopulateFunction("DoStuffGpu", "DoStuff", "GPU",
79                    mismatch_args ? func_wrong_args : func_args, output_args, "",
80                    "", func_lib->add_function());
81   PopulateFunction("DoThings", "DoThings", "", func_args, output_args, "", "",
82                    func_lib->add_function());
83   PopulateFunction("OneOff", "", "", func_args, output_args, "", "",
84                    func_lib->add_function());
85   PopulateFunction("AnotherOneOff", "", "", func_args, output_args, "", "",
86                    func_lib->add_function());
87 }
88 
PopulateComplexLibrary(FunctionDefLibrary * func_lib)89 void PopulateComplexLibrary(FunctionDefLibrary* func_lib) {
90   const std::vector<ArgSpec> input_args{{"in1", "float32"}, {"in2", "int32"}};
91   const std::vector<ArgSpec> output_args{{"out", "float32"}};
92   const std::vector<ArgSpec> output_with_state{
93       {"out", "float32"}, {"state1", "int32"}, {"state2", "int32"}};
94 
95   PopulateFunction("DoStuffCpu", "DoStuff", "CPU", input_args, output_args, "",
96                    "DoStuffCpu_gradient", func_lib->add_function());
97   PopulateFunction("DoStuffCpu_gradient", "DoStuff", "CPU", output_args,
98                    input_args, "DoStuffCpu", "", func_lib->add_function());
99   PopulateFunction("DoStuffGpu", "DoStuff", "GPU", input_args,
100                    output_with_state, "", "DoStuffGpu_gradient",
101                    func_lib->add_function());
102   PopulateFunction("DoStuffGpu_gradient", "DoStuff", "GPU", output_with_state,
103                    input_args, "DoStuffGpu", "", func_lib->add_function());
104 }
105 
CheckEquivImpl(const FunctionLibraryApiInfo & lib_api_info,const string & func_name,const std::vector<string> & expected_other)106 bool CheckEquivImpl(const FunctionLibraryApiInfo& lib_api_info,
107                     const string& func_name,
108                     const std::vector<string>& expected_other) {
109   std::vector<string> other_impl;
110   Status status =
111       lib_api_info.GetEquivalentImplementations(func_name, &other_impl);
112   EXPECT_EQ(status, Status::OK());
113   const std::unordered_set<string> actual(other_impl.begin(), other_impl.end());
114   const std::unordered_set<string> expected(expected_other.begin(),
115                                             expected_other.end());
116   return actual == expected;
117 }
118 
GetInterfaceName(const FunctionLibraryApiInfo & lib_api_info,const string & func_name)119 string GetInterfaceName(const FunctionLibraryApiInfo& lib_api_info,
120                         const string& func_name) {
121   auto* info = lib_api_info.GetApiInfo(func_name);
122   CHECK_NOTNULL(info);
123   return info->interface_name();
124 }
125 
GetPreferredDevice(const FunctionLibraryApiInfo & lib_api_info,const string & func_name)126 string GetPreferredDevice(const FunctionLibraryApiInfo& lib_api_info,
127                           const string& func_name) {
128   auto* info = lib_api_info.GetApiInfo(func_name);
129   CHECK_NOTNULL(info);
130   return info->preferred_device();
131 }
132 
TEST(FunctionApiInfoTest,ParseTags)133 TEST(FunctionApiInfoTest, ParseTags) {
134   FunctionDefLibrary func_lib;
135   PopulateSampleLibrary(/* mismatch_args */ false, &func_lib);
136   FunctionLibraryApiInfo lib_api_info;
137   TF_ASSERT_OK(lib_api_info.Init(func_lib));
138 
139   EXPECT_EQ("DoStuff", GetInterfaceName(lib_api_info, "DoStuffCpu"));
140   EXPECT_EQ("DoStuff", GetInterfaceName(lib_api_info, "DoStuffGpu"));
141   EXPECT_EQ("DoThings", GetInterfaceName(lib_api_info, "DoThings"));
142 
143   EXPECT_EQ("CPU", GetPreferredDevice(lib_api_info, "DoStuffCpu"));
144   EXPECT_EQ("GPU", GetPreferredDevice(lib_api_info, "DoStuffGpu"));
145   EXPECT_EQ("", GetPreferredDevice(lib_api_info, "DoThings"));
146 
147   EXPECT_TRUE(CheckEquivImpl(lib_api_info, "DoStuffCpu", {"DoStuffGpu"}));
148   EXPECT_TRUE(CheckEquivImpl(lib_api_info, "DoStuffGpu", {"DoStuffCpu"}));
149   EXPECT_TRUE(CheckEquivImpl(lib_api_info, "Undefined", {}));
150   EXPECT_TRUE(CheckEquivImpl(lib_api_info, "OneOff", {}));
151   EXPECT_TRUE(CheckEquivImpl(lib_api_info, "AnotherOneOff", {}));
152   EXPECT_TRUE(CheckEquivImpl(lib_api_info, "DoThings", {}));
153 }
154 
TEST(FunctionApiInfoTest,ComplexFunctionLib)155 TEST(FunctionApiInfoTest, ComplexFunctionLib) {
156   FunctionDefLibrary func_lib;
157   PopulateComplexLibrary(&func_lib);
158   FunctionLibraryApiInfo lib_api_info;
159   TF_ASSERT_OK(lib_api_info.Init(func_lib));
160 
161   EXPECT_EQ("DoStuff", GetInterfaceName(lib_api_info, "DoStuffCpu"));
162   EXPECT_EQ("DoStuff", GetInterfaceName(lib_api_info, "DoStuffCpu_gradient"));
163   EXPECT_EQ("DoStuff", GetInterfaceName(lib_api_info, "DoStuffGpu"));
164   EXPECT_EQ("DoStuff", GetInterfaceName(lib_api_info, "DoStuffGpu_gradient"));
165 
166   EXPECT_EQ("CPU", GetPreferredDevice(lib_api_info, "DoStuffCpu"));
167   EXPECT_EQ("CPU", GetPreferredDevice(lib_api_info, "DoStuffCpu_gradient"));
168   EXPECT_EQ("GPU", GetPreferredDevice(lib_api_info, "DoStuffGpu"));
169   EXPECT_EQ("GPU", GetPreferredDevice(lib_api_info, "DoStuffGpu_gradient"));
170 
171   EXPECT_TRUE(CheckEquivImpl(lib_api_info, "DoStuffCpu", {"DoStuffGpu"}));
172   EXPECT_TRUE(CheckEquivImpl(lib_api_info, "DoStuffGpu", {"DoStuffCpu"}));
173   EXPECT_TRUE(CheckEquivImpl(lib_api_info, "DoStuffCpu_gradient",
174                              {"DoStuffGpu_gradient"}));
175   EXPECT_TRUE(CheckEquivImpl(lib_api_info, "DoStuffGpu_gradient",
176                              {"DoStuffCpu_gradient"}));
177   EXPECT_TRUE(CheckEquivImpl(lib_api_info, "Undefined", {}));
178 }
179 
TEST(FunctionApiInfoTest,MismatchedArguments)180 TEST(FunctionApiInfoTest, MismatchedArguments) {
181   FunctionDefLibrary func_lib;
182   PopulateSampleLibrary(/* mismatch_args */ true, &func_lib);
183   FunctionLibraryApiInfo lib_api_info;
184   const Status ret = lib_api_info.Init(func_lib);
185   EXPECT_FALSE(ret.ok());
186 }
187 
188 }  // namespace
189 }  // namespace grappler
190 }  // namespace tensorflow
191