• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/dataset_ops.h"
16 
17 // On mobile we do not provide this functionality because not all of its
18 // dependencies are available there.
19 #if !defined(IS_MOBILE_PLATFORM)
20 #include "tensorflow/core/common_runtime/graph_constructor.h"
21 #include "tensorflow/core/common_runtime/graph_runner.h"
22 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
23 #include "tensorflow/core/data/captured_function.h"
24 #include "tensorflow/core/data/dataset_utils.h"
25 #include "tensorflow/core/framework/dataset.h"
26 #include "tensorflow/core/framework/dataset_stateful_op_allowlist.h"
27 #include "tensorflow/core/framework/graph.pb.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/graph/graph_def_builder.h"
30 #include "tensorflow/core/grappler/graph_topology_view.h"
31 #include "tensorflow/core/grappler/utils/traversal.h"
32 #include "tensorflow/core/util/device_name_utils.h"
33 
34 namespace tensorflow {
35 namespace data {
36 
37 /* static */ constexpr const char* const DatasetToGraphOp::kAllowStateful;
38 /* static */ constexpr const char* const
39     DatasetToGraphOp::kStripDeviceAssignment;
40 /* static */ constexpr const char* const DatasetToGraphOp::kExternalStatePolicy;
41 /* static */ constexpr const char* const DatasetToGraphOp::kDatasetToGraph;
42 /* static */ constexpr const char* const DatasetFromGraphOp::kGraphDef;
43 /* static */ constexpr const char* const DatasetFromGraphOp::kHandle;
44 
45 namespace {
46 constexpr char kPyFunc[] = "PyFunc";
47 }  // namespace
48 
49 // See documentation in ../../ops/dataset_ops.cc for a high-level
50 // description of the following op.
DatasetToGraphOp(OpKernelConstruction * ctx)51 DatasetToGraphOp::DatasetToGraphOp(OpKernelConstruction* ctx)
52     : OpKernel(ctx), op_version_(ctx->def().op() == kDatasetToGraph ? 1 : 2) {
53   if (op_version_ == 2) {
54     if (ctx->HasAttr(kExternalStatePolicy)) {
55       int64_t state_change_option;
56       OP_REQUIRES_OK(ctx,
57                      ctx->GetAttr(kExternalStatePolicy, &state_change_option));
58       external_state_policy_ =
59           SerializationContext::ExternalStatePolicy(state_change_option);
60     }
61   } else {
62     if (ctx->HasAttr(kAllowStateful)) {
63       bool allow_stateful;
64       OP_REQUIRES_OK(ctx, ctx->GetAttr(kAllowStateful, &allow_stateful));
65       if (allow_stateful) {
66         external_state_policy_ =
67             SerializationContext::ExternalStatePolicy::kWarn;
68       } else {
69         external_state_policy_ =
70             SerializationContext::ExternalStatePolicy::kFail;
71       }
72     }
73   }
74 
75   if (ctx->HasAttr(kStripDeviceAssignment)) {
76     OP_REQUIRES_OK(
77         ctx, ctx->GetAttr(kStripDeviceAssignment, &strip_device_assignment_));
78   }
79 }
80 
Compute(OpKernelContext * ctx)81 void DatasetToGraphOp::Compute(OpKernelContext* ctx) {
82   DatasetBase* dataset;
83   OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
84   if (dataset->options().optional_external_state_policy_case() ==
85       Options::kExternalStatePolicy) {
86     switch (dataset->options().external_state_policy()) {
87       case ExternalStatePolicy::POLICY_WARN:
88         external_state_policy_ =
89             SerializationContext::ExternalStatePolicy::kWarn;
90         break;
91       case ExternalStatePolicy::POLICY_IGNORE:
92         external_state_policy_ =
93             SerializationContext::ExternalStatePolicy::kIgnore;
94         break;
95       case ExternalStatePolicy::POLICY_FAIL:
96         external_state_policy_ =
97             SerializationContext::ExternalStatePolicy::kFail;
98         break;
99       default: {
100         LOG(ERROR) << "Dataset " << dataset->type_string()
101                    << " has an unknown external_state_policy enum value: "
102                    << dataset->options().external_state_policy();
103       }
104     }
105   }
106   SerializationContext::Params params(ctx);
107   params.external_state_policy = external_state_policy_;
108 
109   GraphDef graph_def;
110   Status s = AsGraphDef(ctx, dataset, SerializationContext(params), &graph_def);
111   if (!s.ok()) {
112     ctx->CtxFailure(errors::FailedPrecondition(
113         "Failed to serialize the input pipeline graph: ", s.error_message()));
114     return;
115   }
116   if (strip_device_assignment_) {
117     auto library = graph_def.mutable_library();
118     for (auto& function : (*library->mutable_function())) {
119       for (auto& node : (*function.mutable_node_def())) {
120         // We do not strip the device assignment from `PyFunc` ops because they
121         // need to be pinned to a host that is known to have Python interpreter.
122         if (!node.device().empty() && node.op() != kPyFunc) {
123           *node.mutable_device() = DeviceNameUtils::LocalName(node.device());
124         }
125       }
126     }
127   }
128 
129   Tensor* result;
130   OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &result));
131   result->scalar<tstring>()() = graph_def.SerializeAsString();
132 }
133 
Compute(OpKernelContext * ctx)134 void DatasetCardinalityOp::Compute(OpKernelContext* ctx) {
135   DatasetBase* dataset;
136   OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
137   Tensor* result;
138   OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &result));
139   result->scalar<int64>()() = dataset->Cardinality();
140 }
141 
Compute(OpKernelContext * ctx)142 void DatasetFromGraphOp::Compute(OpKernelContext* ctx) {
143   tstring graph_def_string;
144   OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kGraphDef, &graph_def_string));
145   GraphDef graph_def;
146   OP_REQUIRES(ctx, graph_def.ParseFromString(graph_def_string),
147               errors::InvalidArgument("Could not parse GraphDef"));
148   string output_node;
149   for (const auto& node : graph_def.node()) {
150     if (node.op() == FunctionLibraryDefinition::kRetOp) {
151       output_node = node.input(0);
152     }
153   }
154   Graph graph(OpRegistry::Global());
155   OP_REQUIRES_OK(ctx, ImportGraphDef({}, graph_def, &graph, nullptr));
156 
157   FunctionLibraryRuntime* flr;
158   std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
159   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
160   OP_REQUIRES_OK(ctx,
161                  ctx->function_library()->Clone(&flib_def, &pflr, &flr, true));
162 
163   // Some function names may be duplicated (for example, if the serialized
164   // graph has an optimized function that retains its original name). We
165   // override functions in flib_def in the event of conflict. It is
166   // safe to assume that any node in the serialized graph is referring to the
167   // serialized function when there is a conflict.
168   OP_REQUIRES_OK(ctx,
169                  AddToFunctionLibrary(flib_def.get(), graph_def.library()));
170 
171   std::vector<Tensor> outputs;
172   GraphRunner graph_runner(flr->device());
173   OP_REQUIRES_OK(ctx,
174                  graph_runner.Run(&graph, flr, {}, {output_node}, &outputs));
175   OP_REQUIRES_OK(ctx, ctx->set_output(kHandle, outputs[0]));
176 }
177 
178 REGISTER_KERNEL_BUILDER(Name("DatasetToGraph").Device(DEVICE_CPU),
179                         DatasetToGraphOp);
180 REGISTER_KERNEL_BUILDER(Name("DatasetToGraphV2").Device(DEVICE_CPU),
181                         DatasetToGraphOp);
182 
183 REGISTER_KERNEL_BUILDER(Name("DatasetCardinality").Device(DEVICE_CPU),
184                         DatasetCardinalityOp);
185 REGISTER_KERNEL_BUILDER(
186     Name("ExperimentalDatasetCardinality").Device(DEVICE_CPU),
187     DatasetCardinalityOp);
188 
189 REGISTER_KERNEL_BUILDER(Name("DatasetFromGraph").Device(DEVICE_CPU),
190                         DatasetFromGraphOp);
191 
192 }  // namespace data
193 }  // namespace tensorflow
194 #endif  // !IS_MOBILE_PLATFORM
195