Lines Matching refs:embedding_matrix
135 EmbeddingMatrix *const embedding_matrix = in ConcatEmbeddings() local
137 if (embedding_matrix == nullptr) { in ConcatEmbeddings()
149 if (!GetEmbeddingInternal(feature_vector, embedding_matrix, concat_offset, in ConcatEmbeddings()
160 EmbeddingMatrix *const embedding_matrix = embedding_matrices_[es_index].get(); in GetEmbedding() local
161 if (embedding_matrix == nullptr) { in GetEmbedding()
166 return GetEmbeddingInternal(feature_vector, embedding_matrix, 0, embedding, in GetEmbedding()
172 EmbeddingMatrix *const embedding_matrix, const int concat_offset, in GetEmbeddingInternal() argument
174 const int embedding_dim = embedding_matrix->dim(); in GetEmbeddingInternal()
176 embedding_matrix->quant_type() != QuantizationType::NONE; in GetEmbeddingInternal()
204 embedding_matrix->get_embedding(id, &embedding_data, &multiplier); in GetEmbeddingInternal()
210 embedding_matrix->get_embedding(feature_value, &embedding_data, in GetEmbeddingInternal()