From 5b0f7e6e0a4e775a866bf6d3dab558e53ef4863f Mon Sep 17 00:00:00 2001 From: Alex Soffronow-Pagonidis Date: Wed, 11 Mar 2026 14:12:23 +0100 Subject: [PATCH 1/2] implement math function translation --- README.md | 21 +- .../ClickHouseMathMethodTranslator.cs | 161 ++++++++++++ .../ClickHouseMethodCallTranslatorProvider.cs | 5 +- .../MathTranslationTests.cs | 246 ++++++++++++++++++ 4 files changed, 429 insertions(+), 4 deletions(-) create mode 100644 src/EFCore.ClickHouse/Query/ExpressionTranslators/Internal/ClickHouseMathMethodTranslator.cs create mode 100644 test/EFCore.ClickHouse.Tests/MathTranslationTests.cs diff --git a/README.md b/README.md index ac3c83e..df7f81b 100644 --- a/README.md +++ b/README.md @@ -62,15 +62,30 @@ public class PageView ## Current Status -This provider is in early development. It only support **Read-only queries**. You can map entities to existing ClickHouse tables and query them with LINQ (`Where`, `OrderBy`, `Take`, `Skip`, `Select`, `First`, `Single`, `Any`, `Count`, `Sum`, `Min`, `Max`, `Average`, `Distinct`, `GroupBy`). +This provider is in early development. It supports **read-only queries** — you can map entities to existing ClickHouse tables and query them with LINQ. -String methods translate to ClickHouse equivalents: `Contains`, `StartsWith`, `EndsWith`, `IndexOf`, `Replace`, `Substring`, `Trim`, `ToLower`, `ToUpper`, `Length`, and string concatenation all work. +### LINQ Queries + +`Where`, `OrderBy`, `Take`, `Skip`, `Select`, `First`, `Single`, `Any`, `Count`, `Distinct`, `AsNoTracking` + +### GROUP BY & Aggregates + +`GroupBy` with `Count`, `LongCount`, `Sum`, `Average`, `Min`, `Max` — including `HAVING` (`.Where()` after `.GroupBy()`), multiple aggregates in a single projection, and `OrderBy` on aggregate results. + +### String Methods + +`Contains`, `StartsWith`, `EndsWith`, `IndexOf`, `Replace`, `Substring`, `Trim`/`TrimStart`/`TrimEnd`, `ToLower`, `ToUpper`, `Length`, `IsNullOrEmpty`, `Concat` (and `+` operator) + +### Math Functions + +`Math.Abs`, `Floor`, `Ceiling`, `Round`, `Truncate`, `Pow`, `Sqrt`, `Cbrt`, `Exp`, `Log`, `Log2`, `Log10`, `Sign`, `Sin`, `Cos`, `Tan`, `Asin`, `Acos`, `Atan`, `Atan2`, `RadiansToDegrees`, `DegreesToRadians`, `IsNaN`, `IsInfinity`, `IsFinite`, `IsPositiveInfinity`, `IsNegativeInfinity` — with both `Math` and `MathF` overloads. ### Not Yet Implemented - INSERT / UPDATE / DELETE (modification commands are stubbed) - Migrations -- Advanced types, collection types, TimeSpan / TimeOnly, Tuple, Nullable(T), LowCardinality, Nested, other decimal, etc. type mappings +- JOINs, subqueries, set operations +- Advanced types: Array, Tuple, Nullable(T), LowCardinality, Nested, TimeSpan/TimeOnly - Batched inserts ## Building diff --git a/src/EFCore.ClickHouse/Query/ExpressionTranslators/Internal/ClickHouseMathMethodTranslator.cs b/src/EFCore.ClickHouse/Query/ExpressionTranslators/Internal/ClickHouseMathMethodTranslator.cs new file mode 100644 index 0000000..6d0e9a8 --- /dev/null +++ b/src/EFCore.ClickHouse/Query/ExpressionTranslators/Internal/ClickHouseMathMethodTranslator.cs @@ -0,0 +1,161 @@ +using System.Reflection; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Diagnostics; +using Microsoft.EntityFrameworkCore.Query; +using Microsoft.EntityFrameworkCore.Query.SqlExpressions; +using Microsoft.EntityFrameworkCore.Storage; + +namespace ClickHouse.EntityFrameworkCore.Query.ExpressionTranslators.Internal; + +public class ClickHouseMathMethodTranslator : IMethodCallTranslator +{ + private static readonly Dictionary SupportedMethods = new() + { + { typeof(Math).GetRuntimeMethod(nameof(Math.Abs), [typeof(decimal)])!, "abs" }, + { typeof(Math).GetRuntimeMethod(nameof(Math.Abs), [typeof(double)])!, "abs" }, + { typeof(Math).GetRuntimeMethod(nameof(Math.Abs), [typeof(float)])!, "abs" }, + { typeof(Math).GetRuntimeMethod(nameof(Math.Abs), [typeof(int)])!, "abs" }, + { typeof(Math).GetRuntimeMethod(nameof(Math.Abs), [typeof(long)])!, "abs" }, + { typeof(Math).GetRuntimeMethod(nameof(Math.Abs), [typeof(short)])!, "abs" }, + { typeof(MathF).GetRuntimeMethod(nameof(MathF.Abs), [typeof(float)])!, "abs" }, + { typeof(Math).GetRuntimeMethod(nameof(Math.Ceiling), [typeof(double)])!, "ceiling" }, + { typeof(Math).GetRuntimeMethod(nameof(Math.Ceiling), [typeof(decimal)])!, "ceiling" }, + { typeof(MathF).GetRuntimeMethod(nameof(MathF.Ceiling), [typeof(float)])!, "ceiling" }, + { typeof(Math).GetRuntimeMethod(nameof(Math.Floor), [typeof(double)])!, "floor" }, + { typeof(Math).GetRuntimeMethod(nameof(Math.Floor), [typeof(decimal)])!, "floor" }, + { typeof(MathF).GetRuntimeMethod(nameof(MathF.Floor), [typeof(float)])!, "floor" }, + { typeof(Math).GetRuntimeMethod(nameof(Math.Round), [typeof(double)])!, "round" }, + { typeof(Math).GetRuntimeMethod(nameof(Math.Round), [typeof(double), typeof(int)])!, "round" }, + { typeof(Math).GetRuntimeMethod(nameof(Math.Round), [typeof(decimal)])!, "round" }, + { typeof(Math).GetRuntimeMethod(nameof(Math.Round), [typeof(decimal), typeof(int)])!, "round" }, + { typeof(MathF).GetRuntimeMethod(nameof(MathF.Round), [typeof(float)])!, "round" }, + { typeof(MathF).GetRuntimeMethod(nameof(MathF.Round), [typeof(float), typeof(int)])!, "round" }, + { typeof(Math).GetRuntimeMethod(nameof(Math.Truncate), [typeof(double)])!, "truncate" }, + { typeof(Math).GetRuntimeMethod(nameof(Math.Truncate), [typeof(decimal)])!, "truncate" }, + { typeof(MathF).GetRuntimeMethod(nameof(MathF.Truncate), [typeof(float)])!, "truncate" }, + { typeof(Math).GetRuntimeMethod(nameof(Math.Pow), [typeof(double), typeof(double)])!, "pow" }, + { typeof(MathF).GetRuntimeMethod(nameof(MathF.Pow), [typeof(float), typeof(float)])!, "pow" }, + { typeof(Math).GetRuntimeMethod(nameof(Math.Sqrt), [typeof(double)])!, "sqrt" }, + { typeof(MathF).GetRuntimeMethod(nameof(MathF.Sqrt), [typeof(float)])!, "sqrt" }, + { typeof(Math).GetRuntimeMethod(nameof(Math.Cbrt), [typeof(double)])!, "cbrt" }, + { typeof(Math).GetRuntimeMethod(nameof(Math.Exp), [typeof(double)])!, "exp" }, + { typeof(MathF).GetRuntimeMethod(nameof(MathF.Exp), [typeof(float)])!, "exp" }, + { typeof(Math).GetRuntimeMethod(nameof(Math.Log), [typeof(double)])!, "log" }, + { typeof(MathF).GetRuntimeMethod(nameof(MathF.Log), [typeof(float)])!, "log" }, + { typeof(Math).GetRuntimeMethod(nameof(Math.Log10), [typeof(double)])!, "log10" }, + { typeof(MathF).GetRuntimeMethod(nameof(MathF.Log10), [typeof(float)])!, "log10" }, + { typeof(Math).GetRuntimeMethod(nameof(Math.Sign), [typeof(double)])!, "sign" }, + { typeof(Math).GetRuntimeMethod(nameof(Math.Sign), [typeof(float)])!, "sign" }, + { typeof(Math).GetRuntimeMethod(nameof(Math.Sign), [typeof(int)])!, "sign" }, + { typeof(Math).GetRuntimeMethod(nameof(Math.Sign), [typeof(long)])!, "sign" }, + { typeof(Math).GetRuntimeMethod(nameof(Math.Sign), [typeof(decimal)])!, "sign" }, + { typeof(Math).GetRuntimeMethod(nameof(Math.Sign), [typeof(short)])!, "sign" }, + { typeof(MathF).GetRuntimeMethod(nameof(MathF.Sign), [typeof(float)])!, "sign" }, + { typeof(Math).GetRuntimeMethod(nameof(Math.Sin), [typeof(double)])!, "sin" }, + { typeof(MathF).GetRuntimeMethod(nameof(MathF.Sin), [typeof(float)])!, "sin" }, + { typeof(Math).GetRuntimeMethod(nameof(Math.Cos), [typeof(double)])!, "cos" }, + { typeof(MathF).GetRuntimeMethod(nameof(MathF.Cos), [typeof(float)])!, "cos" }, + { typeof(Math).GetRuntimeMethod(nameof(Math.Tan), [typeof(double)])!, "tan" }, + { typeof(MathF).GetRuntimeMethod(nameof(MathF.Tan), [typeof(float)])!, "tan" }, + { typeof(Math).GetRuntimeMethod(nameof(Math.Asin), [typeof(double)])!, "asin" }, + { typeof(MathF).GetRuntimeMethod(nameof(MathF.Asin), [typeof(float)])!, "asin" }, + { typeof(Math).GetRuntimeMethod(nameof(Math.Acos), [typeof(double)])!, "acos" }, + { typeof(MathF).GetRuntimeMethod(nameof(MathF.Acos), [typeof(float)])!, "acos" }, + { typeof(Math).GetRuntimeMethod(nameof(Math.Atan), [typeof(double)])!, "atan" }, + { typeof(MathF).GetRuntimeMethod(nameof(MathF.Atan), [typeof(float)])!, "atan" }, + { typeof(Math).GetRuntimeMethod(nameof(Math.Atan2), [typeof(double), typeof(double)])!, "atan2" }, + { typeof(MathF).GetRuntimeMethod(nameof(MathF.Atan2), [typeof(float), typeof(float)])!, "atan2" }, + { typeof(double).GetRuntimeMethod(nameof(double.RadiansToDegrees), [typeof(double)])!, "degrees" }, + { typeof(float).GetRuntimeMethod(nameof(float.RadiansToDegrees), [typeof(float)])!, "degrees" }, + { typeof(double).GetRuntimeMethod(nameof(double.DegreesToRadians), [typeof(double)])!, "radians" }, + { typeof(float).GetRuntimeMethod(nameof(float.DegreesToRadians), [typeof(float)])!, "radians" }, + }; + + private readonly ISqlExpressionFactory _sqlExpressionFactory; + private readonly IRelationalTypeMappingSource _typeMappingSource; + + public ClickHouseMathMethodTranslator( + ISqlExpressionFactory sqlExpressionFactory, + IRelationalTypeMappingSource typeMappingSource) + { + _sqlExpressionFactory = sqlExpressionFactory; + _typeMappingSource = typeMappingSource; + } + + public SqlExpression? Translate( + SqlExpression? instance, + MethodInfo method, + IReadOnlyList arguments, + IDiagnosticsLogger logger) + { + if (SupportedMethods.TryGetValue(method, out var functionName)) + { + return _sqlExpressionFactory.Function( + functionName, + arguments, + nullable: false, + argumentsPropagateNullability: Enumerable.Repeat(true, arguments.Count), + method.ReturnType); + } + + // Math.Log(x, newBase) → log(x) / log(newBase) + if ((method.DeclaringType == typeof(Math) || method.DeclaringType == typeof(MathF)) + && method.Name == nameof(Math.Log) + && arguments.Count == 2) + { + return _sqlExpressionFactory.Divide( + _sqlExpressionFactory.Function( + "log", + [arguments[0]], + nullable: true, + argumentsPropagateNullability: [true], + method.ReturnType), + _sqlExpressionFactory.Function( + "log", + [arguments[1]], + nullable: true, + argumentsPropagateNullability: [true], + method.ReturnType)); + } + + // double/float.IsNegativeInfinity → isInfinite(x) AND x < 0 + if ((method.DeclaringType == typeof(double) && method.Name == nameof(double.IsNegativeInfinity)) + || (method.DeclaringType == typeof(float) && method.Name == nameof(float.IsNegativeInfinity))) + { + var zeroConstant = method.DeclaringType == typeof(float) + ? _sqlExpressionFactory.Constant(0f, _typeMappingSource.FindMapping(typeof(float))) + : _sqlExpressionFactory.Constant(0d, _typeMappingSource.FindMapping(typeof(double))); + + return _sqlExpressionFactory.AndAlso( + _sqlExpressionFactory.Function( + "isInfinite", + arguments, + nullable: false, + argumentsPropagateNullability: [true], + typeof(bool), + _typeMappingSource.FindMapping(typeof(bool))), + _sqlExpressionFactory.LessThan(arguments[0], zeroConstant)); + } + + // double/float.IsPositiveInfinity → isInfinite(x) AND x > 0 + if ((method.DeclaringType == typeof(double) && method.Name == nameof(double.IsPositiveInfinity)) + || (method.DeclaringType == typeof(float) && method.Name == nameof(float.IsPositiveInfinity))) + { + var zeroConstant = method.DeclaringType == typeof(float) + ? _sqlExpressionFactory.Constant(0f, _typeMappingSource.FindMapping(typeof(float))) + : _sqlExpressionFactory.Constant(0d, _typeMappingSource.FindMapping(typeof(double))); + + return _sqlExpressionFactory.AndAlso( + _sqlExpressionFactory.Function( + "isInfinite", + arguments, + nullable: false, + argumentsPropagateNullability: [true], + typeof(bool), + _typeMappingSource.FindMapping(typeof(bool))), + _sqlExpressionFactory.GreaterThan(arguments[0], zeroConstant)); + } + + return null; + } +} diff --git a/src/EFCore.ClickHouse/Query/ExpressionTranslators/Internal/ClickHouseMethodCallTranslatorProvider.cs b/src/EFCore.ClickHouse/Query/ExpressionTranslators/Internal/ClickHouseMethodCallTranslatorProvider.cs index f586525..a045bd4 100644 --- a/src/EFCore.ClickHouse/Query/ExpressionTranslators/Internal/ClickHouseMethodCallTranslatorProvider.cs +++ b/src/EFCore.ClickHouse/Query/ExpressionTranslators/Internal/ClickHouseMethodCallTranslatorProvider.cs @@ -1,11 +1,13 @@ using Microsoft.EntityFrameworkCore.Query; +using Microsoft.EntityFrameworkCore.Storage; namespace ClickHouse.EntityFrameworkCore.Query.ExpressionTranslators.Internal; public class ClickHouseMethodCallTranslatorProvider : RelationalMethodCallTranslatorProvider { public ClickHouseMethodCallTranslatorProvider( - RelationalMethodCallTranslatorProviderDependencies dependencies) + RelationalMethodCallTranslatorProviderDependencies dependencies, + IRelationalTypeMappingSource typeMappingSource) : base(dependencies) { var sqlExpressionFactory = dependencies.SqlExpressionFactory; @@ -14,6 +16,7 @@ public ClickHouseMethodCallTranslatorProvider( [ new ClickHouseStringMethodTranslator(sqlExpressionFactory), new ClickHouseLikeTranslator(sqlExpressionFactory), + new ClickHouseMathMethodTranslator(sqlExpressionFactory, typeMappingSource), ]); } } diff --git a/test/EFCore.ClickHouse.Tests/MathTranslationTests.cs b/test/EFCore.ClickHouse.Tests/MathTranslationTests.cs new file mode 100644 index 0000000..b3ead33 --- /dev/null +++ b/test/EFCore.ClickHouse.Tests/MathTranslationTests.cs @@ -0,0 +1,246 @@ +using Microsoft.EntityFrameworkCore; +using Xunit; + +namespace EFCore.ClickHouse.Tests; + +public class MathTranslationTests : IClassFixture +{ + private readonly ClickHouseFixture _fixture; + + public MathTranslationTests(ClickHouseFixture fixture) + { + _fixture = fixture; + } + + [Fact] + public async Task Math_Abs_TranslatesCorrectly() + { + await using var context = new TestDbContext(_fixture.ConnectionString); + + // Age - 30: Alice=0, Bob=-5, Charlie=5, ... + var results = await context.TestEntities + .Select(e => new { e.Name, AbsDiff = Math.Abs(e.Age - 30) }) + .OrderBy(x => x.Name) + .AsNoTracking() + .ToListAsync(); + + var alice = results.First(r => r.Name == "Alice"); + Assert.Equal(0, alice.AbsDiff); // |30-30| + var bob = results.First(r => r.Name == "Bob"); + Assert.Equal(5, bob.AbsDiff); // |25-30| + } + + [Fact] + public async Task Math_Floor_TranslatesCorrectly() + { + await using var context = new TestDbContext(_fixture.ConnectionString); + + var results = await context.TestEntities + .Select(e => new { e.Name, Floored = Math.Floor((double)e.Age / 3.0) }) + .OrderBy(x => x.Name) + .AsNoTracking() + .ToListAsync(); + + var alice = results.First(r => r.Name == "Alice"); + Assert.Equal(10.0, alice.Floored); // floor(30/3) = 10 + var bob = results.First(r => r.Name == "Bob"); + Assert.Equal(8.0, bob.Floored); // floor(25/3) = 8 + } + + [Fact] + public async Task Math_Ceiling_TranslatesCorrectly() + { + await using var context = new TestDbContext(_fixture.ConnectionString); + + var results = await context.TestEntities + .Select(e => new { e.Name, Ceiled = Math.Ceiling((double)e.Age / 3.0) }) + .OrderBy(x => x.Name) + .AsNoTracking() + .ToListAsync(); + + var alice = results.First(r => r.Name == "Alice"); + Assert.Equal(10.0, alice.Ceiled); // ceil(30/3) = 10 + var bob = results.First(r => r.Name == "Bob"); + Assert.Equal(9.0, bob.Ceiled); // ceil(25/3) = 9 (8.333... → 9) + } + + [Fact] + public async Task Math_Round_NoDecimals_TranslatesCorrectly() + { + await using var context = new TestDbContext(_fixture.ConnectionString); + + var results = await context.TestEntities + .Select(e => new { e.Name, Rounded = Math.Round((double)e.Age / 3.0) }) + .OrderBy(x => x.Name) + .AsNoTracking() + .ToListAsync(); + + var alice = results.First(r => r.Name == "Alice"); + Assert.Equal(10.0, alice.Rounded); // round(10.0) = 10 + var bob = results.First(r => r.Name == "Bob"); + Assert.Equal(8.0, bob.Rounded); // round(8.333) = 8 + } + + [Fact] + public async Task Math_Round_WithDecimals_TranslatesCorrectly() + { + await using var context = new TestDbContext(_fixture.ConnectionString); + + var results = await context.TestEntities + .Select(e => new { e.Name, Rounded = Math.Round((double)e.Age / 3.0, 2) }) + .OrderBy(x => x.Name) + .AsNoTracking() + .ToListAsync(); + + var bob = results.First(r => r.Name == "Bob"); + Assert.Equal(8.33, bob.Rounded); // round(25/3, 2) = 8.33 + } + + [Fact] + public async Task Math_Sqrt_TranslatesCorrectly() + { + await using var context = new TestDbContext(_fixture.ConnectionString); + + var results = await context.TestEntities + .Select(e => new { e.Name, SqrtAge = Math.Sqrt(e.Age) }) + .OrderBy(x => x.Name) + .AsNoTracking() + .ToListAsync(); + + var eve = results.First(r => r.Name == "Eve"); + Assert.Equal(Math.Sqrt(22), eve.SqrtAge, 10); + var frank = results.First(r => r.Name == "Frank"); + Assert.Equal(Math.Sqrt(40), frank.SqrtAge, 10); + } + + [Fact] + public async Task Math_Pow_TranslatesCorrectly() + { + await using var context = new TestDbContext(_fixture.ConnectionString); + + var results = await context.TestEntities + .Where(e => e.Name == "Bob") + .Select(e => new { e.Name, Squared = Math.Pow(e.Age, 2) }) + .AsNoTracking() + .ToListAsync(); + + Assert.Single(results); + Assert.Equal(625.0, results[0].Squared); // 25^2 + } + + [Fact] + public async Task Math_Log_TranslatesCorrectly() + { + await using var context = new TestDbContext(_fixture.ConnectionString); + + var results = await context.TestEntities + .Where(e => e.Name == "Alice") + .Select(e => new { e.Name, LogAge = Math.Log(e.Age) }) + .AsNoTracking() + .ToListAsync(); + + Assert.Single(results); + Assert.Equal(Math.Log(30), results[0].LogAge, 8); + } + + [Fact] + public async Task Math_Log10_TranslatesCorrectly() + { + await using var context = new TestDbContext(_fixture.ConnectionString); + + var results = await context.TestEntities + .Where(e => e.Name == "Alice") + .Select(e => new { e.Name, Log10Age = Math.Log10(e.Age) }) + .AsNoTracking() + .ToListAsync(); + + Assert.Single(results); + Assert.Equal(Math.Log10(30), results[0].Log10Age, 10); + } + + [Fact] + public async Task Math_Exp_TranslatesCorrectly() + { + await using var context = new TestDbContext(_fixture.ConnectionString); + + // Use a small value to avoid overflow + var results = await context.TestEntities + .Where(e => e.Name == "Bob") + .Select(e => new { e.Name, ExpVal = Math.Exp((double)e.Age / 10.0) }) + .AsNoTracking() + .ToListAsync(); + + Assert.Single(results); + Assert.Equal(Math.Exp(2.5), results[0].ExpVal, 10); + } + + [Fact] + public async Task Math_Sign_TranslatesCorrectly() + { + await using var context = new TestDbContext(_fixture.ConnectionString); + + var results = await context.TestEntities + .Where(e => e.Name == "Alice") + .Select(e => new { e.Name, SignVal = Math.Sign(e.Age - 30) }) + .AsNoTracking() + .ToListAsync(); + + Assert.Single(results); + Assert.Equal(0, results[0].SignVal); // sign(30-30) = 0 + } + + [Fact] + public async Task Math_SinCos_TranslateCorrectly() + { + await using var context = new TestDbContext(_fixture.ConnectionString); + + var results = await context.TestEntities + .Where(e => e.Name == "Alice") + .Select(e => new + { + e.Name, + SinVal = Math.Sin((double)e.Age), + CosVal = Math.Cos((double)e.Age), + }) + .AsNoTracking() + .ToListAsync(); + + Assert.Single(results); + Assert.Equal(Math.Sin(30.0), results[0].SinVal, 10); + Assert.Equal(Math.Cos(30.0), results[0].CosVal, 10); + } + + [Fact] + public async Task Math_Truncate_TranslatesCorrectly() + { + await using var context = new TestDbContext(_fixture.ConnectionString); + + var results = await context.TestEntities + .Select(e => new { e.Name, Truncated = Math.Truncate((double)e.Age / 3.0) }) + .OrderBy(x => x.Name) + .AsNoTracking() + .ToListAsync(); + + var bob = results.First(r => r.Name == "Bob"); + Assert.Equal(8.0, bob.Truncated); // truncate(8.333) = 8 + } + + [Fact] + public async Task Math_Where_WithMathFunction() + { + await using var context = new TestDbContext(_fixture.ConnectionString); + + // Filter: sqrt(age) > 5.5 → age > 30.25 → Alice(30 no), Charlie(35 yes), Frank(40 yes), Grace(33 yes), Ivy(31 yes) + var results = await context.TestEntities + .Where(e => Math.Sqrt(e.Age) > 5.5) + .OrderBy(e => e.Name) + .AsNoTracking() + .ToListAsync(); + + Assert.Equal(4, results.Count); + Assert.Equal("Charlie", results[0].Name); + Assert.Equal("Frank", results[1].Name); + Assert.Equal("Grace", results[2].Name); + Assert.Equal("Ivy", results[3].Name); + } +} From f78e8708cdc26e6dacf845d1d6c28dee00a92f1b Mon Sep 17 00:00:00 2001 From: Alex Soffronow-Pagonidis Date: Wed, 11 Mar 2026 14:21:01 +0100 Subject: [PATCH 2/2] improve test coverage --- .../GroupByAggregateTests.cs | 67 +++++++ .../MathTranslationTests.cs | 189 ++++++++++++++++++ 2 files changed, 256 insertions(+) diff --git a/test/EFCore.ClickHouse.Tests/GroupByAggregateTests.cs b/test/EFCore.ClickHouse.Tests/GroupByAggregateTests.cs index c163ff1..3624d4c 100644 --- a/test/EFCore.ClickHouse.Tests/GroupByAggregateTests.cs +++ b/test/EFCore.ClickHouse.Tests/GroupByAggregateTests.cs @@ -165,4 +165,71 @@ public async Task GroupBy_OrderByAggregate_Sorts() Assert.False(results[1].IsActive); Assert.Equal(4, results[1].Count); } + + [Fact] + public async Task GroupBy_LongCount_ReturnsCorrectCounts() + { + await using var context = new TestDbContext(_fixture.ConnectionString); + + var results = await context.TestEntities + .GroupBy(e => e.IsActive) + .Select(g => new { IsActive = g.Key, Count = g.LongCount() }) + .OrderBy(x => x.IsActive) + .AsNoTracking() + .ToListAsync(); + + Assert.Equal(2, results.Count); + Assert.Equal(4L, results[0].Count); + Assert.Equal(6L, results[1].Count); + // Verify it's actually long, not int + Assert.IsType(results[0].Count); + } + + [Fact] + public async Task GroupBy_CountWithPredicate_UsesConditionalAggregate() + { + // g.Count(x => x.Age > 30) exercises CombineTerms predicate path: + // COUNT(CASE WHEN age > 30 THEN 1 ELSE NULL END) + await using var context = new TestDbContext(_fixture.ConnectionString); + + var results = await context.TestEntities + .GroupBy(e => e.IsActive) + .Select(g => new { IsActive = g.Key, OlderThan30 = g.Count(e => e.Age > 30) }) + .OrderBy(x => x.IsActive) + .AsNoTracking() + .ToListAsync(); + + Assert.Equal(2, results.Count); + // Inactive: Charlie(35) > 30 = 1 + Assert.Equal(1, results[0].OlderThan30); + // Active: Frank(40), Grace(33), Ivy(31) > 30 = 3 + Assert.Equal(3, results[1].OlderThan30); + } + + [Fact] + public async Task GroupBy_SumWithPredicate_UsesConditionalAggregate() + { + // g.Sum(x => x.Age) with a Where predicate on the group exercises + // CombineTerms predicate path for SUM: + // SUM(CASE WHEN age > 30 THEN age ELSE NULL END) + await using var context = new TestDbContext(_fixture.ConnectionString); + + var results = await context.TestEntities + .GroupBy(e => e.IsActive) + .Select(g => new + { + IsActive = g.Key, + SumOlderThan30 = g.Where(e => e.Age > 30).Sum(e => e.Age), + }) + .OrderBy(x => x.IsActive) + .AsNoTracking() + .ToListAsync(); + + Assert.Equal(2, results.Count); + // Inactive: only Charlie(35) + Assert.Equal(35, results[0].SumOlderThan30); + // Active: Frank(40) + Grace(33) + Ivy(31) = 104 + Assert.Equal(104, results[1].SumOlderThan30); + } + } diff --git a/test/EFCore.ClickHouse.Tests/MathTranslationTests.cs b/test/EFCore.ClickHouse.Tests/MathTranslationTests.cs index b3ead33..12c4458 100644 --- a/test/EFCore.ClickHouse.Tests/MathTranslationTests.cs +++ b/test/EFCore.ClickHouse.Tests/MathTranslationTests.cs @@ -1,8 +1,92 @@ using Microsoft.EntityFrameworkCore; +using Testcontainers.ClickHouse; using Xunit; namespace EFCore.ClickHouse.Tests; +public class FloatEntity +{ + public long Id { get; set; } + public string Label { get; set; } = string.Empty; + public double ValFloat64 { get; set; } + public float ValFloat32 { get; set; } +} + +public class FloatDbContext : DbContext +{ + public DbSet Floats => Set(); + + private readonly string _connectionString; + + public FloatDbContext(string connectionString) + { + _connectionString = connectionString; + } + + protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder) + { + optionsBuilder.UseClickHouse(_connectionString); + } + + protected override void OnModelCreating(ModelBuilder modelBuilder) + { + modelBuilder.Entity(entity => + { + entity.ToTable("float_test"); + entity.HasKey(e => e.Id); + entity.Property(e => e.Id).HasColumnName("id"); + entity.Property(e => e.Label).HasColumnName("label"); + entity.Property(e => e.ValFloat64).HasColumnName("val_f64"); + entity.Property(e => e.ValFloat32).HasColumnName("val_f32"); + }); + } +} + +public class FloatFixture : IAsyncLifetime +{ + private readonly ClickHouseContainer _container = new ClickHouseBuilder("clickhouse/clickhouse-server:latest").Build(); + + public string ConnectionString { get; private set; } = string.Empty; + + public async Task InitializeAsync() + { + await _container.StartAsync(); + ConnectionString = _container.GetConnectionString(); + + using var connection = new global::ClickHouse.Driver.ADO.ClickHouseConnection(ConnectionString); + await connection.OpenAsync(); + + using var createCmd = connection.CreateCommand(); + createCmd.CommandText = """ + CREATE TABLE float_test ( + id Int64, + label String, + val_f64 Float64, + val_f32 Float32 + ) ENGINE = MergeTree() + ORDER BY id + """; + await createCmd.ExecuteNonQueryAsync(); + + using var insertCmd = connection.CreateCommand(); + insertCmd.CommandText = """ + INSERT INTO float_test (id, label, val_f64, val_f32) VALUES + (1, 'normal', 42.5, 42.5), + (2, 'negative', -10.0, -10.0), + (3, 'pos_inf', inf, inf), + (4, 'neg_inf', -inf, -inf), + (5, 'nan', nan, nan), + (6, 'zero', 0.0, 0.0) + """; + await insertCmd.ExecuteNonQueryAsync(); + } + + public async Task DisposeAsync() + { + await _container.DisposeAsync(); + } +} + public class MathTranslationTests : IClassFixture { private readonly ClickHouseFixture _fixture; @@ -244,3 +328,108 @@ public async Task Math_Where_WithMathFunction() Assert.Equal("Ivy", results[3].Name); } } + +/// +/// Tests for math translator special cases that aren't in the simple +/// dictionary lookup: Log(x, base), IsPositiveInfinity, IsNegativeInfinity. +/// +public class MathSpecialCaseTests : IClassFixture +{ + private readonly FloatFixture _fixture; + + public MathSpecialCaseTests(FloatFixture fixture) + { + _fixture = fixture; + } + + [Fact] + public async Task Math_LogWithBase_TranslatesCorrectly() + { + // Math.Log(x, base) → log(x) / log(base) + await using var context = new FloatDbContext(_fixture.ConnectionString); + + var results = await context.Floats + .Where(e => e.Label == "normal") + .Select(e => new { e.Label, LogBase10 = Math.Log(e.ValFloat64, 10) }) + .AsNoTracking() + .ToListAsync(); + + Assert.Single(results); + Assert.Equal(Math.Log(42.5, 10), results[0].LogBase10, 8); + } + + [Fact] + public async Task Math_LogWithBase2_TranslatesCorrectly() + { + await using var context = new FloatDbContext(_fixture.ConnectionString); + + var results = await context.Floats + .Where(e => e.Label == "normal") + .Select(e => new { e.Label, LogBase2 = Math.Log(e.ValFloat64, 2) }) + .AsNoTracking() + .ToListAsync(); + + Assert.Single(results); + Assert.Equal(Math.Log(42.5, 2), results[0].LogBase2, 6); + } + + [Fact] + public async Task Double_IsPositiveInfinity_FiltersCorrectly() + { + // double.IsPositiveInfinity(x) → isInfinite(x) AND x > 0 + await using var context = new FloatDbContext(_fixture.ConnectionString); + + var results = await context.Floats + .Where(e => double.IsPositiveInfinity(e.ValFloat64)) + .AsNoTracking() + .ToListAsync(); + + Assert.Single(results); + Assert.Equal("pos_inf", results[0].Label); + } + + [Fact] + public async Task Double_IsNegativeInfinity_FiltersCorrectly() + { + // double.IsNegativeInfinity(x) → isInfinite(x) AND x < 0 + await using var context = new FloatDbContext(_fixture.ConnectionString); + + var results = await context.Floats + .Where(e => double.IsNegativeInfinity(e.ValFloat64)) + .AsNoTracking() + .ToListAsync(); + + Assert.Single(results); + Assert.Equal("neg_inf", results[0].Label); + } + + [Fact] + public async Task Float_IsPositiveInfinity_FiltersCorrectly() + { + // float.IsPositiveInfinity(x) → isInfinite(x) AND x > 0 + await using var context = new FloatDbContext(_fixture.ConnectionString); + + var results = await context.Floats + .Where(e => float.IsPositiveInfinity(e.ValFloat32)) + .AsNoTracking() + .ToListAsync(); + + Assert.Single(results); + Assert.Equal("pos_inf", results[0].Label); + } + + [Fact] + public async Task Float_IsNegativeInfinity_FiltersCorrectly() + { + // float.IsNegativeInfinity(x) → isInfinite(x) AND x < 0 + await using var context = new FloatDbContext(_fixture.ConnectionString); + + var results = await context.Floats + .Where(e => float.IsNegativeInfinity(e.ValFloat32)) + .AsNoTracking() + .ToListAsync(); + + Assert.Single(results); + Assert.Equal("neg_inf", results[0].Label); + } +}