diff --git a/build.sbt b/build.sbt index b8e7638..6f6a337 100644 --- a/build.sbt +++ b/build.sbt @@ -8,18 +8,16 @@ lazy val root = Project("spark-knn", file(".")). lazy val core = knnProject("spark-knn-core"). settings( name := "spark-knn", - spName := "saurfang/spark-knn", credentials += Credentials(Path.userHome / ".ivy2" / ".sbtcredentials"), licenses += "Apache-2.0" -> url("http://opensource.org/licenses/Apache-2.0") ). settings(Dependencies.core). settings( - scalafixDependencies in ThisBuild += "org.scalatest" %% "autofix" % "3.1.0.0", + ThisBuild / scalafixDependencies += "org.scalatest" %% "autofix" % "3.1.0.1", addCompilerPlugin(scalafixSemanticdb) // enable SemanticDB ) lazy val examples = knnProject("spark-knn-examples"). dependsOn(core). settings(fork in run := true, coverageExcludedPackages := ".*examples.*"). - settings(Dependencies.examples). - settings(SparkSubmit.settings: _*) + settings(Dependencies.examples) diff --git a/project/Common.scala b/project/Common.scala index b9228fb..a92eb0a 100644 --- a/project/Common.scala +++ b/project/Common.scala @@ -2,7 +2,6 @@ import com.typesafe.sbt.GitVersioning import sbt._ import Keys._ import com.typesafe.sbt.GitPlugin.autoImport._ -import sbtsparkpackage.SparkPackagePlugin.autoImport._ import scala.language.experimental.macros import scala.reflect.macros.Context @@ -10,15 +9,12 @@ import scala.reflect.macros.Context object Common { val commonSettings = Seq( organization in ThisBuild := "com.github.saurfang", - javacOptions ++= Seq("-source", "1.8", "-target", "1.8"), - scalacOptions ++= Seq("-target:jvm-1.8", "-deprecation", "-feature"), + javacOptions ++= Seq("-source", "11", "-target", "11"), + scalacOptions ++= Seq("-deprecation", "-feature"), //git.useGitDescribe := true, git.baseVersion := "0.0.1", parallelExecution in test := false, - updateOptions := updateOptions.value.withCachedResolution(true), - sparkVersion := "3.0.1", - sparkComponents += "mllib", - spIgnoreProvided := true + updateOptions := updateOptions.value.withCachedResolution(true) ) def knnProject(path: String): Project = macro knnProjectMacroImpl diff --git a/project/Dependencies.scala b/project/Dependencies.scala index e5d9109..dad5613 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -3,16 +3,20 @@ import Keys._ object Dependencies { val Versions = Seq( - crossScalaVersions := Seq("2.12.8", "2.11.12"), + crossScalaVersions := Seq("2.12.18", "2.13.12"), scalaVersion := crossScalaVersions.value.head ) object Compile { - val breeze_natives = "org.scalanlp" %% "breeze-natives" % "1.0" % "provided" + val spark_version = "3.4.1" + val spark_core = "org.apache.spark" %% "spark-core" % spark_version % "provided" + val spark_mllib = "org.apache.spark" %% "spark-mllib" % spark_version % "provided" + val breeze = "org.scalanlp" %% "breeze" % "2.1.0" % "provided" + val netlib = "com.github.fommil.netlib" % "core" % "1.1.2" object Test { - val scalatest = "org.scalatest" %% "scalatest" % "3.1.0" % "test" - val sparktest = "org.apache.spark" %% "spark-core" % "3.0.1" % "test" classifier "tests" + val scalatest = "org.scalatest" %% "scalatest" % "3.2.17" % "test" + val sparktest = "org.apache.spark" %% "spark-core" % spark_version % "test" classifier "tests" } } @@ -20,6 +24,6 @@ object Dependencies { import Test._ val l = libraryDependencies - val core = l ++= Seq(scalatest, sparktest) - val examples = core +: (l ++= Seq(breeze_natives)) + val core = l ++= Seq(spark_core, spark_mllib, scalatest, sparktest) + val examples = core +: (l ++= Seq(breeze, netlib)) } diff --git a/project/SparkSubmit.scala b/project/SparkSubmit.scala deleted file mode 100644 index 45ac23b..0000000 --- a/project/SparkSubmit.scala +++ /dev/null @@ -1,25 +0,0 @@ -import sbtsparksubmit.SparkSubmitPlugin.autoImport._ - -object SparkSubmit { - lazy val settings = - SparkSubmitSetting( - SparkSubmitSetting("sparkMNIST", - Seq( - "--master", "local[3]", - "--class", "com.github.saurfang.spark.ml.knn.examples.MNIST" - ) - ), - SparkSubmitSetting("sparkMNISTCross", - Seq( - "--master", "local[3]", - "--class", "com.github.saurfang.spark.ml.knn.examples.MNISTCrossValidation" - ) - ), - SparkSubmitSetting("sparkMNISTBench", - Seq( - "--master", "local[3]", - "--class", "com.github.saurfang.spark.ml.knn.examples.MNISTBenchmark" - ) - ) - ) -} diff --git a/project/build.properties b/project/build.properties index 07d9935..331a838 100644 --- a/project/build.properties +++ b/project/build.properties @@ -1 +1 @@ -sbt.version = 0.13.18 \ No newline at end of file +sbt.version = 1.9.7 \ No newline at end of file diff --git a/project/plugins.sbt b/project/plugins.sbt index 75d074a..822a09d 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,20 +1,6 @@ -addSbtPlugin("com.github.gseitz" % "sbt-release" % "1.0.3") - -addSbtPlugin("me.lessis" % "bintray-sbt" % "0.3.0") - -addSbtPlugin("com.typesafe.sbt" % "sbt-git" % "0.8.5") - -addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.3") - -addSbtPlugin("com.github.saurfang" % "sbt-spark-submit" % "0.0.4") - -addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.5.0") - -addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.8.0" - excludeAll ExclusionRule(organization = "com.danieltrinh")) -libraryDependencies += "org.scalariform" %% "scalariform" % "0.1.8" - -resolvers += "bintray-spark-packages" at "https://dl.bintray.com/spark-packages/maven/" -addSbtPlugin("org.spark-packages" % "sbt-spark-package" % "0.2.6") - -addSbtPlugin("ch.epfl.scala" % "sbt-scalafix" % "0.9.4") +addSbtPlugin("com.github.sbt" % "sbt-release" % "1.1.0") +addSbtPlugin("com.typesafe.sbt" % "sbt-git" % "1.0.2") +addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "2.1.4") +addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.5.2") +addSbtPlugin("org.scoverage" % "sbt-scoverage" % "2.0.9") +addSbtPlugin("ch.epfl.scala" % "sbt-scalafix" % "0.11.1") diff --git a/spark-knn-core/src/main/scala/org/apache/spark/ml/classification/KNNClassifier.scala b/spark-knn-core/src/main/scala/org/apache/spark/ml/classification/KNNClassifier.scala index aedb43f..313728d 100644 --- a/spark-knn-core/src/main/scala/org/apache/spark/ml/classification/KNNClassifier.scala +++ b/spark-knn-core/src/main/scala/org/apache/spark/ml/classification/KNNClassifier.scala @@ -7,11 +7,13 @@ import org.apache.spark.ml.param.shared.HasWeightCol import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.ml.linalg._ import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.stat.MultiClassSummarizer import org.apache.spark.rdd.RDD import org.apache.spark.sql.types.{DoubleType, StructType} import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.storage.StorageLevel import org.apache.spark.SparkException +import org.apache.spark.sql.functions.col import scala.collection.mutable.ArrayBuffer @@ -76,6 +78,12 @@ with KNNParams with HasWeightCol { /** @group setParam */ def setSeed(value: Long): this.type = set(seed, value) + /** reimplemented as it has been removed removed from spark 3.4.0 */ + private def extractLabeledPoints(dataset: Dataset[_]): RDD[LabeledPoint] = { + import dataset.sparkSession.implicits._ + dataset.select(col($(labelCol)), col($(featuresCol))).as[LabeledPoint].rdd + } + override protected def train(dataset: Dataset[_]): KNNClassificationModel = { // Extract columns from data. If dataset is persisted, do not persist oldDataset. val instances = extractLabeledPoints(dataset).map { diff --git a/spark-knn-core/src/main/scala/org/apache/spark/ml/knn/KNN.scala b/spark-knn-core/src/main/scala/org/apache/spark/ml/knn/KNN.scala index 6ad5bf0..28772b0 100644 --- a/spark-knn-core/src/main/scala/org/apache/spark/ml/knn/KNN.scala +++ b/spark-knn-core/src/main/scala/org/apache/spark/ml/knn/KNN.scala @@ -368,6 +368,9 @@ class KNN(override val uid: String) extends Estimator[KNNModel] with KNNParams { /** @group setParam */ def setFeaturesCol(value: String): this.type = set(featuresCol, value) + /** @group setParam */ + def setDistanceCol(value: String): this.type = set(distanceCol, value) + /** @group setParam */ def setK(value: Int): this.type = set(k, value) diff --git a/spark-knn-examples/src/main/scala/org/apache/spark/ml/classification/NaiveKNN.scala b/spark-knn-examples/src/main/scala/org/apache/spark/ml/classification/NaiveKNN.scala index e967eeb..7003ef3 100644 --- a/spark-knn-examples/src/main/scala/org/apache/spark/ml/classification/NaiveKNN.scala +++ b/spark-knn-examples/src/main/scala/org/apache/spark/ml/classification/NaiveKNN.scala @@ -8,11 +8,13 @@ import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.ml.{Model, Predictor} import org.apache.spark.ml.linalg._ import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.stat.MultiClassSummarizer import org.apache.spark.rdd.RDD import org.apache.spark.sql.types.{ArrayType, DoubleType, StructType} import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.storage.StorageLevel import org.apache.spark.mllib.rdd.MLPairRDDFunctions._ +import org.apache.spark.sql.functions.col import scala.collection.mutable.ArrayBuffer @@ -25,6 +27,12 @@ class NaiveKNNClassifier(override val uid: String, val distanceMetric: DistanceM override def copy(extra: ParamMap): NaiveKNNClassifier = defaultCopy(extra) + /** reimplemented as it has been removed removed from spark 3.4.0 */ + private def extractLabeledPoints(dataset: Dataset[_]): RDD[LabeledPoint] = { + import dataset.sparkSession.implicits._ + dataset.select(col($(labelCol)), col($(featuresCol))).as[LabeledPoint].rdd + } + override protected def train(dataset: Dataset[_]): NaiveKNNClassifierModel = { // Extract columns from data. If dataset is persisted, do not persist oldDataset. val instances = extractLabeledPoints(dataset).map {