• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// This is a simple include file used to simplify the splitting of the
17// tf_ops.cc file. The helpers in here should be refactored and moved to
18// tf_verifiers or tf_ops.
19// TODO(jpienaar): Remove this file post refactoring.
20
21//===----------------------------------------------------------------------===//
22// TF op helper functions
23//===----------------------------------------------------------------------===//
24
25// Returns the RankedTensorType for the given operand. TensorFlow constant ops
26// may have non-static shape because the shape is not propagated during constant
27// folding. If the defining op for the given operand is a constant op, this
28// routine uses the constant op's attribute to get the actual shape.
29static RankedTensorType GetRankedTensorTypeForOperand(Value operand) {
30  DenseElementsAttr attr;
31  if (matchPattern(operand, m_Constant(&attr))) {
32    return attr.getType().dyn_cast<RankedTensorType>();
33  }
34  return operand.getType().dyn_cast<RankedTensorType>();
35}
36
37// Returns true if the given `value` is of ranked float tensor type with the
38// given `rank`.
39static inline bool IsOfRankedFloatTensorType(RankedTensorType type, int rank) {
40  return type && type.getRank() == rank &&
41         type.getElementType().isa<FloatType>();
42}
43
44// Returns true if the given `value` has the specified rank or has unranked
45// type.
46static inline bool IsOfRankOrUnranked(Value value, int64_t rank) {
47  RankedTensorType type = GetRankedTensorTypeForOperand(value);
48  return !type || type.getRank() == rank;
49}
50
51// Returns true if the given `value` has at least the specified rank or has
52// unranked type.
53static inline bool HasRankAtLeast(Value value, int64_t rank) {
54  RankedTensorType type = GetRankedTensorTypeForOperand(value);
55  return !type || type.getRank() >= rank;
56}
57
58// Returns true if the given `value` has at most the specified rank or has
59// unranked type.
60static inline bool HasRankAtMost(Value value, int64_t rank) {
61  RankedTensorType type = GetRankedTensorTypeForOperand(value);
62  return !type || type.getRank() <= rank;
63}
64
65static bool IsUnknownDimOrRank(int64_t dim_or_rank) {
66  return dim_or_rank == -1;
67}
68
69// Returns the tf.Equal/tf.NotEqual result type given `x` and `y` and inputs. If
70// `incompatible_shape_error` is true, reports error if `x` and `y` has
71// incompatible shapes. Otherwise, returns a tensor type with unknown rank.
72static Type DeduceEqualCmpOpType(Builder *builder, Location loc, Value x,
73                                 Value y, BoolAttr incompatible_shape_error) {
74  auto result_type =
75      OpTrait::util::getBroadcastedType(x.getType(), y.getType());
76  if (!result_type) {
77    if (incompatible_shape_error.getValue()) {
78      mlir::emitError(loc, "non-broadcastable operands");
79    } else {
80      return UnrankedTensorType::get(builder->getI1Type());
81    }
82  }
83
84  auto ranked_type = result_type.dyn_cast<RankedTensorType>();
85  if (!ranked_type) return UnrankedTensorType::get(builder->getI1Type());
86
87  return RankedTensorType::get(ranked_type.getShape(), builder->getI1Type());
88}
89
90// Returns dimension index for the given TensorFlow axis that supports negative
91// indexing.
92static int64_t GetDimForAxis(int64_t axis, int64_t rank) {
93  return axis >= 0 ? axis : axis + rank;
94}
95
96// Infers output type for reduction ops such as SumOp, MaxOp etc.
97// TODO(b/e667204a): Move this logic to shape inference once it supports custom
98// inference functions.
99static Type InferReductionOpType(Value input, Value reduction_indices,
100                                 BoolAttr keep_dims, Builder *builder) {
101  Type input_ty = input.getType();
102  Type element_ty = getElementTypeOrSelf(input_ty);
103
104  // Output type is unranked if input type is not ranked.
105  auto ranked_ty = input_ty.dyn_cast<RankedTensorType>();
106  if (!ranked_ty) return UnrankedTensorType::get(element_ty);
107  int64_t rank = ranked_ty.getRank();
108
109  DenseIntElementsAttr indices;
110  if (!matchPattern(reduction_indices, m_Constant(&indices))) {
111    // Output type is unranked if reduction indices are not constant and reduced
112    // dimensions are not kept.
113    if (!keep_dims.getValue()) return UnrankedTensorType::get(element_ty);
114
115    // Otherwise, output type has same rank as the input.
116    return RankedTensorType::get(SmallVector<int64_t, 4>(rank, -1), element_ty);
117  }
118
119  int64_t num_reduce_dim = 0;
120  llvm::SmallVector<bool, 4> is_reduce_dim(rank, false);
121  for (const APInt &index : indices.getValues<APInt>()) {
122    int64_t dim = GetDimForAxis(index.getSExtValue(), rank);
123    // Invalid input.
124    if (dim < 0 || dim >= rank) return UnrankedTensorType::get(element_ty);
125
126    if (!is_reduce_dim[dim]) {
127      is_reduce_dim[dim] = true;
128      num_reduce_dim++;
129    }
130  }
131
132  ArrayRef<int64_t> shape = ranked_ty.getShape();
133  SmallVector<int64_t, 4> out_shape;
134  out_shape.reserve(rank - (keep_dims.getValue() ? 0 : num_reduce_dim));
135  for (int64_t i = 0; i < rank; ++i) {
136    if (!is_reduce_dim[i])
137      out_shape.push_back(shape[i]);
138    else if (keep_dims.getValue())
139      out_shape.push_back(1);
140  }
141  return RankedTensorType::get(out_shape, element_ty);
142}
143
144// Returns the equivalent Value skipping through identity nodes.
145Value LookThroughIdentity(Value result) {
146  while (isa_and_nonnull<IdentityOp, IdentityNOp>(result.getDefiningOp())) {
147    auto op_result = result.cast<OpResult>();
148    result = op_result.getOwner()->getOperand(op_result.getResultNumber());
149  }
150  return result;
151}
152
153// Verifies that the given types are cast compatible. If not, emits appropriate
154// error for the given op. If mask_one_dim is set to true, then the types are
155// allowed to have one mismatching dimension. Masking one of the dimensions is
156// useful for ops like Concat that requires all ranked inputs to have the same
157// rank and match dimension sizes for all but one of the dimensions.
158static LogicalResult VerifyTypesCompatibility(
159    Operation::operand_type_range types, bool mask_one_dim, Operation *op) {
160  constexpr int64_t kUninitialized = -1;
161  int64_t common_rank = kUninitialized;
162  llvm::SmallVector<int64_t, 4> common_dims;
163  int64_t dim_to_mask = kUninitialized;
164
165  // Initialize common_rank with rank of the first ranked type and verify that
166  // following ranked types have the same rank.
167  // Similarly, initialize each of the dimensions with the first type that has
168  // the dimension size available and verify that all following types have the
169  // same size for the dimension. However, if mask_one_dim is true, note down
170  // the dimension index on the first mismatch and ignore dimension at that
171  // index in following types.
172  for (Type ty : types) {
173    RankedTensorType ranked_ty = ty.dyn_cast<RankedTensorType>();
174    if (!ranked_ty) continue;
175
176    int64_t rank = ranked_ty.getRank();
177    if (common_rank == kUninitialized) {
178      common_rank = rank;
179      common_dims.resize(common_rank, kUninitialized);
180    } else if (common_rank != rank) {
181      return op->emitError()
182             << "operand type " << ranked_ty
183             << " is not compatible with preceding operands; expected rank: "
184             << common_rank;
185    }
186
187    for (int64_t i = 0, e = common_rank; i != e; i++) {
188      if (i == dim_to_mask) continue;
189
190      int64_t dim = ranked_ty.getDimSize(i);
191      if (dim == kUninitialized) continue;
192
193      int64_t &common_dim = common_dims[i];
194      if (common_dim == kUninitialized) {
195        common_dim = dim;
196      } else if (common_dim != dim) {
197        // If mask_one_dim is true, do not emit an error if this is the only
198        // dimension with mismatches. Note down the dimension to mask it from
199        // the following types.
200        if (mask_one_dim && dim_to_mask == kUninitialized) {
201          dim_to_mask = i;
202          continue;
203        }
204
205        return op->emitError() << "operand type " << ranked_ty
206                               << " is not compatible with preceding operands; "
207                                  "expected dimension at index "
208                               << i << ": " << common_dim;
209      }
210    }
211  }
212  return success();
213}
214
215//===----------------------------------------------------------------------===//
216// Helper functions detect device capabilities from RuntimeDevices.
217//===----------------------------------------------------------------------===//
218
219namespace {
220using DeviceNameUtils = ::tensorflow::DeviceNameUtils;
221using ParsedName = ::tensorflow::DeviceNameUtils::ParsedName;
222
223bool IsGpuDevice(const DeviceNameUtils::ParsedName &device) {
224  return device.type == ::tensorflow::DEVICE_GPU;
225}
226
227}  // namespace
228
229// Returns true if at least one GPU device is available at runtime.
230bool CanUseGpuDevice(const RuntimeDevices &devices) {
231  return llvm::any_of(devices.device_names(), IsGpuDevice);
232}
233
234// Returns true if all of the GPUs available at runtime support TensorCores
235// (NVIDIA compute capability >= 7.0).
236bool CanUseTensorCores(const RuntimeDevices &devices) {
237  auto has_tensor_cores = [&](const DeviceNameUtils::ParsedName &device) {
238    auto md = devices.GetGpuDeviceMetadata(device);
239    return md ? md->cc_major().getInt() >= 7 : false;
240  };
241  return llvm::all_of(
242      llvm::make_filter_range(devices.device_names(), IsGpuDevice),
243      has_tensor_cores);
244}
245
246// Returns true if operation does not have explicit device placement that would
247// prevent it from running on GPU device.
248bool CanUseGpuDevice(Operation *op) {
249  auto device_attr = op->getAttrOfType<StringAttr>("device");
250  if (!device_attr || device_attr.getValue().empty()) return true;
251
252  DeviceNameUtils::ParsedName device;
253  if (!DeviceNameUtils::ParseFullName(device_attr.getValue().str(), &device))
254    return false;
255
256  // We can't use GPU if operation explicitly placed on non-GPU device.
257  return !device.has_type || device.type == ::tensorflow::DEVICE_GPU;
258}
259
260//===----------------------------------------------------------------------===//
261// TF op helper functions to work with layout transformation.
262//===----------------------------------------------------------------------===//
263
264SmallVector<int64_t, 4> ReversePermutation(ArrayRef<int64_t> permutation) {
265  SmallVector<int64_t, 4> reverse(permutation.size());
266  for (size_t i = 0; i < permutation.size(); ++i) {
267    reverse[permutation[i]] = i;
268  }
269  return reverse;
270}
271
272SmallVector<int64_t, 4> GetDataFormatPermutation(StringRef from, StringRef to) {
273  if (from == "NHWC" && to == "NCHW") {
274    return {0, 3, 1, 2};
275  } else if (from == "NCHW" && to == "NHWC") {
276    return {0, 2, 3, 1};
277  } else {
278    return {};
279  }
280}
281
282// Shuffle elements in the `attr` according to the permutation. Optional
283// `inner_size` allows to shuffle array attributes created from rank 2 tensors
284// on outer dimension only.
285ArrayAttr ShuffleArrayAttr(ArrayAttr attr, ArrayRef<int64_t> permutation,
286                           int inner_size = 1) {
287  if (attr.empty()) return attr;
288
289  assert(attr.size() % inner_size == 0);
290  assert(attr.size() / inner_size == permutation.size());
291
292  SmallVector<Attribute, 8> values{attr.begin(), attr.end()};
293  SmallVector<Attribute, 8> shuffled(values.size());
294
295  for (size_t i = 0; i < permutation.size(); ++i) {
296    for (size_t j = 0; j < inner_size; ++j) {
297      shuffled[i * inner_size + j] = values[permutation[i] * inner_size + j];
298    }
299  }
300
301  return ArrayAttr::get(attr.getContext(), shuffled);
302}
303
304// Shuffle ranked tensor dimensions according to the permutation.
305Type ShuffleRankedTensorType(Type type, ArrayRef<int64_t> permutation) {
306  if (auto ranked_type = type.dyn_cast<RankedTensorType>()) {
307    ArrayRef<int64_t> shape = ranked_type.getShape();
308    assert(permutation.size() == shape.size());
309
310    SmallVector<int64_t, 4> new_shape(permutation.size());
311    for (size_t i = 0; i < permutation.size(); ++i)
312      new_shape[i] = shape[permutation[i]];
313
314    return RankedTensorType::get(new_shape, ranked_type.getElementType());
315  }
316
317  return type;
318}
319
320static bool AreCancellablePermutations(DenseIntElementsAttr perm0,
321                                       DenseIntElementsAttr perm1) {
322  if (perm0.getNumElements() == 0 || perm1.getNumElements() == 0) return false;
323  if (perm0.getNumElements() != perm1.getNumElements()) return false;
324
325  SmallVector<int64_t, 8> perm0_values;
326  for (const auto &value : perm0.getIntValues())
327    perm0_values.push_back(value.getSExtValue());
328
329  SmallVector<int64_t, 8> perm1_values;
330  for (const auto &value : perm1.getIntValues())
331    perm1_values.push_back(value.getSExtValue());
332
333  for (int i = 0; i < perm0_values.size(); ++i) {
334    if (perm0_values[perm1_values[i]] != i) return false;
335  }
336
337  return true;
338}
339
340// Default implementation of `LayoutSensitiveInterface::UpdateDataFormat` for
341// layout sensitive operations that do not have any additional layout dependent
342// attributes besides `data_format` string.
343template <typename Op>
344LogicalResult UpdateDataFormat(StringRef data_format, Op *op) {
345  auto perm = GetDataFormatPermutation(op->data_format(), data_format);
346  if (perm.empty()) return failure();
347
348  // Update data format attribute.
349  (*op)->setAttr("data_format", StringAttr::get(op->getContext(), data_format));
350
351  // Update types for all layout sensitive results.
352  auto layout_sensitive = cast<LayoutSensitiveInterface>(op->getOperation());
353  for (unsigned idx : layout_sensitive.GetLayoutDependentResults()) {
354    OpResult result = op->getOperation()->getResult(idx);
355    result.setType(ShuffleRankedTensorType(result.getType(), perm));
356  }
357
358  return success();
359}
360
361// Default implementation for folding operand transpose into the operation.
362// See `FoldOperandsTransposeInterface::FoldOperandsPermutation`.
363template <typename Op>
364LogicalResult FoldOperandsPermutation(
365    ArrayRef<int64_t> permutation, Op *op,
366    ArrayRef<std::pair<StringRef, ArrayAttr>> shuffle_attrs = {}) {
367  MLIRContext *context = (*op)->template getParentOfType<ModuleOp>().getContext();
368
369  // We only support NHWC <-> NCHW permutations.
370  static constexpr std::array<int64_t, 4> kNchwToNhwc = {0, 2, 3, 1};
371  static constexpr std::array<int64_t, 4> kNhwcToNchw = {0, 3, 1, 2};
372
373  // Operation data format after folding `permutation`.
374  StringRef target_data_format = [&]() -> StringRef {
375    if (op->data_format() == "NHWC" && permutation.equals(kNchwToNhwc)) {
376      return "NCHW";  // cancel NCHW->NHWC operand permutation
377    } else if (op->data_format() == "NCHW" && permutation.equals(kNhwcToNchw)) {
378      return "NHWC";  // cancel NHWC->NCHW operand permutation
379    } else {
380      return "";
381    }
382  }();
383  if (target_data_format.empty()) return failure();
384
385  // To fold operand `permutation` into the `op` we need shuffle all layout
386  // dependent attributes and types with a reverse permutation, and change
387  // operation data format to `target_data_format`.
388  //
389  // Example:
390  //   %1 = SomeOp(...)   {data_format = NHWC}
391  //   %2 = Transpose(%1) {permutation = NHWC->NCHW}
392  //   %3 = Op(%2)        {data_format = NCHW}
393  //
394  // To bypass %2 we have to change data format to shuffle data format from NCHW
395  // to NHWC, which is the reverse of operand permutation (function argument).
396  auto reverse_permutation =
397      GetDataFormatPermutation(op->data_format(), target_data_format);
398  if (reverse_permutation.empty()) return failure();
399
400  (*op)->setAttr("data_format", StringAttr::get(context, target_data_format));
401
402  for (auto pair : shuffle_attrs) {
403    StringRef attr_name = pair.first;
404    ArrayAttr attr_value = pair.second;
405    (*op)->setAttr(attr_name, ShuffleArrayAttr(attr_value, reverse_permutation));
406  }
407
408  auto fold = cast<FoldOperandsTransposeInterface>(op->getOperation());
409  for (unsigned idx : fold.GetLayoutDependentResults()) {
410    OpResult result = op->getOperation()->getResult(idx);
411    result.setType(
412        ShuffleRankedTensorType(result.getType(), reverse_permutation));
413  }
414
415  return success();
416}
417
418//===----------------------------------------------------------------------===//
419// Rewrite Pattern for removing trivial Arithmetic op.
420//===----------------------------------------------------------------------===//
421
422namespace {
423// Fold Arithmetic Op if one of the operands is a constant known to be an
424// Identity (e.g. X+0, X*1, etc...). For commutative operations fold if
425// known identity value is either lhs or rhs.
426template <
427    typename OpT,
428    typename std::enable_if<llvm::is_one_of<
429        OpT, AddV2Op, SubOp, MulOp, DivOp, RealDivOp>::value>::type * = nullptr>
430OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op,
431                                        ArrayRef<Attribute> operands) {
432  auto lhs_type = arithmetic_op.x().getType().template cast<ShapedType>();
433  auto rhs_type = arithmetic_op.y().getType().template cast<ShapedType>();
434  auto result_type =
435      arithmetic_op.getResult().getType().template cast<ShapedType>();
436
437  // We can fold arithmetic operation only of we can prove that we will not
438  // accidentally hide a broadcasting error.
439  auto is_valid_broadcasting = [](ShapedType operand_ty, ShapedType identity_ty,
440                                  ShapedType result_ty) -> bool {
441    // Scalar identity is broadcastable to any operand shape, we only need to
442    // check that operand has the same shape as a result.
443    bool scalar_identity = identity_ty.hasRank() && identity_ty.getRank() == 0;
444    if (scalar_identity) return operand_ty == result_ty;
445
446    // If identity is not a scalar, we must verify that all shapes are equal
447    // and statically known.
448    //
449    // TODO(ezhulenev): Fold if identity shape is statically know to be
450    // broadcastable to the operand shape.
451    return operand_ty == result_ty && identity_ty == result_ty &&
452           result_ty.hasStaticShape();
453  };
454
455  // Check that we have a constant operand on one side (candidate for identity).
456  const bool is_commutative =
457      (std::is_same<OpT, AddV2Op>::value || std::is_same<OpT, MulOp>::value);
458  auto lhs_attr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
459  auto rhs_attr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
460  if (!rhs_attr && !(is_commutative && lhs_attr)) return {};
461
462  // Mul and Div ops have identity value one while AddV2 and SubOp have identity
463  // value zero.
464  const int identity =
465      (std::is_same<OpT, MulOp>::value || std::is_same<OpT, DivOp>::value ||
466       std::is_same<OpT, RealDivOp>::value)
467          ? 1
468          : 0;
469
470  Type element_ty = lhs_type.getElementType();
471  Attribute identity_attr;
472  if (auto ty = element_ty.template dyn_cast<FloatType>()) {
473    identity_attr = FloatAttr::get(ty, static_cast<double>(identity));
474  } else if (auto ty = element_ty.template dyn_cast<IntegerType>()) {
475    identity_attr = IntegerAttr::get(ty, static_cast<int64_t>(identity));
476  } else {
477    return {};
478  }
479
480  // Fold: Op(Operand, Identity) -> Operand.
481  if (rhs_attr && is_valid_broadcasting(lhs_type, rhs_type, result_type)) {
482    if (rhs_attr.isSplat() && rhs_attr.getSplatValue() == identity_attr)
483      return arithmetic_op.x();
484  }
485
486  // Fold: Op(Identity, Operand) -> Operand for commutative operations.
487  if (lhs_attr && is_commutative &&
488      is_valid_broadcasting(rhs_type, lhs_type, result_type)) {
489    if (lhs_attr.isSplat() && lhs_attr.getSplatValue() == identity_attr)
490      return arithmetic_op.y();
491  }
492
493  return {};
494}
495}  // namespace
496
497// Verifies an reduction op's `input` and reduction `dims`.
498static LogicalResult VerifyReductionInputAndDims(Value input, Value dims,
499                                                 Location loc) {
500  auto dims_type = dims.getType().dyn_cast<RankedTensorType>();
501  if (!dims_type) return success();
502  if (dims_type.getRank() > 1)
503    return emitError(loc, "dimensions can only be 0D or 1D tensor");
504
505  auto input_type = input.getType().dyn_cast<RankedTensorType>();
506  if (!input_type) return success();
507  int64_t rank = input_type.getRank();
508
509  DenseIntElementsAttr dims_attr;
510  if (!matchPattern(dims, m_Constant(&dims_attr))) return success();
511  for (const auto &dim_pair : llvm::enumerate(dims_attr)) {
512    int64_t cur_dim = dim_pair.value().getSExtValue();
513    if (cur_dim < -rank || cur_dim >= rank)
514      return emitError(loc)
515             << dim_pair.index() << "-th dimension should be in the range of [-"
516             << rank << ", " << rank << ")";
517  }
518
519  return success();
520}
521
522// A type range with description (in singular form) attached to it.
523using TypeRangeWithDesc = std::pair<TypeRange, StringRef>;
524
525LogicalResult VerifyTypeRangesAreCompatible(Operation *op,
526                                            TypeRangeWithDesc range0,
527                                            TypeRangeWithDesc range1) {
528  if (range0.first.size() != range1.first.size()) {
529    return op->emitOpError()
530           << range0.second << "s (size = " << range0.first.size() << ")"
531           << " should have the same number of values as " << range1.second
532           << "s (size = " << range1.first.size() << ")";
533  }
534
535  for (auto it : llvm::enumerate(llvm::zip(range0.first, range1.first))) {
536    int index = it.index();
537    Type type0 = std::get<0>(it.value());
538    Type type1 = std::get<1>(it.value());
539    if (!AreCastCompatible({type0, type1}))
540      return op->emitOpError(llvm::formatv(
541          "{0} type {1} is incompatible with {2} type {3} at index {4}",
542          range0.second, type0, range1.second, type1, index));
543  }
544  return success();
545}
546
547//===----------------------------------------------------------------------===//
548// Function control flow canonicalization.
549//===----------------------------------------------------------------------===//
550
551// Eliminate attributes that are not needed, but can get attached to Ops
552// during import.
553template <typename Op>
554struct DropAttributes : public OpRewritePattern<Op> {
555  using OpRewritePattern<Op>::OpRewritePattern;
556
557  // Drop the "output_shapes" attribute.
558  LogicalResult matchAndRewrite(Op op,
559                                PatternRewriter &rewriter) const override {
560    bool found = !!op->removeAttr("output_shapes");
561    return success(found);
562  }
563};
564
565// Helper function to create TF op while copying all underscore attributes from
566// another TF op.
567// TODO(jpienaar): This is a workaround until behavior is established.
568template <typename OpTy, typename... Args>
569OpTy CreateTfOp(RewriterBase& b, Operation *op, Args &&... args) {
570  auto ret = b.create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
571  CopyDeviceAndUnderscoredAttributes(op, ret.getOperation());
572  return ret;
573}
574
575// Helper function to replace TF op with another op while copying all underscore
576// attributes from the TF op.
577// TODO(jpienaar): This is a workaround until behavior is established.
578template <typename OpTy, typename... Args>
579OpTy ReplaceTfOpWithNewOp(RewriterBase& b, Operation *op, Args &&... args) {
580  auto ret = CreateTfOp<OpTy>(b, op, std::forward<Args>(args)...);
581  b.replaceOp(op, ret.getOperation()->getResults());
582  return ret;
583}