diff --git a/build.sbt b/build.sbt index 8655395..7759e01 100644 --- a/build.sbt +++ b/build.sbt @@ -38,6 +38,8 @@ libraryDependencies += "com.lambdaworks" %% "jacks" % "2.3.3" libraryDependencies += "org.scalatest" % "scalatest_2.10" % "2.2.1" % "test" +libraryDependencies += "com.databricks" % "spark-csv_2.10" % "1.2.0" + libraryDependencies ++= Seq( "org.eclipse.jetty.orbit" % "javax.servlet" % "3.0.0.v201112011016" artifacts Artifact("javax.servlet", "jar", "jar"), diff --git a/src/main/scala/sampleclean/clean/outlierremoval/OutlierRemovalAlgorithm.scala b/src/main/scala/sampleclean/clean/outlierremoval/OutlierRemovalAlgorithm.scala new file mode 100644 index 0000000..e17b2ee --- /dev/null +++ b/src/main/scala/sampleclean/clean/outlierremoval/OutlierRemovalAlgorithm.scala @@ -0,0 +1,20 @@ +package sampleclean.clean.outlierremoval + +import org.apache.spark.sql.DataFrame +import sampleclean.api.WorkingSet + +/** + * @author Viraj Mahesh + * + */ +trait OutlierRemovalAlgorithm { + + /** + * Removes outliers from a dataset. The algorithm for identifying outliers + * depends on the implementing class. + * + * @param dataFrame The dataset that we are cleaning + * @return A new DataFrame with outliers removed + */ + def removeOutliers(dataFrame: DataFrame): DataFrame +} diff --git a/src/main/scala/sampleclean/clean/outlierremoval/StdDeviationFilter.scala b/src/main/scala/sampleclean/clean/outlierremoval/StdDeviationFilter.scala new file mode 100644 index 0000000..a1f7592 --- /dev/null +++ b/src/main/scala/sampleclean/clean/outlierremoval/StdDeviationFilter.scala @@ -0,0 +1,52 @@ +package sampleclean.clean.outlierremoval + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, DataFrame} +import sampleclean.api.SampleCleanContext + +/** + * Detects outliers based on deviation from the mean + * + * @author Viraj Mahesh + * + * @param scc The sample clean context + * @param maxDev An observation x is an outlier iff: abs(x - mean) > maxDev * stdDev + * @param colName The index of the column that will be used to classify a row + * as an outlier + */ +class StdDeviationFilter(scc: SampleCleanContext, + maxDev: Double, + colName: String) extends OutlierRemovalAlgorithm with Serializable { + + /** + * Convert x to a Double + */ + def toDouble(x: Any): Double = { + x match { + case n: Number => n.doubleValue() + } + } + + override def removeOutliers(dataFrame: DataFrame): DataFrame = { + val colIdx = dataFrame.columns.indexOf(colName) // Get the column index of this column + val univariateData: RDD[Double] = dataFrame.map({case x: Row => toDouble(x(colIdx))}) + + // Calculate the mean and standard deviation of the column + val mean: Double = univariateData.mean + val stdDev: Double = univariateData.stdev + + // Only keep those columns that are less than maxDev standard deviations from the mean + dataFrame.filter(dataFrame.col(colName) > mean - (maxDev * stdDev)) + .filter(dataFrame.col(colName) < mean + (maxDev * stdDev)) + } +} + +object StdDeviationFilter { + /** + * Creates a new StdDeviationFilter and applies it on the dataset + */ + def removeOutliers(scc: SampleCleanContext, maxDev: Double, + columnName: String, dataFrame: DataFrame) = { + new StdDeviationFilter(scc, maxDev, columnName).removeOutliers(dataFrame) + } +} diff --git a/src/test/resources/students.csv b/src/test/resources/students.csv new file mode 100644 index 0000000..8a06c6c --- /dev/null +++ b/src/test/resources/students.csv @@ -0,0 +1,5 @@ +'A',20,3.00 +'B',22,3.23 +'C',25,4.00 +'D',30,4.25 +'E',70,3.10 \ No newline at end of file diff --git a/src/test/scala/sampleclean/clean/outlierremoval/StdDeviationFilterTest.scala b/src/test/scala/sampleclean/clean/outlierremoval/StdDeviationFilterTest.scala new file mode 100644 index 0000000..e2b1936 --- /dev/null +++ b/src/test/scala/sampleclean/clean/outlierremoval/StdDeviationFilterTest.scala @@ -0,0 +1,62 @@ +package sampleclean.clean.outlierremoval + +import org.apache.spark.sql.types._ +import org.apache.spark.sql.SQLContext +import org.scalatest.FunSuite +import sampleclean.clean.LocalSCContext + +/** + * @author Viraj Mahesh + */ +class StdDeviationFilterTest extends FunSuite with LocalSCContext with Serializable { + + // Schema of the students table + val SCHEMA = StructType(List( + StructField("name", StringType, true), + StructField("age", IntegerType , true), + StructField("gpa", DoubleType, true))) + + // Properties passed into the databricks CSV loader + val PROPERTIES = Map("path" -> "./src/test/resources/students.csv") + + test("outlier removal integer column") { + withSampleCleanContext { scc => + val sc = scc.getSparkContext() + val sqlContext = new SQLContext(sc) + + val data = sqlContext.load("com.databricks.spark.csv", schema = SCHEMA, PROPERTIES) + val filteredData = StdDeviationFilter.removeOutliers(scc, 1.0, "age", data) + + // Find the index of the age column and only retain the age column + val colIdx = filteredData.columns.indexOf("age") + val remainingValues = filteredData.map(x => x.getInt(colIdx)).collect() + + assert(remainingValues.length == 4) + + assert(remainingValues.contains(20)) + assert(remainingValues.contains(22)) + assert(remainingValues.contains(25)) + assert(remainingValues.contains(30)) + } + } + + test("outlier removal double column") { + withSampleCleanContext { scc => + val sc = scc.getSparkContext() + val sqlContext = new SQLContext(sc) + + val data = sqlContext.load("com.databricks.spark.csv", schema = SCHEMA, PROPERTIES) + val filteredData = StdDeviationFilter.removeOutliers(scc, 1.0, "gpa", data) + + // Find the index of the GPA column and only retain the gpa column + val colIdx = filteredData.columns.indexOf("gpa") + val remainingValues = filteredData.map(x => x.getDouble(colIdx)).collect() + + assert(remainingValues.length == 3) + + assert(remainingValues.contains(3.23)) + assert(remainingValues.contains(4.00)) + assert(remainingValues.contains(3.10)) + } + } +}