Channel.cs 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Threading;
  4. namespace Cysharp.Threading.Tasks
  5. {
  6. public static class Channel
  7. {
  8. public static Channel<T> CreateSingleConsumerUnbounded<T>()
  9. {
  10. return new SingleConsumerUnboundedChannel<T>();
  11. }
  12. }
  13. public abstract class Channel<TWrite, TRead>
  14. {
  15. public ChannelReader<TRead> Reader { get; protected set; }
  16. public ChannelWriter<TWrite> Writer { get; protected set; }
  17. public static implicit operator ChannelReader<TRead>(Channel<TWrite, TRead> channel) => channel.Reader;
  18. public static implicit operator ChannelWriter<TWrite>(Channel<TWrite, TRead> channel) => channel.Writer;
  19. }
  20. public abstract class Channel<T> : Channel<T, T>
  21. {
  22. }
  23. public abstract class ChannelReader<T>
  24. {
  25. public abstract bool TryRead(out T item);
  26. public abstract UniTask<bool> WaitToReadAsync(CancellationToken cancellationToken = default(CancellationToken));
  27. public abstract UniTask Completion { get; }
  28. public virtual UniTask<T> ReadAsync(CancellationToken cancellationToken = default(CancellationToken))
  29. {
  30. if (this.TryRead(out var item))
  31. {
  32. return UniTask.FromResult(item);
  33. }
  34. return ReadAsyncCore(cancellationToken);
  35. }
  36. async UniTask<T> ReadAsyncCore(CancellationToken cancellationToken = default(CancellationToken))
  37. {
  38. if (await WaitToReadAsync(cancellationToken))
  39. {
  40. if (TryRead(out var item))
  41. {
  42. return item;
  43. }
  44. }
  45. throw new ChannelClosedException();
  46. }
  47. public abstract IUniTaskAsyncEnumerable<T> ReadAllAsync(CancellationToken cancellationToken = default(CancellationToken));
  48. }
  49. public abstract class ChannelWriter<T>
  50. {
  51. public abstract bool TryWrite(T item);
  52. public abstract bool TryComplete(Exception error = null);
  53. public void Complete(Exception error = null)
  54. {
  55. if (!TryComplete(error))
  56. {
  57. throw new ChannelClosedException();
  58. }
  59. }
  60. }
  61. public partial class ChannelClosedException : InvalidOperationException
  62. {
  63. public ChannelClosedException() :
  64. base("Channel is already closed.")
  65. { }
  66. public ChannelClosedException(string message) : base(message) { }
  67. public ChannelClosedException(Exception innerException) :
  68. base("Channel is already closed", innerException)
  69. { }
  70. public ChannelClosedException(string message, Exception innerException) : base(message, innerException) { }
  71. }
  72. internal class SingleConsumerUnboundedChannel<T> : Channel<T>
  73. {
  74. readonly Queue<T> items;
  75. readonly SingleConsumerUnboundedChannelReader readerSource;
  76. UniTaskCompletionSource completedTaskSource;
  77. UniTask completedTask;
  78. Exception completionError;
  79. bool closed;
  80. public SingleConsumerUnboundedChannel()
  81. {
  82. items = new Queue<T>();
  83. Writer = new SingleConsumerUnboundedChannelWriter(this);
  84. readerSource = new SingleConsumerUnboundedChannelReader(this);
  85. Reader = readerSource;
  86. }
  87. sealed class SingleConsumerUnboundedChannelWriter : ChannelWriter<T>
  88. {
  89. readonly SingleConsumerUnboundedChannel<T> parent;
  90. public SingleConsumerUnboundedChannelWriter(SingleConsumerUnboundedChannel<T> parent)
  91. {
  92. this.parent = parent;
  93. }
  94. public override bool TryWrite(T item)
  95. {
  96. bool waiting;
  97. lock (parent.items)
  98. {
  99. if (parent.closed) return false;
  100. parent.items.Enqueue(item);
  101. waiting = parent.readerSource.isWaiting;
  102. }
  103. if (waiting)
  104. {
  105. parent.readerSource.SingalContinuation();
  106. }
  107. return true;
  108. }
  109. public override bool TryComplete(Exception error = null)
  110. {
  111. bool waiting;
  112. lock (parent.items)
  113. {
  114. if (parent.closed) return false;
  115. parent.closed = true;
  116. waiting = parent.readerSource.isWaiting;
  117. if (parent.items.Count == 0)
  118. {
  119. if (error == null)
  120. {
  121. if (parent.completedTaskSource != null)
  122. {
  123. parent.completedTaskSource.TrySetResult();
  124. }
  125. else
  126. {
  127. parent.completedTask = UniTask.CompletedTask;
  128. }
  129. }
  130. else
  131. {
  132. if (parent.completedTaskSource != null)
  133. {
  134. parent.completedTaskSource.TrySetException(error);
  135. }
  136. else
  137. {
  138. parent.completedTask = UniTask.FromException(error);
  139. }
  140. }
  141. if (waiting)
  142. {
  143. parent.readerSource.SingalCompleted(error);
  144. }
  145. }
  146. parent.completionError = error;
  147. }
  148. return true;
  149. }
  150. }
  151. sealed class SingleConsumerUnboundedChannelReader : ChannelReader<T>, IUniTaskSource<bool>
  152. {
  153. readonly Action<object> CancellationCallbackDelegate = CancellationCallback;
  154. readonly SingleConsumerUnboundedChannel<T> parent;
  155. CancellationToken cancellationToken;
  156. CancellationTokenRegistration cancellationTokenRegistration;
  157. UniTaskCompletionSourceCore<bool> core;
  158. internal bool isWaiting;
  159. public SingleConsumerUnboundedChannelReader(SingleConsumerUnboundedChannel<T> parent)
  160. {
  161. this.parent = parent;
  162. TaskTracker.TrackActiveTask(this, 4);
  163. }
  164. public override UniTask Completion
  165. {
  166. get
  167. {
  168. if (parent.completedTaskSource != null) return parent.completedTaskSource.Task;
  169. if (parent.closed)
  170. {
  171. return parent.completedTask;
  172. }
  173. parent.completedTaskSource = new UniTaskCompletionSource();
  174. return parent.completedTaskSource.Task;
  175. }
  176. }
  177. public override bool TryRead(out T item)
  178. {
  179. lock (parent.items)
  180. {
  181. if (parent.items.Count != 0)
  182. {
  183. item = parent.items.Dequeue();
  184. // complete when all value was consumed.
  185. if (parent.closed && parent.items.Count == 0)
  186. {
  187. if (parent.completionError != null)
  188. {
  189. if (parent.completedTaskSource != null)
  190. {
  191. parent.completedTaskSource.TrySetException(parent.completionError);
  192. }
  193. else
  194. {
  195. parent.completedTask = UniTask.FromException(parent.completionError);
  196. }
  197. }
  198. else
  199. {
  200. if (parent.completedTaskSource != null)
  201. {
  202. parent.completedTaskSource.TrySetResult();
  203. }
  204. else
  205. {
  206. parent.completedTask = UniTask.CompletedTask;
  207. }
  208. }
  209. }
  210. }
  211. else
  212. {
  213. item = default;
  214. return false;
  215. }
  216. }
  217. return true;
  218. }
  219. public override UniTask<bool> WaitToReadAsync(CancellationToken cancellationToken)
  220. {
  221. if (cancellationToken.IsCancellationRequested)
  222. {
  223. return UniTask.FromCanceled<bool>(cancellationToken);
  224. }
  225. lock (parent.items)
  226. {
  227. if (parent.items.Count != 0)
  228. {
  229. return CompletedTasks.True;
  230. }
  231. if (parent.closed)
  232. {
  233. if (parent.completionError == null)
  234. {
  235. return CompletedTasks.False;
  236. }
  237. else
  238. {
  239. return UniTask.FromException<bool>(parent.completionError);
  240. }
  241. }
  242. cancellationTokenRegistration.Dispose();
  243. core.Reset();
  244. isWaiting = true;
  245. this.cancellationToken = cancellationToken;
  246. if (this.cancellationToken.CanBeCanceled)
  247. {
  248. cancellationTokenRegistration = this.cancellationToken.RegisterWithoutCaptureExecutionContext(CancellationCallbackDelegate, this);
  249. }
  250. return new UniTask<bool>(this, core.Version);
  251. }
  252. }
  253. public void SingalContinuation()
  254. {
  255. core.TrySetResult(true);
  256. }
  257. public void SingalCancellation(CancellationToken cancellationToken)
  258. {
  259. TaskTracker.RemoveTracking(this);
  260. core.TrySetCanceled(cancellationToken);
  261. }
  262. public void SingalCompleted(Exception error)
  263. {
  264. if (error != null)
  265. {
  266. TaskTracker.RemoveTracking(this);
  267. core.TrySetException(error);
  268. }
  269. else
  270. {
  271. TaskTracker.RemoveTracking(this);
  272. core.TrySetResult(false);
  273. }
  274. }
  275. public override IUniTaskAsyncEnumerable<T> ReadAllAsync(CancellationToken cancellationToken = default)
  276. {
  277. return new ReadAllAsyncEnumerable(this, cancellationToken);
  278. }
  279. bool IUniTaskSource<bool>.GetResult(short token)
  280. {
  281. return core.GetResult(token);
  282. }
  283. void IUniTaskSource.GetResult(short token)
  284. {
  285. core.GetResult(token);
  286. }
  287. UniTaskStatus IUniTaskSource.GetStatus(short token)
  288. {
  289. return core.GetStatus(token);
  290. }
  291. void IUniTaskSource.OnCompleted(Action<object> continuation, object state, short token)
  292. {
  293. core.OnCompleted(continuation, state, token);
  294. }
  295. UniTaskStatus IUniTaskSource.UnsafeGetStatus()
  296. {
  297. return core.UnsafeGetStatus();
  298. }
  299. static void CancellationCallback(object state)
  300. {
  301. var self = (SingleConsumerUnboundedChannelReader)state;
  302. self.SingalCancellation(self.cancellationToken);
  303. }
  304. sealed class ReadAllAsyncEnumerable : IUniTaskAsyncEnumerable<T>, IUniTaskAsyncEnumerator<T>
  305. {
  306. readonly Action<object> CancellationCallback1Delegate = CancellationCallback1;
  307. readonly Action<object> CancellationCallback2Delegate = CancellationCallback2;
  308. readonly SingleConsumerUnboundedChannelReader parent;
  309. CancellationToken cancellationToken1;
  310. CancellationToken cancellationToken2;
  311. CancellationTokenRegistration cancellationTokenRegistration1;
  312. CancellationTokenRegistration cancellationTokenRegistration2;
  313. T current;
  314. bool cacheValue;
  315. bool running;
  316. public ReadAllAsyncEnumerable(SingleConsumerUnboundedChannelReader parent, CancellationToken cancellationToken)
  317. {
  318. this.parent = parent;
  319. this.cancellationToken1 = cancellationToken;
  320. }
  321. public IUniTaskAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = default)
  322. {
  323. if (running)
  324. {
  325. throw new InvalidOperationException("Enumerator is already running, does not allow call GetAsyncEnumerator twice.");
  326. }
  327. if (this.cancellationToken1 != cancellationToken)
  328. {
  329. this.cancellationToken2 = cancellationToken;
  330. }
  331. if (this.cancellationToken1.CanBeCanceled)
  332. {
  333. this.cancellationTokenRegistration1 = this.cancellationToken1.RegisterWithoutCaptureExecutionContext(CancellationCallback1Delegate, this);
  334. }
  335. if (this.cancellationToken2.CanBeCanceled)
  336. {
  337. this.cancellationTokenRegistration2 = this.cancellationToken2.RegisterWithoutCaptureExecutionContext(CancellationCallback2Delegate, this);
  338. }
  339. running = true;
  340. return this;
  341. }
  342. public T Current
  343. {
  344. get
  345. {
  346. if (cacheValue)
  347. {
  348. return current;
  349. }
  350. parent.TryRead(out current);
  351. return current;
  352. }
  353. }
  354. public UniTask<bool> MoveNextAsync()
  355. {
  356. cacheValue = false;
  357. return parent.WaitToReadAsync(CancellationToken.None); // ok to use None, registered in ctor.
  358. }
  359. public UniTask DisposeAsync()
  360. {
  361. cancellationTokenRegistration1.Dispose();
  362. cancellationTokenRegistration2.Dispose();
  363. return default;
  364. }
  365. static void CancellationCallback1(object state)
  366. {
  367. var self = (ReadAllAsyncEnumerable)state;
  368. self.parent.SingalCancellation(self.cancellationToken1);
  369. }
  370. static void CancellationCallback2(object state)
  371. {
  372. var self = (ReadAllAsyncEnumerable)state;
  373. self.parent.SingalCancellation(self.cancellationToken2);
  374. }
  375. }
  376. }
  377. }
  378. }