• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <string>
17 
18 #include "re2/re2.h"
19 #include "tensorflow/core/framework/op_kernel.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/core/status.h"
23 #include "tensorflow/core/platform/mutex.h"
24 #include "tensorflow/core/platform/thread_annotations.h"
25 #include "tensorflow/core/util/ptr_util.h"
26 
27 namespace tensorflow {
28 namespace {
29 
30 // Execute the specified regex using the given context.
31 // Context requirements:
32 //  - "input" string Tensor at input_index=0
33 //  - "output" string Tensor at output_index=0
InternalCompute(const RE2 & regex,const string & rewrite,const bool replace_global,OpKernelContext * ctx)34 Status InternalCompute(const RE2& regex, const string& rewrite,
35                        const bool replace_global, OpKernelContext* ctx) {
36   const Tensor* input_tensor;
37   TF_RETURN_IF_ERROR(ctx->input("input", &input_tensor));
38   Tensor* output_tensor;
39   std::unique_ptr<Tensor> maybe_forwarded =
40       ctx->forward_input(0 /*input_index*/, 0 /*output_index*/,
41                          tensorflow::DT_STRING, input_tensor->shape(),
42                          ctx->input_memory_type(0), ctx->input_alloc_attr(0));
43   if (maybe_forwarded) {
44     output_tensor = maybe_forwarded.get();
45     TF_RETURN_IF_ERROR(ctx->set_output("output", *output_tensor));
46   } else {
47     TF_RETURN_IF_ERROR(
48         ctx->allocate_output("output", input_tensor->shape(), &output_tensor));
49     output_tensor->flat<tstring>() = input_tensor->flat<tstring>();
50   }
51   auto output_flat = output_tensor->flat<tstring>();
52   for (size_t i = 0; i < output_flat.size(); ++i) {
53     // TODO(dero): Mitigate copy; Global and GlobalReplace below currently only
54     // accept std::string.
55     string buf = output_flat(i);
56     if (replace_global) {
57       RE2::GlobalReplace(&buf, regex, rewrite);
58     } else {
59       RE2::Replace(&buf, regex, rewrite);
60     }
61     output_flat(i) = std::move(buf);
62   }
63   return Status::OK();
64 }
65 }  // namespace
66 
67 class RegexReplaceOp : public OpKernel {
68  public:
RegexReplaceOp(OpKernelConstruction * ctx)69   explicit RegexReplaceOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
70     OP_REQUIRES_OK(ctx, ctx->GetAttr("replace_global", &replace_global_));
71   }
72 
~RegexReplaceOp()73   ~RegexReplaceOp() override {}
74 
Compute(OpKernelContext * ctx)75   void Compute(OpKernelContext* ctx) override {
76     const Tensor* pattern_tensor;
77     OP_REQUIRES_OK(ctx, ctx->input("pattern", &pattern_tensor));
78     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(pattern_tensor->shape()),
79                 errors::InvalidArgument("Pattern must be scalar, but received ",
80                                         pattern_tensor->shape().DebugString()));
81     const string& pattern = pattern_tensor->scalar<tstring>()();
82     std::shared_ptr<RE2> regex = CachedRE2(pattern);
83     OP_REQUIRES(ctx, regex->ok(),
84                 errors::InvalidArgument("Invalid pattern: ", pattern,
85                                         ", error: ", regex->error()));
86 
87     const Tensor* rewrite_tensor;
88     OP_REQUIRES_OK(ctx, ctx->input("rewrite", &rewrite_tensor));
89     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rewrite_tensor->shape()),
90                 errors::InvalidArgument("Rewrite must be scalar, but received ",
91                                         rewrite_tensor->shape().DebugString()));
92     const string& rewrite = rewrite_tensor->scalar<tstring>()();
93     OP_REQUIRES_OK(ctx, InternalCompute(*regex, rewrite, replace_global_, ctx));
94   }
95 
96  private:
CachedRE2(const string & pattern)97   std::shared_ptr<RE2> CachedRE2(const string& pattern) {
98     {
99       tf_shared_lock l(mu_);
100       if (regex_ != nullptr && regex_->pattern() == pattern) {
101         return regex_;
102       }
103     }
104     // Construct the new RE2 object before acquiring the lock.
105     auto regex = std::make_shared<RE2>(pattern);
106     {
107       mutex_lock l(mu_);
108       // Swap instead of assigning so that we destruct the old
109       // RE2 object (when necessary) after releasing the lock.
110       regex_.swap(regex);
111       return regex_;
112     }
113   }
114 
115   bool replace_global_;
116   mutex mu_;
117   std::shared_ptr<RE2> regex_ TF_GUARDED_BY(mu_);
118 
119   TF_DISALLOW_COPY_AND_ASSIGN(RegexReplaceOp);
120 };
121 
122 REGISTER_KERNEL_BUILDER(Name("RegexReplace").Device(DEVICE_CPU),
123                         RegexReplaceOp);
124 
125 class StaticRegexReplaceOp : public OpKernel {
126  public:
StaticRegexReplaceOp(OpKernelConstruction * ctx)127   explicit StaticRegexReplaceOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
128     string pattern;
129     OP_REQUIRES_OK(ctx, ctx->GetAttr("pattern", &pattern));
130     re_ = MakeUnique<RE2>(pattern);
131     OP_REQUIRES(ctx, re_->ok(),
132                 errors::InvalidArgument("Invalid pattern: ", pattern,
133                                         ", error: ", re_->error()));
134     OP_REQUIRES_OK(ctx, ctx->GetAttr("rewrite", &rewrite_str_));
135     OP_REQUIRES_OK(ctx, ctx->GetAttr("replace_global", &replace_global_));
136   }
137 
Compute(OpKernelContext * ctx)138   void Compute(OpKernelContext* ctx) override {
139     OP_REQUIRES_OK(ctx,
140                    InternalCompute(*re_, rewrite_str_, replace_global_, ctx));
141   }
142 
143  private:
144   std::unique_ptr<RE2> re_;
145   string rewrite_str_;
146   bool replace_global_;
147 };
148 
149 REGISTER_KERNEL_BUILDER(Name("StaticRegexReplace").Device(DEVICE_CPU),
150                         StaticRegexReplaceOp);
151 
152 }  // namespace tensorflow
153