@@ -7,20 +7,41 @@ namespace DotNetty.Common.Concurrency
7
7
using System . Collections . Generic ;
8
8
using System . Runtime . CompilerServices ;
9
9
using System . Runtime . ExceptionServices ;
10
+ using System . Threading ;
10
11
using System . Threading . Tasks ;
11
12
using System . Threading . Tasks . Sources ;
12
13
13
14
public abstract class AbstractPromise : IPromise , IValueTaskSource
14
15
{
16
+ struct CompletionData
17
+ {
18
+ public Action < object > Continuation { get ; }
19
+ public object State { get ; }
20
+ public ExecutionContext ExecutionContext { get ; }
21
+ public SynchronizationContext SynchronizationContext { get ; }
22
+
23
+ public CompletionData ( Action < object > continuation , object state , ExecutionContext executionContext , SynchronizationContext synchronizationContext )
24
+ {
25
+ this . Continuation = continuation ;
26
+ this . State = state ;
27
+ this . ExecutionContext = executionContext ;
28
+ this . SynchronizationContext = synchronizationContext ;
29
+ }
30
+ }
31
+
15
32
const short SourceToken = 0 ;
33
+
34
+ static readonly ContextCallback ExecutionContextCallback = Execute ;
35
+ static readonly SendOrPostCallback SyncContextCallbackWithExecutionContext = ExecuteWithExecutionContext ;
36
+ static readonly SendOrPostCallback SyncContextCallback = Execute ;
16
37
17
38
static readonly Exception CanceledException = new OperationCanceledException ( ) ;
18
39
static readonly Exception CompletedNoException = new Exception ( ) ;
19
40
20
41
protected Exception exception ;
21
42
22
43
int callbackCount ;
23
- ( Action < object > , object ) [ ] callbacks ;
44
+ CompletionData [ ] completions ;
24
45
25
46
public bool TryComplete ( ) => this . TryComplete0 ( CompletedNoException ) ;
26
47
@@ -34,7 +55,7 @@ protected virtual bool TryComplete0(Exception exception)
34
55
{
35
56
// Set the exception object to the exception passed in or a sentinel value
36
57
this . exception = exception ;
37
- this . TryExecuteCallbacks ( ) ;
58
+ this . TryExecuteCompletions ( ) ;
38
59
return true ;
39
60
}
40
61
@@ -75,27 +96,31 @@ public virtual void GetResult(short token)
75
96
76
97
public virtual void OnCompleted ( Action < object > continuation , object state , short token , ValueTaskSourceOnCompletedFlags flags )
77
98
{
78
- //todo: context preservation
79
- if ( this . callbacks == null )
99
+ if ( this . completions == null )
80
100
{
81
- this . callbacks = new ( Action < object > , object ) [ 1 ] ;
101
+ this . completions = new CompletionData [ 1 ] ;
82
102
}
83
103
84
104
int newIndex = this . callbackCount ;
85
105
this . callbackCount ++ ;
86
106
87
- if ( newIndex == this . callbacks . Length )
107
+ if ( newIndex == this . completions . Length )
88
108
{
89
- var newArray = new ( Action < object > , object ) [ this . callbacks . Length * 2 ] ;
90
- Array . Copy ( this . callbacks , newArray , this . callbacks . Length ) ;
91
- this . callbacks = newArray ;
109
+ var newArray = new CompletionData [ this . completions . Length * 2 ] ;
110
+ Array . Copy ( this . completions , newArray , this . completions . Length ) ;
111
+ this . completions = newArray ;
92
112
}
93
113
94
- this . callbacks [ newIndex ] = ( continuation , state ) ;
114
+ this . completions [ newIndex ] = new CompletionData (
115
+ continuation ,
116
+ state ,
117
+ ( flags & ValueTaskSourceOnCompletedFlags . FlowExecutionContext ) != 0 ? ExecutionContext . Capture ( ) : null ,
118
+ ( flags & ValueTaskSourceOnCompletedFlags . UseSchedulingContext ) != 0 ? SynchronizationContext . Current : null
119
+ ) ;
95
120
96
121
if ( this . exception != null )
97
122
{
98
- this . TryExecuteCallbacks ( ) ;
123
+ this . TryExecuteCompletions ( ) ;
99
124
}
100
125
}
101
126
@@ -120,9 +145,9 @@ bool IsCompletedOrThrow()
120
145
[ MethodImpl ( MethodImplOptions . NoInlining ) ]
121
146
void ThrowLatchedException ( ) => ExceptionDispatchInfo . Capture ( this . exception ) . Throw ( ) ;
122
147
123
- bool TryExecuteCallbacks ( )
148
+ bool TryExecuteCompletions ( )
124
149
{
125
- if ( this . callbackCount == 0 || this . callbacks == null )
150
+ if ( this . callbackCount == 0 || this . completions == null )
126
151
{
127
152
return false ;
128
153
}
@@ -133,8 +158,8 @@ bool TryExecuteCallbacks()
133
158
{
134
159
try
135
160
{
136
- ( Action < object > callback , object state ) = this . callbacks [ i ] ;
137
- callback ( state ) ;
161
+ CompletionData completion = this . completions [ i ] ;
162
+ ExecuteCompletion ( completion ) ;
138
163
}
139
164
catch ( Exception ex )
140
165
{
@@ -154,15 +179,57 @@ bool TryExecuteCallbacks()
154
179
155
180
throw new AggregateException ( exceptions ) ;
156
181
}
157
-
182
+
158
183
[ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
159
184
protected void ClearCallbacks ( )
160
185
{
161
186
if ( this . callbackCount > 0 )
162
187
{
163
188
this . callbackCount = 0 ;
164
- Array . Clear ( this . callbacks , 0 , this . callbacks . Length ) ;
189
+ Array . Clear ( this . completions , 0 , this . completions . Length ) ;
165
190
}
191
+ }
192
+
193
+ [ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
194
+ static void ExecuteCompletion ( CompletionData completion )
195
+ {
196
+ if ( completion . SynchronizationContext == null )
197
+ {
198
+ if ( completion . ExecutionContext == null )
199
+ {
200
+ completion . Continuation ( completion . State ) ;
201
+ }
202
+ else
203
+ {
204
+ //boxing
205
+ ExecutionContext . Run ( completion . ExecutionContext , ExecutionContextCallback , completion ) ;
206
+ }
207
+ }
208
+ else
209
+ {
210
+ if ( completion . ExecutionContext == null )
211
+ {
212
+ //boxing
213
+ completion . SynchronizationContext . Post ( SyncContextCallback , completion ) ;
214
+ }
215
+ else
216
+ {
217
+ //boxing
218
+ completion . SynchronizationContext . Post ( SyncContextCallbackWithExecutionContext , completion ) ;
219
+ }
220
+ }
221
+ }
222
+
223
+ static void Execute ( object state )
224
+ {
225
+ CompletionData completion = ( CompletionData ) state ;
226
+ completion . Continuation ( completion . State ) ;
227
+ }
228
+
229
+ static void ExecuteWithExecutionContext ( object state )
230
+ {
231
+ CompletionData completion = ( CompletionData ) state ;
232
+ ExecutionContext . Run ( completion . ExecutionContext , ExecutionContextCallback , state ) ;
166
233
}
167
234
}
168
235
}
0 commit comments