Queue.cs 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. using System;
  2. using System.Threading;
  3. namespace Cysharp.Threading.Tasks.Linq
  4. {
  5. public static partial class UniTaskAsyncEnumerable
  6. {
  7. public static IUniTaskAsyncEnumerable<TSource> Queue<TSource>(this IUniTaskAsyncEnumerable<TSource> source)
  8. {
  9. return new QueueOperator<TSource>(source);
  10. }
  11. }
  12. internal sealed class QueueOperator<TSource> : IUniTaskAsyncEnumerable<TSource>
  13. {
  14. readonly IUniTaskAsyncEnumerable<TSource> source;
  15. public QueueOperator(IUniTaskAsyncEnumerable<TSource> source)
  16. {
  17. this.source = source;
  18. }
  19. public IUniTaskAsyncEnumerator<TSource> GetAsyncEnumerator(CancellationToken cancellationToken = default)
  20. {
  21. return new _Queue(source, cancellationToken);
  22. }
  23. sealed class _Queue : IUniTaskAsyncEnumerator<TSource>
  24. {
  25. readonly IUniTaskAsyncEnumerable<TSource> source;
  26. CancellationToken cancellationToken;
  27. Channel<TSource> channel;
  28. IUniTaskAsyncEnumerator<TSource> channelEnumerator;
  29. IUniTaskAsyncEnumerator<TSource> sourceEnumerator;
  30. bool channelClosed;
  31. public _Queue(IUniTaskAsyncEnumerable<TSource> source, CancellationToken cancellationToken)
  32. {
  33. this.source = source;
  34. this.cancellationToken = cancellationToken;
  35. }
  36. public TSource Current => channelEnumerator.Current;
  37. public UniTask<bool> MoveNextAsync()
  38. {
  39. cancellationToken.ThrowIfCancellationRequested();
  40. if (sourceEnumerator == null)
  41. {
  42. sourceEnumerator = source.GetAsyncEnumerator(cancellationToken);
  43. channel = Channel.CreateSingleConsumerUnbounded<TSource>();
  44. channelEnumerator = channel.Reader.ReadAllAsync().GetAsyncEnumerator(cancellationToken);
  45. ConsumeAll(this, sourceEnumerator, channel).Forget();
  46. }
  47. return channelEnumerator.MoveNextAsync();
  48. }
  49. static async UniTaskVoid ConsumeAll(_Queue self, IUniTaskAsyncEnumerator<TSource> enumerator, ChannelWriter<TSource> writer)
  50. {
  51. try
  52. {
  53. while (await enumerator.MoveNextAsync())
  54. {
  55. writer.TryWrite(enumerator.Current);
  56. }
  57. writer.TryComplete();
  58. }
  59. catch (Exception ex)
  60. {
  61. writer.TryComplete(ex);
  62. }
  63. finally
  64. {
  65. self.channelClosed = true;
  66. await enumerator.DisposeAsync();
  67. }
  68. }
  69. public async UniTask DisposeAsync()
  70. {
  71. if (sourceEnumerator != null)
  72. {
  73. await sourceEnumerator.DisposeAsync();
  74. }
  75. if (channelEnumerator != null)
  76. {
  77. await channelEnumerator.DisposeAsync();
  78. }
  79. if (!channelClosed)
  80. {
  81. channelClosed = true;
  82. channel.Writer.TryComplete(new OperationCanceledException());
  83. }
  84. }
  85. }
  86. }
  87. }