1# Lint as: python2, python3 2# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Utilities for helping test ops.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import numpy as np 23from six.moves import range 24 25 26def ConvertBetweenDataFormats(x, data_format_src, data_format_dst): 27 """Converts 4D tensor between data formats.""" 28 29 valid_data_formats = ["NHWC", "NCHW", "HWNC", "HWCN"] 30 if data_format_src not in valid_data_formats: 31 raise ValueError("data_format_src must be of %s, got %s." % 32 (valid_data_formats, data_format_src)) 33 if data_format_dst not in valid_data_formats: 34 raise ValueError("data_format_dst must be of %s, got %s." % 35 (valid_data_formats, data_format_dst)) 36 if len(x.shape) != 4: 37 raise ValueError("x must be 4D, got shape %s." % x.shape) 38 39 if data_format_src == data_format_dst: 40 return x 41 42 dim_map = {d: i for i, d in enumerate(data_format_src)} 43 transpose_dims = [dim_map[d] for d in data_format_dst] 44 return np.transpose(x, transpose_dims) 45 46 47def PermuteDimsBetweenDataFormats(dims, data_format_src, data_format_dst): 48 """Get new shape for converting between data formats.""" 49 50 valid_data_formats = ["NHWC", "NCHW", "HWNC", "HWCN"] 51 if data_format_src not in valid_data_formats: 52 raise ValueError("data_format_src must be of %s, got %s." % 53 (valid_data_formats, data_format_src)) 54 if data_format_dst not in valid_data_formats: 55 raise ValueError("data_format_dst must be of %s, got %s." % 56 (valid_data_formats, data_format_dst)) 57 if len(dims) != 4: 58 raise ValueError("dims must be of length 4, got %s." % dims) 59 60 if data_format_src == data_format_dst: 61 return dims 62 63 dim_map = {d: i for i, d in enumerate(data_format_src)} 64 permuted_dims = [dims[dim_map[d]] for d in data_format_dst] 65 return permuted_dims 66 67 68_JIT_WARMUP_ITERATIONS = 10 69 70 71def RunWithWarmup(sess, op_to_run, feed_dict, options=None, run_metadata=None): 72 """Runs a graph a few times to ensure that its clusters are compiled.""" 73 for _ in range(0, _JIT_WARMUP_ITERATIONS): 74 sess.run(op_to_run, feed_dict, options=options) 75 return sess.run( 76 op_to_run, feed_dict, options=options, run_metadata=run_metadata) 77