From e271d002b322c79139e71ab1298132699d63aa10 Mon Sep 17 00:00:00 2001 From: "xin.lei" Date: Thu, 12 Jan 2023 13:53:27 +0800 Subject: [PATCH] enable to handle nan value in VectorAssembler --- pom.xml | 26 +++++++++++-------- .../spark/ml/feature/MDLPDiscretizer.scala | 2 +- .../apache/spark/ml/feature/TestHelper.scala | 2 ++ .../ml/feature/ThresholdFinderSuite.scala | 2 +- 4 files changed, 19 insertions(+), 13 deletions(-) diff --git a/pom.xml b/pom.xml index 6fa3b03..e147095 100644 --- a/pom.xml +++ b/pom.xml @@ -10,19 +10,23 @@ 1.6 UTF-8 2.11 - 2.11.6 + 2.11.8 + 2.4.2 + 2.11 - - org.apache.spark - spark-core_2.11 - 2.1.1 - - - org.apache.spark - spark-mllib_2.11 - 2.1.1 - + + org.apache.spark + spark-core_${scala.spark.version} + ${spark.version} + + + + org.apache.spark + spark-mllib_${scala.spark.version} + ${spark.version} + + org.scala-lang scala-library diff --git a/src/main/scala/org/apache/spark/ml/feature/MDLPDiscretizer.scala b/src/main/scala/org/apache/spark/ml/feature/MDLPDiscretizer.scala index 73ebb4e..dc5d896 100644 --- a/src/main/scala/org/apache/spark/ml/feature/MDLPDiscretizer.scala +++ b/src/main/scala/org/apache/spark/ml/feature/MDLPDiscretizer.scala @@ -260,7 +260,7 @@ object DiscretizerModel extends MLReadable[DiscretizerModel] { .select("splits") .head() val model = new DiscretizerModel(metadata.uid, splits) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/src/test/scala/org/apache/spark/ml/feature/TestHelper.scala b/src/test/scala/org/apache/spark/ml/feature/TestHelper.scala index fd82aa3..97553d1 100644 --- a/src/test/scala/org/apache/spark/ml/feature/TestHelper.scala +++ b/src/test/scala/org/apache/spark/ml/feature/TestHelper.scala @@ -37,6 +37,8 @@ object TestHelper { val featureAssembler = new VectorAssembler() .setInputCols(inputCols) .setOutputCol("features") + .setHandleInvalid("keep") + val processedDf = featureAssembler.transform(dataframe) val discretizer = new MDLPDiscretizer() diff --git a/src/test/scala/org/apache/spark/ml/feature/ThresholdFinderSuite.scala b/src/test/scala/org/apache/spark/ml/feature/ThresholdFinderSuite.scala index 3df75c8..5f693f4 100644 --- a/src/test/scala/org/apache/spark/ml/feature/ThresholdFinderSuite.scala +++ b/src/test/scala/org/apache/spark/ml/feature/ThresholdFinderSuite.scala @@ -1,6 +1,6 @@ package org.apache.spark.ml.feature -import org.apache.spark.mllib.feature.{BucketInfo, FeatureUtils, ThresholdFinder} +import org.apache.spark.mllib.feature.{BucketInfo, ThresholdFinder} import org.apache.spark.sql.SQLContext import org.junit.runner.RunWith import org.scalatest.{BeforeAndAfterAll, FunSuite}