1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <string>
17
18 #include "re2/re2.h"
19 #include "tensorflow/core/framework/op_kernel.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/core/status.h"
23 #include "tensorflow/core/platform/mutex.h"
24 #include "tensorflow/core/platform/thread_annotations.h"
25 #include "tensorflow/core/util/ptr_util.h"
26
27 namespace tensorflow {
28 namespace {
29
30 // Execute the specified regex using the given context.
31 // Context requirements:
32 // - "input" string Tensor at input_index=0
33 // - "output" string Tensor at output_index=0
InternalCompute(const RE2 & regex,const string & rewrite,const bool replace_global,OpKernelContext * ctx)34 Status InternalCompute(const RE2& regex, const string& rewrite,
35 const bool replace_global, OpKernelContext* ctx) {
36 const Tensor* input_tensor;
37 TF_RETURN_IF_ERROR(ctx->input("input", &input_tensor));
38 Tensor* output_tensor;
39 std::unique_ptr<Tensor> maybe_forwarded =
40 ctx->forward_input(0 /*input_index*/, 0 /*output_index*/,
41 tensorflow::DT_STRING, input_tensor->shape(),
42 ctx->input_memory_type(0), ctx->input_alloc_attr(0));
43 if (maybe_forwarded) {
44 output_tensor = maybe_forwarded.get();
45 TF_RETURN_IF_ERROR(ctx->set_output("output", *output_tensor));
46 } else {
47 TF_RETURN_IF_ERROR(
48 ctx->allocate_output("output", input_tensor->shape(), &output_tensor));
49 output_tensor->flat<tstring>() = input_tensor->flat<tstring>();
50 }
51 auto output_flat = output_tensor->flat<tstring>();
52 for (size_t i = 0; i < output_flat.size(); ++i) {
53 // TODO(dero): Mitigate copy; Global and GlobalReplace below currently only
54 // accept std::string.
55 string buf = output_flat(i);
56 if (replace_global) {
57 RE2::GlobalReplace(&buf, regex, rewrite);
58 } else {
59 RE2::Replace(&buf, regex, rewrite);
60 }
61 output_flat(i) = std::move(buf);
62 }
63 return Status::OK();
64 }
65 } // namespace
66
67 class RegexReplaceOp : public OpKernel {
68 public:
RegexReplaceOp(OpKernelConstruction * ctx)69 explicit RegexReplaceOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
70 OP_REQUIRES_OK(ctx, ctx->GetAttr("replace_global", &replace_global_));
71 }
72
~RegexReplaceOp()73 ~RegexReplaceOp() override {}
74
Compute(OpKernelContext * ctx)75 void Compute(OpKernelContext* ctx) override {
76 const Tensor* pattern_tensor;
77 OP_REQUIRES_OK(ctx, ctx->input("pattern", &pattern_tensor));
78 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(pattern_tensor->shape()),
79 errors::InvalidArgument("Pattern must be scalar, but received ",
80 pattern_tensor->shape().DebugString()));
81 const string& pattern = pattern_tensor->scalar<tstring>()();
82 std::shared_ptr<RE2> regex = CachedRE2(pattern);
83 OP_REQUIRES(ctx, regex->ok(),
84 errors::InvalidArgument("Invalid pattern: ", pattern,
85 ", error: ", regex->error()));
86
87 const Tensor* rewrite_tensor;
88 OP_REQUIRES_OK(ctx, ctx->input("rewrite", &rewrite_tensor));
89 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rewrite_tensor->shape()),
90 errors::InvalidArgument("Rewrite must be scalar, but received ",
91 rewrite_tensor->shape().DebugString()));
92 const string& rewrite = rewrite_tensor->scalar<tstring>()();
93 OP_REQUIRES_OK(ctx, InternalCompute(*regex, rewrite, replace_global_, ctx));
94 }
95
96 private:
CachedRE2(const string & pattern)97 std::shared_ptr<RE2> CachedRE2(const string& pattern) {
98 {
99 tf_shared_lock l(mu_);
100 if (regex_ != nullptr && regex_->pattern() == pattern) {
101 return regex_;
102 }
103 }
104 // Construct the new RE2 object before acquiring the lock.
105 auto regex = std::make_shared<RE2>(pattern);
106 {
107 mutex_lock l(mu_);
108 // Swap instead of assigning so that we destruct the old
109 // RE2 object (when necessary) after releasing the lock.
110 regex_.swap(regex);
111 return regex_;
112 }
113 }
114
115 bool replace_global_;
116 mutex mu_;
117 std::shared_ptr<RE2> regex_ TF_GUARDED_BY(mu_);
118
119 TF_DISALLOW_COPY_AND_ASSIGN(RegexReplaceOp);
120 };
121
122 REGISTER_KERNEL_BUILDER(Name("RegexReplace").Device(DEVICE_CPU),
123 RegexReplaceOp);
124
125 class StaticRegexReplaceOp : public OpKernel {
126 public:
StaticRegexReplaceOp(OpKernelConstruction * ctx)127 explicit StaticRegexReplaceOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
128 string pattern;
129 OP_REQUIRES_OK(ctx, ctx->GetAttr("pattern", &pattern));
130 re_ = MakeUnique<RE2>(pattern);
131 OP_REQUIRES(ctx, re_->ok(),
132 errors::InvalidArgument("Invalid pattern: ", pattern,
133 ", error: ", re_->error()));
134 OP_REQUIRES_OK(ctx, ctx->GetAttr("rewrite", &rewrite_str_));
135 OP_REQUIRES_OK(ctx, ctx->GetAttr("replace_global", &replace_global_));
136 }
137
Compute(OpKernelContext * ctx)138 void Compute(OpKernelContext* ctx) override {
139 OP_REQUIRES_OK(ctx,
140 InternalCompute(*re_, rewrite_str_, replace_global_, ctx));
141 }
142
143 private:
144 std::unique_ptr<RE2> re_;
145 string rewrite_str_;
146 bool replace_global_;
147 };
148
149 REGISTER_KERNEL_BUILDER(Name("StaticRegexReplace").Device(DEVICE_CPU),
150 StaticRegexReplaceOp);
151
152 } // namespace tensorflow
153