Skip to content

Commit d0d827a

Browse files
committed
execution and sync context preservation
1 parent 9e5d322 commit d0d827a

File tree

1 file changed

+84
-17
lines changed

1 file changed

+84
-17
lines changed

src/DotNetty.Common/Concurrency/AbstractPromise.cs

Lines changed: 84 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,41 @@ namespace DotNetty.Common.Concurrency
77
using System.Collections.Generic;
88
using System.Runtime.CompilerServices;
99
using System.Runtime.ExceptionServices;
10+
using System.Threading;
1011
using System.Threading.Tasks;
1112
using System.Threading.Tasks.Sources;
1213

1314
public abstract class AbstractPromise : IPromise, IValueTaskSource
1415
{
16+
struct CompletionData
17+
{
18+
public Action<object> Continuation { get; }
19+
public object State { get; }
20+
public ExecutionContext ExecutionContext { get; }
21+
public SynchronizationContext SynchronizationContext { get; }
22+
23+
public CompletionData(Action<object> continuation, object state, ExecutionContext executionContext, SynchronizationContext synchronizationContext)
24+
{
25+
this.Continuation = continuation;
26+
this.State = state;
27+
this.ExecutionContext = executionContext;
28+
this.SynchronizationContext = synchronizationContext;
29+
}
30+
}
31+
1532
const short SourceToken = 0;
33+
34+
static readonly ContextCallback ExecutionContextCallback = Execute;
35+
static readonly SendOrPostCallback SyncContextCallbackWithExecutionContext = ExecuteWithExecutionContext;
36+
static readonly SendOrPostCallback SyncContextCallback = Execute;
1637

1738
static readonly Exception CanceledException = new OperationCanceledException();
1839
static readonly Exception CompletedNoException = new Exception();
1940

2041
protected Exception exception;
2142

2243
int callbackCount;
23-
(Action<object>, object)[] callbacks;
44+
CompletionData[] completions;
2445

2546
public bool TryComplete() => this.TryComplete0(CompletedNoException);
2647

@@ -34,7 +55,7 @@ protected virtual bool TryComplete0(Exception exception)
3455
{
3556
// Set the exception object to the exception passed in or a sentinel value
3657
this.exception = exception;
37-
this.TryExecuteCallbacks();
58+
this.TryExecuteCompletions();
3859
return true;
3960
}
4061

@@ -75,27 +96,31 @@ public virtual void GetResult(short token)
7596

7697
public virtual void OnCompleted(Action<object> continuation, object state, short token, ValueTaskSourceOnCompletedFlags flags)
7798
{
78-
//todo: context preservation
79-
if (this.callbacks == null)
99+
if (this.completions == null)
80100
{
81-
this.callbacks = new (Action<object>, object)[1];
101+
this.completions = new CompletionData[1];
82102
}
83103

84104
int newIndex = this.callbackCount;
85105
this.callbackCount++;
86106

87-
if (newIndex == this.callbacks.Length)
107+
if (newIndex == this.completions.Length)
88108
{
89-
var newArray = new (Action<object>, object)[this.callbacks.Length * 2];
90-
Array.Copy(this.callbacks, newArray, this.callbacks.Length);
91-
this.callbacks = newArray;
109+
var newArray = new CompletionData[this.completions.Length * 2];
110+
Array.Copy(this.completions, newArray, this.completions.Length);
111+
this.completions = newArray;
92112
}
93113

94-
this.callbacks[newIndex] = (continuation, state);
114+
this.completions[newIndex] = new CompletionData(
115+
continuation,
116+
state,
117+
(flags & ValueTaskSourceOnCompletedFlags.FlowExecutionContext) != 0 ? ExecutionContext.Capture() : null,
118+
(flags & ValueTaskSourceOnCompletedFlags.UseSchedulingContext) != 0 ? SynchronizationContext.Current : null
119+
);
95120

96121
if (this.exception != null)
97122
{
98-
this.TryExecuteCallbacks();
123+
this.TryExecuteCompletions();
99124
}
100125
}
101126

@@ -120,9 +145,9 @@ bool IsCompletedOrThrow()
120145
[MethodImpl(MethodImplOptions.NoInlining)]
121146
void ThrowLatchedException() => ExceptionDispatchInfo.Capture(this.exception).Throw();
122147

123-
bool TryExecuteCallbacks()
148+
bool TryExecuteCompletions()
124149
{
125-
if (this.callbackCount == 0 || this.callbacks == null)
150+
if (this.callbackCount == 0 || this.completions == null)
126151
{
127152
return false;
128153
}
@@ -133,8 +158,8 @@ bool TryExecuteCallbacks()
133158
{
134159
try
135160
{
136-
(Action<object> callback, object state) = this.callbacks[i];
137-
callback(state);
161+
CompletionData completion = this.completions[i];
162+
ExecuteCompletion(completion);
138163
}
139164
catch (Exception ex)
140165
{
@@ -154,15 +179,57 @@ bool TryExecuteCallbacks()
154179

155180
throw new AggregateException(exceptions);
156181
}
157-
182+
158183
[MethodImpl(MethodImplOptions.AggressiveInlining)]
159184
protected void ClearCallbacks()
160185
{
161186
if (this.callbackCount > 0)
162187
{
163188
this.callbackCount = 0;
164-
Array.Clear(this.callbacks, 0, this.callbacks.Length);
189+
Array.Clear(this.completions, 0, this.completions.Length);
165190
}
191+
}
192+
193+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
194+
static void ExecuteCompletion(CompletionData completion)
195+
{
196+
if (completion.SynchronizationContext == null)
197+
{
198+
if (completion.ExecutionContext == null)
199+
{
200+
completion.Continuation(completion.State);
201+
}
202+
else
203+
{
204+
//boxing
205+
ExecutionContext.Run(completion.ExecutionContext, ExecutionContextCallback, completion);
206+
}
207+
}
208+
else
209+
{
210+
if (completion.ExecutionContext == null)
211+
{
212+
//boxing
213+
completion.SynchronizationContext.Post(SyncContextCallback, completion);
214+
}
215+
else
216+
{
217+
//boxing
218+
completion.SynchronizationContext.Post(SyncContextCallbackWithExecutionContext, completion);
219+
}
220+
}
221+
}
222+
223+
static void Execute(object state)
224+
{
225+
CompletionData completion = (CompletionData)state;
226+
completion.Continuation(completion.State);
227+
}
228+
229+
static void ExecuteWithExecutionContext(object state)
230+
{
231+
CompletionData completion = (CompletionData)state;
232+
ExecutionContext.Run(completion.ExecutionContext, ExecutionContextCallback, state);
166233
}
167234
}
168235
}

0 commit comments

Comments
 (0)