UniTask.WhenAll.cs 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Threading;
  5. using Cysharp.Threading.Tasks.Internal;
  6. namespace Cysharp.Threading.Tasks
  7. {
  8. public partial struct UniTask
  9. {
  10. public static UniTask<T[]> WhenAll<T>(params UniTask<T>[] tasks)
  11. {
  12. if (tasks.Length == 0)
  13. {
  14. return UniTask.FromResult(Array.Empty<T>());
  15. }
  16. return new UniTask<T[]>(new WhenAllPromise<T>(tasks, tasks.Length), 0);
  17. }
  18. public static UniTask<T[]> WhenAll<T>(IEnumerable<UniTask<T>> tasks)
  19. {
  20. using (var span = ArrayPoolUtil.Materialize(tasks))
  21. {
  22. var promise = new WhenAllPromise<T>(span.Array, span.Length); // consumed array in constructor.
  23. return new UniTask<T[]>(promise, 0);
  24. }
  25. }
  26. public static UniTask WhenAll(params UniTask[] tasks)
  27. {
  28. if (tasks.Length == 0)
  29. {
  30. return UniTask.CompletedTask;
  31. }
  32. return new UniTask(new WhenAllPromise(tasks, tasks.Length), 0);
  33. }
  34. public static UniTask WhenAll(IEnumerable<UniTask> tasks)
  35. {
  36. using (var span = ArrayPoolUtil.Materialize(tasks))
  37. {
  38. var promise = new WhenAllPromise(span.Array, span.Length); // consumed array in constructor.
  39. return new UniTask(promise, 0);
  40. }
  41. }
  42. sealed class WhenAllPromise<T> : IUniTaskSource<T[]>
  43. {
  44. T[] result;
  45. int completeCount;
  46. UniTaskCompletionSourceCore<T[]> core; // don't reset(called after GetResult, will invoke TrySetException.)
  47. public WhenAllPromise(UniTask<T>[] tasks, int tasksLength)
  48. {
  49. TaskTracker.TrackActiveTask(this, 3);
  50. this.completeCount = 0;
  51. if (tasksLength == 0)
  52. {
  53. this.result = Array.Empty<T>();
  54. core.TrySetResult(result);
  55. return;
  56. }
  57. this.result = new T[tasksLength];
  58. for (int i = 0; i < tasksLength; i++)
  59. {
  60. UniTask<T>.Awaiter awaiter;
  61. try
  62. {
  63. awaiter = tasks[i].GetAwaiter();
  64. }
  65. catch (Exception ex)
  66. {
  67. core.TrySetException(ex);
  68. continue;
  69. }
  70. if (awaiter.IsCompleted)
  71. {
  72. TryInvokeContinuation(this, awaiter, i);
  73. }
  74. else
  75. {
  76. awaiter.SourceOnCompleted(state =>
  77. {
  78. using (var t = (StateTuple<WhenAllPromise<T>, UniTask<T>.Awaiter, int>)state)
  79. {
  80. TryInvokeContinuation(t.Item1, t.Item2, t.Item3);
  81. }
  82. }, StateTuple.Create(this, awaiter, i));
  83. }
  84. }
  85. }
  86. static void TryInvokeContinuation(WhenAllPromise<T> self, in UniTask<T>.Awaiter awaiter, int i)
  87. {
  88. try
  89. {
  90. self.result[i] = awaiter.GetResult();
  91. }
  92. catch (Exception ex)
  93. {
  94. self.core.TrySetException(ex);
  95. return;
  96. }
  97. if (Interlocked.Increment(ref self.completeCount) == self.result.Length)
  98. {
  99. self.core.TrySetResult(self.result);
  100. }
  101. }
  102. public T[] GetResult(short token)
  103. {
  104. TaskTracker.RemoveTracking(this);
  105. GC.SuppressFinalize(this);
  106. return core.GetResult(token);
  107. }
  108. void IUniTaskSource.GetResult(short token)
  109. {
  110. GetResult(token);
  111. }
  112. public UniTaskStatus GetStatus(short token)
  113. {
  114. return core.GetStatus(token);
  115. }
  116. public UniTaskStatus UnsafeGetStatus()
  117. {
  118. return core.UnsafeGetStatus();
  119. }
  120. public void OnCompleted(Action<object> continuation, object state, short token)
  121. {
  122. core.OnCompleted(continuation, state, token);
  123. }
  124. }
  125. sealed class WhenAllPromise : IUniTaskSource
  126. {
  127. int completeCount;
  128. int tasksLength;
  129. UniTaskCompletionSourceCore<AsyncUnit> core; // don't reset(called after GetResult, will invoke TrySetException.)
  130. public WhenAllPromise(UniTask[] tasks, int tasksLength)
  131. {
  132. TaskTracker.TrackActiveTask(this, 3);
  133. this.tasksLength = tasksLength;
  134. this.completeCount = 0;
  135. if (tasksLength == 0)
  136. {
  137. core.TrySetResult(AsyncUnit.Default);
  138. return;
  139. }
  140. for (int i = 0; i < tasksLength; i++)
  141. {
  142. UniTask.Awaiter awaiter;
  143. try
  144. {
  145. awaiter = tasks[i].GetAwaiter();
  146. }
  147. catch (Exception ex)
  148. {
  149. core.TrySetException(ex);
  150. continue;
  151. }
  152. if (awaiter.IsCompleted)
  153. {
  154. TryInvokeContinuation(this, awaiter);
  155. }
  156. else
  157. {
  158. awaiter.SourceOnCompleted(state =>
  159. {
  160. using (var t = (StateTuple<WhenAllPromise, UniTask.Awaiter>)state)
  161. {
  162. TryInvokeContinuation(t.Item1, t.Item2);
  163. }
  164. }, StateTuple.Create(this, awaiter));
  165. }
  166. }
  167. }
  168. static void TryInvokeContinuation(WhenAllPromise self, in UniTask.Awaiter awaiter)
  169. {
  170. try
  171. {
  172. awaiter.GetResult();
  173. }
  174. catch (Exception ex)
  175. {
  176. self.core.TrySetException(ex);
  177. return;
  178. }
  179. if (Interlocked.Increment(ref self.completeCount) == self.tasksLength)
  180. {
  181. self.core.TrySetResult(AsyncUnit.Default);
  182. }
  183. }
  184. public void GetResult(short token)
  185. {
  186. TaskTracker.RemoveTracking(this);
  187. GC.SuppressFinalize(this);
  188. core.GetResult(token);
  189. }
  190. public UniTaskStatus GetStatus(short token)
  191. {
  192. return core.GetStatus(token);
  193. }
  194. public UniTaskStatus UnsafeGetStatus()
  195. {
  196. return core.UnsafeGetStatus();
  197. }
  198. public void OnCompleted(Action<object> continuation, object state, short token)
  199. {
  200. core.OnCompleted(continuation, state, token);
  201. }
  202. }
  203. }
  204. }