CombineLatest.tt 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. <#@ template debug="false" hostspecific="false" language="C#" #>
  2. <#@ assembly name="System.Core" #>
  3. <#@ import namespace="System.Linq" #>
  4. <#@ import namespace="System.Text" #>
  5. <#@ import namespace="System.Collections.Generic" #>
  6. <#@ output extension=".cs" #>
  7. <#
  8. var tMax = 15;
  9. Func<int, string> typeArgs = x => string.Join(", ", Enumerable.Range(1, x).Select(x => $"T{x}")) + ", TResult";
  10. Func<int, string> paramArgs = x => string.Join(", ", Enumerable.Range(1, x).Select(x => $"IUniTaskAsyncEnumerable<T{x}> source{x}"));
  11. Func<int, string> parameters = x => string.Join(", ", Enumerable.Range(1, x).Select(x => $"source{x}"));
  12. #>
  13. using Cysharp.Threading.Tasks.Internal;
  14. using System;
  15. using System.Threading;
  16. namespace Cysharp.Threading.Tasks.Linq
  17. {
  18. public static partial class UniTaskAsyncEnumerable
  19. {
  20. <# for(var i = 2; i <= tMax; i++) { #>
  21. public static IUniTaskAsyncEnumerable<TResult> CombineLatest<<#= typeArgs(i) #>>(this <#= paramArgs(i) #>, Func<<#= typeArgs(i) #>> resultSelector)
  22. {
  23. <# for(var j = 1; j <= i; j++) { #>
  24. Error.ThrowArgumentNullException(source<#= j #>, nameof(source<#= j #>));
  25. <# } #>
  26. Error.ThrowArgumentNullException(resultSelector, nameof(resultSelector));
  27. return new CombineLatest<<#= typeArgs(i) #>>(<#= parameters(i) #>, resultSelector);
  28. }
  29. <# } #>
  30. }
  31. <# for(var i = 2; i <= tMax; i++) { #>
  32. internal class CombineLatest<<#= typeArgs(i) #>> : IUniTaskAsyncEnumerable<TResult>
  33. {
  34. <# for(var j = 1; j <= i; j++) { #>
  35. readonly IUniTaskAsyncEnumerable<T<#= j #>> source<#= j #>;
  36. <# } #>
  37. readonly Func<<#= typeArgs(i) #>> resultSelector;
  38. public CombineLatest(<#= paramArgs(i) #>, Func<<#= typeArgs(i) #>> resultSelector)
  39. {
  40. <# for(var j = 1; j <= i; j++) { #>
  41. this.source<#= j #> = source<#= j #>;
  42. <# } #>
  43. this.resultSelector = resultSelector;
  44. }
  45. public IUniTaskAsyncEnumerator<TResult> GetAsyncEnumerator(CancellationToken cancellationToken = default)
  46. {
  47. return new _CombineLatest(<#= parameters(i) #>, resultSelector, cancellationToken);
  48. }
  49. class _CombineLatest : MoveNextSource, IUniTaskAsyncEnumerator<TResult>
  50. {
  51. <# for(var j = 1; j <= i; j++) { #>
  52. static readonly Action<object> Completed<#= j #>Delegate = Completed<#= j #>;
  53. <# } #>
  54. const int CompleteCount = <#= i #>;
  55. <# for(var j = 1; j <= i; j++) { #>
  56. readonly IUniTaskAsyncEnumerable<T<#= j #>> source<#= j #>;
  57. <# } #>
  58. readonly Func<<#= typeArgs(i) #>> resultSelector;
  59. CancellationToken cancellationToken;
  60. <# for(var j = 1; j <= i; j++) { #>
  61. IUniTaskAsyncEnumerator<T<#= j #>> enumerator<#= j #>;
  62. UniTask<bool>.Awaiter awaiter<#= j #>;
  63. bool hasCurrent<#= j #>;
  64. bool running<#= j #>;
  65. T<#= j #> current<#= j #>;
  66. <# } #>
  67. int completedCount;
  68. bool syncRunning;
  69. TResult result;
  70. public _CombineLatest(<#= paramArgs(i) #>, Func<<#= typeArgs(i) #>> resultSelector, CancellationToken cancellationToken)
  71. {
  72. <# for(var j = 1; j <= i; j++) { #>
  73. this.source<#= j #> = source<#= j #>;
  74. <# } #>
  75. this.resultSelector = resultSelector;
  76. this.cancellationToken = cancellationToken;
  77. TaskTracker.TrackActiveTask(this, 3);
  78. }
  79. public TResult Current => result;
  80. public UniTask<bool> MoveNextAsync()
  81. {
  82. cancellationToken.ThrowIfCancellationRequested();
  83. if (completedCount == CompleteCount) return CompletedTasks.False;
  84. if (enumerator1 == null)
  85. {
  86. <# for(var j = 1; j <= i; j++) { #>
  87. enumerator<#= j #> = source<#= j #>.GetAsyncEnumerator(cancellationToken);
  88. <# } #>
  89. }
  90. completionSource.Reset();
  91. AGAIN:
  92. syncRunning = true;
  93. <# for(var j = 1; j <= i; j++) { #>
  94. if (!running<#= j #>)
  95. {
  96. running<#= j #> = true;
  97. awaiter<#= j #> = enumerator<#= j #>.MoveNextAsync().GetAwaiter();
  98. if (awaiter<#= j #>.IsCompleted)
  99. {
  100. Completed<#= j #>(this);
  101. }
  102. else
  103. {
  104. awaiter<#= j #>.SourceOnCompleted(Completed<#= j #>Delegate, this);
  105. }
  106. }
  107. <# } #>
  108. if (<#= string.Join(" || ", Enumerable.Range(1, i).Select(x => $"!running{x}")) #>)
  109. {
  110. goto AGAIN;
  111. }
  112. syncRunning = false;
  113. return new UniTask<bool>(this, completionSource.Version);
  114. }
  115. <# for(var j = 1; j <= i; j++) { #>
  116. static void Completed<#= j #>(object state)
  117. {
  118. var self = (_CombineLatest)state;
  119. self.running<#= j #> = false;
  120. try
  121. {
  122. if (self.awaiter<#= j #>.GetResult())
  123. {
  124. self.hasCurrent<#= j #> = true;
  125. self.current<#= j #> = self.enumerator<#= j #>.Current;
  126. goto SUCCESS;
  127. }
  128. else
  129. {
  130. self.running<#= j #> = true; // as complete, no more call MoveNextAsync.
  131. if (Interlocked.Increment(ref self.completedCount) == CompleteCount)
  132. {
  133. goto COMPLETE;
  134. }
  135. return;
  136. }
  137. }
  138. catch (Exception ex)
  139. {
  140. self.running<#= j #> = true; // as complete, no more call MoveNextAsync.
  141. self.completedCount = CompleteCount;
  142. self.completionSource.TrySetException(ex);
  143. return;
  144. }
  145. SUCCESS:
  146. if (!self.TrySetResult())
  147. {
  148. if (self.syncRunning) return;
  149. self.running<#= j #> = true; // as complete, no more call MoveNextAsync.
  150. try
  151. {
  152. self.awaiter<#= j #> = self.enumerator<#= j #>.MoveNextAsync().GetAwaiter();
  153. }
  154. catch (Exception ex)
  155. {
  156. self.completedCount = CompleteCount;
  157. self.completionSource.TrySetException(ex);
  158. return;
  159. }
  160. self.awaiter<#= j #>.SourceOnCompleted(Completed<#= j #>Delegate, self);
  161. }
  162. return;
  163. COMPLETE:
  164. self.completionSource.TrySetResult(false);
  165. return;
  166. }
  167. <# } #>
  168. bool TrySetResult()
  169. {
  170. if (<#= string.Join(" && ", Enumerable.Range(1, i).Select(x => $"hasCurrent{x}")) #>)
  171. {
  172. result = resultSelector(<#= string.Join(", ", Enumerable.Range(1, i).Select(x => $"current{x}")) #>);
  173. completionSource.TrySetResult(true);
  174. return true;
  175. }
  176. else
  177. {
  178. return false;
  179. }
  180. }
  181. public async UniTask DisposeAsync()
  182. {
  183. TaskTracker.RemoveTracking(this);
  184. <# for(var j = 1; j <= i; j++) { #>
  185. if (enumerator<#= j #> != null)
  186. {
  187. await enumerator<#= j #>.DisposeAsync();
  188. }
  189. <# } #>
  190. }
  191. }
  192. }
  193. <# } #>
  194. }