2
2
# All rights reserved. Use of this source code is governed by
3
3
# a BSD-style license that can be found in the LICENSE file.
4
4
5
+ from ..mpi import MPI , use_mpi
6
+
5
7
import sys
6
8
import time
7
9
8
10
import warnings
9
11
10
12
from unittest .signals import registerResult
13
+
11
14
from unittest import TestCase
12
15
from unittest import TestResult
13
16
14
17
15
18
class MPITestCase (TestCase ):
16
- """A simple wrapper around the standard TestCase which provides
17
- one extra method to set the communicator.
18
- """
19
+ """A simple wrapper around the standard TestCase which stores the communicator."""
19
20
20
21
def __init__ (self , * args , ** kwargs ):
21
- super (MPITestCase , self ).__init__ (* args , ** kwargs )
22
-
23
- def setComm ( self , comm ) :
24
- self .comm = comm
22
+ super ().__init__ (* args , ** kwargs )
23
+ self . comm = None
24
+ if use_mpi :
25
+ self .comm = MPI . COMM_WORLD
25
26
26
27
27
28
class MPITestResult (TestResult ):
28
29
"""A test result class that can print formatted text results to a stream.
29
30
30
31
The actions needed are coordinated across all processes.
31
32
32
- Used by MPITestRunner.
33
33
"""
34
34
35
35
separator1 = "=" * 70
36
36
separator2 = "-" * 70
37
37
38
- def __init__ (self , comm , stream = None , descriptions = None , verbosity = None , ** kwargs ):
39
- super (MPITestResult , self ).__init__ (
38
+ def __init__ (self , stream = None , descriptions = None , verbosity = None , ** kwargs ):
39
+ super ().__init__ (
40
40
stream = stream , descriptions = descriptions , verbosity = verbosity , ** kwargs
41
41
)
42
- self .comm = comm
42
+ self .comm = None
43
+ if use_mpi :
44
+ self .comm = MPI .COMM_WORLD
43
45
self .stream = stream
44
46
self .descriptions = descriptions
45
47
self .buffer = False
@@ -53,8 +55,7 @@ def getDescription(self, test):
53
55
return str (test )
54
56
55
57
def startTest (self , test ):
56
- if isinstance (test , MPITestCase ):
57
- test .setComm (self .comm )
58
+ super ().startTest (test )
58
59
self .stream .flush ()
59
60
if self .comm is not None :
60
61
self .comm .barrier ()
@@ -65,11 +66,10 @@ def startTest(self, test):
65
66
self .stream .flush ()
66
67
if self .comm is not None :
67
68
self .comm .barrier ()
68
- super (MPITestResult , self ).startTest (test )
69
69
return
70
70
71
71
def addSuccess (self , test ):
72
- super (MPITestResult , self ).addSuccess (test )
72
+ super ().addSuccess (test )
73
73
if self .comm is None :
74
74
self .stream .write ("ok " )
75
75
else :
@@ -78,7 +78,7 @@ def addSuccess(self, test):
78
78
return
79
79
80
80
def addError (self , test , err ):
81
- super (MPITestResult , self ).addError (test , err )
81
+ super ().addError (test , err )
82
82
if self .comm is None :
83
83
self .stream .write ("error " )
84
84
else :
@@ -87,7 +87,7 @@ def addError(self, test, err):
87
87
return
88
88
89
89
def addFailure (self , test , err ):
90
- super (MPITestResult , self ).addFailure (test , err )
90
+ super ().addFailure (test , err )
91
91
if self .comm is None :
92
92
self .stream .write ("fail " )
93
93
else :
@@ -96,7 +96,7 @@ def addFailure(self, test, err):
96
96
return
97
97
98
98
def addSkip (self , test , reason ):
99
- super (MPITestResult , self ).addSkip (test , reason )
99
+ super ().addSkip (test , reason )
100
100
if self .comm is None :
101
101
self .stream .write ("skipped({}) " .format (reason ))
102
102
else :
@@ -105,7 +105,7 @@ def addSkip(self, test, reason):
105
105
return
106
106
107
107
def addExpectedFailure (self , test , err ):
108
- super (MPITestResult , self ).addExpectedFailure (test , err )
108
+ super ().addExpectedFailure (test , err )
109
109
if self .comm is None :
110
110
self .stream .write ("expected-fail " )
111
111
else :
@@ -114,11 +114,11 @@ def addExpectedFailure(self, test, err):
114
114
return
115
115
116
116
def addUnexpectedSuccess (self , test ):
117
- super (MPITestResult , self ).addUnexpectedSuccess (test )
117
+ super ().addUnexpectedSuccess (test )
118
118
if self .comm is None :
119
- self .stream .writeln ("unexpected-success " )
119
+ self .stream .write ("unexpected-success " )
120
120
else :
121
- self .stream .writeln ("[{}]unexpected-success " .format (self .comm .rank ))
121
+ self .stream .write ("[{}]unexpected-success " .format (self .comm .rank ))
122
122
return
123
123
124
124
def printErrorList (self , flavour , errors ):
@@ -142,15 +142,13 @@ def printErrorList(self, flavour, errors):
142
142
def printErrors (self ):
143
143
if self .comm is None :
144
144
self .stream .writeln ()
145
- self .stream .flush ()
146
145
self .printErrorList ("ERROR" , self .errors )
147
146
self .printErrorList ("FAIL" , self .failures )
148
147
self .stream .flush ()
149
148
else :
150
149
self .comm .barrier ()
151
150
if self .comm .rank == 0 :
152
151
self .stream .writeln ()
153
- self .stream .flush ()
154
152
for p in range (self .comm .size ):
155
153
if p == self .comm .rank :
156
154
self .printErrorList ("ERROR" , self .errors )
@@ -203,15 +201,15 @@ class MPITestRunner(object):
203
201
204
202
resultclass = MPITestResult
205
203
206
- def __init__ (
207
- self , comm , stream = None , descriptions = True , verbosity = 2 , warnings = None
208
- ):
204
+ def __init__ (self , stream = None , descriptions = True , verbosity = 2 , warnings = None ):
209
205
"""Construct a MPITestRunner.
210
206
211
207
Subclasses should accept **kwargs to ensure compatibility as the
212
208
interface changes.
213
209
"""
214
- self .comm = comm
210
+ self .comm = None
211
+ if use_mpi :
212
+ self .comm = MPI .COMM_WORLD
215
213
if stream is None :
216
214
stream = sys .stderr
217
215
self .stream = _WritelnDecorator (stream )
@@ -221,9 +219,7 @@ def __init__(
221
219
222
220
def run (self , test ):
223
221
"Run the given test case or test suite."
224
- result = MPITestResult (
225
- self .comm , self .stream , self .descriptions , self .verbosity
226
- )
222
+ result = MPITestResult (self .stream , self .descriptions , self .verbosity )
227
223
registerResult (result )
228
224
with warnings .catch_warnings ():
229
225
if self .warnings :
0 commit comments