Home
last modified time | relevance | path

Searched refs:tf_axis (Results 1 – 2 of 2) sorted by relevance

/external/tensorflow/tensorflow/python/kernel_tests/
Dgather_op_test.py106 tf_axis = constant_op.constant(axis)
108 gather = array_ops.gather(tf_params, tf_indices, axis=tf_axis)
127 gather, [tf_params, tf_indices, tf_axis], gather_grad)
/external/tensorflow/tensorflow/compiler/tf2tensorrt/convert/
Dconvert_nodes.cc504 Status ConvertAxis(int tf_axis, int trt_nb_dims, absl::string_view node_name, in ConvertAxis() argument
508 if (tf_axis < -tf_nb_dims || tf_axis >= tf_nb_dims) { in ConvertAxis()
510 "Axis value of ", tf_axis, " is out of bounds, must be in range [", in ConvertAxis()
514 if (tf_axis < 0) tf_axis += tf_nb_dims; in ConvertAxis()
516 if (tf_axis == 0) { in ConvertAxis()
522 *trt_axis = tf_axis - 1; in ConvertAxis()