1 //! Cosine similarity 2 #![cfg(feature = "std")] 3 use crate::counter::Counter; 4 use crate::{Algorithm, Result}; 5 6 /// [Cosine similarity] is the cosine of the angle between two vectors. 7 /// 8 /// This is how many symbols the given strings have in common 9 /// divided by the square root of the product of the strings' lengths. 10 /// 11 /// [Cosine similarity]: https://en.wikipedia.org/wiki/Cosine_similarity 12 #[derive(Default)] 13 pub struct Cosine {} 14 15 impl Algorithm<f64> for Cosine { for_iter<C, E>(&self, s1: C, s2: C) -> Result<f64> where C: Iterator<Item = E>, E: Eq + core::hash::Hash,16 fn for_iter<C, E>(&self, s1: C, s2: C) -> Result<f64> 17 where 18 C: Iterator<Item = E>, 19 E: Eq + core::hash::Hash, 20 { 21 let c1 = Counter::from_iter(s1); 22 let c2 = Counter::from_iter(s2); 23 let n1 = c1.count(); 24 let n2 = c2.count(); 25 let res = match (n1, n2) { 26 (0, 0) => 1., 27 (_, 0) | (0, _) => 0., 28 (_, _) => { 29 let ic = c1.intersect_count(&c2); 30 ic as f64 / ((n1 * n2) as f64).sqrt() 31 } 32 }; 33 Result { 34 abs: res, 35 is_distance: false, 36 max: 1., 37 len1: c1.count(), 38 len2: c2.count(), 39 } 40 } 41 } 42 43 #[cfg(test)] 44 mod tests { 45 use crate::str::cosine; 46 use assert2::assert; 47 use rstest::rstest; 48 is_close(a: f64, b: f64) -> bool49 fn is_close(a: f64, b: f64) -> bool { 50 (a - b).abs() < 1E-5 51 } 52 53 #[rstest] 54 #[case("", "", 1.)] 55 #[case("nelson", "", 0.)] 56 #[case("", "neilsen", 0.)] 57 // parity with textdistance 58 #[case("test", "text", 3. / 4.)] 59 #[case("nelson", "neilsen", 0.771516)] function_str(#[case] s1: &str, #[case] s2: &str, #[case] exp: f64)60 fn function_str(#[case] s1: &str, #[case] s2: &str, #[case] exp: f64) { 61 let act = cosine(s1, s2); 62 let ok = is_close(act, exp); 63 assert!(ok, "cosine({}, {}) is {}, not {}", s1, s2, act, exp); 64 } 65 } 66