Buffer.cs 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. using Cysharp.Threading.Tasks.Internal;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Threading;
  5. namespace Cysharp.Threading.Tasks.Linq
  6. {
  7. public static partial class UniTaskAsyncEnumerable
  8. {
  9. public static IUniTaskAsyncEnumerable<IList<TSource>> Buffer<TSource>(this IUniTaskAsyncEnumerable<TSource> source, Int32 count)
  10. {
  11. Error.ThrowArgumentNullException(source, nameof(source));
  12. if (count <= 0) throw Error.ArgumentOutOfRange(nameof(count));
  13. return new Buffer<TSource>(source, count);
  14. }
  15. public static IUniTaskAsyncEnumerable<IList<TSource>> Buffer<TSource>(this IUniTaskAsyncEnumerable<TSource> source, Int32 count, Int32 skip)
  16. {
  17. Error.ThrowArgumentNullException(source, nameof(source));
  18. if (count <= 0) throw Error.ArgumentOutOfRange(nameof(count));
  19. if (skip <= 0) throw Error.ArgumentOutOfRange(nameof(skip));
  20. return new BufferSkip<TSource>(source, count, skip);
  21. }
  22. }
  23. internal sealed class Buffer<TSource> : IUniTaskAsyncEnumerable<IList<TSource>>
  24. {
  25. readonly IUniTaskAsyncEnumerable<TSource> source;
  26. readonly int count;
  27. public Buffer(IUniTaskAsyncEnumerable<TSource> source, int count)
  28. {
  29. this.source = source;
  30. this.count = count;
  31. }
  32. public IUniTaskAsyncEnumerator<IList<TSource>> GetAsyncEnumerator(CancellationToken cancellationToken = default)
  33. {
  34. return new _Buffer(source, count, cancellationToken);
  35. }
  36. sealed class _Buffer : MoveNextSource, IUniTaskAsyncEnumerator<IList<TSource>>
  37. {
  38. static readonly Action<object> MoveNextCoreDelegate = MoveNextCore;
  39. readonly IUniTaskAsyncEnumerable<TSource> source;
  40. readonly int count;
  41. CancellationToken cancellationToken;
  42. IUniTaskAsyncEnumerator<TSource> enumerator;
  43. UniTask<bool>.Awaiter awaiter;
  44. bool continueNext;
  45. bool completed;
  46. List<TSource> buffer;
  47. public _Buffer(IUniTaskAsyncEnumerable<TSource> source, int count, CancellationToken cancellationToken)
  48. {
  49. this.source = source;
  50. this.count = count;
  51. this.cancellationToken = cancellationToken;
  52. TaskTracker.TrackActiveTask(this, 3);
  53. }
  54. public IList<TSource> Current { get; private set; }
  55. public UniTask<bool> MoveNextAsync()
  56. {
  57. cancellationToken.ThrowIfCancellationRequested();
  58. if (enumerator == null)
  59. {
  60. enumerator = source.GetAsyncEnumerator(cancellationToken);
  61. buffer = new List<TSource>(count);
  62. }
  63. completionSource.Reset();
  64. SourceMoveNext();
  65. return new UniTask<bool>(this, completionSource.Version);
  66. }
  67. void SourceMoveNext()
  68. {
  69. if (completed)
  70. {
  71. if (buffer != null && buffer.Count > 0)
  72. {
  73. var ret = buffer;
  74. buffer = null;
  75. Current = ret;
  76. completionSource.TrySetResult(true);
  77. return;
  78. }
  79. else
  80. {
  81. completionSource.TrySetResult(false);
  82. return;
  83. }
  84. }
  85. try
  86. {
  87. LOOP:
  88. awaiter = enumerator.MoveNextAsync().GetAwaiter();
  89. if (awaiter.IsCompleted)
  90. {
  91. continueNext = true;
  92. MoveNextCore(this);
  93. if (continueNext)
  94. {
  95. continueNext = false;
  96. goto LOOP; // avoid recursive
  97. }
  98. }
  99. else
  100. {
  101. awaiter.SourceOnCompleted(MoveNextCoreDelegate, this);
  102. }
  103. }
  104. catch (Exception ex)
  105. {
  106. completionSource.TrySetException(ex);
  107. }
  108. }
  109. static void MoveNextCore(object state)
  110. {
  111. var self = (_Buffer)state;
  112. if (self.TryGetResult(self.awaiter, out var result))
  113. {
  114. if (result)
  115. {
  116. self.buffer.Add(self.enumerator.Current);
  117. if (self.buffer.Count == self.count)
  118. {
  119. self.Current = self.buffer;
  120. self.buffer = new List<TSource>(self.count);
  121. self.continueNext = false;
  122. self.completionSource.TrySetResult(true);
  123. return;
  124. }
  125. else
  126. {
  127. if (!self.continueNext)
  128. {
  129. self.SourceMoveNext();
  130. }
  131. }
  132. }
  133. else
  134. {
  135. self.continueNext = false;
  136. self.completed = true;
  137. self.SourceMoveNext();
  138. }
  139. }
  140. else
  141. {
  142. self.continueNext = false;
  143. }
  144. }
  145. public UniTask DisposeAsync()
  146. {
  147. TaskTracker.RemoveTracking(this);
  148. if (enumerator != null)
  149. {
  150. return enumerator.DisposeAsync();
  151. }
  152. return default;
  153. }
  154. }
  155. }
  156. internal sealed class BufferSkip<TSource> : IUniTaskAsyncEnumerable<IList<TSource>>
  157. {
  158. readonly IUniTaskAsyncEnumerable<TSource> source;
  159. readonly int count;
  160. readonly int skip;
  161. public BufferSkip(IUniTaskAsyncEnumerable<TSource> source, int count, int skip)
  162. {
  163. this.source = source;
  164. this.count = count;
  165. this.skip = skip;
  166. }
  167. public IUniTaskAsyncEnumerator<IList<TSource>> GetAsyncEnumerator(CancellationToken cancellationToken = default)
  168. {
  169. return new _BufferSkip(source, count, skip, cancellationToken);
  170. }
  171. sealed class _BufferSkip : MoveNextSource, IUniTaskAsyncEnumerator<IList<TSource>>
  172. {
  173. static readonly Action<object> MoveNextCoreDelegate = MoveNextCore;
  174. readonly IUniTaskAsyncEnumerable<TSource> source;
  175. readonly int count;
  176. readonly int skip;
  177. CancellationToken cancellationToken;
  178. IUniTaskAsyncEnumerator<TSource> enumerator;
  179. UniTask<bool>.Awaiter awaiter;
  180. bool continueNext;
  181. bool completed;
  182. Queue<List<TSource>> buffers;
  183. int index = 0;
  184. public _BufferSkip(IUniTaskAsyncEnumerable<TSource> source, int count, int skip, CancellationToken cancellationToken)
  185. {
  186. this.source = source;
  187. this.count = count;
  188. this.skip = skip;
  189. this.cancellationToken = cancellationToken;
  190. TaskTracker.TrackActiveTask(this, 3);
  191. }
  192. public IList<TSource> Current { get; private set; }
  193. public UniTask<bool> MoveNextAsync()
  194. {
  195. cancellationToken.ThrowIfCancellationRequested();
  196. if (enumerator == null)
  197. {
  198. enumerator = source.GetAsyncEnumerator(cancellationToken);
  199. buffers = new Queue<List<TSource>>();
  200. }
  201. completionSource.Reset();
  202. SourceMoveNext();
  203. return new UniTask<bool>(this, completionSource.Version);
  204. }
  205. void SourceMoveNext()
  206. {
  207. if (completed)
  208. {
  209. if (buffers.Count > 0)
  210. {
  211. Current = buffers.Dequeue();
  212. completionSource.TrySetResult(true);
  213. return;
  214. }
  215. else
  216. {
  217. completionSource.TrySetResult(false);
  218. return;
  219. }
  220. }
  221. try
  222. {
  223. LOOP:
  224. awaiter = enumerator.MoveNextAsync().GetAwaiter();
  225. if (awaiter.IsCompleted)
  226. {
  227. continueNext = true;
  228. MoveNextCore(this);
  229. if (continueNext)
  230. {
  231. continueNext = false;
  232. goto LOOP; // avoid recursive
  233. }
  234. }
  235. else
  236. {
  237. awaiter.SourceOnCompleted(MoveNextCoreDelegate, this);
  238. }
  239. }
  240. catch (Exception ex)
  241. {
  242. completionSource.TrySetException(ex);
  243. }
  244. }
  245. static void MoveNextCore(object state)
  246. {
  247. var self = (_BufferSkip)state;
  248. if (self.TryGetResult(self.awaiter, out var result))
  249. {
  250. if (result)
  251. {
  252. if (self.index++ % self.skip == 0)
  253. {
  254. self.buffers.Enqueue(new List<TSource>(self.count));
  255. }
  256. var item = self.enumerator.Current;
  257. foreach (var buffer in self.buffers)
  258. {
  259. buffer.Add(item);
  260. }
  261. if (self.buffers.Count > 0 && self.buffers.Peek().Count == self.count)
  262. {
  263. self.Current = self.buffers.Dequeue();
  264. self.continueNext = false;
  265. self.completionSource.TrySetResult(true);
  266. return;
  267. }
  268. else
  269. {
  270. if (!self.continueNext)
  271. {
  272. self.SourceMoveNext();
  273. }
  274. }
  275. }
  276. else
  277. {
  278. self.continueNext = false;
  279. self.completed = true;
  280. self.SourceMoveNext();
  281. }
  282. }
  283. else
  284. {
  285. self.continueNext = false;
  286. }
  287. }
  288. public UniTask DisposeAsync()
  289. {
  290. TaskTracker.RemoveTracking(this);
  291. if (enumerator != null)
  292. {
  293. return enumerator.DisposeAsync();
  294. }
  295. return default;
  296. }
  297. }
  298. }
  299. }