Skip to content

Commit 52ea166

Browse files
authored
Merge pull request #3 from reibitto/snapshot/dryRun
Prepare for new release
2 parents 008fcf0 + a078d07 commit 52ea166

File tree

11 files changed

+165
-74
lines changed

11 files changed

+165
-74
lines changed

.github/workflows/release.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ on:
44
branches:
55
- master
66
- main
7-
- snapshot/
7+
- snapshot/*
88
tags: ["*"]
99
jobs:
1010
publish:

README.md

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ bit easier for you.
2020
Add the following to `project/plugins.sbt`:
2121

2222
```scala
23-
addSbtPlugin("com.github.reibitto" % "sbt-test-shards" % "0.1.0")
23+
addSbtPlugin("com.github.reibitto" % "sbt-test-shards" % "0.2.0")
2424
```
2525

2626
## Configuration
@@ -72,19 +72,33 @@ shardingAlgorithm := ShardingAlgorithm.Balance(
7272
)
7373
```
7474

75-
As you can see, filling this out manually would be tedious. Ideally you'd want to derive
76-
this data structure from a test report. If that's not an option, you could also get away
77-
with only including your slowest test suites in this list and leave the rest to the fallback
78-
sharding algorithm.
75+
As you can see, filling this out manually would be tedious and would require constant maintenance
76+
as you add/remove tests (particularly if the tests are expensive). sbt automatically generates
77+
test report xml files (JUnit-compatible format) when tests are run, and sbt-test-shards can consume
78+
these reports so you don't have to manually manage this yourself. Example usage:
7979

80-
Eventually this plugin will be able to consume test reports itself so that you won't have to
81-
worry about it at all.
80+
```scala
81+
shardingAlgorithm := ShardingAlgorithm.Balance.fromJUnitReports(
82+
Seq(Paths.get(s"path-to-report-files")), // these will usually be located in the `target` folders
83+
shardsInfo = ShardingInfo(testShardCount.value)
84+
)
85+
```
86+
87+
For there to be test reports you have to first run `sbt test` on your entire project. And there's also
88+
the issue that these files won't exist in your CI environment unless you cache/store them somewhere.
89+
I'd recommend storing them remotely somewhere and then pulling them down in CI before running the tests.
90+
And upon successful CI completion, publish the newly generated test reports remotely to keep them up to date.
91+
This can be anywhere such as S3 or even storing them in an artifact as resources and publishing to a private
92+
Maven repo.
8293

8394
### Additional configuration
8495

8596
If you're debugging and want to see logs in CI of which suites are set to run and which
8697
are skipped, you can use `testShardDebug := true`
8798

99+
Also you can run `testDryRun` to see how each suite will be distributed without actually
100+
running all the tests and waiting for them to complete.
101+
88102
## CI Configuration
89103

90104
### GitHub Actions
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
package sbttestshards
2+
3+
final case class ShardResult(testShard: Option[Int])

src/main/scala/sbttestshards/ShardingAlgorithm.scala

Lines changed: 55 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package sbttestshards
22

3+
import sbttestshards.parsers.FullTestReport
34
import sbttestshards.parsers.JUnitReportParser
45

56
import java.nio.charset.StandardCharsets
@@ -10,8 +11,19 @@ import scala.util.hashing.MurmurHash3
1011
// This trait is open so that users can implement a custom `ShardingAlgorithm` if they'd like
1112
trait ShardingAlgorithm {
1213

13-
/** Determines whether the specified spec will run on this shard or not. */
14-
def shouldRun(specName: String, shardContext: ShardContext): Boolean
14+
/** Prior test report that can be used by some sharding algorithms to optimize
15+
* balancing tests across different shards.
16+
*/
17+
def priorReport: Option[FullTestReport]
18+
19+
/** Returns the result of whether the specified suite will run on this shard
20+
* or not.
21+
*/
22+
def check(suiteName: String, shardContext: ShardContext): ShardResult
23+
24+
/** Determines whether the specified suite will run on this shard or not. */
25+
def shouldRun(suiteName: String, shardContext: ShardContext): Boolean =
26+
check(suiteName, shardContext).testShard.contains(shardContext.testShard)
1527
}
1628

