123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174 |
- using Cysharp.Threading.Tasks.Internal;
- using System;
- using System.Threading;
- namespace Cysharp.Threading.Tasks.Linq
- {
- public static partial class UniTaskAsyncEnumerable
- {
- public static IUniTaskAsyncEnumerable<T> Create<T>(Func<IAsyncWriter<T>, CancellationToken, UniTask> create)
- {
- Error.ThrowArgumentNullException(create, nameof(create));
- return new Create<T>(create);
- }
- }
- public interface IAsyncWriter<T>
- {
- UniTask YieldAsync(T value);
- }
- internal sealed class Create<T> : IUniTaskAsyncEnumerable<T>
- {
- readonly Func<IAsyncWriter<T>, CancellationToken, UniTask> create;
- public Create(Func<IAsyncWriter<T>, CancellationToken, UniTask> create)
- {
- this.create = create;
- }
- public IUniTaskAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = default)
- {
- return new _Create(create, cancellationToken);
- }
- sealed class _Create : MoveNextSource, IUniTaskAsyncEnumerator<T>
- {
- readonly Func<IAsyncWriter<T>, CancellationToken, UniTask> create;
- readonly CancellationToken cancellationToken;
- int state = -1;
- AsyncWriter writer;
- public _Create(Func<IAsyncWriter<T>, CancellationToken, UniTask> create, CancellationToken cancellationToken)
- {
- this.create = create;
- this.cancellationToken = cancellationToken;
- TaskTracker.TrackActiveTask(this, 3);
- }
- public T Current { get; private set; }
- public UniTask DisposeAsync()
- {
- TaskTracker.RemoveTracking(this);
- return default;
- }
- public UniTask<bool> MoveNextAsync()
- {
- if (state == -2) return default;
- completionSource.Reset();
- MoveNext();
- return new UniTask<bool>(this, completionSource.Version);
- }
- void MoveNext()
- {
- try
- {
- switch (state)
- {
- case -1: // init
- {
- writer = new AsyncWriter(this);
- RunWriterTask(create(writer, cancellationToken)).Forget();
- if (Volatile.Read(ref state) == -2)
- {
- return; // complete synchronously
- }
- state = 0; // wait YieldAsync, it set TrySetResult(true)
- return;
- }
- case 0:
- writer.SignalWriter();
- return;
- default:
- goto DONE;
- }
- }
- catch (Exception ex)
- {
- state = -2;
- completionSource.TrySetException(ex);
- return;
- }
- DONE:
- state = -2;
- completionSource.TrySetResult(false);
- return;
- }
- async UniTaskVoid RunWriterTask(UniTask task)
- {
- try
- {
- await task;
- goto DONE;
- }
- catch (Exception ex)
- {
- Volatile.Write(ref state, -2);
- completionSource.TrySetException(ex);
- return;
- }
- DONE:
- Volatile.Write(ref state, -2);
- completionSource.TrySetResult(false);
- }
- public void SetResult(T value)
- {
- Current = value;
- completionSource.TrySetResult(true);
- }
- }
- sealed class AsyncWriter : IUniTaskSource, IAsyncWriter<T>
- {
- readonly _Create enumerator;
- UniTaskCompletionSourceCore<AsyncUnit> core;
- public AsyncWriter(_Create enumerator)
- {
- this.enumerator = enumerator;
- }
- public void GetResult(short token)
- {
- core.GetResult(token);
- }
- public UniTaskStatus GetStatus(short token)
- {
- return core.GetStatus(token);
- }
- public UniTaskStatus UnsafeGetStatus()
- {
- return core.UnsafeGetStatus();
- }
- public void OnCompleted(Action<object> continuation, object state, short token)
- {
- core.OnCompleted(continuation, state, token);
- }
- public UniTask YieldAsync(T value)
- {
- core.Reset();
- enumerator.SetResult(value);
- return new UniTask(this, core.Version);
- }
- public void SignalWriter()
- {
- core.TrySetResult(AsyncUnit.Default);
- }
- }
- }
- }
|