1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/dataset_ops.h"
16
17 // On mobile we do not provide this functionality because not all of its
18 // dependencies are available there.
19 #if !defined(IS_MOBILE_PLATFORM)
20 #include "tensorflow/core/common_runtime/graph_constructor.h"
21 #include "tensorflow/core/common_runtime/graph_runner.h"
22 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
23 #include "tensorflow/core/data/captured_function.h"
24 #include "tensorflow/core/data/dataset_utils.h"
25 #include "tensorflow/core/framework/dataset.h"
26 #include "tensorflow/core/framework/dataset_stateful_op_allowlist.h"
27 #include "tensorflow/core/framework/graph.pb.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/graph/graph_def_builder.h"
30 #include "tensorflow/core/grappler/graph_topology_view.h"
31 #include "tensorflow/core/grappler/utils/traversal.h"
32 #include "tensorflow/core/util/device_name_utils.h"
33
34 namespace tensorflow {
35 namespace data {
36
37 /* static */ constexpr const char* const DatasetToGraphOp::kAllowStateful;
38 /* static */ constexpr const char* const
39 DatasetToGraphOp::kStripDeviceAssignment;
40 /* static */ constexpr const char* const DatasetToGraphOp::kExternalStatePolicy;
41 /* static */ constexpr const char* const DatasetToGraphOp::kDatasetToGraph;
42 /* static */ constexpr const char* const DatasetFromGraphOp::kGraphDef;
43 /* static */ constexpr const char* const DatasetFromGraphOp::kHandle;
44
45 namespace {
46 constexpr char kPyFunc[] = "PyFunc";
47 } // namespace
48
49 // See documentation in ../../ops/dataset_ops.cc for a high-level
50 // description of the following op.
DatasetToGraphOp(OpKernelConstruction * ctx)51 DatasetToGraphOp::DatasetToGraphOp(OpKernelConstruction* ctx)
52 : OpKernel(ctx), op_version_(ctx->def().op() == kDatasetToGraph ? 1 : 2) {
53 if (op_version_ == 2) {
54 if (ctx->HasAttr(kExternalStatePolicy)) {
55 int64_t state_change_option;
56 OP_REQUIRES_OK(ctx,
57 ctx->GetAttr(kExternalStatePolicy, &state_change_option));
58 external_state_policy_ =
59 SerializationContext::ExternalStatePolicy(state_change_option);
60 }
61 } else {
62 if (ctx->HasAttr(kAllowStateful)) {
63 bool allow_stateful;
64 OP_REQUIRES_OK(ctx, ctx->GetAttr(kAllowStateful, &allow_stateful));
65 if (allow_stateful) {
66 external_state_policy_ =
67 SerializationContext::ExternalStatePolicy::kWarn;
68 } else {
69 external_state_policy_ =
70 SerializationContext::ExternalStatePolicy::kFail;
71 }
72 }
73 }
74
75 if (ctx->HasAttr(kStripDeviceAssignment)) {
76 OP_REQUIRES_OK(
77 ctx, ctx->GetAttr(kStripDeviceAssignment, &strip_device_assignment_));
78 }
79 }
80
Compute(OpKernelContext * ctx)81 void DatasetToGraphOp::Compute(OpKernelContext* ctx) {
82 DatasetBase* dataset;
83 OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
84 if (dataset->options().optional_external_state_policy_case() ==
85 Options::kExternalStatePolicy) {
86 switch (dataset->options().external_state_policy()) {
87 case ExternalStatePolicy::POLICY_WARN:
88 external_state_policy_ =
89 SerializationContext::ExternalStatePolicy::kWarn;
90 break;
91 case ExternalStatePolicy::POLICY_IGNORE:
92 external_state_policy_ =
93 SerializationContext::ExternalStatePolicy::kIgnore;
94 break;
95 case ExternalStatePolicy::POLICY_FAIL:
96 external_state_policy_ =
97 SerializationContext::ExternalStatePolicy::kFail;
98 break;
99 default: {
100 LOG(ERROR) << "Dataset " << dataset->type_string()
101 << " has an unknown external_state_policy enum value: "
102 << dataset->options().external_state_policy();
103 }
104 }
105 }
106 SerializationContext::Params params(ctx);
107 params.external_state_policy = external_state_policy_;
108
109 GraphDef graph_def;
110 Status s = AsGraphDef(ctx, dataset, SerializationContext(params), &graph_def);
111 if (!s.ok()) {
112 ctx->CtxFailure(errors::FailedPrecondition(
113 "Failed to serialize the input pipeline graph: ", s.error_message()));
114 return;
115 }
116 if (strip_device_assignment_) {
117 auto library = graph_def.mutable_library();
118 for (auto& function : (*library->mutable_function())) {
119 for (auto& node : (*function.mutable_node_def())) {
120 // We do not strip the device assignment from `PyFunc` ops because they
121 // need to be pinned to a host that is known to have Python interpreter.
122 if (!node.device().empty() && node.op() != kPyFunc) {
123 *node.mutable_device() = DeviceNameUtils::LocalName(node.device());
124 }
125 }
126 }
127 }
128
129 Tensor* result;
130 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &result));
131 result->scalar<tstring>()() = graph_def.SerializeAsString();
132 }
133
Compute(OpKernelContext * ctx)134 void DatasetCardinalityOp::Compute(OpKernelContext* ctx) {
135 DatasetBase* dataset;
136 OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
137 Tensor* result;
138 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &result));
139 result->scalar<int64>()() = dataset->Cardinality();
140 }
141
Compute(OpKernelContext * ctx)142 void DatasetFromGraphOp::Compute(OpKernelContext* ctx) {
143 tstring graph_def_string;
144 OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kGraphDef, &graph_def_string));
145 GraphDef graph_def;
146 OP_REQUIRES(ctx, graph_def.ParseFromString(graph_def_string),
147 errors::InvalidArgument("Could not parse GraphDef"));
148 string output_node;
149 for (const auto& node : graph_def.node()) {
150 if (node.op() == FunctionLibraryDefinition::kRetOp) {
151 output_node = node.input(0);
152 }
153 }
154 Graph graph(OpRegistry::Global());
155 OP_REQUIRES_OK(ctx, ImportGraphDef({}, graph_def, &graph, nullptr));
156
157 FunctionLibraryRuntime* flr;
158 std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
159 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
160 OP_REQUIRES_OK(ctx,
161 ctx->function_library()->Clone(&flib_def, &pflr, &flr, true));
162
163 // Some function names may be duplicated (for example, if the serialized
164 // graph has an optimized function that retains its original name). We
165 // override functions in flib_def in the event of conflict. It is
166 // safe to assume that any node in the serialized graph is referring to the
167 // serialized function when there is a conflict.
168 OP_REQUIRES_OK(ctx,
169 AddToFunctionLibrary(flib_def.get(), graph_def.library()));
170
171 std::vector<Tensor> outputs;
172 GraphRunner graph_runner(flr->device());
173 OP_REQUIRES_OK(ctx,
174 graph_runner.Run(&graph, flr, {}, {output_node}, &outputs));
175 OP_REQUIRES_OK(ctx, ctx->set_output(kHandle, outputs[0]));
176 }
177
178 REGISTER_KERNEL_BUILDER(Name("DatasetToGraph").Device(DEVICE_CPU),
179 DatasetToGraphOp);
180 REGISTER_KERNEL_BUILDER(Name("DatasetToGraphV2").Device(DEVICE_CPU),
181 DatasetToGraphOp);
182
183 REGISTER_KERNEL_BUILDER(Name("DatasetCardinality").Device(DEVICE_CPU),
184 DatasetCardinalityOp);
185 REGISTER_KERNEL_BUILDER(
186 Name("ExperimentalDatasetCardinality").Device(DEVICE_CPU),
187 DatasetCardinalityOp);
188
189 REGISTER_KERNEL_BUILDER(Name("DatasetFromGraph").Device(DEVICE_CPU),
190 DatasetFromGraphOp);
191
192 } // namespace data
193 } // namespace tensorflow
194 #endif // !IS_MOBILE_PLATFORM
195