Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions scripts/utils.sh
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ run_test_suites() {
'++ 2.12.20 OtherTests:testOnly * -- -l SampleDataTest' \
'++ 2.13.16 OtherTests:testOnly * -- -l SampleDataTest' \
'++ 2.12.20 UDFTests:testOnly * -- -l SampleDataTest' \
'++ 2.13.16 UDFTests:testOnly * -- -l SampleDataTest -l com.snowflake.snowpark.UDFPackageTest' \
'++ 2.13.16 UDFTests:testOnly * -- -l SampleDataTest' \
'++ 2.12.20 UDTFTests:testOnly * -- -l SampleDataTest' \
'++ 2.13.16 UDTFTests:testOnly * -- -l SampleDataTest -l com.snowflake.snowpark.UDFPackageTest' \
'++ 2.13.16 UDTFTests:testOnly * -- -l SampleDataTest' \
+SprocTests:test
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import com.snowflake.snowpark.{
}

import java.io._
import com.snowflake.snowpark.internal.Utils.{ScalaCompatVersion, SnowparkPackageName}
import com.snowflake.snowpark.types.{Geography, Geometry, Variant}
import com.snowflake.snowpark_java.types.InternalUtils
import com.snowflake.snowpark_java.udtf._
Expand All @@ -33,6 +34,10 @@ import scala.collection.{JavaConverters, mutable}
import scala.collection.JavaConverters._

object JavaUtils {
def snowparkScalaCompatVersion(): String = ScalaCompatVersion

def snowparkPackageName(): String = SnowparkPackageName

def session_setJavaAPI(builder: SessionBuilder): SessionBuilder =
builder.setJavaAPI()

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.snowflake.snowpark_test;

import com.snowflake.snowpark.SnowparkClientException;
import com.snowflake.snowpark.internal.JavaUtils;
import com.snowflake.snowpark_java.Row;
import com.snowflake.snowpark_java.Session;
import com.snowflake.snowpark_java.StoredProcedure;
Expand Down Expand Up @@ -37,7 +38,9 @@ public void call() {
+ "returns STRING\n"
+ "language scala\n"
+ "runtime_version=2.12\n"
+ "packages=('com.snowflake:snowpark:latest')\n"
+ "packages=('"
+ JavaUtils.snowparkPackageName()
+ ":latest')\n"
+ "handler='Test.run'\n"
+ "as\n"
+ "$$\n"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.snowflake.snowpark_test;

import com.snowflake.snowpark.TestUtils;
import com.snowflake.snowpark.internal.JavaUtils;
import com.snowflake.snowpark_java.Session;

public abstract class UDFTestBase extends TestFunctions {
Expand All @@ -18,7 +19,11 @@ public Session getSession() {
}

protected Session createSession() {
return Session.builder().configFile(defaultProfile).create();
Session newSession = Session.builder().configFile(defaultProfile).create();
if (JavaUtils.snowparkScalaCompatVersion().equals("2.13")) {
newSession.sql("alter session set ENABLE_SCALA_UDF_RUNTIME_2_13=true").collect();
}
return newSession;
}

protected void runQuery(String sql) {
Expand Down
12 changes: 3 additions & 9 deletions src/test/scala/com/snowflake/snowpark/TestUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ object TestUtils extends Logging {
sess.removeDependency(snClassDir)
if (usePackages) {
sess.removePackage(Utils.clientPackageName)
sess.addPackage("com.snowflake:snowpark:latest")
sess.addPackage(s"${Utils.SnowparkPackageName}:latest")
} else {
// Replace scoverage classes in classpath because they are instrumented.
sess.addDependency(snClassDir.replace("scoverage-", ""))
Expand All @@ -147,10 +147,7 @@ object TestUtils extends Logging {
classOf[org.scalatest.compatible.Assertion],
classOf[org.scalactic.TripleEquals], // scalactic jar
classOf[io.opentelemetry.exporters.inmemory.InMemorySpanExporter],
classOf[io.opentelemetry.sdk.trace.export.SpanExporter],
classOf[scala.Product],
classOf[scala.reflect.api.Universe],
classOf[scala.tools.nsc.Global])
classOf[io.opentelemetry.sdk.trace.export.SpanExporter])
.flatMap(UDFClassPath.getPathForClass(_))
.foreach(path => {
val file = new File(path)
Expand Down Expand Up @@ -179,10 +176,7 @@ object TestUtils extends Logging {
classOf[BeforeAndAfterAll], // scala test jar
classOf[org.scalactic.TripleEquals], // scalactic jar
classOf[io.opentelemetry.exporters.inmemory.InMemorySpanExporter],
classOf[io.opentelemetry.sdk.trace.export.SpanExporter],
classOf[scala.Product],
classOf[scala.reflect.api.Universe],
classOf[scala.tools.nsc.Global])
classOf[io.opentelemetry.sdk.trace.export.SpanExporter])
.flatMap(UDFClassPath.getPathForClass(_))
.foreach(path => {
val file = new File(path)
Expand Down
19 changes: 14 additions & 5 deletions src/test/scala/com/snowflake/snowpark/UDFClasspathSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package com.snowflake.snowpark
import java.io.File
import java.net.URLClassLoader
import com.snowflake.snowpark.internal.ScalaFunctions._toUdf
import com.snowflake.snowpark.internal.Utils.clientPackageName
import com.snowflake.snowpark.internal.{UDFClassPath, UDXRegistrationHandler, Utils}
import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito.{atLeastOnce, never, reset, spy, times, verify}
Expand All @@ -15,9 +14,19 @@ import scala.util.Random
@UDFTest
class UDFClasspathSuite extends SNTestBase {

override def beforeAll(): Unit = {
super.beforeAll()
override def beforeAll: Unit = {
super.beforeAll
TestUtils.addDepsToClassPath(session)
if (Utils.ScalaCompatVersion == "2.13") {
session.runQuery("alter session set ENABLE_SCALA_UDF_RUNTIME_2_13=true")
}
}

override def afterAll: Unit = {
if (Utils.ScalaCompatVersion == "2.13") {
session.runQuery("alter session set ENABLE_SCALA_UDF_RUNTIME_2_13=false")
}
super.afterAll
}

test("Test that jars uploaded to different stages") {
Expand Down Expand Up @@ -97,8 +106,8 @@ class UDFClasspathSuite extends SNTestBase {
val fixedPath = path.replace("scoverage-", "")
// Remove snowpark jar from classpath, The code will catch error and add the path and retry
newSession.removeDependency(fixedPath)
newSession.removePackage("com.snowflake:snowpark:latest")
newSession.removePackage(clientPackageName)
newSession.removePackage(s"${Utils.SnowparkPackageName}:latest")
newSession.removePackage(Utils.clientPackageName)
val func = "func_" + Random.nextInt().abs

ignoreClassNotFoundForScoverageClasses {
Expand Down
14 changes: 10 additions & 4 deletions src/test/scala/com/snowflake/snowpark/UDFInternalSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,16 @@ class UDFInternalSuite extends TestData {
TestUtils.addDepsToClassPath(session, Some(stageName))
TestUtils.addDepsToClassPath(newSession, Some(stageName))
}
if (Utils.ScalaCompatVersion == "2.13") {
session.runQuery("alter session set ENABLE_SCALA_UDF_RUNTIME_2_13=true")
}
}

override def afterAll: Unit = {
dropStage(stageName)
if (Utils.ScalaCompatVersion == "2.13") {
session.runQuery("alter session set ENABLE_SCALA_UDF_RUNTIME_2_13=false")
}
super.afterAll
}

Expand All @@ -46,7 +52,7 @@ class UDFInternalSuite extends TestData {
verify(mockSession, times(1)).addDependency(path)
}
verify(mockSession, times(1)).removeDependency(path)
verify(mockSession, times(1)).addPackage("com.snowflake:snowpark:latest")
verify(mockSession, times(1)).addPackage(s"${Utils.SnowparkPackageName}:latest")
}

test("Test permanent udf not failing back to upload jar", JavaStoredProcExclude) {
Expand Down Expand Up @@ -80,7 +86,7 @@ class UDFInternalSuite extends TestData {
verify(mockSession, times(1)).addDependency(path)
}
verify(mockSession, times(1)).removeDependency(path)
verify(mockSession, times(1)).addPackage("com.snowflake:snowpark:latest")
verify(mockSession, times(1)).addPackage(s"${Utils.SnowparkPackageName}:latest")
}

test("Test add version logic", JavaStoredProcExclude) {
Expand Down Expand Up @@ -189,7 +195,7 @@ class PackageUDFSuite extends UDFSuite {
val snClassDir = UDFClassPath.getPathForClass(classOf[Session]).get
session.removeDependency(snClassDir.replace("scoverage-", ""))
session.removePackage(Utils.clientPackageName)
session.addPackage("com.snowflake:snowpark:latest")
session.addPackage(s"${Utils.SnowparkPackageName}:latest")
}

override def afterAll: Unit = {
Expand All @@ -207,7 +213,7 @@ class PackageUDTFSuite extends UDTFSuite {
val snClassDir = UDFClassPath.getPathForClass(classOf[Session]).get
session.removeDependency(snClassDir.replace("scoverage-", ""))
session.removePackage(Utils.clientPackageName)
session.addPackage("com.snowflake:snowpark:latest")
session.addPackage(s"${Utils.SnowparkPackageName}:latest")
}

override def afterAll: Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import com.snowflake.snowpark.internal.Utils.clientPackageName

import java.io.{BufferedOutputStream, File, FileOutputStream}
import java.nio.file.{Files, NoSuchFileException}
import com.snowflake.snowpark.internal.{JavaUtils, UDFClassPath}
import com.snowflake.snowpark.internal.{JavaUtils, UDFClassPath, Utils}

import scala.reflect.internal.util.BatchSourceFile
import scala.reflect.io.{AbstractFile, VirtualDirectory}
Expand All @@ -17,12 +17,22 @@ class UDFRegistrationSuite extends SNTestBase with FileUtils {
private val tempStage = this.getClass.getSimpleName + Random.alphanumeric.take(5).mkString("")
private val stagePrefix = "udfJar"

override def beforeAll(): Unit = {
override def beforeAll: Unit = {
super.beforeAll()
session.runQuery(s"create or replace temporary stage $tempStage")
if (!isStoredProc(session)) {
TestUtils.addDepsToClassPath(session)
}
if (Utils.ScalaCompatVersion == "2.13") {
session.runQuery("alter session set ENABLE_SCALA_UDF_RUNTIME_2_13=true")
}
}

override def afterAll: Unit = {
if (Utils.ScalaCompatVersion == "2.13") {
session.runQuery("alter session set ENABLE_SCALA_UDF_RUNTIME_2_13=false")
}
super.afterAll
}

test("Test that jar files are uploaded to stage correctly") {
Expand Down
12 changes: 11 additions & 1 deletion src/test/scala/com/snowflake/snowpark/UDTFInternalSuite.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.snowflake.snowpark

import com.snowflake.snowpark.internal.ScalaFunctions
import com.snowflake.snowpark.internal.{ScalaFunctions, Utils}
import com.snowflake.snowpark.types._
import com.snowflake.snowpark.udtf._

Expand All @@ -13,6 +13,16 @@ class UDTFInternalSuite extends SNTestBase {
if (!isStoredProc(session)) {
TestUtils.addDepsToClassPath(session)
}
if (Utils.ScalaCompatVersion == "2.13") {
session.runQuery("alter session set ENABLE_SCALA_UDF_RUNTIME_2_13=true")
}
}

override def afterAll: Unit = {
if (Utils.ScalaCompatVersion == "2.13") {
session.runQuery("alter session set ENABLE_SCALA_UDF_RUNTIME_2_13=false")
}
super.afterAll
}

test("Unit test for UDTF0") {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.snowflake.snowpark_test

import com.snowflake.snowpark.TestUtils.{removeFile, writeFile}
import com.snowflake.snowpark.internal.Utils.ScalaCompatVersion
import com.snowflake.snowpark.functions._
import com.snowflake.snowpark._
import net.snowflake.client.jdbc.SnowflakeSQLException
Expand Down Expand Up @@ -30,12 +31,18 @@ class PermanentUDFSuite extends TestData {
// In stored procs mode, there is only one session
TestUtils.addDepsToClassPath(newSession, Some(stageName))
}
if (ScalaCompatVersion == "2.13") {
session.sql("alter session set ENABLE_SCALA_UDF_RUNTIME_2_13=true").collect()
}
}

override def afterAll: Unit = {
dropStage(stageName)
removeFile(tempDirectory1.getAbsolutePath, session)
removeFile(tempDirectory2.getAbsolutePath, session)
if (ScalaCompatVersion == "2.13") {
session.sql("alter session set ENABLE_SCALA_UDF_RUNTIME_2_13=false").collect()
}
super.afterAll
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,16 @@ class PermanentUDTFSuite extends TestData {
TestUtils.addDepsToClassPath(session, Some(stageName))
TestUtils.addDepsToClassPath(newSession, Some(stageName))
}
if (Utils.ScalaCompatVersion == "2.13") {
session.sql("alter session set ENABLE_SCALA_UDF_RUNTIME_2_13=true").collect()
}
}

override def afterAll: Unit = {
dropStage(stageName)
if (Utils.ScalaCompatVersion == "2.13") {
session.sql("alter session set ENABLE_SCALA_UDF_RUNTIME_2_13=false").collect()
}
super.afterAll
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.snowflake.snowpark_test

import com.snowflake.snowpark._
import com.snowflake.snowpark.internal.Utils.{ScalaCompatVersion, SnowparkPackageName}
import net.snowflake.client.jdbc.SnowflakeSQLException

import java.sql.{Date, Timestamp}
Expand All @@ -15,10 +16,16 @@ class StoredProcedureSuite extends SNTestBase {
if (!isStoredProc(session)) {
TestUtils.addDepsToClassPath(session, Some(testStage))
}
if (ScalaCompatVersion == "2.13") {
session.sql("alter session set ENABLE_SCALA_UDF_RUNTIME_2_13=true").collect()
}
}

override def afterAll: Unit = {
dropStage(testStage)
if (ScalaCompatVersion == "2.13") {
session.sql("alter session set ENABLE_SCALA_UDF_RUNTIME_2_13=false").collect()
}
super.afterAll
}

Expand All @@ -29,7 +36,7 @@ class StoredProcedureSuite extends SNTestBase {
|returns STRING
|language scala
|runtime_version=2.12
|packages=('com.snowflake:snowpark:latest')
|packages=('${SnowparkPackageName}:latest')
|handler='Test.run'
|as
|$$$$
Expand All @@ -56,7 +63,7 @@ class StoredProcedureSuite extends SNTestBase {
|returns STRING
|language scala
|runtime_version=2.12
|packages=('com.snowflake:snowpark:latest')
|packages=('${SnowparkPackageName}:latest')
|handler='Test.run'
|as
|$$$$
Expand All @@ -81,7 +88,7 @@ class StoredProcedureSuite extends SNTestBase {
|returns STRING
|language scala
|runtime_version=2.12
|packages=('com.snowflake:snowpark:latest')
|packages=('${SnowparkPackageName}:latest')
|handler='Test.run'
|as
|$$$$
Expand Down
7 changes: 7 additions & 0 deletions src/test/scala/com/snowflake/snowpark_test/UDFSuite.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.snowflake.snowpark_test

import com.snowflake.snowpark._
import com.snowflake.snowpark.internal.Utils.ScalaCompatVersion
import com.snowflake.snowpark.functions.{col, _}
import com.snowflake.snowpark.types.{Geography, Geometry, Variant}

Expand Down Expand Up @@ -53,6 +54,9 @@ trait UDFSuite extends TestData {
if (!isStoredProc(session)) {
TestUtils.addDepsToClassPath(session)
}
if (ScalaCompatVersion == "2.13") {
session.sql("alter session set ENABLE_SCALA_UDF_RUNTIME_2_13=true").collect()
}
}

override def afterAll: Unit = {
Expand All @@ -63,6 +67,9 @@ trait UDFSuite extends TestData {
dropTable(tableName)
dropTable(semiStructuredTable)
runQuery(s"DROP STAGE IF EXISTS $tmpStageName", session)
if (ScalaCompatVersion == "2.13") {
session.sql("alter session set ENABLE_SCALA_UDF_RUNTIME_2_13=false").collect()
}
super.afterAll()
}

Expand Down
7 changes: 7 additions & 0 deletions src/test/scala/com/snowflake/snowpark_test/UDTFSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,18 @@ class UDTFSuite extends TestData {
.toDF("c1", "c2")
.write
.saveAsTable(wordCountTableName)

if (Utils.ScalaCompatVersion == "2.13") {
session.sql("alter session set ENABLE_SCALA_UDF_RUNTIME_2_13=true").collect()
}
}

override def afterAll: Unit = {
dropTable(wordCountTableName)
dropTable(tableName)
if (Utils.ScalaCompatVersion == "2.13") {
session.sql("alter session set ENABLE_SCALA_UDF_RUNTIME_2_13=false").collect()
}
super.afterAll
}

Expand Down
Loading
Loading