• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Lint as: python2, python3
2# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Utilities for helping test ops."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import numpy as np
23from six.moves import range
24
25
26def ConvertBetweenDataFormats(x, data_format_src, data_format_dst):
27  """Converts 4D tensor between data formats."""
28
29  valid_data_formats = ["NHWC", "NCHW", "HWNC", "HWCN"]
30  if data_format_src not in valid_data_formats:
31    raise ValueError("data_format_src must be of %s, got %s." %
32                     (valid_data_formats, data_format_src))
33  if data_format_dst not in valid_data_formats:
34    raise ValueError("data_format_dst must be of %s, got %s." %
35                     (valid_data_formats, data_format_dst))
36  if len(x.shape) != 4:
37    raise ValueError("x must be 4D, got shape %s." % x.shape)
38
39  if data_format_src == data_format_dst:
40    return x
41
42  dim_map = {d: i for i, d in enumerate(data_format_src)}
43  transpose_dims = [dim_map[d] for d in data_format_dst]
44  return np.transpose(x, transpose_dims)
45
46
47def PermuteDimsBetweenDataFormats(dims, data_format_src, data_format_dst):
48  """Get new shape for converting between data formats."""
49
50  valid_data_formats = ["NHWC", "NCHW", "HWNC", "HWCN"]
51  if data_format_src not in valid_data_formats:
52    raise ValueError("data_format_src must be of %s, got %s." %
53                     (valid_data_formats, data_format_src))
54  if data_format_dst not in valid_data_formats:
55    raise ValueError("data_format_dst must be of %s, got %s." %
56                     (valid_data_formats, data_format_dst))
57  if len(dims) != 4:
58    raise ValueError("dims must be of length 4, got %s." % dims)
59
60  if data_format_src == data_format_dst:
61    return dims
62
63  dim_map = {d: i for i, d in enumerate(data_format_src)}
64  permuted_dims = [dims[dim_map[d]] for d in data_format_dst]
65  return permuted_dims
66
67
68_JIT_WARMUP_ITERATIONS = 10
69
70
71def RunWithWarmup(sess, op_to_run, feed_dict, options=None, run_metadata=None):
72  """Runs a graph a few times to ensure that its clusters are compiled."""
73  for _ in range(0, _JIT_WARMUP_ITERATIONS):
74    sess.run(op_to_run, feed_dict, options=options)
75  return sess.run(
76      op_to_run, feed_dict, options=options, run_metadata=run_metadata)
77