Searched refs:tf_axis (Results 1 – 2 of 2) sorted by relevance
/external/tensorflow/tensorflow/python/kernel_tests/ |
D | gather_op_test.py | 106 tf_axis = constant_op.constant(axis) 108 gather = array_ops.gather(tf_params, tf_indices, axis=tf_axis) 127 gather, [tf_params, tf_indices, tf_axis], gather_grad)
|
/external/tensorflow/tensorflow/compiler/tf2tensorrt/convert/ |
D | convert_nodes.cc | 504 Status ConvertAxis(int tf_axis, int trt_nb_dims, absl::string_view node_name, in ConvertAxis() argument 508 if (tf_axis < -tf_nb_dims || tf_axis >= tf_nb_dims) { in ConvertAxis() 510 "Axis value of ", tf_axis, " is out of bounds, must be in range [", in ConvertAxis() 514 if (tf_axis < 0) tf_axis += tf_nb_dims; in ConvertAxis() 516 if (tf_axis == 0) { in ConvertAxis() 522 *trt_axis = tf_axis - 1; in ConvertAxis()
|