Searched refs:grad_out_index (Results 1 – 1 of 1) sorted by relevance
985 int64 grad_out_index = argmax_flat(index); in launch() local988 grad_out_index += cur_batch * output_size_per_batch; in launch()990 CHECK(grad_out_index >= output_start && grad_out_index < output_end) in launch()991 << "Invalid output gradient index: " << grad_out_index << ", " in launch()993 grad_out_flat(grad_out_index) += grad_in_flat(index); in launch()