Take.cs 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. using Cysharp.Threading.Tasks.Internal;
  2. using System;
  3. using System.Threading;
  4. namespace Cysharp.Threading.Tasks.Linq
  5. {
  6. public static partial class UniTaskAsyncEnumerable
  7. {
  8. public static IUniTaskAsyncEnumerable<TSource> Take<TSource>(this IUniTaskAsyncEnumerable<TSource> source, Int32 count)
  9. {
  10. Error.ThrowArgumentNullException(source, nameof(source));
  11. return new Take<TSource>(source, count);
  12. }
  13. }
  14. internal sealed class Take<TSource> : IUniTaskAsyncEnumerable<TSource>
  15. {
  16. readonly IUniTaskAsyncEnumerable<TSource> source;
  17. readonly int count;
  18. public Take(IUniTaskAsyncEnumerable<TSource> source, int count)
  19. {
  20. this.source = source;
  21. this.count = count;
  22. }
  23. public IUniTaskAsyncEnumerator<TSource> GetAsyncEnumerator(CancellationToken cancellationToken = default)
  24. {
  25. return new _Take(source, count, cancellationToken);
  26. }
  27. sealed class _Take : MoveNextSource, IUniTaskAsyncEnumerator<TSource>
  28. {
  29. static readonly Action<object> MoveNextCoreDelegate = MoveNextCore;
  30. readonly IUniTaskAsyncEnumerable<TSource> source;
  31. readonly int count;
  32. CancellationToken cancellationToken;
  33. IUniTaskAsyncEnumerator<TSource> enumerator;
  34. UniTask<bool>.Awaiter awaiter;
  35. int index;
  36. public _Take(IUniTaskAsyncEnumerable<TSource> source, int count, CancellationToken cancellationToken)
  37. {
  38. this.source = source;
  39. this.count = count;
  40. this.cancellationToken = cancellationToken;
  41. TaskTracker.TrackActiveTask(this, 3);
  42. }
  43. public TSource Current { get; private set; }
  44. public UniTask<bool> MoveNextAsync()
  45. {
  46. cancellationToken.ThrowIfCancellationRequested();
  47. if (enumerator == null)
  48. {
  49. enumerator = source.GetAsyncEnumerator(cancellationToken);
  50. }
  51. if (checked(index) >= count)
  52. {
  53. return CompletedTasks.False;
  54. }
  55. completionSource.Reset();
  56. SourceMoveNext();
  57. return new UniTask<bool>(this, completionSource.Version);
  58. }
  59. void SourceMoveNext()
  60. {
  61. try
  62. {
  63. awaiter = enumerator.MoveNextAsync().GetAwaiter();
  64. if (awaiter.IsCompleted)
  65. {
  66. MoveNextCore(this);
  67. }
  68. else
  69. {
  70. awaiter.SourceOnCompleted(MoveNextCoreDelegate, this);
  71. }
  72. }
  73. catch (Exception ex)
  74. {
  75. completionSource.TrySetException(ex);
  76. }
  77. }
  78. static void MoveNextCore(object state)
  79. {
  80. var self = (_Take)state;
  81. if (self.TryGetResult(self.awaiter, out var result))
  82. {
  83. if (result)
  84. {
  85. self.index++;
  86. self.Current = self.enumerator.Current;
  87. self.completionSource.TrySetResult(true);
  88. }
  89. else
  90. {
  91. self.completionSource.TrySetResult(false);
  92. }
  93. }
  94. }
  95. public UniTask DisposeAsync()
  96. {
  97. TaskTracker.RemoveTracking(this);
  98. if (enumerator != null)
  99. {
  100. return enumerator.DisposeAsync();
  101. }
  102. return default;
  103. }
  104. }
  105. }
  106. }