UniTask.WhenAll.Generated.tt 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. <#@ template debug="false" hostspecific="false" language="C#" #>
  2. <#@ assembly name="System.Core" #>
  3. <#@ import namespace="System.Linq" #>
  4. <#@ import namespace="System.Text" #>
  5. <#@ import namespace="System.Collections.Generic" #>
  6. <#@ output extension=".cs" #>
  7. #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
  8. using System;
  9. using System.Runtime.CompilerServices;
  10. using System.Runtime.ExceptionServices;
  11. using System.Threading;
  12. using Cysharp.Threading.Tasks.Internal;
  13. namespace Cysharp.Threading.Tasks
  14. {
  15. public partial struct UniTask
  16. {
  17. <# for(var i = 2; i <= 15; i++ ) {
  18. var range = Enumerable.Range(1, i);
  19. var t = string.Join(", ", range.Select(x => "T" + x));
  20. var args = string.Join(", ", range.Select(x => $"UniTask<T{x}> task{x}"));
  21. var targs = string.Join(", ", range.Select(x => $"task{x}"));
  22. var tresult = string.Join(", ", range.Select(x => $"task{x}.GetAwaiter().GetResult()"));
  23. var completedSuccessfullyAnd = string.Join(" && ", range.Select(x => $"task{x}.Status.IsCompletedSuccessfully()"));
  24. var tfield = string.Join(", ", range.Select(x => $"self.t{x}"));
  25. #>
  26. public static UniTask<(<#= t #>)> WhenAll<<#= t #>>(<#= args #>)
  27. {
  28. if (<#= completedSuccessfullyAnd #>)
  29. {
  30. return new UniTask<(<#= t #>)>((<#= tresult #>));
  31. }
  32. return new UniTask<(<#= t #>)>(new WhenAllPromise<<#= t #>>(<#= targs #>), 0);
  33. }
  34. sealed class WhenAllPromise<<#= t #>> : IUniTaskSource<(<#= t #>)>
  35. {
  36. <# for(var j = 1; j <= i; j++) { #>
  37. T<#= j #> t<#= j #> = default;
  38. <# } #>
  39. int completedCount;
  40. UniTaskCompletionSourceCore<(<#= t #>)> core;
  41. public WhenAllPromise(<#= args #>)
  42. {
  43. TaskTracker.TrackActiveTask(this, 3);
  44. this.completedCount = 0;
  45. <# for(var j = 1; j <= i; j++) { #>
  46. {
  47. var awaiter = task<#= j #>.GetAwaiter();
  48. if (awaiter.IsCompleted)
  49. {
  50. TryInvokeContinuationT<#= j #>(this, awaiter);
  51. }
  52. else
  53. {
  54. awaiter.SourceOnCompleted(state =>
  55. {
  56. using (var t = (StateTuple<WhenAllPromise<<#= t #>>, UniTask<T<#= j #>>.Awaiter>)state)
  57. {
  58. TryInvokeContinuationT<#= j #>(t.Item1, t.Item2);
  59. }
  60. }, StateTuple.Create(this, awaiter));
  61. }
  62. }
  63. <# } #>
  64. }
  65. <# for(var j = 1; j <= i; j++) { #>
  66. static void TryInvokeContinuationT<#= j #>(WhenAllPromise<<#= t #>> self, in UniTask<T<#= j #>>.Awaiter awaiter)
  67. {
  68. try
  69. {
  70. self.t<#= j #> = awaiter.GetResult();
  71. }
  72. catch (Exception ex)
  73. {
  74. self.core.TrySetException(ex);
  75. return;
  76. }
  77. if (Interlocked.Increment(ref self.completedCount) == <#= i #>)
  78. {
  79. self.core.TrySetResult((<#= tfield #>));
  80. }
  81. }
  82. <# } #>
  83. public (<#= t #>) GetResult(short token)
  84. {
  85. TaskTracker.RemoveTracking(this);
  86. GC.SuppressFinalize(this);
  87. return core.GetResult(token);
  88. }
  89. void IUniTaskSource.GetResult(short token)
  90. {
  91. GetResult(token);
  92. }
  93. public UniTaskStatus GetStatus(short token)
  94. {
  95. return core.GetStatus(token);
  96. }
  97. public UniTaskStatus UnsafeGetStatus()
  98. {
  99. return core.UnsafeGetStatus();
  100. }
  101. public void OnCompleted(Action<object> continuation, object state, short token)
  102. {
  103. core.OnCompleted(continuation, state, token);
  104. }
  105. }
  106. <# } #>
  107. }
  108. }