1729
object ShardingAlgorithm {
@@ -21,24 +33,37 @@ object ShardingAlgorithm {
2133
*/
2234
final case object SuiteName extends ShardingAlgorithm {
2335

24-
override def shouldRun(specName: String, shardContext: ShardContext): Boolean =
25-
MurmurHash3
26-
.bytesHash(specName.getBytes(StandardCharsets.UTF_8))
27-
.abs % shardContext.testShardCount == shardContext.testShard
36+
def check(suiteName: String, shardContext: ShardContext): ShardResult = {
37+
val testShard = MurmurHash3
38+
.bytesHash(suiteName.getBytes(StandardCharsets.UTF_8))
39+
.abs % shardContext.testShardCount
40+
41+
ShardResult(Some(testShard))
42+
}
43+
44+
def priorReport: Option[FullTestReport] = None
2845
}
2946

3047
/** Will always mark the test to run on this shard. Useful for debugging or
3148
* for fallback algorithms.
3249
*/
3350
final case object Always extends ShardingAlgorithm {
34-
override def shouldRun(specName: String, shardContext: ShardContext): Boolean = true
51+
52+
def check(suiteName: String, shardContext: ShardContext): ShardResult =
53+
ShardResult(Some(shardContext.testShard))
54+
55+
def priorReport: Option[FullTestReport] = None
3556
}
3657

3758
/** Will never mark the test to run on this shard. Useful for debugging or for
3859
* fallback algorithms.
3960
*/
4061
final case object Never extends ShardingAlgorithm {
41-
override def shouldRun(specName: String, shardContext: ShardContext): Boolean = false
62+
63+
def check(suiteName: String, shardContext: ShardContext): ShardResult =
64+
ShardResult(None)
65+
66+
def priorReport: Option[FullTestReport] = None
4267
}
4368

4469
object Balance {
@@ -47,28 +72,33 @@ object ShardingAlgorithm {
4772
reportDirectories: Seq[Path],
4873
shardsInfo: ShardingInfo,
4974
fallbackShardingAlgorithm: ShardingAlgorithm = ShardingAlgorithm.SuiteName
50-
): Balance =
75+
): Balance = {
76+
val priorReport = JUnitReportParser.parseDirectoriesRecursively(reportDirectories)
77+
5178
ShardingAlgorithm.Balance(
52-
JUnitReportParser.parseDirectoriesRecursively(reportDirectories).testReports.map { r =>
53-
SpecInfo(r.name, Some(Duration.ofMillis((r.timeTaken * 1000).toLong)))
79+
priorReport.testReports.map { r =>
80+
SuiteInfo(r.name, Some(Duration.ofMillis((r.timeTaken * 1000).toLong)))
5481
},
5582
shardsInfo,
56-
fallbackShardingAlgorithm
83+
fallbackShardingAlgorithm,
84+
Some(priorReport)
5785
)
86+
}
5887
}
5988

6089
/** Attempts to balance the shards by execution time so that no one shard
6190
* takes significantly longer to complete than another.
6291
*/
6392
final case class Balance(
64-
specs: Seq[SpecInfo],
93+
suites: Seq[SuiteInfo],
6594
shardsInfo: ShardingInfo,
66-
fallbackShardingAlgorithm: ShardingAlgorithm = ShardingAlgorithm.SuiteName
95+
fallbackShardingAlgorithm: ShardingAlgorithm = ShardingAlgorithm.SuiteName,
96+
priorReport: Option[FullTestReport] = None
6797
) extends ShardingAlgorithm {
6898

6999
// TODO: Median might be better here?
70100
private val averageTime: Option[Duration] = {
71-
val allTimeTaken = specs.flatMap(_.timeTaken)
101+
val allTimeTaken = suites.flatMap(_.timeTaken)
72102

73103
allTimeTaken.reduceOption(_.plus(_)).map { d =>
74104
if (allTimeTaken.isEmpty) Duration.ZERO
@@ -80,7 +110,7 @@ object ShardingAlgorithm {
80110
// is NP-complete, there's a lot of room for improvement with other algorithms. Dynamic programming should be
81111
// possible here.
82112
def distributeEvenly: Map[TestSuiteInfoSimple, Int] = {
83-
val allTests = specs
113+
val allTests = suites
84114
.map(t => TestSuiteInfoSimple(t.name, t.timeTaken.getOrElse(averageTime.getOrElse(Duration.ZERO))))
85115
.sortBy(_.timeTaken)(Orderings.duration.reverse)
86116

@@ -109,10 +139,15 @@ object ShardingAlgorithm {
109139
k.name -> v
110140
}
111141

112-
def shouldRun(specName: String, shardContext: ShardContext): Boolean =
113-
bucketMap.get(specName) match {
114-
case Some(bucketIndex) => bucketIndex == shardContext.testShard
115-
case None => fallbackShardingAlgorithm.shouldRun(specName, shardContext)
142+
def check(suiteName: String, shardContext: ShardContext): ShardResult =
143+
bucketMap.get(suiteName) match {
144+
case Some(bucketIndex) =>
145+
ShardResult(Some(bucketIndex))
146+
147+
case None =>
148+
shardContext.logger.warn(s"Using fallback algorithm for $suiteName")
149+
150+
fallbackShardingAlgorithm.check(suiteName, shardContext)
116151
}
117152
}
118153

src/main/scala/sbttestshards/SpecBucketItem.scala

Lines changed: 0 additions & 5 deletions
This file was deleted.

src/main/scala/sbttestshards/SpecInfo.scala

Lines changed: 0 additions & 5 deletions
This file was deleted.
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
package sbttestshards
2+
3+
import java.time.Duration
4+
5+
final case class SuiteBucketItem(name: String, timeTaken: Duration)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
package sbttestshards
2+
3+
import java.time.Duration
4+
5+
final case class SuiteInfo(name: String, timeTaken: Option[Duration])

src/main/scala/sbttestshards/TestShardsPlugin.scala

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package sbttestshards
22

33
import sbt.*
44
import sbt.Keys.*
5+
import sbttestshards.parsers.FullTestReport
56

67
object TestShardsPlugin extends AutoPlugin {
78

@@ -10,17 +11,13 @@ object TestShardsPlugin extends AutoPlugin {
1011
val testShardCount = settingKey[Int]("testShardCount")
1112
val shardingAlgorithm = settingKey[ShardingAlgorithm]("shardingAlgorithm")
1213
val testShardDebug = settingKey[Boolean]("testShardDebug")
14+
val testDryRun = inputKey[Unit]("testDryRun")
1315
}
1416

1517
import autoImport.*
1618

1719
override def trigger = allRequirements
1820

19-
def stringConfig(key: String, default: String): String = {
20-
val propertyKey = key.replace('_', '.').toLowerCase
21-
sys.props.get(propertyKey).orElse(sys.env.get(key)).getOrElse(default)
22-
}
23-
2421
override lazy val projectSettings: Seq[Def.Setting[?]] =
2522
Seq(
2623
testShard := stringConfig("TEST_SHARD", "0").toInt,
@@ -30,18 +27,56 @@ object TestShardsPlugin extends AutoPlugin {
3027
Test / testOptions += {
3128
val shardContext = ShardContext(testShard.value, testShardCount.value, sLog.value)
3229

33-
Tests.Filter { specName =>
34-
val isInShard = shardingAlgorithm.value.shouldRun(specName, shardContext)
30+
Tests.Filter { suiteName =>
31+
val isInShard = shardingAlgorithm.value.shouldRun(suiteName, shardContext)
3532

3633
if (testShardDebug.value)
3734
if (isInShard)
38-
sLog.value.info(s"`$specName` set to run on this shard (#${testShard.value}).")
35+
sLog.value.info(s"`$suiteName` set to run on this shard (#${testShard.value}).")
3936
else
40-
sLog.value.warn(s"`$specName` skipped because it will run on another shard.")
37+
sLog.value.warn(s"`$suiteName` skipped because it will run on another shard.")
4138

4239
isInShard
4340
}
41+
},
42+
testDryRun := {
43+
val shardContext = ShardContext(testShard.value, testShardCount.value, sLog.value)
44+
val logger = shardContext.logger
45+
val algorithm = shardingAlgorithm.value
46+
val priorReport = algorithm.priorReport.getOrElse(FullTestReport.empty)
47+
val sbtSuiteNames = (Test / definedTestNames).value.toSet
48+
val missingSuiteNames = sbtSuiteNames diff priorReport.testReports.map(_.name).toSet
49+
50+
val results = priorReport.testReports.map { suiteReport =>
51+
val shardResult = algorithm.check(suiteReport.name, shardContext)
52+
53+
shardResult.testShard -> suiteReport
54+
}.collect { case (Some(shard), report) => shard -> report }
55+
.groupBy(_._1)
56+
57+
results.toSeq.sortBy(_._1).foreach { case (k, v) =>
58+
val totalTime = BigDecimal(v.map(_._2.timeTaken).sum).setScale(3, BigDecimal.RoundingMode.HALF_UP)
59+
60+
logger.info(s"[${moduleName.value}] Shard $k expected to take $totalTime s")
61+
62+
v.map(_._2).foreach { suiteReport =>
63+
logger.info(s"* ${suiteReport.name} = ${suiteReport.timeTaken} s")
64+
}
65+
}
66+
67+
if (missingSuiteNames.nonEmpty) {
68+
logger.warn(s"Detected ${missingSuiteNames.size} suites that don't have a test report")
69+
70+
missingSuiteNames.foreach { s =>
71+
logger.warn(s"- $s")
72+
}
73+
}
4474
}
4575
)
4676

77+
private def stringConfig(key: String, default: String): String = {
78+
val propertyKey = key.replace('_', '.').toLowerCase
79+
sys.props.get(propertyKey).orElse(sys.env.get(key)).getOrElse(default)
80+
}
81+
4782
}

src/main/scala/sbttestshards/parsers/JUnitReportParser.scala

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
package sbttestshards.parsers
22

3-
import java.nio.file.{Files, Path, Paths}
3+
import java.nio.file.Files
4+
import java.nio.file.Path
45
import scala.jdk.CollectionConverters.*
56
import scala.xml.XML
67

7-
final case class FullTestReport(testReports: Seq[SpecTestReport]) {
8-
def specCount: Int = testReports.length
8+
final case class FullTestReport(testReports: Seq[SuiteReport]) {
9+
def suiteCount: Int = testReports.length
910

1011
def testCount: Int = testReports.map(_.testCount).sum
1112

@@ -18,7 +19,11 @@ final case class FullTestReport(testReports: Seq[SpecTestReport]) {
1819
def ++(other: FullTestReport): FullTestReport = FullTestReport(testReports ++ other.testReports)
1920
}
2021

21-
final case class SpecTestReport(
22+
object FullTestReport {
23+
def empty: FullTestReport = FullTestReport(Seq.empty)
24+
}
25+
26+
final case class SuiteReport(
2227
name: String,
2328
testCount: Int,
2429
errorCount: Int,
@@ -66,10 +71,10 @@ object JUnitReportParser {
6671
}
6772
)
6873

69-
def parseReport(reportFile: Path): SpecTestReport = {
74+
def parseReport(reportFile: Path): SuiteReport = {
7075
val xml = XML.loadFile(reportFile.toFile)
7176

72-
val specName = xml \@ "name"
77+
val suiteName = xml \@ "name"
7378
val testCount = (xml \@ "tests").toInt
7479
val errorCount = (xml \@ "errors").toInt
7580
val failureCount = (xml \@ "failures").toInt
@@ -83,15 +88,14 @@ object JUnitReportParser {
8388
val testName = (node \@ "name").trim
8489

8590
Some(testName)
86-
} else {
91+
} else
8792
None
88-
}
8993
}.collect { case Some(testName) =>
9094
testName
9195
}
9296

93-
SpecTestReport(
94-
name = specName,
97+
SuiteReport(
98+
name = suiteName,
9599
testCount = testCount,
96100
errorCount = errorCount,
97101
failureCount = failureCount,

0 commit comments

Comments
 (0)