<#@ template debug="false" hostspecific="false" language="C#" #> <#@ assembly name="System.Core" #> <#@ import namespace="System.Linq" #> <#@ import namespace="System.Text" #> <#@ import namespace="System.Collections.Generic" #> <#@ output extension=".cs" #> <# var types = new[] { (typeof(int), "double"), (typeof(long), "double"), (typeof(float),"float"), (typeof(double),"double"), (typeof(decimal),"decimal"), (typeof(int?),"double?"), (typeof(long?),"double?"), (typeof(float?),"float?"), (typeof(double?),"double?"), (typeof(decimal?),"decimal?"), }; Func IsNullable = x => x.IsGenericType; Func ElementType = x => IsNullable(x) ? x.GetGenericArguments()[0] : x; Func TypeName = x => IsNullable(x) ? x.GetGenericArguments()[0].Name + "?" : x.Name; Func WithSuffix = x => IsNullable(x) ? ".GetValueOrDefault()" : ""; Func CalcResult = x => { var e = ElementType(x); return (e == typeof(int) || e == typeof(long)) ? "(double)sum / count" : (e == typeof(float)) ? "(float)(sum / count)" : "sum / count"; }; #> using System; using System.Threading; using Cysharp.Threading.Tasks.Internal; namespace Cysharp.Threading.Tasks.Linq { public static partial class UniTaskAsyncEnumerable { <# foreach(var (t, ret) in types) { #> public static UniTask<<#= ret #>> AverageAsync(this IUniTaskAsyncEnumerable<<#= TypeName(t) #>> source, CancellationToken cancellationToken = default) { Error.ThrowArgumentNullException(source, nameof(source)); return Average.AverageAsync(source, cancellationToken); } public static UniTask<<#= ret #>> AverageAsync(this IUniTaskAsyncEnumerable source, Func> selector, CancellationToken cancellationToken = default) { Error.ThrowArgumentNullException(source, nameof(source)); Error.ThrowArgumentNullException(source, nameof(selector)); return Average.AverageAsync(source, selector, cancellationToken); } public static UniTask<<#= ret #>> AverageAwaitAsync(this IUniTaskAsyncEnumerable source, Func>> selector, CancellationToken cancellationToken = default) { Error.ThrowArgumentNullException(source, nameof(source)); Error.ThrowArgumentNullException(source, nameof(selector)); return Average.AverageAwaitAsync(source, selector, cancellationToken); } public static UniTask<<#= ret #>> AverageAwaitWithCancellationAsync(this IUniTaskAsyncEnumerable source, Func>> selector, CancellationToken cancellationToken = default) { Error.ThrowArgumentNullException(source, nameof(source)); Error.ThrowArgumentNullException(source, nameof(selector)); return Average.AverageAwaitWithCancellationAsync(source, selector, cancellationToken); } <# } #> } internal static class Average { <# foreach(var (t, ret) in types) { #> public static async UniTask<<#= ret #>> AverageAsync(IUniTaskAsyncEnumerable<<#= TypeName(t) #>> source, CancellationToken cancellationToken) { long count = 0; <#= TypeName(t) #> sum = 0; var e = source.GetAsyncEnumerator(cancellationToken); try { while (await e.MoveNextAsync()) { <# if (IsNullable(t)) { #> var v = e.Current; if (v.HasValue) { checked { sum += v.Value; count++; } } <# } else { #> checked { sum += e.Current; count++; } <# } #> } } finally { if (e != null) { await e.DisposeAsync(); } } return <#= CalcResult(t) #>; } public static async UniTask<<#= ret #>> AverageAsync(IUniTaskAsyncEnumerable source, Func> selector, CancellationToken cancellationToken) { long count = 0; <#= TypeName(t) #> sum = 0; var e = source.GetAsyncEnumerator(cancellationToken); try { while (await e.MoveNextAsync()) { <# if (IsNullable(t)) { #> var v = selector(e.Current); if (v.HasValue) { checked { sum += v.Value; count++; } } <# } else { #> checked { sum += selector(e.Current); count++; } <# } #> } } finally { if (e != null) { await e.DisposeAsync(); } } return <#= CalcResult(t) #>; } public static async UniTask<<#= ret #>> AverageAwaitAsync(IUniTaskAsyncEnumerable source, Func>> selector, CancellationToken cancellationToken) { long count = 0; <#= TypeName(t) #> sum = 0; var e = source.GetAsyncEnumerator(cancellationToken); try { while (await e.MoveNextAsync()) { <# if (IsNullable(t)) { #> var v = await selector(e.Current); if (v.HasValue) { checked { sum += v.Value; count++; } } <# } else { #> checked { sum += await selector(e.Current); count++; } <# } #> } } finally { if (e != null) { await e.DisposeAsync(); } } return <#= CalcResult(t) #>; } public static async UniTask<<#= ret #>> AverageAwaitWithCancellationAsync(IUniTaskAsyncEnumerable source, Func>> selector, CancellationToken cancellationToken) { long count = 0; <#= TypeName(t) #> sum = 0; var e = source.GetAsyncEnumerator(cancellationToken); try { while (await e.MoveNextAsync()) { <# if (IsNullable(t)) { #> var v = await selector(e.Current, cancellationToken); if (v.HasValue) { checked { sum += v.Value; count++; } } <# } else { #> checked { sum += await selector(e.Current, cancellationToken); count++; } <# } #> } } finally { if (e != null) { await e.DisposeAsync(); } } return <#= CalcResult(t) #>; } <# } #> } }