• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one or more
3  * contributor license agreements.  See the NOTICE file distributed with
4  * this work for additional information regarding copyright ownership.
5  * The ASF licenses this file to You under the Apache License, Version 2.0
6  * (the "License"); you may not use this file except in compliance with
7  * the License.  You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 package org.apache.commons.math3.ml.distance;
18 
19 import org.apache.commons.math3.exception.DimensionMismatchException;
20 import org.apache.commons.math3.util.FastMath;
21 import org.apache.commons.math3.util.MathArrays;
22 
23 /**
24  * Calculates the Earh Mover's distance (also known as Wasserstein metric) between two distributions.
25  *
26  * @see <a href="http://en.wikipedia.org/wiki/Earth_mover's_distance">Earth Mover's distance (Wikipedia)</a>
27  *
28  * @since 3.3
29  */
30 public class EarthMoversDistance implements DistanceMeasure {
31 
32     /** Serializable version identifier. */
33     private static final long serialVersionUID = -5406732779747414922L;
34 
35     /** {@inheritDoc} */
compute(double[] a, double[] b)36     public double compute(double[] a, double[] b)
37     throws DimensionMismatchException {
38         MathArrays.checkEqualLength(a, b);
39         double lastDistance = 0;
40         double totalDistance = 0;
41         for (int i = 0; i < a.length; i++) {
42             final double currentDistance = (a[i] + lastDistance) - b[i];
43             totalDistance += FastMath.abs(currentDistance);
44             lastDistance = currentDistance;
45         }
46         return totalDistance;
47     }
48 }
49