Skip to content

Commit

Permalink
Fix AwaitCompletion to yield results during source iteration
Browse files Browse the repository at this point in the history
Merge of PR #505 that closes #502
  • Loading branch information
atifaziz authored Jun 4, 2018
1 parent bad8004 commit 201dbb3
Showing 1 changed file with 194 additions and 141 deletions.
335 changes: 194 additions & 141 deletions MoreLinq/Experimental/Await.cs
Original file line number Diff line number Diff line change
Expand Up @@ -416,45 +416,124 @@ public static IAwaitQuery<TResult> AwaitCompletion<T, TTaskResult, TResult>(

return
AwaitQuery.Create(
options => _(options.MaxConcurrency ?? int.MaxValue,
options => _(options.MaxConcurrency,
options.Scheduler ?? TaskScheduler.Default,
options.PreserveOrder));

IEnumerable<TResult> _(int maxConcurrency, TaskScheduler scheduler, bool ordered)
IEnumerable<TResult> _(int? maxConcurrency, TaskScheduler scheduler, bool ordered)
{
// A separate task will enumerate the source and launch tasks.
// It will post all progress as notices to the collection below.
// A notice is essentially a discriminated union like:
//
// type Notice<'a, 'b> =
// | End
// | Result of (int * 'a * Task<'b>)
// | Error of ExceptionDispatchInfo
//
// Note that BlockingCollection.CompleteAdding is never used to
// to mark the end (which its own notice above) because
// BlockingCollection.Add throws if called after CompleteAdding
// and we want to deliberately tolerate the race condition.

var notices = new BlockingCollection<(Notice, (int, T, Task<TTaskResult>), ExceptionDispatchInfo)>();
var cancellationTokenSource = new CancellationTokenSource();
var cancellationToken = cancellationTokenSource.Token;
var completed = false;

var enumerator =
source.Index()
.Select(e => (e.Key, Item: e.Value, Task: evaluator(e.Value, cancellationToken)))
.GetEnumerator();
var consumerCancellationTokenSource = new CancellationTokenSource();
(Exception, Exception) lastCriticalErrors = default;

void PostNotice(Notice notice,
(int, T, Task<TTaskResult>) item,
Exception error)
{
// If a notice fails to post then assume critical error
// conditions (like low memory), capture the error without
// further allocation of resources and trip the cancellation
// token source used by the main loop waiting on notices.
// Note that only the "last" critical error is reported
// as maintaining a list would incur allocations. The idea
// here is to make a best effort attempt to report any of
// the error conditions that may be occuring, which is still
// better than nothing.

try
{
var edi = error != null
? ExceptionDispatchInfo.Capture(error)
: null;
notices.Add((notice, item, edi));
}
catch (Exception e)
{
// Don't use ExceptionDispatchInfo.Capture here to avoid
// inducing allocations if already under low memory
// conditions.

lastCriticalErrors = (e, error);
consumerCancellationTokenSource.Cancel();
throw;
}
}

var completed = false;
var cancellationTokenSource = new CancellationTokenSource();

var enumerator = source.Index().GetEnumerator();
IDisposable disposable = enumerator; // disables AccessToDisposedClosure warnings

try
{
var cancellationToken = cancellationTokenSource.Token;

// Fire-up a parallel loop to iterate through the source and
// launch tasks, posting a result-notice as each task
// completes and another, an end-notice, when all tasks have
// completed.

Task.Factory.StartNew(
() =>
CollectToAsync(
enumerator,
e => e.Task,
notices,
(e, r) => (Notice.Result, (e.Key, e.Item, e.Task), default),
ex => (Notice.Error, default, ExceptionDispatchInfo.Capture(ex)),
(Notice.End, default, default),
maxConcurrency, cancellationTokenSource),
async () =>
{
try
{
await enumerator.StartAsync(
e => evaluator(e.Value, cancellationToken),
(e, r) => PostNotice(Notice.Result, (e.Key, e.Value, r), default),
() => PostNotice(Notice.End, default, default),
maxConcurrency, cancellationToken);
}
catch (Exception e)
{
PostNotice(Notice.Error, default, e);
}
},
CancellationToken.None,
TaskCreationOptions.DenyChildAttach,
scheduler);

// Remainder here is the main loop that waits for and
// processes notices.

var nextKey = 0;
var holds = ordered ? new List<(int, T, Task<TTaskResult>)>() : null;

foreach (var (kind, result, error) in notices.GetConsumingEnumerable())
using (var notice = notices.GetConsumingEnumerable(consumerCancellationTokenSource.Token)
.GetEnumerator())
while (true)
{
try
{
if (!notice.MoveNext())
break;
}
catch (OperationCanceledException e) when (e.CancellationToken == consumerCancellationTokenSource.Token)
{
var (error1, error2) = lastCriticalErrors;
throw new Exception("One or more critical errors have occurred.",
error2 != null ? new AggregateException(error1, error2)
: new AggregateException(error1));
}

var (kind, result, error) = notice.Current;

if (kind == Notice.Error)
error.Throw();

Expand Down Expand Up @@ -531,149 +610,76 @@ IEnumerable<TResult> _(int maxConcurrency, TaskScheduler scheduler, bool ordered
}
}

enum Notice { Result, Error, End }

static async Task CollectToAsync<T, TResult, TNotice>(
this IEnumerator<T> e,
Func<T, Task<TResult>> taskSelector,
BlockingCollection<TNotice> collection,
Func<T, Task<TResult>, TNotice> completionNoticeSelector,
Func<Exception, TNotice> errorNoticeSelector,
TNotice endNotice,
int maxConcurrency,
CancellationTokenSource cancellationTokenSource)
enum Notice { End, Result, Error }

static async Task StartAsync<T, TResult>(
this IEnumerator<T> enumerator,
Func<T, Task<TResult>> starter,
Action<T, Task<TResult>> onTaskCompletion,
Action onEnd,
int? maxConcurrency,
CancellationToken cancellationToken)
{
Reader<T> reader = null;
if (enumerator == null) throw new ArgumentNullException(nameof(enumerator));
if (starter == null) throw new ArgumentNullException(nameof(starter));
if (onTaskCompletion == null) throw new ArgumentNullException(nameof(onTaskCompletion));
if (onEnd == null) throw new ArgumentNullException(nameof(onEnd));
if (maxConcurrency < 1) throw new ArgumentOutOfRangeException(nameof(maxConcurrency));

try
using (enumerator)
{
reader = new Reader<T>(e);

var cancellationToken = cancellationTokenSource.Token;
var cancellationTaskSource = new TaskCompletionSource<bool>();
cancellationToken.Register(() => cancellationTaskSource.TrySetResult(true));
var pendingCount = 1; // terminator

var tasks = new List<(T Item, Task<TResult> Task)>();

for (var i = 0; i < maxConcurrency; i++)
void OnPendingCompleted()
{
if (!reader.TryRead(out var item))
break;
tasks.Add((item, taskSelector(item)));
if (Interlocked.Decrement(ref pendingCount) == 0)
onEnd();
}

while (tasks.Count > 0)
var concurrencyGate = maxConcurrency is int count
? new ConcurrencyGate(count)
: ConcurrencyGate.Unbounded;

while (enumerator.MoveNext())
{
// Task.WaitAny is synchronous and blocking but allows the
// waiting to be cancelled via a CancellationToken.
// Task.WhenAny can be awaited so it is better since the
// thread won't be blocked and can return to the pool.
// However, it doesn't support cancellation so instead a
// task is built on top of the CancellationToken that
// completes when the CancellationToken trips.
//
// Also, Task.WhenAny returns the task (Task) object that
// completed but task objects may not be unique due to
// caching, e.g.:
//
// async Task<bool> Foo() => true;
// async Task<bool> Bar() => true;
// var foo = Foo();
// var bar = Bar();
// var same = foo.Equals(bar); // == true
//
// In this case, the task returned by Task.WhenAny will
// match `foo` and `bar`:
//
// var done = Task.WhenAny(foo, bar);
//
// Logically speaking, the uniqueness of a task does not
// matter but here it does, especially when Await (the main
// user of CollectAsync) needs to return results ordered.
// Fortunately, we compose our own task on top of the
// original that links each item with the task result and as
// a consequence generate new and unique task objects.

var completedTask = await
Task.WhenAny(tasks.Select(it => (Task) it.Task).Concat(cancellationTaskSource.Task))
.ConfigureAwait(continueOnCapturedContext: false);

if (completedTask == cancellationTaskSource.Task)
try
{
// Cancellation during the wait means the enumeration
// has been stopped by the user so the results of the
// remaining tasks are no longer needed. Those tasks
// should cancel as a result of sharing the same
// cancellation token and provided that they passed it
// on to any downstream asynchronous operations. Either
// way, this loop is done so exit hard here.

return;
await concurrencyGate.EnterAsync(cancellationToken);
}

var i = tasks.FindIndex(it => it.Task.Equals(completedTask));

catch (OperationCanceledException e) when (e.CancellationToken == cancellationToken)
{
var (item, task) = tasks[i];
tasks.RemoveAt(i);
return;
}

// Await the task rather than using its result directly
// to avoid having the task's exception bubble up as
// AggregateException if the task failed.
Interlocked.Increment(ref pendingCount);

collection.Add(completionNoticeSelector(item, task));
}
var item = enumerator.Current;
var task = starter(item);

{
if (reader.TryRead(out var item))
tasks.Add((item, taskSelector(item)));
}
}
// Add a continutation that notifies completion of the task,
// along with the necessary housekeeping, in case it
// completes before maximum concurrency is reached.

collection.Add(endNotice);
}
catch (Exception ex)
{
cancellationTokenSource.Cancel();
collection.Add(errorNoticeSelector(ex));
}
finally
{
reader?.Dispose();
}
#pragma warning disable 4014 // https://docs.microsoft.com/en-us/dotnet/csharp/language-reference/compiler-messages/cs4014

collection.CompleteAdding();
}
task.ContinueWith(cancellationToken: cancellationToken,
continuationOptions: TaskContinuationOptions.ExecuteSynchronously,
scheduler: TaskScheduler.Current,
continuationAction: t =>
{
concurrencyGate.Exit();

sealed class Reader<T> : IDisposable
{
IEnumerator<T> _enumerator;
if (cancellationToken.IsCancellationRequested)
return;

public Reader(IEnumerator<T> enumerator) =>
_enumerator = enumerator;
onTaskCompletion(item, t);
OnPendingCompleted();
});

public bool TryRead(out T item)
{
var ended = false;
if (_enumerator == null || (ended = !_enumerator.MoveNext()))
{
if (ended)
Dispose();
item = default;
return false;
#pragma warning restore 4014
}

item = _enumerator.Current;
return true;
}

public void Dispose()
{
var e = _enumerator;
if (e == null)
return;
_enumerator = null;
e.Dispose();
OnPendingCompleted();
}
}

Expand Down Expand Up @@ -720,6 +726,53 @@ static class TupleComparer<T1, T2, T3>
public static readonly IComparer<(T1, T2, T3)> Item3 =
Comparer<(T1, T2, T3)>.Create((x, y) => Comparer<T3>.Default.Compare(x.Item3, y.Item3));
}

static class CompletedTask
{
#if NET451 || NETSTANDARD1_0

public static readonly Task Instance;

static CompletedTask()
{
var tcs = new TaskCompletionSource<object>();
tcs.SetResult(null);
Instance = tcs.Task;
}

#else

public static readonly Task Instance = Task.CompletedTask;

#endif
}

sealed class ConcurrencyGate
{
public static readonly ConcurrencyGate Unbounded = new ConcurrencyGate();

readonly SemaphoreSlim _semaphore;

ConcurrencyGate(SemaphoreSlim semaphore = null) =>
_semaphore = semaphore;

public ConcurrencyGate(int max) :
this(new SemaphoreSlim(max, max)) {}

public Task EnterAsync(CancellationToken token)
{
if (_semaphore == null)
{
token.ThrowIfCancellationRequested();
return CompletedTask.Instance;
}

return _semaphore.WaitAsync(token);
}

public void Exit() =>
_semaphore?.Release();
}
}
}

Expand Down

0 comments on commit 201dbb3

Please sign in to comment.