EnumeratorAsyncExtensions.cs 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
  2. using System;
  3. using System.Collections;
  4. using System.Reflection;
  5. using System.Runtime.ExceptionServices;
  6. using System.Threading;
  7. using Cysharp.Threading.Tasks.Internal;
  8. using UnityEngine;
  9. namespace Cysharp.Threading.Tasks
  10. {
  11. public static class EnumeratorAsyncExtensions
  12. {
  13. public static UniTask.Awaiter GetAwaiter<T>(this T enumerator)
  14. where T : IEnumerator
  15. {
  16. var e = (IEnumerator)enumerator;
  17. Error.ThrowArgumentNullException(e, nameof(enumerator));
  18. return new UniTask(EnumeratorPromise.Create(e, PlayerLoopTiming.Update, CancellationToken.None, out var token), token).GetAwaiter();
  19. }
  20. public static UniTask WithCancellation(this IEnumerator enumerator, CancellationToken cancellationToken)
  21. {
  22. Error.ThrowArgumentNullException(enumerator, nameof(enumerator));
  23. return new UniTask(EnumeratorPromise.Create(enumerator, PlayerLoopTiming.Update, cancellationToken, out var token), token);
  24. }
  25. public static UniTask ToUniTask(this IEnumerator enumerator, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken))
  26. {
  27. Error.ThrowArgumentNullException(enumerator, nameof(enumerator));
  28. return new UniTask(EnumeratorPromise.Create(enumerator, timing, cancellationToken, out var token), token);
  29. }
  30. public static UniTask ToUniTask(this IEnumerator enumerator, MonoBehaviour coroutineRunner)
  31. {
  32. var source = AutoResetUniTaskCompletionSource.Create();
  33. coroutineRunner.StartCoroutine(Core(enumerator, coroutineRunner, source));
  34. return source.Task;
  35. }
  36. static IEnumerator Core(IEnumerator inner, MonoBehaviour coroutineRunner, AutoResetUniTaskCompletionSource source)
  37. {
  38. yield return coroutineRunner.StartCoroutine(inner);
  39. source.TrySetResult();
  40. }
  41. sealed class EnumeratorPromise : IUniTaskSource, IPlayerLoopItem, ITaskPoolNode<EnumeratorPromise>
  42. {
  43. static TaskPool<EnumeratorPromise> pool;
  44. EnumeratorPromise nextNode;
  45. public ref EnumeratorPromise NextNode => ref nextNode;
  46. static EnumeratorPromise()
  47. {
  48. TaskPool.RegisterSizeGetter(typeof(EnumeratorPromise), () => pool.Size);
  49. }
  50. IEnumerator innerEnumerator;
  51. CancellationToken cancellationToken;
  52. int initialFrame;
  53. bool loopRunning;
  54. bool calledGetResult;
  55. UniTaskCompletionSourceCore<object> core;
  56. EnumeratorPromise()
  57. {
  58. }
  59. public static IUniTaskSource Create(IEnumerator innerEnumerator, PlayerLoopTiming timing, CancellationToken cancellationToken, out short token)
  60. {
  61. if (cancellationToken.IsCancellationRequested)
  62. {
  63. return AutoResetUniTaskCompletionSource.CreateFromCanceled(cancellationToken, out token);
  64. }
  65. if (!pool.TryPop(out var result))
  66. {
  67. result = new EnumeratorPromise();
  68. }
  69. TaskTracker.TrackActiveTask(result, 3);
  70. result.innerEnumerator = ConsumeEnumerator(innerEnumerator);
  71. result.cancellationToken = cancellationToken;
  72. result.loopRunning = true;
  73. result.calledGetResult = false;
  74. result.initialFrame = -1;
  75. token = result.core.Version;
  76. // run immediately.
  77. if (result.MoveNext())
  78. {
  79. PlayerLoopHelper.AddAction(timing, result);
  80. }
  81. return result;
  82. }
  83. public void GetResult(short token)
  84. {
  85. try
  86. {
  87. calledGetResult = true;
  88. core.GetResult(token);
  89. }
  90. finally
  91. {
  92. if (!loopRunning)
  93. {
  94. TryReturn();
  95. }
  96. }
  97. }
  98. public UniTaskStatus GetStatus(short token)
  99. {
  100. return core.GetStatus(token);
  101. }
  102. public UniTaskStatus UnsafeGetStatus()
  103. {
  104. return core.UnsafeGetStatus();
  105. }
  106. public void OnCompleted(Action<object> continuation, object state, short token)
  107. {
  108. core.OnCompleted(continuation, state, token);
  109. }
  110. public bool MoveNext()
  111. {
  112. if (calledGetResult)
  113. {
  114. loopRunning = false;
  115. TryReturn();
  116. return false;
  117. }
  118. if (innerEnumerator == null) // invalid status, returned but loop running?
  119. {
  120. return false;
  121. }
  122. if (cancellationToken.IsCancellationRequested)
  123. {
  124. loopRunning = false;
  125. core.TrySetCanceled(cancellationToken);
  126. return false;
  127. }
  128. if (initialFrame == -1)
  129. {
  130. // Time can not touch in threadpool.
  131. if (PlayerLoopHelper.IsMainThread)
  132. {
  133. initialFrame = Time.frameCount;
  134. }
  135. }
  136. else if (initialFrame == Time.frameCount)
  137. {
  138. return true; // already executed in first frame, skip.
  139. }
  140. try
  141. {
  142. if (innerEnumerator.MoveNext())
  143. {
  144. return true;
  145. }
  146. }
  147. catch (Exception ex)
  148. {
  149. loopRunning = false;
  150. core.TrySetException(ex);
  151. return false;
  152. }
  153. loopRunning = false;
  154. core.TrySetResult(null);
  155. return false;
  156. }
  157. bool TryReturn()
  158. {
  159. TaskTracker.RemoveTracking(this);
  160. core.Reset();
  161. innerEnumerator = default;
  162. cancellationToken = default;
  163. return pool.TryPush(this);
  164. }
  165. // Unwrap YieldInstructions
  166. static IEnumerator ConsumeEnumerator(IEnumerator enumerator)
  167. {
  168. while (enumerator.MoveNext())
  169. {
  170. var current = enumerator.Current;
  171. if (current == null)
  172. {
  173. yield return null;
  174. }
  175. else if (current is CustomYieldInstruction cyi)
  176. {
  177. // WWW, WaitForSecondsRealtime
  178. while (cyi.keepWaiting)
  179. {
  180. yield return null;
  181. }
  182. }
  183. else if (current is YieldInstruction)
  184. {
  185. IEnumerator innerCoroutine = null;
  186. switch (current)
  187. {
  188. case AsyncOperation ao:
  189. innerCoroutine = UnwrapWaitAsyncOperation(ao);
  190. break;
  191. case WaitForSeconds wfs:
  192. innerCoroutine = UnwrapWaitForSeconds(wfs);
  193. break;
  194. }
  195. if (innerCoroutine != null)
  196. {
  197. while (innerCoroutine.MoveNext())
  198. {
  199. yield return null;
  200. }
  201. }
  202. else
  203. {
  204. goto WARN;
  205. }
  206. }
  207. else if (current is IEnumerator e3)
  208. {
  209. var e4 = ConsumeEnumerator(e3);
  210. while (e4.MoveNext())
  211. {
  212. yield return null;
  213. }
  214. }
  215. else
  216. {
  217. goto WARN;
  218. }
  219. continue;
  220. WARN:
  221. // WaitForEndOfFrame, WaitForFixedUpdate, others.
  222. UnityEngine.Debug.LogWarning($"yield {current.GetType().Name} is not supported on await IEnumerator or IEnumerator.ToUniTask(), please use ToUniTask(MonoBehaviour coroutineRunner) instead.");
  223. yield return null;
  224. }
  225. }
  226. static readonly FieldInfo waitForSeconds_Seconds = typeof(WaitForSeconds).GetField("m_Seconds", BindingFlags.Instance | BindingFlags.GetField | BindingFlags.NonPublic);
  227. static IEnumerator UnwrapWaitForSeconds(WaitForSeconds waitForSeconds)
  228. {
  229. var second = (float)waitForSeconds_Seconds.GetValue(waitForSeconds);
  230. var elapsed = 0.0f;
  231. while (true)
  232. {
  233. yield return null;
  234. elapsed += Time.deltaTime;
  235. if (elapsed >= second)
  236. {
  237. break;
  238. }
  239. };
  240. }
  241. static IEnumerator UnwrapWaitAsyncOperation(AsyncOperation asyncOperation)
  242. {
  243. while (!asyncOperation.isDone)
  244. {
  245. yield return null;
  246. }
  247. }
  248. }
  249. }
  250. }