Skip to content

Commit

Permalink
Fix to #35239 - EF9: SaveChanges() is significantly slower in .NET9 v…
Browse files Browse the repository at this point in the history
…s. .NET8 when using .ToJson() Mapping vs. PostgreSQL Legacy POCO mapping
  • Loading branch information
maumar committed Dec 20, 2024
1 parent 4b3e12f commit f2c36d8
Show file tree
Hide file tree
Showing 4 changed files with 593 additions and 15 deletions.
140 changes: 137 additions & 3 deletions src/EFCore.Cosmos/ChangeTracking/Internal/StringDictionaryComparer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,25 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.ChangeTracking.Internal;
/// </summary>
public sealed class StringDictionaryComparer<TDictionary, TElement> : ValueComparer<object>, IInfrastructure<ValueComparer>
{
private static readonly bool UseOldBehavior35239 =
AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue35239", out var enabled35239) && enabled35239;

private static readonly MethodInfo CompareMethod = typeof(StringDictionaryComparer<TDictionary, TElement>).GetMethod(
nameof(Compare), BindingFlags.Static | BindingFlags.NonPublic, [typeof(object), typeof(object), typeof(Func<TElement, TElement, bool>)])!;

private static readonly MethodInfo LegacyCompareMethod = typeof(StringDictionaryComparer<TDictionary, TElement>).GetMethod(
nameof(Compare), BindingFlags.Static | BindingFlags.NonPublic, [typeof(object), typeof(object), typeof(ValueComparer)])!;

private static readonly MethodInfo GetHashCodeMethod = typeof(StringDictionaryComparer<TDictionary, TElement>).GetMethod(
nameof(GetHashCode), BindingFlags.Static | BindingFlags.NonPublic, [typeof(IEnumerable), typeof(Func<TElement, int>)])!;

private static readonly MethodInfo LegacyGetHashCodeMethod = typeof(StringDictionaryComparer<TDictionary, TElement>).GetMethod(
nameof(GetHashCode), BindingFlags.Static | BindingFlags.NonPublic, [typeof(IEnumerable), typeof(ValueComparer)])!;

private static readonly MethodInfo SnapshotMethod = typeof(StringDictionaryComparer<TDictionary, TElement>).GetMethod(
nameof(Snapshot), BindingFlags.Static | BindingFlags.NonPublic, [typeof(object), typeof(Func<TElement, TElement>)])!;

private static readonly MethodInfo LegacySnapshotMethod = typeof(StringDictionaryComparer<TDictionary, TElement>).GetMethod(
nameof(Snapshot), BindingFlags.Static | BindingFlags.NonPublic, [typeof(object), typeof(ValueComparer)])!;

/// <summary>
Expand Down Expand Up @@ -52,9 +64,23 @@ ValueComparer IInfrastructure<ValueComparer>.Instance
var prm1 = Expression.Parameter(typeof(object), "a");
var prm2 = Expression.Parameter(typeof(object), "b");

if (elementComparer is ValueComparer<TElement> && !UseOldBehavior35239)
{
// (a, b) => Compare(a, b, elementComparer.Equals)
return Expression.Lambda<Func<object?, object?, bool>>(
Expression.Call(
CompareMethod,
prm1,
prm2,
elementComparer.EqualsExpression),
prm1,
prm2);
}

// (a, b) => Compare(a, b, new Comparer(...))
return Expression.Lambda<Func<object?, object?, bool>>(
Expression.Call(
CompareMethod,
LegacyCompareMethod,
prm1,
prm2,
#pragma warning disable EF9100
Expand All @@ -68,9 +94,23 @@ private static Expression<Func<object, int>> GetHashCodeLambda(ValueComparer ele
{
var prm = Expression.Parameter(typeof(object), "o");

if (elementComparer is ValueComparer<TElement> && !UseOldBehavior35239)
{
// o => GetHashCode((IEnumerable)o, elementComparer.GetHashCode)
return Expression.Lambda<Func<object, int>>(
Expression.Call(
GetHashCodeMethod,
Expression.Convert(
prm,
typeof(IEnumerable)),
elementComparer.HashCodeExpression),
prm);
}

// o => GetHashCode((IEnumerable)o, new Comparer(...))
return Expression.Lambda<Func<object, int>>(
Expression.Call(
GetHashCodeMethod,
LegacyGetHashCodeMethod,
Expression.Convert(
prm,
typeof(IEnumerable)),
Expand All @@ -84,16 +124,70 @@ private static Expression<Func<object, object>> SnapshotLambda(ValueComparer ele
{
var prm = Expression.Parameter(typeof(object), "source");

if (elementComparer is ValueComparer<TElement> && !UseOldBehavior35239)
{
// source => Snapshot(source, elementComparer.Snapshot)
return Expression.Lambda<Func<object, object>>(
Expression.Call(
SnapshotMethod,
prm,
elementComparer.SnapshotExpression),
prm);
}

// source => Snapshot(source, new Comparer(..))
return Expression.Lambda<Func<object, object>>(
Expression.Call(
SnapshotMethod,
LegacySnapshotMethod,
prm,
#pragma warning disable EF9100
elementComparer.ConstructorExpression),
#pragma warning restore EF9100
prm);
}

private static bool Compare(object? a, object? b, Func<TElement?, TElement?, bool> elementCompare)
{
if (ReferenceEquals(a, b))
{
return true;
}

if (a is null)
{
return b is null;
}

if (b is null)
{
return false;
}

if (a is IReadOnlyDictionary<string, TElement?> aDictionary && b is IReadOnlyDictionary<string, TElement?> bDictionary)
{
if (aDictionary.Count != bDictionary.Count)
{
return false;
}

foreach (var pair in aDictionary)
{
if (!bDictionary.TryGetValue(pair.Key, out var bValue)
|| !elementCompare(pair.Value, bValue))
{
return false;
}
}

return true;
}

throw new InvalidOperationException(
CosmosStrings.BadDictionaryType(
(a is IDictionary<string, TElement?> ? b : a).GetType().ShortDisplayName(),
typeof(IDictionary<,>).MakeGenericType(typeof(string), typeof(TElement)).ShortDisplayName()));
}

private static bool Compare(object? a, object? b, ValueComparer elementComparer)
{
if (ReferenceEquals(a, b))
Expand Down Expand Up @@ -136,6 +230,27 @@ private static bool Compare(object? a, object? b, ValueComparer elementComparer)
typeof(IDictionary<,>).MakeGenericType(typeof(string), elementComparer.Type).ShortDisplayName()));
}

private static int GetHashCode(IEnumerable source, Func<TElement?, int> elementGetHashCode)
{
if (source is not IReadOnlyDictionary<string, TElement?> sourceDictionary)
{
throw new InvalidOperationException(
CosmosStrings.BadDictionaryType(
source.GetType().ShortDisplayName(),
typeof(IList<>).MakeGenericType(typeof(TElement)).ShortDisplayName()));
}

var hash = new HashCode();

foreach (var pair in sourceDictionary)
{
hash.Add(pair.Key);
hash.Add(pair.Value == null ? 0 : elementGetHashCode(pair.Value));
}

return hash.ToHashCode();
}

private static int GetHashCode(IEnumerable source, ValueComparer elementComparer)
{
if (source is not IReadOnlyDictionary<string, TElement?> sourceDictionary)
Expand All @@ -157,6 +272,25 @@ private static int GetHashCode(IEnumerable source, ValueComparer elementComparer
return hash.ToHashCode();
}

private static IReadOnlyDictionary<string, TElement?> Snapshot(object source, Func<TElement?, TElement?> elementSnapshot)
{
if (source is not IReadOnlyDictionary<string, TElement?> sourceDictionary)
{
throw new InvalidOperationException(
CosmosStrings.BadDictionaryType(
source.GetType().ShortDisplayName(),
typeof(IDictionary<,>).MakeGenericType(typeof(string), typeof(TElement)).ShortDisplayName()));
}

var snapshot = new Dictionary<string, TElement?>();
foreach (var pair in sourceDictionary)
{
snapshot[pair.Key] = pair.Value == null ? default : (TElement?)elementSnapshot(pair.Value);
}

return snapshot;
}

private static IReadOnlyDictionary<string, TElement?> Snapshot(object source, ValueComparer elementComparer)
{
if (source is not IReadOnlyDictionary<string, TElement?> sourceDictionary)
Expand Down
Loading

0 comments on commit f2c36d8

Please sign in to comment.