11package sbttestshards
22
3+ import sbttestshards .parsers .FullTestReport
34import sbttestshards .parsers .JUnitReportParser
45
56import java .nio .charset .StandardCharsets
@@ -10,8 +11,19 @@ import scala.util.hashing.MurmurHash3
1011// This trait is open so that users can implement a custom `ShardingAlgorithm` if they'd like
1112trait ShardingAlgorithm {
1213
13- /** Determines whether the specified spec will run on this shard or not. */
14- def shouldRun (specName : String , shardContext : ShardContext ): Boolean
14+ /** Prior test report that can be used by some sharding algorithms to optimize
15+ * balancing tests across different shards.
16+ */
17+ def priorReport : Option [FullTestReport ]
18+
19+ /** Returns the result of whether the specified suite will run on this shard
20+ * or not.
21+ */
22+ def check (suiteName : String , shardContext : ShardContext ): ShardResult
23+
24+ /** Determines whether the specified suite will run on this shard or not. */
25+ def shouldRun (suiteName : String , shardContext : ShardContext ): Boolean =
26+ check(suiteName, shardContext).testShard.contains(shardContext.testShard)
1527}
1628
1729object ShardingAlgorithm {
@@ -21,24 +33,37 @@ object ShardingAlgorithm {
2133 */
2234 final case object SuiteName extends ShardingAlgorithm {
2335
24- override def shouldRun (specName : String , shardContext : ShardContext ): Boolean =
25- MurmurHash3
26- .bytesHash(specName.getBytes(StandardCharsets .UTF_8 ))
27- .abs % shardContext.testShardCount == shardContext.testShard
36+ def check (suiteName : String , shardContext : ShardContext ): ShardResult = {
37+ val testShard = MurmurHash3
38+ .bytesHash(suiteName.getBytes(StandardCharsets .UTF_8 ))
39+ .abs % shardContext.testShardCount
40+
41+ ShardResult (Some (testShard))
42+ }
43+
44+ def priorReport : Option [FullTestReport ] = None
2845 }
2946
3047 /** Will always mark the test to run on this shard. Useful for debugging or
3148 * for fallback algorithms.
3249 */
3350 final case object Always extends ShardingAlgorithm {
34- override def shouldRun (specName : String , shardContext : ShardContext ): Boolean = true
51+
52+ def check (suiteName : String , shardContext : ShardContext ): ShardResult =
53+ ShardResult (Some (shardContext.testShard))
54+
55+ def priorReport : Option [FullTestReport ] = None
3556 }
3657
3758 /** Will never mark the test to run on this shard. Useful for debugging or for
3859 * fallback algorithms.
3960 */
4061 final case object Never extends ShardingAlgorithm {
41- override def shouldRun (specName : String , shardContext : ShardContext ): Boolean = false
62+
63+ def check (suiteName : String , shardContext : ShardContext ): ShardResult =
64+ ShardResult (None )
65+
66+ def priorReport : Option [FullTestReport ] = None
4267 }
4368
4469 object Balance {
@@ -47,28 +72,33 @@ object ShardingAlgorithm {
4772 reportDirectories : Seq [Path ],
4873 shardsInfo : ShardingInfo ,
4974 fallbackShardingAlgorithm : ShardingAlgorithm = ShardingAlgorithm .SuiteName
50- ): Balance =
75+ ): Balance = {
76+ val priorReport = JUnitReportParser .parseDirectoriesRecursively(reportDirectories)
77+
5178 ShardingAlgorithm .Balance (
52- JUnitReportParser .parseDirectoriesRecursively(reportDirectories) .testReports.map { r =>
53- SpecInfo (r.name, Some (Duration .ofMillis((r.timeTaken * 1000 ).toLong)))
79+ priorReport .testReports.map { r =>
80+ SuiteInfo (r.name, Some (Duration .ofMillis((r.timeTaken * 1000 ).toLong)))
5481 },
5582 shardsInfo,
56- fallbackShardingAlgorithm
83+ fallbackShardingAlgorithm,
84+ Some (priorReport)
5785 )
86+ }
5887 }
5988
6089 /** Attempts to balance the shards by execution time so that no one shard
6190 * takes significantly longer to complete than another.
6291 */
6392 final case class Balance (
64- specs : Seq [SpecInfo ],
93+ suites : Seq [SuiteInfo ],
6594 shardsInfo : ShardingInfo ,
66- fallbackShardingAlgorithm : ShardingAlgorithm = ShardingAlgorithm .SuiteName
95+ fallbackShardingAlgorithm : ShardingAlgorithm = ShardingAlgorithm .SuiteName ,
96+ priorReport : Option [FullTestReport ] = None
6797 ) extends ShardingAlgorithm {
6898
6999 // TODO: Median might be better here?
70100 private val averageTime : Option [Duration ] = {
71- val allTimeTaken = specs .flatMap(_.timeTaken)
101+ val allTimeTaken = suites .flatMap(_.timeTaken)
72102
73103 allTimeTaken.reduceOption(_.plus(_)).map { d =>
74104 if (allTimeTaken.isEmpty) Duration .ZERO
@@ -80,7 +110,7 @@ object ShardingAlgorithm {
80110 // is NP-complete, there's a lot of room for improvement with other algorithms. Dynamic programming should be
81111 // possible here.
82112 def distributeEvenly : Map [TestSuiteInfoSimple , Int ] = {
83- val allTests = specs
113+ val allTests = suites
84114 .map(t => TestSuiteInfoSimple (t.name, t.timeTaken.getOrElse(averageTime.getOrElse(Duration .ZERO ))))
85115 .sortBy(_.timeTaken)(Orderings .duration.reverse)
86116
@@ -109,10 +139,15 @@ object ShardingAlgorithm {
109139 k.name -> v
110140 }
111141
112- def shouldRun (specName : String , shardContext : ShardContext ): Boolean =
113- bucketMap.get(specName) match {
114- case Some (bucketIndex) => bucketIndex == shardContext.testShard
115- case None => fallbackShardingAlgorithm.shouldRun(specName, shardContext)
142+ def check (suiteName : String , shardContext : ShardContext ): ShardResult =
143+ bucketMap.get(suiteName) match {
144+ case Some (bucketIndex) =>
145+ ShardResult (Some (bucketIndex))
146+
147+ case None =>
148+ shardContext.logger.warn(s " Using fallback algorithm for $suiteName" )
149+
150+ fallbackShardingAlgorithm.check(suiteName, shardContext)
116151 }
117152 }
118153
0 commit comments