1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/kernel_def_builder.h"
17 #include "tensorflow/core/framework/attr_value.pb.h"
18 #include "tensorflow/core/framework/kernel_def.pb.h"
19
20 namespace tensorflow {
21
KernelDefBuilder(const char * op_name)22 KernelDefBuilder::KernelDefBuilder(const char* op_name) {
23 kernel_def_ = new KernelDef;
24 kernel_def_->set_op(op_name);
25 }
26
~KernelDefBuilder()27 KernelDefBuilder::~KernelDefBuilder() {
28 DCHECK(kernel_def_ == nullptr) << "Did not call Build()";
29 }
30
Device(const char * device_type)31 KernelDefBuilder& KernelDefBuilder::Device(const char* device_type) {
32 kernel_def_->set_device_type(device_type);
33 return *this;
34 }
35
36 template <>
AttrConstraint(const char * attr_name,gtl::ArraySlice<int64> allowed)37 KernelDefBuilder& KernelDefBuilder::AttrConstraint<int64>(
38 const char* attr_name, gtl::ArraySlice<int64> allowed) {
39 auto* constraint = kernel_def_->add_constraint();
40 constraint->set_name(attr_name);
41 auto* allowed_values = constraint->mutable_allowed_values()->mutable_list();
42 for (const int64 integer : allowed) {
43 LOG(INFO) << integer;
44 allowed_values->add_i(integer);
45 }
46 return *this;
47 }
48
49 template <>
AttrConstraint(const char * attr_name,int64 allowed)50 KernelDefBuilder& KernelDefBuilder::AttrConstraint<int64>(const char* attr_name,
51 int64 allowed) {
52 return AttrConstraint(
53 attr_name,
54 gtl::ArraySlice<int64>(std::initializer_list<int64>({allowed})));
55 }
56
57 template <>
AttrConstraint(const char * attr_name,gtl::ArraySlice<string> allowed)58 KernelDefBuilder& KernelDefBuilder::AttrConstraint<string>(
59 const char* attr_name, gtl::ArraySlice<string> allowed) {
60 auto* constraint = kernel_def_->add_constraint();
61 constraint->set_name(attr_name);
62 auto* allowed_values = constraint->mutable_allowed_values()->mutable_list();
63 for (const auto& str : allowed) {
64 allowed_values->add_s(str);
65 }
66 return *this;
67 }
68
69 template <>
AttrConstraint(const char * attr_name,string allowed)70 KernelDefBuilder& KernelDefBuilder::AttrConstraint<string>(
71 const char* attr_name, string allowed) {
72 return AttrConstraint(
73 attr_name,
74 gtl::ArraySlice<string>(std::initializer_list<string>({allowed})));
75 }
76
77 template <>
AttrConstraint(const char * attr_name,gtl::ArraySlice<const char * > allowed)78 KernelDefBuilder& KernelDefBuilder::AttrConstraint<const char*>(
79 const char* attr_name, gtl::ArraySlice<const char*> allowed) {
80 auto* constraint = kernel_def_->add_constraint();
81 constraint->set_name(attr_name);
82 auto* allowed_values = constraint->mutable_allowed_values()->mutable_list();
83 for (const auto& str : allowed) {
84 allowed_values->add_s(str);
85 }
86 return *this;
87 }
88
89 template <>
AttrConstraint(const char * attr_name,const char * allowed)90 KernelDefBuilder& KernelDefBuilder::AttrConstraint<const char*>(
91 const char* attr_name, const char* allowed) {
92 return AttrConstraint(attr_name,
93 gtl::ArraySlice<const char*>(
94 std::initializer_list<const char*>({allowed})));
95 }
96
97 template <>
AttrConstraint(const char * attr_name,bool allowed)98 KernelDefBuilder& KernelDefBuilder::AttrConstraint<bool>(const char* attr_name,
99 bool allowed) {
100 auto* constraint = kernel_def_->add_constraint();
101 constraint->set_name(attr_name);
102 auto* allowed_values = constraint->mutable_allowed_values()->mutable_list();
103 allowed_values->add_b(allowed);
104 return *this;
105 }
106
TypeConstraint(const char * attr_name,gtl::ArraySlice<DataType> allowed)107 KernelDefBuilder& KernelDefBuilder::TypeConstraint(
108 const char* attr_name, gtl::ArraySlice<DataType> allowed) {
109 auto* constraint = kernel_def_->add_constraint();
110 constraint->set_name(attr_name);
111 auto* allowed_values = constraint->mutable_allowed_values()->mutable_list();
112 for (DataType dt : allowed) {
113 allowed_values->add_type(dt);
114 }
115 return *this;
116 }
117
TypeConstraint(const char * attr_name,DataType allowed)118 KernelDefBuilder& KernelDefBuilder::TypeConstraint(const char* attr_name,
119 DataType allowed) {
120 auto* constraint = kernel_def_->add_constraint();
121 constraint->set_name(attr_name);
122 constraint->mutable_allowed_values()->mutable_list()->add_type(allowed);
123 return *this;
124 }
125
HostMemory(const char * arg_name)126 KernelDefBuilder& KernelDefBuilder::HostMemory(const char* arg_name) {
127 kernel_def_->add_host_memory_arg(arg_name);
128 return *this;
129 }
130
Label(const char * label)131 KernelDefBuilder& KernelDefBuilder::Label(const char* label) {
132 CHECK_EQ(kernel_def_->label(), "")
133 << "Trying to set a kernel's label a second time: '" << label
134 << "' in: " << kernel_def_->DebugString();
135 kernel_def_->set_label(label);
136 return *this;
137 }
138
Priority(int32 priority)139 KernelDefBuilder& KernelDefBuilder::Priority(int32 priority) {
140 kernel_def_->set_priority(priority);
141 return *this;
142 }
143
Build()144 const KernelDef* KernelDefBuilder::Build() {
145 KernelDef* r = kernel_def_;
146 kernel_def_ = nullptr;
147 return r;
148 }
149
150 } // namespace tensorflow
151