diff --git a/.github/workflows/dotnet.yml b/.github/workflows/dotnet.yml index e26254b..91822b4 100644 --- a/.github/workflows/dotnet.yml +++ b/.github/workflows/dotnet.yml @@ -2,20 +2,20 @@ name: .NET on: push: - branches: [ "master" ] + branches: [ "main" ] pull_request: - branches: [ "master" ] + branches: [ "main" ] permissions: contents: read jobs: build: - runs-on: windows-latest + runs-on: ubuntu-latest steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v4 - name: Restore dependencies run: dotnet restore @@ -23,5 +23,44 @@ jobs: - name: Build run: dotnet build --configuration Release --no-restore - - name: Test - run: dotnet test --configuration Release --no-build --verbosity normal + test-sqlserver: + runs-on: ubuntu-latest + needs: build + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Restore dependencies + run: dotnet restore + + - name: Test SqlServer + run: dotnet test "N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.SqlServer.Test/N.EntityFramework.Extensions.SqlServer.Test.csproj" --configuration Release --settings N.EntityFramework.Extensions.SqlServer.runsettings --verbosity normal + + test-mysql: + runs-on: ubuntu-latest + needs: build + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Restore dependencies + run: dotnet restore + + - name: Test MySQL + run: dotnet test "N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/N.EntityFramework.Extensions.MySql.Test.csproj" --configuration Release --settings N.EntityFramework.Extensions.MySql.runsettings --verbosity normal + + test-postgresql: + runs-on: ubuntu-latest + needs: build + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Restore dependencies + run: dotnet restore + + - name: Test PostgreSQL + run: dotnet test "N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.PostgreSql.Test/N.EntityFramework.Extensions.PostgreSql.Test.csproj" --configuration Release --settings N.EntityFramework.Extensions.PostgreSql.runsettings --verbosity normal diff --git a/N.EntityFrameworkCore.PostgreSQL.Extensions.runsettings b/N.EntityFramework.Extensions.MySql.runsettings similarity index 91% rename from N.EntityFrameworkCore.PostgreSQL.Extensions.runsettings rename to N.EntityFramework.Extensions.MySql.runsettings index e5061ec..6f2f423 100644 --- a/N.EntityFrameworkCore.PostgreSQL.Extensions.runsettings +++ b/N.EntityFramework.Extensions.MySql.runsettings @@ -1,7 +1,7 @@ - 1 + 4 diff --git a/N.EntityFramework.Extensions.MySql/Common/Constants.cs b/N.EntityFramework.Extensions.MySql/Common/Constants.cs new file mode 100644 index 0000000..d756e1d --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Common/Constants.cs @@ -0,0 +1,6 @@ +namespace N.EntityFrameworkCore.Extensions.Common; + +public static class Constants +{ + public static readonly string InternalId_ColumnName = "_be_xx_id"; +} \ No newline at end of file diff --git a/N.EntityFramework.Extensions.MySql/Data/BulkDeleteOptions.cs b/N.EntityFramework.Extensions.MySql/Data/BulkDeleteOptions.cs new file mode 100644 index 0000000..412d3b4 --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Data/BulkDeleteOptions.cs @@ -0,0 +1,9 @@ +using System; +using System.Linq.Expressions; + +namespace N.EntityFrameworkCore.Extensions; + +public class BulkDeleteOptions : BulkOptions +{ + public Expression> DeleteOnCondition { get; set; } +} \ No newline at end of file diff --git a/N.EntityFramework.Extensions.MySql/Data/BulkFetchOptions.cs b/N.EntityFramework.Extensions.MySql/Data/BulkFetchOptions.cs new file mode 100644 index 0000000..790c7f2 --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Data/BulkFetchOptions.cs @@ -0,0 +1,11 @@ +using System; +using System.Linq.Expressions; + +namespace N.EntityFrameworkCore.Extensions; + +public class BulkFetchOptions : BulkOptions +{ + public Expression> IgnoreColumns { get; set; } + public Expression> InputColumns { get; set; } + public Expression> JoinOnCondition { get; set; } +} \ No newline at end of file diff --git a/N.EntityFramework.Extensions.MySql/Data/BulkInsertOptions.cs b/N.EntityFramework.Extensions.MySql/Data/BulkInsertOptions.cs new file mode 100644 index 0000000..9daa53c --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Data/BulkInsertOptions.cs @@ -0,0 +1,27 @@ +using System; +using System.Linq; +using System.Linq.Expressions; + +namespace N.EntityFrameworkCore.Extensions; + +public class BulkInsertOptions : BulkOptions +{ + public bool AutoMapOutput { get; set; } + public Expression> IgnoreColumns { get; set; } + public Expression> InputColumns { get; set; } + public bool InsertIfNotExists { get; set; } + public Expression> InsertOnCondition { get; set; } + public bool KeepIdentity { get; set; } + + public string[] GetInputColumns() => + InputColumns?.Body.Type.GetProperties().Select(o => o.Name).ToArray(); + + public BulkInsertOptions() + { + AutoMapOutput = true; + } + internal BulkInsertOptions(BulkOptions options) + { + EntityType = options.EntityType; + } +} \ No newline at end of file diff --git a/N.EntityFramework.Extensions.MySql/Data/BulkInsertResult.cs b/N.EntityFramework.Extensions.MySql/Data/BulkInsertResult.cs new file mode 100644 index 0000000..29315ff --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Data/BulkInsertResult.cs @@ -0,0 +1,9 @@ +using System.Collections.Generic; + +namespace N.EntityFrameworkCore.Extensions; + +internal sealed class BulkInsertResult +{ + internal int RowsAffected { get; set; } + internal Dictionary EntityMap { get; set; } +} \ No newline at end of file diff --git a/N.EntityFramework.Extensions.MySql/Data/BulkMergeOption.cs b/N.EntityFramework.Extensions.MySql/Data/BulkMergeOption.cs new file mode 100644 index 0000000..e388ee3 --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Data/BulkMergeOption.cs @@ -0,0 +1,24 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; + +namespace N.EntityFrameworkCore.Extensions; + +public class BulkMergeOptions : BulkOptions +{ + public Expression> MergeOnCondition { get; set; } + public Expression> IgnoreColumnsOnInsert { get; set; } + public Expression> IgnoreColumnsOnUpdate { get; set; } + public bool AutoMapOutput { get; set; } + internal bool DeleteIfNotMatched { get; set; } + + public BulkMergeOptions() + { + AutoMapOutput = true; + } + public List GetIgnoreColumnsOnInsert() => + IgnoreColumnsOnInsert?.Body.Type.GetProperties().Select(o => o.Name).ToList() ?? []; + public List GetIgnoreColumnsOnUpdate() => + IgnoreColumnsOnUpdate?.Body.Type.GetProperties().Select(o => o.Name).ToList() ?? []; +} \ No newline at end of file diff --git a/N.EntityFramework.Extensions.MySql/Data/BulkMergeOutputRow.cs b/N.EntityFramework.Extensions.MySql/Data/BulkMergeOutputRow.cs new file mode 100644 index 0000000..a3f5703 --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Data/BulkMergeOutputRow.cs @@ -0,0 +1,11 @@ +namespace N.EntityFrameworkCore.Extensions; + +public class BulkMergeOutputRow +{ + public string Action { get; set; } + + public BulkMergeOutputRow(string action) + { + Action = action; + } +} \ No newline at end of file diff --git a/N.EntityFramework.Extensions.MySql/Data/BulkMergeResult.cs b/N.EntityFramework.Extensions.MySql/Data/BulkMergeResult.cs new file mode 100644 index 0000000..b28592e --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Data/BulkMergeResult.cs @@ -0,0 +1,12 @@ +using System.Collections.Generic; + +namespace N.EntityFrameworkCore.Extensions; + +public class BulkMergeResult +{ + public IEnumerable> Output { get; set; } + public int RowsAffected { get; set; } + public int RowsDeleted { get; internal set; } + public int RowsInserted { get; internal set; } + public int RowsUpdated { get; internal set; } +} \ No newline at end of file diff --git a/N.EntityFramework.Extensions.MySql/Data/BulkOperation.cs b/N.EntityFramework.Extensions.MySql/Data/BulkOperation.cs new file mode 100644 index 0000000..b54414c --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Data/BulkOperation.cs @@ -0,0 +1,276 @@ +using System; +using System.Collections.Generic; +using System.Data.Common; +using System.IO; +using System.Linq; +using System.Linq.Expressions; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.ChangeTracking.Internal; +using Microsoft.EntityFrameworkCore.Metadata; +using N.EntityFrameworkCore.Extensions.Common; +using N.EntityFrameworkCore.Extensions.Sql; +using N.EntityFrameworkCore.Extensions.Util; + +namespace N.EntityFrameworkCore.Extensions; + +internal sealed partial class BulkOperation : IDisposable +{ + internal DbConnection Connection => DbTransactionContext.Connection; + internal DbContext Context { get; } + internal bool StagingTableCreated { get; set; } + internal string StagingTableName { get; } + internal string[] PrimaryKeyColumnNames { get; } + internal BulkOptions Options { get; } + internal Expression> InputColumns { get; } + internal Expression> IgnoreColumns { get; } + internal DbTransactionContext DbTransactionContext { get; } + internal Type EntityType => typeof(T); + internal DbTransaction Transaction => DbTransactionContext.CurrentTransaction; + internal TableMapping TableMapping { get; } + internal IEnumerable SchemaQualifiedTableNames => TableMapping.GetSchemaQualifiedTableNames(); + + + public BulkOperation(DbContext dbContext, BulkOptions options, Expression> inputColumns = null, Expression> ignoreColumns = null) + { + Context = dbContext; + Options = options; + InputColumns = inputColumns; + IgnoreColumns = ignoreColumns; + + DbTransactionContext = new DbTransactionContext(dbContext, options.CommandTimeout); + TableMapping = dbContext.GetTableMapping(typeof(T), options.EntityType); + StagingTableName = CommonUtil.GetStagingTableName(TableMapping, options.UsePermanentTable, Connection); + PrimaryKeyColumnNames = TableMapping.GetPrimaryKeyColumns().ToArray(); + } + public void Dispose() + { + if (StagingTableCreated) + { + // For MySQL temporary staging tables, use DROP TEMPORARY TABLE to avoid implicit transaction commit + bool isTemporary = Context.Database.IsMySql() && !Options.UsePermanentTable; + Context.Database.DropTable(StagingTableName, true, isTemporary); + } + } + internal bool ShouldKeepIdentityForMerge() + { + // MySQL does not support keeping identity values for merge in the same way as PostgreSQL + return false; + } + internal bool ShouldPreallocateIdentityValues(bool autoMapOutput, bool keepIdentity, IEnumerable entities) + { + // MySQL does not support preallocating identity values + return false; + } + internal void PreallocateIdentityValues(IEnumerable entities) + { + // No-op for MySQL + } + internal BulkInsertResult BulkInsertStagingData(IEnumerable entities, bool keepIdentity = true, bool useInternalId = false) + { + IEnumerable columnsToInsert = GetColumnNames(keepIdentity); + string internalIdColumn = useInternalId ? Common.Constants.InternalId_ColumnName : null; + Context.Database.CloneTable(SchemaQualifiedTableNames, StagingTableName, TableMapping.GetQualifiedColumnNames(columnsToInsert), internalIdColumn, isTemporary: !Options.UsePermanentTable); + StagingTableCreated = true; + // ALTER TABLE on temporary tables causes an implicit commit in MySQL, which would break + // any active user transaction. Skip adding the index when inside a user-provided transaction. + if (keepIdentity && PrimaryKeyColumnNames.Length > 0 && Context.Database.IsMySql() && DbTransactionContext.OwnsTransaction) + { + string indexColumns = string.Join(",", PrimaryKeyColumnNames.Select(c => Context.DelimitIdentifier(c))); + Context.Database.ExecuteSqlInternal($"ALTER TABLE {StagingTableName} ADD INDEX idx_pk ({indexColumns})"); + } + return DbContextExtensions.BulkInsert(entities, Options, TableMapping, Connection, Transaction, StagingTableName, columnsToInsert, useInternalId); + } + internal BulkMergeResult ExecuteMerge(Dictionary entityMap, Expression> mergeOnCondition, + bool autoMapOutput, bool keepIdentity, bool insertIfNotExists, bool update = false, bool delete = false, bool preallocatedIds = false) + { + return ExecuteMergeMySql(entityMap, mergeOnCondition, autoMapOutput, keepIdentity, insertIfNotExists, update, delete, preallocatedIds); + } + + private IEnumerable GetMergeOutputColumns(IEnumerable autoGeneratedColumns, bool delete = false) + { + List columnsToOutput = ["$action", $"[s].[{Constants.InternalId_ColumnName}]"]; + columnsToOutput.AddRange(autoGeneratedColumns.Select(o => $"[inserted].[{o}]")); + return columnsToOutput; + } + private object[] GetMergeOutputValues(IEnumerable columns, object[] values, IEnumerable properties) + { + var columnList = columns.ToList(); + var valuesIndex = properties.Select(o => columnList.IndexOf($"[inserted].[{o.GetColumnName()}]")); + return valuesIndex.Select(i => values[i]).ToArray(); + } + internal int ExecuteUpdate(IEnumerable entities, Expression> updateOnCondition) + { + return ExecuteUpdateMySql(updateOnCondition); + } + private BulkMergeResult ExecuteMergeMySql(Dictionary entityMap, Expression> mergeOnCondition, + bool autoMapOutput, bool keepIdentity, bool insertIfNotExists, bool update, bool delete, bool preallocatedIds = false) + { + Dictionary rowsInserted = []; + Dictionary rowsUpdated = []; + Dictionary rowsDeleted = []; + List> outputRows = []; + + foreach (var entityType in TableMapping.EntityTypes) + { + var targetTableName = Context.DelimitIdentifier(entityType.GetTableName(), entityType.GetSchema() ?? Context.Database.GetDefaultSchema()); + var columnsToInsert = GetColumnNames(entityType, keepIdentity).ToList(); + var columnsToUpdate = update ? GetColumnNames(entityType).ToList() : []; + var autoGeneratedColumns = autoMapOutput ? TableMapping.GetAutoGeneratedColumns(entityType).ToList() : []; + var allProperties = autoMapOutput + ? TableMapping.GetEntityProperties(entityType, ValueGenerated.OnAdd).Concat(TableMapping.GetEntityProperties(entityType, ValueGenerated.OnAddOrUpdate)).ToList() + : []; + + string matchJoinCondition = CommonUtil.GetJoinConditionSql(Context, mergeOnCondition, PrimaryKeyColumnNames, "s", "t"); + string pkJoinCondition = CommonUtil.GetJoinConditionSql(Context, null, PrimaryKeyColumnNames, "s", "t"); + string joinCondition = insertIfNotExists ? matchJoinCondition : "1=2"; + + HashSet matchedIds = autoMapOutput && update + ? GetMatchedInternalIds(targetTableName, matchJoinCondition) + : []; + + rowsUpdated[entityType] = 0; + if (columnsToUpdate.Count > 0) + { + // MySQL UPDATE returns "rows changed" by default, not "rows matched". + // Count matched rows first to get reliable "rows found" semantics. + string matchCountSql = $"SELECT COUNT(*) FROM {StagingTableName} AS s INNER JOIN {targetTableName} AS t ON {joinCondition}"; + rowsUpdated[entityType] = Convert.ToInt32(Context.Database.ExecuteScalar(matchCountSql, null, Options.CommandTimeout)); + // MySQL UPDATE with JOIN syntax + string updateSetExpression = string.Join(",", columnsToUpdate.Select(c => $"t.{Context.DelimitIdentifier(c)}=s.{Context.DelimitIdentifier(c)}")); + string updateSql = $"UPDATE {StagingTableName} AS s INNER JOIN {targetTableName} AS t ON {joinCondition} SET {updateSetExpression}"; + Context.Database.ExecuteSqlInternal(updateSql, Options.CommandTimeout); + } + + rowsDeleted[entityType] = 0; + if (TableMapping.EntityType == entityType && delete) + { + string deleteJoinCondition = mergeOnCondition != null ? matchJoinCondition : pkJoinCondition; + // MySQL multi-table DELETE syntax: DELETE alias FROM table AS alias WHERE ... + string deleteSql = $"DELETE t FROM {targetTableName} AS t WHERE NOT EXISTS (SELECT 1 FROM {StagingTableName} AS s WHERE {deleteJoinCondition})"; + rowsDeleted[entityType] = Context.Database.ExecuteSqlInternal(deleteSql, Options.CommandTimeout); + for (int i = 0; i < rowsDeleted[entityType]; i++) + outputRows.Add(new BulkMergeOutputRow(SqlMergeAction.Delete)); + } + + string insertColumnsSql = string.Join(",", columnsToInsert.Select(Context.DelimitIdentifier)); + string sourceColumnsSql = string.Join(",", columnsToInsert.Select(c => Context.DelimitMemberAccess("s", c))); + // When staging rows have Id=0 (keepIdentity=false, no explicit merge condition), ORDER BY _be_xx_id so + // LAST_INSERT_ID() + k correctly maps to each entity in entityMap order. + bool useLastInsertId = autoMapOutput && !keepIdentity && mergeOnCondition == null && autoGeneratedColumns.Any(); + string insertOrderBy = useLastInsertId ? $" ORDER BY s.{Context.DelimitIdentifier(Constants.InternalId_ColumnName)}" : ""; + string insertSql = $"INSERT INTO {targetTableName} ({insertColumnsSql}) SELECT {sourceColumnsSql} FROM {StagingTableName} AS s WHERE NOT EXISTS (SELECT 1 FROM {targetTableName} AS t WHERE {joinCondition}){insertOrderBy}"; + rowsInserted[entityType] = Context.Database.ExecuteSqlInternal(insertSql, Options.CommandTimeout); + if (keepIdentity && rowsInserted[entityType] > 0) + SyncMySqlAutoIncrement(entityType); + + if (autoMapOutput) + { + if (useLastInsertId && rowsInserted[entityType] > 0) + { + // When keepIdentity=false and no merge condition, staging PKs are all 0. + // JOIN on PK would return no rows, so use LAST_INSERT_ID() instead. + // MySQL guarantees consecutive auto-increment IDs for a single INSERT statement. + var lastInsertResult = Context.BulkQuery("SELECT LAST_INSERT_ID()", Options); + long firstAutoId = Convert.ToInt64(lastInsertResult.Results.First()[0]); + var insertedEntities = entityMap + .Where(kvp => !matchedIds.Contains((int)kvp.Key)) + .OrderBy(kvp => kvp.Key) + .ToList(); + for (int k = 0; k < insertedEntities.Count; k++) + { + var entity = insertedEntities[k].Value; + long generatedId = firstAutoId + k; + var generatedPkProperty = GetGeneratedPrimaryKeyProperty(); + if (generatedPkProperty != null) + { + object pkValue = Convert.ChangeType(generatedId, generatedPkProperty.ClrType); + Context.SetStoreGeneratedValues(entity, [generatedPkProperty], [pkValue]); + } + outputRows.Add(new BulkMergeOutputRow(SqlMergeAction.Insert)); + } + } + else + { + string outputColumnsSql = autoGeneratedColumns.Any() + ? "," + string.Join(",", autoGeneratedColumns.Select(c => Context.DelimitMemberAccess("t", c))) + : string.Empty; + var outputQuery = $"SELECT {Context.DelimitMemberAccess("s", Constants.InternalId_ColumnName)}{outputColumnsSql} FROM {StagingTableName} AS s JOIN {targetTableName} AS t ON {matchJoinCondition}"; + var bulkQueryResult = Context.BulkQuery(outputQuery, Options); + var autoGeneratedColumnList = autoGeneratedColumns.ToList(); + foreach (var result in bulkQueryResult.Results) + { + int entityId = Convert.ToInt32(result[0]); + bool wasMatched = matchedIds.Contains(entityId); + string action = wasMatched ? SqlMergeAction.Update : SqlMergeAction.Insert; + outputRows.Add(new BulkMergeOutputRow(action)); + + if (entityMap.TryGetValue(entityId, out var entity) && allProperties.Count > 0) + { + object[] entityValues = allProperties.Select(p => result[1 + autoGeneratedColumnList.IndexOf(p.GetColumnName())]).ToArray(); + Context.SetStoreGeneratedValues(entity, allProperties, entityValues); + } + } + } + } + } + + return new BulkMergeResult + { + Output = outputRows, + RowsAffected = rowsInserted.Values.FirstOrDefault() + rowsUpdated.Values.FirstOrDefault() + rowsDeleted.Values.Sum(), + RowsDeleted = rowsDeleted.Values.Sum(), + RowsInserted = rowsInserted.Values.FirstOrDefault(), + RowsUpdated = rowsUpdated.Values.FirstOrDefault() + }; + } + private int ExecuteUpdateMySql(Expression> updateOnCondition) + { + int rowsUpdated = 0; + foreach (var entityType in TableMapping.EntityTypes) + { + IEnumerable columnsToUpdate = GetColumnNames(entityType); + string updateSetExpression = string.Join(",", columnsToUpdate.Select(c => $"t.{Context.DelimitIdentifier(c)}=s.{Context.DelimitIdentifier(c)}")); + string targetTableName = Context.DelimitIdentifier(entityType.GetTableName(), entityType.GetSchema() ?? Context.Database.GetDefaultSchema()); + // MySQL UPDATE with JOIN syntax + string updateSql = $"UPDATE {StagingTableName} AS s INNER JOIN {targetTableName} AS t ON {CommonUtil.GetJoinConditionSql(Context, updateOnCondition, PrimaryKeyColumnNames, "s", "t")} SET {updateSetExpression}"; + rowsUpdated = Context.Database.ExecuteSqlInternal(updateSql, Options.CommandTimeout); + } + return rowsUpdated; + } + private HashSet GetMatchedInternalIds(string targetTableName, string joinCondition) + { + var results = Context.BulkQuery( + $"SELECT {Context.DelimitMemberAccess("s", Constants.InternalId_ColumnName)} FROM {StagingTableName} AS s JOIN {targetTableName} AS t ON {joinCondition}", + Options); + return results.Results.Select(r => Convert.ToInt32(r[0])).ToHashSet(); + } + private IProperty GetGeneratedPrimaryKeyProperty() + { + return TableMapping.EntityType.GetProperties().SingleOrDefault(o => o.IsPrimaryKey() && o.ValueGenerated != ValueGenerated.Never); + } + private void SyncMySqlAutoIncrement(IEntityType entityType) + { + var tableName = Context.DelimitIdentifier(entityType.GetTableName(), entityType.GetSchema() ?? Context.Database.GetDefaultSchema()); + // Reset AUTO_INCREMENT to resync after bulk insert with explicit IDs + Context.Database.ExecuteSqlInternal($"ALTER TABLE {tableName} AUTO_INCREMENT = 1", Options.CommandTimeout); + } + internal void ValidateBulkMerge(Expression> mergeOnCondition) + { + if (PrimaryKeyColumnNames.Length == 0 && mergeOnCondition == null) + throw new InvalidDataException("BulkMerge requires that the entity have a primary key or that Options.MergeOnCondition be set"); + } + internal void ValidateBulkUpdate(Expression> updateOnCondition) + { + if (PrimaryKeyColumnNames.Length == 0 && updateOnCondition == null) + throw new InvalidDataException("BulkUpdate requires that the entity have a primary key or the Options.UpdateOnCondition must be set."); + + } + internal IEnumerable GetColumnNames(bool includePrimaryKeys = false) + { + return GetColumnNames(null, includePrimaryKeys); + } + internal IEnumerable GetColumnNames(IEntityType entityType, bool includePrimaryKeys = false) + { + return CommonUtil.FilterColumns(TableMapping.GetColumnNames(entityType, includePrimaryKeys), PrimaryKeyColumnNames, InputColumns, IgnoreColumns); + } +} diff --git a/N.EntityFramework.Extensions.MySql/Data/BulkOperationAsync.cs b/N.EntityFramework.Extensions.MySql/Data/BulkOperationAsync.cs new file mode 100644 index 0000000..f2092a2 --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Data/BulkOperationAsync.cs @@ -0,0 +1,184 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Metadata; +using N.EntityFrameworkCore.Extensions.Common; +using N.EntityFrameworkCore.Extensions.Util; + +namespace N.EntityFrameworkCore.Extensions; + +internal sealed partial class BulkOperation +{ + internal async Task> BulkInsertStagingDataAsync(IEnumerable entities, bool keepIdentity = true, bool useInternalId = false, CancellationToken cancellationToken = default) + { + IEnumerable columnsToInsert = GetColumnNames(keepIdentity); + string internalIdColumn = useInternalId ? Common.Constants.InternalId_ColumnName : null; + await Context.Database.CloneTableAsync(SchemaQualifiedTableNames, StagingTableName, TableMapping.GetQualifiedColumnNames(columnsToInsert), internalIdColumn, cancellationToken, isTemporary: !Options.UsePermanentTable); + StagingTableCreated = true; + // ALTER TABLE on temporary tables causes an implicit commit in MySQL, which would break + // any active user transaction. Skip adding the index when inside a user-provided transaction. + if (keepIdentity && PrimaryKeyColumnNames.Length > 0 && Context.Database.IsMySql() && DbTransactionContext.OwnsTransaction) + { + string indexColumns = string.Join(",", PrimaryKeyColumnNames.Select(c => Context.DelimitIdentifier(c))); + await Context.Database.ExecuteSqlAsync($"ALTER TABLE {StagingTableName} ADD INDEX idx_pk ({indexColumns})", Options.CommandTimeout, cancellationToken); + } + return await DbContextExtensionsAsync.BulkInsertAsync(entities, Options, TableMapping, Connection, Transaction, StagingTableName, columnsToInsert, useInternalId, cancellationToken); + } + + internal async Task> ExecuteMergeAsync(Dictionary entityMap, Expression> mergeOnCondition, + bool autoMapOutput, bool keepIdentity, bool insertIfNotExists, bool update = false, bool delete = false, bool preallocatedIds = false, CancellationToken cancellationToken = default) + { + return await ExecuteMergeMySqlAsync(entityMap, mergeOnCondition, autoMapOutput, keepIdentity, insertIfNotExists, update, delete, preallocatedIds, cancellationToken); + } + internal async Task ExecuteUpdateAsync(IEnumerable entities, Expression> updateOnCondition, CancellationToken cancellationToken = default) + { + return await ExecuteUpdateMySqlAsync(updateOnCondition, cancellationToken); + } + private async Task> ExecuteMergeMySqlAsync(Dictionary entityMap, Expression> mergeOnCondition, + bool autoMapOutput, bool keepIdentity, bool insertIfNotExists, bool update, bool delete, bool preallocatedIds = false, CancellationToken cancellationToken = default) + { + Dictionary rowsInserted = []; + Dictionary rowsUpdated = []; + Dictionary rowsDeleted = []; + List> outputRows = []; + + foreach (var entityType in TableMapping.EntityTypes) + { + var targetTableName = Context.DelimitIdentifier(entityType.GetTableName(), entityType.GetSchema() ?? Context.Database.GetDefaultSchema()); + var columnsToInsert = GetColumnNames(entityType, keepIdentity).ToList(); + var columnsToUpdate = update ? GetColumnNames(entityType).ToList() : []; + var autoGeneratedColumns = autoMapOutput ? TableMapping.GetAutoGeneratedColumns(entityType).ToList() : []; + var allProperties = autoMapOutput + ? TableMapping.GetEntityProperties(entityType, ValueGenerated.OnAdd).Concat(TableMapping.GetEntityProperties(entityType, ValueGenerated.OnAddOrUpdate)).ToList() + : []; + + string matchJoinCondition = CommonUtil.GetJoinConditionSql(Context, mergeOnCondition, PrimaryKeyColumnNames, "s", "t"); + string pkJoinCondition = CommonUtil.GetJoinConditionSql(Context, null, PrimaryKeyColumnNames, "s", "t"); + string joinCondition = insertIfNotExists ? matchJoinCondition : "1=2"; + + HashSet matchedIds = autoMapOutput && update + ? await GetMatchedInternalIdsAsync(targetTableName, matchJoinCondition, cancellationToken) + : []; + + rowsUpdated[entityType] = 0; + if (columnsToUpdate.Count > 0) + { + // MySQL UPDATE returns "rows changed" by default, not "rows matched". + // Count matched rows first to get reliable "rows found" semantics. + string matchCountSql = $"SELECT COUNT(*) FROM {StagingTableName} AS s INNER JOIN {targetTableName} AS t ON {joinCondition}"; + rowsUpdated[entityType] = Convert.ToInt32(await Context.Database.ExecuteScalarAsync(matchCountSql, null, Options.CommandTimeout, cancellationToken)); + string updateSetExpression = string.Join(",", columnsToUpdate.Select(c => $"t.{Context.DelimitIdentifier(c)}=s.{Context.DelimitIdentifier(c)}")); + string updateSql = $"UPDATE {StagingTableName} AS s INNER JOIN {targetTableName} AS t ON {joinCondition} SET {updateSetExpression}"; + await Context.Database.ExecuteSqlAsync(updateSql, Options.CommandTimeout, cancellationToken); + } + + rowsDeleted[entityType] = 0; + if (TableMapping.EntityType == entityType && delete) + { + string deleteJoinCondition = mergeOnCondition != null ? matchJoinCondition : pkJoinCondition; + string deleteSql = $"DELETE t FROM {targetTableName} AS t WHERE NOT EXISTS (SELECT 1 FROM {StagingTableName} AS s WHERE {deleteJoinCondition})"; + rowsDeleted[entityType] = await Context.Database.ExecuteSqlAsync(deleteSql, Options.CommandTimeout, cancellationToken); + for (int i = 0; i < rowsDeleted[entityType]; i++) + outputRows.Add(new BulkMergeOutputRow(SqlMergeAction.Delete)); + } + + string insertColumnsSql = string.Join(",", columnsToInsert.Select(Context.DelimitIdentifier)); + string sourceColumnsSql = string.Join(",", columnsToInsert.Select(c => Context.DelimitMemberAccess("s", c))); + bool useLastInsertId = autoMapOutput && !keepIdentity && mergeOnCondition == null && autoGeneratedColumns.Any(); + string insertOrderBy = useLastInsertId ? $" ORDER BY s.{Context.DelimitIdentifier(Constants.InternalId_ColumnName)}" : ""; + string insertSql = $"INSERT INTO {targetTableName} ({insertColumnsSql}) SELECT {sourceColumnsSql} FROM {StagingTableName} AS s WHERE NOT EXISTS (SELECT 1 FROM {targetTableName} AS t WHERE {joinCondition}){insertOrderBy}"; + rowsInserted[entityType] = await Context.Database.ExecuteSqlAsync(insertSql, Options.CommandTimeout, cancellationToken); + if (keepIdentity && rowsInserted[entityType] > 0) + await SyncMySqlAutoIncrementAsync(entityType, cancellationToken); + + if (autoMapOutput) + { + if (useLastInsertId && rowsInserted[entityType] > 0) + { + var lastInsertResult = await Context.BulkQueryAsync("SELECT LAST_INSERT_ID()", Connection, Transaction, Options, cancellationToken); + long firstAutoId = Convert.ToInt64(lastInsertResult.Results.First()[0]); + var insertedEntities = entityMap + .Where(kvp => !matchedIds.Contains((int)kvp.Key)) + .OrderBy(kvp => kvp.Key) + .ToList(); + for (int k = 0; k < insertedEntities.Count; k++) + { + var entity = insertedEntities[k].Value; + long generatedId = firstAutoId + k; + var generatedPkProperty = GetGeneratedPrimaryKeyProperty(); + if (generatedPkProperty != null) + { + object pkValue = Convert.ChangeType(generatedId, generatedPkProperty.ClrType); + Context.SetStoreGeneratedValues(entity, [generatedPkProperty], [pkValue]); + } + outputRows.Add(new BulkMergeOutputRow(SqlMergeAction.Insert)); + } + } + else + { + string outputColumnsSql = autoGeneratedColumns.Any() + ? "," + string.Join(",", autoGeneratedColumns.Select(c => Context.DelimitMemberAccess("t", c))) + : string.Empty; + string outputQuery = $"SELECT {Context.DelimitMemberAccess("s", Constants.InternalId_ColumnName)}{outputColumnsSql} FROM {StagingTableName} AS s JOIN {targetTableName} AS t ON {matchJoinCondition}"; + var bulkQueryResult = await Context.BulkQueryAsync(outputQuery, Connection, Transaction, Options, cancellationToken); + var autoGeneratedColumnList = autoGeneratedColumns.ToList(); + foreach (var result in bulkQueryResult.Results) + { + int entityId = Convert.ToInt32(result[0]); + string action = matchedIds.Contains(entityId) ? SqlMergeAction.Update : SqlMergeAction.Insert; + outputRows.Add(new BulkMergeOutputRow(action)); + + if (entityMap.TryGetValue(entityId, out var entity) && allProperties.Count > 0) + { + object[] entityValues = allProperties.Select(p => result[1 + autoGeneratedColumnList.IndexOf(p.GetColumnName())]).ToArray(); + Context.SetStoreGeneratedValues(entity, allProperties, entityValues); + } + } + } + } + } + + return new BulkMergeResult + { + Output = outputRows, + RowsAffected = rowsInserted.Values.FirstOrDefault() + rowsUpdated.Values.FirstOrDefault() + rowsDeleted.Values.Sum(), + RowsDeleted = rowsDeleted.Values.Sum(), + RowsInserted = rowsInserted.Values.FirstOrDefault(), + RowsUpdated = rowsUpdated.Values.FirstOrDefault() + }; + } + private async Task ExecuteUpdateMySqlAsync(Expression> updateOnCondition, CancellationToken cancellationToken) + { + int rowsUpdated = 0; + foreach (var entityType in TableMapping.EntityTypes) + { + IEnumerable columnsToUpdate = GetColumnNames(entityType); + string updateSetExpression = string.Join(",", columnsToUpdate.Select(c => $"t.{Context.DelimitIdentifier(c)}=s.{Context.DelimitIdentifier(c)}")); + string targetTableName = Context.DelimitIdentifier(entityType.GetTableName(), entityType.GetSchema() ?? Context.Database.GetDefaultSchema()); + string updateSql = $"UPDATE {StagingTableName} AS s INNER JOIN {targetTableName} AS t ON {CommonUtil.GetJoinConditionSql(Context, updateOnCondition, PrimaryKeyColumnNames, "s", "t")} SET {updateSetExpression}"; + rowsUpdated = await Context.Database.ExecuteSqlAsync(updateSql, Options.CommandTimeout, cancellationToken); + } + return rowsUpdated; + } + internal async Task PreallocateIdentityValuesAsync(IEnumerable entities, CancellationToken cancellationToken) + { + // No-op for MySQL + await Task.CompletedTask; + } + private async Task> GetMatchedInternalIdsAsync(string targetTableName, string joinCondition, CancellationToken cancellationToken) + { + var results = await Context.BulkQueryAsync( + $"SELECT {Context.DelimitMemberAccess("s", Constants.InternalId_ColumnName)} FROM {StagingTableName} AS s JOIN {targetTableName} AS t ON {joinCondition}", + Connection, Transaction, Options, cancellationToken); + return results.Results.Select(r => Convert.ToInt32(r[0])).ToHashSet(); + } + private async Task SyncMySqlAutoIncrementAsync(IEntityType entityType, CancellationToken cancellationToken) + { + var tableName = Context.DelimitIdentifier(entityType.GetTableName(), entityType.GetSchema() ?? Context.Database.GetDefaultSchema()); + await Context.Database.ExecuteSqlAsync($"ALTER TABLE {tableName} AUTO_INCREMENT = 1", Options.CommandTimeout, cancellationToken); + } +} diff --git a/N.EntityFramework.Extensions.MySql/Data/BulkOptions.cs b/N.EntityFramework.Extensions.MySql/Data/BulkOptions.cs new file mode 100644 index 0000000..6b9ae99 --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Data/BulkOptions.cs @@ -0,0 +1,18 @@ +using Microsoft.EntityFrameworkCore.Metadata; +using N.EntityFrameworkCore.Extensions.Enums; + +namespace N.EntityFrameworkCore.Extensions; + +public class BulkOptions +{ + public int BatchSize { get; set; } + public bool UsePermanentTable { get; set; } + public int? CommandTimeout { get; set; } + internal ConnectionBehavior ConnectionBehavior { get; set; } + internal IEntityType EntityType { get; set; } + + public BulkOptions() + { + ConnectionBehavior = ConnectionBehavior.Default; + } +} \ No newline at end of file diff --git a/N.EntityFramework.Extensions.MySql/Data/BulkQueryResult.cs b/N.EntityFramework.Extensions.MySql/Data/BulkQueryResult.cs new file mode 100644 index 0000000..267af6f --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Data/BulkQueryResult.cs @@ -0,0 +1,10 @@ +using System.Collections.Generic; + +namespace N.EntityFrameworkCore.Extensions; + +public class BulkQueryResult +{ + public IEnumerable Results { get; internal set; } + public IEnumerable Columns { get; internal set; } + public int RowsAffected { get; internal set; } +} \ No newline at end of file diff --git a/N.EntityFramework.Extensions.MySql/Data/BulkSyncOptions.cs b/N.EntityFramework.Extensions.MySql/Data/BulkSyncOptions.cs new file mode 100644 index 0000000..a255a2c --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Data/BulkSyncOptions.cs @@ -0,0 +1,10 @@ + +namespace N.EntityFrameworkCore.Extensions; + +public class BulkSyncOptions : BulkMergeOptions +{ + public BulkSyncOptions() + { + DeleteIfNotMatched = true; + } +} \ No newline at end of file diff --git a/N.EntityFramework.Extensions.MySql/Data/BulkSyncResult.cs b/N.EntityFramework.Extensions.MySql/Data/BulkSyncResult.cs new file mode 100644 index 0000000..852b5c7 --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Data/BulkSyncResult.cs @@ -0,0 +1,18 @@ + +namespace N.EntityFrameworkCore.Extensions; + +public class BulkSyncResult : BulkMergeResult +{ + public new int RowsDeleted { get; set; } + public static BulkSyncResult Map(BulkMergeResult result) + { + return new BulkSyncResult() + { + Output = result.Output, + RowsAffected = result.RowsAffected, + RowsDeleted = result.RowsDeleted, + RowsInserted = result.RowsInserted, + RowsUpdated = result.RowsUpdated + }; + } +} \ No newline at end of file diff --git a/N.EntityFramework.Extensions.MySql/Data/BulkUpdateOptions.cs b/N.EntityFramework.Extensions.MySql/Data/BulkUpdateOptions.cs new file mode 100644 index 0000000..f71fe82 --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Data/BulkUpdateOptions.cs @@ -0,0 +1,11 @@ +using System; +using System.Linq.Expressions; + +namespace N.EntityFrameworkCore.Extensions; + +public class BulkUpdateOptions : BulkOptions +{ + public Expression> InputColumns { get; set; } + public Expression> IgnoreColumns { get; set; } + public Expression> UpdateOnCondition { get; set; } +} \ No newline at end of file diff --git a/N.EntityFramework.Extensions.MySql/Data/DatabaseFacadeExtensions.cs b/N.EntityFramework.Extensions.MySql/Data/DatabaseFacadeExtensions.cs new file mode 100644 index 0000000..602cf86 --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Data/DatabaseFacadeExtensions.cs @@ -0,0 +1,175 @@ +using System; +using System.Collections.Generic; +using System.Data; +using System.Data.Common; +using System.Linq; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Storage; +using N.EntityFrameworkCore.Extensions.Enums; +using N.EntityFrameworkCore.Extensions.Util; + +namespace N.EntityFrameworkCore.Extensions; + +public static class DatabaseFacadeExtensions +{ + public static SqlQuery FromSqlQuery(this DatabaseFacade database, string sqlText, params object[] parameters) + { + return new SqlQuery(database, sqlText, parameters); + } + public static int ClearTable(this DatabaseFacade database, string tableName) + { + return database.ExecuteSqlRaw($"DELETE FROM {database.DelimitTableName(tableName)}"); + } + public static int DropTable(this DatabaseFacade database, string tableName, bool ifExists = false, bool isTemporary = false) + { + string formattedTableName = database.DelimitTableName(tableName); + // Use DROP TEMPORARY TABLE for MySQL temporary staging tables to avoid implicit transaction commit + string temporaryKeyword = isTemporary ? "TEMPORARY " : ""; + string sql = ifExists ? $"DROP {temporaryKeyword}TABLE IF EXISTS {formattedTableName}" : $"DROP {temporaryKeyword}TABLE {formattedTableName}"; + return database.ExecuteSqlInternal(sql, null, ConnectionBehavior.Default); + } + public static void TruncateTable(this DatabaseFacade database, string tableName, bool ifExists = false) + { + bool truncateTable = !ifExists || database.TableExists(tableName); + if (!truncateTable) + return; + + string formattedTableName = database.DelimitTableName(tableName); + // MySQL TRUNCATE automatically resets AUTO_INCREMENT; PostgreSQL needs RESTART IDENTITY + string sql = database.IsPostgreSql() + ? $"TRUNCATE TABLE {formattedTableName} RESTART IDENTITY" + : $"TRUNCATE TABLE {formattedTableName}"; + database.ExecuteSqlRaw(sql); + } + public static bool TableExists(this DatabaseFacade database, string tableName) + { + var objectName = database.ParseObjectName(tableName); + if (database.IsMySql()) + { + return Convert.ToBoolean(database.ExecuteScalar( + "SELECT EXISTS (SELECT 1 FROM information_schema.TABLES WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = @name)", + [CreateParameter(database, "@name", objectName.Name)])); + } + return Convert.ToBoolean(database.ExecuteScalar( + database.IsPostgreSql() + ? "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_schema = @schema AND table_name = @name)" + : "SELECT CASE WHEN EXISTS (SELECT 1 FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = @schema AND TABLE_NAME = @name) THEN 1 ELSE 0 END", + [CreateParameter(database, "@schema", objectName.Schema), CreateParameter(database, "@name", objectName.Name)])); + } + public static bool TableHasIdentity(this DatabaseFacade database, string tableName) + { + var objectName = database.ParseObjectName(tableName); + if (database.IsMySql()) + { + return Convert.ToBoolean(database.ExecuteScalar( + "SELECT EXISTS (SELECT 1 FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = @name AND EXTRA LIKE '%auto_increment%')", + [CreateParameter(database, "@name", objectName.Name)])); + } + string sql = database.IsPostgreSql() + ? """ + SELECT EXISTS ( + SELECT 1 + FROM information_schema.columns + WHERE table_schema = @schema + AND table_name = @name + AND (is_identity = 'YES' OR column_default LIKE 'nextval(%') + ) + """ + : "SELECT ISNULL(OBJECTPROPERTY(OBJECT_ID(@fullName), 'TableHasIdentity'), 0)"; + + object[] parameters = database.IsPostgreSql() + ? [CreateParameter(database, "@schema", objectName.Schema), CreateParameter(database, "@name", objectName.Name)] + : [CreateParameter(database, "@fullName", $"{objectName.Schema}.{objectName.Name}")]; + + return Convert.ToBoolean(database.ExecuteScalar(sql, parameters)); + } + internal static int CloneTable(this DatabaseFacade database, string sourceTable, string destinationTable, IEnumerable columnNames, string internalIdColumnName = null, bool isTemporary = false) + { + return database.CloneTable([sourceTable], destinationTable, columnNames, internalIdColumnName, isTemporary); + } + internal static int CloneTable(this DatabaseFacade database, IEnumerable sourceTables, string destinationTable, IEnumerable columnNames, string internalIdColumnName = null, bool isTemporary = false) + { + string columns = columnNames != null && columnNames.Any() ? string.Join(",", columnNames.Select(database.FormatSelectColumn)) : "*"; + if (!string.IsNullOrEmpty(internalIdColumnName)) + columns = $"{columns},CAST(NULL AS SIGNED) AS {database.DelimitIdentifier(internalIdColumnName)}"; + + // MySQL TEMPORARY tables do not cause implicit transaction commits (unlike regular DDL tables) + string createKeyword = database.IsMySql() && isTemporary ? "CREATE TEMPORARY TABLE" : "CREATE TABLE"; + string sql = database.IsMySql() + ? $"{createKeyword} {destinationTable} AS SELECT {columns} FROM {string.Join(",", sourceTables)} WHERE 1=0" + : database.IsPostgreSql() + ? $"CREATE TABLE {destinationTable} AS SELECT {columns} FROM {string.Join(",", sourceTables)} LIMIT 0" + : $"SELECT TOP 0 {columns} INTO {destinationTable} FROM {string.Join(",", sourceTables)}"; + return database.ExecuteSqlInternal(sql); + } + internal static DbCommand CreateCommand(this DatabaseFacade database, ConnectionBehavior connectionBehavior = ConnectionBehavior.Default) + { + var dbConnection = database.GetDbConnection(connectionBehavior); + if (dbConnection.State != ConnectionState.Open) + dbConnection.Open(); + var command = dbConnection.CreateCommand(); + if (database.CurrentTransaction != null && connectionBehavior == ConnectionBehavior.Default) + command.Transaction = database.CurrentTransaction.GetDbTransaction(); + return command; + } + internal static int ExecuteSqlInternal(this DatabaseFacade database, string sql, int? commandTimeout = null, ConnectionBehavior connectionBehavior = default) + { + return database.ExecuteSql(sql, null, commandTimeout, connectionBehavior); + } + internal static int ExecuteSql(this DatabaseFacade database, string sql, object[] parameters = null, int? commandTimeout = null, ConnectionBehavior connectionBehavior = default) + { + using var command = database.CreateCommand(connectionBehavior); + command.CommandText = sql; + if (commandTimeout != null) + command.CommandTimeout = commandTimeout.Value; + if (parameters != null) + command.Parameters.AddRange(parameters); + return command.ExecuteNonQuery(); + } + internal static object ExecuteScalar(this DatabaseFacade database, string query, object[] parameters = null, int? commandTimeout = null) + { + using var command = database.CreateCommand(); + command.CommandText = query; + if (commandTimeout.HasValue) + command.CommandTimeout = commandTimeout.Value; + if (parameters != null) + command.Parameters.AddRange(parameters); + return command.ExecuteScalar(); + } + internal static void ToggleIdentityInsert(this DatabaseFacade database, string tableName, bool enable) + { + if (database.IsPostgreSql() || database.IsMySql()) + return; + + bool hasIdentity = database.TableHasIdentity(tableName); + if (hasIdentity) + { + string boolString = enable ? "ON" : "OFF"; + database.ExecuteSql($"SET IDENTITY_INSERT {tableName} {boolString}"); + } + } + internal static DbConnection GetDbConnection(this DatabaseFacade database, ConnectionBehavior connectionBehavior) + { + return connectionBehavior == ConnectionBehavior.New ? database.GetDbConnection().CloneConnection() : database.GetDbConnection(); + } + + private static DbParameter CreateParameter(DatabaseFacade database, string name, object value) + { + using var command = database.GetDbConnection().CreateCommand(); + var parameter = command.CreateParameter(); + parameter.ParameterName = name; + parameter.Value = value ?? DBNull.Value; + return parameter; + } + internal static string FormatSelectColumn(this DatabaseFacade database, string columnName) + { + if (columnName.Contains('[') || columnName.Contains('"') || columnName.Contains('`') || columnName.Contains('(') || columnName.Contains(' ')) + return columnName; + + if (columnName.Contains('.')) + return string.Join(".", columnName.Split('.').Select(database.DelimitIdentifier)); + + return database.DelimitIdentifier(columnName); + } +} diff --git a/N.EntityFramework.Extensions.MySql/Data/DatabaseFacadeExtensionsAsync.cs b/N.EntityFramework.Extensions.MySql/Data/DatabaseFacadeExtensionsAsync.cs new file mode 100644 index 0000000..a340222 --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Data/DatabaseFacadeExtensionsAsync.cs @@ -0,0 +1,84 @@ +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Infrastructure; +using N.EntityFrameworkCore.Extensions.Util; + +namespace N.EntityFrameworkCore.Extensions; + +public static class DatabaseFacadeExtensionsAsync +{ + public static async Task ClearTableAsync(this DatabaseFacade database, string tableName, CancellationToken cancellationToken = default) + { + return await database.ExecuteSqlRawAsync($"DELETE FROM {database.DelimitTableName(tableName)}", cancellationToken); + } + public static async Task TruncateTableAsync(this DatabaseFacade database, string tableName, bool ifExists = false, CancellationToken cancellationToken = default) + { + bool truncateTable = !ifExists || database.TableExists(tableName); + if (!truncateTable) + return; + + string formattedTableName = database.DelimitTableName(tableName); + string sql = database.IsPostgreSql() + ? $"TRUNCATE TABLE {formattedTableName} RESTART IDENTITY" + : $"TRUNCATE TABLE {formattedTableName}"; + await database.ExecuteSqlRawAsync(sql, cancellationToken); + } + internal static async Task CloneTableAsync(this DatabaseFacade database, string sourceTable, string destinationTable, IEnumerable columnNames, string internalIdColumnName = null, CancellationToken cancellationToken = default, bool isTemporary = false) + { + return await database.CloneTableAsync([sourceTable], destinationTable, columnNames, internalIdColumnName, cancellationToken, isTemporary); + } + internal static async Task CloneTableAsync(this DatabaseFacade database, IEnumerable sourceTables, string destinationTable, IEnumerable columnNames, string internalIdColumnName = null, CancellationToken cancellationToken = default, bool isTemporary = false) + { + string columns = columnNames != null && columnNames.Any() ? string.Join(",", columnNames.Select(database.FormatSelectColumn)) : "*"; + if (!string.IsNullOrEmpty(internalIdColumnName)) + columns = $"{columns},CAST(NULL AS SIGNED) AS {database.DelimitIdentifier(internalIdColumnName)}"; + + string temporaryKeyword = isTemporary && database.IsMySql() ? "TEMPORARY " : ""; + string sql = database.IsMySql() + ? $"CREATE {temporaryKeyword}TABLE {destinationTable} AS SELECT {columns} FROM {string.Join(",", sourceTables)} WHERE 1=0" + : database.IsPostgreSql() + ? $"CREATE TABLE {destinationTable} AS SELECT {columns} FROM {string.Join(",", sourceTables)} LIMIT 0" + : $"SELECT TOP 0 {columns} INTO {destinationTable} FROM {string.Join(",", sourceTables)}"; + return await database.ExecuteSqlRawAsync(sql, cancellationToken); + } + internal static async Task ExecuteSqlAsync(this DatabaseFacade database, string sql, int? commandTimeout = null, CancellationToken cancellationToken = default) + { + return await database.ExecuteSqlAsync(sql, null, commandTimeout, cancellationToken); + } + internal static async Task ExecuteSqlAsync(this DatabaseFacade database, string sql, object[] parameters = null, int? commandTimeout = null, CancellationToken cancellationToken = default) + { + int value; + int? origCommandTimeout = database.GetCommandTimeout(); + database.SetCommandTimeout(commandTimeout); + value = parameters != null + ? await database.ExecuteSqlRawAsync(sql, parameters, cancellationToken) + : await database.ExecuteSqlRawAsync(sql, cancellationToken); + database.SetCommandTimeout(origCommandTimeout); + return value; + } + internal static async Task ExecuteScalarAsync(this DatabaseFacade database, string query, object[] parameters = null, int? commandTimeout = null, CancellationToken cancellationToken = default) + { + await using var command = database.CreateCommand(); + command.CommandText = query; + if (commandTimeout.HasValue) + command.CommandTimeout = commandTimeout.Value; + if (parameters != null) + command.Parameters.AddRange(parameters); + return await command.ExecuteScalarAsync(cancellationToken); + } + internal static async Task ToggleIdentityInsertAsync(this DatabaseFacade database, string tableName, bool enable) + { + if (database.IsPostgreSql() || database.IsMySql()) + return; + + bool hasIdentity = database.TableHasIdentity(tableName); + if (hasIdentity) + { + string boolString = enable ? "ON" : "OFF"; + await database.ExecuteSqlAsync($"SET IDENTITY_INSERT {tableName} {boolString}", database.GetCommandTimeout()); + } + } +} diff --git a/N.EntityFramework.Extensions.MySql/Data/DbContextExtensions.cs b/N.EntityFramework.Extensions.MySql/Data/DbContextExtensions.cs new file mode 100644 index 0000000..5e64877 --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Data/DbContextExtensions.cs @@ -0,0 +1,814 @@ +using System; +using System.Collections.Generic; +using System.Data; +using System.Data.Common; +using System.IO; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using MySqlConnector; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.ChangeTracking.Internal; +using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Internal; +using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.EntityFrameworkCore.Query; +using Microsoft.EntityFrameworkCore.Query.Internal; +using N.EntityFrameworkCore.Extensions.Common; +using N.EntityFrameworkCore.Extensions.Enums; +using N.EntityFrameworkCore.Extensions.Extensions; +using N.EntityFrameworkCore.Extensions.Sql; +using N.EntityFrameworkCore.Extensions.Util; + +namespace N.EntityFrameworkCore.Extensions; + +public static class DbContextExtensions +{ + private static readonly EfExtensionsCommandInterceptor efExtensionsCommandInterceptor; + static DbContextExtensions() + { + efExtensionsCommandInterceptor = new EfExtensionsCommandInterceptor(); + } + public static void SetupEfCoreExtensions(this DbContextOptionsBuilder builder) + { + builder.AddInterceptors(efExtensionsCommandInterceptor); + } + public static int BulkDelete(this DbContext context, IEnumerable entities) + { + return context.BulkDelete(entities, new BulkDeleteOptions()); + } + public static int BulkDelete(this DbContext context, IEnumerable entities, Action> optionsAction) + { + return context.BulkDelete(entities, optionsAction.Build()); + } + public static int BulkDelete(this DbContext context, IEnumerable entities, BulkDeleteOptions options) + { + var tableMapping = context.GetTableMapping(typeof(T), options.EntityType); + + using (var dbTransactionContext = new DbTransactionContext(context, options)) + { + var dbConnection = dbTransactionContext.Connection; + var transaction = dbTransactionContext.CurrentTransaction; + int rowsAffected = 0; + try + { + string stagingTableName = CommonUtil.GetStagingTableName(tableMapping, options.UsePermanentTable, dbConnection); + string destinationTableName = context.DelimitIdentifier(tableMapping.TableName, tableMapping.Schema); + string[] keyColumnNames = options.DeleteOnCondition != null ? CommonUtil.GetColumns(options.DeleteOnCondition, ["s"]) + : tableMapping.GetPrimaryKeyColumns().ToArray(); + + if (keyColumnNames.Length == 0 && options.DeleteOnCondition == null) + throw new InvalidDataException("BulkDelete requires that the entity have a primary key or the Options.DeleteOnCondition must be set."); + + context.Database.CloneTable(destinationTableName, stagingTableName, keyColumnNames, isTemporary: !options.UsePermanentTable); + BulkInsert(entities, options, tableMapping, dbConnection, transaction, stagingTableName, keyColumnNames, false); + + string joinCondition = CommonUtil.GetJoinConditionSql(context, options.DeleteOnCondition, keyColumnNames); + // MySQL multi-table DELETE syntax + string deleteSql = $"DELETE t FROM {stagingTableName} s JOIN {destinationTableName} t ON {joinCondition}"; + rowsAffected = context.Database.ExecuteSqlInternal(deleteSql, options.CommandTimeout); + + context.Database.DropTable(stagingTableName, isTemporary: !options.UsePermanentTable); + dbTransactionContext.Commit(); + } + catch (Exception) + { + dbTransactionContext.Rollback(); + throw; + } + return rowsAffected; + } + } + public static IEnumerable BulkFetch(this DbSet dbSet, IEnumerable entities) where T : class, new() + { + return dbSet.BulkFetch(entities, new BulkFetchOptions()); + } + public static IEnumerable BulkFetch(this DbSet dbSet, IEnumerable entities, Action> optionsAction) where T : class, new() + { + return dbSet.BulkFetch(entities, optionsAction.Build()); + } + public static IEnumerable BulkFetch(this DbSet dbSet, IEnumerable entities, BulkFetchOptions options) where T : class, new() + { + var context = dbSet.GetDbContext(); + var tableMapping = context.GetTableMapping(typeof(T)); + + using (var dbTransactionContext = new DbTransactionContext(context, options.CommandTimeout, ConnectionBehavior.New)) + { + string selectSql, stagingTableName = string.Empty; + var dbConnection = dbTransactionContext.Connection; + var transaction = dbTransactionContext.CurrentTransaction; + try + { + stagingTableName = CommonUtil.GetStagingTableName(tableMapping, true, dbConnection); + string destinationTableName = context.DelimitIdentifier(tableMapping.TableName, tableMapping.Schema); + string[] keyColumnNames = options.JoinOnCondition != null ? CommonUtil.GetColumns(options.JoinOnCondition, ["s"]) + : tableMapping.GetPrimaryKeyColumns().ToArray(); + IEnumerable columnNames = CommonUtil.FilterColumns(tableMapping.GetColumns(true), keyColumnNames, options.InputColumns, options.IgnoreColumns); + IEnumerable columnsToFetch = CommonUtil.FormatColumns(context, "t", columnNames); + + if (keyColumnNames.Length == 0 && options.JoinOnCondition == null) + throw new InvalidDataException("BulkFetch requires that the entity have a primary key or the Options.JoinOnCondition must be set."); + + context.Database.CloneTable(destinationTableName, stagingTableName, keyColumnNames); + BulkInsert(entities, options, tableMapping, dbConnection, transaction, stagingTableName, keyColumnNames, false); + selectSql = $"SELECT {SqlUtil.ConvertToColumnString(columnsToFetch)} FROM {stagingTableName} s JOIN {destinationTableName} t ON {CommonUtil.GetJoinConditionSql(context, options.JoinOnCondition, keyColumnNames)}"; + + + dbTransactionContext.Commit(); + } + catch + { + dbTransactionContext.Rollback(); + throw; + } + + foreach (var item in context.FetchInternal(selectSql)) + { + yield return item; + } + context.Database.DropTable(stagingTableName); + } + } + public static void Fetch(this IQueryable queryable, Action> action, Action> optionsAction) where T : class, new() + { + Fetch(queryable, action, optionsAction.Build()); + } + public static void Fetch(this IQueryable queryable, Action> action, FetchOptions options) where T : class, new() + { + var dbContext = queryable.GetDbContext(); + var tableMapping = dbContext.GetTableMapping(typeof(T)); + HashSet includedColumns = GetIncludedColumns(tableMapping, options.InputColumns, options.IgnoreColumns); + int batch = 1; + int count = 0; + List entities = []; + foreach (var entity in queryable.AsNoTracking().AsEnumerable()) + { + ClearExcludedColumns(dbContext, tableMapping, entity, includedColumns); + entities.Add(entity); + count++; + if (count == options.BatchSize) + { + action(new FetchResult { Results = entities, Batch = batch }); + entities.Clear(); + count = 0; + batch++; + } + } + + if (entities.Count > 0) + action(new FetchResult { Results = entities, Batch = batch }); + } + public static int BulkInsert(this DbContext context, IEnumerable entities) + { + return context.BulkInsert(entities, new BulkInsertOptions()); + } + public static int BulkInsert(this DbContext context, IEnumerable entities, Action> optionsAction) + { + return context.BulkInsert(entities, optionsAction.Build()); + } + public static int BulkInsert(this DbContext context, IEnumerable entities, BulkInsertOptions options) + { + int rowsAffected = 0; + using (var bulkOperation = new BulkOperation(context, options, options.InputColumns, options.IgnoreColumns)) + { + try + { + bool keepIdentity = options.KeepIdentity || bulkOperation.ShouldPreallocateIdentityValues(options.AutoMapOutput, options.KeepIdentity, entities); + if (keepIdentity && !options.KeepIdentity) + bulkOperation.PreallocateIdentityValues(entities); + var bulkInsertResult = bulkOperation.BulkInsertStagingData(entities, true, true); + var bulkMergeResult = bulkOperation.ExecuteMerge(bulkInsertResult.EntityMap, options.InsertOnCondition, + options.AutoMapOutput, keepIdentity, options.InsertIfNotExists); + rowsAffected = bulkMergeResult.RowsAffected; + bulkOperation.DbTransactionContext.Commit(); + } + catch (Exception) + { + bulkOperation.DbTransactionContext.Rollback(); + throw; + } + } + return rowsAffected; + } + public static BulkMergeResult BulkMerge(this DbContext context, IEnumerable entities) + { + return BulkMerge(context, entities, new BulkMergeOptions()); + } + public static BulkMergeResult BulkMerge(this DbContext context, IEnumerable entities, BulkMergeOptions options) + { + return InternalBulkMerge(context, entities, options); + } + public static BulkMergeResult BulkMerge(this DbContext context, IEnumerable entities, Action> optionsAction) + { + return BulkMerge(context, entities, optionsAction.Build()); + } + public static int BulkSaveChanges(this DbContext dbContext) + { + return dbContext.BulkSaveChanges(true); + } + public static int BulkSaveChanges(this DbContext dbContext, bool acceptAllChangesOnSuccess = true) + { + int rowsAffected = 0; + var stateManager = dbContext.GetDependencies().StateManager; + + dbContext.ChangeTracker.DetectChanges(); + var entries = stateManager.GetEntriesToSave(true); + + foreach (var saveEntryGroup in entries.GroupBy(o => new { o.EntityType, o.EntityState })) + { + var key = saveEntryGroup.Key; + var entities = saveEntryGroup.AsEnumerable(); + if (key.EntityState == EntityState.Added) + { + rowsAffected += dbContext.BulkInsert(entities, o => { o.EntityType = key.EntityType; }); + } + else if (key.EntityState == EntityState.Modified) + { + rowsAffected += dbContext.BulkUpdate(entities, o => { o.EntityType = key.EntityType; }); + } + else if (key.EntityState == EntityState.Deleted) + { + rowsAffected += dbContext.BulkDelete(entities, o => { o.EntityType = key.EntityType; }); + } + } + + if (acceptAllChangesOnSuccess) + dbContext.ChangeTracker.AcceptAllChanges(); + + return rowsAffected; + } + public static BulkSyncResult BulkSync(this DbContext context, IEnumerable entities) + { + return BulkSync(context, entities, new BulkSyncOptions()); + } + public static BulkSyncResult BulkSync(this DbContext context, IEnumerable entities, Action> optionsAction) + { + return BulkSyncResult.Map(InternalBulkMerge(context, entities, optionsAction.Build())); + } + public static BulkSyncResult BulkSync(this DbContext context, IEnumerable entities, BulkSyncOptions options) + { + return BulkSyncResult.Map(InternalBulkMerge(context, entities, options)); + } + public static int BulkUpdate(this DbContext context, IEnumerable entities) + { + return BulkUpdate(context, entities, new BulkUpdateOptions()); + } + public static int BulkUpdate(this DbContext context, IEnumerable entities, Action> optionsAction) + { + return BulkUpdate(context, entities, optionsAction.Build()); + } + public static int BulkUpdate(this DbContext context, IEnumerable entities, BulkUpdateOptions options) + { + int rowsUpdated = 0; + using (var bulkOperation = new BulkOperation(context, options, options.InputColumns, options.IgnoreColumns)) + { + try + { + bulkOperation.ValidateBulkUpdate(options.UpdateOnCondition); + bulkOperation.BulkInsertStagingData(entities); + rowsUpdated = bulkOperation.ExecuteUpdate(entities, options.UpdateOnCondition); + bulkOperation.DbTransactionContext.Commit(); + } + catch (Exception) + { + bulkOperation.DbTransactionContext.Rollback(); + throw; + } + } + return rowsUpdated; + } + public static int DeleteFromQuery(this IQueryable queryable, int? commandTimeout = null) where T : class + { + using (var dbTransactionContext = new DbTransactionContext(queryable.GetDbContext(), commandTimeout)) + { + try + { + int rowsAffected = queryable.ExecuteDelete(); + dbTransactionContext.Commit(); + return rowsAffected; + } + catch (Exception) + { + dbTransactionContext.Rollback(); + throw; + } + } + } + public static int InsertFromQuery(this IQueryable queryable, string tableName, Expression> insertObjectExpression, int? commandTimeout = null) where T : class + { + using (var dbTransactionContext = new DbTransactionContext(queryable.GetDbContext(), commandTimeout)) + { + var dbContext = dbTransactionContext.DbContext; + try + { + var tableMapping = dbContext.GetTableMapping(typeof(T)); + var columnNames = insertObjectExpression.GetObjectProperties(); + if (!dbContext.Database.TableExists(tableName)) + { + dbContext.Database.CloneTable(tableMapping.FullQualifedTableName, dbContext.Database.DelimitTableName(tableName), tableMapping.GetQualifiedColumnNames(columnNames)); + } + + var entities = queryable.AsNoTracking().ToList(); + int rowsAffected = BulkInsert(entities, new BulkInsertOptions { KeepIdentity = true, AutoMapOutput = false, CommandTimeout = commandTimeout }, tableMapping, + dbTransactionContext.Connection, dbTransactionContext.CurrentTransaction, dbContext.Database.DelimitTableName(tableName), columnNames).RowsAffected; + dbTransactionContext.Commit(); + return rowsAffected; + } + catch (Exception) + { + dbTransactionContext.Rollback(); + throw; + } + } + } + public static int UpdateFromQuery(this IQueryable queryable, Expression> updateExpression, int? commandTimeout = null) where T : class + { + using (var dbTransactionContext = new DbTransactionContext(queryable.GetDbContext(), commandTimeout)) + { + try + { + int rowsAffected = queryable.ExecuteUpdate(BuildSetPropertyCalls(updateExpression)); + dbTransactionContext.Commit(); + return rowsAffected; + } + catch (Exception) + { + dbTransactionContext.Rollback(); + throw; + } + } + } + public static QueryToFileResult QueryToCsvFile(this IQueryable queryable, string filePath) where T : class + { + return QueryToCsvFile(queryable, filePath, new QueryToFileOptions()); + } + public static QueryToFileResult QueryToCsvFile(this IQueryable queryable, Stream stream) where T : class + { + return QueryToCsvFile(queryable, stream, new QueryToFileOptions()); + } + public static QueryToFileResult QueryToCsvFile(this IQueryable queryable, string filePath, Action optionsAction) where T : class + { + return QueryToCsvFile(queryable, filePath, optionsAction.Build()); + } + public static QueryToFileResult QueryToCsvFile(this IQueryable queryable, Stream stream, Action optionsAction) where T : class + { + return QueryToCsvFile(queryable, stream, optionsAction.Build()); + } + public static QueryToFileResult QueryToCsvFile(this IQueryable queryable, string filePath, QueryToFileOptions options) where T : class + { + using var fileStream = File.Create(filePath); + return QueryToCsvFile(queryable, fileStream, options); + } + public static QueryToFileResult QueryToCsvFile(this IQueryable queryable, Stream stream, QueryToFileOptions options) where T : class + { + return InternalQueryToFile(queryable, stream, options); + } + public static QueryToFileResult SqlQueryToCsvFile(this DatabaseFacade database, string filePath, string sqlText, params object[] parameters) + { + return SqlQueryToCsvFile(database, filePath, new QueryToFileOptions(), sqlText, parameters); + } + public static QueryToFileResult SqlQueryToCsvFile(this DatabaseFacade database, Stream stream, string sqlText, params object[] parameters) + { + return SqlQueryToCsvFile(database, stream, new QueryToFileOptions(), sqlText, parameters); + } + public static QueryToFileResult SqlQueryToCsvFile(this DatabaseFacade database, string filePath, Action optionsAction, string sqlText, params object[] parameters) + { + return SqlQueryToCsvFile(database, filePath, optionsAction.Build(), sqlText, parameters); + } + public static QueryToFileResult SqlQueryToCsvFile(this DatabaseFacade database, Stream stream, Action optionsAction, string sqlText, params object[] parameters) + { + return SqlQueryToCsvFile(database, stream, optionsAction.Build(), sqlText, parameters); + } + public static QueryToFileResult SqlQueryToCsvFile(this DatabaseFacade database, string filePath, QueryToFileOptions options, string sqlText, params object[] parameters) + { + using var fileStream = File.Create(filePath); + return SqlQueryToCsvFile(database, fileStream, options, sqlText, parameters); + } + public static QueryToFileResult SqlQueryToCsvFile(this DatabaseFacade database, Stream stream, QueryToFileOptions options, string sqlText, params object[] parameters) + { + return InternalQueryToFile(database.GetDbConnection(), stream, options, sqlText, parameters); + } + public static void Clear(this DbSet dbSet) where T : class + { + var dbContext = dbSet.GetDbContext(); + var tableMapping = dbContext.GetTableMapping(typeof(T)); + dbContext.Database.ClearTable(tableMapping.FullQualifedTableName); + } + public static void Truncate(this DbSet dbSet) where T : class + { + var dbContext = dbSet.GetDbContext(); + var tableMapping = dbContext.GetTableMapping(typeof(T)); + dbContext.Database.TruncateTable(tableMapping.FullQualifedTableName); + } + public static IQueryable UsingTable(this IQueryable queryable, string tableName) where T : class + { + var dbContext = queryable.GetDbContext(); + var tableMapping = dbContext.GetTableMapping(typeof(T)); + efExtensionsCommandInterceptor.AddCommand(Guid.NewGuid(), + new EfExtensionsCommand + { + CommandType = EfExtensionsCommandType.ChangeTableName, + OldValue = tableMapping.FullQualifedTableName, + NewValue = dbContext.Database.DelimitTableName(tableName), + Connection = dbContext.GetDbConnection() + }); + return queryable; + } + public static TableMapping GetTableMapping(this DbContext dbContext, Type type, IEntityType entityType = null) + { + entityType ??= dbContext.Model.FindEntityType(type); + return new TableMapping(dbContext, entityType); + } + internal static void SetStoreGeneratedValues(this DbContext context, T entity, IEnumerable properties, object[] values) + { + int index = 0; + var updateEntry = entity as InternalEntityEntry; + if (updateEntry == null) + { + var entry = context.Entry(entity); + updateEntry = entry.GetInfrastructure(); + } + + if (updateEntry != null) + { + foreach (var property in properties) + { + if ((updateEntry.EntityState == EntityState.Added && + (property.ValueGenerated == ValueGenerated.OnAdd || property.ValueGenerated == ValueGenerated.OnAddOrUpdate)) || + (updateEntry.EntityState == EntityState.Modified && + (property.ValueGenerated == ValueGenerated.OnUpdate || property.ValueGenerated == ValueGenerated.OnAddOrUpdate)) || + updateEntry.EntityState == EntityState.Detached + ) + { + updateEntry.SetStoreGeneratedValue(property, values[index]); + } + index++; + } + if (updateEntry.EntityState == EntityState.Detached) + updateEntry.AcceptChanges(); + } + else + { + throw new InvalidOperationException("SetStoreValues() failed because an instance of InternalEntityEntry was not found."); + } + } + internal static BulkInsertResult BulkInsert(IEnumerable entities, BulkOptions options, TableMapping tableMapping, DbConnection dbConnection, DbTransaction transaction, string tableName, + IEnumerable inputColumns = null, bool useInternalId = false) + { + using var dataReader = new EntityDataReader(tableMapping, entities, useInternalId); + if (dbConnection is MySqlConnection mySqlConnection) + { + var includeColumns = BuildIncludeColumns(dataReader, inputColumns, useInternalId); + if (includeColumns.Count == 0) + return new BulkInsertResult { RowsAffected = 0, EntityMap = dataReader.EntityMap }; + + string destTable = UnwrapTableName(tableName); + string columnList = string.Join(",", includeColumns.Select(c => $"`{c.name}`")); + const int batchSize = 500; + int totalInserted = 0; + var rowBuffer = new List(batchSize); + + using var cmd = mySqlConnection.CreateCommand(); + cmd.Transaction = transaction as MySqlTransaction; + if (options.CommandTimeout.HasValue) + cmd.CommandTimeout = options.CommandTimeout.Value; + + void FlushBatch() + { + if (rowBuffer.Count == 0) return; + cmd.Parameters.Clear(); + var sb = new System.Text.StringBuilder($"INSERT INTO `{destTable}` ({columnList}) VALUES "); + for (int r = 0; r < rowBuffer.Count; r++) + { + if (r > 0) sb.Append(','); + sb.Append('('); + for (int c = 0; c < includeColumns.Count; c++) + { + if (c > 0) sb.Append(','); + string paramName = $"@p{r}_{c}"; + sb.Append(paramName); + cmd.Parameters.AddWithValue(paramName, rowBuffer[r][c] ?? DBNull.Value); + } + sb.Append(')'); + } + cmd.CommandText = sb.ToString(); + totalInserted += cmd.ExecuteNonQuery(); + rowBuffer.Clear(); + } + + while (dataReader.Read()) + { + var rowData = new object[includeColumns.Count]; + for (int i = 0; i < includeColumns.Count; i++) + rowData[i] = dataReader.GetValue(includeColumns[i].ordinal) ?? DBNull.Value; + rowBuffer.Add(rowData); + if (rowBuffer.Count >= batchSize) + FlushBatch(); + } + FlushBatch(); + + return new BulkInsertResult + { + RowsAffected = totalInserted, + EntityMap = dataReader.EntityMap + }; + } + + throw new NotSupportedException($"The connection type '{dbConnection.GetType().Name}' is not supported for BulkInsert. Use a MySqlConnection."); + } + internal static List<(int ordinal, string name)> BuildIncludeColumns(EntityDataReader dataReader, IEnumerable inputColumns, bool useInternalId) + { + var includeColumns = new List<(int ordinal, string name)>(); + int colIdx = 0; + foreach (var property in dataReader.TableMapping.Properties) + { + var columnName = dataReader.TableMapping.GetColumnName(property); + if (inputColumns == null || inputColumns.Contains(columnName)) + includeColumns.Add((colIdx, columnName)); + colIdx++; + } + if (useInternalId) + includeColumns.Add((colIdx, Constants.InternalId_ColumnName)); + return includeColumns; + } + internal static string UnwrapTableName(string tableName) => tableName.Replace("`", ""); + internal static BulkQueryResult BulkQuery(this DbContext context, string sqlText, BulkOptions options) + { + List results = []; + List columns = []; + using var command = context.Database.CreateCommand(); + command.CommandText = sqlText; + if (options.CommandTimeout.HasValue) + command.CommandTimeout = options.CommandTimeout.Value; + using var reader = command.ExecuteReader(); + while (reader.Read()) + { + if (columns.Count == 0) + { + for (int i = 0; i < reader.FieldCount; i++) + columns.Add(reader.GetName(i)); + } + object[] values = new object[reader.FieldCount]; + reader.GetValues(values); + results.Add(values); + } + + return new BulkQueryResult + { + Columns = columns, + Results = results, + RowsAffected = reader.RecordsAffected + }; + } + internal static DbContext GetDbContext(this IQueryable queryable) where T : class + { + DbContext dbContext; + try + { + if ((queryable as InternalDbSet) != null) + { + dbContext = queryable.GetPrivateFieldValue("_context") as DbContext; + } + else if ((queryable as EntityQueryable) != null) + { + var queryCompiler = queryable.Provider.GetPrivateFieldValue("_queryCompiler"); + var contextFactory = queryCompiler.GetPrivateFieldValue("_queryContextFactory"); + var queryDependencies = contextFactory.GetPrivateFieldValue("Dependencies") as QueryContextDependencies; + dbContext = queryDependencies.CurrentContext.Context as DbContext; + } + else + { + throw new Exception("This extension method could not find the DbContext for this type that implements IQueryable"); + } + } + catch + { + throw new Exception("This extension method could not find the DbContext for this type that implements IQueryable"); + } + return dbContext; + } + internal static DbConnection GetDbConnection(this DbContext context, ConnectionBehavior connectionBehavior = ConnectionBehavior.Default) + { + var dbConnection = context.Database.GetDbConnection(); + return connectionBehavior == ConnectionBehavior.New ? dbConnection.CloneConnection() : dbConnection; + } + private static IEnumerable FetchInternal(this DbContext dbContext, string sqlText, object[] parameters = null) where T : class, new() + { + using var command = dbContext.Database.CreateCommand(ConnectionBehavior.New); + command.CommandText = sqlText; + if (parameters != null) + command.Parameters.AddRange(parameters); + + var tableMapping = dbContext.GetTableMapping(typeof(T), null); + using var reader = command.ExecuteReader(); + var properties = reader.GetProperties(tableMapping); + var valuesFromProvider = properties.Select(p => tableMapping.GetValueFromProvider(p)).ToArray(); + + while (reader.Read()) + { + var entity = reader.MapEntity(dbContext, properties, valuesFromProvider); + yield return entity; + } + } + private static BulkMergeResult InternalBulkMerge(this DbContext context, IEnumerable entities, BulkMergeOptions options) + { + BulkMergeResult bulkMergeResult; + using (var bulkOperation = new BulkOperation(context, options)) + { + try + { + bool shouldPreallocate = bulkOperation.ShouldPreallocateIdentityValues(true, false, entities); + bool keepIdentity = shouldPreallocate || bulkOperation.ShouldKeepIdentityForMerge(); + if (shouldPreallocate) + bulkOperation.PreallocateIdentityValues(entities); + bulkOperation.ValidateBulkMerge(options.MergeOnCondition); + var bulkInsertResult = bulkOperation.BulkInsertStagingData(entities, true, true); + bulkMergeResult = bulkOperation.ExecuteMerge(bulkInsertResult.EntityMap, options.MergeOnCondition, options.AutoMapOutput, + keepIdentity, true, true, options.DeleteIfNotMatched, shouldPreallocate); + bulkOperation.DbTransactionContext.Commit(); + } + catch (Exception) + { + bulkOperation.DbTransactionContext.Rollback(); + throw; + } + } + return bulkMergeResult; + } + private static void ClearEntityStateToUnchanged(DbContext dbContext, IEnumerable entities) + { + foreach (var entity in entities) + { + var entry = dbContext.Entry(entity); + if (entry.State == EntityState.Added || entry.State == EntityState.Modified) + dbContext.Entry(entity).State = EntityState.Unchanged; + } + } + private static void Validate(TableMapping tableMapping) + { + if (!tableMapping.GetPrimaryKeyColumns().Any()) + { + throw new Exception("You must have a primary key on this table to use this function."); + } + } + private static QueryToFileResult InternalQueryToFile(this IQueryable queryable, Stream stream, QueryToFileOptions options) where T : class + { + return InternalQueryToFile(queryable.AsNoTracking().AsEnumerable(), stream, options); + } + private static QueryToFileResult InternalQueryToFile(DbConnection dbConnection, Stream stream, QueryToFileOptions options, string sqlText, object[] parameters = null) + { + int dataRowCount = 0; + int totalRowCount = 0; + long bytesWritten = 0; + + if (dbConnection.State == ConnectionState.Closed) + dbConnection.Open(); + + using var command = dbConnection.CreateCommand(); + command.CommandText = sqlText; + if (parameters != null) + command.Parameters.AddRange(parameters); + if (options.CommandTimeout.HasValue) + command.CommandTimeout = options.CommandTimeout.Value; + + using var streamWriter = new StreamWriter(stream, leaveOpen: true); + using (var reader = command.ExecuteReader()) + { + if (options.IncludeHeaderRow) + { + for (int i = 0; i < reader.FieldCount; i++) + { + streamWriter.Write(options.TextQualifer); + streamWriter.Write(reader.GetName(i)); + streamWriter.Write(options.TextQualifer); + if (i != reader.FieldCount - 1) + { + streamWriter.Write(options.ColumnDelimiter); + } + } + totalRowCount++; + streamWriter.Write(options.RowDelimiter); + } + while (reader.Read()) + { + object[] values = new object[reader.FieldCount]; + reader.GetValues(values); + for (int i = 0; i < values.Length; i++) + { + streamWriter.Write(options.TextQualifer); + streamWriter.Write(values[i]); + streamWriter.Write(options.TextQualifer); + if (i != values.Length - 1) + { + streamWriter.Write(options.ColumnDelimiter); + } + } + streamWriter.Write(options.RowDelimiter); + dataRowCount++; + totalRowCount++; + } + streamWriter.Flush(); + bytesWritten = streamWriter.BaseStream.Length; + } + return new QueryToFileResult() + { + BytesWritten = bytesWritten, + DataRowCount = dataRowCount, + TotalRowCount = totalRowCount + }; + } + private static QueryToFileResult InternalQueryToFile(IEnumerable entities, Stream stream, QueryToFileOptions options) + { + int dataRowCount = 0; + int totalRowCount = 0; + long bytesWritten = 0; + var properties = typeof(T).GetProperties().Where(p => p.CanRead && !typeof(System.Collections.IEnumerable).IsAssignableFrom(p.PropertyType) || p.PropertyType == typeof(string)).ToArray(); + + using var streamWriter = new StreamWriter(stream, leaveOpen: true); + if (options.IncludeHeaderRow) + { + WriteCsvRow(streamWriter, properties.Select(p => p.Name), options); + totalRowCount++; + } + + foreach (var entity in entities) + { + WriteCsvRow(streamWriter, properties.Select(p => p.GetValue(entity)), options); + dataRowCount++; + totalRowCount++; + } + + streamWriter.Flush(); + bytesWritten = streamWriter.BaseStream.Length; + return new QueryToFileResult { BytesWritten = bytesWritten, DataRowCount = dataRowCount, TotalRowCount = totalRowCount }; + } + private static HashSet GetIncludedColumns(TableMapping tableMapping, Expression> inputColumns, Expression> ignoreColumns) + { + var includedColumns = inputColumns != null + ? inputColumns.GetObjectProperties().ToHashSet() + : tableMapping.Properties.Select(p => p.Name).ToHashSet(); + + if (ignoreColumns != null) + includedColumns.ExceptWith(ignoreColumns.GetObjectProperties()); + + return includedColumns; + } + private static void ClearExcludedColumns(DbContext dbContext, TableMapping tableMapping, T entity, HashSet includedColumns) where T : class + { + var entry = dbContext.Entry(entity); + foreach (var property in tableMapping.Properties) + { + if (includedColumns.Contains(property.Name)) + continue; + + object defaultValue = property.ClrType.IsValueType ? Activator.CreateInstance(property.ClrType) : null; + if (property.DeclaringType is IComplexType complexType) + { + var complexProperty = entry.ComplexProperty(complexType.ComplexProperty); + if (complexProperty.CurrentValue != null) + complexProperty.Property(property).CurrentValue = defaultValue; + } + else + { + entry.Property(property.Name).CurrentValue = defaultValue; + } + } + } + private static void WriteCsvRow(TextWriter writer, IEnumerable values, QueryToFileOptions options) + { + bool first = true; + foreach (var value in values) + { + if (!first) + writer.Write(options.ColumnDelimiter); + + writer.Write(options.TextQualifer); + writer.Write(value); + writer.Write(options.TextQualifer); + first = false; + } + writer.Write(options.RowDelimiter); + } + private static Expression, SetPropertyCalls>> BuildSetPropertyCalls(Expression> updateExpression) where T : class + { + if (updateExpression.Body is not MemberInitExpression memberInitExpression) + throw new InvalidOperationException("UpdateFromQuery requires a member initialization expression."); + + var entityParameter = updateExpression.Parameters[0]; + var callsParam = Expression.Parameter(typeof(SetPropertyCalls), "calls"); + var setPropertyMethod = typeof(SetPropertyCalls) + .GetMethods() + .Single(m => m.Name == nameof(SetPropertyCalls.SetProperty) && m.GetParameters().Length == 2 && m.GetParameters()[1].ParameterType.IsGenericType); + + Expression current = callsParam; + foreach (var binding in memberInitExpression.Bindings.OfType()) + { + var propertyInfo = binding.Member as PropertyInfo ?? throw new InvalidOperationException("Only property bindings are supported."); + var propertyLambda = Expression.Lambda(Expression.Property(entityParameter, propertyInfo), entityParameter); + var valueLambda = Expression.Lambda(binding.Expression, entityParameter); + current = Expression.Call(current, setPropertyMethod.MakeGenericMethod(propertyInfo.PropertyType), propertyLambda, valueLambda); + } + + return Expression.Lambda, SetPropertyCalls>>(current, callsParam); + } +} diff --git a/N.EntityFramework.Extensions.MySql/Data/DbContextExtensionsAsync.cs b/N.EntityFramework.Extensions.MySql/Data/DbContextExtensionsAsync.cs new file mode 100644 index 0000000..037ea69 --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Data/DbContextExtensionsAsync.cs @@ -0,0 +1,709 @@ +using System; +using System.Collections.Generic; +using System.Data; +using System.Data.Common; +using System.IO; +using System.Linq; +using System.Linq.Expressions; +using System.Threading; +using System.Threading.Tasks; +using MySqlConnector; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Internal; +using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.EntityFrameworkCore.Query; +using N.EntityFrameworkCore.Extensions.Common; +using N.EntityFrameworkCore.Extensions.Enums; +using N.EntityFrameworkCore.Extensions.Extensions; +using N.EntityFrameworkCore.Extensions.Sql; +using N.EntityFrameworkCore.Extensions.Util; + +namespace N.EntityFrameworkCore.Extensions; + +public static class DbContextExtensionsAsync +{ + public static async Task BulkDeleteAsync(this DbContext context, IEnumerable entities, CancellationToken cancellationToken = default) + { + return await context.BulkDeleteAsync(entities, new BulkDeleteOptions(), cancellationToken); + } + public static async Task BulkDeleteAsync(this DbContext context, IEnumerable entities, Action> optionsAction, CancellationToken cancellationToken = default) + { + return await context.BulkDeleteAsync(entities, optionsAction.Build(), cancellationToken); + } + public static async Task BulkDeleteAsync(this DbContext context, IEnumerable entities, BulkDeleteOptions options, CancellationToken cancellationToken = default) + { + int rowsAffected = 0; + var tableMapping = context.GetTableMapping(typeof(T), options.EntityType); + + using (var dbTransactionContext = new DbTransactionContext(context, options)) + { + var dbConnection = dbTransactionContext.Connection; + var transaction = dbTransactionContext.CurrentTransaction; + try + { + string stagingTableName = CommonUtil.GetStagingTableName(tableMapping, options.UsePermanentTable, dbConnection); + string destinationTableName = context.DelimitIdentifier(tableMapping.TableName, tableMapping.Schema); + string[] keyColumnNames = options.DeleteOnCondition != null ? CommonUtil.GetColumns(options.DeleteOnCondition, ["s"]) + : tableMapping.GetPrimaryKeyColumns().ToArray(); + + if (keyColumnNames.Length == 0 && options.DeleteOnCondition == null) + throw new InvalidDataException("BulkDelete requires that the entity have a primary key or the Options.DeleteOnCondition must be set."); + + await context.Database.CloneTableAsync(destinationTableName, stagingTableName, keyColumnNames, null, cancellationToken, isTemporary: !options.UsePermanentTable); + await BulkInsertAsync(entities, options, tableMapping, dbConnection, transaction, stagingTableName, keyColumnNames, + false, cancellationToken); + string joinCondition = CommonUtil.GetJoinConditionSql(context, options.DeleteOnCondition, keyColumnNames); + // MySQL multi-table DELETE syntax + string deleteSql = $"DELETE t FROM {stagingTableName} s JOIN {destinationTableName} t ON {joinCondition}"; + rowsAffected = await context.Database.ExecuteSqlRawAsync(deleteSql, cancellationToken); + context.Database.DropTable(stagingTableName, isTemporary: !options.UsePermanentTable); + dbTransactionContext.Commit(); + } + catch (Exception) + { + dbTransactionContext.Rollback(); + throw; + } + return rowsAffected; + } + } + public static async Task> BulkFetchAsync(this DbSet dbSet, IEnumerable entities, CancellationToken cancellationToken = default) where T : class, new() + { + return await dbSet.BulkFetchAsync(entities, new BulkFetchOptions(), cancellationToken); + } + public static async Task> BulkFetchAsync(this DbSet dbSet, IEnumerable entities, Action> optionsAction, CancellationToken cancellationToken = default) where T : class, new() + { + return await dbSet.BulkFetchAsync(entities, optionsAction.Build(), cancellationToken); + } + public static async Task> BulkFetchAsync(this DbSet dbSet, IEnumerable entities, BulkFetchOptions options, CancellationToken cancellationToken = default) where T : class, new() + { + var context = dbSet.GetDbContext(); + var tableMapping = context.GetTableMapping(typeof(T)); + + using (var dbTransactionContext = new DbTransactionContext(context, options.CommandTimeout, ConnectionBehavior.New)) + { + string selectSql; + var dbConnection = dbTransactionContext.Connection; + var transaction = dbTransactionContext.CurrentTransaction; + string stagingTableName = string.Empty; + try + { + stagingTableName = CommonUtil.GetStagingTableName(tableMapping, true, dbConnection); + string destinationTableName = context.DelimitIdentifier(tableMapping.TableName, tableMapping.Schema); + string[] keyColumnNames = options.JoinOnCondition != null ? CommonUtil.GetColumns(options.JoinOnCondition, ["s"]) + : tableMapping.GetPrimaryKeyColumns().ToArray(); + IEnumerable columnNames = CommonUtil.FilterColumns(tableMapping.GetColumns(true), keyColumnNames, options.InputColumns, options.IgnoreColumns); + IEnumerable columnsToFetch = CommonUtil.FormatColumns(context, "t", columnNames); + + if (keyColumnNames.Length == 0 && options.JoinOnCondition == null) + throw new InvalidDataException("BulkFetch requires that the entity have a primary key or the Options.JoinOnCondition must be set."); + + await context.Database.CloneTableAsync(destinationTableName, stagingTableName, keyColumnNames, null, cancellationToken); + await BulkInsertAsync(entities, options, tableMapping, dbConnection, transaction, stagingTableName, keyColumnNames, false, cancellationToken); + selectSql = $"SELECT {SqlUtil.ConvertToColumnString(columnsToFetch)} FROM {stagingTableName} s JOIN {destinationTableName} t ON {CommonUtil.GetJoinConditionSql(context, options.JoinOnCondition, keyColumnNames)}"; + + dbTransactionContext.Commit(); + } + catch + { + dbTransactionContext.Rollback(); + throw; + } + + var results = await context.FetchInternalAsync(selectSql, cancellationToken: cancellationToken); + context.Database.DropTable(stagingTableName); + return results; + } + } + public static async Task FetchAsync(this IQueryable queryable, Func, Task> action, Action> optionsAction, CancellationToken cancellationToken = default) where T : class, new() + { + await FetchAsync(queryable, action, optionsAction.Build(), cancellationToken); + } + public static async Task FetchAsync(this IQueryable queryable, Func, Task> action, FetchOptions options, CancellationToken cancellationToken = default) where T : class, new() + { + var dbContext = queryable.GetDbContext(); + var tableMapping = dbContext.GetTableMapping(typeof(T)); + HashSet includedColumns = GetIncludedColumns(tableMapping, options.InputColumns, options.IgnoreColumns); + int batch = 1; + int count = 0; + List entities = []; + await foreach (var entity in queryable.AsNoTracking().AsAsyncEnumerable().WithCancellation(cancellationToken)) + { + ClearExcludedColumns(dbContext, tableMapping, entity, includedColumns); + entities.Add(entity); + count++; + if (count == options.BatchSize) + { + await action(new FetchResult { Results = entities, Batch = batch }); + entities.Clear(); + count = 0; + batch++; + } + cancellationToken.ThrowIfCancellationRequested(); + } + + if (entities.Count > 0) + await action(new FetchResult { Results = entities, Batch = batch }); + } + public static async Task BulkInsertAsync(this DbContext context, IEnumerable entities, CancellationToken cancellationToken = default) + { + return await context.BulkInsertAsync(entities, new BulkInsertOptions(), cancellationToken); + } + public static async Task BulkInsertAsync(this DbContext context, IEnumerable entities, Action> optionsAction, CancellationToken cancellationToken = default) + { + return await context.BulkInsertAsync(entities, optionsAction.Build(), cancellationToken); + } + public static async Task BulkInsertAsync(this DbContext context, IEnumerable entities, BulkInsertOptions options, CancellationToken cancellationToken = default) + { + int rowsAffected = 0; + using (var bulkOperation = new BulkOperation(context, options, options.InputColumns, options.IgnoreColumns)) + { + try + { + bool keepIdentity = options.KeepIdentity || bulkOperation.ShouldPreallocateIdentityValues(options.AutoMapOutput, options.KeepIdentity, entities); + if (keepIdentity && !options.KeepIdentity) + await bulkOperation.PreallocateIdentityValuesAsync(entities, cancellationToken); + var bulkInsertResult = await bulkOperation.BulkInsertStagingDataAsync(entities, true, true); + var bulkMergeResult = await bulkOperation.ExecuteMergeAsync(bulkInsertResult.EntityMap, options.InsertOnCondition, + options.AutoMapOutput, keepIdentity, options.InsertIfNotExists); + rowsAffected = bulkMergeResult.RowsAffected; + bulkOperation.DbTransactionContext.Commit(); + } + catch (Exception) + { + bulkOperation.DbTransactionContext.Rollback(); + throw; + } + } + return rowsAffected; + } + public static async Task> BulkMergeAsync(this DbContext context, IEnumerable entities, CancellationToken cancellationToken = default) + { + return await BulkMergeAsync(context, entities, new BulkMergeOptions(), cancellationToken); + } + public static async Task> BulkMergeAsync(this DbContext context, IEnumerable entities, BulkMergeOptions options, CancellationToken cancellationToken = default) + { + return await InternalBulkMergeAsync(context, entities, options, cancellationToken); + } + public static async Task> BulkMergeAsync(this DbContext context, IEnumerable entities, Action> optionsAction, CancellationToken cancellationToken = default) + { + return await BulkMergeAsync(context, entities, optionsAction.Build(), cancellationToken); + } + public static async Task BulkSaveChangesAsync(this DbContext dbContext) + { + return await dbContext.BulkSaveChangesAsync(true); + } + public static async Task BulkSaveChangesAsync(this DbContext dbContext, bool acceptAllChangesOnSuccess = true) + { + int rowsAffected = 0; + var stateManager = dbContext.GetDependencies().StateManager; + + dbContext.ChangeTracker.DetectChanges(); + var entries = stateManager.GetEntriesToSave(true); + + foreach (var saveEntryGroup in entries.GroupBy(o => new { o.EntityType, o.EntityState })) + { + var key = saveEntryGroup.Key; + var entities = saveEntryGroup.AsEnumerable(); + if (key.EntityState == EntityState.Added) + { + rowsAffected += await dbContext.BulkInsertAsync(entities, o => { o.EntityType = key.EntityType; }); + } + else if (key.EntityState == EntityState.Modified) + { + rowsAffected += await dbContext.BulkUpdateAsync(entities, o => { o.EntityType = key.EntityType; }); + } + else if (key.EntityState == EntityState.Deleted) + { + rowsAffected += await dbContext.BulkDeleteAsync(entities, o => { o.EntityType = key.EntityType; }); + } + } + + if (acceptAllChangesOnSuccess) + dbContext.ChangeTracker.AcceptAllChanges(); + + return rowsAffected; + } + public static async Task> BulkSyncAsync(this DbContext context, IEnumerable entities, CancellationToken cancellationToken = default) + { + return await BulkSyncAsync(context, entities, new BulkSyncOptions(), cancellationToken); + } + public static async Task> BulkSyncAsync(this DbContext context, IEnumerable entities, Action> optionsAction, CancellationToken cancellationToken = default) + { + return BulkSyncResult.Map(await InternalBulkMergeAsync(context, entities, optionsAction.Build(), cancellationToken)); + } + public static async Task> BulkSyncAsync(this DbContext context, IEnumerable entities, BulkSyncOptions options, CancellationToken cancellationToken = default) + { + return BulkSyncResult.Map(await InternalBulkMergeAsync(context, entities, options, cancellationToken)); + } + public static async Task BulkUpdateAsync(this DbContext context, IEnumerable entities, CancellationToken cancellationToken = default) + { + return await BulkUpdateAsync(context, entities, new BulkUpdateOptions(), cancellationToken); + } + public static async Task BulkUpdateAsync(this DbContext context, IEnumerable entities, Action> optionsAction, CancellationToken cancellationToken = default) + { + return await BulkUpdateAsync(context, entities, optionsAction.Build(), cancellationToken); + } + public static async Task BulkUpdateAsync(this DbContext context, IEnumerable entities, BulkUpdateOptions options, CancellationToken cancellationToken = default) + { + int rowsUpdated = 0; + using (var bulkOperation = new BulkOperation(context, options, options.InputColumns, options.IgnoreColumns)) + { + try + { + bulkOperation.ValidateBulkUpdate(options.UpdateOnCondition); + await bulkOperation.BulkInsertStagingDataAsync(entities, cancellationToken: cancellationToken); + rowsUpdated = await bulkOperation.ExecuteUpdateAsync(entities, options.UpdateOnCondition, cancellationToken); + bulkOperation.DbTransactionContext.Commit(); + } + catch (Exception) + { + bulkOperation.DbTransactionContext.Rollback(); + throw; + } + } + return rowsUpdated; + } + public static async Task DeleteFromQueryAsync(this IQueryable queryable, int? commandTimeout = null, CancellationToken cancellationToken = default) where T : class + { + var dbContext = queryable.GetDbContext(); + using (var dbTransactionContext = new DbTransactionContext(dbContext, commandTimeout)) + { + try + { + int rowsAffected = await queryable.ExecuteDeleteAsync(cancellationToken); + dbTransactionContext.Commit(); + return rowsAffected; + } + catch (Exception) + { + dbTransactionContext.Rollback(); + throw; + } + } + } + public static async Task InsertFromQueryAsync(this IQueryable queryable, string tableName, Expression> insertObjectExpression, int? commandTimeout = null, + CancellationToken cancellationToken = default) where T : class + { + var dbContext = queryable.GetDbContext(); + using (var dbTransactionContext = new DbTransactionContext(dbContext, commandTimeout)) + { + try + { + var tableMapping = dbContext.GetTableMapping(typeof(T)); + var columnNames = insertObjectExpression.GetObjectProperties(); + if (!dbContext.Database.TableExists(tableName)) + { + await dbContext.Database.CloneTableAsync(tableMapping.FullQualifedTableName, dbContext.Database.DelimitTableName(tableName), tableMapping.GetQualifiedColumnNames(columnNames), cancellationToken: cancellationToken); + } + + var entities = await queryable.AsNoTracking().ToListAsync(cancellationToken); + int rowsAffected = (int)(await BulkInsertAsync(entities, new BulkInsertOptions { KeepIdentity = true, AutoMapOutput = false, CommandTimeout = commandTimeout }, tableMapping, + dbTransactionContext.Connection, dbTransactionContext.CurrentTransaction, dbContext.Database.DelimitTableName(tableName), columnNames, cancellationToken: cancellationToken)).RowsAffected; + dbTransactionContext.Commit(); + return rowsAffected; + } + catch (Exception) + { + dbTransactionContext.Rollback(); + throw; + } + } + } + public static async Task UpdateFromQueryAsync(this IQueryable queryable, Expression> updateExpression, int? commandTimeout = null, + CancellationToken cancellationToken = default) where T : class + { + var dbContext = queryable.GetDbContext(); + using (var dbTransactionContext = new DbTransactionContext(dbContext, commandTimeout)) + { + try + { + int rowsAffected = await queryable.ExecuteUpdateAsync(BuildSetPropertyCalls(updateExpression), cancellationToken); + dbTransactionContext.Commit(); + return rowsAffected; + } + catch (Exception) + { + dbTransactionContext.Rollback(); + throw; + } + } + } + public static async Task QueryToCsvFileAsync(this IQueryable queryable, string filePath, CancellationToken cancellationToken = default) where T : class + { + return await QueryToCsvFileAsync(queryable, filePath, new QueryToFileOptions(), cancellationToken); + } + public static async Task QueryToCsvFileAsync(this IQueryable queryable, Stream stream, CancellationToken cancellationToken = default) where T : class + { + return await QueryToCsvFileAsync(queryable, stream, new QueryToFileOptions(), cancellationToken); + } + public static async Task QueryToCsvFileAsync(this IQueryable queryable, string filePath, Action optionsAction, + CancellationToken cancellationToken = default) where T : class + { + return await QueryToCsvFileAsync(queryable, filePath, optionsAction.Build(), cancellationToken); + } + public static async Task QueryToCsvFileAsync(this IQueryable queryable, Stream stream, Action optionsAction, + CancellationToken cancellationToken = default) where T : class + { + return await QueryToCsvFileAsync(queryable, stream, optionsAction.Build(), cancellationToken); + } + public static async Task QueryToCsvFileAsync(this IQueryable queryable, string filePath, QueryToFileOptions options, + CancellationToken cancellationToken = default) where T : class + { + await using var fileStream = File.Create(filePath); + return await QueryToCsvFileAsync(queryable, fileStream, options, cancellationToken); + } + public static async Task QueryToCsvFileAsync(this IQueryable queryable, Stream stream, QueryToFileOptions options, + CancellationToken cancellationToken = default) where T : class + { + return await InternalQueryToFileAsync(queryable, stream, options, cancellationToken); + } + public static async Task SqlQueryToCsvFileAsync(this DatabaseFacade database, string filePath, string sqlText, object[] parameters, + CancellationToken cancellationToken = default) + { + return await SqlQueryToCsvFileAsync(database, filePath, new QueryToFileOptions(), sqlText, parameters, cancellationToken); + } + public static async Task SqlQueryToCsvFileAsync(this DatabaseFacade database, Stream stream, string sqlText, object[] parameters, + CancellationToken cancellationToken = default) + { + return await SqlQueryToCsvFileAsync(database, stream, new QueryToFileOptions(), sqlText, parameters, cancellationToken); + } + public static async Task SqlQueryToCsvFileAsync(this DatabaseFacade database, string filePath, Action optionsAction, string sqlText, object[] parameters, + CancellationToken cancellationToken = default) + { + return await SqlQueryToCsvFileAsync(database, filePath, optionsAction.Build(), sqlText, parameters, cancellationToken); + } + public static async Task SqlQueryToCsvFileAsync(this DatabaseFacade database, Stream stream, Action optionsAction, string sqlText, object[] parameters, + CancellationToken cancellationToken = default) + { + return await SqlQueryToCsvFileAsync(database, stream, optionsAction.Build(), sqlText, parameters, cancellationToken); + } + public static async Task SqlQueryToCsvFileAsync(this DatabaseFacade database, string filePath, QueryToFileOptions options, string sqlText, object[] parameters, + CancellationToken cancellationToken = default) + { + await using var fileStream = File.Create(filePath); + return await SqlQueryToCsvFileAsync(database, fileStream, options, sqlText, parameters, cancellationToken); + } + public static async Task SqlQueryToCsvFileAsync(this DatabaseFacade database, Stream stream, QueryToFileOptions options, string sqlText, object[] parameters, + CancellationToken cancellationToken = default) + { + return await InternalQueryToFileAsync(database.GetDbConnection(), stream, options, sqlText, parameters, cancellationToken); + } + public static async Task ClearAsync(this DbSet dbSet, CancellationToken cancellationToken = default) where T : class + { + var dbContext = dbSet.GetDbContext(); + var tableMapping = dbContext.GetTableMapping(typeof(T)); + await dbContext.Database.ClearTableAsync(tableMapping.FullQualifedTableName, cancellationToken); + } + public static async Task TruncateAsync(this DbSet dbSet, CancellationToken cancellationToken = default) where T : class + { + var dbContext = dbSet.GetDbContext(); + var tableMapping = dbContext.GetTableMapping(typeof(T)); + await dbContext.Database.TruncateTableAsync(tableMapping.FullQualifedTableName, false, cancellationToken); + } + internal static async Task> BulkInsertAsync(IEnumerable entities, BulkOptions options, TableMapping tableMapping, DbConnection dbConnection, DbTransaction transaction, string tableName, + IEnumerable inputColumns = null, bool useInternalId = false, CancellationToken cancellationToken = default) + { + using var dataReader = new EntityDataReader(tableMapping, entities, useInternalId); + if (dbConnection is MySqlConnection mySqlConnection) + { + var includeColumns = DbContextExtensions.BuildIncludeColumns(dataReader, inputColumns, useInternalId); + if (includeColumns.Count == 0) + return new BulkInsertResult { RowsAffected = 0, EntityMap = dataReader.EntityMap }; + + string destTable = DbContextExtensions.UnwrapTableName(tableName); + string columnList = string.Join(",", includeColumns.Select(c => $"`{c.name}`")); + const int batchSize = 500; + int totalInserted = 0; + var rowBuffer = new List(batchSize); + + await using var cmd = mySqlConnection.CreateCommand(); + cmd.Transaction = transaction as MySqlTransaction; + if (options.CommandTimeout.HasValue) + cmd.CommandTimeout = options.CommandTimeout.Value; + + async Task FlushBatchAsync() + { + if (rowBuffer.Count == 0) return; + cmd.Parameters.Clear(); + var sb = new System.Text.StringBuilder($"INSERT INTO `{destTable}` ({columnList}) VALUES "); + for (int r = 0; r < rowBuffer.Count; r++) + { + if (r > 0) sb.Append(','); + sb.Append('('); + for (int c = 0; c < includeColumns.Count; c++) + { + if (c > 0) sb.Append(','); + string paramName = $"@p{r}_{c}"; + sb.Append(paramName); + cmd.Parameters.AddWithValue(paramName, rowBuffer[r][c] ?? DBNull.Value); + } + sb.Append(')'); + } + cmd.CommandText = sb.ToString(); + totalInserted += await cmd.ExecuteNonQueryAsync(cancellationToken); + rowBuffer.Clear(); + } + + while (dataReader.Read()) + { + var rowData = new object[includeColumns.Count]; + for (int i = 0; i < includeColumns.Count; i++) + rowData[i] = dataReader.GetValue(includeColumns[i].ordinal) ?? DBNull.Value; + rowBuffer.Add(rowData); + if (rowBuffer.Count >= batchSize) + await FlushBatchAsync(); + } + await FlushBatchAsync(); + + return new BulkInsertResult + { + RowsAffected = totalInserted, + EntityMap = dataReader.EntityMap + }; + } + + throw new NotSupportedException($"The connection type '{dbConnection.GetType().Name}' is not supported for BulkInsertAsync. Use a MySqlConnection."); + } + internal static async Task BulkQueryAsync(this DbContext context, string sqlText, DbConnection dbConnection, DbTransaction transaction, BulkOptions options, CancellationToken cancellationToken = default) + { + List results = []; + List columns = []; + await using var command = dbConnection.CreateCommand(); + command.CommandText = sqlText; + command.Transaction = transaction; + if (options.CommandTimeout.HasValue) + command.CommandTimeout = options.CommandTimeout.Value; + await using var reader = await command.ExecuteReaderAsync(cancellationToken); + while (await reader.ReadAsync(cancellationToken)) + { + if (columns.Count == 0) + { + for (int i = 0; i < reader.FieldCount; i++) + columns.Add(reader.GetName(i)); + } + object[] values = new object[reader.FieldCount]; + reader.GetValues(values); + results.Add(values); + } + + return new BulkQueryResult + { + Columns = columns, + Results = results, + RowsAffected = reader.RecordsAffected + }; + } + private static async Task> InternalBulkMergeAsync(this DbContext context, IEnumerable entities, BulkMergeOptions options, CancellationToken cancellationToken = default) + { + BulkMergeResult bulkMergeResult; + using (var bulkOperation = new BulkOperation(context, options)) + { + try + { + bool shouldPreallocate = bulkOperation.ShouldPreallocateIdentityValues(true, false, entities); + bool keepIdentity = shouldPreallocate || bulkOperation.ShouldKeepIdentityForMerge(); + if (shouldPreallocate) + await bulkOperation.PreallocateIdentityValuesAsync(entities, cancellationToken); + bulkOperation.ValidateBulkMerge(options.MergeOnCondition); + var bulkInsertResult = await bulkOperation.BulkInsertStagingDataAsync(entities, true, true, cancellationToken); + bulkMergeResult = await bulkOperation.ExecuteMergeAsync(bulkInsertResult.EntityMap, options.MergeOnCondition, options.AutoMapOutput, + keepIdentity, true, true, options.DeleteIfNotMatched, shouldPreallocate, cancellationToken); + bulkOperation.DbTransactionContext.Commit(); + } + catch (Exception) + { + bulkOperation.DbTransactionContext.Rollback(); + throw; + } + } + return bulkMergeResult; + } + private static async Task InternalQueryToFileAsync(this IQueryable queryable, Stream stream, QueryToFileOptions options, + CancellationToken cancellationToken = default) where T : class + { + return await InternalQueryToFileAsync(queryable.AsNoTracking().AsAsyncEnumerable(), stream, options, cancellationToken); + } + private static async Task InternalQueryToFileAsync(DbConnection dbConnection, Stream stream, QueryToFileOptions options, string sqlText, object[] parameters = null, + CancellationToken cancellationToken = default) + { + int dataRowCount = 0; + int totalRowCount = 0; + long bytesWritten = 0; + + if (dbConnection.State == ConnectionState.Closed) + dbConnection.Open(); + + await using var command = dbConnection.CreateCommand(); + command.CommandText = sqlText; + if (parameters != null) + command.Parameters.AddRange(parameters); + if (options.CommandTimeout.HasValue) + command.CommandTimeout = options.CommandTimeout.Value; + + await using var streamWriter = new StreamWriter(stream, leaveOpen: true); + using (var reader = await command.ExecuteReaderAsync(cancellationToken)) + { + if (options.IncludeHeaderRow) + { + for (int i = 0; i < reader.FieldCount; i++) + { + streamWriter.Write(options.TextQualifer); + streamWriter.Write(reader.GetName(i)); + streamWriter.Write(options.TextQualifer); + if (i != reader.FieldCount - 1) + { + await streamWriter.WriteAsync(options.ColumnDelimiter); + } + } + totalRowCount++; + await streamWriter.WriteAsync(options.RowDelimiter); + } + while (await reader.ReadAsync(cancellationToken)) + { + object[] values = new object[reader.FieldCount]; + reader.GetValues(values); + for (int i = 0; i < values.Length; i++) + { + streamWriter.Write(options.TextQualifer); + streamWriter.Write(values[i]); + streamWriter.Write(options.TextQualifer); + if (i != values.Length - 1) + { + await streamWriter.WriteAsync(options.ColumnDelimiter); + } + } + await streamWriter.WriteAsync(options.RowDelimiter); + dataRowCount++; + totalRowCount++; + } + await streamWriter.FlushAsync(); + bytesWritten = streamWriter.BaseStream.Length; + } + return new QueryToFileResult() + { + BytesWritten = bytesWritten, + DataRowCount = dataRowCount, + TotalRowCount = totalRowCount + }; + } + private static async Task InternalQueryToFileAsync(IAsyncEnumerable entities, Stream stream, QueryToFileOptions options, CancellationToken cancellationToken) where T : class + { + int dataRowCount = 0; + int totalRowCount = 0; + long bytesWritten = 0; + var properties = typeof(T).GetProperties().Where(p => p.CanRead && (!typeof(System.Collections.IEnumerable).IsAssignableFrom(p.PropertyType) || p.PropertyType == typeof(string))).ToArray(); + + await using var streamWriter = new StreamWriter(stream, leaveOpen: true); + if (options.IncludeHeaderRow) + { + await WriteCsvRowAsync(streamWriter, properties.Select(p => (object)p.Name), options, cancellationToken); + totalRowCount++; + } + + await foreach (var entity in entities.WithCancellation(cancellationToken)) + { + await WriteCsvRowAsync(streamWriter, properties.Select(p => p.GetValue(entity)), options, cancellationToken); + dataRowCount++; + totalRowCount++; + } + + await streamWriter.FlushAsync(cancellationToken); + bytesWritten = streamWriter.BaseStream.Length; + return new QueryToFileResult { BytesWritten = bytesWritten, DataRowCount = dataRowCount, TotalRowCount = totalRowCount }; + } + private static async Task> FetchInternalAsync(this DbContext dbContext, string sqlText, object[] parameters = null, CancellationToken cancellationToken = default) where T : class, new() + { + List results = []; + await using var command = dbContext.Database.CreateCommand(ConnectionBehavior.New); + command.CommandText = sqlText; + if (parameters != null) + command.Parameters.AddRange(parameters); + + var tableMapping = dbContext.GetTableMapping(typeof(T), null); + var reader = await command.ExecuteReaderAsync(cancellationToken); + var properties = reader.GetProperties(tableMapping); + var valuesFromProvider = properties.Select(p => tableMapping.GetValueFromProvider(p)).ToArray(); + + while (await reader.ReadAsync(cancellationToken)) + { + var entity = reader.MapEntity(dbContext, properties, valuesFromProvider); + results.Add(entity); + } + + await reader.CloseAsync(); + await command.Connection.CloseAsync(); + return results; + } + private static HashSet GetIncludedColumns(TableMapping tableMapping, Expression> inputColumns, Expression> ignoreColumns) + { + var includedColumns = inputColumns != null + ? inputColumns.GetObjectProperties().ToHashSet() + : tableMapping.Properties.Select(p => p.Name).ToHashSet(); + + if (ignoreColumns != null) + includedColumns.ExceptWith(ignoreColumns.GetObjectProperties()); + + return includedColumns; + } + private static void ClearExcludedColumns(DbContext dbContext, TableMapping tableMapping, T entity, HashSet includedColumns) where T : class + { + var entry = dbContext.Entry(entity); + foreach (var property in tableMapping.Properties) + { + if (includedColumns.Contains(property.Name)) + continue; + + object defaultValue = property.ClrType.IsValueType ? Activator.CreateInstance(property.ClrType) : null; + if (property.DeclaringType is IComplexType complexType) + { + var complexProperty = entry.ComplexProperty(complexType.ComplexProperty); + if (complexProperty.CurrentValue != null) + complexProperty.Property(property).CurrentValue = defaultValue; + } + else + { + entry.Property(property.Name).CurrentValue = defaultValue; + } + } + } + private static async Task WriteCsvRowAsync(TextWriter writer, IEnumerable values, QueryToFileOptions options, CancellationToken cancellationToken) + { + bool first = true; + foreach (var value in values) + { + if (!first) + await writer.WriteAsync(options.ColumnDelimiter); + + await writer.WriteAsync(options.TextQualifer); + await writer.WriteAsync(value?.ToString()); + await writer.WriteAsync(options.TextQualifer); + first = false; + cancellationToken.ThrowIfCancellationRequested(); + } + await writer.WriteAsync(options.RowDelimiter); + } + private static Expression, SetPropertyCalls>> BuildSetPropertyCalls(Expression> updateExpression) where T : class + { + if (updateExpression.Body is not MemberInitExpression memberInitExpression) + throw new InvalidOperationException("UpdateFromQuery requires a member initialization expression."); + + var entityParameter = updateExpression.Parameters[0]; + var callsParam = Expression.Parameter(typeof(SetPropertyCalls), "calls"); + var setPropertyMethod = typeof(SetPropertyCalls) + .GetMethods() + .Single(m => m.Name == nameof(SetPropertyCalls.SetProperty) && m.GetParameters().Length == 2 && m.GetParameters()[1].ParameterType.IsGenericType); + + Expression current = callsParam; + foreach (var binding in memberInitExpression.Bindings.OfType()) + { + var propertyInfo = binding.Member as System.Reflection.PropertyInfo ?? throw new InvalidOperationException("Only property bindings are supported."); + var propertyLambda = Expression.Lambda(Expression.Property(entityParameter, propertyInfo), entityParameter); + var valueLambda = Expression.Lambda(binding.Expression, entityParameter); + current = Expression.Call(current, setPropertyMethod.MakeGenericMethod(propertyInfo.PropertyType), propertyLambda, valueLambda); + } + + return Expression.Lambda, SetPropertyCalls>>(current, callsParam); + } +} diff --git a/N.EntityFramework.Extensions.MySql/Data/DbTransactionContext.cs b/N.EntityFramework.Extensions.MySql/Data/DbTransactionContext.cs new file mode 100644 index 0000000..3830a13 --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Data/DbTransactionContext.cs @@ -0,0 +1,71 @@ +using System; +using System.Data.Common; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Storage; +using N.EntityFrameworkCore.Extensions.Enums; +using N.EntityFrameworkCore.Extensions.Util; + + +namespace N.EntityFrameworkCore.Extensions; + +internal sealed class DbTransactionContext : IDisposable +{ + private bool closeConnection; + private bool ownsTransaction; + private int? defaultCommandTimeout; + private DbContext context; + private IDbContextTransaction transaction; + + public DbConnection Connection { get; internal set; } + public DbTransaction CurrentTransaction { get; private set; } + public DbContext DbContext => context; + internal bool OwnsTransaction => ownsTransaction; + + public DbTransactionContext(DbContext context, BulkOptions bulkOptions, bool openConnection = true) : this(context, bulkOptions.CommandTimeout, bulkOptions.ConnectionBehavior, openConnection) + { + + } + public DbTransactionContext(DbContext context, int? commandTimeout = null, ConnectionBehavior connectionBehavior = ConnectionBehavior.Default, bool openConnection = true) + { + this.context = context; + Connection = context.GetDbConnection(connectionBehavior); + if (openConnection) + { + if (Connection.State == System.Data.ConnectionState.Closed) + { + Connection.Open(); + closeConnection = true; + } + } + if (connectionBehavior == ConnectionBehavior.Default) + { + ownsTransaction = context.Database.CurrentTransaction == null; + transaction = context.Database.CurrentTransaction; + defaultCommandTimeout = context.Database.GetCommandTimeout(); + if (transaction != null) + CurrentTransaction = transaction.GetDbTransaction(); + } + + context.Database.SetCommandTimeout(commandTimeout); + } + + public void Dispose() + { + context.Database.SetCommandTimeout(defaultCommandTimeout); + if (closeConnection) + { + Connection.Close(); + } + } + + internal void Commit() + { + if (ownsTransaction && transaction != null) + transaction.Commit(); + } + internal void Rollback() + { + if (transaction != null) + transaction.Rollback(); + } +} diff --git a/N.EntityFramework.Extensions.MySql/Data/EfExtensionsCommand.cs b/N.EntityFramework.Extensions.MySql/Data/EfExtensionsCommand.cs new file mode 100644 index 0000000..f94c8c7 --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Data/EfExtensionsCommand.cs @@ -0,0 +1,22 @@ +using System.Data.Common; +using Microsoft.EntityFrameworkCore.Diagnostics; + +namespace N.EntityFrameworkCore.Extensions; + +internal sealed class EfExtensionsCommand +{ + public EfExtensionsCommandType CommandType { get; set; } + public string OldValue { get; set; } + public string NewValue { get; set; } + public DbConnection Connection { get; internal set; } + + internal bool Execute(DbCommand command, CommandEventData eventData, InterceptionResult result) + { + if (CommandType == EfExtensionsCommandType.ChangeTableName) + { + command.CommandText = command.CommandText.Replace(OldValue, NewValue); + } + + return true; + } +} diff --git a/N.EntityFramework.Extensions.MySql/Data/EfExtensionsCommandInterceptor.cs b/N.EntityFramework.Extensions.MySql/Data/EfExtensionsCommandInterceptor.cs new file mode 100644 index 0000000..da8b8f2 --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Data/EfExtensionsCommandInterceptor.cs @@ -0,0 +1,27 @@ +using System; +using System.Collections.Concurrent; +using System.Data.Common; +using Microsoft.EntityFrameworkCore.Diagnostics; + +namespace N.EntityFrameworkCore.Extensions; + +public class EfExtensionsCommandInterceptor : DbCommandInterceptor +{ + private ConcurrentDictionary extensionCommands = new(); + public override InterceptionResult ReaderExecuting(DbCommand command, CommandEventData eventData, InterceptionResult result) + { + foreach (var extensionCommand in extensionCommands) + { + if (extensionCommand.Value.Connection == command.Connection) + { + extensionCommand.Value.Execute(command, eventData, result); + extensionCommands.TryRemove(extensionCommand.Key, out _); + } + } + return result; + } + internal void AddCommand(Guid clientConnectionId, EfExtensionsCommand efExtensionsCommand) + { + extensionCommands.TryAdd(clientConnectionId, efExtensionsCommand); + } +} \ No newline at end of file diff --git a/N.EntityFramework.Extensions.MySql/Data/EntityDataReader.cs b/N.EntityFramework.Extensions.MySql/Data/EntityDataReader.cs new file mode 100644 index 0000000..a2a6238 --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Data/EntityDataReader.cs @@ -0,0 +1,248 @@ +using System; +using System.Collections.Generic; +using System.Data; +using Microsoft.EntityFrameworkCore.ChangeTracking; +using Microsoft.EntityFrameworkCore.ChangeTracking.Internal; +using Microsoft.EntityFrameworkCore.Metadata; +using N.EntityFrameworkCore.Extensions.Common; + +namespace N.EntityFrameworkCore.Extensions; + +internal sealed class EntityDataReader : IDataReader +{ + public TableMapping TableMapping { get; set; } + public Dictionary EntityMap { get; set; } + private Dictionary columnIndexes; + private int currentId; + private bool useInternalId; + private int tableFieldCount; + private IEnumerable entities; + private IEnumerator enumerator; + private Dictionary> selectors; + + public EntityDataReader(TableMapping tableMapping, IEnumerable entities, bool useInternalId) + { + this.columnIndexes = []; + this.currentId = 0; + this.useInternalId = useInternalId; + this.tableFieldCount = tableMapping.Properties.Length; + this.entities = entities; + this.enumerator = entities.GetEnumerator(); + this.selectors = []; + this.EntityMap = []; + this.FieldCount = tableMapping.Properties.Length; + this.TableMapping = tableMapping; + + + int i = 0; + foreach (var property in tableMapping.Properties) + { + selectors[i] = GetValueSelector(property); + columnIndexes[tableMapping.GetColumnName(property)] = i; + i++; + } + + if (useInternalId) + { + this.FieldCount++; + columnIndexes[Constants.InternalId_ColumnName] = i; + } + } + private Func GetValueSelector(IProperty property) + { + Func selector; + var valueGeneratorFactory = property.GetValueGeneratorFactory(); + if (valueGeneratorFactory != null) + { + var valueGenerator = valueGeneratorFactory.Invoke(property, this.TableMapping.EntityType); + selector = entry => valueGenerator.Next(entry); + } + else + { + var valueConverter = property.GetTypeMapping().Converter; + if (valueConverter != null) + { + selector = entry => valueConverter.ConvertToProvider(entry.CurrentValues[property]); + } + else + { + if (property.DeclaringType is IComplexType complexType) + { + selector = entry => entry.ComplexProperty(complexType.ComplexProperty).Property(property).CurrentValue; + } + else + { + selector = entry => entry.CurrentValues[property]; + } + } + } + return selector; + } + public object this[int i] => throw new NotImplementedException(); + + public object this[string name] => throw new NotImplementedException(); + + public int Depth { get; set; } + + public bool IsClosed => throw new NotImplementedException(); + + public int RecordsAffected => throw new NotImplementedException(); + + public int FieldCount { get; set; } + + public void Close() + { + throw new NotImplementedException(); + } + + public void Dispose() + { + selectors = null; + enumerator.Dispose(); + } + + public bool GetBoolean(int i) + { + throw new NotImplementedException(); + } + + public byte GetByte(int i) + { + throw new NotImplementedException(); + } + + public long GetBytes(int i, long fieldOffset, byte[] buffer, int bufferoffset, int length) + { + throw new NotImplementedException(); + } + + public char GetChar(int i) + { + throw new NotImplementedException(); + } + + public long GetChars(int i, long fieldoffset, char[] buffer, int bufferoffset, int length) + { + throw new NotImplementedException(); + } + + public IDataReader GetData(int i) + { + throw new NotImplementedException(); + } + + public string GetDataTypeName(int i) + { + throw new NotImplementedException(); + } + + public DateTime GetDateTime(int i) + { + throw new NotImplementedException(); + } + + public decimal GetDecimal(int i) + { + throw new NotImplementedException(); + } + + public double GetDouble(int i) + { + throw new NotImplementedException(); + } + + public Type GetFieldType(int i) + { + throw new NotImplementedException(); + } + + public float GetFloat(int i) + { + throw new NotImplementedException(); + } + + public Guid GetGuid(int i) + { + throw new NotImplementedException(); + } + + public short GetInt16(int i) + { + throw new NotImplementedException(); + } + + public int GetInt32(int i) + { + throw new NotImplementedException(); + } + + public long GetInt64(int i) + { + throw new NotImplementedException(); + } + + public string GetName(int i) + { + throw new NotImplementedException(); + } + + public int GetOrdinal(string name) + { + return columnIndexes[name]; + } + + public DataTable GetSchemaTable() + { + throw new NotImplementedException(); + } + + public string GetString(int i) + { + throw new NotImplementedException(); + } + + public object GetValue(int i) + { + if (i == tableFieldCount) + { + return this.currentId; + } + else + { + return selectors[i](FindEntry(enumerator.Current)); + } + + } + + private EntityEntry FindEntry(object entity) + { + return entity is InternalEntityEntry internalEntry ? internalEntry.ToEntityEntry() : TableMapping.DbContext.Entry(entity); + } + + public int GetValues(object[] values) + { + throw new NotImplementedException(); + } + + public bool IsDBNull(int i) + { + throw new NotImplementedException(); + } + + public bool NextResult() + { + throw new NotImplementedException(); + } + + public bool Read() + { + bool moveNext = enumerator.MoveNext(); + + if (moveNext && this.useInternalId) + { + this.currentId++; + this.EntityMap.Add(this.currentId, enumerator.Current); + } + return moveNext; + } +} diff --git a/N.EntityFramework.Extensions.MySql/Data/FetchOptions.cs b/N.EntityFramework.Extensions.MySql/Data/FetchOptions.cs new file mode 100644 index 0000000..1d00787 --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Data/FetchOptions.cs @@ -0,0 +1,11 @@ +using System; +using System.Linq.Expressions; + +namespace N.EntityFrameworkCore.Extensions; + +public class FetchOptions +{ + public Expression> IgnoreColumns { get; set; } + public Expression> InputColumns { get; set; } + public int BatchSize { get; set; } +} \ No newline at end of file diff --git a/N.EntityFramework.Extensions.MySql/Data/FetchResult.cs b/N.EntityFramework.Extensions.MySql/Data/FetchResult.cs new file mode 100644 index 0000000..1b579ec --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Data/FetchResult.cs @@ -0,0 +1,9 @@ +using System.Collections.Generic; + +namespace N.EntityFrameworkCore.Extensions; + +public class FetchResult +{ + public List Results { get; set; } + public int Batch { get; set; } +} \ No newline at end of file diff --git a/N.EntityFramework.Extensions.MySql/Data/QueryToFileOptions.cs b/N.EntityFramework.Extensions.MySql/Data/QueryToFileOptions.cs new file mode 100644 index 0000000..ce9c3c4 --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Data/QueryToFileOptions.cs @@ -0,0 +1,18 @@ +namespace N.EntityFrameworkCore.Extensions; + +public class QueryToFileOptions +{ + public string ColumnDelimiter { get; set; } + public int? CommandTimeout { get; set; } + public bool IncludeHeaderRow { get; set; } + public string RowDelimiter { get; set; } + public string TextQualifer { get; set; } + + public QueryToFileOptions() + { + ColumnDelimiter = ","; + IncludeHeaderRow = true; + RowDelimiter = "\r\n"; + TextQualifer = ""; + } +} \ No newline at end of file diff --git a/N.EntityFramework.Extensions.MySql/Data/QueryToFileResult.cs b/N.EntityFramework.Extensions.MySql/Data/QueryToFileResult.cs new file mode 100644 index 0000000..174615e --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Data/QueryToFileResult.cs @@ -0,0 +1,8 @@ +namespace N.EntityFrameworkCore.Extensions; + +public class QueryToFileResult +{ + public long BytesWritten { get; set; } + public int DataRowCount { get; internal set; } + public int TotalRowCount { get; internal set; } +} \ No newline at end of file diff --git a/N.EntityFramework.Extensions.MySql/Data/SqlMergeAction.cs b/N.EntityFramework.Extensions.MySql/Data/SqlMergeAction.cs new file mode 100644 index 0000000..64ac316 --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Data/SqlMergeAction.cs @@ -0,0 +1,8 @@ +namespace N.EntityFrameworkCore.Extensions; + +internal static class SqlMergeAction +{ + public const string Insert = "INSERT"; + public const string Update = "UPDATE"; + public const string Delete = "DELETE"; +} \ No newline at end of file diff --git a/N.EntityFramework.Extensions.MySql/Data/SqlQuery.cs b/N.EntityFramework.Extensions.MySql/Data/SqlQuery.cs new file mode 100644 index 0000000..51b434c --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Data/SqlQuery.cs @@ -0,0 +1,36 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.EntityFrameworkCore.Infrastructure; +using N.EntityFrameworkCore.Extensions.Sql; + +namespace N.EntityFrameworkCore.Extensions; + +public class SqlQuery +{ + private DatabaseFacade database; + public string SqlText { get; private set; } + public object[] Parameters { get; private set; } + + public SqlQuery(DatabaseFacade database, string sqlText, params object[] parameters) + { + this.database = database; + SqlText = sqlText; + Parameters = parameters; + } + + public int Count() + { + string countSqlText = SqlBuilder.Parse(SqlText).Count(); + return Convert.ToInt32(database.ExecuteScalar(countSqlText, Parameters)); + } + public async Task CountAsync(CancellationToken cancellationToken = default) + { + string countSqlText = SqlBuilder.Parse(SqlText).Count(); + return Convert.ToInt32(await database.ExecuteScalarAsync(countSqlText, Parameters, null, cancellationToken)); + } + public int ExecuteNonQuery() + { + return database.ExecuteSql(SqlText, Parameters); + } +} diff --git a/N.EntityFramework.Extensions.MySql/Data/TableMapping.cs b/N.EntityFramework.Extensions.MySql/Data/TableMapping.cs new file mode 100644 index 0000000..1597671 --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Data/TableMapping.cs @@ -0,0 +1,135 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.EntityFrameworkCore.Metadata.Internal; +using N.EntityFrameworkCore.Extensions.Extensions; +using N.EntityFrameworkCore.Extensions.Util; + +namespace N.EntityFrameworkCore.Extensions; + +public class TableMapping +{ + public DbContext DbContext { get; private set; } + public IEntityType EntityType { get; set; } + public IProperty[] Properties { get; } + public string Schema { get; } + public string TableName { get; } + public IEnumerable EntityTypes { get; } + + public bool HasIdentityColumn => EntityType.FindPrimaryKey().Properties.Any(o => o.ValueGenerated != ValueGenerated.Never); + public StoreObjectIdentifier StoreObjectIdentifier => StoreObjectIdentifier.Table(TableName, EntityType.GetSchema() ?? DbContext.Database.GetDefaultSchema()); + private Dictionary ColumnMap { get; set; } + public string FullQualifedTableName => DbContext.DelimitIdentifier(TableName, Schema); + + public TableMapping(DbContext dbContext, IEntityType entityType) + { + DbContext = dbContext; + EntityType = entityType; + Properties = GetProperties(entityType); + ColumnMap = Properties.Select(p => new KeyValuePair(GetColumnName(p), p)).ToDictionary(); + Schema = entityType.GetSchema() ?? dbContext.Database.GetDefaultSchema(); + TableName = entityType.GetTableName(); + EntityTypes = EntityType.GetAllBaseTypesInclusive().Where(o => !o.IsAbstract()); + } + public IProperty GetPropertyFromColumnName(string columnName) => ColumnMap[columnName]; + private static IProperty[] GetProperties(IEntityType entityType) + { + var properties = entityType.GetProperties().ToList(); + properties.AddRange(entityType.GetComplexProperties().SelectMany(p => p.ComplexType.GetProperties())); + return properties.ToArray(); + } + + public IEnumerable GetQualifiedColumnNames(IEnumerable columnNames, IEntityType entityType = null) + { + return Properties.Where(o => entityType == null || o.GetDeclaringEntityType() == entityType) + .Select(o => new + { + Column = FindColumn(o), + Name = GetColumnName(o) + }) + .Where(o => columnNames == null || columnNames.Contains(o.Name)) + .Select(o => $"{DbContext.DelimitIdentifier(o.Column?.Table.Name ?? TableName)}.{DbContext.DelimitIdentifier(o.Name)}").ToList(); + } + public string GetColumnName(IProperty property) => FindColumn(property)?.Name ?? property.Name; + private IColumnBase FindColumn(IProperty property) + { + var entityType = property.GetDeclaringEntityType(); + if (entityType == null || entityType.IsAbstract()) + entityType = EntityType; + var storeObjectIdentifier = StoreObjectIdentifier.Table(entityType.GetTableName(), entityType.GetSchema()); + return property.FindColumn(storeObjectIdentifier); + } + + private string FindTableName(IEntityType declaringEntityType, IEntityType entityType) => + declaringEntityType != null && declaringEntityType.IsAbstract() ? declaringEntityType.GetTableName() : entityType.GetTableName(); + public IEnumerable GetColumnNames(IEntityType entityType, bool primaryKeyColumns) + { + List columns; + if (entityType != null) + { + columns = entityType.GetProperties().Where(o => (o.GetDeclaringEntityType() == entityType || o.GetDeclaringEntityType().IsAbstract() + || o.IsForeignKeyToSelf()) && o.ValueGenerated == ValueGenerated.Never) + .Select(GetColumnName).ToList(); + + columns.AddRange(entityType.GetComplexProperties().SelectMany(o => o.ComplexType.GetProperties() + .Select(GetColumnName))); + } + else + { + columns = EntityType.GetProperties().Where(o => o.ValueGenerated == ValueGenerated.Never) + .Select(GetColumnName).ToList(); + + columns.AddRange(EntityType.GetComplexProperties().SelectMany(o => o.ComplexType.GetProperties() + .Select(GetColumnName))); + } + if (primaryKeyColumns) + { + columns.AddRange(GetPrimaryKeyColumns()); + } + return columns.Distinct(); + } + public IEnumerable GetColumns(bool includePrimaryKeyColumns = false) + { + List columns = []; + foreach (var entityType in EntityTypes) + { + var storeObjectIdentifier = StoreObjectIdentifier.Create(entityType, StoreObjectType.Table).GetValueOrDefault(); + columns.AddRange(entityType.GetProperties().Where(o => o.ValueGenerated == ValueGenerated.Never) + .Select(GetColumnName)); + + columns.AddRange(EntityType.GetComplexProperties().SelectMany(o => o.ComplexType.GetProperties() + .Select(GetColumnName))); + + if (includePrimaryKeyColumns) + columns.AddRange(GetPrimaryKeyColumns()); + } + return columns.Where(o => o != null).Distinct(); + } + public IEnumerable GetPrimaryKeyColumns() => + EntityType.FindPrimaryKey().Properties.Select(GetColumnName); + + internal IEnumerable GetAutoGeneratedColumns(IEntityType entityType = null) + { + entityType ??= EntityType; + return entityType.GetProperties().Where(o => o.ValueGenerated != ValueGenerated.Never) + .Select(GetColumnName); + } + + internal IEnumerable GetEntityProperties(IEntityType entityType = null, ValueGenerated? valueGenerated = null) + { + entityType ??= EntityType; + return entityType.GetProperties().Where(o => valueGenerated == null || o.ValueGenerated == valueGenerated).AsEnumerable(); + } + internal Func GetValueFromProvider(IProperty property) + { + var valueConverter = property.GetTypeMapping().Converter; + return valueConverter != null ? value => valueConverter.ConvertFromProvider(value) : value => value; + } + internal IEnumerable GetSchemaQualifiedTableNames() + { + return EntityTypes + .Select(o => DbContext.DelimitIdentifier(o.GetTableName(), o.GetSchema() ?? DbContext.Database.GetDefaultSchema())).Distinct(); + } +} diff --git a/N.EntityFramework.Extensions.MySql/Enums/ConnectionBehavior.cs b/N.EntityFramework.Extensions.MySql/Enums/ConnectionBehavior.cs new file mode 100644 index 0000000..3e12ad4 --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Enums/ConnectionBehavior.cs @@ -0,0 +1,7 @@ +namespace N.EntityFrameworkCore.Extensions.Enums; + +internal enum ConnectionBehavior +{ + Default, + New +} \ No newline at end of file diff --git a/N.EntityFramework.Extensions.MySql/Enums/EfExtensionsCommandType.cs b/N.EntityFramework.Extensions.MySql/Enums/EfExtensionsCommandType.cs new file mode 100644 index 0000000..7dc1746 --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Enums/EfExtensionsCommandType.cs @@ -0,0 +1,6 @@ +namespace N.EntityFrameworkCore.Extensions; + +internal enum EfExtensionsCommandType +{ + ChangeTableName +} \ No newline at end of file diff --git a/N.EntityFramework.Extensions.MySql/Extensions/CommonExtensions.cs b/N.EntityFramework.Extensions.MySql/Extensions/CommonExtensions.cs new file mode 100644 index 0000000..8241e59 --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Extensions/CommonExtensions.cs @@ -0,0 +1,13 @@ +using System; + +namespace N.EntityFrameworkCore.Extensions.Extensions; + +internal static class CommonExtensions +{ + internal static T Build(this Action buildAction) where T : new() + { + var parameter = new T(); + buildAction(parameter); + return parameter; + } +} \ No newline at end of file diff --git a/N.EntityFramework.Extensions.MySql/Extensions/DbDataReaderExtensions.cs b/N.EntityFramework.Extensions.MySql/Extensions/DbDataReaderExtensions.cs new file mode 100644 index 0000000..c5bffe5 --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Extensions/DbDataReaderExtensions.cs @@ -0,0 +1,51 @@ +using System; +using System.Collections.Generic; +using System.Data.Common; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Metadata; + +namespace N.EntityFrameworkCore.Extensions.Extensions; + +internal static class DbDataReaderExtensions +{ + internal static T MapEntity(this DbDataReader reader, DbContext dbContext, IProperty[] properties, Func[] valuesFromProvider) where T : class, new() + { + var entity = new T(); + var entry = dbContext.Entry(entity); + + for (var i = 0; i < reader.FieldCount; i++) + { + var property = properties[i]; + var value = valuesFromProvider[i].Invoke(reader.GetValue(i)); + if (value == DBNull.Value) + value = null; + + if (property.DeclaringType is IComplexType complexType) + { + var complexProperty = entry.ComplexProperty(complexType.ComplexProperty); + if (complexProperty.CurrentValue == null) + { + complexProperty.CurrentValue = Activator.CreateInstance(complexType.ClrType); + } + complexProperty.Property(property).CurrentValue = value; + } + else + { + entry.Property(property).CurrentValue = value; + } + } + return entity; + } + internal static IProperty[] GetProperties(this DbDataReader reader, TableMapping tableMapping) + { + List properties = []; + + for (var i = 0; i < reader.FieldCount; i++) + { + var property = tableMapping.GetPropertyFromColumnName(reader.GetName(i)); + properties.Add(property); + } + + return properties.ToArray(); + } +} \ No newline at end of file diff --git a/N.EntityFramework.Extensions.MySql/Extensions/IPropertyExtensions.cs b/N.EntityFramework.Extensions.MySql/Extensions/IPropertyExtensions.cs new file mode 100644 index 0000000..6c95145 --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Extensions/IPropertyExtensions.cs @@ -0,0 +1,11 @@ +using Microsoft.EntityFrameworkCore.Metadata; + +namespace N.EntityFrameworkCore.Extensions.Extensions; + +public static class IPropertyExtensions +{ + public static IEntityType GetDeclaringEntityType(this IProperty property) + { + return property.DeclaringType as IEntityType; + } +} \ No newline at end of file diff --git a/N.EntityFramework.Extensions.MySql/Extensions/LinqExtensions.cs b/N.EntityFramework.Extensions.MySql/Extensions/LinqExtensions.cs new file mode 100644 index 0000000..d305556 --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Extensions/LinqExtensions.cs @@ -0,0 +1,246 @@ +using System; +using System.Collections.Generic; +using System.Globalization; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Text; +using System.Text.RegularExpressions; +using Microsoft.EntityFrameworkCore; +using N.EntityFrameworkCore.Extensions.Util; + +namespace N.EntityFrameworkCore.Extensions; + +internal static class LinqExtensions +{ + internal static List GetObjectProperties(this Expression> expression) + { + if (expression == null) + { + return []; + } + else if (expression.Body is MemberExpression propertyExpression) + { + return [propertyExpression.Member.Name]; + } + else if (expression.Body is NewExpression newExpression) + { + return newExpression.Members.Select(o => o.Name).ToList(); + } + else if ((expression.Body is UnaryExpression unaryExpression) && (unaryExpression.Operand.GetPrivateFieldValue("Member") is PropertyInfo propertyInfo)) + { + return [propertyInfo.Name]; + } + else + { + throw new InvalidOperationException("GetObjectProperties() encountered an unsupported expression type"); + } + } + internal static string ToSql(this ExpressionType expressionType) => expressionType switch + { + ExpressionType.AndAlso => "AND", + ExpressionType.Or => "OR", + ExpressionType.Add => "+", + ExpressionType.Subtract => "-", + ExpressionType.Multiply => "*", + ExpressionType.Divide => "/", + ExpressionType.Modulo => "%", + ExpressionType.Equal => "=", + _ => string.Empty + }; + + internal static string ToSql(this MemberBinding binding) + { + if (binding is MemberAssignment memberAssingment) + { + return GetExpressionValueAsString(memberAssingment.Expression); + } + else + { + throw new NotSupportedException(); + } + } + internal static string ToSql(this Expression expression) + { + var sb = new StringBuilder(); + if (expression is BinaryExpression binaryExpression) + { + sb.Append(binaryExpression.Left.ToSql()); + sb.Append($" {expression.NodeType.ToSql()} "); + sb.Append(binaryExpression.Right.ToSql()); + } + else if (expression is MemberExpression memberExpression) + { + return $"{memberExpression}"; + } + else if (expression is UnaryExpression unaryExpression) + { + return $"{unaryExpression.Operand}"; + } + return sb.ToString(); + } + internal static string GetExpressionValueAsString(Expression expression) + { + if (expression is ConstantExpression constantExpression) + { + return ConvertToSqlValue(constantExpression.Value); + } + else if (expression is MemberExpression memberExpression) + { + if (memberExpression.Expression is ParameterExpression parameterExpression) + { + return memberExpression.ToString(); + } + else + { + return ConvertToSqlValue(Expression.Lambda(expression).Compile().DynamicInvoke()); + } + } + else if (expression.NodeType == ExpressionType.Convert) + { + return ConvertToSqlValue(Expression.Lambda(expression).Compile().DynamicInvoke()); + } + else if (expression.NodeType == ExpressionType.Call) + { + var methodCallExpression = expression as MethodCallExpression; + List argValues = []; + foreach (var argument in methodCallExpression.Arguments) + { + argValues.Add(GetExpressionValueAsString(argument)); + } + return methodCallExpression.Method.Name switch + { + "ToString" => $"CONVERT(VARCHAR,{argValues[0]})", + _ => $"{methodCallExpression.Method.Name}({string.Join(",", argValues)})" + }; + } + else + { + var binaryExpression = expression as BinaryExpression; + string leftValue = GetExpressionValueAsString(binaryExpression.Left); + string rightValue = GetExpressionValueAsString(binaryExpression.Right); + string joinValue = expression.NodeType.ToSql(); + + return $"({leftValue} {joinValue} {rightValue})"; + } + } + internal static string ToSqlPredicate2(this Expression expression, params string[] parameters) + { + var sql = ToSqlString(expression.Body); + + for (var i = 0; i < parameters.Length; i++) + sql = sql.Replace($"${expression.Parameters[i].Name!}.", $"{parameters[i]}."); + + return sql; + } + internal static string ToSqlPredicate(this Expression expression, params string[] parameters) + { + var expressionBody = (string)expression.Body.GetPrivateFieldValue("DebugView"); + expressionBody = expressionBody.Replace(System.Environment.NewLine, " "); + var stringBuilder = new StringBuilder(expressionBody); + + int i = 0; + foreach (var expressionParam in expression.Parameters) + { + if (parameters.Length <= i) break; + stringBuilder.Replace((string)expressionParam.GetPrivateFieldValue("DebugView"), parameters[i]); + i++; + } + stringBuilder.Replace("== null", "IS NULL"); + stringBuilder.Replace("!= null", "IS NOT NULL"); + stringBuilder.Replace("&&", "AND"); + stringBuilder.Replace("==", "="); + stringBuilder.Replace("||", "OR"); + stringBuilder.Replace("(System.Nullable`1[System.Int32])", ""); + stringBuilder.Replace("(System.Int32)", ""); + return stringBuilder.ToString(); + } + internal static string ToSqlPredicate(this Expression expression, DbContext dbContext, params string[] parameters) + { + string predicate = expression.ToSqlPredicate(parameters); + return DelimitMemberAccess(dbContext, predicate); + } + internal static string ToSqlUpdateSetExpression(this Expression expression, string tableName) + { + List setValues = []; + var memberInitExpression = expression.Body as MemberInitExpression; + foreach (var binding in memberInitExpression.Bindings) + { + string expValue = binding.ToSql(); + expValue = expValue.Replace($"{expression.Parameters.First().Name}.", ""); + setValues.Add($"[{binding.Member.Name}]={expValue}"); + } + return string.Join(",", setValues); + } + internal static string ToSqlUpdateSetExpression(this Expression expression, DbContext dbContext, string tableName) + { + List setValues = []; + var memberInitExpression = expression.Body as MemberInitExpression; + foreach (var binding in memberInitExpression.Bindings) + { + string expValue = binding.ToSql(); + expValue = expValue.Replace($"{expression.Parameters.First().Name}.", ""); + expValue = DelimitMemberAccess(dbContext, expValue); + setValues.Add($"{dbContext.DelimitIdentifier(binding.Member.Name)}={expValue}"); + } + return string.Join(",", setValues); + } + private static string ToSqlString(Expression expression, string sql = null) + { + sql ??= ""; + if (expression is not BinaryExpression b) + return sql; + + var sb = new StringBuilder(); + if (b.Left is MemberExpression mel) + sb.Append($"${mel} = "); + if (b.Right is MemberExpression mer) + sb.Append($"${mer}"); + + if (b.Left is UnaryExpression ubl) + sb.Append($"${ubl.Operand} = "); + if (b.Right is UnaryExpression ubr) + sb.Append($"${ubr.Operand}"); + + if (sb.Length > 0) + return sb.ToString(); + + var left = ToSqlString(b.Left, sql); + if (string.IsNullOrWhiteSpace(left)) + return sql; + + var right = ToSqlString(b.Right, sql); + return $"{left} AND {right}"; + } + private static string ConvertToSqlValue(object value) + { + if (value == null) + return "NULL"; + if (value is string str) + return $"'{str.Replace("'", "''")}'"; + if (value is Guid guid) + return $"'{guid}'"; + if (value is bool b) + return b ? "1" : "0"; + if (value is DateTime dt) + return $"'{dt:yyyy-MM-ddTHH:mm:ss.fffffff}'"; // Convert to ISO-8601 + if (value is DateTimeOffset dto) + return $"'{dto:yyyy-MM-ddTHH:mm:ss.fffffffzzzz}'"; // Convert to ISO-8601 + var valueType = value.GetType(); + if (valueType.IsEnum) + return Convert.ToString((int)value); + if (!valueType.IsClass) + return Convert.ToString(value, CultureInfo.InvariantCulture); + + throw new NotImplementedException("Unhandled data type."); + } + private static string DelimitMemberAccess(DbContext dbContext, string expression) + { + return Regex.Replace(expression, @"(? + { + string alias = match.Groups[1].Value; + string member = match.Groups[2].Value; + return dbContext.DelimitMemberAccess(alias, member); + }); + } +} diff --git a/N.EntityFramework.Extensions.MySql/Extensions/ObjectExtensions.cs b/N.EntityFramework.Extensions.MySql/Extensions/ObjectExtensions.cs new file mode 100644 index 0000000..3821322 --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Extensions/ObjectExtensions.cs @@ -0,0 +1,32 @@ +using System; +using System.Reflection; + +namespace N.EntityFrameworkCore.Extensions; + +internal static class ObjectExtensions +{ + internal static object GetPrivateFieldValue(this object obj, string propName) + { + if (obj == null) throw new ArgumentNullException(nameof(obj)); + Type t = obj.GetType(); + FieldInfo fieldInfo = null; + PropertyInfo propertyInfo = null; + while (fieldInfo == null && propertyInfo == null && t != null) + { + fieldInfo = t.GetField(propName, BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); + if (fieldInfo == null) + { + propertyInfo = t.GetProperty(propName, BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); + } + + t = t.BaseType; + } + if (fieldInfo == null && propertyInfo == null) + throw new ArgumentOutOfRangeException(nameof(propName), $"Field {propName} was not found in Type {obj.GetType().FullName}"); + + if (fieldInfo != null) + return fieldInfo.GetValue(obj); + + return propertyInfo.GetValue(obj, null); + } +} \ No newline at end of file diff --git a/N.EntityFramework.Extensions.MySql/Extensions/SqlStatementExtensions.cs b/N.EntityFramework.Extensions.MySql/Extensions/SqlStatementExtensions.cs new file mode 100644 index 0000000..b202282 --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Extensions/SqlStatementExtensions.cs @@ -0,0 +1,13 @@ +using System.Collections.Generic; +using N.EntityFrameworkCore.Extensions.Sql; + +namespace N.EntityFrameworkCore.Extensions.Extensions; + +internal static class SqlStatementExtensions +{ + internal static void WriteInsert(this SqlStatement statement, IEnumerable insertColumns) + { + statement.CreatePart(SqlKeyword.Insert, SqlExpression.Columns(insertColumns)); + statement.CreatePart(SqlKeyword.Values, SqlExpression.Columns(insertColumns)); + } +} \ No newline at end of file diff --git a/N.EntityFramework.Extensions.MySql/GlobalSuppressions.cs b/N.EntityFramework.Extensions.MySql/GlobalSuppressions.cs new file mode 100644 index 0000000..cb5f59f --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/GlobalSuppressions.cs @@ -0,0 +1,12 @@ +// This file is used by Code Analysis to maintain SuppressMessage +// attributes that are applied to this project. +// Project-level suppressions either have no target or are given +// a specific target and scoped to a namespace, type, member, etc. + +using System.Diagnostics.CodeAnalysis; + +[assembly: SuppressMessage("Usage", "EF1001:Internal EF Core API usage.", Justification = "EntityFrameworkCore Extension", Scope = "member", Target = "~M:N.EntityFrameworkCore.Extensions.DbContextExtensions.BulkSaveChanges(Microsoft.EntityFrameworkCore.DbContext,System.Boolean)~System.Int32")] +[assembly: SuppressMessage("Usage", "EF1001:Internal EF Core API usage.", Justification = "EntityFrameworkCore Extension", Scope = "member", Target = "~M:N.EntityFrameworkCore.Extensions.DbContextExtensions.SetStoreGeneratedValues``1(Microsoft.EntityFrameworkCore.DbContext,``0,System.Collections.Generic.IEnumerable{Microsoft.EntityFrameworkCore.Metadata.IProperty},System.Object[])")] +[assembly: SuppressMessage("Usage", "EF1001:Internal EF Core API usage.", Justification = "EntityFrameworkCore Extension", Scope = "member", Target = "~M:N.EntityFrameworkCore.Extensions.DbContextExtensionsAsync.BulkSaveChangesAsync(Microsoft.EntityFrameworkCore.DbContext,System.Boolean)~System.Threading.Tasks.Task{System.Int32}")] +[assembly: SuppressMessage("Usage", "EF1001:Internal EF Core API usage.", Justification = "EntityFrameworkCore Extension", Scope = "member", Target = "~M:N.EntityFrameworkCore.Extensions.TableMapping.GetColumnNames(Microsoft.EntityFrameworkCore.Metadata.IEntityType,System.Boolean)~System.Collections.Generic.IEnumerable{System.String}")] +[assembly: SuppressMessage("Usage", "EF1001:Internal EF Core API usage.", Justification = "EntityFrameworkCore Extension", Scope = "member", Target = "~M:N.EntityFrameworkCore.Extensions.EntityDataReader`1.FindEntry(System.Object)~Microsoft.EntityFrameworkCore.ChangeTracking.EntityEntry")] \ No newline at end of file diff --git a/N.EntityFramework.Extensions.MySql/N.EntityFramework.Extensions.MySql.csproj b/N.EntityFramework.Extensions.MySql/N.EntityFramework.Extensions.MySql.csproj new file mode 100644 index 0000000..50f8d7d --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/N.EntityFramework.Extensions.MySql.csproj @@ -0,0 +1,40 @@ + + + + net9.0 + 9.0.0.1 + N.EntityFramework.Extensions.MySql + true + https://github.com/NorthernLight1/N.EntityFrameworkCore.Extensions/ + Northern25 + Copyright © 2026 + + N.EntityFramework.Extensions.MySql extends your DbContext in EF Core with high-performance bulk operations for MySql: BulkDelete, BulkInsert, BulkMerge, BulkSync, BulkUpdate, Fetch, DeleteFromQuery, InsertFromQuery, UpdateFromQuery. + +Inheritance models supported: Table-Per-Concrete, Table-Per-Hierarchy, Table-Per-Type + MIT + README.md + + + + 5 + + + + + True + \ + + + + + + + + + + + + + + diff --git a/N.EntityFramework.Extensions.MySql/Sql/SqlBuilder.cs b/N.EntityFramework.Extensions.MySql/Sql/SqlBuilder.cs new file mode 100644 index 0000000..2a06ac0 --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Sql/SqlBuilder.cs @@ -0,0 +1,152 @@ +using System; +using System.Collections.Generic; +using System.Data; +using System.Linq; +using System.Linq.Expressions; + +namespace N.EntityFrameworkCore.Extensions.Sql; + +internal sealed class SqlBuilder +{ + private static readonly string[] keywords = ["DECLARE", "SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY"]; + internal string Sql => ToString(); + internal List Clauses { get; private set; } + internal List<(string Name, DbType DbType, int Size, object Value)> Parameters { get; private set; } + private SqlBuilder(string sql) + { + Clauses = []; + Parameters = []; + Initialize(sql); + } + + internal string Count() => + $"SELECT COUNT(*) FROM ({string.Join("\r\n", Clauses.Where(o => o.Name != "ORDER BY").Select(o => o.ToString()))}) s"; + public override string ToString() => string.Join("\r\n", Clauses.Select(o => o.ToString())); + internal static SqlBuilder Parse(string sql) => new SqlBuilder(sql); + internal string GetTableAlias() + { + var sqlFromClause = Clauses.First(o => o.Name == "FROM"); + var startIndex = sqlFromClause.InputText.LastIndexOf(" AS "); + return startIndex > 0 ? sqlFromClause.InputText[(startIndex + 4)..] : ""; + } + internal void ChangeToDelete() + { + Validate(); + var sqlClause = Clauses.FirstOrDefault(); + var sqlFromClause = Clauses.First(o => o.Name == "FROM"); + if (sqlClause != null) + { + sqlClause.Name = "DELETE"; + int aliasStartIndex = sqlFromClause.InputText.IndexOf("AS ") + 3; + int aliasLength = sqlFromClause.InputText.IndexOf(']', aliasStartIndex) - aliasStartIndex + 1; + sqlClause.InputText = sqlFromClause.InputText[aliasStartIndex..(aliasStartIndex + aliasLength)]; + } + } + internal void ChangeToUpdate(string updateExpression, string setExpression) + { + Validate(); + var sqlClause = Clauses.FirstOrDefault(); + if (sqlClause != null) + { + sqlClause.Name = "UPDATE"; + sqlClause.InputText = updateExpression; + Clauses.Insert(1, new SqlClause { Name = "SET", InputText = setExpression }); + } + } + internal void ChangeToInsert(string tableName, Expression> insertObjectExpression) + { + Validate(); + var sqlSelectClause = Clauses.FirstOrDefault(); + string columnsToInsert = string.Join(",", insertObjectExpression.GetObjectProperties()); + string insertValueExpression = $"INTO {tableName} ({columnsToInsert})"; + Clauses.Insert(0, new SqlClause { Name = "INSERT", InputText = insertValueExpression }); + sqlSelectClause.InputText = columnsToInsert; + } + internal void SelectColumns(IEnumerable columns) + { + var tableAlias = GetTableAlias(); + var sqlClause = Clauses.FirstOrDefault(); + if (sqlClause.Name == "SELECT") + { + sqlClause.InputText = string.Join(",", columns.Select(c => $"{tableAlias}.{c}")); + } + } + private void Initialize(string sqlText) + { + string curClause = string.Empty; + int curClauseIndex = 0; + for (int i = 0; i < sqlText.Length;) + { + string keyword = StartsWithString(sqlText.AsSpan(i), keywords, StringComparison.OrdinalIgnoreCase); + bool isWordStart = i == 0 || sqlText[i - 1] == ' ' || (i > 1 && sqlText[i - 2] == '\r' && sqlText[i - 1] == '\n'); + if (keyword != null && isWordStart) + { + string inputText = sqlText[curClauseIndex..i]; + if (!string.IsNullOrEmpty(curClause)) + { + if (curClause == "DECLARE") + { + var declareParts = inputText[..inputText.IndexOf(';')].Trim().Split(' '); + int sizeStartIndex = declareParts[1].IndexOf('('); + int sizeLength = declareParts[1].IndexOf(')') - (sizeStartIndex + 1); + string dbTypeString = sizeStartIndex != -1 ? declareParts[1][..sizeStartIndex] : declareParts[1]; + DbType dbType = (DbType)Enum.Parse(typeof(DbType), dbTypeString, true); + int size = sizeStartIndex != -1 ? + Convert.ToInt32(declareParts[1][(sizeStartIndex + 1)..(sizeStartIndex + 1 + sizeLength)]) : 0; + string value = GetDeclareValue(declareParts[3]); + Parameters.Add((declareParts[0], dbType, size, value)); + } + else + { + Clauses.Add(SqlClause.Parse(curClause, inputText)); + } + } + curClause = keyword; + curClauseIndex = i + curClause.Length; + i = i + curClause.Length; + } + else + { + i++; + } + } + if (!string.IsNullOrEmpty(curClause)) + Clauses.Add(SqlClause.Parse(curClause, sqlText[curClauseIndex..])); + } + private string GetDeclareValue(string value) + { + if (value.StartsWith('\'')) + { + return value[1..^1]; + } + else if (value.StartsWith("N'")) + { + return value[2..^1]; + } + else if (value.StartsWith("CAST(")) + { + return value[5..]; + } + else + { + return value; + } + } + private static string StartsWithString(ReadOnlySpan textToSearch, string[] valuesToFind, StringComparison stringComparison) + { + foreach (var valueToFind in valuesToFind) + { + if (textToSearch.StartsWith(valueToFind, stringComparison)) + return valueToFind; + } + + return null; + } + private void Validate() + { + if (Clauses.Count == 0) + { + throw new Exception("You must parse a valid sql statement before you can use this function."); + } + } +} diff --git a/N.EntityFramework.Extensions.MySql/Sql/SqlClause.cs b/N.EntityFramework.Extensions.MySql/Sql/SqlClause.cs new file mode 100644 index 0000000..f5f4148 --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Sql/SqlClause.cs @@ -0,0 +1,14 @@ +namespace N.EntityFrameworkCore.Extensions.Sql; + +internal sealed class SqlClause +{ + internal string Name { get; set; } + internal string InputText { get; set; } + internal string Sql => ToString(); + internal static SqlClause Parse(string name, string inputText) + { + string cleanText = inputText.Replace("\r\n", "").Trim(); + return new SqlClause { Name = name, InputText = cleanText }; + } + public override string ToString() => $"{Name} {InputText}"; +} \ No newline at end of file diff --git a/N.EntityFramework.Extensions.MySql/Sql/SqlExpression.cs b/N.EntityFramework.Extensions.MySql/Sql/SqlExpression.cs new file mode 100644 index 0000000..8ff597d --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Sql/SqlExpression.cs @@ -0,0 +1,70 @@ +using System.Collections.Generic; +using System.Linq; +using System.Text; +using N.EntityFrameworkCore.Extensions.Util; + +namespace N.EntityFrameworkCore.Extensions.Sql; + +internal sealed class SqlExpression +{ + internal SqlExpressionType ExpressionType { get; } + List Items { get; set; } + internal string Sql => ToSql(); + string Alias { get; } + internal bool IsEmpty => Items.Count == 0; + + SqlExpression(SqlExpressionType expressionType, object item, string alias = null) + { + ExpressionType = expressionType; + Items = []; + if (item is IEnumerable values) + { + Items.AddRange(values.ToArray()); + } + else + { + Items.Add(item); + } + Alias = alias; + } + SqlExpression(SqlExpressionType expressionType, object[] items, string alias = null) + { + ExpressionType = expressionType; + Items = []; + Items.AddRange(items); + Alias = alias; + } + internal static SqlExpression Columns(IEnumerable columns) => + new SqlExpression(SqlExpressionType.Columns, columns); + + internal static SqlExpression Set(IEnumerable columns) => + new SqlExpression(SqlExpressionType.Set, columns); + + internal static SqlExpression String(string joinOnCondition) => + new SqlExpression(SqlExpressionType.String, joinOnCondition); + + internal static SqlExpression Table(string tableName, string alias = null) => + new SqlExpression(SqlExpressionType.Table, Util.CommonUtil.FormatTableName(tableName), alias); + + private string ToSql() + { + var sbSql = new StringBuilder(); + if (ExpressionType == SqlExpressionType.Columns) + { + var values = Items.Where(o => o != null).Select(o => o.ToString()).Where(o => !string.IsNullOrWhiteSpace(o)).ToArray(); + sbSql.Append(string.Join(",", CommonUtil.FormatColumns(values))); + } + else + { + sbSql.Append(string.Join(",", Items.Where(o => o != null).Select(o => o.ToString()).Where(o => !string.IsNullOrWhiteSpace(o)))); + } + if (Alias != null) + { + sbSql.Append(" "); + sbSql.Append(SqlKeyword.As.ToString().ToUpper()); + sbSql.Append(" "); + sbSql.Append(Alias); + } + return sbSql.ToString(); + } +} diff --git a/N.EntityFramework.Extensions.MySql/Sql/SqlExpressionType.cs b/N.EntityFramework.Extensions.MySql/Sql/SqlExpressionType.cs new file mode 100644 index 0000000..5c326df --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Sql/SqlExpressionType.cs @@ -0,0 +1,9 @@ +namespace N.EntityFrameworkCore.Extensions.Sql; + +internal enum SqlExpressionType +{ + String, + Table, + Columns, + Set +} \ No newline at end of file diff --git a/N.EntityFramework.Extensions.MySql/Sql/SqlKeyword.cs b/N.EntityFramework.Extensions.MySql/Sql/SqlKeyword.cs new file mode 100644 index 0000000..d109a97 --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Sql/SqlKeyword.cs @@ -0,0 +1,29 @@ +namespace N.EntityFrameworkCore.Extensions.Sql; + +internal enum SqlKeyword +{ + Select, + Delete, + Insert, + Values, + Update, + Set, + Merge, + Into, + From, + On, + Where, + Using, + When, + Then, + Matched, + Not, + Output, + As, + By, + Source, + Target, + Off, + Identity_Insert, + Semicolon, +} \ No newline at end of file diff --git a/N.EntityFramework.Extensions.MySql/Sql/SqlPart.cs b/N.EntityFramework.Extensions.MySql/Sql/SqlPart.cs new file mode 100644 index 0000000..652b2f3 --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Sql/SqlPart.cs @@ -0,0 +1,14 @@ +namespace N.EntityFrameworkCore.Extensions.Sql; + +internal sealed class SqlPart +{ + internal SqlKeyword Keyword { get; } + internal SqlExpression Expression { get; } + internal bool IgnoreOutput => GetIgnoreOutput(); + internal SqlPart(SqlKeyword keyword, SqlExpression expression) + { + Keyword = keyword; + Expression = expression; + } + private bool GetIgnoreOutput() => Keyword == SqlKeyword.Output && (Expression == null || Expression.IsEmpty); +} \ No newline at end of file diff --git a/N.EntityFramework.Extensions.MySql/Sql/SqlStatement.cs b/N.EntityFramework.Extensions.MySql/Sql/SqlStatement.cs new file mode 100644 index 0000000..bcf98ca --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Sql/SqlStatement.cs @@ -0,0 +1,105 @@ +using System.Collections.Generic; +using System.Linq; +using System.Text; +using N.EntityFrameworkCore.Extensions.Extensions; + +namespace N.EntityFrameworkCore.Extensions.Sql; + +internal sealed class SqlStatement +{ + internal string Sql => ToSql(); + List SqlParts { get; } + SqlStatement() + { + SqlParts = []; + } + internal void CreatePart(SqlKeyword keyword, SqlExpression expression = null) => + SqlParts.Add(new SqlPart(keyword, expression)); + internal void SetIdentityInsert(string tableName, bool enable) + { + CreatePart(SqlKeyword.Set); + CreatePart(SqlKeyword.Identity_Insert, SqlExpression.Table(tableName)); + if (enable) + CreatePart(SqlKeyword.On); + else + CreatePart(SqlKeyword.Off); + CreatePart(SqlKeyword.Semicolon); + } + internal static SqlStatement CreateMerge(string sourceTableName, string targetTableName, string joinOnCondition, + IEnumerable insertColumns, IEnumerable updateColumns, IEnumerable outputColumns, + bool deleteIfNotMatched = false, bool hasIdentityColumn = false) + { + var statement = new SqlStatement(); + if (hasIdentityColumn) + statement.SetIdentityInsert(targetTableName, true); + statement.CreatePart(SqlKeyword.Merge, SqlExpression.Table(targetTableName, "t")); + statement.CreatePart(SqlKeyword.Using, SqlExpression.Table(sourceTableName, "s")); + statement.CreatePart(SqlKeyword.On, SqlExpression.String(joinOnCondition)); + statement.CreatePart(SqlKeyword.When); + statement.CreatePart(SqlKeyword.Not); + statement.CreatePart(SqlKeyword.Matched); + statement.CreatePart(SqlKeyword.Then); + statement.WriteInsert(insertColumns); + if (updateColumns.Any()) + { + var updateSetColumns = updateColumns.Select(c => $"t.[{c}]=s.[{c}]"); + statement.CreatePart(SqlKeyword.When); + statement.CreatePart(SqlKeyword.Matched); + statement.CreatePart(SqlKeyword.Then); + statement.CreatePart(SqlKeyword.Update); + statement.CreatePart(SqlKeyword.Set, SqlExpression.Set(updateSetColumns)); + } + if (deleteIfNotMatched) + { + statement.CreatePart(SqlKeyword.When); + statement.CreatePart(SqlKeyword.Not); + statement.CreatePart(SqlKeyword.Matched); + statement.CreatePart(SqlKeyword.By); + statement.CreatePart(SqlKeyword.Source); + statement.CreatePart(SqlKeyword.Then); + statement.CreatePart(SqlKeyword.Delete); + } + if (outputColumns.Any()) + statement.CreatePart(SqlKeyword.Output, SqlExpression.Columns(outputColumns)); + statement.CreatePart(SqlKeyword.Semicolon); + + if (hasIdentityColumn) + statement.SetIdentityInsert(targetTableName, false); + return statement; + } + + private string ToSql() + { + var sbSql = new StringBuilder(); + foreach (var part in SqlParts) + { + if (part.Keyword == SqlKeyword.Semicolon) + { + int lastIndex = sbSql.Length - 1; + if (lastIndex > -1 && sbSql[lastIndex] == ' ') + { + sbSql[lastIndex] = ';'; + sbSql.Append("\n"); + } + else + { + sbSql.Append(";\n"); + } + } + else if (!part.IgnoreOutput) + { + sbSql.Append(part.Keyword.ToString().ToUpper()); + sbSql.Append(" "); + bool useParenthese = part.Keyword == SqlKeyword.Insert || part.Keyword == SqlKeyword.Values; + + if (part.Expression != null) + { + string expressionSql = useParenthese ? $"({part.Expression.Sql})" : part.Expression.Sql; + sbSql.Append(expressionSql); + sbSql.Append(" "); + } + } + } + return sbSql.ToString(); + } +} \ No newline at end of file diff --git a/N.EntityFramework.Extensions.MySql/Util/CommonUtil.cs b/N.EntityFramework.Extensions.MySql/Util/CommonUtil.cs new file mode 100644 index 0000000..9e3ba27 --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Util/CommonUtil.cs @@ -0,0 +1,135 @@ +using System; +using System.Collections.Generic; +using System.Data.Common; +using System.IO; +using System.Linq; +using System.Linq.Expressions; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Infrastructure; + +namespace N.EntityFrameworkCore.Extensions.Util; + +internal static class CommonUtil +{ + internal static string GetStagingTableName(TableMapping tableMapping, bool usePermanentTable, DbConnection dbConnection) + { + string uniqueSuffix = Guid.NewGuid().ToString("N"); + if (usePermanentTable) + return tableMapping.DbContext.Database.GetPermanentStagingTableName(tableMapping.Schema, tableMapping.TableName, uniqueSuffix); + return tableMapping.DbContext.Database.GetTemporaryTableName(tableMapping.TableName); + } + internal static IEnumerable FormatColumns(DbContext dbContext, IEnumerable columns) + { + return columns.Select(s => FormatColumn(dbContext, s)); + } + internal static IEnumerable FormatColumns(IEnumerable columns) + { + return columns.Select(FormatColumnLegacy); + } + internal static IEnumerable FormatColumns(DbContext dbContext, string tableAlias, IEnumerable columns) + { + return columns.Select(s => dbContext.DelimitMemberAccess(tableAlias, RemoveQualifier(s))); + } + internal static IEnumerable FormatColumns(DatabaseFacade database, string tableAlias, IEnumerable columns) + { + return columns.Select(s => database.DelimitMemberAccess(tableAlias, RemoveQualifier(s))); + } + internal static IEnumerable FormatColumns(string tableAlias, IEnumerable columns) + { + return columns.Select(s => s.StartsWith('[') && s.EndsWith(']') ? $"[{tableAlias}].{s}" : $"[{tableAlias}].[{s}]"); + } + internal static IEnumerable FilterColumns(IEnumerable columnNames, string[] primaryKeyColumnNames, Expression> inputColumns, Expression> ignoreColumns) + { + var filteredColumnNames = columnNames; + if (inputColumns != null) + { + var inputColumnNames = inputColumns.GetObjectProperties(); + filteredColumnNames = filteredColumnNames.Intersect(inputColumnNames.Union(primaryKeyColumnNames)); + } + if (ignoreColumns != null) + { + var ignoreColumnNames = ignoreColumns.GetObjectProperties(); + if (ignoreColumnNames.Intersect(primaryKeyColumnNames).Any()) + { + throw new InvalidDataException("Primary key columns can not be ignored in BulkInsertOptions.IgnoreColumns"); + } + else + { + filteredColumnNames = filteredColumnNames.Except(ignoreColumnNames); + } + } + return filteredColumnNames; + } + internal static string FormatTableName(DatabaseFacade database, string tableName) + { + return database.DelimitTableName(tableName); + } + internal static string FormatTableName(string tableName) + { + return string.Join(".", tableName.Split('.').Select(s => $"[{RemoveQualifier(s)}]")); + } + private static string FormatColumn(DbContext dbContext, string column) + { + var parts = column.Split('.'); + return string.Join(".", parts.Select(p => p.StartsWith('$') ? p : dbContext.DelimitIdentifier(RemoveQualifier(p)))); + } + private static string FormatColumnLegacy(string column) + { + var parts = column.Split('.'); + return string.Join(".", parts.Select(p => p.StartsWith('$') || (p.StartsWith('[') && p.EndsWith(']')) ? p : $"[{p}]")); + } + private static string RemoveQualifier(string name) + { + return name.TrimStart('[').TrimEnd(']').Trim('"'); + } +} +internal static class CommonUtil +{ + internal static string[] GetColumns(Expression> expression, string[] tableNames) + { + List foundColumns = []; + string sqlText = (string)expression.Body.GetPrivateFieldValue("DebugView"); + var sqlSpan = sqlText.AsSpan(); + + int offset = 0; + while (offset < sqlSpan.Length) + { + int startIndex = sqlSpan[offset..].IndexOf('$'); + if (startIndex == -1) break; + startIndex += offset; + + var remaining = sqlSpan[startIndex..]; + int spaceIndex = remaining.IndexOf(' '); + var columnSpan = spaceIndex == -1 ? remaining : remaining[..spaceIndex]; + + int dotIndex = columnSpan.IndexOf('.'); + if (dotIndex >= 0) + { + var tablePart = columnSpan[1..dotIndex]; // skip leading '$' + var columnPart = columnSpan[(dotIndex + 1)..]; + if (tableNames == null || tableNames.Contains(tablePart.ToString())) + { + foundColumns.Add(columnPart.ToString()); + } + } + + offset = startIndex + 1; + } + + return foundColumns.ToArray(); + } + internal static string GetJoinConditionSql(Expression> joinKeyExpression, string[] storeGeneratedColumnNames, string sourceTableName = "s", string targetTableName = "t") + { + if (joinKeyExpression != null) + return joinKeyExpression.ToSqlPredicate(sourceTableName, targetTableName); + + return string.Join(" AND ", storeGeneratedColumnNames.Select(c => $"{sourceTableName}.[{c}]={targetTableName}.[{c}]")); + } + internal static string GetJoinConditionSql(DbContext dbContext, Expression> joinKeyExpression, string[] storeGeneratedColumnNames, string sourceTableName = "s", string targetTableName = "t") + { + if (joinKeyExpression != null) + return joinKeyExpression.ToSqlPredicate(dbContext, sourceTableName, targetTableName); + + return string.Join(" AND ", storeGeneratedColumnNames.Select(c => $"{dbContext.DelimitMemberAccess(sourceTableName, c)}={dbContext.DelimitMemberAccess(targetTableName, c)}")); + } +} diff --git a/N.EntityFramework.Extensions.MySql/Util/RelationalProviderUtil.cs b/N.EntityFramework.Extensions.MySql/Util/RelationalProviderUtil.cs new file mode 100644 index 0000000..c192fcd --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Util/RelationalProviderUtil.cs @@ -0,0 +1,128 @@ +using System; +using System.Data.Common; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Storage; + +namespace N.EntityFrameworkCore.Extensions.Util; + +internal enum DatabaseProvider +{ + SqlServer, + PostgreSql, + MySql +} + +internal readonly record struct DatabaseObjectName(string Schema, string Name) +{ + internal bool HasSchema => !string.IsNullOrWhiteSpace(Schema); +} + +internal static class RelationalProviderUtil +{ + internal static DatabaseProvider GetDatabaseProvider(this DatabaseFacade database) + { + return database.ProviderName switch + { + string providerName when providerName.Contains("SqlServer", StringComparison.OrdinalIgnoreCase) => DatabaseProvider.SqlServer, + string providerName when providerName.Contains("Npgsql", StringComparison.OrdinalIgnoreCase) => DatabaseProvider.PostgreSql, + string providerName when providerName.Contains("MySql", StringComparison.OrdinalIgnoreCase) => DatabaseProvider.MySql, + _ => throw new NotSupportedException($"The database provider '{database.ProviderName}' is not supported.") + }; + } + + internal static bool IsSqlServer(this DatabaseFacade database) => database.GetDatabaseProvider() == DatabaseProvider.SqlServer; + + internal static bool IsPostgreSql(this DatabaseFacade database) => database.GetDatabaseProvider() == DatabaseProvider.PostgreSql; + + internal static string GetDefaultSchema(this DatabaseFacade database) => + database.GetDatabaseProvider() == DatabaseProvider.MySql ? null : (database.IsPostgreSql() ? "public" : "dbo"); + + internal static string DelimitIdentifier(this DatabaseFacade database, string identifier) => + database.GetSqlGenerationHelper().DelimitIdentifier(UnwrapIdentifier(identifier)); + + internal static string DelimitIdentifier(this DatabaseFacade database, string identifier, string schema) => + schema == null + ? database.DelimitIdentifier(identifier) + : database.GetSqlGenerationHelper().DelimitIdentifier(UnwrapIdentifier(identifier), UnwrapIdentifier(schema)); + + internal static string DelimitIdentifier(this DbContext dbContext, string identifier) => + dbContext.Database.DelimitIdentifier(identifier); + + internal static string DelimitIdentifier(this DbContext dbContext, string identifier, string schema) => + dbContext.Database.DelimitIdentifier(identifier, schema); + + internal static string DelimitTableName(this DatabaseFacade database, string tableName) + { + var objectName = database.ParseObjectName(tableName); + return objectName.HasSchema + ? database.DelimitIdentifier(objectName.Name, objectName.Schema) + : database.DelimitIdentifier(objectName.Name); + } + + internal static string DelimitTableName(this DbContext dbContext, string tableName) => + dbContext.Database.DelimitTableName(tableName); + + internal static string DelimitMemberAccess(this DbContext dbContext, string alias, string columnName) => + $"{dbContext.DelimitIdentifier(alias)}.{dbContext.DelimitIdentifier(columnName)}"; + + internal static string DelimitMemberAccess(this DatabaseFacade database, string alias, string columnName) => + $"{database.DelimitIdentifier(alias)}.{database.DelimitIdentifier(columnName)}"; + + internal static DatabaseObjectName ParseObjectName(this DatabaseFacade database, string objectName) + { + string normalized = objectName.Trim(); + if (string.IsNullOrWhiteSpace(normalized)) + throw new ArgumentException("Object name cannot be empty.", nameof(objectName)); + + var parts = normalized.Split('.', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries); + return parts.Length switch + { + 1 => new DatabaseObjectName(IsTemporaryName(parts[0]) ? null : database.GetDefaultSchema(), UnwrapIdentifier(parts[0])), + 2 => new DatabaseObjectName(UnwrapIdentifier(parts[0]), UnwrapIdentifier(parts[1])), + _ => throw new InvalidOperationException($"Unsupported object name format '{objectName}'.") + }; + } + + internal static string UnwrapIdentifier(string value) => + value.Trim().Trim('[', ']', '"', '`'); + + internal static string GetTemporaryTableName(this DatabaseFacade database, string baseName) + { + const string prefix = "tmp_be_xx_"; + const int guidSuffixLength = 33; // "_" + 32 hex chars (Guid:N) + const int maxIdentifierLength = 64; // MySQL identifier limit + string unwrapped = UnwrapIdentifier(baseName); + int maxNameLength = maxIdentifierLength - prefix.Length - guidSuffixLength; + if (unwrapped.Length > maxNameLength) + unwrapped = unwrapped[..maxNameLength]; + string temporaryName = $"{prefix}{unwrapped}_{Guid.NewGuid():N}"; + return database.DelimitIdentifier(temporaryName); + } + + internal static string GetPermanentStagingTableName(this DatabaseFacade database, string schema, string tableName, string uniqueSuffix) + { + const string prefix = "tmp_be_xx_"; + const int maxIdentifierLength = 64; // MySQL identifier limit + string unwrapped = UnwrapIdentifier(tableName); + int maxNameLength = maxIdentifierLength - prefix.Length - 1 - uniqueSuffix.Length; // 1 for "_" + if (unwrapped.Length > maxNameLength) + unwrapped = unwrapped[..maxNameLength]; + string stagingName = $"{prefix}{unwrapped}_{uniqueSuffix}"; + return database.DelimitIdentifier(stagingName, schema); + } + + internal static DbConnection CloneConnection(this DbConnection dbConnection) => + dbConnection switch + { + ICloneable cloneable => (DbConnection)cloneable.Clone(), + _ => throw new NotSupportedException($"Connection type '{dbConnection.GetType().FullName}' does not support cloning.") + }; + + private static ISqlGenerationHelper GetSqlGenerationHelper(this DatabaseFacade database) => + ((IInfrastructure)database).Instance.GetService(typeof(ISqlGenerationHelper)) as ISqlGenerationHelper + ?? throw new InvalidOperationException("Unable to resolve ISqlGenerationHelper."); + + private static bool IsTemporaryName(string objectName) => + UnwrapIdentifier(objectName).StartsWith("#", StringComparison.Ordinal); +} diff --git a/N.EntityFramework.Extensions.MySql/Util/SqlUtil.cs b/N.EntityFramework.Extensions.MySql/Util/SqlUtil.cs new file mode 100644 index 0000000..de42c71 --- /dev/null +++ b/N.EntityFramework.Extensions.MySql/Util/SqlUtil.cs @@ -0,0 +1,11 @@ +using System.Collections.Generic; + +namespace N.EntityFrameworkCore.Extensions; + +internal static class SqlUtil +{ + internal static string ConvertToColumnString(IEnumerable columnNames) + { + return string.Join(",", columnNames); + } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.runsettings b/N.EntityFramework.Extensions.PostgreSql.runsettings similarity index 79% rename from N.EntityFrameworkCore.Extensions.runsettings rename to N.EntityFramework.Extensions.PostgreSql.runsettings index 8a39bda..6f2f423 100644 --- a/N.EntityFrameworkCore.Extensions.runsettings +++ b/N.EntityFramework.Extensions.PostgreSql.runsettings @@ -1,5 +1,8 @@ + + 4 + diff --git a/N.EntityFramework.Extensions.SqlServer.runsettings b/N.EntityFramework.Extensions.SqlServer.runsettings new file mode 100644 index 0000000..6f2f423 --- /dev/null +++ b/N.EntityFramework.Extensions.SqlServer.runsettings @@ -0,0 +1,15 @@ + + + + 4 + + + + + + detailed + + + + + diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Common/Config.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Common/Config.cs new file mode 100644 index 0000000..83416f2 --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Common/Config.cs @@ -0,0 +1,32 @@ +using System; +using System.Data.Common; +using Microsoft.Extensions.Configuration; +using MySqlConnector; + +namespace N.EntityFrameworkCore.Extensions.Test.Common; + +public class Config +{ + private static readonly IConfigurationRoot configuration = new ConfigurationBuilder() + .AddJsonFile("appsettings.json", optional: true, reloadOnChange: true) + .AddEnvironmentVariables() + .Build(); + + public static string GetConnectionString(string name) + { + return configuration.GetConnectionString(name); + } + public static bool IsMySql => true; + public static bool IsPostgreSql => false; + public static bool IsSqlServer => false; + public static bool UseMySqlContainer => + !string.Equals(configuration["UseMySqlContainer"], "false", StringComparison.OrdinalIgnoreCase); + public static string GetTestDatabaseConnectionString() => + UseMySqlContainer ? MySqlContainerManager.GetConnectionString() : GetConnectionString("MySqlTestDatabase"); + public static DbParameter CreateParameter(string name, object value) => + new MySqlParameter(name, value ?? DBNull.Value); + public static string DelimitIdentifier(string identifier) => $"`{identifier}`"; + public static string DelimitTableName(string tableName) => $"`{tableName}`"; + public static bool IsPrimaryKeyViolation(Exception exception) => + exception.Message.Contains("Duplicate entry", StringComparison.OrdinalIgnoreCase); +} diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Common/MySqlContainerManager.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Common/MySqlContainerManager.cs new file mode 100644 index 0000000..e83a04b --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Common/MySqlContainerManager.cs @@ -0,0 +1,69 @@ +using System; +using System.Threading.Tasks; +using Testcontainers.MySql; + +namespace N.EntityFrameworkCore.Extensions.Test.Common; + +internal static class MySqlContainerManager +{ + private static readonly object syncRoot = new(); + private static Task initializationTask; + private static MySqlContainer container; + private static bool cleanupRegistered; + + internal static string GetConnectionString() + { + EnsureStarted(); + return container.GetConnectionString() + ";AllowLoadLocalInfile=true;UseAffectedRows=false"; + } + + internal static void EnsureStarted() + { + EnsureStartedAsync().GetAwaiter().GetResult(); + } + + internal static Task EnsureStartedAsync() + { + lock (syncRoot) + { + initializationTask ??= StartContainerAsync(); + return initializationTask; + } + } + + private static async Task StartContainerAsync() + { + try + { + container = new MySqlBuilder() + .WithImage("mysql:8.4") + .WithDatabase("NEntityFrameworkCoreExtensions") + .WithUsername("root") + .WithPassword("mysql") + .Build(); + + await container.StartAsync(); + RegisterCleanup(); + } + catch (Exception ex) + { + throw new InvalidOperationException("MySql tests require Docker when UseMySqlContainer is enabled.", ex); + } + } + + private static void RegisterCleanup() + { + lock (syncRoot) + { + if (cleanupRegistered) + return; + + AppDomain.CurrentDomain.ProcessExit += (_, _) => + { + if (container != null) + container.DisposeAsync().AsTask().GetAwaiter().GetResult(); + }; + cleanupRegistered = true; + } + } +} diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Common/TestDatabaseInitializer.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Common/TestDatabaseInitializer.cs new file mode 100644 index 0000000..5579398 --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Common/TestDatabaseInitializer.cs @@ -0,0 +1,76 @@ +using System.Threading.Tasks; +using Microsoft.EntityFrameworkCore; +using N.EntityFrameworkCore.Extensions.Test.Data; + +namespace N.EntityFrameworkCore.Extensions.Test.Common; + +internal static class TestDatabaseInitializer +{ + internal static void EnsureCreated(TestDbContext dbContext) + { + if (Config.UseMySqlContainer) + MySqlContainerManager.EnsureStarted(); + + dbContext.Database.EnsureCreated(); + CreateProviderSpecificObjects(dbContext); + } + + internal static async Task EnsureCreatedAsync(TestDbContext dbContext) + { + if (Config.UseMySqlContainer) + await MySqlContainerManager.EnsureStartedAsync(); + + await dbContext.Database.EnsureCreatedAsync(); + await CreateProviderSpecificObjectsAsync(dbContext); + } + + internal static void CreateProviderSpecificObjects(TestDbContext dbContext) + { + dbContext.Database.ExecuteSqlRaw("DROP TRIGGER IF EXISTS trg_order_modified_datetime_before_insert"); + dbContext.Database.ExecuteSqlRaw(""" + CREATE TRIGGER trg_order_modified_datetime_before_insert + BEFORE INSERT ON `Orders` + FOR EACH ROW + SET NEW.`DbModifiedDateTime` = NOW(6) + """); + dbContext.Database.ExecuteSqlRaw("DROP TRIGGER IF EXISTS trg_order_modified_datetime_before_update"); + dbContext.Database.ExecuteSqlRaw(""" + CREATE TRIGGER trg_order_modified_datetime_before_update + BEFORE UPDATE ON `Orders` + FOR EACH ROW + SET NEW.`DbModifiedDateTime` = NOW(6) + """); + dbContext.Database.ExecuteSqlRaw("DROP TRIGGER IF EXISTS trgProductWithTriggers"); + dbContext.Database.ExecuteSqlRaw(""" + CREATE TRIGGER trgProductWithTriggers + BEFORE INSERT ON `ProductsWithTrigger` + FOR EACH ROW + SET NEW.`Id` = NEW.`Id` + """); + } + + internal static async Task CreateProviderSpecificObjectsAsync(TestDbContext dbContext) + { + await dbContext.Database.ExecuteSqlRawAsync("DROP TRIGGER IF EXISTS trg_order_modified_datetime_before_insert"); + await dbContext.Database.ExecuteSqlRawAsync(""" + CREATE TRIGGER trg_order_modified_datetime_before_insert + BEFORE INSERT ON `Orders` + FOR EACH ROW + SET NEW.`DbModifiedDateTime` = NOW(6) + """); + await dbContext.Database.ExecuteSqlRawAsync("DROP TRIGGER IF EXISTS trg_order_modified_datetime_before_update"); + await dbContext.Database.ExecuteSqlRawAsync(""" + CREATE TRIGGER trg_order_modified_datetime_before_update + BEFORE UPDATE ON `Orders` + FOR EACH ROW + SET NEW.`DbModifiedDateTime` = NOW(6) + """); + await dbContext.Database.ExecuteSqlRawAsync("DROP TRIGGER IF EXISTS trgProductWithTriggers"); + await dbContext.Database.ExecuteSqlRawAsync(""" + CREATE TRIGGER trgProductWithTriggers + BEFORE INSERT ON `ProductsWithTrigger` + FOR EACH ROW + SET NEW.`Id` = NEW.`Id` + """); + } +} diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/Address.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/Address.cs new file mode 100644 index 0000000..6396e16 --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/Address.cs @@ -0,0 +1,18 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel.DataAnnotations.Schema; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +[ComplexType] +public class Address +{ + public required string Line1 { get; set; } + public string? Line2 { get; set; } + public required string City { get; set; } + public required string Country { get; set; } + public required string PostCode { get; set; } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/Enums/ProductStatus.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/Enums/ProductStatus.cs new file mode 100644 index 0000000..91a250d --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/Enums/ProductStatus.cs @@ -0,0 +1,7 @@ +namespace N.EntityFrameworkCore.Extensions.Test.Data.Enums; + +public enum ProductStatus +{ + InStock, + OutOfStock, +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/Order.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/Order.cs new file mode 100644 index 0000000..48f7715 --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/Order.cs @@ -0,0 +1,34 @@ +using System; +using System.ComponentModel.DataAnnotations; + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class Order +{ + [Key] + public long Id { get; set; } + public string ExternalId { get; set; } + public Guid? GlobalId { get; set; } + public decimal Price { get; set; } + public DateTime AddedDateTime { get; set; } + public DateTime? ModifiedDateTime { get; set; } + public DateTimeOffset? ModifiedDateTimeOffset { get; set; } + public bool DbActive { get; set; } + public DateTime DbAddedDateTime { get; set; } + public DateTime DbModifiedDateTime { get; set; } + public bool? Trigger { get; set; } + public bool Active { get; set; } + public OrderStatus Status { get; set; } + public Order() + { + AddedDateTime = DateTime.UtcNow; + Active = true; + } +} + +public enum OrderStatus +{ + Unknown, + Completed, + Error +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/OrderWithComplexType.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/OrderWithComplexType.cs new file mode 100644 index 0000000..6318d48 --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/OrderWithComplexType.cs @@ -0,0 +1,18 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel.DataAnnotations; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class OrderWithComplexType +{ + [Key] + public long Id { get; set; } + [Required] + public Address ShippingAddress { get; set; } + [Required] + public Address BillingAddress { get; set; } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/Position.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/Position.cs new file mode 100644 index 0000000..95ec3e4 --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/Position.cs @@ -0,0 +1,8 @@ +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class Position +{ + public int Building; + public int Aisle; + public int Bay; +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/Product.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/Product.cs new file mode 100644 index 0000000..84fcc0f --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/Product.cs @@ -0,0 +1,32 @@ +using System; +using System.ComponentModel.DataAnnotations; +using System.ComponentModel.DataAnnotations.Schema; +using N.EntityFrameworkCore.Extensions.Test.Data.Enums; + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class Product +{ + [Key] + [DatabaseGenerated(DatabaseGeneratedOption.None)] + public string Id { get; set; } + [StringLength(50)] + public string Name { get; set; } + public decimal Price { get; set; } + public bool OutOfStock { get; set; } + [Column("Status")] + [StringLength(25)] + public string StatusString { get; set; } + public int? ProductCategoryId { get; set; } + public System.Drawing.Color Color { get; set; } + public ProductStatus? StatusEnum { get; set; } + public DateTime? UpdatedDateTime { get; set; } + + public Position Position { get; set; } + + public virtual ProductCategory ProductCategory { get; set; } + public Product() + { + + } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/ProductCategory.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/ProductCategory.cs new file mode 100644 index 0000000..168263b --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/ProductCategory.cs @@ -0,0 +1,8 @@ +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class ProductCategory +{ + public int Id { get; set; } + public string Name { get; set; } + public bool Active { get; internal set; } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/ProductWithComplexKey.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/ProductWithComplexKey.cs new file mode 100644 index 0000000..948fc86 --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/ProductWithComplexKey.cs @@ -0,0 +1,25 @@ +using System; +using System.ComponentModel.DataAnnotations; +using System.ComponentModel.DataAnnotations.Schema; + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class ProductWithComplexKey +{ + public Guid Key1 { get; set; } + public Guid Key2 { get; set; } + public Guid Key3 { get; set; } + public Guid Key4 { get; set; } + public string ExternalId { get; set; } + public decimal Price { get; set; } + public bool OutOfStock { get; set; } + [Column("Status")] + [StringLength(25)] + public string StatusString { get; set; } + public DateTime? UpdatedDateTime { get; set; } + public ProductWithComplexKey() + { + Key3 = Guid.NewGuid(); + Key4 = Guid.NewGuid(); + } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/ProductWithCustomSchema.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/ProductWithCustomSchema.cs new file mode 100644 index 0000000..dead561 --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/ProductWithCustomSchema.cs @@ -0,0 +1,14 @@ +using System.ComponentModel.DataAnnotations; +using System.ComponentModel.DataAnnotations.Schema; + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class ProductWithCustomSchema +{ + [Key] + [DatabaseGenerated(DatabaseGeneratedOption.None)] + public string Id { get; set; } + [StringLength(50)] + public string Name { get; set; } + public decimal Price { get; set; } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/ProductWithTrigger.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/ProductWithTrigger.cs new file mode 100644 index 0000000..d3c9302 --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/ProductWithTrigger.cs @@ -0,0 +1,22 @@ +using System.ComponentModel.DataAnnotations; +using System.ComponentModel.DataAnnotations.Schema; + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class ProductWithTrigger +{ + [Key] + [DatabaseGenerated(DatabaseGeneratedOption.None)] + public string Id { get; set; } + [StringLength(50)] + public string Name { get; set; } + public decimal Price { get; set; } + public bool OutOfStock { get; set; } + [Column("Status")] + [StringLength(25)] + public string StatusString { get; set; } + public ProductWithTrigger() + { + + } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/TestDbContext.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/TestDbContext.cs new file mode 100644 index 0000000..ccda76f --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/TestDbContext.cs @@ -0,0 +1,65 @@ +using System; +using System.Drawing; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Diagnostics; +using N.EntityFrameworkCore.Extensions.Test.Common; + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class TestDbContext : DbContext +{ + public virtual DbSet Products { get; set; } + public virtual DbSet ProductCategories { get; set; } + public virtual DbSet ProductsWithCustomSchema { get; set; } + public virtual DbSet ProductsWithComplexKey { get; set; } + public virtual DbSet ProductsWithTrigger { get; set; } + public virtual DbSet Orders { get; set; } + public virtual DbSet OrdersWithComplexType { get; set; } + public virtual DbSet TpcPeople { get; set; } + public virtual DbSet TphPeople { get; set; } + public virtual DbSet TphCustomers { get; set; } + public virtual DbSet TphVendors { get; set; } + public virtual DbSet TptPeople { get; set; } + public virtual DbSet TptCustomers { get; set; } + public virtual DbSet TptVendors { get; set; } + + protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder) + { + var connectionString = Config.GetTestDatabaseConnectionString(); + optionsBuilder.UseMySql(connectionString, ServerVersion.AutoDetect(connectionString)); + optionsBuilder.SetupEfCoreExtensions(); + optionsBuilder.UseLazyLoadingProxies(); + optionsBuilder.ConfigureWarnings(warnings => + warnings.Ignore(RelationalEventId.PendingModelChangesWarning)); + } + protected override void OnModelCreating(ModelBuilder modelBuilder) + { + modelBuilder.Entity().ToTable("ProductsWithCustomSchema"); + modelBuilder.Entity().HasKey(c => new { c.Key1 }); + modelBuilder.Entity().Property("Key1").HasDefaultValueSql("(UUID())"); + modelBuilder.Entity().Property("Key2").HasDefaultValueSql("(UUID())"); + modelBuilder.Entity().HasKey(p => new { p.Key3, p.Key4 }); + modelBuilder.Entity().Property("DbAddedDateTime").HasDefaultValueSql("CURRENT_TIMESTAMP(6)"); + modelBuilder.Entity().Property("DbModifiedDateTime").HasDefaultValueSql("CURRENT_TIMESTAMP(6)").ValueGeneratedOnAddOrUpdate(); + modelBuilder.Entity().Property(p => p.DbActive).HasDefaultValueSql("TRUE"); + modelBuilder.Entity().Property(p => p.Status).HasConversion(); + modelBuilder.Entity().HasIndex(o => o.ExternalId); + modelBuilder.Entity(b => + { + b.ComplexProperty(e => e.BillingAddress); + b.ComplexProperty(e => e.ShippingAddress); + }); + modelBuilder.Entity().UseTpcMappingStrategy(); + modelBuilder.Entity().ToTable("TpcCustomer"); + modelBuilder.Entity().ToTable("TpcVendor"); + modelBuilder.Entity().Property("CreatedDate"); + modelBuilder.Entity().ToTable("TptPeople"); + modelBuilder.Entity().ToTable("TptCustomer"); + modelBuilder.Entity().ToTable("TptVendor"); + modelBuilder.Entity(t => + { + t.ComplexProperty(p => p.Position).IsRequired(); + t.Property(p => p.Color).HasConversion(x => x.ToArgb(), x => Color.FromArgb(x)); + }); + } +} diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/TpcCustomer.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/TpcCustomer.cs new file mode 100644 index 0000000..c67f4dc --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/TpcCustomer.cs @@ -0,0 +1,10 @@ +using System; + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class TpcCustomer : TpcPerson +{ + public string Email { get; set; } + public string Phone { get; set; } + public DateTime AddedDate { get; set; } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/TpcPerson.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/TpcPerson.cs new file mode 100644 index 0000000..c3823e5 --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/TpcPerson.cs @@ -0,0 +1,11 @@ +using System.ComponentModel.DataAnnotations.Schema; + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public abstract class TpcPerson +{ + [DatabaseGenerated(DatabaseGeneratedOption.None)] + public long Id { get; set; } + public string FirstName { get; set; } + public string LastName { get; set; } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/TpcVendor.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/TpcVendor.cs new file mode 100644 index 0000000..7c32100 --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/TpcVendor.cs @@ -0,0 +1,9 @@ + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class TpcVendor : TpcPerson +{ + public string Email { get; set; } + public string Phone { get; set; } + public string Url { get; set; } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/TphCustomer.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/TphCustomer.cs new file mode 100644 index 0000000..5325781 --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/TphCustomer.cs @@ -0,0 +1,10 @@ +using System; + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class TphCustomer : TphPerson +{ + public string Email { get; set; } + public string Phone { get; set; } + public DateTime AddedDate { get; set; } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/TphPerson.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/TphPerson.cs new file mode 100644 index 0000000..37f090b --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/TphPerson.cs @@ -0,0 +1,11 @@ +using System.ComponentModel.DataAnnotations.Schema; + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +[Table("TphPeople")] +public abstract class TphPerson +{ + public long Id { get; set; } + public string FirstName { get; set; } + public string LastName { get; set; } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/TphVendor.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/TphVendor.cs new file mode 100644 index 0000000..2effa1d --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/TphVendor.cs @@ -0,0 +1,9 @@ + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class TphVendor : TphPerson +{ + public string Email { get; set; } + public string Phone { get; set; } + public string Url { get; set; } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/TptCustomer.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/TptCustomer.cs new file mode 100644 index 0000000..ec17284 --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/TptCustomer.cs @@ -0,0 +1,10 @@ +using System; + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class TptCustomer : TptPerson +{ + public string Email { get; set; } + public string Phone { get; set; } + public DateTime AddedDate { get; set; } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/TptPerson.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/TptPerson.cs new file mode 100644 index 0000000..4c23db2 --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/TptPerson.cs @@ -0,0 +1,11 @@ +using System.ComponentModel.DataAnnotations.Schema; + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class TptPerson +{ + [DatabaseGenerated(DatabaseGeneratedOption.None)] + public long Id { get; set; } + public string FirstName { get; set; } + public string LastName { get; set; } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/TptVendor.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/TptVendor.cs new file mode 100644 index 0000000..ca9eb25 --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/Data/TptVendor.cs @@ -0,0 +1,9 @@ + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class TptVendor : TptPerson +{ + public string Email { get; set; } + public string Phone { get; set; } + public string Url { get; set; } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DatabaseExtensions/DatabaseExtensionsBase.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DatabaseExtensions/DatabaseExtensionsBase.cs new file mode 100644 index 0000000..b7bf6b2 --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DatabaseExtensions/DatabaseExtensionsBase.cs @@ -0,0 +1,70 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Common; +using N.EntityFrameworkCore.Extensions.Test.Data; + +namespace N.EntityFrameworkCore.Extensions.Test.DatabaseExtensions; + +public class DatabaseExtensionsBase +{ + private TestDbContext _currentDbContext; + + [TestCleanup] + public void Cleanup() + { + _currentDbContext?.Dispose(); + _currentDbContext = null; + } + + protected TestDbContext SetupDbContext(bool populateData) + { + var dbContext = new TestDbContext(); + _currentDbContext = dbContext; + TestDatabaseInitializer.EnsureCreated(dbContext); + dbContext.Orders.Truncate(); + if (populateData) + { + var orders = new List(); + int id = 1; + for (int i = 0; i < 2050; i++) + { + DateTime addedDateTime = DateTime.UtcNow.AddDays(-id); + orders.Add(new Order + { + Id = id, + ExternalId = string.Format("id-{0}", i), + Price = 1.25M, + AddedDateTime = addedDateTime, + ModifiedDateTime = addedDateTime.AddHours(3) + }); + id++; + } + for (int i = 0; i < 1050; i++) + { + orders.Add(new Order { Id = id, Price = 5.35M }); + id++; + } + for (int i = 0; i < 2050; i++) + { + orders.Add(new Order { Id = id, Price = 1.25M }); + id++; + } + for (int i = 0; i < 6000; i++) + { + orders.Add(new Order { Id = id, Price = 15.35M }); + id++; + } + for (int i = 0; i < 6000; i++) + { + orders.Add(new Order { Id = id, Price = 15.35M }); + id++; + } + + Debug.WriteLine("Last Id for Order is {0}", id); + dbContext.BulkInsert(orders, new BulkInsertOptions() { KeepIdentity = true }); + } + return dbContext; + } +} diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DatabaseExtensions/SqlQueryToCsvFile.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DatabaseExtensions/SqlQueryToCsvFile.cs new file mode 100644 index 0000000..c75dce6 --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DatabaseExtensions/SqlQueryToCsvFile.cs @@ -0,0 +1,36 @@ +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Common; + +namespace N.EntityFrameworkCore.Extensions.Test.DatabaseExtensions; + +[TestClass] +public class SqlQueryToCsvFile : DatabaseExtensionsBase +{ + [TestMethod] + public void With_Default_Options() + { + var dbContext = SetupDbContext(true); + int count = dbContext.Orders.Where(o => o.Price > 5M).Count(); + string sql = $"SELECT * FROM {Config.DelimitTableName("Orders")} WHERE {Config.DelimitIdentifier("Price")} > @Price"; + var queryToCsvFileResult = dbContext.Database.SqlQueryToCsvFile("SqlQueryToCsvFile-Test.csv", sql, Config.CreateParameter("@Price", 5M)); + + Assert.IsTrue(count > 0, "There should be existing data in the source table"); + Assert.IsTrue(queryToCsvFileResult.DataRowCount == count, "The number of data rows written to the file should match the count from the database"); + Assert.IsTrue(queryToCsvFileResult.TotalRowCount == count + 1, "The total number of rows written to the file should match the count from the database plus the header row"); + } + [TestMethod] + public void With_Options_ColumnDelimiter_TextQualifer() + { + var dbContext = SetupDbContext(true); + string filePath = "SqlQueryToCsvFile_Options_ColumnDelimiter_TextQualifer-Test.csv"; + int count = dbContext.Orders.Where(o => o.Price > 5M).Count(); + string sql = $"SELECT * FROM {Config.DelimitTableName("Orders")} WHERE {Config.DelimitIdentifier("Price")} > @Price"; + var queryToCsvFileResult = dbContext.Database.SqlQueryToCsvFile(filePath, options => { options.ColumnDelimiter = "|"; options.TextQualifer = "\""; }, + sql, Config.CreateParameter("@Price", 5M)); + + Assert.IsTrue(count > 0, "There should be existing data in the source table"); + Assert.IsTrue(queryToCsvFileResult.DataRowCount == count, "The number of data rows written to the file should match the count from the database"); + Assert.IsTrue(queryToCsvFileResult.TotalRowCount == count + 1, "The total number of rows written to the file should match the count from the database plus the header row"); + } +} diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DatabaseExtensions/SqlQueryToCsvFileAsync.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DatabaseExtensions/SqlQueryToCsvFileAsync.cs new file mode 100644 index 0000000..2a87427 --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DatabaseExtensions/SqlQueryToCsvFileAsync.cs @@ -0,0 +1,37 @@ +using System.Linq; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Common; + +namespace N.EntityFrameworkCore.Extensions.Test.DatabaseExtensions; + +[TestClass] +public class SqlQueryToCsvFileAsync : DatabaseExtensionsBase +{ + [TestMethod] + public async Task With_Default_Options() + { + var dbContext = SetupDbContext(true); + int count = dbContext.Orders.Where(o => o.Price > 5M).Count(); + string sql = $"SELECT * FROM {Config.DelimitTableName("Orders")} WHERE {Config.DelimitIdentifier("Price")} > @Price"; + var queryToCsvFileResult = await dbContext.Database.SqlQueryToCsvFileAsync("SqlQueryToCsvFile-Test.csv", sql, new object[] { Config.CreateParameter("@Price", 5M) }); + + Assert.IsTrue(count > 0, "There should be existing data in the source table"); + Assert.IsTrue(queryToCsvFileResult.DataRowCount == count, "The number of data rows written to the file should match the count from the database"); + Assert.IsTrue(queryToCsvFileResult.TotalRowCount == count + 1, "The total number of rows written to the file should match the count from the database plus the header row"); + } + [TestMethod] + public async Task With_Options_ColumnDelimiter_TextQualifer() + { + var dbContext = SetupDbContext(true); + string filePath = "SqlQueryToCsvFile_Options_ColumnDelimiter_TextQualifer-Test.csv"; + int count = dbContext.Orders.Where(o => o.Price > 5M).Count(); + string sql = $"SELECT * FROM {Config.DelimitTableName("Orders")} WHERE {Config.DelimitIdentifier("Price")} > @Price"; + var queryToCsvFileResult = await dbContext.Database.SqlQueryToCsvFileAsync(filePath, options => { options.ColumnDelimiter = "|"; options.TextQualifer = "\""; }, + sql, new object[] { Config.CreateParameter("@Price", 5M) }); + + Assert.IsTrue(count > 0, "There should be existing data in the source table"); + Assert.IsTrue(queryToCsvFileResult.DataRowCount == count, "The number of data rows written to the file should match the count from the database"); + Assert.IsTrue(queryToCsvFileResult.TotalRowCount == count + 1, "The total number of rows written to the file should match the count from the database plus the header row"); + } +} diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DatabaseExtensions/SqlQuery_Count.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DatabaseExtensions/SqlQuery_Count.cs new file mode 100644 index 0000000..19d10e6 --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DatabaseExtensions/SqlQuery_Count.cs @@ -0,0 +1,34 @@ +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Common; + +namespace N.EntityFrameworkCore.Extensions.Test.DatabaseExtensions; + +[TestClass] +public class SqlQuery_Count : DatabaseExtensionsBase +{ + [TestMethod] + public void With_Decimal_Value() + { + var dbContext = SetupDbContext(true); + int efCount = dbContext.Orders.Where(o => o.Price > 5M).Count(); + string sql = $"SELECT * FROM {Config.DelimitTableName("Orders")} WHERE {Config.DelimitIdentifier("Price")} > @Price"; + var sqlCount = dbContext.Database.FromSqlQuery(sql, Config.CreateParameter("@Price", 5M)).Count(); + + Assert.IsTrue(efCount > 0, "Count from EF should be greater than zero"); + Assert.IsTrue(efCount > 0, "Count from SQL should be greater than zero"); + Assert.IsTrue(efCount == sqlCount, "Count from EF should match the count from the SqlQuery"); + } + [TestMethod] + public void With_OrderBy() + { + var dbContext = SetupDbContext(true); + int efCount = dbContext.Orders.Where(o => o.Price > 5M).Count(); + string sql = $"SELECT * FROM {Config.DelimitTableName("Orders")} WHERE {Config.DelimitIdentifier("Price")} > @Price ORDER BY {Config.DelimitIdentifier("Id")}"; + var sqlCount = dbContext.Database.FromSqlQuery(sql, Config.CreateParameter("@Price", 5M)).Count(); + + Assert.IsTrue(efCount > 0, "Count from EF should be greater than zero"); + Assert.IsTrue(efCount > 0, "Count from SQL should be greater than zero"); + Assert.IsTrue(efCount == sqlCount, "Count from EF should match the count from the SqlQuery"); + } +} diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DatabaseExtensions/SqlQuery_CountAsync.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DatabaseExtensions/SqlQuery_CountAsync.cs new file mode 100644 index 0000000..4a5ade2 --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DatabaseExtensions/SqlQuery_CountAsync.cs @@ -0,0 +1,35 @@ +using System.Linq; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Common; + +namespace N.EntityFrameworkCore.Extensions.Test.DatabaseExtensions; + +[TestClass] +public class SqlQuery_CountAsync : DatabaseExtensionsBase +{ + [TestMethod] + public async Task With_Decimal_Value() + { + var dbContext = SetupDbContext(true); + int efCount = dbContext.Orders.Where(o => o.Price > 5M).Count(); + string sql = $"SELECT * FROM {Config.DelimitTableName("Orders")} WHERE {Config.DelimitIdentifier("Price")} > @Price"; + var sqlCount = await dbContext.Database.FromSqlQuery(sql, Config.CreateParameter("@Price", 5M)).CountAsync(); + + Assert.IsTrue(efCount > 0, "Count from EF should be greater than zero"); + Assert.IsTrue(efCount > 0, "Count from SQL should be greater than zero"); + Assert.IsTrue(efCount == sqlCount, "Count from EF should match the count from the SqlQuery"); + } + [TestMethod] + public async Task With_OrderBy() + { + var dbContext = SetupDbContext(true); + int efCount = dbContext.Orders.Where(o => o.Price > 5M).Count(); + string sql = $"SELECT * FROM {Config.DelimitTableName("Orders")} WHERE {Config.DelimitIdentifier("Price")} > @Price ORDER BY {Config.DelimitIdentifier("Id")}"; + var sqlCount = await dbContext.Database.FromSqlQuery(sql, Config.CreateParameter("@Price", 5M)).CountAsync(); + + Assert.IsTrue(efCount > 0, "Count from EF should be greater than zero"); + Assert.IsTrue(efCount > 0, "Count from SQL should be greater than zero"); + Assert.IsTrue(efCount == sqlCount, "Count from EF should match the count from the SqlQuery"); + } +} diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DatabaseExtensions/TableExists.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DatabaseExtensions/TableExists.cs new file mode 100644 index 0000000..75fa7f7 --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DatabaseExtensions/TableExists.cs @@ -0,0 +1,20 @@ +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace N.EntityFrameworkCore.Extensions.Test.DatabaseExtensions; + +[TestClass] +public class TableExists : DatabaseExtensionsBase +{ + [TestMethod] + public void With_Orders_Table() + { + var dbContext = SetupDbContext(true); + int efCount = dbContext.Orders.Where(o => o.Price > 5M).Count(); + bool ordersTableExists = dbContext.Database.TableExists("Orders"); + bool orderNewTableExists = dbContext.Database.TableExists("OrdersNew"); + + Assert.IsTrue(ordersTableExists, "Orders table should exist"); + Assert.IsTrue(!orderNewTableExists, "Orders_New table should not exist"); + } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DatabaseExtensions/TruncateTable.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DatabaseExtensions/TruncateTable.cs new file mode 100644 index 0000000..b1e3e52 --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DatabaseExtensions/TruncateTable.cs @@ -0,0 +1,20 @@ +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace N.EntityFrameworkCore.Extensions.Test.DatabaseExtensions; + +[TestClass] +public class TruncateTable : DatabaseExtensionsBase +{ + [TestMethod] + public void With_Orders_Table() + { + var dbContext = SetupDbContext(true); + int oldOrdersCount = dbContext.Orders.Count(); + dbContext.Database.TruncateTable("Orders"); + int newOrdersCount = dbContext.Orders.Count(); + + Assert.IsTrue(oldOrdersCount > 0, "Orders table should have data"); + Assert.IsTrue(newOrdersCount == 0, "Order table should be empty after truncating"); + } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DatabaseExtensions/TruncateTableAsync.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DatabaseExtensions/TruncateTableAsync.cs new file mode 100644 index 0000000..d5f2dee --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DatabaseExtensions/TruncateTableAsync.cs @@ -0,0 +1,21 @@ +using System.Linq; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace N.EntityFrameworkCore.Extensions.Test.DatabaseExtensions; + +[TestClass] +public class TruncateTableAsync : DatabaseExtensionsBase +{ + [TestMethod] + public async Task With_Orders_Table() + { + var dbContext = SetupDbContext(true); + int oldOrdersCount = dbContext.Orders.Count(); + await dbContext.Database.TruncateTableAsync("Orders"); + int newOrdersCount = dbContext.Orders.Count(); + + Assert.IsTrue(oldOrdersCount > 0, "Orders table should have data"); + Assert.IsTrue(newOrdersCount == 0, "Order table should be empty after truncating"); + } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/BulkDelete.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/BulkDelete.cs new file mode 100644 index 0000000..e5eeaa3 --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/BulkDelete.cs @@ -0,0 +1,91 @@ +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Data; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class BulkDelete : DbContextExtensionsBase +{ + [TestMethod] + public void With_Default_Options() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M).ToList(); + int rowsDeleted = dbContext.BulkDelete(orders); + int newTotal = dbContext.Orders.Where(o => o.Price == 1.25M).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(rowsDeleted == orders.Count, "The number of rows deleted must match the count of existing rows in database"); + Assert.IsTrue(newTotal == 0, "Must be 0 to indicate all records were deleted"); + } + [TestMethod] + public void With_Inheritance_Tpc() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpc); + var customers = dbContext.TpcPeople.OfType().ToList(); + int rowsDeleted = dbContext.BulkDelete(customers); + var newCustomers = dbContext.TpcPeople.OfType().Count(); + + Assert.IsTrue(customers.Count > 0, "There must be tphCustomer records in database"); + Assert.IsTrue(rowsDeleted == customers.Count, "The number of rows deleted must match the count of existing rows in database"); + Assert.IsTrue(newCustomers == 0, "Must be 0 to indicate all records were deleted"); + } + [TestMethod] + public void With_Inheritance_Tph() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tph); + var customers = dbContext.TphPeople.OfType().ToList(); + int rowsDeleted = dbContext.BulkDelete(customers); + var newCustomers = dbContext.TphPeople.OfType().Count(); + + Assert.IsTrue(customers.Count > 0, "There must be tphCustomer records in database"); + Assert.IsTrue(rowsDeleted == customers.Count, "The number of rows deleted must match the count of existing rows in database"); + Assert.IsTrue(newCustomers == 0, "Must be 0 to indicate all records were deleted"); + } + [TestMethod] + public void With_Inheritance_Tpt() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpt); + var customers = dbContext.TptCustomers.ToList(); + int rowsDeleted = dbContext.BulkDelete(customers); + var newCustomers = dbContext.TptCustomers.Count(); + + Assert.IsTrue(customers.Count > 0, "There must be tphCustomer records in database"); + Assert.IsTrue(rowsDeleted == customers.Count, "The number of rows deleted must match the count of existing rows in database"); + Assert.IsTrue(newCustomers == 0, "Must be 0 to indicate all records were deleted"); + } + [TestMethod] + public void With_Options_DeleteOnCondition() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Where(o => o.Price == 1.25M).Count(); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).ToList(); + int rowsDeleted = dbContext.BulkDelete(orders, options => { options.DeleteOnCondition = (s, t) => s.ExternalId == t.ExternalId; options.UsePermanentTable = true; }); + int newTotal = dbContext.Orders.Where(o => o.Price == 1.25M).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price < $2)"); + Assert.IsTrue(rowsDeleted == orders.Count, "The number of rows deleted must match the count of existing rows in database"); + Assert.IsTrue(newTotal == oldTotal - rowsDeleted, "Must be 0 to indicate all records were deleted"); + } + [DoNotParallelize] + [TestMethod] + public void With_Transaction() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M).ToList(); + int rowsDeleted, newTotal = 0; + using (var transaction = dbContext.Database.BeginTransaction()) + { + rowsDeleted = dbContext.BulkDelete(orders); + newTotal = dbContext.Orders.Where(o => o.Price == 1.25M).Count(); + transaction.Rollback(); + } + var rollbackTotal = dbContext.Orders.Count(o => o.Price == 1.25M); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price < $2)"); + Assert.IsTrue(rowsDeleted == orders.Count, "The number of rows deleted must match the count of existing rows in database"); + Assert.IsTrue(newTotal == 0, "Must be 0 to indicate all records were deleted"); + Assert.IsTrue(rollbackTotal == orders.Count, "The number of rows after the transacation has been rollbacked should match the original count"); + } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/BulkDeleteAsync.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/BulkDeleteAsync.cs new file mode 100644 index 0000000..ef60220 --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/BulkDeleteAsync.cs @@ -0,0 +1,92 @@ +using System.Linq; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Data; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class BulkDeleteAsync : DbContextExtensionsBase +{ + [TestMethod] + public async Task With_Default_Options() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M).ToList(); + int rowsDeleted = await dbContext.BulkDeleteAsync(orders); + int newTotal = dbContext.Orders.Where(o => o.Price == 1.25M).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(rowsDeleted == orders.Count, "The number of rows deleted must match the count of existing rows in database"); + Assert.IsTrue(newTotal == 0, "Must be 0 to indicate all records were deleted"); + } + [TestMethod] + public async Task With_Inheritance_Tpc() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpc); + var customers = dbContext.TpcPeople.OfType().ToList(); + int rowsDeleted = await dbContext.BulkDeleteAsync(customers); + var newCustomers = dbContext.TpcPeople.OfType().Count(); + + Assert.IsTrue(customers.Count > 0, "There must be tphCustomer records in database"); + Assert.IsTrue(rowsDeleted == customers.Count, "The number of rows deleted must match the count of existing rows in database"); + Assert.IsTrue(newCustomers == 0, "Must be 0 to indicate all records were deleted"); + } + [TestMethod] + public async Task With_Inheritance_Tph() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tph); + var customers = dbContext.TphPeople.OfType().ToList(); + int rowsDeleted = await dbContext.BulkDeleteAsync(customers); + var newCustomers = dbContext.TphPeople.OfType().Count(); + + Assert.IsTrue(customers.Count > 0, "There must be tphCustomer records in database"); + Assert.IsTrue(rowsDeleted == customers.Count, "The number of rows deleted must match the count of existing rows in database"); + Assert.IsTrue(newCustomers == 0, "Must be 0 to indicate all records were deleted"); + } + [TestMethod] + public async Task With_Inheritance_Tpt() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpt); + var customers = dbContext.TptCustomers.ToList(); + int rowsDeleted = await dbContext.BulkDeleteAsync(customers); + var newCustomers = dbContext.TptCustomers.Count(); + + Assert.IsTrue(customers.Count > 0, "There must be tphCustomer records in database"); + Assert.IsTrue(rowsDeleted == customers.Count, "The number of rows deleted must match the count of existing rows in database"); + Assert.IsTrue(newCustomers == 0, "Must be 0 to indicate all records were deleted"); + } + [TestMethod] + public async Task With_Options_DeleteOnCondition() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Where(o => o.Price == 1.25M).Count(); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).ToList(); + int rowsDeleted = await dbContext.BulkDeleteAsync(orders, options => { options.DeleteOnCondition = (s, t) => s.ExternalId == t.ExternalId; options.UsePermanentTable = true; }); + int newTotal = dbContext.Orders.Where(o => o.Price == 1.25M).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price < $2)"); + Assert.IsTrue(rowsDeleted == orders.Count, "The number of rows deleted must match the count of existing rows in database"); + Assert.IsTrue(newTotal == oldTotal - rowsDeleted, "Must be 0 to indicate all records were deleted"); + } + [DoNotParallelize] + [TestMethod] + public async Task With_Transaction() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M).ToList(); + int rowsDeleted, newTotal = 0; + using (var transaction = dbContext.Database.BeginTransaction()) + { + rowsDeleted = await dbContext.BulkDeleteAsync(orders); + newTotal = dbContext.Orders.Where(o => o.Price == 1.25M).Count(); + transaction.Rollback(); + } + var rollbackTotal = dbContext.Orders.Count(o => o.Price == 1.25M); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price < $2)"); + Assert.IsTrue(rowsDeleted == orders.Count, "The number of rows deleted must match the count of existing rows in database"); + Assert.IsTrue(newTotal == 0, "Must be 0 to indicate all records were deleted"); + Assert.IsTrue(rollbackTotal == orders.Count, "The number of rows after the transacation has been rollbacked should match the original count"); + } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/BulkFetch.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/BulkFetch.cs new file mode 100644 index 0000000..fe34879 --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/BulkFetch.cs @@ -0,0 +1,114 @@ +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class BulkFetch : DbContextExtensionsBase +{ + [TestMethod] + public void With_Complex_Property() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(o => o.Price == 1.25m).ToList(); + var fetchedProducts = dbContext.Products.BulkFetch(products); + bool foundNullPositionProperty = fetchedProducts.Any(o => o.Position == null); + + Assert.IsTrue(products.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(products.Count == fetchedProducts.Count(), "The number of rows deleted must match the count of existing rows in database"); + Assert.IsFalse(foundNullPositionProperty, "The Position complex property should be populated when using BulkFetch()"); + } + [TestMethod] + public void With_Default_Options() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M).ToList(); + var fetchedOrders = dbContext.Orders.BulkFetch(orders); + bool ordersAreMatched = true; + + foreach (var fetchedOrder in fetchedOrders) + { + var order = orders.First(o => o.Id == fetchedOrder.Id); + if (order.ExternalId != fetchedOrder.ExternalId || order.AddedDateTime != fetchedOrder.AddedDateTime || order.ModifiedDateTime != fetchedOrder.ModifiedDateTime) + { + ordersAreMatched = false; + break; + } + } + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(orders.Count == fetchedOrders.Count(), "The number of rows deleted must match the count of existing rows in database"); + Assert.IsTrue(ordersAreMatched, "The orders from BulkFetch() should match what is retrieved from DbContext"); + } + [TestMethod] + public void With_Enum() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(o => o.Price == 1.25m).ToList(); + var fetchedProducts = dbContext.Products.BulkFetch(products); + bool productsAreMatched = true; + + foreach (var fetchedProduct in fetchedProducts) + { + var product = products.First(o => o.Id == fetchedProduct.Id); + if (product.Id != fetchedProduct.Id || product.Name != fetchedProduct.Name || product.StatusEnum != fetchedProduct.StatusEnum) + { + productsAreMatched = false; + break; + } + } + + Assert.IsTrue(products.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(products.Count == fetchedProducts.Count(), "The number of rows deleted must match the count of existing rows in database"); + Assert.IsTrue(productsAreMatched, "The products from BulkFetch() should match what is retrieved from DbContext"); + } + [TestMethod] + public void With_IQueryable() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price <= 10 && o.ExternalId != null); + var fetchedOrders = dbContext.Orders.BulkFetch(orders, options => { options.IgnoreColumns = o => new { o.ExternalId }; }).ToList(); + int newTotal = dbContext.Orders.Where(o => o.Price <= 10 && o.ExternalId == null).Count(); + bool foundNullExternalId = fetchedOrders.Where(o => o.ExternalId != null).Any(); + + Assert.IsTrue(orders.Count() > 0, "There must be orders in the database that match condition (Price <= 10 And ExternalId != null)"); + Assert.IsTrue(orders.Count() == fetchedOrders.Count(), "The number of orders must match the number of fetched orders"); + Assert.IsTrue(!foundNullExternalId, "Fetched orders should not contain any items where ExternalId is null."); + } + [TestMethod] + public void With_Options_IgnoreColumns() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price <= 10 && o.ExternalId != null).ToList(); + var fetchedOrders = dbContext.Orders.BulkFetch(orders, options => { options.IgnoreColumns = o => new { o.ExternalId }; }).ToList(); + int newTotal = dbContext.Orders.Where(o => o.Price <= 10 && o.ExternalId == null).Count(); + bool foundNullExternalId = fetchedOrders.Where(o => o.ExternalId != null).Any(); + + Assert.IsTrue(orders.Count() > 0, "There must be orders in the database that match condition (Price <= 10 And ExternalId != null)"); + Assert.IsTrue(orders.Count() == fetchedOrders.Count(), "The number of orders must match the number of fetched orders"); + Assert.IsTrue(!foundNullExternalId, "Fetched orders should not contain any items where ExternalId is null."); + } + [TestMethod] + public void With_ValueConverter() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(o => o.Price == 1.25M).ToList(); + var fetchedProducts = dbContext.Products.BulkFetch(products); + bool areMatched = true; + + foreach (var fetchedProduct in fetchedProducts) + { + var product = products.First(o => o.Id == fetchedProduct.Id); + if (product.Name != fetchedProduct.Name || product.Price != fetchedProduct.Price + || product.Color != fetchedProduct.Color) + { + areMatched = false; + break; + } + } + + Assert.IsTrue(products.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(products.Count == fetchedProducts.Count(), "The number of rows deleted must match the count of existing rows in database"); + Assert.IsTrue(areMatched, "The products from BulkFetch() should match what is retrieved from DbContext"); + } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/BulkFetchAsync.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/BulkFetchAsync.cs new file mode 100644 index 0000000..9be0270 --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/BulkFetchAsync.cs @@ -0,0 +1,113 @@ +using System.Linq; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class BulkFetchAsync : DbContextExtensionsBase +{ + [TestMethod] + public async Task With_Complex_Property() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(o => o.Price == 1.25m).ToList(); + var fetchedProducts = (await dbContext.Products.BulkFetchAsync(products)).ToList(); + bool foundNullPositionProperty = fetchedProducts.Any(o => o.Position == null); + + Assert.IsTrue(products.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(products.Count == fetchedProducts.Count, "The number of rows deleted must match the count of existing rows in database"); + Assert.IsFalse(foundNullPositionProperty, "The Position complex property should be populated when using BulkFetchAsync()"); + } + [TestMethod] + public async Task With_Default_Options() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M).ToList(); + var fetchedOrders = (await dbContext.Orders.BulkFetchAsync(orders)).ToList(); + bool ordersAreMatched = true; + + foreach (var fetchedOrder in fetchedOrders) + { + var order = orders.First(o => o.Id == fetchedOrder.Id); + if (order.ExternalId != fetchedOrder.ExternalId || order.AddedDateTime != fetchedOrder.AddedDateTime || order.ModifiedDateTime != fetchedOrder.ModifiedDateTime) + { + ordersAreMatched = false; + break; + } + } + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(orders.Count == fetchedOrders.Count, "The number of rows deleted must match the count of existing rows in database"); + Assert.IsTrue(ordersAreMatched, "The orders from BulkFetchAsync() should match what is retrieved from DbContext"); + } + [TestMethod] + public async Task With_Enum() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(o => o.Price == 1.25m).ToList(); + var fetchedProducts = (await dbContext.Products.BulkFetchAsync(products)).ToList(); + bool productsAreMatched = true; + + foreach (var fetchedProduct in fetchedProducts) + { + var product = products.First(o => o.Id == fetchedProduct.Id); + if (product.Id != fetchedProduct.Id || product.Name != fetchedProduct.Name || product.StatusEnum != fetchedProduct.StatusEnum) + { + productsAreMatched = false; + break; + } + } + + Assert.IsTrue(products.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(products.Count == fetchedProducts.Count, "The number of rows deleted must match the count of existing rows in database"); + Assert.IsTrue(productsAreMatched, "The products from BulkFetchAsync() should match what is retrieved from DbContext"); + } + [TestMethod] + public async Task With_IQueryable() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price <= 10 && o.ExternalId != null); + var fetchedOrders = (await dbContext.Orders.BulkFetchAsync(orders, options => { options.IgnoreColumns = o => new { o.ExternalId }; })).ToList(); + bool foundNonNullExternalId = fetchedOrders.Any(o => o.ExternalId != null); + + Assert.IsTrue(orders.Count() > 0, "There must be orders in the database that match condition (Price <= 10 And ExternalId != null)"); + Assert.IsTrue(orders.Count() == fetchedOrders.Count, "The number of orders must match the number of fetched orders"); + Assert.IsFalse(foundNonNullExternalId, "Fetched orders should not contain any items where ExternalId is not null."); + } + [TestMethod] + public async Task With_Options_IgnoreColumns() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price <= 10 && o.ExternalId != null).ToList(); + var fetchedOrders = (await dbContext.Orders.BulkFetchAsync(orders, options => { options.IgnoreColumns = o => new { o.ExternalId }; })).ToList(); + bool foundNonNullExternalId = fetchedOrders.Any(o => o.ExternalId != null); + + Assert.IsTrue(orders.Count() > 0, "There must be orders in the database that match condition (Price <= 10 And ExternalId != null)"); + Assert.IsTrue(orders.Count() == fetchedOrders.Count, "The number of orders must match the number of fetched orders"); + Assert.IsFalse(foundNonNullExternalId, "Fetched orders should not contain any items where ExternalId is not null."); + } + [TestMethod] + public async Task With_ValueConverter() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(o => o.Price == 1.25M).ToList(); + var fetchedProducts = (await dbContext.Products.BulkFetchAsync(products)).ToList(); + bool areMatched = true; + + foreach (var fetchedProduct in fetchedProducts) + { + var product = products.First(o => o.Id == fetchedProduct.Id); + if (product.Name != fetchedProduct.Name || product.Price != fetchedProduct.Price + || product.Color != fetchedProduct.Color) + { + areMatched = false; + break; + } + } + + Assert.IsTrue(products.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(products.Count == fetchedProducts.Count, "The number of rows deleted must match the count of existing rows in database"); + Assert.IsTrue(areMatched, "The products from BulkFetchAsync() should match what is retrieved from DbContext"); + } +} diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/BulkInsert.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/BulkInsert.cs new file mode 100644 index 0000000..a5c90fb --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/BulkInsert.cs @@ -0,0 +1,458 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.EntityFrameworkCore; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Common; +using N.EntityFrameworkCore.Extensions.Test.Data; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class BulkInsert : DbContextExtensionsBase +{ + [TestMethod] + public void With_Complex_Key() + { + var dbContext = SetupDbContext(true); + var products = new List(); + for (int i = 50000; i < 60000; i++) + { + var key = i.ToString(); + products.Add(new ProductWithComplexKey { Price = 1.57M }); + } + int oldTotal = dbContext.ProductsWithComplexKey.Where(o => o.Price <= 10).Count(); + int rowsInserted = dbContext.BulkInsert(products); + int newTotal = dbContext.ProductsWithComplexKey.Where(o => o.Price <= 10).Count(); + + Assert.IsTrue(rowsInserted == products.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of rows inserted."); + } + [TestMethod] + public void With_Complex_Type() + { + var dbContext = SetupDbContext(true); + var orders = new List(); + for (int i = 1; i < 1000; i++) + { + orders.Add(new OrderWithComplexType + { + Id = i, + ShippingAddress = new Address + { + Line1 = $"123 Main St, {i}", + City = "Atlanta", + Country = "USA", + PostCode = "30303" + }, + BillingAddress = new Address + { + Line1 = $"456 Oak St, {i}", + City = "Atlanta", + Country = "USA", + PostCode = "30303" + } + }); + } + int oldTotal = dbContext.OrdersWithComplexType.Count(); + int rowsInserted = dbContext.BulkInsert(orders); + int newTotal = dbContext.OrdersWithComplexType.Count(); + + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of rows inserted."); + } + [TestMethod] + public void With_Default_Options() + { + var dbContext = SetupDbContext(false); + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M }); + } + int oldTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + int rowsInserted = dbContext.BulkInsert(orders); + int newTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of rows inserted."); + } + [TestMethod] + public void With_Inheritance_Tpc() + { + var dbContext = SetupDbContext(false); + var customers = new List(); + var vendors = new List(); + for (int i = 0; i < 20000; i++) + { + customers.Add(new TpcCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + AddedDate = DateTime.UtcNow + }); + } + for (int i = 20000; i < 30000; i++) + { + vendors.Add(new TpcVendor + { + Id = i, + FirstName = string.Format("Mike_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("mike.smith{0}@domain.com", i), + Url = string.Format("http://domain.com/mike.smith{0}", i) + }); + } + int oldTotal = dbContext.TpcPeople.Count(); + int customerRowsInserted = dbContext.BulkInsert(customers, o => o.UsePermanentTable = true); + int vendorRowsInserted = dbContext.BulkInsert(vendors, o => o.UsePermanentTable = true); + int rowsInserted = customerRowsInserted + vendorRowsInserted; + int newTotal = dbContext.TpcPeople.Count(); + + Assert.IsTrue(rowsInserted == customers.Count + vendors.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of rows inserted."); + } + [TestMethod] + public void With_Inheritance_Tph() + { + var dbContext = SetupDbContext(false); + var customers = new List(); + var vendors = new List(); + for (int i = 0; i < 20000; i++) + { + customers.Add(new TphCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + Phone = "404-555-1111", + AddedDate = DateTime.UtcNow + }); + } + for (int i = 20000; i < 30000; i++) + { + vendors.Add(new TphVendor + { + Id = i, + FirstName = string.Format("Mike_{0}", i), + LastName = string.Format("Smith_{0}", i), + Phone = "404-555-2222", + Email = string.Format("mike.smith{0}@domain.com", i), + Url = string.Format("http://domain.com/mike.smith{0}", i) + }); + } + int oldTotal = dbContext.TphPeople.Count(); + int customerRowsInserted = dbContext.BulkInsert(customers); + int vendorRowsInserted = dbContext.BulkInsert(vendors); + int rowsInserted = customerRowsInserted + vendorRowsInserted; + int newTotal = dbContext.TphPeople.Count(); + + Assert.IsTrue(rowsInserted == customers.Count + vendors.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of rows inserted."); + } + [TestMethod] + public void With_Inheritance_Tpt() + { + var dbContext = SetupDbContext(false); + var customers = new List(); + var vendors = new List(); + for (int i = 0; i < 20000; i++) + { + customers.Add(new TptCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + Phone = "777-555-1234", + AddedDate = DateTime.UtcNow + }); + } + for (int i = 20000; i < 30000; i++) + { + vendors.Add(new TptVendor + { + Id = i, + FirstName = string.Format("Mike_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("mike.smith{0}@domain.com", i), + Url = string.Format("http://domain.com/mike.smith{0}", i) + }); + } + int oldTotal = dbContext.TptPeople.Count(); + int customerRowsInserted = dbContext.BulkInsert(customers, o => o.UsePermanentTable = true); + int vendorRowsInserted = dbContext.BulkInsert(vendors); + int rowsInserted = customerRowsInserted + vendorRowsInserted; + int newTotal = dbContext.TptPeople.Count(); + + Assert.IsTrue(rowsInserted == customers.Count + vendors.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of rows inserted."); + } + [TestMethod] + public void Without_Identity_Column() + { + var dbContext = SetupDbContext(true); + var products = new List(); + for (int i = 50000; i < 60000; i++) + { + products.Add(new Product { Id = i.ToString(), Price = 1.57M }); + } + int oldTotal = dbContext.Products.Where(o => o.Price <= 10).Count(); + int rowsInserted = dbContext.BulkInsert(products); + int newTotal = dbContext.Products.Where(o => o.Price <= 10).Count(); + + Assert.IsTrue(rowsInserted == products.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of rows inserted."); + } + [TestMethod] + public void With_Options_AutoMapIdentity() + { + var dbContext = SetupDbContext(false); + var orders = new List(); + for (int i = 0; i < 5000; i++) + { + orders.Add(new Order { ExternalId = i.ToString(), Price = ((decimal)i + 0.55M) }); + } + int rowsAdded = dbContext.BulkInsert(orders, new BulkInsertOptions + { + UsePermanentTable = true + }); + bool autoMapIdentityMatched = true; + var ordersInDb = dbContext.Orders.ToList(); + Order order1 = null; + Order order2 = null; + foreach (var order in orders) + { + order1 = order; + var orderinDb = ordersInDb.First(o => o.Id == order.Id); + order2 = orderinDb; + if (!(orderinDb.ExternalId == order.ExternalId && orderinDb.Price == order.Price)) + { + autoMapIdentityMatched = false; + break; + } + } + + Assert.IsTrue(rowsAdded == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(autoMapIdentityMatched, "The auto mapping of ids of entities that were merged failed to match up"); + } + [TestMethod] + public void With_Options_IgnoreColumns() + { + var dbContext = SetupDbContext(false); + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i, ExternalId = i.ToString(), Price = 1.57M, Active = true }); + } + int oldTotal = dbContext.Orders.Where(o => o.Price <= 10 && o.ExternalId == null).Count(); + int rowsInserted = dbContext.BulkInsert(orders, options => { options.UsePermanentTable = true; options.IgnoreColumns = o => new { o.ExternalId }; }); + int newTotal = dbContext.Orders.Where(o => o.Price <= 10 && o.ExternalId == null).Count(); + + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of rows inserted."); + } + [TestMethod] + public void With_Options_InputColumns() + { + var dbContext = SetupDbContext(false); + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i, ExternalId = i.ToString(), Price = 1.57M, Active = true, Status = OrderStatus.Completed }); + } + int oldTotal = dbContext.Orders.Where(o => o.Price == 1.57M && o.ExternalId == null && o.Active == true).Count(); + int rowsInserted = dbContext.BulkInsert(orders, options => + { + options.UsePermanentTable = true; + options.InputColumns = o => new { o.Price, o.Active, o.AddedDateTime, o.Status }; + }); + int newTotal = dbContext.Orders.Where(o => o.Price == 1.57M && o.ExternalId == null && o.Active == true).Count(); + + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of rows inserted."); + } + [TestMethod] + public void With_KeepIdentity() + { + var dbContext = SetupDbContext(false); + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i + 1000, Price = 1.57M }); + } + int oldTotal = dbContext.Orders.Count(); + int rowsInserted = dbContext.BulkInsert(orders, options => { options.KeepIdentity = true; options.BatchSize = 1000; }); + var oldOrders = dbContext.Orders.OrderBy(o => o.Id).ToList(); + var newOrders = dbContext.Orders.OrderBy(o => o.Id).ToList(); + bool allIdentityFieldsMatch = true; + for (int i = 0; i < 20000; i++) + { + if (newOrders[i].Id != oldOrders[i].Id) + { + allIdentityFieldsMatch = false; + break; + } + } + try + { + int rowsInserted2 = dbContext.BulkInsert(orders, new BulkInsertOptions() + { + KeepIdentity = true, + BatchSize = 1000, + }); + } + catch (Exception ex) + { + Assert.IsTrue(Config.IsPrimaryKeyViolation(ex)); + } + + Assert.IsTrue(oldTotal == 0, "There should not be any records in the table"); + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(allIdentityFieldsMatch, "The identities between the source and the database should match."); + } + [TestMethod] + public void With_Schema() + { + var dbContext = SetupDbContext(false); + var products = new List(); + for (int i = 1; i < 10000; i++) + { + var key = i.ToString(); + products.Add(new ProductWithCustomSchema + { + Id = key, + Name = $"Product-{key}", + Price = 1.57M + }); + } + int oldTotal = dbContext.ProductsWithCustomSchema.Where(o => o.Price <= 10).Count(); + int rowsInserted = dbContext.BulkInsert(products); + int newTotal = dbContext.ProductsWithCustomSchema.Where(o => o.Price <= 10).Count(); + + Assert.IsTrue(rowsInserted == products.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of rows inserted."); + } + [DoNotParallelize] + [TestMethod] + public void With_Transaction() + { + var dbContext = SetupDbContext(false); + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M }); + } + int oldTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + int rowsInserted, newTotal; + using (var transaction = dbContext.Database.BeginTransaction()) + { + rowsInserted = dbContext.BulkInsert(orders); + newTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + transaction.Rollback(); + } + int rollbackTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of rows inserted."); + Assert.IsTrue(rollbackTotal == oldTotal, "The number of rows after the transacation has been rollbacked should match the original count"); + } + [TestMethod] + public void With_Options_InsertIfNotExists() + { + var dbContext = SetupDbContext(true); + var orders = new List(); + long maxId = dbContext.Orders.Max(o => o.Id); + long expectedRowsInserted = 1000; + int existingRowsToAdd = 100; + long startId = maxId - existingRowsToAdd + 1, endId = maxId + expectedRowsInserted + 1; + for (long i = startId; i < endId; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M }); + } + + int oldTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + int rowsInserted = dbContext.BulkInsert(orders, new BulkInsertOptions() { InsertIfNotExists = true }); + int newTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + + Assert.IsTrue(rowsInserted == expectedRowsInserted, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == expectedRowsInserted, "The new count minus the old count should match the number of rows inserted."); + } + [TestMethod] + public void With_Proxy_Type() + { + var dbContext = SetupDbContext(false); + int oldTotalCount = dbContext.Products.Where(o => o.Price == 10.57M).Count(); + + var products = new List(); + for (int i = 0; i < 2000; i++) + { + var product = dbContext.Products.CreateProxy(); + product.Id = (-i).ToString(); + product.Price = 10.57M; + products.Add(product); + } + int oldTotal = dbContext.Products.Where(o => o.Price == 10.57M).Count(); + int rowsInserted = dbContext.BulkInsert(products); + int newTotal = dbContext.Products.Where(o => o.Price == 10.57M).Count(); + + Assert.IsTrue(rowsInserted == products.Count, "The number of rows inserted must match the count of products list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of rows inserted."); + } + [TestMethod] + public void With_Trigger() + { + var dbContext = SetupDbContext(false); + var products = new List(); + for (int i = 1; i < 1000; i++) + { + products.Add(new ProductWithTrigger { Id = i.ToString(), Price = 1.57M, StatusString = "InStock" }); + } + + //The return int from BulkInsert() will be off when using triggers + dbContext.BulkInsert(products, options => + { + options.AutoMapOutput = false; + }); + var rowsInserted = dbContext.ProductsWithTrigger.Count(); + + Assert.IsTrue(rowsInserted == products.Count, $"The number of rows inserted must match the count of products ({rowsInserted}!={products.Count})"); + } + [TestMethod] + public void With_ValueGenerated_Default() + { + var dbContext = SetupDbContext(false); + var nowDateTime = Config.IsPostgreSql ? DateTime.UtcNow : DateTime.Now; + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M }); + } + int oldTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + int rowsInserted = dbContext.BulkInsert(orders); + int newTotal = dbContext.Orders.Where(o => o.Price <= 10 && o.DbAddedDateTime > nowDateTime && o.DbActive).Count(); + + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of rows inserted."); + } + [TestMethod] + public void With_ValueGenerated_Computed() + { + var dbContext = SetupDbContext(false); + var nowDateTime = Config.IsPostgreSql ? DateTime.UtcNow : DateTime.Now; + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M, DbModifiedDateTime = nowDateTime }); + } + int oldTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + int rowsInserted = dbContext.BulkInsert(orders); + int newTotal = dbContext.Orders.Where(o => o.Price <= 10 && o.DbModifiedDateTime > nowDateTime).Count(); + + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of rows inserted."); + } +} diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/BulkInsertAsync.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/BulkInsertAsync.cs new file mode 100644 index 0000000..f9e7d9b --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/BulkInsertAsync.cs @@ -0,0 +1,477 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.EntityFrameworkCore; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Common; +using N.EntityFrameworkCore.Extensions.Test.Data; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class BulkInsertAsync : DbContextExtensionsBase +{ + [TestMethod] + public async Task With_Complex_Key() + { + var dbContext = SetupDbContext(true); + var products = new List(); + for (int i = 50000; i < 60000; i++) + { + var key = i.ToString(); + products.Add(new ProductWithComplexKey { Price = 1.57M }); + } + int oldTotal = dbContext.ProductsWithComplexKey.Where(o => o.Price <= 10).Count(); + int rowsInserted = await dbContext.BulkInsertAsync(products); + int newTotal = dbContext.ProductsWithComplexKey.Where(o => o.Price <= 10).Count(); + + Assert.IsTrue(rowsInserted == products.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of rows inserted."); + } + [TestMethod] + public async Task With_Complex_Type() + { + var dbContext = SetupDbContext(true); + var orders = new List(); + for (int i = 1; i < 1000; i++) + { + orders.Add(new OrderWithComplexType + { + Id = i, + ShippingAddress = new Address + { + Line1 = $"123 Main St, {i}", + City = "Atlanta", + Country = "USA", + PostCode = "30303" + }, + BillingAddress = new Address + { + Line1 = $"456 Oak St, {i}", + City = "Atlanta", + Country = "USA", + PostCode = "30303" + } + }); + } + int oldTotal = dbContext.OrdersWithComplexType.Count(); + int rowsInserted = await dbContext.BulkInsertAsync(orders); + int newTotal = dbContext.OrdersWithComplexType.Count(); + + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of rows inserted."); + } + [TestMethod] + public async Task With_Default_Options() + { + var dbContext = SetupDbContext(false); + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M }); + } + int oldTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + int rowsInserted = await dbContext.BulkInsertAsync(orders); + int newTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of rows inserted."); + } + //[TestMethod] + //public async Task With_IEnumerable() + //{ + // var dbContext = SetupDbContext(false); + // var orders = dbContext.Orders.Where(o => o.Price <= 10); + + // foreach(var order in orders) + // { + // order.Price = 15.75M; + // } + // int oldTotal = orders.Count(); + // int rowsInserted = await dbContext.BulkInsertAsync(orders); + // int newTotal = orders.Count(); + + // Assert.IsTrue(rowsInserted == oldTotal, "The number of rows inserted must match the count of order list"); + // Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of rows inserted."); + //} + [TestMethod] + public async Task With_Inheritance_Tpc() + { + var dbContext = SetupDbContext(false); + var customers = new List(); + var vendors = new List(); + for (int i = 0; i < 20000; i++) + { + customers.Add(new TpcCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + AddedDate = DateTime.UtcNow + }); + } + for (int i = 20000; i < 30000; i++) + { + vendors.Add(new TpcVendor + { + Id = i, + FirstName = string.Format("Mike_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("mike.smith{0}@domain.com", i), + Url = string.Format("http://domain.com/mike.smith{0}", i) + }); + } + int oldTotal = dbContext.TpcPeople.Count(); + int customerRowsInserted = await dbContext.BulkInsertAsync(customers); + int vendorRowsInserted = await dbContext.BulkInsertAsync(vendors); + int rowsInserted = customerRowsInserted + vendorRowsInserted; + int newTotal = dbContext.TpcPeople.Count(); + + Assert.IsTrue(rowsInserted == customers.Count + vendors.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of rows inserted."); + } + [TestMethod] + public async Task With_Inheritance_Tph() + { + var dbContext = SetupDbContext(false); + var customers = new List(); + var vendors = new List(); + for (int i = 0; i < 20000; i++) + { + customers.Add(new TphCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + Phone = "404-555-1111", + AddedDate = DateTime.UtcNow + }); + } + for (int i = 20000; i < 30000; i++) + { + vendors.Add(new TphVendor + { + Id = i, + FirstName = string.Format("Mike_{0}", i), + LastName = string.Format("Smith_{0}", i), + Phone = "404-555-2222", + Email = string.Format("mike.smith{0}@domain.com", i), + Url = string.Format("http://domain.com/mike.smith{0}", i) + }); + } + int oldTotal = dbContext.TphPeople.Count(); + int customerRowsInserted = await dbContext.BulkInsertAsync(customers); + int vendorRowsInserted = await dbContext.BulkInsertAsync(vendors); + int rowsInserted = customerRowsInserted + vendorRowsInserted; + int newTotal = dbContext.TphPeople.Count(); + + Assert.IsTrue(rowsInserted == customers.Count + vendors.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of rows inserted."); + } + [TestMethod] + public async Task With_Inheritance_Tpt() + { + var dbContext = SetupDbContext(false); + var customers = new List(); + var vendors = new List(); + for (int i = 0; i < 20000; i++) + { + customers.Add(new TptCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + Phone = "777-555-1234", + AddedDate = DateTime.UtcNow + }); + } + for (int i = 20000; i < 30000; i++) + { + vendors.Add(new TptVendor + { + Id = i, + FirstName = string.Format("Mike_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("mike.smith{0}@domain.com", i), + Url = string.Format("http://domain.com/mike.smith{0}", i) + }); + } + int oldTotal = dbContext.TptPeople.Count(); + int customerRowsInserted = await dbContext.BulkInsertAsync(customers, o => o.UsePermanentTable = true); + int vendorRowsInserted = await dbContext.BulkInsertAsync(vendors); + int rowsInserted = customerRowsInserted + vendorRowsInserted; + int newTotal = dbContext.TptPeople.Count(); + + Assert.IsTrue(rowsInserted == customers.Count + vendors.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of rows inserted."); + } + [TestMethod] + public async Task Without_Identity_Column() + { + var dbContext = SetupDbContext(true); + var products = new List(); + for (int i = 50000; i < 60000; i++) + { + products.Add(new Product { Id = i.ToString(), Price = 1.57M }); + } + int oldTotal = dbContext.Products.Where(o => o.Price <= 10).Count(); + int rowsInserted = await dbContext.BulkInsertAsync(products); + int newTotal = dbContext.Products.Where(o => o.Price <= 10).Count(); + + Assert.IsTrue(rowsInserted == products.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of rows inserted."); + } + [TestMethod] + public async Task With_Options_AutoMapIdentity() + { + + var dbContext = SetupDbContext(false); + var orders = new List(); + for (int i = 0; i < 5000; i++) + { + orders.Add(new Order { ExternalId = i.ToString(), Price = ((decimal)i + 0.55M) }); + } + int rowsAdded = await dbContext.BulkInsertAsync(orders, new BulkInsertOptions + { + UsePermanentTable = true + }); + bool autoMapIdentityMatched = true; + var ordersInDb = dbContext.Orders.ToList(); + Order order1 = null; + Order order2 = null; + foreach (var order in orders) + { + order1 = order; + var orderinDb = ordersInDb.First(o => o.Id == order.Id); + order2 = orderinDb; + if (!(orderinDb.ExternalId == order.ExternalId && orderinDb.Price == order.Price)) + { + autoMapIdentityMatched = false; + break; + } + } + + Assert.IsTrue(rowsAdded == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(autoMapIdentityMatched, "The auto mapping of ids of entities that were merged failed to match up"); + } + [TestMethod] + public async Task With_Options_IgnoreColumns() + { + var dbContext = SetupDbContext(false); + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i, ExternalId = i.ToString(), Price = 1.57M, Active = true }); + } + int oldTotal = dbContext.Orders.Where(o => o.Price <= 10 && o.ExternalId == null).Count(); + int rowsInserted = await dbContext.BulkInsertAsync(orders, options => { options.UsePermanentTable = true; options.IgnoreColumns = o => new { o.ExternalId }; }); + int newTotal = dbContext.Orders.Where(o => o.Price <= 10 && o.ExternalId == null).Count(); + + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of rows inserted."); + } + [TestMethod] + public async Task With_Options_InputColumns() + { + var dbContext = SetupDbContext(false); + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i, ExternalId = i.ToString(), Price = 1.57M, Active = true, Status = OrderStatus.Completed }); + } + int oldTotal = dbContext.Orders.Where(o => o.Price == 1.57M && o.ExternalId == null && o.Active == true).Count(); + int rowsInserted = await dbContext.BulkInsertAsync(orders, options => + { + options.UsePermanentTable = true; + options.InputColumns = o => new { o.Price, o.Active, o.AddedDateTime, o.Status }; + }); + int newTotal = dbContext.Orders.Where(o => o.Price == 1.57M && o.ExternalId == null && o.Active == true).Count(); + + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of rows inserted."); + } + [TestMethod] + public async Task With_KeepIdentity() + { + var dbContext = SetupDbContext(false); + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i + 1000, Price = 1.57M }); + } + int oldTotal = dbContext.Orders.Count(); + int rowsInserted = await dbContext.BulkInsertAsync(orders, options => { options.KeepIdentity = true; options.BatchSize = 1000; }); + var oldOrders = dbContext.Orders.OrderBy(o => o.Id).ToList(); + var newOrders = dbContext.Orders.OrderBy(o => o.Id).ToList(); + bool allIdentityFieldsMatch = true; + for (int i = 0; i < 20000; i++) + { + if (newOrders[i].Id != oldOrders[i].Id) + { + allIdentityFieldsMatch = false; + break; + } + } + try + { + int rowsInserted2 = await dbContext.BulkInsertAsync(orders, new BulkInsertOptions() + { + KeepIdentity = true, + BatchSize = 1000, + }); + } + catch (Exception ex) + { + Assert.IsTrue(Config.IsPrimaryKeyViolation(ex)); + } + + Assert.IsTrue(oldTotal == 0, "There should not be any records in the table"); + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(allIdentityFieldsMatch, "The identities between the source and the database should match."); + } + [TestMethod] + public async Task With_Proxy_Type() + { + var dbContext = SetupDbContext(false); + int oldTotalCount = dbContext.Products.Where(o => o.Price == 10.57M).Count(); + + var products = new List(); + for (int i = 0; i < 2000; i++) + { + var product = dbContext.Products.CreateProxy(); + product.Id = (-i).ToString(); + product.Price = 10.57M; + products.Add(product); + } + int oldTotal = dbContext.Products.Where(o => o.Price == 10.57M).Count(); + int rowsInserted = await dbContext.BulkInsertAsync(products); + int newTotal = dbContext.Products.Where(o => o.Price == 10.57M).Count(); + + Assert.IsTrue(rowsInserted == products.Count, "The number of rows inserted must match the count of products list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of rows inserted."); + } + [TestMethod] + public async Task With_Trigger() + { + var dbContext = SetupDbContext(false); + var products = new List(); + for (int i = 1; i < 1000; i++) + { + products.Add(new ProductWithTrigger { Id = i.ToString(), Price = 1.57M, StatusString = "InStock" }); + } + + //The return int from BulkInsertAsync() will be off when using triggers + await dbContext.BulkInsertAsync(products, options => + { + options.AutoMapOutput = false; + }); + var rowsInserted = dbContext.ProductsWithTrigger.Count(); + + Assert.IsTrue(rowsInserted == products.Count, $"The number of rows inserted must match the count of products ({rowsInserted}!={products.Count})"); + } + [TestMethod] + public async Task With_Schema() + { + var dbContext = SetupDbContext(false); + var products = new List(); + for (int i = 1; i < 10000; i++) + { + var key = i.ToString(); + products.Add(new ProductWithCustomSchema + { + Id = key, + Name = $"Product-{key}", + Price = 1.57M + }); + } + int oldTotal = dbContext.ProductsWithCustomSchema.Where(o => o.Price <= 10).Count(); + int rowsInserted = await dbContext.BulkInsertAsync(products); + int newTotal = dbContext.ProductsWithCustomSchema.Where(o => o.Price <= 10).Count(); + + Assert.IsTrue(rowsInserted == products.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of rows inserted."); + } + [DoNotParallelize] + [TestMethod] + public async Task With_Transaction() + { + var dbContext = SetupDbContext(false); + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M }); + } + int oldTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + int rowsInserted, newTotal; + using (var transaction = dbContext.Database.BeginTransaction()) + { + rowsInserted = await dbContext.BulkInsertAsync(orders); + newTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + transaction.Rollback(); + } + int rollbackTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of rows inserted."); + Assert.IsTrue(rollbackTotal == oldTotal, "The number of rows after the transacation has been rollbacked should match the original count"); + } + [TestMethod] + public async Task With_Options_InsertIfNotExists() + { + var dbContext = SetupDbContext(true); + var orders = new List(); + long maxId = dbContext.Orders.Max(o => o.Id); + long expectedRowsInserted = 1000; + int existingRowsToAdd = 100; + long startId = maxId - existingRowsToAdd + 1, endId = maxId + expectedRowsInserted + 1; + for (long i = startId; i < endId; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M }); + } + + int oldTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + int rowsInserted = await dbContext.BulkInsertAsync(orders, new BulkInsertOptions() { InsertIfNotExists = true }); + int newTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + + Assert.IsTrue(rowsInserted == expectedRowsInserted, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == expectedRowsInserted, "The new count minus the old count should match the number of rows inserted."); + } + [TestMethod] + public async Task With_ValueGenerated_Default() + { + var dbContext = SetupDbContext(false); + var nowDateTime = Config.IsPostgreSql ? DateTime.UtcNow : DateTime.Now; + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M }); + } + int oldTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + int rowsInserted = await dbContext.BulkInsertAsync(orders); + int newTotal = dbContext.Orders.Where(o => o.Price <= 10 && o.DbAddedDateTime > nowDateTime && o.DbActive).Count(); + + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of rows inserted."); + } + [TestMethod] + public async Task With_ValueGenerated_Computed() + { + var dbContext = SetupDbContext(false); + var nowDateTime = Config.IsPostgreSql ? DateTime.UtcNow : DateTime.Now; + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M, DbModifiedDateTime = nowDateTime }); + } + int oldTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + int rowsInserted = await dbContext.BulkInsertAsync(orders); + int newTotal = dbContext.Orders.Where(o => o.Price <= 10 && o.DbModifiedDateTime > nowDateTime).Count(); + + Assert.IsTrue(rowsInserted == orders.Count(), "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of rows inserted."); + } +} diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/BulkMerge.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/BulkMerge.cs new file mode 100644 index 0000000..51d65ee --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/BulkMerge.cs @@ -0,0 +1,407 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Common; +using N.EntityFrameworkCore.Extensions.Test.Data; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class BulkMerge : DbContextExtensionsBase +{ + [TestMethod] + public void With_Complex_Key() + { + var dbContext = SetupDbContext(true); + var products = dbContext.ProductsWithComplexKey.Where(o => o.Price == 1.25M).ToList(); + int productsToAdd = 5000; + decimal updatedPrice = 5.25M; + var productsToUpdate = products.ToList(); + foreach (var product in products) + { + product.Price = updatedPrice; + } + for (int i = 0; i < productsToAdd; i++) + { + products.Add(new ProductWithComplexKey { ExternalId = (20000 + i).ToString(), Price = 3.55M }); + } + var result = dbContext.BulkMerge(products); + var allProducts = dbContext.ProductsWithComplexKey.ToList(); + bool areAddedOrdersMerged = true; + bool areUpdatedOrdersMerged = true; + foreach (var product in allProducts) + { + if (productsToUpdate.Contains(product) && product.Price != updatedPrice) + { + areUpdatedOrdersMerged = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == products.Count(), "The number of rows inserted must match the count of order list"); + Assert.IsTrue(result.RowsUpdated == productsToUpdate.Count, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == productsToAdd, "The number of rows added must match"); + Assert.IsTrue(areAddedOrdersMerged, "The orders that were added did not merge correctly"); + Assert.IsTrue(areUpdatedOrdersMerged, "The orders that were updated did not merge correctly"); + } + [TestMethod] + public void With_Default_Options() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Id <= 10000).OrderBy(o => o.Id).ToList(); + int ordersToAdd = 5000; + int ordersToUpdate = orders.Count; + foreach (var order in orders) + { + order.Price = Convert.ToDecimal(order.Id + .25); + } + for (int i = 0; i < ordersToAdd; i++) + { + orders.Add(new Order { Id = 100000 + i, Price = 3.55M }); + } + var result = dbContext.BulkMerge(orders, o => o.UsePermanentTable = true); + var newOrders = dbContext.Orders.OrderBy(o => o.Id).ToList(); + bool areAddedOrdersMerged = true; + bool areUpdatedOrdersMerged = true; + foreach (var newOrder in newOrders.Where(o => o.Id <= 10000).OrderBy(o => o.Id)) + { + if (newOrder.Price != Convert.ToDecimal(newOrder.Id + .25)) + { + areUpdatedOrdersMerged = false; + break; + } + } + foreach (var newOrder in newOrders.Where(o => o.Id >= 500000).OrderBy(o => o.Id)) + { + if (newOrder.Price != 3.55M) + { + areAddedOrdersMerged = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == orders.Count(), "The number of rows inserted must match the count of order list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(areAddedOrdersMerged, "The orders that were added did not merge correctly"); + Assert.IsTrue(areUpdatedOrdersMerged, "The orders that were updated did not merge correctly"); + } + [TestMethod] + public void With_Inheritance_Tpc() + { + var dbContext = SetupDbContext(true); + var customers = dbContext.TpcPeople.Where(o => o.Id <= 1000).OfType().ToList(); + int customersToAdd = 5000; + int customersToUpdate = customers.Count; + foreach (var customer in customers) + { + customer.FirstName = "BulkMerge_Tpc_Update"; + } + for (int i = 0; i < customersToAdd; i++) + { + customers.Add(new TpcCustomer + { + Id = 10000 + i, + FirstName = "BulkMerge_Tpc_Add", + AddedDate = DateTime.UtcNow + }); + } + var result = dbContext.BulkMerge(customers, options => { options.MergeOnCondition = (s, t) => s.Id == t.Id; }); + int customersAdded = dbContext.TpcPeople.Where(o => o.FirstName == "BulkMerge_Tpc_Add").OfType().Count(); + int customersUpdated = dbContext.TpcPeople.Where(o => o.FirstName == "BulkMerge_Tpc_Update").OfType().Count(); + + Assert.IsTrue(result.RowsAffected == customers.Count, "The number of rows inserted must match the count of customer list"); + Assert.IsTrue(result.RowsUpdated == customersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == customersToAdd, "The number of rows added must match"); + Assert.IsTrue(customersToAdd == customersAdded, "The custmoers that were added did not merge correctly"); + Assert.IsTrue(customersToUpdate == customersUpdated, "The customers that were updated did not merge correctly"); + } + [TestMethod] + public void With_Inheritance_Tph() + { + var dbContext = SetupDbContext(true); + var customers = dbContext.TphPeople.Where(o => o.Id <= 1000).OfType().ToList(); + int customersToAdd = 5000; + int customersToUpdate = customers.Count; + foreach (var customer in customers) + { + customer.FirstName = "BulkMerge_Tph_Update"; + } + for (int i = 0; i < customersToAdd; i++) + { + customers.Add(new TphCustomer + { + Id = 10000 + i, + FirstName = "BulkMerge_Tph_Add", + AddedDate = DateTime.UtcNow + }); + } + var result = dbContext.BulkMerge(customers); + int customersAdded = dbContext.TphPeople.Where(o => o.FirstName == "BulkMerge_Tph_Add").OfType().Count(); + int customersUpdated = dbContext.TphPeople.Where(o => o.FirstName == "BulkMerge_Tph_Update").OfType().Count(); + + Assert.IsTrue(result.RowsAffected == customers.Count(), "The number of rows inserted must match the count of customer list"); + Assert.IsTrue(result.RowsUpdated == customersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == customersToAdd, "The number of rows added must match"); + Assert.IsTrue(customersToAdd == customersAdded, "The custmoers that were added did not merge correctly"); + Assert.IsTrue(customersToUpdate == customersUpdated, "The customers that were updated did not merge correctly"); + } + [TestMethod] + public void With_Inheritance_Tpt() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpt); + var customers = dbContext.TptPeople.Where(o => o.Id <= 1000).OfType().ToList(); + int customersToAdd = 5000; + int customersToUpdate = customers.Count; + foreach (var customer in customers) + { + customer.FirstName = "BulkMerge_Tpt_Update"; + } + for (int i = 0; i < customersToAdd; i++) + { + customers.Add(new TptCustomer + { + Id = 10000 + i, + FirstName = "BulkMerge_Tpt_Add", + AddedDate = DateTime.UtcNow + }); + } + var result = dbContext.BulkMerge(customers); + int customersAdded = dbContext.TptPeople.Where(o => o.FirstName == "BulkMerge_Tpt_Add").OfType().Count(); + int customersUpdated = dbContext.TptPeople.Where(o => o.FirstName == "BulkMerge_Tpt_Update").OfType().Count(); + + Assert.IsTrue(result.RowsAffected == customers.Count(), "The number of rows inserted must match the count of customer list"); + Assert.IsTrue(result.RowsUpdated == customersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == customersToAdd, "The number of rows added must match"); + Assert.IsTrue(customersToAdd == customersAdded, "The custmoers that were added did not merge correctly"); + Assert.IsTrue(customersToUpdate == customersUpdated, "The customers that were updated did not merge correctly"); + } + [TestMethod] + public void With_Default_Options_MergeOnCondition() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Id <= 100 && o.ExternalId != null).OrderBy(o => o.Id).ToList(); + int ordersToAdd = 50; + int ordersToUpdate = orders.Count; + foreach (var order in orders) + { + order.Price = Convert.ToDecimal(order.Id + .25); + } + for (int i = 0; i < ordersToAdd; i++) + { + orders.Add(new Order { Id = 100000 + i, Price = 3.55M }); + } + var result = dbContext.BulkMerge(orders, options => { options.MergeOnCondition = (s, t) => s.ExternalId == t.ExternalId; options.BatchSize = 1000; }); + var newOrders = dbContext.Orders.OrderBy(o => o.Id).ToList(); + bool areAddedOrdersMerged = true; + bool areUpdatedOrdersMerged = true; + foreach (var newOrder in newOrders.Where(o => o.Id <= 100 && o.ExternalId != null).OrderBy(o => o.Id)) + { + if (newOrder.Price != Convert.ToDecimal(newOrder.Id + .25)) + { + areUpdatedOrdersMerged = false; + break; + } + } + foreach (var newOrder in newOrders.Where(o => o.Id >= 500000).OrderBy(o => o.Id)) + { + if (newOrder.Price != 3.55M) + { + areAddedOrdersMerged = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == orders.Count(), "The number of rows inserted must match the count of order list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(areAddedOrdersMerged, "The orders that were added did not merge correctly"); + Assert.IsTrue(areUpdatedOrdersMerged, "The orders that were updated did not merge correctly"); + } + [TestMethod] + public void With_Options_AutoMapIdentity() + { + var dbContext = SetupDbContext(true); + int ordersToUpdate = 3; + int ordersToAdd = 2; + var orders = new List + { + new Order { ExternalId = "id-1", Price=7.10M }, + new Order { ExternalId = "id-2", Price=9.33M }, + new Order { ExternalId = "id-3", Price=3.25M }, + new Order { ExternalId = "id-1000001", Price=2.15M }, + new Order { ExternalId = "id-1000002", Price=5.75M }, + }; + var result = dbContext.BulkMerge(orders, new BulkMergeOptions + { + MergeOnCondition = (s, t) => s.ExternalId == t.ExternalId, + UsePermanentTable = true + }); + bool autoMapIdentityMatched = true; + foreach (var order in orders) + { + if (!dbContext.Orders.Any(o => o.ExternalId == order.ExternalId && o.Price == order.Price)) + { + autoMapIdentityMatched = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == ordersToAdd + ordersToUpdate, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(autoMapIdentityMatched, "The auto mapping of ids of entities that were merged failed to match up"); + } + [TestMethod] + public void With_Options_AutoMapOutput() + { + var dbContext = SetupDbContext(true); + int ordersToUpdate = 3; + int ordersToAdd = 2; + var orders = new List + { + new Order { ExternalId = "id-1", Price=7.10M }, + new Order { ExternalId = "id-2", Price=9.33M }, + new Order { ExternalId = "id-3", Price=3.25M }, + new Order { ExternalId = "id-1000001", Price=2.15M }, + new Order { ExternalId = "id-1000002", Price=5.75M }, + }; + var result = dbContext.BulkMerge(orders, new BulkMergeOptions + { + MergeOnCondition = (s, t) => s.ExternalId == t.ExternalId, + AutoMapOutput = true + }); + var autoMapIdentityMatched = orders.All(x => x.Id != 0); + + Assert.IsTrue(result.RowsAffected == ordersToAdd + ordersToUpdate, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(autoMapIdentityMatched, "The auto mapping of ids of entities that were merged failed to match up"); + } + [TestMethod] + public void With_Key() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(o => o.Price == 1.25M).OrderBy(o => o.Id).ToList(); + int productsToAdd = 5000; + var productsToUpdate = products.ToList(); + foreach (var product in products) + { + product.Price = Convert.ToDecimal(product.Id) + .25M; + } + for (int i = 0; i < productsToAdd; i++) + { + products.Add(new Product { Id = (20000 + i).ToString(), Price = 3.55M }); + } + var result = dbContext.BulkMerge(products); + var newProducts = dbContext.Products.OrderBy(o => o.Id).ToList(); + bool areAddedOrdersMerged = true; + bool areUpdatedOrdersMerged = true; + foreach (var newProduct in newProducts.Where(o => productsToUpdate.Select(o => o.Id).Contains(o.Id))) + { + if (newProduct.Price != Convert.ToDecimal(newProduct.Id) + .25M) + { + areUpdatedOrdersMerged = false; + break; + } + } + foreach (var newProduct in newProducts.Where(o => Convert.ToInt32(o.Id) >= 20000).OrderBy(o => o.Id)) + { + if (newProduct.Price != 3.55M) + { + areAddedOrdersMerged = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == products.Count(), "The number of rows inserted must match the count of order list"); + Assert.IsTrue(result.RowsUpdated == productsToUpdate.Count, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == productsToAdd, "The number of rows added must match"); + Assert.IsTrue(areAddedOrdersMerged, "The orders that were added did not merge correctly"); + Assert.IsTrue(areUpdatedOrdersMerged, "The orders that were updated did not merge correctly"); + } + [DoNotParallelize] + [TestMethod] + public void With_Transaction() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Id <= 10000).OrderBy(o => o.Id).ToList(); + int ordersToAdd = 5000; + int ordersToUpdate = orders.Count; + foreach (var order in orders) + { + order.Price = Convert.ToDecimal(order.Id + .25); + } + for (int i = 0; i < ordersToAdd; i++) + { + orders.Add(new Order { Id = 100000 + i, Price = 3.55M }); + } + BulkMergeResult result; + using (var transaction = dbContext.Database.BeginTransaction()) + { + result = dbContext.BulkMerge(orders); + transaction.Rollback(); + } + int ordersUpdated = dbContext.Orders.Count(o => o.Id <= 10000 && o.Price == ((decimal)o.Id + .25M) && o.Price != 1.25M); + int ordersAdded = dbContext.Orders.Count(o => o.Id >= 100000); + + Assert.IsTrue(result.RowsAffected == orders.Count(), "The number of rows inserted must match the count of order list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(ordersAdded == 0, "The number of rows added must equal 0 since transaction was rollbacked"); + Assert.IsTrue(ordersUpdated == 0, "The number of rows updated must equal 0 since transaction was rollbacked"); + } + [TestMethod] + public void With_ValueGenerated_Default() + { + var dbContext = SetupDbContext(false); + var nowDateTime = Config.IsPostgreSql ? DateTime.UtcNow : DateTime.Now; + var orders = new List(); + for (int i = 0; i < 1000; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M }); + } + int oldTotal = dbContext.Orders.Where(o => o.DbAddedDateTime > nowDateTime).Count(); + var mergeResult = dbContext.BulkMerge(orders); + int newTotal = dbContext.Orders.Where(o => o.Price <= 1.57M + && o.DbAddedDateTime > nowDateTime).Count(); + + Assert.IsTrue(mergeResult.RowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == mergeResult.RowsInserted, "The new count minus the old count should match the number of rows inserted."); + } + [TestMethod] + public void With_ValueGenerated_Computed() + { + var dbContext = SetupDbContext(false); + var nowDateTime = Config.IsPostgreSql ? DateTime.UtcNow : DateTime.Now; + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M, DbModifiedDateTime = nowDateTime }); + } + int oldTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + var result = dbContext.BulkMerge(orders); + int newTotal = dbContext.Orders.Where(o => o.Price <= 10 && o.DbModifiedDateTime > nowDateTime).Count(); + + Assert.IsTrue(result.RowsInserted == orders.Count(), "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == result.RowsInserted, "The new count minus the old count should match the number of rows inserted."); + } + [TestMethod] + public void With_Merge_On_Enum() + { + var dbContext = SetupDbContext(true); + dbContext.BulkSaveChanges(); + var nowDateTime = Config.IsPostgreSql ? DateTime.UtcNow : DateTime.Now; + var orders = new List(); + for (int i = 0; i < 20; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M, DbModifiedDateTime = nowDateTime, Status = OrderStatus.Completed }); + } + + var result = dbContext.BulkMerge(orders, options => options.MergeOnCondition = (s, t) => s.Id == t.Id && s.Status == t.Status); + + Assert.AreEqual(1, result.RowsInserted); + Assert.AreEqual(19, result.RowsUpdated); + } +} diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/BulkMergeAsync.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/BulkMergeAsync.cs new file mode 100644 index 0000000..14592ba --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/BulkMergeAsync.cs @@ -0,0 +1,408 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Common; +using N.EntityFrameworkCore.Extensions.Test.Data; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class BulkMergeAsync : DbContextExtensionsBase +{ + [TestMethod] + public async Task With_Complex_Key() + { + var dbContext = SetupDbContext(true); + var products = dbContext.ProductsWithComplexKey.Where(o => o.Price == 1.25M).ToList(); + int productsToAdd = 5000; + decimal updatedPrice = 5.25M; + var productsToUpdate = products.ToList(); + foreach (var product in products) + { + product.Price = updatedPrice; + } + for (int i = 0; i < productsToAdd; i++) + { + products.Add(new ProductWithComplexKey { ExternalId = (20000 + i).ToString(), Price = 3.55M }); + } + var result = await dbContext.BulkMergeAsync(products); + var allProducts = dbContext.ProductsWithComplexKey.ToList(); + bool areAddedOrdersMerged = true; + bool areUpdatedOrdersMerged = true; + foreach (var product in allProducts) + { + if (productsToUpdate.Contains(product) && product.Price != updatedPrice) + { + areUpdatedOrdersMerged = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == products.Count(), "The number of rows inserted must match the count of order list"); + Assert.IsTrue(result.RowsUpdated == productsToUpdate.Count, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == productsToAdd, "The number of rows added must match"); + Assert.IsTrue(areAddedOrdersMerged, "The orders that were added did not merge correctly"); + Assert.IsTrue(areUpdatedOrdersMerged, "The orders that were updated did not merge correctly"); + } + [TestMethod] + public async Task With_Default_Options() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Id <= 10000).OrderBy(o => o.Id).ToList(); + int ordersToAdd = 5000; + int ordersToUpdate = orders.Count; + foreach (var order in orders) + { + order.Price = Convert.ToDecimal(order.Id + .25); + } + for (int i = 0; i < ordersToAdd; i++) + { + orders.Add(new Order { Id = 100000 + i, Price = 3.55M }); + } + var result = await dbContext.BulkMergeAsync(orders); + var newOrders = dbContext.Orders.OrderBy(o => o.Id).ToList(); + bool areAddedOrdersMerged = true; + bool areUpdatedOrdersMerged = true; + foreach (var newOrder in newOrders.Where(o => o.Id <= 10000).OrderBy(o => o.Id)) + { + if (newOrder.Price != Convert.ToDecimal(newOrder.Id + .25)) + { + areUpdatedOrdersMerged = false; + break; + } + } + foreach (var newOrder in newOrders.Where(o => o.Id >= 500000).OrderBy(o => o.Id)) + { + if (newOrder.Price != 3.55M) + { + areAddedOrdersMerged = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == orders.Count(), "The number of rows inserted must match the count of order list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(areAddedOrdersMerged, "The orders that were added did not merge correctly"); + Assert.IsTrue(areUpdatedOrdersMerged, "The orders that were updated did not merge correctly"); + } + [TestMethod] + public async Task With_Inheritance_Tpc() + { + var dbContext = SetupDbContext(true); + var customers = dbContext.TpcPeople.Where(o => o.Id <= 1000).OfType().ToList(); + int customersToAdd = 5000; + int customersToUpdate = customers.Count; + foreach (var customer in customers) + { + customer.FirstName = "BulkMerge_Tpc_Update"; + } + for (int i = 0; i < customersToAdd; i++) + { + customers.Add(new TpcCustomer + { + Id = 10000 + i, + FirstName = "BulkMerge_Tpc_Add", + AddedDate = DateTime.UtcNow + }); + } + var result = await dbContext.BulkMergeAsync(customers, options => { options.MergeOnCondition = (s, t) => s.Id == t.Id; }); + int customersAdded = dbContext.TpcPeople.Where(o => o.FirstName == "BulkMerge_Tpc_Add").OfType().Count(); + int customersUpdated = dbContext.TpcPeople.Where(o => o.FirstName == "BulkMerge_Tpc_Update").OfType().Count(); + + Assert.IsTrue(result.RowsAffected == customers.Count, "The number of rows inserted must match the count of customer list"); + Assert.IsTrue(result.RowsUpdated == customersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == customersToAdd, "The number of rows added must match"); + Assert.IsTrue(customersToAdd == customersAdded, "The custmoers that were added did not merge correctly"); + Assert.IsTrue(customersToUpdate == customersUpdated, "The customers that were updated did not merge correctly"); + } + [TestMethod] + public async Task With_Inheritance_Tph() + { + var dbContext = SetupDbContext(true); + var customers = dbContext.TphPeople.Where(o => o.Id <= 1000).OfType().ToList(); + int customersToAdd = 5000; + int customersToUpdate = customers.Count; + foreach (var customer in customers) + { + customer.FirstName = "BulkMerge_Tph_Update"; + } + for (int i = 0; i < customersToAdd; i++) + { + customers.Add(new TphCustomer + { + Id = 10000 + i, + FirstName = "BulkMerge_Tph_Add", + AddedDate = DateTime.UtcNow + }); + } + var result = await dbContext.BulkMergeAsync(customers); + int customersAdded = dbContext.TphPeople.Where(o => o.FirstName == "BulkMerge_Tph_Add").OfType().Count(); + int customersUpdated = dbContext.TphPeople.Where(o => o.FirstName == "BulkMerge_Tph_Update").OfType().Count(); + + Assert.IsTrue(result.RowsAffected == customers.Count(), "The number of rows inserted must match the count of customer list"); + Assert.IsTrue(result.RowsUpdated == customersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == customersToAdd, "The number of rows added must match"); + Assert.IsTrue(customersToAdd == customersAdded, "The custmoers that were added did not merge correctly"); + Assert.IsTrue(customersToUpdate == customersUpdated, "The customers that were updated did not merge correctly"); + } + [TestMethod] + public async Task With_Inheritance_Tpt() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpt); + var customers = dbContext.TptPeople.Where(o => o.Id <= 1000).OfType().ToList(); + int customersToAdd = 5000; + int customersToUpdate = customers.Count; + foreach (var customer in customers) + { + customer.FirstName = "BulkMergeAsync_Tpt_Update"; + } + for (int i = 0; i < customersToAdd; i++) + { + customers.Add(new TptCustomer + { + Id = 10000 + i, + FirstName = "BulkMergeAsync_Tpt_Add", + AddedDate = DateTime.UtcNow + }); + } + var result = await dbContext.BulkMergeAsync(customers); + int customersAdded = dbContext.TptPeople.Where(o => o.FirstName == "BulkMergeAsync_Tpt_Add").OfType().Count(); + int customersUpdated = dbContext.TptPeople.Where(o => o.FirstName == "BulkMergeAsync_Tpt_Update").OfType().Count(); + + Assert.IsTrue(result.RowsAffected == customers.Count(), "The number of rows inserted must match the count of customer list"); + Assert.IsTrue(result.RowsUpdated == customersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == customersToAdd, "The number of rows added must match"); + Assert.IsTrue(customersToAdd == customersAdded, "The custmoers that were added did not merge correctly"); + Assert.IsTrue(customersToUpdate == customersUpdated, "The customers that were updated did not merge correctly"); + } + [TestMethod] + public async Task With_Default_Options_MergeOnCondition() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Id <= 100 && o.ExternalId != null).OrderBy(o => o.Id).ToList(); + int ordersToAdd = 50; + int ordersToUpdate = orders.Count; + foreach (var order in orders) + { + order.Price = Convert.ToDecimal(order.Id + .25); + } + for (int i = 0; i < ordersToAdd; i++) + { + orders.Add(new Order { Id = 100000 + i, Price = 3.55M }); + } + var result = await dbContext.BulkMergeAsync(orders, options => { options.MergeOnCondition = (s, t) => s.ExternalId == t.ExternalId; options.BatchSize = 1000; }); + var newOrders = dbContext.Orders.OrderBy(o => o.Id).ToList(); + bool areAddedOrdersMerged = true; + bool areUpdatedOrdersMerged = true; + foreach (var newOrder in newOrders.Where(o => o.Id <= 100 && o.ExternalId != null).OrderBy(o => o.Id)) + { + if (newOrder.Price != Convert.ToDecimal(newOrder.Id + .25)) + { + areUpdatedOrdersMerged = false; + break; + } + } + foreach (var newOrder in newOrders.Where(o => o.Id >= 500000).OrderBy(o => o.Id)) + { + if (newOrder.Price != 3.55M) + { + areAddedOrdersMerged = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == orders.Count(), "The number of rows inserted must match the count of order list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(areAddedOrdersMerged, "The orders that were added did not merge correctly"); + Assert.IsTrue(areUpdatedOrdersMerged, "The orders that were updated did not merge correctly"); + } + [TestMethod] + public async Task With_Options_AutoMapIdentity() + { + var dbContext = SetupDbContext(true); + int ordersToUpdate = 3; + int ordersToAdd = 2; + var orders = new List + { + new Order { ExternalId = "id-1", Price=7.10M }, + new Order { ExternalId = "id-2", Price=9.33M }, + new Order { ExternalId = "id-3", Price=3.25M }, + new Order { ExternalId = "id-1000001", Price=2.15M }, + new Order { ExternalId = "id-1000002", Price=5.75M }, + }; + var result = await dbContext.BulkMergeAsync(orders, new BulkMergeOptions + { + MergeOnCondition = (s, t) => s.ExternalId == t.ExternalId, + UsePermanentTable = true + }); + bool autoMapIdentityMatched = true; + foreach (var order in orders) + { + if (!dbContext.Orders.Any(o => o.ExternalId == order.ExternalId && o.Price == order.Price)) + { + autoMapIdentityMatched = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == ordersToAdd + ordersToUpdate, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(autoMapIdentityMatched, "The auto mapping of ids of entities that were merged failed to match up"); + } + [TestMethod] + public async Task With_Options_AutoMapOutput() + { + var dbContext = SetupDbContext(true); + int ordersToUpdate = 3; + int ordersToAdd = 2; + var orders = new List + { + new Order { ExternalId = "id-1", Price=7.10M }, + new Order { ExternalId = "id-2", Price=9.33M }, + new Order { ExternalId = "id-3", Price=3.25M }, + new Order { ExternalId = "id-1000001", Price=2.15M }, + new Order { ExternalId = "id-1000002", Price=5.75M }, + }; + var result = await dbContext.BulkMergeAsync(orders, new BulkMergeOptions + { + MergeOnCondition = (s, t) => s.ExternalId == t.ExternalId, + AutoMapOutput = true + }); + var autoMapIdentityMatched = orders.All(x => x.Id != 0); + + Assert.IsTrue(result.RowsAffected == ordersToAdd + ordersToUpdate, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(autoMapIdentityMatched, "The auto mapping of ids of entities that were merged failed to match up"); + } + [TestMethod] + public async Task With_Key() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(o => o.Price == 1.25M).OrderBy(o => o.Id).ToList(); + int productsToAdd = 5000; + var productsToUpdate = products.ToList(); + foreach (var product in products) + { + product.Price = Convert.ToDecimal(product.Id) + .25M; + } + for (int i = 0; i < productsToAdd; i++) + { + products.Add(new Product { Id = (20000 + i).ToString(), Price = 3.55M }); + } + var result = await dbContext.BulkMergeAsync(products); + var newProducts = dbContext.Products.OrderBy(o => o.Id).ToList(); + bool areAddedOrdersMerged = true; + bool areUpdatedOrdersMerged = true; + foreach (var newProduct in newProducts.Where(o => productsToUpdate.Select(o => o.Id).Contains(o.Id))) + { + if (newProduct.Price != Convert.ToDecimal(newProduct.Id) + .25M) + { + areUpdatedOrdersMerged = false; + break; + } + } + foreach (var newProduct in newProducts.Where(o => Convert.ToInt32(o.Id) >= 20000).OrderBy(o => o.Id)) + { + if (newProduct.Price != 3.55M) + { + areAddedOrdersMerged = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == products.Count(), "The number of rows inserted must match the count of order list"); + Assert.IsTrue(result.RowsUpdated == productsToUpdate.Count, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == productsToAdd, "The number of rows added must match"); + Assert.IsTrue(areAddedOrdersMerged, "The orders that were added did not merge correctly"); + Assert.IsTrue(areUpdatedOrdersMerged, "The orders that were updated did not merge correctly"); + } + [DoNotParallelize] + [TestMethod] + public async Task With_Transaction() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Id <= 10000).OrderBy(o => o.Id).ToList(); + int ordersToAdd = 5000; + int ordersToUpdate = orders.Count; + foreach (var order in orders) + { + order.Price = Convert.ToDecimal(order.Id + .25); + } + for (int i = 0; i < ordersToAdd; i++) + { + orders.Add(new Order { Id = 100000 + i, Price = 3.55M }); + } + BulkMergeResult result; + using (var transaction = dbContext.Database.BeginTransaction()) + { + result = await dbContext.BulkMergeAsync(orders); + transaction.Rollback(); + } + int ordersUpdated = dbContext.Orders.Count(o => o.Id <= 10000 && o.Price == ((decimal)o.Id + .25M) && o.Price != 1.25M); + int ordersAdded = dbContext.Orders.Count(o => o.Id >= 100000); + + Assert.IsTrue(result.RowsAffected == orders.Count(), "The number of rows inserted must match the count of order list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(ordersAdded == 0, "The number of rows added must equal 0 since transaction was rollbacked"); + Assert.IsTrue(ordersUpdated == 0, "The number of rows updated must equal 0 since transaction was rollbacked"); + } + [TestMethod] + public async Task With_ValueGenerated_Default() + { + var dbContext = SetupDbContext(false); + var nowDateTime = Config.IsPostgreSql ? DateTime.UtcNow : DateTime.Now; + var orders = new List(); + for (int i = 0; i < 1000; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M }); + } + int oldTotal = dbContext.Orders.Where(o => o.DbAddedDateTime > nowDateTime).Count(); + var mergeResult = await dbContext.BulkMergeAsync(orders); + int newTotal = dbContext.Orders.Where(o => o.Price <= 1.57M + && o.DbAddedDateTime > nowDateTime).Count(); + + Assert.IsTrue(mergeResult.RowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == mergeResult.RowsInserted, "The new count minus the old count should match the number of rows inserted."); + } + [TestMethod] + public async Task With_ValueGenerated_Computed() + { + var dbContext = SetupDbContext(false); + var nowDateTime = Config.IsPostgreSql ? DateTime.UtcNow : DateTime.Now; + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M, DbModifiedDateTime = nowDateTime }); + } + int oldTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + var result = await dbContext.BulkMergeAsync(orders); + int newTotal = dbContext.Orders.Where(o => o.Price <= 10 && o.DbModifiedDateTime > nowDateTime).Count(); + + Assert.IsTrue(result.RowsInserted == orders.Count(), "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == result.RowsInserted, "The new count minus the old count should match the number of rows inserted."); + } + [TestMethod] + public async Task With_Merge_On_Enum() + { + var dbContext = SetupDbContext(true); + await dbContext.BulkSaveChangesAsync(); + var nowDateTime = Config.IsPostgreSql ? DateTime.UtcNow : DateTime.Now; + var orders = new List(); + for (int i = 0; i < 20; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M, DbModifiedDateTime = nowDateTime, Status = OrderStatus.Completed }); + } + + var result = await dbContext.BulkMergeAsync(orders, options => options.MergeOnCondition = (s, t) => s.Id == t.Id && s.Status == t.Status); + + Assert.AreEqual(1, result.RowsInserted); + Assert.AreEqual(19, result.RowsUpdated); + } +} diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/BulkSaveChanges.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/BulkSaveChanges.cs new file mode 100644 index 0000000..d2f9fbc --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/BulkSaveChanges.cs @@ -0,0 +1,261 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Data; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class BulkSaveChanges : DbContextExtensionsBase +{ + [TestMethod] + public void With_Default_Options() + { + var dbContext = SetupDbContext(true); + var totalCount = dbContext.Orders.Count(); + + //Add new orders + var ordersToAdd = new List(); + for (int i = 0; i < 2000; i++) + { + ordersToAdd.Add(new Order { Id = -i, Price = 10.57M }); + } + dbContext.Orders.AddRange(ordersToAdd); + + //Delete orders + var ordersToDelete = dbContext.Orders.Where(o => o.Price <= 5).ToList(); + dbContext.Orders.RemoveRange(ordersToDelete); + + //Update existing orders + var ordersToUpdate = dbContext.Orders.Where(o => o.Price > 5 && o.Price <= 10).ToList(); + foreach (var orderToUpdate in ordersToUpdate) + { + orderToUpdate.Price = 99M; + } + + + int rowsAffected = dbContext.BulkSaveChanges(); + int ordersAddedCount = dbContext.Orders.Where(o => o.Price == 10.57M).Count(); + int ordersDeletedCount = dbContext.Orders.Where(o => o.Price <= 5).Count(); + int ordersUpdatedCount = dbContext.Orders.Where(o => o.Price == 99M).Count(); + + Assert.IsTrue(rowsAffected == ordersToAdd.Count + ordersToDelete.Count + ordersToUpdate.Count, "The number of rows affected must equal the sum of entities added, deleted and updated"); + Assert.IsTrue(ordersAddedCount == ordersToAdd.Count(), "The number of orders to add did not match what was expected."); + Assert.IsTrue(ordersDeletedCount == 0, "The number of orders that was deleted did not match what was expected."); + Assert.IsTrue(ordersUpdatedCount == ordersToUpdate.Count(), "The number of orders that was updated did not match what was expected."); + } + [TestMethod] + public void With_Add_Changes() + { + var dbContext = SetupDbContext(true); + var oldTotalCount = dbContext.Orders.Where(o => o.Price == 10.57M).Count(); + + var ordersToAdd = new List(); + for (int i = 0; i < 2000; i++) + { + ordersToAdd.Add(new Order { Id = -i, Price = 10.57M }); + } + dbContext.Orders.AddRange(ordersToAdd); + + int rowsAffected = dbContext.BulkSaveChanges(); + int newTotalCount = dbContext.Orders.Where(o => o.Price == 10.57M).Count(); + + Assert.IsTrue(ordersToAdd.Where(o => o.Id <= 0).Count() == 0, "Primary key should have been updated for all entities"); + Assert.IsTrue(rowsAffected == ordersToAdd.Count, "The number of rows affected must equal the sum of entities added, deleted and updated"); + Assert.IsTrue(oldTotalCount + ordersToAdd.Count == newTotalCount, "The number of orders to add did not match what was expected."); + } + [TestMethod] + public void With_Delete_Changes() + { + var dbContext = SetupDbContext(true); + var oldTotalCount = dbContext.Orders.Where(o => o.Price <= 5).Count(); + + //Delete orders + var ordersToDelete = dbContext.Orders.Where(o => o.Price <= 5).ToList(); + dbContext.Orders.RemoveRange(ordersToDelete); + + int rowsAffected = dbContext.BulkSaveChanges(); + int newTotalCount = dbContext.Orders.Where(o => o.Price <= 5).Count(); + + Assert.IsTrue(rowsAffected == ordersToDelete.Count, "The number of rows affected must equal the sum of entities added, deleted and updated"); + Assert.IsTrue(oldTotalCount - ordersToDelete.Count == newTotalCount, "The number of orders to add did not match what was expected."); + } + [TestMethod] + public void With_Update_Changes() + { + var dbContext = SetupDbContext(true); + var oldTotalCount = dbContext.Orders.Where(o => o.Price <= 10).Count(); + + //Update existing orders + var ordersToUpdate = dbContext.Orders.Where(o => o.Price <= 10).ToList(); + foreach (var orderToUpdate in ordersToUpdate) + { + orderToUpdate.Price = 99M; + } + + int rowsAffected = dbContext.BulkSaveChanges(); + int newTotalCount = dbContext.Orders.Where(o => o.Price <= 10).Count(); + int expectedCount = dbContext.Orders.Where(o => o.Price == 99M).Count(); + + Assert.IsTrue(rowsAffected == ordersToUpdate.Count, "The number of rows affected must equal the sum of entities added, deleted and updated"); + Assert.IsTrue(oldTotalCount - ordersToUpdate.Count == newTotalCount, "The number of orders to add did not match what was expected."); + } + [TestMethod] + public void With_Inheritance_Tpc() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpc); + //Delete Customers + var customersToDelete = dbContext.TpcPeople.OfType().Where(o => o.Id <= 1000); + int expectedRowsDeleted = customersToDelete.Count(); + dbContext.TpcPeople.RemoveRange(customersToDelete); + //Update Customers + var customersToUpdate = dbContext.TpcPeople.OfType().Where(o => o.Id > 1000 && o.Id <= 1500); + int expectedRowsUpdated = customersToUpdate.Count(); + foreach (var customerToUpdate in customersToUpdate) + { + customerToUpdate.FirstName = "CustomerUpdated"; + } + //Add New Customers + long maxId = dbContext.TpcPeople.OfType().Max(o => o.Id); + int expectedRowsAdded = 3000; + for (long i = maxId + 1; i <= maxId + expectedRowsAdded; i++) + { + dbContext.TpcPeople.Add(new TpcCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + AddedDate = DateTime.UtcNow + }); + } + int rowsAffected = dbContext.BulkSaveChanges(); + int rowsAfterDelete = customersToDelete.Count(); + int rowsUpdated = customersToUpdate.Where(o => o.FirstName == "CustomerUpdated").Count(); + int rowsAdded = dbContext.TpcPeople.OfType().Where(o => o.Id > maxId).Count(); + int expectedRowsAffected = expectedRowsDeleted + expectedRowsUpdated + expectedRowsAdded; + + Assert.IsTrue(rowsAfterDelete == 0, "The number of rows deleted not not match what was expected."); + Assert.IsTrue(rowsUpdated == expectedRowsUpdated, "The number of rows updated not not match what was expected."); + Assert.IsTrue(rowsAdded == expectedRowsAdded, "The number of rows added not not match what was expected."); + Assert.IsTrue(rowsAffected == expectedRowsAffected, "The new count minus the old count should match the number of rows inserted."); + } + [TestMethod] + public void With_Inheritance_Tph() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tph); + //Delete Customers + var customersToDelete = dbContext.TphCustomers.Where(o => o.Id <= 1000); + int expectedRowsDeleted = customersToDelete.Count(); + dbContext.TphCustomers.RemoveRange(customersToDelete); + //Update Customers + var customersToUpdate = dbContext.TphCustomers.Where(o => o.Id > 1000 && o.Id <= 1500); + int expectedRowsUpdated = customersToUpdate.Count(); + foreach (var customerToUpdate in customersToUpdate) + { + customerToUpdate.FirstName = "CustomerUpdated"; + } + //Add New Customers + long maxId = dbContext.TphPeople.Max(o => o.Id); + int expectedRowsAdded = 3000; + for (long i = maxId + 1; i <= maxId + expectedRowsAdded; i++) + { + dbContext.TphCustomers.Add(new TphCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + AddedDate = DateTime.UtcNow + }); + } + int rowsAffected = dbContext.BulkSaveChanges(); + int rowsAfterDelete = customersToDelete.Count(); + int rowsUpdated = customersToUpdate.Where(o => o.FirstName == "CustomerUpdated").Count(); + int rowsAdded = dbContext.TphCustomers.Where(o => o.Id > maxId).Count(); + int expectedRowsAffected = expectedRowsDeleted + expectedRowsUpdated + expectedRowsAdded; + + Assert.IsTrue(expectedRowsDeleted > 0, "The expected number of rows to delete must be greater than zero."); + Assert.IsTrue(rowsAfterDelete == 0, "The number of rows deleted not not match what was expected."); + Assert.IsTrue(rowsUpdated == expectedRowsUpdated, "The number of rows updated not not match what was expected."); + Assert.IsTrue(rowsAdded == expectedRowsAdded, "The number of rows added not not match what was expected."); + Assert.IsTrue(rowsAffected == expectedRowsAffected, "The new count minus the old count should match the number of rows inserted."); + } + [TestMethod] + public void With_Inheritance_Tpt() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpt); + //Delete Customers + var customersToDelete = dbContext.TptCustomers.Where(o => o.Id <= 1000); + int expectedRowsDeleted = customersToDelete.Count(); + dbContext.TptCustomers.RemoveRange(customersToDelete); + //Update Customers + var customersToUpdate = dbContext.TptCustomers.Where(o => o.Id > 1000 && o.Id <= 1500); + int expectedRowsUpdated = customersToUpdate.Count(); + foreach (var customerToUpdate in customersToUpdate) + { + customerToUpdate.Email = "name@domain.com"; + customerToUpdate.FirstName = "CustomerUpdated"; + } + //Add New Customers + long maxId = dbContext.TptPeople.Max(o => o.Id); + int expectedRowsAdded = 3000; + for (long i = maxId + 1; i <= maxId + expectedRowsAdded; i++) + { + dbContext.TptCustomers.Add(new TptCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + AddedDate = DateTime.UtcNow + }); + } + int rowsAffected = dbContext.BulkSaveChanges(); + int rowsAfterDelete = customersToDelete.Count(); + int rowsUpdated = customersToUpdate.Where(o => o.FirstName == "CustomerUpdated").Count(); + int rowsAdded = dbContext.TptCustomers.Where(o => o.Id > maxId).Count(); + int expectedRowsAffected = expectedRowsDeleted + expectedRowsUpdated + expectedRowsAdded; + + Assert.IsTrue(rowsAfterDelete == 0, "The number of rows deleted not not match what was expected."); + Assert.IsTrue(rowsUpdated == expectedRowsUpdated, "The number of rows updated not not match what was expected."); + Assert.IsTrue(rowsAdded == expectedRowsAdded, "The number of rows added not not match what was expected."); + Assert.IsTrue(rowsAffected == expectedRowsAffected, "The new count minus the old count should match the number of rows inserted."); + } + [TestMethod] + public void With_Schema() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Schema); + var totalCount = dbContext.ProductsWithCustomSchema.Count(); + + //Add new products + var productsToAdd = new List(); + for (int i = 0; i < 2000; i++) + { + productsToAdd.Add(new ProductWithCustomSchema { Id = (-i).ToString(), Price = 10.57M }); + } + dbContext.ProductsWithCustomSchema.AddRange(productsToAdd); + + //Delete products + var productsToDelete = dbContext.ProductsWithCustomSchema.Where(o => o.Price <= 5).ToList(); + dbContext.ProductsWithCustomSchema.RemoveRange(productsToDelete); + + //Update existing products + var productsToUpdate = dbContext.ProductsWithCustomSchema.Where(o => o.Price > 5 && o.Price <= 10).ToList(); + foreach (var productToUpdate in productsToUpdate) + { + productToUpdate.Price = 99M; + } + + int rowsAffected = dbContext.BulkSaveChanges(); + int productsAddedCount = dbContext.ProductsWithCustomSchema.Where(o => o.Price == 10.57M).Count(); + int productsDeletedCount = dbContext.ProductsWithCustomSchema.Where(o => o.Price <= 5).Count(); + int productsUpdatedCount = dbContext.ProductsWithCustomSchema.Where(o => o.Price == 99M).Count(); + + Assert.IsTrue(rowsAffected == productsToAdd.Count + productsToDelete.Count + productsToUpdate.Count, "The number of rows affected must equal the sum of entities added, deleted and updated"); + Assert.IsTrue(productsAddedCount == productsToAdd.Count(), "The number of products to add did not match what was expected."); + Assert.IsTrue(productsDeletedCount == 0, "The number of products that was deleted did not match what was expected."); + Assert.IsTrue(productsUpdatedCount == productsToUpdate.Count(), "The number of products that was updated did not match what was expected."); + } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/BulkSaveChangesAsync.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/BulkSaveChangesAsync.cs new file mode 100644 index 0000000..367c793 --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/BulkSaveChangesAsync.cs @@ -0,0 +1,263 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.EntityFrameworkCore; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Data; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class BulkSaveChangesAsync : DbContextExtensionsBase +{ + [TestMethod] + public async Task With_Default_Options() + { + var dbContext = SetupDbContext(true); + var totalCount = dbContext.Orders.Count(); + + //Add new orders + var ordersToAdd = new List(); + for (int i = 0; i < 2000; i++) + { + ordersToAdd.Add(new Order { Id = -i, Price = 10.57M }); + } + dbContext.Orders.AddRange(ordersToAdd); + + //Delete orders + var ordersToDelete = dbContext.Orders.Where(o => o.Price <= 5).ToList(); + dbContext.Orders.RemoveRange(ordersToDelete); + + //Update existing orders + var ordersToUpdate = dbContext.Orders.Where(o => o.Price > 5 && o.Price <= 10).ToList(); + foreach (var orderToUpdate in ordersToUpdate) + { + orderToUpdate.Price = 99M; + } + + + int rowsAffected = await dbContext.BulkSaveChangesAsync(); + int ordersAddedCount = dbContext.Orders.Where(o => o.Price == 10.57M).Count(); + int ordersDeletedCount = dbContext.Orders.Where(o => o.Price <= 5).Count(); + int ordersUpdatedCount = dbContext.Orders.Where(o => o.Price == 99M).Count(); + + Assert.IsTrue(rowsAffected == ordersToAdd.Count + ordersToDelete.Count + ordersToUpdate.Count, "The number of rows affected must equal the sum of entities added, deleted and updated"); + Assert.IsTrue(ordersAddedCount == ordersToAdd.Count(), "The number of orders to add did not match what was expected."); + Assert.IsTrue(ordersDeletedCount == 0, "The number of orders that was deleted did not match what was expected."); + Assert.IsTrue(ordersUpdatedCount == ordersToUpdate.Count(), "The number of orders that was updated did not match what was expected."); + } + [TestMethod] + public async Task With_Add_Changes() + { + var dbContext = SetupDbContext(true); + var oldTotalCount = dbContext.Orders.Where(o => o.Price == 10.57M).Count(); + + var ordersToAdd = new List(); + for (int i = 0; i < 2000; i++) + { + ordersToAdd.Add(new Order { Id = -i, Price = 10.57M }); + } + dbContext.Orders.AddRange(ordersToAdd); + + int rowsAffected = await dbContext.BulkSaveChangesAsync(); + int newTotalCount = dbContext.Orders.Where(o => o.Price == 10.57M).Count(); + + Assert.IsTrue(ordersToAdd.Where(o => o.Id <= 0).Count() == 0, "Primary key should have been updated for all entities"); + Assert.IsTrue(rowsAffected == ordersToAdd.Count, "The number of rows affected must equal the sum of entities added, deleted and updated"); + Assert.IsTrue(oldTotalCount + ordersToAdd.Count == newTotalCount, "The number of orders to add did not match what was expected."); + } + [TestMethod] + public async Task With_Delete_Changes() + { + var dbContext = SetupDbContext(true); + var oldTotalCount = dbContext.Orders.Where(o => o.Price <= 5).Count(); + + //Delete orders + var ordersToDelete = dbContext.Orders.Where(o => o.Price <= 5).ToList(); + dbContext.Orders.RemoveRange(ordersToDelete); + + int rowsAffected = await dbContext.BulkSaveChangesAsync(); + int newTotalCount = dbContext.Orders.Where(o => o.Price <= 5).Count(); + + Assert.IsTrue(rowsAffected == ordersToDelete.Count, "The number of rows affected must equal the sum of entities added, deleted and updated"); + Assert.IsTrue(oldTotalCount - ordersToDelete.Count == newTotalCount, "The number of orders to add did not match what was expected."); + } + [TestMethod] + public async Task With_Update_Changes() + { + var dbContext = SetupDbContext(true); + var oldTotalCount = dbContext.Orders.Where(o => o.Price <= 10).Count(); + + //Update existing orders + var ordersToUpdate = dbContext.Orders.Where(o => o.Price <= 10).ToList(); + foreach (var orderToUpdate in ordersToUpdate) + { + orderToUpdate.Price = 99M; + } + + int rowsAffected = await dbContext.BulkSaveChangesAsync(); + int newTotalCount = dbContext.Orders.Where(o => o.Price <= 10).Count(); + int expectedCount = dbContext.Orders.Where(o => o.Price == 99M).Count(); + + Assert.IsTrue(rowsAffected == ordersToUpdate.Count, "The number of rows affected must equal the sum of entities added, deleted and updated"); + Assert.IsTrue(oldTotalCount - ordersToUpdate.Count == newTotalCount, "The number of orders to add did not match what was expected."); + } + [TestMethod] + public async Task With_Inheritance_Tpc() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpc); + //Delete Customers + var customersToDelete = dbContext.TpcPeople.OfType().Where(o => o.Id <= 1000); + int expectedRowsDeleted = customersToDelete.Count(); + dbContext.TpcPeople.RemoveRange(customersToDelete); + //Update Customers + var customersToUpdate = dbContext.TpcPeople.OfType().Where(o => o.Id > 1000 && o.Id <= 1500); + int expectedRowsUpdated = customersToUpdate.Count(); + foreach (var customerToUpdate in customersToUpdate) + { + customerToUpdate.FirstName = "CustomerUpdated"; + } + //Add New Customers + long maxId = dbContext.TpcPeople.OfType().Max(o => o.Id); + int expectedRowsAdded = 3000; + for (long i = maxId + 1; i <= maxId + expectedRowsAdded; i++) + { + dbContext.TpcPeople.Add(new TpcCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + AddedDate = DateTime.UtcNow + }); + } + int rowsAffected = await dbContext.BulkSaveChangesAsync(); + int rowsAfterDelete = customersToDelete.Count(); + int rowsUpdated = customersToUpdate.Where(o => o.FirstName == "CustomerUpdated").Count(); + int rowsAdded = dbContext.TpcPeople.OfType().Where(o => o.Id > maxId).Count(); + int expectedRowsAffected = expectedRowsDeleted + expectedRowsUpdated + expectedRowsAdded; + + Assert.IsTrue(rowsAfterDelete == 0, "The number of rows deleted not not match what was expected."); + Assert.IsTrue(rowsUpdated == expectedRowsUpdated, "The number of rows updated not not match what was expected."); + Assert.IsTrue(rowsAdded == expectedRowsAdded, "The number of rows added not not match what was expected."); + Assert.IsTrue(rowsAffected == expectedRowsAffected, "The new count minus the old count should match the number of rows inserted."); + } + [TestMethod] + public async Task With_Inheritance_Tph() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tph); + //Delete Customers + var customersToDelete = dbContext.TphCustomers.Where(o => o.Id <= 1000); + int expectedRowsDeleted = customersToDelete.Count(); + dbContext.TphCustomers.RemoveRange(customersToDelete); + //Update Customers + var customersToUpdate = dbContext.TphCustomers.Where(o => o.Id > 1000 && o.Id <= 1500); + int expectedRowsUpdated = customersToUpdate.Count(); + foreach (var customerToUpdate in customersToUpdate) + { + customerToUpdate.FirstName = "CustomerUpdated"; + } + //Add New Customers + long maxId = dbContext.TphPeople.Max(o => o.Id); + int expectedRowsAdded = 3000; + for (long i = maxId + 1; i <= maxId + expectedRowsAdded; i++) + { + dbContext.TphCustomers.Add(new TphCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + AddedDate = DateTime.UtcNow + }); + } + int rowsAffected = await dbContext.BulkSaveChangesAsync(); + int rowsAfterDelete = customersToDelete.Count(); + int rowsUpdated = customersToUpdate.Where(o => o.FirstName == "CustomerUpdated").Count(); + int rowsAdded = dbContext.TphCustomers.Where(o => o.Id > maxId).Count(); + int expectedRowsAffected = expectedRowsDeleted + expectedRowsUpdated + expectedRowsAdded; + + Assert.IsTrue(expectedRowsDeleted > 0, "The expected number of rows to delete must be greater than zero."); + Assert.IsTrue(rowsAfterDelete == 0, "The number of rows deleted not not match what was expected."); + Assert.IsTrue(rowsUpdated == expectedRowsUpdated, "The number of rows updated not not match what was expected."); + Assert.IsTrue(rowsAdded == expectedRowsAdded, "The number of rows added not not match what was expected."); + Assert.IsTrue(rowsAffected == expectedRowsAffected, "The new count minus the old count should match the number of rows inserted."); + } + [TestMethod] + public async Task With_Inheritance_Tpt() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpt); + //Delete Customers + var customersToDelete = dbContext.TptCustomers.Where(o => o.Id <= 1000); + int expectedRowsDeleted = customersToDelete.Count(); + dbContext.TptCustomers.RemoveRange(customersToDelete); + //Update Customers + var customersToUpdate = dbContext.TptCustomers.Where(o => o.Id > 1000 && o.Id <= 1500); + int expectedRowsUpdated = customersToUpdate.Count(); + foreach (var customerToUpdate in customersToUpdate) + { + customerToUpdate.Email = "name@domain.com"; + customerToUpdate.FirstName = "CustomerUpdated"; + } + //Add New Customers + long maxId = dbContext.TptPeople.Max(o => o.Id); + int expectedRowsAdded = 3000; + for (long i = maxId + 1; i <= maxId + expectedRowsAdded; i++) + { + dbContext.TptCustomers.Add(new TptCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + AddedDate = DateTime.UtcNow + }); + } + int rowsAffected = await dbContext.BulkSaveChangesAsync(); + int rowsAfterDelete = customersToDelete.Count(); + int rowsUpdated = customersToUpdate.Where(o => o.FirstName == "CustomerUpdated").Count(); + int rowsAdded = dbContext.TptCustomers.Where(o => o.Id > maxId).Count(); + int expectedRowsAffected = expectedRowsDeleted + expectedRowsUpdated + expectedRowsAdded; + + Assert.IsTrue(rowsAfterDelete == 0, "The number of rows deleted not not match what was expected."); + Assert.IsTrue(rowsUpdated == expectedRowsUpdated, "The number of rows updated not not match what was expected."); + Assert.IsTrue(rowsAdded == expectedRowsAdded, "The number of rows added not not match what was expected."); + Assert.IsTrue(rowsAffected == expectedRowsAffected, "The new count minus the old count should match the number of rows inserted."); + } + [TestMethod] + public async Task With_Schema() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Schema); + var totalCount = await dbContext.ProductsWithCustomSchema.CountAsync(); + + //Add new products + var productsToAdd = new List(); + for (int i = 0; i < 2000; i++) + { + productsToAdd.Add(new ProductWithCustomSchema { Id = (-i).ToString(), Price = 10.57M }); + } + dbContext.ProductsWithCustomSchema.AddRange(productsToAdd); + + //Delete products + var productsToDelete = dbContext.ProductsWithCustomSchema.Where(o => o.Price <= 5).ToList(); + dbContext.ProductsWithCustomSchema.RemoveRange(productsToDelete); + + //Update existing products + var productsToUpdate = dbContext.ProductsWithCustomSchema.Where(o => o.Price > 5 && o.Price <= 10).ToList(); + foreach (var productToUpdate in productsToUpdate) + { + productToUpdate.Price = 99M; + } + + int rowsAffected = await dbContext.BulkSaveChangesAsync(); + int productsAddedCount = await dbContext.ProductsWithCustomSchema.Where(o => o.Price == 10.57M).CountAsync(); + int productsDeletedCount = await dbContext.ProductsWithCustomSchema.Where(o => o.Price <= 5).CountAsync(); + int productsUpdatedCount = await dbContext.ProductsWithCustomSchema.Where(o => o.Price == 99M).CountAsync(); + + Assert.IsTrue(rowsAffected == productsToAdd.Count + productsToDelete.Count + productsToUpdate.Count, "The number of rows affected must equal the sum of entities added, deleted and updated"); + Assert.IsTrue(productsAddedCount == productsToAdd.Count(), "The number of products to add did not match what was expected."); + Assert.IsTrue(productsDeletedCount == 0, "The number of products that was deleted did not match what was expected."); + Assert.IsTrue(productsUpdatedCount == productsToUpdate.Count(), "The number of products that was updated did not match what was expected."); + } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/BulkSync.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/BulkSync.cs new file mode 100644 index 0000000..3240ca5 --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/BulkSync.cs @@ -0,0 +1,241 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Data; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class BulkSync : DbContextExtensionsBase +{ + [TestMethod] + public void With_Default_Options() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Count(); + var orders = dbContext.Orders.Where(o => o.Id <= 10000).OrderBy(o => o.Id).ToList(); + int ordersToAdd = 5000; + int ordersToUpdate = orders.Count; + foreach (var order in orders) + { + order.Price = Convert.ToDecimal(order.Id + .25); + } + for (int i = 0; i < ordersToAdd; i++) + { + orders.Add(new Order { Id = 100000 + i, Price = 3.55M }); + } + var result = dbContext.BulkSync(orders); + var newOrders = dbContext.Orders.OrderBy(o => o.Id).ToList(); + bool areAddedOrdersMerged = true; + bool areUpdatedOrdersMerged = true; + foreach (var newOrder in newOrders.Where(o => o.Id <= 10000).OrderBy(o => o.Id)) + { + if (newOrder.Price != Convert.ToDecimal(newOrder.Id + .25)) + { + areUpdatedOrdersMerged = false; + break; + } + } + foreach (var newOrder in newOrders.Where(o => o.Id >= 500000).OrderBy(o => o.Id)) + { + if (newOrder.Price != 3.55M) + { + areAddedOrdersMerged = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == oldTotal + ordersToAdd, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(result.RowsDeleted == oldTotal - orders.Count() + ordersToAdd, "The number of rows deleted must match the difference from the total existing orders to the new orders to add/update"); + Assert.IsTrue(areAddedOrdersMerged, "The orders that were added did not merge correctly"); + Assert.IsTrue(areUpdatedOrdersMerged, "The orders that were updated did not merge correctly"); + } + [TestMethod] + public void With_Inheritance_Tpc() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpc); + var customers = dbContext.TpcPeople.Where(o => o.Id <= 1000).OfType().ToList(); + int customersToAdd = 5000; + int customersToUpdate = customers.Count; + int customersToDelete = dbContext.TpcPeople.OfType().Count() - customersToUpdate; + + foreach (var customer in customers) + { + customer.FirstName = "BulkSync_Tpc_Update"; + } + for (int i = 0; i < customersToAdd; i++) + { + customers.Add(new TpcCustomer + { + Id = 10000 + i, + FirstName = "BulkSync_Tpc_Add", + AddedDate = DateTime.UtcNow + }); + } + var result = dbContext.BulkSync(customers, options => { options.MergeOnCondition = (s, t) => s.Id == t.Id; }); + int customersAdded = dbContext.TpcPeople.Where(o => o.FirstName == "BulkSync_Tpc_Add").OfType().Count(); + int customersUpdated = dbContext.TpcPeople.Where(o => o.FirstName == "BulkSync_Tpc_Update").OfType().Count(); + int newCustomerTotal = dbContext.TpcPeople.OfType().Count(); + + Assert.IsTrue(result.RowsAffected == customersAdded + customersToUpdate + customersToDelete, "The number of rows affected must match the sum of customers added, updated and deleted."); + Assert.IsTrue(result.RowsUpdated == customersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == customersToAdd, "The number of rows added must match"); + Assert.IsTrue(result.RowsDeleted == customersToDelete, "The number of rows deleted must match the difference from the total existing orders to the new orders to add/update"); + Assert.IsTrue(customersToAdd == customersAdded, "The custmoers that were added did not merge correctly"); + Assert.IsTrue(customersToUpdate == customersUpdated, "The customers that were updated did not merge correctly"); + Assert.IsTrue(newCustomerTotal == customersToAdd + customersToUpdate, "The count of customers in the database should match the sum of customers added and updated."); + } + [TestMethod] + public void With_Inheritance_Tph() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tph); + var customers = dbContext.TphCustomers.Where(o => o.Id <= 1000).ToList(); + int customersToAdd = 5000; + int customersToUpdate = customers.Count; + int customersToDelete = dbContext.TphPeople.Count() - customersToUpdate; + + foreach (var customer in customers) + { + customer.FirstName = "BulkSync_Tph_Update"; + } + for (int i = 0; i < customersToAdd; i++) + { + customers.Add(new TphCustomer + { + Id = 10000 + i, + FirstName = "BulkSync_Tph_Add", + AddedDate = DateTime.UtcNow + }); + } + var result = dbContext.BulkSync(customers, options => { options.UsePermanentTable = true; options.MergeOnCondition = (s, t) => s.Id == t.Id; }); + int customersAdded = dbContext.TphCustomers.Where(o => o.FirstName == "BulkSync_Tph_Add").Count(); + int customersUpdated = dbContext.TphCustomers.Where(o => o.FirstName == "BulkSync_Tph_Update").Count(); + int newCustomerTotal = dbContext.TphCustomers.Count(); + + Assert.IsTrue(result.RowsAffected == customersAdded + customersToUpdate + customersToDelete, "The number of rows affected must match the sum of customers added, updated and deleted."); + Assert.IsTrue(result.RowsUpdated == customersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == customersToAdd, "The number of rows added must match"); + Assert.IsTrue(result.RowsDeleted == customersToDelete, "The number of rows deleted must match the difference from the total existing orders to the new orders to add/update"); + Assert.IsTrue(customersToAdd == customersAdded, "The customers that were added did not merge correctly"); + Assert.IsTrue(customersToUpdate == customersUpdated, "The customers that were updated did not merge correctly"); + Assert.IsTrue(newCustomerTotal == customersToAdd + customersToUpdate, "The count of customers in the database should match the sum of customers added and updated."); + } + [TestMethod] + public void With_Inheritance_Tpt() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpt); + var customers = dbContext.TptPeople.Where(o => o.Id <= 1000).OfType().ToList(); + int customersToAdd = 5000; + int customersToUpdate = customers.Count; + int customersToDelete = dbContext.TptCustomers.Count() - customersToUpdate; + + foreach (var customer in customers) + { + customer.FirstName = "BulkSync_Tpt_Update"; + } + for (int i = 0; i < customersToAdd; i++) + { + customers.Add(new TptCustomer + { + Id = 10000 + i, + FirstName = "BulkSync_Tpt_Add", + AddedDate = DateTime.UtcNow + }); + } + var result = dbContext.BulkSync(customers, options => { options.MergeOnCondition = (s, t) => s.Id == t.Id; }); + int customersAdded = dbContext.TptPeople.Where(o => o.FirstName == "BulkSync_Tpt_Add").OfType().Count(); + int customersUpdated = dbContext.TptPeople.Where(o => o.FirstName == "BulkSync_Tpt_Update").OfType().Count(); + int newCustomerTotal = dbContext.TptPeople.OfType().Count(); + + Assert.IsTrue(result.RowsAffected == customersAdded + customersToUpdate + customersToDelete, "The number of rows affected must match the sum of customers added, updated and deleted."); + Assert.IsTrue(result.RowsUpdated == customersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == customersToAdd, "The number of rows added must match"); + Assert.IsTrue(result.RowsDeleted == customersToDelete, "The number of rows deleted must match the difference from the total existing orders to the new orders to add/update"); + Assert.IsTrue(customersToAdd == customersAdded, "The custmoers that were added did not merge correctly"); + Assert.IsTrue(customersToUpdate == customersUpdated, "The customers that were updated did not merge correctly"); + Assert.IsTrue(newCustomerTotal == customersToAdd + customersToUpdate, "The count of customers in the database should match the sum of customers added and updated."); + } + [TestMethod] + public void With_Options_AutoMapIdentity() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Count(); + int ordersToUpdate = 3; + int ordersToAdd = 2; + var orders = new List + { + new Order { ExternalId = "id-1", Price=7.10M }, + new Order { ExternalId = "id-2", Price=9.33M }, + new Order { ExternalId = "id-3", Price=3.25M }, + new Order { ExternalId = "id-1000001", Price=2.15M }, + new Order { ExternalId = "id-1000002", Price=5.75M }, + }; + var result = dbContext.BulkSync(orders, options => { options.MergeOnCondition = (s, t) => s.ExternalId == t.ExternalId; options.UsePermanentTable = true; }); + bool autoMapIdentityMatched = true; + foreach (var order in orders) + { + if (!dbContext.Orders.Any(o => o.ExternalId == order.ExternalId && o.Price == order.Price)) + { + autoMapIdentityMatched = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == oldTotal + ordersToAdd, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(result.RowsDeleted == oldTotal - orders.Count() + ordersToAdd, "The number of rows deleted must match the difference from the total existing orders to the new orders to add/update"); + Assert.IsTrue(autoMapIdentityMatched, "The auto mapping of ids of entities that were merged failed to match up"); + } + [TestMethod] + public void With_Options_MergeOnCondition() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Count(); + var orders = dbContext.Orders.Where(o => o.Id <= 100 && o.ExternalId != null).OrderBy(o => o.Id).ToList(); + int ordersToAdd = 50; + int ordersToUpdate = orders.Count; + foreach (var order in orders) + { + order.Price = Convert.ToDecimal(order.Id + .25); + } + for (int i = 0; i < ordersToAdd; i++) + { + orders.Add(new Order { Id = 100000 + i, Price = 3.55M }); + } + var result = dbContext.BulkSync(orders, new BulkSyncOptions + { + MergeOnCondition = (s, t) => s.ExternalId == t.ExternalId, + BatchSize = 1000 + }); + var newOrders = dbContext.Orders.OrderBy(o => o.Id).ToList(); + bool areAddedOrdersMerged = true; + bool areUpdatedOrdersMerged = true; + foreach (var newOrder in newOrders.Where(o => o.Id <= 100 && o.ExternalId != null).OrderBy(o => o.Id)) + { + if (newOrder.Price != Convert.ToDecimal(newOrder.Id + .25)) + { + areUpdatedOrdersMerged = false; + break; + } + } + foreach (var newOrder in newOrders.Where(o => o.Id >= 500000).OrderBy(o => o.Id)) + { + if (newOrder.Price != 3.55M) + { + areAddedOrdersMerged = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == oldTotal + ordersToAdd, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(result.RowsDeleted == oldTotal - orders.Count() + ordersToAdd, "The number of rows deleted must match the difference from the total existing orders to the new orders to add/update"); + Assert.IsTrue(areAddedOrdersMerged, "The orders that were added did not merge correctly"); + Assert.IsTrue(areUpdatedOrdersMerged, "The orders that were updated did not merge correctly"); + } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/BulkSyncAsync.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/BulkSyncAsync.cs new file mode 100644 index 0000000..1108a0a --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/BulkSyncAsync.cs @@ -0,0 +1,243 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.EntityFrameworkCore; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Data; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class BulkSyncAsync : DbContextExtensionsBase +{ + [TestMethod] + public async Task With_Default_Options() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Count(); + var orders = dbContext.Orders.Where(o => o.Id <= 10000).OrderBy(o => o.Id).ToList(); + int ordersToAdd = 5000; + int ordersToUpdate = orders.Count; + foreach (var order in orders) + { + order.Price = Convert.ToDecimal(order.Id + .25); + } + for (int i = 0; i < ordersToAdd; i++) + { + orders.Add(new Order { Id = 100000 + i, Price = 3.55M }); + } + var result = await dbContext.BulkSyncAsync(orders); + var newOrders = dbContext.Orders.OrderBy(o => o.Id).ToList(); + bool areAddedOrdersMerged = true; + bool areUpdatedOrdersMerged = true; + foreach (var newOrder in newOrders.Where(o => o.Id <= 10000).OrderBy(o => o.Id)) + { + if (newOrder.Price != Convert.ToDecimal(newOrder.Id + .25)) + { + areUpdatedOrdersMerged = false; + break; + } + } + foreach (var newOrder in newOrders.Where(o => o.Id >= 500000).OrderBy(o => o.Id)) + { + if (newOrder.Price != 3.55M) + { + areAddedOrdersMerged = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == oldTotal + ordersToAdd, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(result.RowsDeleted == oldTotal - orders.Count() + ordersToAdd, "The number of rows deleted must match the difference from the total existing orders to the new orders to add/update"); + Assert.IsTrue(areAddedOrdersMerged, "The orders that were added did not merge correctly"); + Assert.IsTrue(areUpdatedOrdersMerged, "The orders that were updated did not merge correctly"); + } + [TestMethod] + public async Task With_Inheritance_Tpc() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpc); + var customers = await dbContext.TpcPeople.Where(o => o.Id <= 1000).OfType().ToListAsync(); + int customersToAdd = 5000; + int customersToUpdate = customers.Count; + int customersToDelete = dbContext.TpcPeople.OfType().Count() - customersToUpdate; + + foreach (var customer in customers) + { + customer.FirstName = "BulkSync_Tpc_Update"; + } + for (int i = 0; i < customersToAdd; i++) + { + customers.Add(new TpcCustomer + { + Id = 10000 + i, + FirstName = "BulkSync_Tpc_Add", + AddedDate = DateTime.UtcNow + }); + } + var result = await dbContext.BulkSyncAsync(customers, options => { options.MergeOnCondition = (s, t) => s.Id == t.Id; }); + int customersAdded = dbContext.TpcPeople.Where(o => o.FirstName == "BulkSync_Tpc_Add").OfType().Count(); + int customersUpdated = dbContext.TpcPeople.Where(o => o.FirstName == "BulkSync_Tpc_Update").OfType().Count(); + int newCustomerTotal = dbContext.TpcPeople.OfType().Count(); + + Assert.IsTrue(result.RowsAffected == customersAdded + customersToUpdate + customersToDelete, "The number of rows affected must match the sum of customers added, updated and deleted."); + Assert.IsTrue(result.RowsUpdated == customersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == customersToAdd, "The number of rows added must match"); + Assert.IsTrue(result.RowsDeleted == customersToDelete, "The number of rows deleted must match the difference from the total existing orders to the new orders to add/update"); + Assert.IsTrue(customersToAdd == customersAdded, "The custmoers that were added did not merge correctly"); + Assert.IsTrue(customersToUpdate == customersUpdated, "The customers that were updated did not merge correctly"); + Assert.IsTrue(newCustomerTotal == customersToAdd + customersToUpdate, "The count of customers in the database should match the sum of customers added and updated."); + } + [TestMethod] + public async Task With_Inheritance_Tph() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tph); + var customers = await dbContext.TphCustomers.Where(o => o.Id <= 1000).ToListAsync(); + int customersToAdd = 5000; + int customersToUpdate = customers.Count; + int customersToDelete = dbContext.TphPeople.Count() - customersToUpdate; + + foreach (var customer in customers) + { + customer.FirstName = "BulkSync_Tph_Update"; + } + for (int i = 0; i < customersToAdd; i++) + { + customers.Add(new TphCustomer + { + Id = 10000 + i, + FirstName = "BulkSync_Tph_Add", + AddedDate = DateTime.UtcNow + }); + } + var result = await dbContext.BulkSyncAsync(customers, options => { options.UsePermanentTable = true; options.MergeOnCondition = (s, t) => s.Id == t.Id; }); + int customersAdded = dbContext.TphCustomers.Where(o => o.FirstName == "BulkSync_Tph_Add").Count(); + int customersUpdated = dbContext.TphCustomers.Where(o => o.FirstName == "BulkSync_Tph_Update").Count(); + int newCustomerTotal = dbContext.TphCustomers.Count(); + + Assert.IsTrue(result.RowsAffected == customersAdded + customersToUpdate + customersToDelete, "The number of rows affected must match the sum of customers added, updated and deleted."); + Assert.IsTrue(result.RowsUpdated == customersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == customersToAdd, "The number of rows added must match"); + Assert.IsTrue(result.RowsDeleted == customersToDelete, "The number of rows deleted must match the difference from the total existing orders to the new orders to add/update"); + Assert.IsTrue(customersToAdd == customersAdded, "The customers that were added did not merge correctly"); + Assert.IsTrue(customersToUpdate == customersUpdated, "The customers that were updated did not merge correctly"); + Assert.IsTrue(newCustomerTotal == customersToAdd + customersToUpdate, "The count of customers in the database should match the sum of customers added and updated."); + } + [TestMethod] + public async Task With_Inheritance_Tpt() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpt); + var customers = await dbContext.TptPeople.Where(o => o.Id <= 1000).OfType().ToListAsync(); + int customersToAdd = 5000; + int customersToUpdate = customers.Count; + int customersToDelete = dbContext.TptCustomers.Count() - customersToUpdate; + + foreach (var customer in customers) + { + customer.FirstName = "BulkSync_Tpt_Update"; + } + for (int i = 0; i < customersToAdd; i++) + { + customers.Add(new TptCustomer + { + Id = 10000 + i, + FirstName = "BulkSync_Tpt_Add", + AddedDate = DateTime.UtcNow + }); + } + var result = await dbContext.BulkSyncAsync(customers, options => { options.MergeOnCondition = (s, t) => s.Id == t.Id; }); + int customersAdded = dbContext.TptPeople.Where(o => o.FirstName == "BulkSync_Tpt_Add").OfType().Count(); + int customersUpdated = dbContext.TptPeople.Where(o => o.FirstName == "BulkSync_Tpt_Update").OfType().Count(); + int newCustomerTotal = dbContext.TptPeople.OfType().Count(); + + Assert.IsTrue(result.RowsAffected == customersAdded + customersToUpdate + customersToDelete, "The number of rows affected must match the sum of customers added, updated and deleted."); + Assert.IsTrue(result.RowsUpdated == customersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == customersToAdd, "The number of rows added must match"); + Assert.IsTrue(result.RowsDeleted == customersToDelete, "The number of rows deleted must match the difference from the total existing orders to the new orders to add/update"); + Assert.IsTrue(customersToAdd == customersAdded, "The custmoers that were added did not merge correctly"); + Assert.IsTrue(customersToUpdate == customersUpdated, "The customers that were updated did not merge correctly"); + Assert.IsTrue(newCustomerTotal == customersToAdd + customersToUpdate, "The count of customers in the database should match the sum of customers added and updated."); + } + [TestMethod] + public async Task With_Options_AutoMapIdentity() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Count(); + int ordersToUpdate = 3; + int ordersToAdd = 2; + var orders = new List + { + new Order { ExternalId = "id-1", Price=7.10M }, + new Order { ExternalId = "id-2", Price=9.33M }, + new Order { ExternalId = "id-3", Price=3.25M }, + new Order { ExternalId = "id-1000001", Price=2.15M }, + new Order { ExternalId = "id-1000002", Price=5.75M }, + }; + var result = await dbContext.BulkSyncAsync(orders, options => { options.MergeOnCondition = (s, t) => s.ExternalId == t.ExternalId; options.UsePermanentTable = true; }); + bool autoMapIdentityMatched = true; + foreach (var order in orders) + { + if (!dbContext.Orders.Any(o => o.ExternalId == order.ExternalId && o.Price == order.Price)) + { + autoMapIdentityMatched = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == oldTotal + ordersToAdd, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(result.RowsDeleted == oldTotal - orders.Count() + ordersToAdd, "The number of rows deleted must match the difference from the total existing orders to the new orders to add/update"); + Assert.IsTrue(autoMapIdentityMatched, "The auto mapping of ids of entities that were merged failed to match up"); + } + [TestMethod] + public async Task With_Options_MergeOnCondition() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Count(); + var orders = dbContext.Orders.Where(o => o.Id <= 100 && o.ExternalId != null).OrderBy(o => o.Id).ToList(); + int ordersToAdd = 50; + int ordersToUpdate = orders.Count; + foreach (var order in orders) + { + order.Price = Convert.ToDecimal(order.Id + .25); + } + for (int i = 0; i < ordersToAdd; i++) + { + orders.Add(new Order { Id = 100000 + i, Price = 3.55M }); + } + var result = await dbContext.BulkSyncAsync(orders, new BulkSyncOptions + { + MergeOnCondition = (s, t) => s.ExternalId == t.ExternalId, + BatchSize = 1000 + }); + var newOrders = dbContext.Orders.OrderBy(o => o.Id).ToList(); + bool areAddedOrdersMerged = true; + bool areUpdatedOrdersMerged = true; + foreach (var newOrder in newOrders.Where(o => o.Id <= 100 && o.ExternalId != null).OrderBy(o => o.Id)) + { + if (newOrder.Price != Convert.ToDecimal(newOrder.Id + .25)) + { + areUpdatedOrdersMerged = false; + break; + } + } + foreach (var newOrder in newOrders.Where(o => o.Id >= 500000).OrderBy(o => o.Id)) + { + if (newOrder.Price != 3.55M) + { + areAddedOrdersMerged = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == oldTotal + ordersToAdd, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(result.RowsDeleted == oldTotal - orders.Count() + ordersToAdd, "The number of rows deleted must match the difference from the total existing orders to the new orders to add/update"); + Assert.IsTrue(areAddedOrdersMerged, "The orders that were added did not merge correctly"); + Assert.IsTrue(areUpdatedOrdersMerged, "The orders that were updated did not merge correctly"); + } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/BulkUpdate.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/BulkUpdate.cs new file mode 100644 index 0000000..952a06f --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/BulkUpdate.cs @@ -0,0 +1,244 @@ +using System.Linq; +using Microsoft.EntityFrameworkCore; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Data; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class BulkUpdate : DbContextExtensionsBase +{ + [TestMethod] + public void With_Complex_Key() + { + var dbContext = SetupDbContext(true); + var products = dbContext.ProductsWithComplexKey.Where(o => o.Price == 1.25M).ToList(); + foreach (var product in products) + { + product.Price = 2.35M; + } + var oldTotal = dbContext.ProductsWithComplexKey.Where(o => o.Price == 2.35M).Count(); + int rowsUpdated = dbContext.BulkUpdate(products); + var newTotal = dbContext.ProductsWithComplexKey.Where(o => o.Price == 2.35M).Count(); + + Assert.IsTrue(products.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(rowsUpdated == products.Count, "The number of rows updated must match the count of entities that were retrieved"); + Assert.IsTrue(newTotal == rowsUpdated + oldTotal, "The count of new orders must be equal the number of rows updated in the database."); + } + [TestMethod] + public void With_Default_Options() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M).OrderBy(o => o.Id).ToList(); + long maxId = 0; + foreach (var order in orders) + { + order.Price = 2.35M; + maxId = order.Id; + } + int rowsUpdated = dbContext.BulkUpdate(orders); + var newOrders = dbContext.Orders.Where(o => o.Price == 2.35M).OrderBy(o => o.Id).Count(); + int entitiesWithChanges = dbContext.ChangeTracker.Entries().Where(t => t.State == EntityState.Modified).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(rowsUpdated == orders.Count, "The number of rows updated must match the count of entities that were retrieved"); + Assert.IsTrue(newOrders == rowsUpdated, "The count of new orders must be equal the number of rows updated in the database."); + } + [TestMethod] + public void With_Inheritance_Tpc() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpc); + var customers = dbContext.TpcPeople.Where(o => o.LastName != "BulkUpdate_Tpc").OfType().ToList(); + var vendors = dbContext.TpcPeople.OfType().ToList(); + foreach (var customer in customers) + { + customer.FirstName = string.Format("Id={0}", customer.Id); + customer.LastName = "BulkUpdate_Tpc"; + } + int rowsUpdated = dbContext.BulkUpdate(customers); + var newCustomers = dbContext.TpcPeople.Where(o => o.LastName == "BulkUpdate_Tpc").OfType().Count(); + int entitiesWithChanges = dbContext.ChangeTracker.Entries().Where(t => t.State == EntityState.Modified).Count(); + + Assert.IsTrue(vendors.Count > 0 && vendors.Count != customers.Count, "There should be vendor records in the database"); + Assert.IsTrue(customers.Count > 0, "There must be customers in database that match this condition (Price = $1.25)"); + Assert.IsTrue(rowsUpdated == customers.Count, "The number of rows updated must match the count of entities that were retrieved"); + Assert.IsTrue(newCustomers == rowsUpdated, "The count of new customers must be equal the number of rows updated in the database."); + } + [TestMethod] + public void With_Inheritance_Tph() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tph); + var customers = dbContext.TphPeople.Where(o => o.LastName != "BulkUpdateTest").OfType().ToList(); + var vendors = dbContext.TphPeople.OfType().ToList(); + foreach (var customer in customers) + { + customer.FirstName = string.Format("Id={0}", customer.Id); + customer.LastName = "BulkUpdateTest"; + } + int rowsUpdated = dbContext.BulkUpdate(customers); + var newCustomers = dbContext.TphPeople.Where(o => o.LastName == "BulkUpdateTest").OrderBy(o => o.Id).Count(); + int entitiesWithChanges = dbContext.ChangeTracker.Entries().Where(t => t.State == EntityState.Modified).Count(); + + Assert.IsTrue(vendors.Count > 0 && vendors.Count != customers.Count, "There should be vendor records in the database"); + Assert.IsTrue(customers.Count > 0, "There must be customers in database that match this condition (Price = $1.25)"); + Assert.IsTrue(rowsUpdated == customers.Count, "The number of rows updated must match the count of entities that were retrieved"); + Assert.IsTrue(newCustomers == rowsUpdated, "The count of new customers must be equal the number of rows updated in the database."); + } + [TestMethod] + public void With_Inheritance_Tpt() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpt); + var customers = dbContext.TptCustomers.Where(o => o.LastName != "BulkUpdateTest").ToList(); + foreach (var customer in customers) + { + customer.FirstName = string.Format("Id={0}", customer.Id); + customer.LastName = "BulkUpdateTest"; + } + int rowsUpdated = dbContext.BulkUpdate(customers); + var newCustomers = dbContext.TptCustomers.Where(o => o.LastName == "BulkUpdateTest").OrderBy(o => o.Id).Count(); + //int entitiesWithChanges = dbContext.ChangeTracker.Entries().Where(t => t.State == EntityState.Modified).Count(); + + Assert.IsTrue(customers.Count > 0, "There must be customers in database that match this condition (Price = $1.25)"); + Assert.IsTrue(rowsUpdated == customers.Count, "The number of rows updated must match the count of entities that were retrieved"); + Assert.IsTrue(newCustomers == rowsUpdated, "The count of new customers must be equal the number of rows updated in the database."); + } + [TestMethod] + public void With_Options_InputColumns_PropertyExpression() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).OrderBy(o => o.Id).ToList(); + foreach (var order in orders) + { + order.Price = 2.35M; + order.ExternalId = null; + } + var oldTotal = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + int rowsUpdated = dbContext.BulkUpdate(orders, options => { options.InputColumns = o => o.Price; }); + var newTotal1 = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + var newTotal2 = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(newTotal1 == rowsUpdated + oldTotal, "The count of new orders must be equal the number of rows updated in the database."); + Assert.IsTrue(newTotal2 == 0, "There should be not records with condition (Price = $1.25)"); + } + [TestMethod] + public void With_Options_InputColumns_NewExpression() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).OrderBy(o => o.Id).ToList(); + foreach (var order in orders) + { + order.Price = 2.35M; + order.ExternalId = null; + } + var oldTotal = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + int rowsUpdated = dbContext.BulkUpdate(orders, options => { options.InputColumns = o => new { o.Price }; }); + var newTotal1 = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + var newTotal2 = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(newTotal1 == rowsUpdated + oldTotal, "The count of new orders must be equal the number of rows updated in the database."); + Assert.IsTrue(newTotal2 == 0, "There should be not records with condition (Price = $1.25)"); + } + [TestMethod] + public void With_Options_IgnoreColumns_PropertyExpression() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).OrderBy(o => o.Id).ToList(); + foreach (var order in orders) + { + order.Price = 2.35M; + order.ExternalId = null; + } + var oldTotal = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + int rowsUpdated = dbContext.BulkUpdate(orders, options => { options.IgnoreColumns = o => o.ExternalId; }); + var newTotal1 = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + var newTotal2 = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(newTotal1 == rowsUpdated + oldTotal, "The count of new orders must be equal the number of rows updated in the database."); + Assert.IsTrue(newTotal2 == 0, "There should be not records with condition (Price = $1.25)"); + } + [TestMethod] + public void With_Options_IgnoreColumns_NewExpression() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).OrderBy(o => o.Id).ToList(); + foreach (var order in orders) + { + order.Price = 2.35M; + order.ExternalId = null; + } + var oldTotal = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + int rowsUpdated = dbContext.BulkUpdate(orders, options => { options.IgnoreColumns = o => new { o.ExternalId }; }); + var newTotal1 = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + var newTotal2 = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(newTotal1 == rowsUpdated + oldTotal, "The count of new orders must be equal the number of rows updated in the database."); + Assert.IsTrue(newTotal2 == 0, "There should be not records with condition (Price = $1.25)"); + } + [TestMethod] + public void With_Options_UpdateOnCondition() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M).OrderBy(o => o.Id).ToList(); + int ordersWithExternalId = orders.Where(o => o.ExternalId != null).Count(); + foreach (var order in orders) + { + order.Price = 2.35M; + } + var oldTotal = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + int rowsUpdated = dbContext.BulkUpdate(orders, options => { options.UpdateOnCondition = (s, t) => s.ExternalId == t.ExternalId; }); + var newTotal = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(rowsUpdated == ordersWithExternalId, "The number of rows updated must match the count of entities that were retrieved"); + Assert.IsTrue(newTotal == rowsUpdated + oldTotal, "The count of new orders must be equal the number of rows updated in the database."); + } + [TestMethod] + public void With_Options_UpdateOnCondition_Enum() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(o => o.Price == 1.25M).OrderBy(o => o.Id).ToList(); + foreach (var product in products) + { + product.Price = 2.35M; + } + int rowsUpdated = dbContext.BulkUpdate(products, o => + { + o.UpdateOnCondition = (s, t) => s.Id == t.Id && s.StatusEnum == t.StatusEnum; + }); + var newProducts = dbContext.Products.Where(o => o.Price == 2.35M).OrderBy(o => o.Id).Count(); + + Assert.IsTrue(products.Count > 0, "There must be products in database that match this condition (Price = $1.25)"); + Assert.IsTrue(rowsUpdated == products.Count, "The number of rows updated must match the count of entities that were retrieved"); + Assert.IsTrue(newProducts == rowsUpdated, "The count of new products must be equal the number of rows updated in the database."); + } + [DoNotParallelize] + [TestMethod] + public void With_Transaction() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M).OrderBy(o => o.Id).ToList(); + long maxId = 0; + foreach (var order in orders) + { + order.Price = 2.35M; + maxId = order.Id; + } + int rowsUpdated, newOrders; + using (var transaction = dbContext.Database.BeginTransaction()) + { + rowsUpdated = dbContext.BulkUpdate(orders); + newOrders = dbContext.Orders.Where(o => o.Price == 2.35M).Count(); + transaction.Rollback(); + } + int rollbackTotal = dbContext.Orders.Where(o => o.Price == 1.25M).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(rowsUpdated == orders.Count, "The number of rows updated must match the count of entities that were retrieved"); + Assert.IsTrue(newOrders == rowsUpdated, "The count of new orders must be equal the number of rows updated in the database."); + Assert.IsTrue(rollbackTotal == orders.Count, "The number of rows after the transacation has been rollbacked should match the original count"); + } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/BulkUpdateAsync.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/BulkUpdateAsync.cs new file mode 100644 index 0000000..e444ac0 --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/BulkUpdateAsync.cs @@ -0,0 +1,244 @@ +using System.Linq; +using System.Threading.Tasks; +using Microsoft.EntityFrameworkCore; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Data; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class BulkUpdateAsync : DbContextExtensionsBase +{ + [TestMethod] + public async Task With_Complex_Key() + { + var dbContext = SetupDbContext(true); + var products = dbContext.ProductsWithComplexKey.Where(o => o.Price == 1.25M).ToList(); + foreach (var product in products) + { + product.Price = 2.35M; + } + var oldTotal = dbContext.ProductsWithComplexKey.Where(o => o.Price == 2.35M).Count(); + int rowsUpdated = await dbContext.BulkUpdateAsync(products); + var newTotal = dbContext.ProductsWithComplexKey.Where(o => o.Price == 2.35M).Count(); + + Assert.IsTrue(products.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(rowsUpdated == products.Count, "The number of rows updated must match the count of entities that were retrieved"); + Assert.IsTrue(newTotal == rowsUpdated + oldTotal, "The count of new orders must be equal the number of rows updated in the database."); + } + [TestMethod] + public async Task With_Default_Options() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M).OrderBy(o => o.Id).ToList(); + long maxId = 0; + foreach (var order in orders) + { + order.Price = 2.35M; + maxId = order.Id; + } + int rowsUpdated = await dbContext.BulkUpdateAsync(orders); + var newOrders = dbContext.Orders.Where(o => o.Price == 2.35M).OrderBy(o => o.Id).Count(); + int entitiesWithChanges = dbContext.ChangeTracker.Entries().Where(t => t.State == EntityState.Modified).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(rowsUpdated == orders.Count, "The number of rows updated must match the count of entities that were retrieved"); + Assert.IsTrue(newOrders == rowsUpdated, "The count of new orders must be equal the number of rows updated in the database."); + } + [TestMethod] + public async Task With_Inheritance_Tpc() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpc); + var customers = dbContext.TpcPeople.Where(o => o.LastName != "BulkUpdateTest").OfType().ToList(); + var vendors = dbContext.TpcPeople.OfType().ToList(); + foreach (var customer in customers) + { + customer.FirstName = string.Format("Id={0}", customer.Id); + customer.LastName = "BulkUpdate_Tpc"; + } + int rowsUpdated = await dbContext.BulkUpdateAsync(customers, options => { options.UpdateOnCondition = (s, t) => s.Id == t.Id; }); + var newCustomers = dbContext.TpcPeople.Where(o => o.LastName == "BulkUpdate_Tpc").OfType().Count(); + int entitiesWithChanges = dbContext.ChangeTracker.Entries().Where(t => t.State == EntityState.Modified).Count(); + + Assert.IsTrue(vendors.Count > 0 && vendors.Count != customers.Count, "There should be vendor records in the database"); + Assert.IsTrue(customers.Count > 0, "There must be customers in database that match this condition (Price = $1.25)"); + Assert.IsTrue(rowsUpdated == customers.Count, "The number of rows updated must match the count of entities that were retrieved"); + Assert.IsTrue(newCustomers == rowsUpdated, "The count of new customers must be equal the number of rows updated in the database."); + } + [TestMethod] + public async Task With_Inheritance_Tph() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tph); + var customers = dbContext.TphPeople.Where(o => o.LastName != "BulkUpdateTest").OfType().ToList(); + var vendors = dbContext.TphPeople.OfType().ToList(); + foreach (var customer in customers) + { + customer.FirstName = string.Format("Id={0}", customer.Id); + customer.LastName = "BulkUpdateTest"; + } + int rowsUpdated = await dbContext.BulkUpdateAsync(customers); + var newCustomers = dbContext.TphPeople.Where(o => o.LastName == "BulkUpdateTest").OrderBy(o => o.Id).Count(); + int entitiesWithChanges = dbContext.ChangeTracker.Entries().Where(t => t.State == EntityState.Modified).Count(); + + Assert.IsTrue(vendors.Count > 0 && vendors.Count != customers.Count, "There should be vendor records in the database"); + Assert.IsTrue(customers.Count > 0, "There must be customers in database that match this condition (Price = $1.25)"); + Assert.IsTrue(rowsUpdated == customers.Count, "The number of rows updated must match the count of entities that were retrieved"); + Assert.IsTrue(newCustomers == rowsUpdated, "The count of new customers must be equal the number of rows updated in the database."); + } + [TestMethod] + public async Task With_Inheritance_Tpt() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpt); + var customers = dbContext.TptCustomers.Where(o => o.LastName != "BulkUpdateTest").ToList(); + foreach (var customer in customers) + { + customer.FirstName = string.Format("Id={0}", customer.Id); + customer.LastName = "BulkUpdateTest"; + } + int rowsUpdated = await dbContext.BulkUpdateAsync(customers); + var newCustomers = await dbContext.TptCustomers.Where(o => o.LastName == "BulkUpdateTest").CountAsync(); + + Assert.IsTrue(customers.Count > 0, "There must be customers in database that match this condition (Price = $1.25)"); + Assert.IsTrue(rowsUpdated == customers.Count, "The number of rows updated must match the count of entities that were retrieved"); + Assert.IsTrue(newCustomers == rowsUpdated, "The count of new customers must be equal the number of rows updated in the database."); + } + [TestMethod] + public async Task With_Options_InputColumns_PropertyExpression() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).OrderBy(o => o.Id).ToList(); + foreach (var order in orders) + { + order.Price = 2.35M; + order.ExternalId = null; + } + var oldTotal = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + int rowsUpdated = await dbContext.BulkUpdateAsync(orders, options => { options.InputColumns = o => o.Price; }); + var newTotal1 = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + var newTotal2 = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(newTotal1 == rowsUpdated + oldTotal, "The count of new orders must be equal the number of rows updated in the database."); + Assert.IsTrue(newTotal2 == 0, "There should be not records with condition (Price = $1.25)"); + } + [TestMethod] + public async Task With_Options_InputColumns_NewExpression() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).OrderBy(o => o.Id).ToList(); + foreach (var order in orders) + { + order.Price = 2.35M; + order.ExternalId = null; + } + var oldTotal = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + int rowsUpdated = await dbContext.BulkUpdateAsync(orders, options => { options.InputColumns = o => new { o.Price }; }); + var newTotal1 = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + var newTotal2 = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).Count(); + + Assert.IsTrue(orders.Count() > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(newTotal1 == rowsUpdated + oldTotal, "The count of new orders must be equal the number of rows updated in the database."); + Assert.IsTrue(newTotal2 == 0, "There should be not records with condition (Price = $1.25)"); + } + [TestMethod] + public async Task With_Options_IgnoreColumns_PropertyExpression() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).OrderBy(o => o.Id).ToList(); + foreach (var order in orders) + { + order.Price = 2.35M; + order.ExternalId = null; + } + var oldTotal = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + int rowsUpdated = await dbContext.BulkUpdateAsync(orders, options => { options.IgnoreColumns = o => o.ExternalId; }); + var newTotal1 = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + var newTotal2 = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(newTotal1 == rowsUpdated + oldTotal, "The count of new orders must be equal the number of rows updated in the database."); + Assert.IsTrue(newTotal2 == 0, "There should be not records with condition (Price = $1.25)"); + } + [TestMethod] + public async Task With_Options_IgnoreColumns_NewExpression() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).OrderBy(o => o.Id).ToList(); + foreach (var order in orders) + { + order.Price = 2.35M; + order.ExternalId = null; + } + var oldTotal = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + int rowsUpdated = await dbContext.BulkUpdateAsync(orders, options => { options.IgnoreColumns = o => new { o.ExternalId }; }); + var newTotal1 = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + var newTotal2 = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(newTotal1 == rowsUpdated + oldTotal, "The count of new orders must be equal the number of rows updated in the database."); + Assert.IsTrue(newTotal2 == 0, "There should be not records with condition (Price = $1.25)"); + } + [TestMethod] + public async Task With_Options_UpdateOnCondition() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M).OrderBy(o => o.Id).ToList(); + int ordersWithExternalId = orders.Where(o => o.ExternalId != null).Count(); + foreach (var order in orders) + { + order.Price = 2.35M; + } + var oldTotal = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + int rowsUpdated = await dbContext.BulkUpdateAsync(orders, options => { options.UpdateOnCondition = (s, t) => s.ExternalId == t.ExternalId; }); + var newTotal = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(rowsUpdated == ordersWithExternalId, "The number of rows updated must match the count of entities that were retrieved"); + Assert.IsTrue(newTotal == rowsUpdated + oldTotal, "The count of new orders must be equal the number of rows updated in the database."); + } + [TestMethod] + public async Task With_Options_UpdateOnCondition_Enum() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(o => o.Price == 1.25M).OrderBy(o => o.Id).ToList(); + foreach (var product in products) + { + product.Price = 2.35M; + } + int rowsUpdated = await dbContext.BulkUpdateAsync(products, o => + { + o.UpdateOnCondition = (s, t) => s.Id == t.Id && s.StatusEnum == t.StatusEnum; + }); + var newProducts = dbContext.Products.Where(o => o.Price == 2.35M).OrderBy(o => o.Id).Count(); + + Assert.IsTrue(products.Count > 0, "There must be products in database that match this condition (Price = $1.25)"); + Assert.IsTrue(rowsUpdated == products.Count, "The number of rows updated must match the count of entities that were retrieved"); + Assert.IsTrue(newProducts == rowsUpdated, "The count of new products must be equal the number of rows updated in the database."); + } + [DoNotParallelize] + [TestMethod] + public async Task With_Transaction() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M).OrderBy(o => o.Id).ToList(); + long maxId = 0; + foreach (var order in orders) + { + order.Price = 2.35M; + maxId = order.Id; + } + int rowsUpdated, newOrders; + using (var transaction = dbContext.Database.BeginTransaction()) + { + rowsUpdated = await dbContext.BulkUpdateAsync(orders); + newOrders = dbContext.Orders.Where(o => o.Price == 2.35M).Count(); + transaction.Rollback(); + } + int rollbackTotal = dbContext.Orders.Where(o => o.Price == 1.25M).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(rowsUpdated == orders.Count, "The number of rows updated must match the count of entities that were retrieved"); + Assert.IsTrue(newOrders == rowsUpdated, "The count of new orders must be equal the number of rows updated in the database."); + Assert.IsTrue(rollbackTotal == orders.Count, "The number of rows after the transacation has been rollbacked should match the original count"); + } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/DbContextExtensionsBase.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/DbContextExtensionsBase.cs new file mode 100644 index 0000000..526d7a7 --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/DbContextExtensionsBase.cs @@ -0,0 +1,267 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Drawing; +using Microsoft.EntityFrameworkCore; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Common; +using N.EntityFrameworkCore.Extensions.Test.Data; +using N.EntityFrameworkCore.Extensions.Test.Data.Enums; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +public enum PopulateDataMode +{ + Normal, + Tpc, + Tph, + Tpt, + Schema +} +[TestClass] +public class DbContextExtensionsBase +{ + private TestDbContext _currentDbContext; + + [TestInitialize] + public void Init() + { + using var dbContext = new TestDbContext(); + TestDatabaseInitializer.EnsureCreated(dbContext); + } + + [TestCleanup] + public void Cleanup() + { + _currentDbContext?.Dispose(); + _currentDbContext = null; + } + + protected TestDbContext SetupDbContext(bool populateData, PopulateDataMode mode = PopulateDataMode.Normal) + { + var dbContext = new TestDbContext(); + _currentDbContext = dbContext; + TestDatabaseInitializer.EnsureCreated(dbContext); + dbContext.Orders.Truncate(); + dbContext.Products.Truncate(); + dbContext.ProductCategories.Clear(); + dbContext.ProductsWithCustomSchema.Truncate(); + dbContext.ProductsWithTrigger.Truncate(); + dbContext.Database.ClearTable("TpcCustomer"); + dbContext.Database.ClearTable("TpcVendor"); + dbContext.TphPeople.Truncate(); + dbContext.Database.ClearTable("TptPeople"); + dbContext.Database.ClearTable("TptCustomer"); + dbContext.Database.ClearTable("TptVendor"); + dbContext.Database.DropTable("ProductsUnderTen", true); + dbContext.Database.DropTable("OrdersUnderTen", true); + dbContext.Database.DropTable("OrdersLast30Days", true); + if (populateData) + { + if (mode == PopulateDataMode.Normal) + { + var orders = new List(); + int id = 1; + for (int i = 0; i < 2050; i++) + { + DateTime addedDateTime = DateTime.UtcNow.AddDays(-id); + orders.Add(new Order + { + Id = id, + ExternalId = string.Format("id-{0}", i), + Price = 1.25M, + AddedDateTime = addedDateTime, + ModifiedDateTime = addedDateTime.AddHours(3), + Status = OrderStatus.Completed + }); + id++; + } + for (int i = 0; i < 1050; i++) + { + orders.Add(new Order { Id = id, Price = 5.35M }); + id++; + } + for (int i = 0; i < 2050; i++) + { + orders.Add(new Order { Id = id, Price = 1.25M }); + id++; + } + for (int i = 0; i < 6000; i++) + { + orders.Add(new Order { Id = id, Price = 15.35M }); + id++; + } + for (int i = 0; i < 6000; i++) + { + orders.Add(new Order { Id = id, Price = 15.35M }); + id++; + } + + Debug.WriteLine("Last Id for Order is {0}", id); + dbContext.BulkInsert(orders, new BulkInsertOptions() { KeepIdentity = true }); + + var productCategories = new List() + { + new ProductCategory { Id=1, Name="Category-1", Active=true}, + new ProductCategory { Id=2, Name="Category-2", Active=true}, + new ProductCategory { Id=3, Name="Category-3", Active=true}, + new ProductCategory { Id=4, Name="Category-4", Active=false}, + }; + dbContext.BulkInsert(productCategories, o => { o.KeepIdentity = true; o.UsePermanentTable = true; }); + var products = new List(); + id = 1; + for (int i = 0; i < 2050; i++) + { + products.Add(new Product + { + Id = i.ToString(), + Price = 1.25M, + OutOfStock = false, + ProductCategoryId = 4, + StatusEnum = ProductStatus.InStock, + Color = Color.Black, + Position = new Position { Building = 5, Aisle = 33, Bay = i }, + }); + id++; + } + for (int i = 2050; i < 7000; i++) + { + products.Add(new Product { Id = i.ToString(), Price = 1.25M, OutOfStock = true, StatusEnum = ProductStatus.OutOfStock }); + id++; + } + + Debug.WriteLine("Last Id for Product is {0}", id); + dbContext.BulkInsert(products, new BulkInsertOptions() { KeepIdentity = false, AutoMapOutput = false, UsePermanentTable = true }); + + //ProductWithComplexKey + var productsWithComplexKey = new List(); + id = 1; + + for (int i = 0; i < 2050; i++) + { + productsWithComplexKey.Add(new ProductWithComplexKey { Price = 1.25M }); + id++; + } + + Debug.WriteLine("Last Id for ProductsWithComplexKey is {0}", id); + dbContext.BulkInsert(productsWithComplexKey, new BulkInsertOptions() { KeepIdentity = false, AutoMapOutput = false }); + } + else if (mode == PopulateDataMode.Tph) + { + //TPH Customers & Vendors + var tphCustomers = new List(); + var tphVendors = new List(); + for (int i = 1; i < 2000; i++) + { + tphCustomers.Add(new TphCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + Phone = "404-555-1111", + AddedDate = DateTime.UtcNow + }); + } + for (int i = 2000; i < 3000; i++) + { + tphVendors.Add(new TphVendor + { + Id = i, + FirstName = string.Format("Mike_{0}", i), + LastName = string.Format("Smith_{0}", i), + Phone = "404-555-2222", + Email = string.Format("mike.smith{0}@domain.com", i), + Url = string.Format("http://domain.com/mike.smith{0}", i) + }); + } + dbContext.BulkInsert(tphCustomers, new BulkInsertOptions() { KeepIdentity = true }); + dbContext.BulkInsert(tphVendors, new BulkInsertOptions() { KeepIdentity = true }); + } + else if (mode == PopulateDataMode.Tpc) + { + //TPC Customers & Vendors + var tpcCustomers = new List(); + var tpcVendors = new List(); + for (int i = 1; i <= 2000; i++) + { + tpcCustomers.Add(new TpcCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + Phone = "404-555-1111", + AddedDate = DateTime.UtcNow + }); + } + for (int i = 2001; i <= 3000; i++) + { + tpcVendors.Add(new TpcVendor + { + Id = i, + FirstName = string.Format("Mike_{0}", i), + LastName = string.Format("Smith_{0}", i), + Phone = "404-555-2222", + Email = string.Format("mike.smith{0}@domain.com", i), + Url = string.Format("http://domain.com/mike.smith{0}", i) + }); + } + dbContext.BulkInsert(tpcCustomers, new BulkInsertOptions() { KeepIdentity = true }); + dbContext.BulkInsert(tpcVendors, new BulkInsertOptions() { KeepIdentity = true }); + } + else if (mode == PopulateDataMode.Tpt) + { + //Customers & Vendors + var tptCustomers = new List(); + var tptVendors = new List(); + for (int i = 1; i <= 2000; i++) + { + tptCustomers.Add(new TptCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + Phone = "404-555-1111", + AddedDate = DateTime.UtcNow + }); + } + for (int i = 2001; i < 3000; i++) + { + tptVendors.Add(new TptVendor + { + Id = i, + FirstName = string.Format("Mike_{0}", i), + LastName = string.Format("Smith_{0}", i), + Phone = "404-555-2222", + Email = string.Format("mike.smith{0}@domain.com", i), + Url = string.Format("http://domain.com/mike.smith{0}", i) + }); + } + dbContext.BulkInsert(tptCustomers, new BulkInsertOptions() { KeepIdentity = true, UsePermanentTable = true }); + dbContext.BulkInsert(tptVendors, new BulkInsertOptions() { KeepIdentity = true }); + } + else if (mode == PopulateDataMode.Schema) + { + //ProductWithCustomSchema + var productsWithCustomSchema = new List(); + int id = 1; + + for (int i = 0; i < 2050; i++) + { + productsWithCustomSchema.Add(new ProductWithCustomSchema { Id = id.ToString(), Price = 1.25M }); + id++; + } + for (int i = 2050; i < 5000; i++) + { + productsWithCustomSchema.Add(new ProductWithCustomSchema { Id = id.ToString(), Price = 6.75M }); + id++; + } + + dbContext.BulkInsert(productsWithCustomSchema); + } + } + return dbContext; + } +} diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/DeleteFromQuery.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/DeleteFromQuery.cs new file mode 100644 index 0000000..9c83c4b --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/DeleteFromQuery.cs @@ -0,0 +1,148 @@ +using System; +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class DeleteFromQuery : DbContextExtensionsBase +{ + [TestMethod] + public void With_Boolean_Value() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(p => p.OutOfStock); + int oldTotal = products.Count(a => a.OutOfStock); + int rowUpdated = products.DeleteFromQuery(); + int newTotal = dbContext.Products.Count(o => o.OutOfStock); + + Assert.IsTrue(oldTotal > 0, "There must be products in database that match this condition (OutOfStock == true)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the condition (OutOfStock == false)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public void With_Child_Relationship() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(p => !p.ProductCategory.Active); + int oldTotal = products.Count(); + int rowsDeleted = products.DeleteFromQuery(); + int newTotal = products.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be products in database that match this condition (ProductCategory.Active == false)"); + Assert.IsTrue(rowsDeleted == oldTotal, "The number of rows update must match the count of rows that match the condition (ProductCategory.Active == false)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were deleted"); + } + [TestMethod] + public void With_Decimal_Using_IQueryable() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price <= 10); + int oldTotal = orders.Count(); + int rowsDeleted = orders.DeleteFromQuery(); + int newTotal = orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(rowsDeleted == oldTotal, "The number of rows deleted must match the count of existing rows in database"); + Assert.IsTrue(newTotal == 0, "Delete() Failed: must be 0 to indicate all records were deleted"); + } + [TestMethod] + public void With_Decimal_Using_IEnumerable() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price <= 10); + int oldTotal = orders.Count(); + int rowsDeleted = orders.DeleteFromQuery(); + int newTotal = orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(rowsDeleted == oldTotal, "The number of rows deleted must match the count of existing rows in database"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were deleted"); + } + [TestMethod] + public void With_DateTime() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Count(); + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + int rowsToDelete = dbContext.Orders.Where(o => o.ModifiedDateTime != null && o.ModifiedDateTime >= dateTime).Count(); + int rowsDeleted = dbContext.Orders.Where(o => o.ModifiedDateTime != null && o.ModifiedDateTime >= dateTime) + .DeleteFromQuery(); + int newTotal = dbContext.Orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(rowsDeleted == rowsToDelete, "The number of rows deleted must match the count of the rows that matched in the database"); + Assert.IsTrue(oldTotal - newTotal == rowsDeleted, "The rows deleted must match the new count minues the old count"); + } + [TestMethod] + public void With_Delete_All() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Count(); + int rowsDeleted = dbContext.Orders.DeleteFromQuery(); + int newTotal = dbContext.Orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(rowsDeleted == oldTotal, "The number of rows deleted must match the count of existing rows in database"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were deleted"); + } + [TestMethod] + public void With_Different_Values() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Count(); + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + var orders = dbContext.Orders.Where(o => o.Id == 1 && o.Active && o.ModifiedDateTime >= dateTime); + int rowsToDelete = orders.Count(); + int rowsDeleted = orders.DeleteFromQuery(); + int newTotal = dbContext.Orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(rowsDeleted == rowsToDelete, "The number of rows deleted must match the count of the rows that matched in the database"); + Assert.IsTrue(oldTotal - newTotal == rowsDeleted, "The rows deleted must match the new count minues the old count"); + } + [TestMethod] + public void With_Empty_List() + { + var dbContext = SetupDbContext(false); + int oldTotal = dbContext.Orders.Count(); + int rowsDeleted = dbContext.Orders.DeleteFromQuery(); + int newTotal = dbContext.Orders.Count(); + + Assert.IsTrue(oldTotal == 0, "There must be no orders in database that match this condition"); + Assert.IsTrue(rowsDeleted == oldTotal, "The number of rows deleted must match the count of existing rows in database"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were deleted"); + } + [TestMethod] + public void With_Schema() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Schema); + int oldTotal = dbContext.ProductsWithCustomSchema.Count(); + int rowsDeleted = dbContext.ProductsWithCustomSchema.DeleteFromQuery(); + int newTotal = dbContext.ProductsWithCustomSchema.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be products in database that match this condition"); + Assert.IsTrue(rowsDeleted == oldTotal, "The number of rows deleted must match the count of existing rows in database"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were deleted"); + } + [DoNotParallelize] + [TestMethod] + public void With_Transaction() + { + var dbContext = SetupDbContext(true); + int rowsDeleted; + int oldTotal = dbContext.Orders.Count(); + var orders = dbContext.Orders.Where(o => o.Price <= 10); + int rowsToDelete = orders.Count(); + using (var transaction = dbContext.Database.BeginTransaction()) + { + rowsDeleted = orders.DeleteFromQuery(); + transaction.Rollback(); + } + int newTotal = dbContext.Orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < $10)"); + Assert.IsTrue(rowsDeleted == orders.Count(), "The number of rows update must match the count of rows that match the condtion (Price < $10)"); + Assert.IsTrue(newTotal == oldTotal, "The new count must match the old count since the transaction was rollbacked"); + } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/DeleteFromQueryAsync.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/DeleteFromQueryAsync.cs new file mode 100644 index 0000000..99a306e --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/DeleteFromQueryAsync.cs @@ -0,0 +1,149 @@ +using System; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class DeleteFromQueryAsync : DbContextExtensionsBase +{ + [TestMethod] + public async Task With_Boolean_Value() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(p => p.OutOfStock); + int oldTotal = products.Count(a => a.OutOfStock); + int rowUpdated = await products.DeleteFromQueryAsync(); + int newTotal = dbContext.Products.Count(o => o.OutOfStock); + + Assert.IsTrue(oldTotal > 0, "There must be products in database that match this condition (OutOfStock == true)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the condition (OutOfStock == false)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public async Task With_Child_Relationship() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(p => !p.ProductCategory.Active); + int oldTotal = products.Count(); + int rowsDeleted = await products.DeleteFromQueryAsync(); + int newTotal = products.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be products in database that match this condition (ProductCategory.Active == false)"); + Assert.IsTrue(rowsDeleted == oldTotal, "The number of rows update must match the count of rows that match the condition (ProductCategory.Active == false)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were deleted"); + } + [TestMethod] + public async Task With_Decimal_Using_IQueryable() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price <= 10); + int oldTotal = orders.Count(); + int rowsDeleted = await orders.DeleteFromQueryAsync(); + int newTotal = orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(rowsDeleted == oldTotal, "The number of rows deleted must match the count of existing rows in database"); + Assert.IsTrue(newTotal == 0, "Delete() Failed: must be 0 to indicate all records were deleted"); + } + [TestMethod] + public async Task With_Decimal_Using_IEnumerable() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price <= 10); + int oldTotal = orders.Count(); + int rowsDeleted = await orders.DeleteFromQueryAsync(); + int newTotal = orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(rowsDeleted == oldTotal, "The number of rows deleted must match the count of existing rows in database"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were deleted"); + } + [TestMethod] + public async Task With_DateTime() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Count(); + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + int rowsToDelete = dbContext.Orders.Where(o => o.ModifiedDateTime != null && o.ModifiedDateTime >= dateTime).Count(); + int rowsDeleted = await dbContext.Orders.Where(o => o.ModifiedDateTime != null && o.ModifiedDateTime >= dateTime) + .DeleteFromQueryAsync(); + int newTotal = dbContext.Orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(rowsDeleted == rowsToDelete, "The number of rows deleted must match the count of the rows that matched in the database"); + Assert.IsTrue(oldTotal - newTotal == rowsDeleted, "The rows deleted must match the new count minues the old count"); + } + [TestMethod] + public async Task With_Delete_All() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Count(); + int rowsDeleted = await dbContext.Orders.DeleteFromQueryAsync(); + int newTotal = dbContext.Orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(rowsDeleted == oldTotal, "The number of rows deleted must match the count of existing rows in database"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were deleted"); + } + [TestMethod] + public async Task With_Different_Values() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Count(); + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + var orders = dbContext.Orders.Where(o => o.Id == 1 && o.Active && o.ModifiedDateTime >= dateTime); + int rowsToDelete = orders.Count(); + int rowsDeleted = await orders.DeleteFromQueryAsync(); + int newTotal = dbContext.Orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(rowsDeleted == rowsToDelete, "The number of rows deleted must match the count of the rows that matched in the database"); + Assert.IsTrue(oldTotal - newTotal == rowsDeleted, "The rows deleted must match the new count minues the old count"); + } + [TestMethod] + public async Task With_Empty_List() + { + var dbContext = SetupDbContext(false); + int oldTotal = dbContext.Orders.Count(); + int rowsDeleted = await dbContext.Orders.DeleteFromQueryAsync(); + int newTotal = dbContext.Orders.Count(); + + Assert.IsTrue(oldTotal == 0, "There must be no orders in database that match this condition"); + Assert.IsTrue(rowsDeleted == oldTotal, "The number of rows deleted must match the count of existing rows in database"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were deleted"); + } + [TestMethod] + public async Task With_Schema() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Schema); + int oldTotal = dbContext.ProductsWithCustomSchema.Count(); + int rowsDeleted = await dbContext.ProductsWithCustomSchema.DeleteFromQueryAsync(); + int newTotal = dbContext.ProductsWithCustomSchema.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be products in database that match this condition"); + Assert.IsTrue(rowsDeleted == oldTotal, "The number of rows deleted must match the count of existing rows in database"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were deleted"); + } + [DoNotParallelize] + [TestMethod] + public async Task With_Transaction() + { + var dbContext = SetupDbContext(true); + int rowsDeleted; + int oldTotal = dbContext.Orders.Count(); + var orders = dbContext.Orders.Where(o => o.Price <= 10); + int rowsToDelete = orders.Count(); + using (var transaction = dbContext.Database.BeginTransaction()) + { + rowsDeleted = await orders.DeleteFromQueryAsync(); + transaction.Rollback(); + } + int newTotal = dbContext.Orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < $10)"); + Assert.IsTrue(rowsDeleted == orders.Count(), "The number of rows update must match the count of rows that match the condtion (Price < $10)"); + Assert.IsTrue(newTotal == oldTotal, "The new count must match the old count since the transaction was rollbacked"); + } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/Fetch.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/Fetch.cs new file mode 100644 index 0000000..d4c84d2 --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/Fetch.cs @@ -0,0 +1,177 @@ +using System; +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class Fetch : DbContextExtensionsBase +{ + [TestMethod] + public void With_BulkInsert() + { + var dbContext = SetupDbContext(true); + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + var orders = dbContext.Orders.Where(o => o.AddedDateTime <= dateTime); + int totalOrdersToFetch = orders.Count(); + int totalOrdersFetched = 0; + int batchSize = 5000; + orders.Fetch(result => + { + totalOrdersFetched += result.Results.Count(); + var ordersFetched = result.Results; + foreach (var orderFetched in ordersFetched) + { + orderFetched.Price = 75; + } + dbContext.BulkInsert(ordersFetched); + }, options => { options.BatchSize = batchSize; }); + + int totalOrder = orders.Count(); + int totalOrderInserted = orders.Where(o => o.Price == 75).Count(); + Assert.IsTrue(totalOrdersToFetch == totalOrdersFetched, "The total number of rows fetched must match the number of rows to fetch"); + Assert.IsTrue(totalOrderInserted == totalOrdersFetched, "The total number of rows updated must match the number of rows that were fetched"); + Assert.IsTrue(totalOrder - totalOrdersToFetch == totalOrderInserted, "The total number of rows must match the number of rows that were updated"); + } + [TestMethod] + public void With_BulkUpdate() + { + var dbContext = SetupDbContext(true); + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + var orders = dbContext.Orders.Where(o => o.AddedDateTime <= dateTime); + int totalOrdersToFetch = orders.Count(); + int totalOrdersFetched = 0; + int batchSize = 5000; + orders.Fetch(result => + { + totalOrdersFetched += result.Results.Count(); + var ordersFetched = result.Results; + foreach (var orderFetched in ordersFetched) + { + orderFetched.Price = 75; + } + dbContext.BulkUpdate(ordersFetched); + }, options => { options.BatchSize = batchSize; }); + + int totalOrder = orders.Count(); + int totalOrderUpdated = orders.Where(o => o.Price == 75).Count(); + Assert.IsTrue(totalOrdersToFetch == totalOrdersFetched, "The total number of rows fetched must match the number of rows to fetch"); + Assert.IsTrue(totalOrderUpdated == totalOrdersFetched, "The total number of rows updated must match the number of rows that were fetched"); + Assert.IsTrue(totalOrder == totalOrderUpdated, "The total number of rows must match the number of rows that were updated"); + } + [TestMethod] + public void With_DateTime() + { + var dbContext = SetupDbContext(true); + int batchSize = 1000; + int batchCount = 0; + int totalCount = 0; + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + var orders = dbContext.Orders.Where(o => o.AddedDateTime <= dateTime); + int expectedTotalCount = orders.Count(); + int expectedBatchCount = (int)Math.Ceiling(expectedTotalCount / (decimal)batchSize); + + orders.Fetch(result => + { + batchCount++; + totalCount += result.Results.Count(); + Assert.IsTrue(result.Results.Count <= batchSize, "The count of results in each batch callback should less than or equal to the batchSize"); + }, options => { options.BatchSize = batchSize; }); + + Assert.IsTrue(expectedTotalCount > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(expectedTotalCount == totalCount, "The total number of rows fetched must match the count of existing rows in database"); + Assert.IsTrue(expectedBatchCount == batchCount, "The total number of batches fetched must match what is expected"); + } + [TestMethod] + public void With_Decimal() + { + var dbContext = SetupDbContext(true); + int batchSize = 1000; + int batchCount = 0; + int totalCount = 0; + var orders = dbContext.Orders.Where(o => o.Price < 10M); + int expectedTotalCount = orders.Count(); + int expectedBatchCount = (int)Math.Ceiling(expectedTotalCount / (decimal)batchSize); + + orders.Fetch(result => + { + batchCount++; + totalCount += result.Results.Count(); + Assert.IsTrue(result.Results.Count <= batchSize, "The count of results in each batch callback should less than or equal to the batchSize"); + }, options => { options.BatchSize = batchSize; }); + + Assert.IsTrue(expectedTotalCount > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(expectedTotalCount == totalCount, "The total number of rows fetched must match the count of existing rows in database"); + Assert.IsTrue(expectedBatchCount == batchCount, "The total number of batches fetched must match what is expected"); + } + [TestMethod] + public void With_Enum() + { + var dbContext = SetupDbContext(true); + int batchSize = 1000; + int batchCount = 0; + int totalCount = 0; + var products = dbContext.Products.Where(o => o.Price < 10M); + int expectedTotalCount = products.Count(); + int expectedBatchCount = (int)Math.Ceiling(expectedTotalCount / (decimal)batchSize); + + products.Fetch(result => + { + batchCount++; + totalCount += result.Results.Count(); + Assert.IsTrue(result.Results.Count <= batchSize, "The count of results in each batch callback should less than or equal to the batchSize"); + }, options => { options.BatchSize = batchSize; }); + + Assert.IsTrue(expectedTotalCount > 0, "There must be products in database that match this condition"); + Assert.IsTrue(expectedTotalCount == totalCount, "The total number of rows fetched must match the count of existing rows in database"); + Assert.IsTrue(expectedBatchCount == batchCount, "The total number of batches fetched must match what is expected"); + } + [TestMethod] + public void With_Options_IgnoreColumns() + { + var dbContext = SetupDbContext(true); + int batchSize = 1000; + int batchCount = 0; + int totalCount = 0; + var orders = dbContext.Orders.Where(o => o.Price < 10M); + int expectedTotalCount = orders.Count(); + int expectedBatchCount = (int)Math.Ceiling(expectedTotalCount / (decimal)batchSize); + + orders.Fetch(result => + { + batchCount++; + totalCount += result.Results.Count(); + bool isAllExternalIdNull = !result.Results.Any(o => o.ExternalId != null); + Assert.IsTrue(isAllExternalIdNull, "All records should have ExternalId equal to NULL since it was not loaded."); + Assert.IsTrue(result.Results.Count <= batchSize, "The count of results in each batch callback should less than or equal to the batchSize"); + }, options => { options.BatchSize = batchSize; options.IgnoreColumns = s => new { s.ExternalId }; }); + + Assert.IsTrue(expectedTotalCount > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(expectedTotalCount == totalCount, "The total number of rows fetched must match the count of existing rows in database"); + Assert.IsTrue(expectedBatchCount == batchCount, "The total number of batches fetched must match what is expected"); + } + [TestMethod] + public void With_Options_InputColumns() + { + var dbContext = SetupDbContext(true); + int batchSize = 1000; + int batchCount = 0; + int totalCount = 0; + var orders = dbContext.Orders.Where(o => o.Price < 10M); + int expectedTotalCount = orders.Count(); + int expectedBatchCount = (int)Math.Ceiling(expectedTotalCount / (decimal)batchSize); + + orders.Fetch(result => + { + batchCount++; + totalCount += result.Results.Count(); + bool isAllExternalIdNull = !result.Results.Any(o => o.ExternalId != null); + Assert.IsTrue(isAllExternalIdNull, "All records should have ExternalId equal to NULL since it was not loaded."); + Assert.IsTrue(result.Results.Count <= batchSize, "The count of results in each batch callback should less than or equal to the batchSize"); + }, options => { options.BatchSize = batchSize; options.InputColumns = s => new { s.Id, s.Price }; }); + + Assert.IsTrue(expectedTotalCount > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(expectedTotalCount == totalCount, "The total number of rows fetched must match the count of existing rows in database"); + Assert.IsTrue(expectedBatchCount == batchCount, "The total number of batches fetched must match what is expected"); + } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/FetchAsync.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/FetchAsync.cs new file mode 100644 index 0000000..c9a4865 --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/FetchAsync.cs @@ -0,0 +1,193 @@ +using System; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class FetchAsync : DbContextExtensionsBase +{ + [TestMethod] + public async Task With_BulkInsert() + { + var dbContext = SetupDbContext(true); + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + var orders = dbContext.Orders.Where(o => o.AddedDateTime <= dateTime); + int totalOrdersToFetch = orders.Count(); + int totalOrdersFetched = 0; + int batchSize = 5000; + await orders.FetchAsync(async result => + { + totalOrdersFetched += result.Results.Count; + var ordersFetched = result.Results; + foreach (var orderFetched in ordersFetched) + { + orderFetched.Price = 75; + } + await dbContext.BulkInsertAsync(ordersFetched); + }, options => { options.BatchSize = batchSize; }); + + int totalOrder = orders.Count(); + int totalOrderInserted = orders.Where(o => o.Price == 75).Count(); + Assert.IsTrue(totalOrdersToFetch == totalOrdersFetched, "The total number of rows fetched must match the number of rows to fetch"); + Assert.IsTrue(totalOrderInserted == totalOrdersFetched, "The total number of rows updated must match the number of rows that were fetched"); + Assert.IsTrue(totalOrder - totalOrdersToFetch == totalOrderInserted, "The total number of rows must match the number of rows that were updated"); + } + [TestMethod] + public async Task With_BulkUpdate() + { + var dbContext = SetupDbContext(true); + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + var orders = dbContext.Orders.Where(o => o.AddedDateTime <= dateTime); + int totalOrdersToFetch = orders.Count(); + int totalOrdersFetched = 0; + int batchSize = 5000; + await orders.FetchAsync(async result => + { + totalOrdersFetched += result.Results.Count; + var ordersFetched = result.Results; + foreach (var orderFetched in ordersFetched) + { + orderFetched.Price = 75; + } + await dbContext.BulkUpdateAsync(ordersFetched); + }, options => { options.BatchSize = batchSize; }); + + int totalOrder = orders.Count(); + int totalOrderUpdated = orders.Where(o => o.Price == 75).Count(); + Assert.IsTrue(totalOrdersToFetch == totalOrdersFetched, "The total number of rows fetched must match the number of rows to fetch"); + Assert.IsTrue(totalOrderUpdated == totalOrdersFetched, "The total number of rows updated must match the number of rows that were fetched"); + Assert.IsTrue(totalOrder == totalOrderUpdated, "The total number of rows must match the number of rows that were updated"); + } + [TestMethod] + public async Task With_DateTime() + { + var dbContext = SetupDbContext(true); + int batchSize = 1000; + int batchCount = 0; + int totalCount = 0; + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + var orders = dbContext.Orders.Where(o => o.AddedDateTime <= dateTime); + int expectedTotalCount = orders.Count(); + int expectedBatchCount = (int)Math.Ceiling(expectedTotalCount / (decimal)batchSize); + + await orders.FetchAsync(async result => + { + await Task.Run(() => + { + batchCount++; + totalCount += result.Results.Count; + Assert.IsTrue(result.Results.Count <= batchSize, "The count of results in each batch callback should less than or equal to the batchSize"); + }); + }, options => { options.BatchSize = batchSize; }); + + Assert.IsTrue(expectedTotalCount > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(expectedTotalCount == totalCount, "The total number of rows fetched must match the count of existing rows in database"); + Assert.IsTrue(expectedBatchCount == batchCount, "The total number of batches fetched must match what is expected"); + } + [TestMethod] + public async Task With_Decimal() + { + var dbContext = SetupDbContext(true); + int batchSize = 1000; + int batchCount = 0; + int totalCount = 0; + var orders = dbContext.Orders.Where(o => o.Price < 10M); + int expectedTotalCount = orders.Count(); + int expectedBatchCount = (int)Math.Ceiling(expectedTotalCount / (decimal)batchSize); + + await orders.FetchAsync(async result => + { + await Task.Run(() => + { + batchCount++; + totalCount += result.Results.Count(); + Assert.IsTrue(result.Results.Count <= batchSize, "The count of results in each batch callback should less than or equal to the batchSize"); + }); + }, options => { options.BatchSize = batchSize; }); + + Assert.IsTrue(expectedTotalCount > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(expectedTotalCount == totalCount, "The total number of rows fetched must match the count of existing rows in database"); + Assert.IsTrue(expectedBatchCount == batchCount, "The total number of batches fetched must match what is expected"); + } + [TestMethod] + public async Task With_Enum() + { + var dbContext = SetupDbContext(true); + int batchSize = 1000; + int batchCount = 0; + int totalCount = 0; + var products = dbContext.Products.Where(o => o.Price < 10M); + int expectedTotalCount = products.Count(); + int expectedBatchCount = (int)Math.Ceiling(expectedTotalCount / (decimal)batchSize); + + await products.FetchAsync(async result => + { + await Task.Run(() => + { + batchCount++; + totalCount += result.Results.Count(); + Assert.IsTrue(result.Results.Count <= batchSize, "The count of results in each batch callback should less than or equal to the batchSize"); + }); + }, options => { options.BatchSize = batchSize; }); + + Assert.IsTrue(expectedTotalCount > 0, "There must be products in database that match this condition"); + Assert.IsTrue(expectedTotalCount == totalCount, "The total number of rows fetched must match the count of existing rows in database"); + Assert.IsTrue(expectedBatchCount == batchCount, "The total number of batches fetched must match what is expected"); + } + [TestMethod] + public async Task With_Options_IgnoreColumns() + { + var dbContext = SetupDbContext(true); + int batchSize = 1000; + int batchCount = 0; + int totalCount = 0; + var orders = dbContext.Orders.Where(o => o.Price < 10M); + int expectedTotalCount = orders.Count(); + int expectedBatchCount = (int)Math.Ceiling(expectedTotalCount / (decimal)batchSize); + + await orders.FetchAsync(async result => + { + await Task.Run(() => + { + batchCount++; + totalCount += result.Results.Count; + bool isAllExternalIdNull = !result.Results.Any(o => o.ExternalId != null); + Assert.IsTrue(isAllExternalIdNull, "All records should have ExternalId equal to NULL since it was not loaded."); + Assert.IsTrue(result.Results.Count <= batchSize, "The count of results in each batch callback should less than or equal to the batchSize"); + }); + }, options => { options.BatchSize = batchSize; options.IgnoreColumns = s => new { s.ExternalId }; }); + + Assert.IsTrue(expectedTotalCount > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(expectedTotalCount == totalCount, "The total number of rows fetched must match the count of existing rows in database"); + Assert.IsTrue(expectedBatchCount == batchCount, "The total number of batches fetched must match what is expected"); + } + [TestMethod] + public async Task With_Options_InputColumns() + { + var dbContext = SetupDbContext(true); + int batchSize = 1000; + int batchCount = 0; + int totalCount = 0; + var orders = dbContext.Orders.Where(o => o.Price < 10M); + int expectedTotalCount = orders.Count(); + int expectedBatchCount = (int)Math.Ceiling(expectedTotalCount / (decimal)batchSize); + + await orders.FetchAsync(async result => + { + await Task.Run(() => + { + batchCount++; + totalCount += result.Results.Count(); + bool isAllExternalIdNull = !result.Results.Any(o => o.ExternalId != null); + Assert.IsTrue(isAllExternalIdNull, "All records should have ExternalId equal to NULL since it was not loaded."); + Assert.IsTrue(result.Results.Count <= batchSize, "The count of results in each batch callback should less than or equal to the batchSize"); + }); + }, options => { options.BatchSize = batchSize; options.InputColumns = s => new { s.Id, s.Price }; }); + + Assert.IsTrue(expectedTotalCount > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(expectedTotalCount == totalCount, "The total number of rows fetched must match the count of existing rows in database"); + Assert.IsTrue(expectedBatchCount == batchCount, "The total number of batches fetched must match what is expected"); + } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/InsertFromQuery.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/InsertFromQuery.cs new file mode 100644 index 0000000..284576b --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/InsertFromQuery.cs @@ -0,0 +1,88 @@ +using System; +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class InsertFromQuery : DbContextExtensionsBase +{ + [TestMethod] + public void With_DateTime_Value() + { + var dbContext = SetupDbContext(true); + string tableName = "OrdersLast30Days"; + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + int oldTotal = dbContext.Orders.Count(); + + var orders = dbContext.Orders.Where(o => o.AddedDateTime >= dateTime); + int oldSourceTotal = orders.Count(); + int rowsInserted = orders.InsertFromQuery(tableName, + o => new { o.Id, o.ExternalId, o.Price, o.AddedDateTime, o.ModifiedDateTime, o.Active }); + int newSourceTotal = orders.Count(); + int newTargetTotal = orders.UsingTable(tableName).Count(); + + Assert.IsTrue(oldTotal > oldSourceTotal, "The total should be greater then the number of rows selected from the source table"); + Assert.IsTrue(oldSourceTotal > 0, "There should be existing data in the source table"); + Assert.IsTrue(oldSourceTotal == newSourceTotal, "There should not be any change in the count of rows in the source table"); + Assert.IsTrue(rowsInserted == oldSourceTotal, "The number of records inserted must match the count of the source table"); + Assert.IsTrue(rowsInserted == newTargetTotal, "The different in count in the target table before and after the insert must match the total row inserted"); + } + [TestMethod] + public void With_Decimal_Value() + { + var dbContext = SetupDbContext(true); + string tableName = "OrdersUnderTen"; + var orders = dbContext.Orders.Where(o => o.Price < 10M); + int oldSourceTotal = orders.Count(); + int rowsInserted = dbContext.Orders.Where(o => o.Price < 10M).InsertFromQuery(tableName, o => new { o.Id, o.Price, o.AddedDateTime, o.Active }); + int newSourceTotal = orders.Count(); + int newTargetTotal = orders.UsingTable(tableName).Count(); + + Assert.IsTrue(oldSourceTotal > 0, "There should be existing data in the source table"); + Assert.IsTrue(oldSourceTotal == newSourceTotal, "There should not be any change in the count of rows in the source table"); + Assert.IsTrue(rowsInserted == oldSourceTotal, "The number of records inserted must match the count of the source table"); + Assert.IsTrue(rowsInserted == newTargetTotal, "The different in count in the target table before and after the insert must match the total row inserted"); + } + [TestMethod] + public void With_Schema() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Schema); + string tableName = "ProductsUnderTen"; + var products = dbContext.ProductsWithCustomSchema.Where(o => o.Price < 10M); + int oldSourceTotal = products.Count(); + int rowsInserted = dbContext.ProductsWithCustomSchema.Where(o => o.Price < 10M).InsertFromQuery(tableName, o => new { o.Id, o.Price }); + int newSourceTotal = products.Count(); + int newTargetTotal = products.UsingTable(tableName).Count(); + + Assert.IsTrue(oldSourceTotal > 0, "There should be existing data in the source table"); + Assert.IsTrue(oldSourceTotal == newSourceTotal, "There should not be any change in the count of rows in the source table"); + Assert.IsTrue(rowsInserted == oldSourceTotal, "The number of records inserted must match the count of the source table"); + Assert.IsTrue(rowsInserted == newTargetTotal, "The different in count in the target table before and after the insert must match the total row inserted"); + } + [TestMethod] + [DoNotParallelize] + [Ignore("MySQL DDL auto-commit: CREATE TABLE in InsertFromQuery cannot be rolled back in MySQL because DDL statements cause implicit transaction commits. Table will persist after rollback.")] + public void With_Transaction() + { + var dbContext = SetupDbContext(true); + string tableName = "OrdersUnderTen"; + int rowsInserted; + bool tableExistsBefore, tableExistsAfter; + int oldSourceTotal = dbContext.Orders.Where(o => o.Price < 10M).Count(); + using (var transaction = dbContext.Database.BeginTransaction()) + { + rowsInserted = dbContext.Orders.Where(o => o.Price < 10M).InsertFromQuery(tableName, o => new { o.Price, o.Id, o.AddedDateTime, o.Active }); + tableExistsBefore = dbContext.Database.TableExists(tableName); + transaction.Rollback(); + } + tableExistsAfter = dbContext.Database.TableExists(tableName); + int newSourceTotal = dbContext.Orders.Where(o => o.Price < 10M).Count(); + + Assert.IsTrue(oldSourceTotal > 0, "There must be orders in database that match this condition (Price < $10)"); + Assert.IsTrue(rowsInserted == oldSourceTotal, "The number of rows update must match the count of rows that match the condtion (Price < $10)"); + Assert.IsTrue(newSourceTotal == oldSourceTotal, "The new count must match the old count since the transaction was rollbacked"); + Assert.IsTrue(tableExistsBefore, string.Format("Table {0} should exist before transaction rollback", tableName)); + Assert.IsFalse(tableExistsAfter, string.Format("Table {0} should not exist after transaction rollback", tableName)); + } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/InsertFromQueryAsync.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/InsertFromQueryAsync.cs new file mode 100644 index 0000000..c31d2c7 --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/InsertFromQueryAsync.cs @@ -0,0 +1,89 @@ +using System; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class InsertFromQueryAsync : DbContextExtensionsBase +{ + [TestMethod] + public async Task With_DateTime_Value() + { + var dbContext = SetupDbContext(true); + string tableName = "OrdersLast30Days"; + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + int oldTotal = dbContext.Orders.Count(); + + var orders = dbContext.Orders.Where(o => o.AddedDateTime >= dateTime); + int oldSourceTotal = orders.Count(); + int rowsInserted = await orders.InsertFromQueryAsync(tableName, + o => new { o.Id, o.ExternalId, o.Price, o.AddedDateTime, o.ModifiedDateTime, o.Active }); + int newSourceTotal = orders.Count(); + int newTargetTotal = orders.UsingTable(tableName).Count(); + + Assert.IsTrue(oldTotal > oldSourceTotal, "The total should be greater then the number of rows selected from the source table"); + Assert.IsTrue(oldSourceTotal > 0, "There should be existing data in the source table"); + Assert.IsTrue(oldSourceTotal == newSourceTotal, "There should not be any change in the count of rows in the source table"); + Assert.IsTrue(rowsInserted == oldSourceTotal, "The number of records inserted must match the count of the source table"); + Assert.IsTrue(rowsInserted == newTargetTotal, "The different in count in the target table before and after the insert must match the total row inserted"); + } + [TestMethod] + public async Task With_Decimal_Value() + { + var dbContext = SetupDbContext(true); + string tableName = "OrdersUnderTen"; + var orders = dbContext.Orders.Where(o => o.Price < 10M); + int oldSourceTotal = orders.Count(); + int rowsInserted = await dbContext.Orders.Where(o => o.Price < 10M).InsertFromQueryAsync(tableName, o => new { o.Id, o.Price, o.AddedDateTime, o.Active }); + int newSourceTotal = orders.Count(); + int newTargetTotal = orders.UsingTable(tableName).Count(); + + Assert.IsTrue(oldSourceTotal > 0, "There should be existing data in the source table"); + Assert.IsTrue(oldSourceTotal == newSourceTotal, "There should not be any change in the count of rows in the source table"); + Assert.IsTrue(rowsInserted == oldSourceTotal, "The number of records inserted must match the count of the source table"); + Assert.IsTrue(rowsInserted == newTargetTotal, "The different in count in the target table before and after the insert must match the total row inserted"); + } + [TestMethod] + public async Task With_Schema() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Schema); + string tableName = "ProductsUnderTen"; + var products = dbContext.ProductsWithCustomSchema.Where(o => o.Price < 10M); + int oldSourceTotal = products.Count(); + int rowsInserted = await dbContext.ProductsWithCustomSchema.Where(o => o.Price < 10M).InsertFromQueryAsync(tableName, o => new { o.Id, o.Price }); + int newSourceTotal = products.Count(); + int newTargetTotal = products.UsingTable(tableName).Count(); + + Assert.IsTrue(oldSourceTotal > 0, "There should be existing data in the source table"); + Assert.IsTrue(oldSourceTotal == newSourceTotal, "There should not be any change in the count of rows in the source table"); + Assert.IsTrue(rowsInserted == oldSourceTotal, "The number of records inserted must match the count of the source table"); + Assert.IsTrue(rowsInserted == newTargetTotal, "The different in count in the target table before and after the insert must match the total row inserted"); + } + [TestMethod] + [DoNotParallelize] + [Ignore("MySQL DDL auto-commit: CREATE TABLE in InsertFromQuery cannot be rolled back in MySQL because DDL statements cause implicit transaction commits. Table will persist after rollback.")] + public async Task With_Transaction() + { + var dbContext = SetupDbContext(true); + string tableName = "OrdersUnderTen"; + int rowsInserted; + bool tableExistsBefore, tableExistsAfter; + int oldSourceTotal = dbContext.Orders.Where(o => o.Price < 10M).Count(); + using (var transaction = dbContext.Database.BeginTransaction()) + { + rowsInserted = await dbContext.Orders.Where(o => o.Price < 10M).InsertFromQueryAsync(tableName, o => new { o.Price, o.Id, o.AddedDateTime, o.Active }); + tableExistsBefore = dbContext.Database.TableExists(tableName); + transaction.Rollback(); + } + tableExistsAfter = dbContext.Database.TableExists(tableName); + int newSourceTotal = dbContext.Orders.Where(o => o.Price < 10M).Count(); + + Assert.IsTrue(oldSourceTotal > 0, "There must be orders in database that match this condition (Price < $10)"); + Assert.IsTrue(rowsInserted == oldSourceTotal, "The number of rows update must match the count of rows that match the condtion (Price < $10)"); + Assert.IsTrue(newSourceTotal == oldSourceTotal, "The new count must match the old count since the transaction was rollbacked"); + Assert.IsTrue(tableExistsBefore, string.Format("Table {0} should exist before transaction rollback", tableName)); + Assert.IsFalse(tableExistsAfter, string.Format("Table {0} should not exist after transaction rollback", tableName)); + } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/QueryToCsvFile.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/QueryToCsvFile.cs new file mode 100644 index 0000000..8bd91de --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/QueryToCsvFile.cs @@ -0,0 +1,47 @@ +using System.IO; +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class QueryToCsvFile : DbContextExtensionsBase +{ + [TestMethod] + public void With_Default_Options() + { + var dbContext = SetupDbContext(true); + var query = dbContext.Orders.Where(o => o.Price < 10M); + int count = query.Count(); + var queryToCsvFileResult = query.QueryToCsvFile("QueryToCsvFile-Test.csv"); + + Assert.IsTrue(count > 0, "There should be existing data in the source table"); + Assert.IsTrue(queryToCsvFileResult.DataRowCount == count, "The number of data rows written to the file should match the count from the database"); + Assert.IsTrue(queryToCsvFileResult.TotalRowCount == count + 1, "The total number of rows written to the file should match the count from the database plus the header row"); + } + [TestMethod] + public void With_Options_ColumnDelimiter_TextQualifer_HeaderRow() + { + var dbContext = SetupDbContext(true); + var query = dbContext.Orders.Where(o => o.Price < 10M); + int count = query.Count(); + var queryToCsvFileResult = query.QueryToCsvFile("QueryToCsvFile_Options_ColumnDelimiter_TextQualifer_HeaderRow-Test.csv", options => { options.ColumnDelimiter = "|"; options.TextQualifer = "\""; options.IncludeHeaderRow = false; }); + + Assert.IsTrue(count > 0, "There should be existing data in the source table"); + Assert.IsTrue(queryToCsvFileResult.DataRowCount == count, "The number of data rows written to the file should match the count from the database"); + Assert.IsTrue(queryToCsvFileResult.TotalRowCount == count, "The total number of rows written to the file should match the count from the database without any header row"); + } + [TestMethod] + public void Using_FileStream() + { + var dbContext = SetupDbContext(true); + var query = dbContext.Orders.Where(o => o.Price < 10M); + int count = query.Count(); + var fileStream = File.Create("QueryToCsvFile_Stream-Test.csv"); + var queryToCsvFileResult = query.QueryToCsvFile(fileStream); + + Assert.IsTrue(count > 0, "There should be existing data in the source table"); + Assert.IsTrue(queryToCsvFileResult.DataRowCount == count, "The number of data rows written to the file should match the count from the database"); + Assert.IsTrue(queryToCsvFileResult.TotalRowCount == count + 1, "The total number of rows written to the file should match the count from the database plus the header row"); + } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/QueryToCsvFileAsync.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/QueryToCsvFileAsync.cs new file mode 100644 index 0000000..5aff479 --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/QueryToCsvFileAsync.cs @@ -0,0 +1,48 @@ +using System.IO; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class QueryToCsvFileAsync : DbContextExtensionsBase +{ + [TestMethod] + public async Task With_Default_Options() + { + var dbContext = SetupDbContext(true); + var query = dbContext.Orders.Where(o => o.Price < 10M); + int count = query.Count(); + var queryToCsvFileResult = await query.QueryToCsvFileAsync("QueryToCsvFile-Test.csv"); + + Assert.IsTrue(count > 0, "There should be existing data in the source table"); + Assert.IsTrue(queryToCsvFileResult.DataRowCount == count, "The number of data rows written to the file should match the count from the database"); + Assert.IsTrue(queryToCsvFileResult.TotalRowCount == count + 1, "The total number of rows written to the file should match the count from the database plus the header row"); + } + [TestMethod] + public async Task With_Options_ColumnDelimiter_TextQualifer_HeaderRow() + { + var dbContext = SetupDbContext(true); + var query = dbContext.Orders.Where(o => o.Price < 10M); + int count = query.Count(); + var queryToCsvFileResult = await query.QueryToCsvFileAsync("QueryToCsvFile_Options_ColumnDelimiter_TextQualifer_HeaderRow-Test.csv", options => { options.ColumnDelimiter = "|"; options.TextQualifer = "\""; options.IncludeHeaderRow = false; }); + + Assert.IsTrue(count > 0, "There should be existing data in the source table"); + Assert.IsTrue(queryToCsvFileResult.DataRowCount == count, "The number of data rows written to the file should match the count from the database"); + Assert.IsTrue(queryToCsvFileResult.TotalRowCount == count, "The total number of rows written to the file should match the count from the database without any header row"); + } + [TestMethod] + public async Task Using_FileStream() + { + var dbContext = SetupDbContext(true); + var query = dbContext.Orders.Where(o => o.Price < 10M); + int count = query.Count(); + var fileStream = File.Create("QueryToCsvFile_Stream-Test.csv"); + var queryToCsvFileResult = await query.QueryToCsvFileAsync(fileStream); + + Assert.IsTrue(count > 0, "There should be existing data in the source table"); + Assert.IsTrue(queryToCsvFileResult.DataRowCount == count, "The number of data rows written to the file should match the count from the database"); + Assert.IsTrue(queryToCsvFileResult.TotalRowCount == count + 1, "The total number of rows written to the file should match the count from the database plus the header row"); + } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/UpdateFromQuery.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/UpdateFromQuery.cs new file mode 100644 index 0000000..98eaf83 --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/UpdateFromQuery.cs @@ -0,0 +1,276 @@ +using System; +using System.Collections.Generic; +using System.Globalization; +using System.Linq; +using System.Threading; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Data; +using N.EntityFrameworkCore.Extensions.Test.Data.Enums; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class UpdateFromQuery : DbContextExtensionsBase +{ + [TestMethod] + public void With_Boolean_Value() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Products.Count(a => a.OutOfStock); + int rowUpdated = dbContext.Products.Where(a => a.OutOfStock).UpdateFromQuery(a => new Product { OutOfStock = false }); + int newTotal = dbContext.Products.Count(o => o.OutOfStock); + + Assert.IsTrue(oldTotal > 0, "There must be articles in database that match this condition (OutOfStock == true)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the condition (OutOfStock == false)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public void With_Concatenating_String() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.ExternalId == null); + int oldTotal = orders.Count(); + int rowUpdated = orders.UpdateFromQuery(o => new Order { ExternalId = Convert.ToString(o.Id) + "Test" }); + int newTotal = orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (ExternalId == null)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the condition (ExternalId == null)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public void With_Concatenating_String_And_Number() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.ExternalId == null); + int oldTotal = orders.Count(); + int rowUpdated = orders.UpdateFromQuery(o => new Order { ExternalId = Convert.ToString(o.Id) + "Test" }); + int newTotal = orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (ExternalId == null)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the condition (ExternalId == null)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public void With_DateTime_Value() + { + var dbContext = SetupDbContext(true); + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + DateTime now = DateTime.UtcNow; + + int oldTotal = dbContext.Orders.Where(o => o.AddedDateTime >= dateTime).Count(); + int rowUpdated = dbContext.Orders.Where(o => o.AddedDateTime >= dateTime).UpdateFromQuery(o => new Order { ModifiedDateTime = now }); + int newTotal = dbContext.Orders.Where(o => o.ModifiedDateTime == now).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Orders added in last 30 days)"); + Assert.IsTrue(rowUpdated == newTotal, "The number of rows updated should equal the new count of rows that match the condition (Orders added in last 30 days)"); + Assert.IsTrue(oldTotal == newTotal, "The count of rows matching the condition before and after the update should be equal. (Orders added in last 30 days)"); + } + [TestMethod] + public void With_DateTimeOffset_Value() + { + var dbContext = SetupDbContext(true); + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + DateTimeOffset now = DateTimeOffset.UtcNow; + + int oldTotal = dbContext.Orders.Where(o => o.AddedDateTime >= dateTime).Count(); + int rowUpdated = dbContext.Orders.Where(o => o.AddedDateTime >= dateTime).UpdateFromQuery(o => new Order { ModifiedDateTimeOffset = now }); + int newTotal = dbContext.Orders.Where(o => o.ModifiedDateTimeOffset == now).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Orders added in last 30 days)"); + Assert.IsTrue(rowUpdated == newTotal, "The number of rows updated should equal the new count of rows that match the condition (Orders added in last 30 days)"); + Assert.IsTrue(oldTotal == newTotal, "The count of rows matching the condition before and after the update should be equal. (Orders added in last 30 days)"); + } + [TestMethod] + public void With_DateTimeOffset_No_UTC_Value() + { + var dbContext = SetupDbContext(true); + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + DateTimeOffset now = DateTimeOffset.Parse("2020-06-17T16:00:00+05:00").ToUniversalTime(); + + int oldTotal = dbContext.Orders.Where(o => o.AddedDateTime >= dateTime).Count(); + int rowUpdated = dbContext.Orders.Where(o => o.AddedDateTime >= dateTime).UpdateFromQuery(o => new Order { ModifiedDateTimeOffset = now }); + int newTotal = dbContext.Orders.Where(o => o.ModifiedDateTimeOffset == now).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Orders added in last 30 days)"); + Assert.IsTrue(rowUpdated == newTotal, "The number of rows updated should equal the new count of rows that match the condition (Orders added in last 30 days)"); + Assert.IsTrue(oldTotal == newTotal, "The count of rows matching the condition before and after the update should be equal. (Orders added in last 30 days)"); + } + [TestMethod] + public void With_Decimal_Value() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price < 10M); + int oldTotal = orders.Count(); + int rowUpdated = orders.UpdateFromQuery(o => new Order { Price = 25.30M }); + int newTotal = orders.Count(); + int matchCount = dbContext.Orders.Where(o => o.Price == 25.30M).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < $10)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the condtion (Price < $10)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + Assert.IsTrue(matchCount == rowUpdated, "The match count must be equal the number of rows updated in the database."); + } + [TestMethod] + public void With_Different_Culture() + { + Thread.CurrentThread.CurrentCulture = CultureInfo.GetCultureInfo("sv-SE"); + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Where(o => o.Price < 10M).Count(); + int rowUpdated = dbContext.Orders.Where(o => o.Price < 10M).UpdateFromQuery(o => new Order { Price = 25.30M }); + int newTotal = dbContext.Orders.Where(o => o.Price < 10M).Count(); + int matchCount = dbContext.Orders.Where(o => o.Price == 25.30M).Count(); + + Assert.AreEqual("25,30", Convert.ToString(25.30M)); + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < $10)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the condtion (Price < $10)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + Assert.IsTrue(matchCount == rowUpdated, "The match count must be equal the number of rows updated in the database."); + } + [TestMethod] + public void With_Enum_Value() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(a => a.StatusEnum == ProductStatus.OutOfStock && a.OutOfStock); + int oldTotal = products.Count(); + int rowUpdated = products.UpdateFromQuery(a => new Product { StatusEnum = ProductStatus.InStock }); + int newTotal = products.Count(o => o.StatusEnum == ProductStatus.OutOfStock && o.OutOfStock); + int newTotal2 = dbContext.Products.Count(o => o.StatusEnum == ProductStatus.InStock && o.OutOfStock); + + Assert.IsTrue(oldTotal > 0, "There must be articles in database that match this condition (OutOfStock == true)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the condition (OutOfStock == false)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + Assert.IsTrue(newTotal2 == oldTotal, "All rows must have been updated"); + } + [TestMethod] + public void With_Guid_Value() + { + var dbContext = SetupDbContext(true); + var guid = Guid.NewGuid(); + var orders = dbContext.Orders.Where(o => o.Price < 10M); + int oldTotal = orders.Count(); + int rowUpdated = orders.UpdateFromQuery(o => new Order { GlobalId = guid }); + int matchCount = dbContext.Orders.Where(o => o.GlobalId == guid).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < $10)"); + Assert.IsTrue(rowUpdated == oldTotal, $"The number of rows update must match the count of rows that match the condition (GlobalId = '{guid}')"); + Assert.IsTrue(matchCount == rowUpdated, "The match count must be equal the number of rows updated in the database."); + } + [TestMethod] + public void With_Long_List() + { + var dbContext = SetupDbContext(true); + var ids = new List() { 1, 2, 3, 4, 5, 6, 7, 8 }; + var orders = dbContext.Orders.Where(o => ids.Contains(o.Id)); + int oldTotal = orders.Count(); + int rowUpdated = orders.UpdateFromQuery(o => new Order { Price = 25.25M }); + int newTotal = orders.Where(o => o.Price != 25.25M).Count(); + int matchCount = dbContext.Orders.Where(o => ids.Contains(o.Id) && o.Price == 25.25M).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < $10)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the condtion (Price < $10)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + Assert.IsTrue(matchCount == rowUpdated, "The match count must be equal the number of rows updated in the database."); + } + [TestMethod] + public void With_MethodCall() + { + var dbContext = SetupDbContext(true); + + int oldTotal = dbContext.Orders.Count(a => a.Price < 10); + int rowUpdated = dbContext.Orders.Where(a => a.Price < 10).UpdateFromQuery(o => new Order { Price = Math.Ceiling((o.Price + 10.5M) * 3 / 1) }); + int newTotal = dbContext.Orders.Count(o => o.Price < 10); + + Assert.IsTrue(oldTotal > 0, "There must be order in database that match this condition (Price < 10)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the condition (Price < 10)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public void With_Null_Value() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.ExternalId != null); + int oldTotal = orders.Count(); + int rowUpdated = orders.UpdateFromQuery(o => new Order { ExternalId = null }); + int newTotal = orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (ExternalId != null)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the condition (ExternalId != null)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public void With_Schema() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Schema); + var products = dbContext.ProductsWithCustomSchema.Where(o => o.Price < 5M); + int oldTotal = products.Count(); + int rowUpdated = products.UpdateFromQuery(o => new ProductWithCustomSchema { Price = 25.30M }); + int newTotal = products.Count(); + int matchCount = dbContext.ProductsWithCustomSchema.Where(o => o.Price == 25.30M).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be products in database that match this condition (Price < 5)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the condtion (Price < 5)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + Assert.IsTrue(matchCount == rowUpdated, "The match count must be equal the number of rows updated in the database."); + } + [TestMethod] + public void With_String_Containing_Apostrophe() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Where(o => o.ExternalId == null).Count(); + int rowUpdated = dbContext.Orders.Where(o => o.ExternalId == null).UpdateFromQuery(o => new Order { ExternalId = "inv'alid" }); + int newTotal = dbContext.Orders.Where(o => o.ExternalId == null).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (ExternalId == null)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the condition (ExternalId == null)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [DoNotParallelize] + [TestMethod] + public void With_Transaction() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Where(o => o.Price < 10M).Count(); + int rowUpdated; + using (var transaction = dbContext.Database.BeginTransaction()) + { + rowUpdated = dbContext.Orders.Where(o => o.Price < 10M).UpdateFromQuery(o => new Order { Price = 25.30M }); + transaction.Rollback(); + } + int newTotal = dbContext.Orders.Where(o => o.Price < 10M).Count(); + int matchCount = dbContext.Orders.Where(o => o.Price == 25.30M).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < $10)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the condtion (Price < $10)"); + Assert.IsTrue(newTotal == oldTotal, "The new count must match the old count since the transaction was rollbacked"); + Assert.IsTrue(matchCount == 0, "The match count must be equal to 0 since the transaction was rollbacked."); + } + [TestMethod] + public void With_Variables() + { + var dbContext = SetupDbContext(true); + decimal priceStart = 10M; + decimal priceUpdate = 0.34M; + + int oldTotal = dbContext.Orders.Count(a => a.Price < 10); + int rowUpdated = dbContext.Orders.Where(a => a.Price < 10).UpdateFromQuery(a => new Order { Price = priceStart + priceUpdate }); + int newTotal = dbContext.Orders.Count(o => o.Price < 10); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < 10)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the condition (Price < 10)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public void With_Variable_And_Decimal() + { + var dbContext = SetupDbContext(true); + decimal priceStart = 10M; + + int oldTotal = dbContext.Orders.Count(a => a.Price < 10); + int rowUpdated = dbContext.Orders.Where(a => a.Price < 10).UpdateFromQuery(a => new Order { Price = priceStart + 7M }); + int newTotal = dbContext.Orders.Count(o => o.Price < 10); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < 10)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the condition (Price < 10)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/UpdateFromQueryAsync.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/UpdateFromQueryAsync.cs new file mode 100644 index 0000000..c83626f --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbContextExtensions/UpdateFromQueryAsync.cs @@ -0,0 +1,278 @@ +using System; +using System.Collections.Generic; +using System.Globalization; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.EntityFrameworkCore; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Data; +using N.EntityFrameworkCore.Extensions.Test.Data.Enums; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class UpdateFromQueryAsync : DbContextExtensionsBase +{ + [TestMethod] + public async Task With_Boolean_Value() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Products.Count(a => a.OutOfStock); + int rowUpdated = await dbContext.Products.Where(a => a.OutOfStock).UpdateFromQueryAsync(a => new Product { OutOfStock = false }); + int newTotal = dbContext.Products.Count(o => o.OutOfStock); + + Assert.IsTrue(oldTotal > 0, "There must be articles in database that match this condition (OutOfStock == true)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the condition (OutOfStock == false)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public async Task With_Concatenating_String() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.ExternalId == null); + int oldTotal = orders.Count(); + int rowUpdated = await orders.UpdateFromQueryAsync(o => new Order { ExternalId = Convert.ToString(o.Id) + "Test" }); + int newTotal = orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (ExternalId == null)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the condition (ExternalId == null)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public async Task With_Concatenating_String_And_Number() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.ExternalId == null); + int oldTotal = orders.Count(); + int rowUpdated = await orders.UpdateFromQueryAsync(o => new Order { ExternalId = Convert.ToString(o.Id) + "Test" }); + int newTotal = orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (ExternalId == null)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the condition (ExternalId == null)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public async Task With_DateTime_Value() + { + var dbContext = SetupDbContext(true); + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + DateTime now = DateTime.UtcNow; + + int oldTotal = dbContext.Orders.Where(o => o.AddedDateTime >= dateTime).Count(); + int rowUpdated = await dbContext.Orders.Where(o => o.AddedDateTime >= dateTime).UpdateFromQueryAsync(o => new Order { ModifiedDateTime = now }); + int newTotal = dbContext.Orders.Where(o => o.ModifiedDateTime == now).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Orders added in last 30 days)"); + Assert.IsTrue(rowUpdated == newTotal, "The number of rows updated should equal the new count of rows that match the condition (Orders added in last 30 days)"); + Assert.IsTrue(oldTotal == newTotal, "The count of rows matching the condition before and after the update should be equal. (Orders added in last 30 days)"); + } + [TestMethod] + public async Task With_DateTimeOffset_Value() + { + var dbContext = SetupDbContext(true); + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + DateTimeOffset now = DateTimeOffset.UtcNow; + + int oldTotal = dbContext.Orders.Where(o => o.AddedDateTime >= dateTime).Count(); + int rowUpdated = await dbContext.Orders.Where(o => o.AddedDateTime >= dateTime).UpdateFromQueryAsync(o => new Order { ModifiedDateTimeOffset = now }); + int newTotal = dbContext.Orders.Where(o => o.ModifiedDateTimeOffset == now).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Orders added in last 30 days)"); + Assert.IsTrue(rowUpdated == newTotal, "The number of rows updated should equal the new count of rows that match the condition (Orders added in last 30 days)"); + Assert.IsTrue(oldTotal == newTotal, "The count of rows matching the condition before and after the update should be equal. (Orders added in last 30 days)"); + } + [TestMethod] + public async Task With_DateTimeOffset_No_UTC_Value() + { + var dbContext = SetupDbContext(true); + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + DateTimeOffset now = DateTimeOffset.Parse("2020-06-17T16:00:00+05:00").ToUniversalTime(); + + int oldTotal = dbContext.Orders.Where(o => o.AddedDateTime >= dateTime).Count(); + int rowUpdated = await dbContext.Orders.Where(o => o.AddedDateTime >= dateTime).UpdateFromQueryAsync(o => new Order { ModifiedDateTimeOffset = now }); + int newTotal = dbContext.Orders.Where(o => o.ModifiedDateTimeOffset == now).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Orders added in last 30 days)"); + Assert.IsTrue(rowUpdated == newTotal, "The number of rows updated should equal the new count of rows that match the condition (Orders added in last 30 days)"); + Assert.IsTrue(oldTotal == newTotal, "The count of rows matching the condition before and after the update should be equal. (Orders added in last 30 days)"); + } + [TestMethod] + public async Task With_Decimal_Value() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price < 10M); + int oldTotal = orders.Count(); + int rowUpdated = await orders.UpdateFromQueryAsync(o => new Order { Price = 25.30M }); + int newTotal = orders.Count(); + int matchCount = dbContext.Orders.Where(o => o.Price == 25.30M).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < $10)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the condtion (Price < $10)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + Assert.IsTrue(matchCount == rowUpdated, "The match count must be equal the number of rows updated in the database."); + } + [TestMethod] + public async Task With_Different_Culture() + { + Thread.CurrentThread.CurrentCulture = CultureInfo.GetCultureInfo("sv-SE"); + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Where(o => o.Price < 10M).Count(); + int rowUpdated = await dbContext.Orders.Where(o => o.Price < 10M).UpdateFromQueryAsync(o => new Order { Price = 25.30M }); + int newTotal = dbContext.Orders.Where(o => o.Price < 10M).Count(); + int matchCount = dbContext.Orders.Where(o => o.Price == 25.30M).Count(); + + Assert.AreEqual("25,30", Convert.ToString(25.30M)); + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < $10)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the condtion (Price < $10)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + Assert.IsTrue(matchCount == rowUpdated, "The match count must be equal the number of rows updated in the database."); + } + [TestMethod] + public async Task With_Enum_Value() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(a => a.StatusEnum == ProductStatus.OutOfStock && a.OutOfStock); + int oldTotal = products.Count(); + int rowUpdated = await products.UpdateFromQueryAsync(a => new Product { StatusEnum = ProductStatus.InStock }); + int newTotal = products.Count(o => o.StatusEnum == ProductStatus.OutOfStock && o.OutOfStock); + int newTotal2 = dbContext.Products.Count(o => o.StatusEnum == ProductStatus.InStock && o.OutOfStock); + + Assert.IsTrue(oldTotal > 0, "There must be articles in database that match this condition (OutOfStock == true)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the condition (OutOfStock == false)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + Assert.IsTrue(newTotal2 == oldTotal, "All rows must have been updated"); + } + [TestMethod] + public async Task With_Guid_Value() + { + var dbContext = SetupDbContext(true); + var guid = Guid.NewGuid(); + var orders = dbContext.Orders.Where(o => o.Price < 10M); + int oldTotal = await orders.CountAsync(); + int rowUpdated = await orders.UpdateFromQueryAsync(o => new Order { GlobalId = guid }); + int matchCount = await dbContext.Orders.Where(o => o.GlobalId == guid).CountAsync(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < $10)"); + Assert.IsTrue(rowUpdated == oldTotal, $"The number of rows update must match the count of rows that match the condition (GlobalId = '{guid}')"); + Assert.IsTrue(matchCount == rowUpdated, "The match count must be equal the number of rows updated in the database."); + } + [TestMethod] + public async Task With_Long_List() + { + var dbContext = SetupDbContext(true); + var ids = new List() { 1, 2, 3, 4, 5, 6, 7, 8 }; + var orders = dbContext.Orders.Where(o => ids.Contains(o.Id)); + int oldTotal = orders.Count(); + int rowUpdated = await orders.UpdateFromQueryAsync(o => new Order { Price = 25.25M }); + int newTotal = orders.Where(o => o.Price != 25.25M).Count(); + int matchCount = dbContext.Orders.Where(o => ids.Contains(o.Id) && o.Price == 25.25M).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < $10)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the condtion (Price < $10)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + Assert.IsTrue(matchCount == rowUpdated, "The match count must be equal the number of rows updated in the database."); + } + [TestMethod] + public async Task With_MethodCall() + { + var dbContext = SetupDbContext(true); + + int oldTotal = dbContext.Orders.Count(a => a.Price < 10); + int rowUpdated = await dbContext.Orders.Where(a => a.Price < 10).UpdateFromQueryAsync(o => new Order { Price = Math.Ceiling((o.Price + 10.5M) * 3 / 1) }); + int newTotal = dbContext.Orders.Count(o => o.Price < 10); + + Assert.IsTrue(oldTotal > 0, "There must be order in database that match this condition (Price < 10)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the condition (Price < 10)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public async Task With_Null_Value() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.ExternalId != null); + int oldTotal = orders.Count(); + int rowUpdated = await orders.UpdateFromQueryAsync(o => new Order { ExternalId = null }); + int newTotal = orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (ExternalId != null)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the condition (ExternalId != null)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public async Task With_Schema() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Schema); + var products = dbContext.ProductsWithCustomSchema.Where(o => o.Price < 5M); + int oldTotal = products.Count(); + int rowUpdated = await products.UpdateFromQueryAsync(o => new ProductWithCustomSchema { Price = 25.30M }); + int newTotal = products.Count(); + int matchCount = dbContext.ProductsWithCustomSchema.Where(o => o.Price == 25.30M).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be products in database that match this condition (Price < 5)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the condtion (Price < 5)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + Assert.IsTrue(matchCount == rowUpdated, "The match count must be equal the number of rows updated in the database."); + } + [TestMethod] + public async Task With_String_Containing_Apostrophe() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Where(o => o.ExternalId == null).Count(); + int rowUpdated = await dbContext.Orders.Where(o => o.ExternalId == null).UpdateFromQueryAsync(o => new Order { ExternalId = "inv'alid" }); + int newTotal = dbContext.Orders.Where(o => o.ExternalId == null).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (ExternalId == null)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the condition (ExternalId == null)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [DoNotParallelize] + [TestMethod] + public async Task With_Transaction() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Where(o => o.Price < 10M).Count(); + int rowUpdated; + using (var transaction = dbContext.Database.BeginTransaction()) + { + rowUpdated = await dbContext.Orders.Where(o => o.Price < 10M).UpdateFromQueryAsync(o => new Order { Price = 25.30M }); + transaction.Rollback(); + } + int newTotal = dbContext.Orders.Where(o => o.Price < 10M).Count(); + int matchCount = dbContext.Orders.Where(o => o.Price == 25.30M).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < $10)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the condtion (Price < $10)"); + Assert.IsTrue(newTotal == oldTotal, "The new count must match the old count since the transaction was rollbacked"); + Assert.IsTrue(matchCount == 0, "The match count must be equal to 0 since the transaction was rollbacked."); + } + [TestMethod] + public async Task With_Variables() + { + var dbContext = SetupDbContext(true); + decimal priceStart = 10M; + decimal priceUpdate = 0.34M; + + int oldTotal = dbContext.Orders.Count(a => a.Price < 10); + int rowUpdated = await dbContext.Orders.Where(a => a.Price < 10).UpdateFromQueryAsync(a => new Order { Price = priceStart + priceUpdate }); + int newTotal = dbContext.Orders.Count(o => o.Price < 10); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < 10)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the condition (Price < 10)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public async Task With_Variable_And_Decimal() + { + var dbContext = SetupDbContext(true); + decimal priceStart = 10M; + + int oldTotal = dbContext.Orders.Count(a => a.Price < 10); + int rowUpdated = await dbContext.Orders.Where(a => a.Price < 10).UpdateFromQueryAsync(a => new Order { Price = priceStart + 7M }); + int newTotal = dbContext.Orders.Count(o => o.Price < 10); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < 10)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the condition (Price < 10)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbSetExtensions/Clear.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbSetExtensions/Clear.cs new file mode 100644 index 0000000..998cf97 --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbSetExtensions/Clear.cs @@ -0,0 +1,21 @@ +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +namespace N.EntityFrameworkCore.Extensions.Test.DbSetExtensions; + +[TestClass] +public class Clear : DbContextExtensionsBase +{ + [TestMethod] + public void Using_Dbset() + { + var dbContext = SetupDbContext(true); + int oldOrdersCount = dbContext.Orders.Count(); + dbContext.Orders.Clear(); + int newOrdersCount = dbContext.Orders.Count(); + + Assert.IsTrue(oldOrdersCount > 0, "Orders table should have data"); + Assert.IsTrue(newOrdersCount == 0, "Order table should be empty after truncating"); + } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbSetExtensions/ClearAsync.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbSetExtensions/ClearAsync.cs new file mode 100644 index 0000000..ee209c7 --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbSetExtensions/ClearAsync.cs @@ -0,0 +1,22 @@ +using System.Linq; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +namespace N.EntityFrameworkCore.Extensions.Test.DbSetExtensions; + +[TestClass] +public class ClearAsync : DbContextExtensionsBase +{ + [TestMethod] + public async Task Using_Dbset() + { + var dbContext = SetupDbContext(true); + int oldOrdersCount = dbContext.Orders.Count(); + await dbContext.Orders.ClearAsync(); + int newOrdersCount = dbContext.Orders.Count(); + + Assert.IsTrue(oldOrdersCount > 0, "Orders table should have data"); + Assert.IsTrue(newOrdersCount == 0, "Order table should be empty after truncating"); + } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbSetExtensions/Truncate.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbSetExtensions/Truncate.cs new file mode 100644 index 0000000..03b999e --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbSetExtensions/Truncate.cs @@ -0,0 +1,21 @@ +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +namespace N.EntityFrameworkCore.Extensions.Test.DbSetExtensions; + +[TestClass] +public class Truncate : DbContextExtensionsBase +{ + [TestMethod] + public void Using_Dbset() + { + var dbContext = SetupDbContext(true); + int oldOrdersCount = dbContext.Orders.Count(); + dbContext.Orders.Truncate(); + int newOrdersCount = dbContext.Orders.Count(); + + Assert.IsTrue(oldOrdersCount > 0, "Orders table should have data"); + Assert.IsTrue(newOrdersCount == 0, "Order table should be empty after truncating"); + } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbSetExtensions/TruncateAsync.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbSetExtensions/TruncateAsync.cs new file mode 100644 index 0000000..838547e --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/DbSetExtensions/TruncateAsync.cs @@ -0,0 +1,22 @@ +using System.Linq; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +namespace N.EntityFrameworkCore.Extensions.Test.DbSetExtensions; + +[TestClass] +public class TruncateAsync : DbContextExtensionsBase +{ + [TestMethod] + public async Task Using_Dbset() + { + var dbContext = SetupDbContext(true); + int oldOrdersCount = dbContext.Orders.Count(); + await dbContext.Orders.TruncateAsync(); + int newOrdersCount = dbContext.Orders.Count(); + + Assert.IsTrue(oldOrdersCount > 0, "Orders table should have data"); + Assert.IsTrue(newOrdersCount == 0, "Order table should be empty after truncating"); + } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/LinqExtensions/ToSqlPredicateTests.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/LinqExtensions/ToSqlPredicateTests.cs new file mode 100644 index 0000000..69dfdff --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/LinqExtensions/ToSqlPredicateTests.cs @@ -0,0 +1,102 @@ +using System; +using System.Linq.Expressions; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace N.EntityFrameworkCore.Extensions.Test.LinqExtensions; + +[TestClass] +public class ToSqlPredicateTests +{ + [TestMethod] + public void Should_handle_int() + { + Expression> expression = (s, t) => s.Id == t.Id; + + var sqlPredicate = expression.ToSqlPredicate("s", "t"); + + Assert.AreEqual("s.Id = t.Id", sqlPredicate); + } + + [TestMethod] + public void Should_handle_enum() + { + Expression> expression = (s, t) => s.Type == t.Type; + + var sqlPredicate = expression.ToSqlPredicate("s", "t"); + + Assert.AreEqual("s.Type = t.Type", sqlPredicate); + } + + [TestMethod] + public void Should_handle_complex_one() + { + Expression> expression = (s, t) => s.Type == t.Type && + (s.Id == t.Id && + s.ExternalId == t.ExternalId); + + var sqlPredicate = expression.ToSqlPredicate("s", "t"); + + Assert.AreEqual("s.Type = t.Type AND s.Id = t.Id AND s.ExternalId = t.ExternalId", sqlPredicate); + } + + [TestMethod] + public void Should_handle_prop_naming() + { + Expression> expression = (source, target) => source.Id == target.Id && + source.ExternalId == target.ExternalId; + + var sqlPredicate = expression.ToSqlPredicate("s", "t"); + + Assert.AreEqual("s.Id = t.Id AND s.ExternalId = t.ExternalId", sqlPredicate); + } + + [TestMethod] + public void Should_handle_simple_big_one() + { + Expression> expression = (s, t) => s.Type == t.Type && + s.Id == t.Id && + s.ExternalId == t.ExternalId && + s.TesterVar1 == t.TesterVar1 && + s.TesterVar2 == t.TesterVar2 && + s.TesterVar3 == t.TesterVar3 && + s.TesterVar4 == t.TesterVar4 && + s.TesterVar5 == t.TesterVar5; + + var sqlPredicate = expression.ToSqlPredicate("s", "t"); + + Assert.AreEqual("s.Type = t.Type AND s.Id = t.Id AND s.ExternalId = t.ExternalId AND s.TesterVar1 = t.TesterVar1 AND s.TesterVar2 = t.TesterVar2 AND s.TesterVar3 = t.TesterVar3 AND s.TesterVar4 = t.TesterVar4 AND s.TesterVar5 = t.TesterVar5", sqlPredicate); + } + + [TestMethod] + public void Should_handle_complex_big_one() + { + Expression> expression = (s, t) => s.Type == t.Type && + s.Id == t.Id && + (s.ExternalId == t.ExternalId || s.TesterVar1 == t.TesterVar1) && + (s.TesterVar2 == t.TesterVar2 || (s.TesterVar2 == null && t.TesterVar2 == null)) && + (s.TesterVar3 == t.TesterVar3 || (s.TesterVar3 != null && t.TesterVar3 != null)); + + var sqlPredicate = expression.ToSqlPredicate("s", "t"); + + Assert.AreEqual("s.Type = t.Type AND s.Id = t.Id AND (s.ExternalId = t.ExternalId OR s.TesterVar1 = t.TesterVar1) AND (s.TesterVar2 = t.TesterVar2 OR s.TesterVar2 IS NULL AND t.TesterVar2 IS NULL) AND (s.TesterVar3 = t.TesterVar3 OR s.TesterVar3 IS NOT NULL AND t.TesterVar3 IS NOT NULL)", sqlPredicate); + } + + record Entity + { + public Guid Id { get; set; } + public EntityType Type { get; set; } + public int ExternalId { get; set; } + public string TesterVar1 { get; set; } + public string TesterVar2 { get; set; } + public string TesterVar3 { get; set; } + public string TesterVar4 { get; set; } + public string TesterVar5 { get; set; } + } + + enum EntityType + { + One, + Two, + Three + } +} \ No newline at end of file diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/N.EntityFramework.Extensions.MySql.Test.csproj b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/N.EntityFramework.Extensions.MySql.Test.csproj new file mode 100644 index 0000000..fcf453b --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/N.EntityFramework.Extensions.MySql.Test.csproj @@ -0,0 +1,33 @@ + + + + net9.0 + + $(MSBuildThisFileDirectory)..\..\N.EntityFramework.Extensions.MySql.runsettings + + + + + + + + + + + + + + + + + + + + + + + Always + + + + diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/appsettings.json b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/appsettings.json new file mode 100644 index 0000000..f170582 --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.MySql.Test/appsettings.json @@ -0,0 +1,7 @@ +{ + "DatabaseProvider": "MySql", + "UseMySqlContainer": true, + "ConnectionStrings": { + "MySqlTestDatabase": "Server=localhost;Port=3306;Database=NEntityFrameworkCoreExtensions;User Id=root;Password=mysql;" + } +} diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.SqlServer.Test/Common/Config.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.SqlServer.Test/Common/Config.cs index 79459c2..00019bc 100644 --- a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.SqlServer.Test/Common/Config.cs +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.SqlServer.Test/Common/Config.cs @@ -17,7 +17,10 @@ public static string GetConnectionString(string name) return configuration.GetConnectionString(name); } public static bool IsSqlServer => true; - public static string GetTestDatabaseConnectionString() => GetConnectionString("SqlServerTestDatabase"); + public static bool UseSqlServerContainer => + !string.Equals(configuration["UseSqlServerContainer"], "false", StringComparison.OrdinalIgnoreCase); + public static string GetTestDatabaseConnectionString() => + UseSqlServerContainer ? SqlServerContainerManager.GetConnectionString() : GetConnectionString("SqlServerTestDatabase"); public static DbParameter CreateParameter(string name, object value) => new SqlParameter(name, value ?? DBNull.Value); public static string DelimitIdentifier(string identifier) => $"[{identifier}]"; public static string DelimitTableName(string tableName) => tableName; diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.SqlServer.Test/Common/SqlServerContainerManager.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.SqlServer.Test/Common/SqlServerContainerManager.cs new file mode 100644 index 0000000..480e236 --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.SqlServer.Test/Common/SqlServerContainerManager.cs @@ -0,0 +1,70 @@ +using System; +using System.Threading.Tasks; +using Microsoft.Data.SqlClient; +using Testcontainers.MsSql; + +namespace N.EntityFrameworkCore.Extensions.Test.Common; + +internal static class SqlServerContainerManager +{ + private static readonly object syncRoot = new(); + private static Task initializationTask; + private static MsSqlContainer container; + private static bool cleanupRegistered; + + internal static string GetConnectionString() + { + EnsureStarted(); + var builder = new SqlConnectionStringBuilder(container.GetConnectionString()) + { + InitialCatalog = "NEntityFrameworkCoreExtensions" + }; + return builder.ConnectionString; + } + + internal static void EnsureStarted() + { + EnsureStartedAsync().GetAwaiter().GetResult(); + } + + internal static Task EnsureStartedAsync() + { + lock (syncRoot) + { + initializationTask ??= StartContainerAsync(); + return initializationTask; + } + } + + private static async Task StartContainerAsync() + { + try + { + container = new MsSqlBuilder("mcr.microsoft.com/mssql/server:2022-latest") + .Build(); + + await container.StartAsync(); + RegisterCleanup(); + } + catch (Exception ex) + { + throw new InvalidOperationException("SqlServer tests require Docker when UseSqlServerContainer is enabled.", ex); + } + } + + private static void RegisterCleanup() + { + lock (syncRoot) + { + if (cleanupRegistered) + return; + + AppDomain.CurrentDomain.ProcessExit += (_, _) => + { + if (container != null) + container.DisposeAsync().AsTask().GetAwaiter().GetResult(); + }; + cleanupRegistered = true; + } + } +} diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.SqlServer.Test/Common/TestDatabaseInitializer.cs b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.SqlServer.Test/Common/TestDatabaseInitializer.cs index c530f4d..cebd343 100644 --- a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.SqlServer.Test/Common/TestDatabaseInitializer.cs +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.SqlServer.Test/Common/TestDatabaseInitializer.cs @@ -8,12 +8,18 @@ internal static class TestDatabaseInitializer { internal static void EnsureCreated(TestDbContext dbContext) { + if (Config.UseSqlServerContainer) + SqlServerContainerManager.EnsureStarted(); + dbContext.Database.EnsureCreated(); CreateSqlServerObjects(dbContext); } internal static async Task EnsureCreatedAsync(TestDbContext dbContext) { + if (Config.UseSqlServerContainer) + await SqlServerContainerManager.EnsureStartedAsync(); + await dbContext.Database.EnsureCreatedAsync(); await CreateSqlServerObjectsAsync(dbContext); } diff --git a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.SqlServer.Test/N.EntityFramework.Extensions.SqlServer.Test.csproj b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.SqlServer.Test/N.EntityFramework.Extensions.SqlServer.Test.csproj index 9ccc176..9690c96 100644 --- a/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.SqlServer.Test/N.EntityFramework.Extensions.SqlServer.Test.csproj +++ b/N.EntityFrameworkCore.Extensions.Test/N.EntityFramework.Extensions.SqlServer.Test/N.EntityFramework.Extensions.SqlServer.Test.csproj @@ -3,7 +3,7 @@ net10.0 - $(MSBuildThisFileDirectory)..\..\N.EntityFrameworkCore.Extensions.runsettings + $(MSBuildThisFileDirectory)..\..\N.EntityFramework.Extensions.SqlServer.runsettings @@ -21,6 +21,7 @@ + diff --git a/N.EntityFrameworkCore.Extensions.sln b/N.EntityFrameworkCore.Extensions.sln index 3423c29..9e8d267 100644 --- a/N.EntityFrameworkCore.Extensions.sln +++ b/N.EntityFrameworkCore.Extensions.sln @@ -18,6 +18,10 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "N.EntityFramework.Extension EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "N.EntityFramework.Extensions.PostgreSql.Test", "N.EntityFrameworkCore.Extensions.Test\N.EntityFramework.Extensions.PostgreSql.Test\N.EntityFramework.Extensions.PostgreSql.Test.csproj", "{6CE9CBC4-2626-4464-A92E-35E92D5284B4}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "N.EntityFramework.Extensions.MySql", "N.EntityFramework.Extensions.MySql\N.EntityFramework.Extensions.MySql.csproj", "{51DD9C47-4F1B-472E-AFE1-BE1E0D3A1009}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "N.EntityFramework.Extensions.MySql.Test", "N.EntityFrameworkCore.Extensions.Test\N.EntityFramework.Extensions.MySql.Test\N.EntityFramework.Extensions.MySql.Test.csproj", "{C02A68C1-39E3-45A5-B7C3-CB24374B4B43}" +EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "N.EntityFrameworkCore.Extensions", "N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.csproj", "{DD7629D4-2E9D-4192-B4A5-06F49B9D5B5E}" EndProject Global @@ -78,6 +82,30 @@ Global {6CE9CBC4-2626-4464-A92E-35E92D5284B4}.Release|x64.Build.0 = Release|Any CPU {6CE9CBC4-2626-4464-A92E-35E92D5284B4}.Release|x86.ActiveCfg = Release|Any CPU {6CE9CBC4-2626-4464-A92E-35E92D5284B4}.Release|x86.Build.0 = Release|Any CPU + {51DD9C47-4F1B-472E-AFE1-BE1E0D3A1009}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {51DD9C47-4F1B-472E-AFE1-BE1E0D3A1009}.Debug|Any CPU.Build.0 = Debug|Any CPU + {51DD9C47-4F1B-472E-AFE1-BE1E0D3A1009}.Debug|x64.ActiveCfg = Debug|Any CPU + {51DD9C47-4F1B-472E-AFE1-BE1E0D3A1009}.Debug|x64.Build.0 = Debug|Any CPU + {51DD9C47-4F1B-472E-AFE1-BE1E0D3A1009}.Debug|x86.ActiveCfg = Debug|Any CPU + {51DD9C47-4F1B-472E-AFE1-BE1E0D3A1009}.Debug|x86.Build.0 = Debug|Any CPU + {51DD9C47-4F1B-472E-AFE1-BE1E0D3A1009}.Release|Any CPU.ActiveCfg = Release|Any CPU + {51DD9C47-4F1B-472E-AFE1-BE1E0D3A1009}.Release|Any CPU.Build.0 = Release|Any CPU + {51DD9C47-4F1B-472E-AFE1-BE1E0D3A1009}.Release|x64.ActiveCfg = Release|Any CPU + {51DD9C47-4F1B-472E-AFE1-BE1E0D3A1009}.Release|x64.Build.0 = Release|Any CPU + {51DD9C47-4F1B-472E-AFE1-BE1E0D3A1009}.Release|x86.ActiveCfg = Release|Any CPU + {51DD9C47-4F1B-472E-AFE1-BE1E0D3A1009}.Release|x86.Build.0 = Release|Any CPU + {C02A68C1-39E3-45A5-B7C3-CB24374B4B43}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {C02A68C1-39E3-45A5-B7C3-CB24374B4B43}.Debug|Any CPU.Build.0 = Debug|Any CPU + {C02A68C1-39E3-45A5-B7C3-CB24374B4B43}.Debug|x64.ActiveCfg = Debug|Any CPU + {C02A68C1-39E3-45A5-B7C3-CB24374B4B43}.Debug|x64.Build.0 = Debug|Any CPU + {C02A68C1-39E3-45A5-B7C3-CB24374B4B43}.Debug|x86.ActiveCfg = Debug|Any CPU + {C02A68C1-39E3-45A5-B7C3-CB24374B4B43}.Debug|x86.Build.0 = Debug|Any CPU + {C02A68C1-39E3-45A5-B7C3-CB24374B4B43}.Release|Any CPU.ActiveCfg = Release|Any CPU + {C02A68C1-39E3-45A5-B7C3-CB24374B4B43}.Release|Any CPU.Build.0 = Release|Any CPU + {C02A68C1-39E3-45A5-B7C3-CB24374B4B43}.Release|x64.ActiveCfg = Release|Any CPU + {C02A68C1-39E3-45A5-B7C3-CB24374B4B43}.Release|x64.Build.0 = Release|Any CPU + {C02A68C1-39E3-45A5-B7C3-CB24374B4B43}.Release|x86.ActiveCfg = Release|Any CPU + {C02A68C1-39E3-45A5-B7C3-CB24374B4B43}.Release|x86.Build.0 = Release|Any CPU {DD7629D4-2E9D-4192-B4A5-06F49B9D5B5E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {DD7629D4-2E9D-4192-B4A5-06F49B9D5B5E}.Debug|Any CPU.Build.0 = Debug|Any CPU {DD7629D4-2E9D-4192-B4A5-06F49B9D5B5E}.Debug|x64.ActiveCfg = Debug|Any CPU @@ -97,6 +125,7 @@ Global GlobalSection(NestedProjects) = preSolution {A52CDEFF-F507-4EA2-B5B1-AE46A3AFCE95} = {CBD9B889-7168-4D7F-898F-3111EABC28DE} {6CE9CBC4-2626-4464-A92E-35E92D5284B4} = {CBD9B889-7168-4D7F-898F-3111EABC28DE} + {C02A68C1-39E3-45A5-B7C3-CB24374B4B43} = {CBD9B889-7168-4D7F-898F-3111EABC28DE} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {68C5AD77-58F3-4E3F-8238-73ADAB7059C2} diff --git a/PROJECT_FILES_CONTENT.txt b/PROJECT_FILES_CONTENT.txt new file mode 100644 index 0000000..8262d0a --- /dev/null +++ b/PROJECT_FILES_CONTENT.txt @@ -0,0 +1,23893 @@ +=== DIRECTORY: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer === + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Common\Constants.cs ---namespace N.EntityFrameworkCore.Extensions.Common;public static class Constants +{ + public static readonly string InternalId_ColumnName = "_be_xx_id"; +}--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Data\BulkDeleteOptions.cs ---using System; +using System.Linq.Expressions;namespace N.EntityFrameworkCore.Extensions;public class BulkDeleteOptions : BulkOptions +{ + public Expression> DeleteOnCondition { get; set; } +} +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Data\BulkFetchOptions.cs ---using System; +using System.Linq.Expressions;namespace N.EntityFrameworkCore.Extensions;public class BulkFetchOptions : BulkOptions +{ + public Expression> IgnoreColumns { get; set; } + public Expression> InputColumns { get; set; } + public Expression> JoinOnCondition { get; set; } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Data\BulkInsertO +Options.cs --- + +using System; +using System.Linq; +using System.Linq.Expressions; + +namespace N.EntityFrameworkCore.Extensions; + +public class BulkInsertOptions : BulkOptions +{ + public bool AutoMapOutput { get; set; } + public Expression> IgnoreColumns { get; set; } + public Expression> InputColumns { get; set; } + public bool InsertIfNotExists { get; set; } + public Expression> InsertOnCondition { get; set; } + public bool KeepIdentity { get; set; } + + public string[] GetInputColumns() => + InputColumns?.Body.Type.GetProperties().Select(o => o.Name).ToArray(); + + public BulkInsertOptions() + { + AutoMapOutput = true; + } + internal BulkInsertOptions(BulkOptions options) + { + EntityType = options.EntityType; + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Data\BulkInsertR +Result.cs --- + +using System.Collections.Generic; + +namespace N.EntityFrameworkCore.Extensions; + +internal sealed class BulkInsertResult +{ + internal int RowsAffected { get; set; } + internal Dictionary EntityMap { get; set; } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Data\BulkMergeOp +ption.cs --- + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; + +namespace N.EntityFrameworkCore.Extensions; + +public class BulkMergeOptions : BulkOptions +{ + public Expression> MergeOnCondition { get; set; } + public Expression> IgnoreColumnsOnInsert { get; set; } + public Expression> IgnoreColumnsOnUpdate { get; set; } + public bool AutoMapOutput { get; set; } + internal bool DeleteIfNotMatched { get; set; } + + public BulkMergeOptions() + { + AutoMapOutput = true; + } + public List GetIgnoreColumnsOnInsert() => + IgnoreColumnsOnInsert?.Body.Type.GetProperties().Select(o => o.Name).ToList() ?? []; + public List GetIgnoreColumnsOnUpdate() => + IgnoreColumnsOnUpdate?.Body.Type.GetProperties().Select(o => o.Name).ToList() ?? []; +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Data\BulkMergeOu +utputRow.cs --- + +namespace N.EntityFrameworkCore.Extensions; + +public class BulkMergeOutputRow +{ + public string Action { get; set; } + + public BulkMergeOutputRow(string action) + { + Action = action; + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Data\BulkMergeRe +esult.cs --- + +using System.Collections.Generic; + +namespace N.EntityFrameworkCore.Extensions; + +public class BulkMergeResult +{ + public IEnumerable> Output { get; set; } + public int RowsAffected { get; set; } + public int RowsDeleted { get; internal set; } + public int RowsInserted { get; internal set; } + public int RowsUpdated { get; internal set; } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Data\BulkOperati +ion.cs --- + +using System; +using System.Collections.Generic; +using System.Data.Common; +using System.IO; +using System.Linq; +using System.Linq.Expressions; +using Microsoft.Data.SqlClient; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Metadata; +using N.EntityFrameworkCore.Extensions.Common; +using N.EntityFrameworkCore.Extensions.Sql; +using N.EntityFrameworkCore.Extensions.Util; + +namespace N.EntityFrameworkCore.Extensions; + +internal sealed partial class BulkOperation : IDisposable +{ + internal DbConnection Connection => DbTransactionContext.Connection; + internal DbContext Context { get; } + internal bool StagingTableCreated { get; set; } + internal string StagingTableName { get; } + internal string[] PrimaryKeyColumnNames { get; } + internal BulkOptions Options { get; } + internal Expression> InputColumns { get; } + internal Expression> IgnoreColumns { get; } + internal DbTransactionContext DbTransactionContext { get; } + internal Type EntityType => typeof(T); + internal DbTransaction Transaction => DbTransactionContext.CurrentTransaction; + internal TableMapping TableMapping { get; } + internal IEnumerable SchemaQualifiedTableNames => TableMapping.GetSchemaQualifiedTableNames(); + + public BulkOperation(DbContext dbContext, BulkOptions options, Expression> inputColumns = null, Expr +ression> ignoreColumns = null) + { + Context = dbContext; + Options = options; + InputColumns = inputColumns; + IgnoreColumns = ignoreColumns; + + DbTransactionContext = new DbTransactionContext(dbContext, options.CommandTimeout); + TableMapping = dbContext.GetTableMapping(typeof(T), options.EntityType); + StagingTableName = CommonUtil.GetStagingTableName(TableMapping, options.UsePermanentTable, Connection); + PrimaryKeyColumnNames = TableMapping.GetPrimaryKeyColumns().ToArray(); + } + public void Dispose() + { + if (StagingTableCreated) + Context.Database.DropTable(StagingTableName, true); + } + internal BulkInsertResult BulkInsertStagingData(IEnumerable entities, bool keepIdentity = true, bool useIntern +nalId = false) + { + IEnumerable columnsToInsert = GetColumnNames(keepIdentity); + string internalIdColumn = useInternalId ? Common.Constants.InternalId_ColumnName : null; + Context.Database.CloneTable(SchemaQualifiedTableNames, StagingTableName, TableMapping.GetQualifiedColumnNames(co +olumnsToInsert), internalIdColumn); + StagingTableCreated = true; + return DbContextExtensions.BulkInsert(entities, Options, TableMapping, Connection, Transaction, StagingTableName +e, columnsToInsert, SqlBulkCopyOptions.KeepIdentity, useInternalId); + } + internal BulkMergeResult ExecuteMerge(Dictionary entityMap, Expression> mergeOnConditio +on, + bool autoMapOutput, bool keepIdentity, bool insertIfNotExists, bool update = false, bool delete = false) + { + Dictionary rowsInserted = []; + Dictionary rowsUpdated = []; + Dictionary rowsDeleted = []; + Dictionary rowsAffected = []; + List> outputRows = []; + + foreach (var entityType in TableMapping.EntityTypes) + { + rowsInserted[entityType] = 0; + rowsUpdated[entityType] = 0; + rowsDeleted[entityType] = 0; + rowsAffected[entityType] = 0; + + var columnsToInsert = GetColumnNames(entityType, keepIdentity); + var columnsToUpdate = update ? GetColumnNames(entityType) : []; + var autoGeneratedColumns = autoMapOutput ? TableMapping.GetAutoGeneratedColumns(entityType) : []; + var columnsToOutput = autoMapOutput ? GetMergeOutputColumns(autoGeneratedColumns, delete) : []; + var deleteEntityType = TableMapping.EntityType == entityType && delete; + + string mergeOnConditionSql = insertIfNotExists ? CommonUtil.GetJoinConditionSql(mergeOnCondition, Primary +yKeyColumnNames, "t", "s") : "1=2"; + bool toggleIdentity = keepIdentity && TableMapping.HasIdentityColumn; + var mergeStatement = SqlStatement.CreateMerge(StagingTableName, entityType.GetSchemaQualifiedTableName(), + mergeOnConditionSql, columnsToInsert, columnsToUpdate, columnsToOutput, deleteEntityType, toggleIdentity +y); + + if (autoMapOutput) + { + List allProperties = + [ + .. TableMapping.GetEntityProperties(entityType, ValueGenerated.OnAdd).ToArray(), + .. TableMapping.GetEntityProperties(entityType, ValueGenerated.OnAddOrUpdate).ToArray() + ]; + + var bulkQueryResult = Context.BulkQuery(mergeStatement.Sql, Options); + rowsAffected[entityType] = bulkQueryResult.RowsAffected; + + foreach (var result in bulkQueryResult.Results) + { + string action = (string)result[0]; + outputRows.Add(new BulkMergeOutputRow(action)); + + if (action == SqlMergeAction.Delete) + { + rowsDeleted[entityType]++; + } + else + { + int entityId = (int)result[1]; + var entity = entityMap[entityId]; + if (action == SqlMergeAction.Insert) + { + rowsInserted[entityType]++; + if (allProperties.Count != 0) + { + var entityValues = GetMergeOutputValues(columnsToOutput, result, allProperties); + Context.SetStoreGeneratedValues(entity, allProperties, entityValues); + } + } + else if (action == SqlMergeAction.Update) + { + rowsUpdated[entityType]++; + if (allProperties.Count != 0) + { + var entityValues = GetMergeOutputValues(columnsToOutput, result, allProperties); + Context.SetStoreGeneratedValues(entity, allProperties, entityValues); + } + } + } + } + } + else + { + rowsAffected[entityType] = Context.Database.ExecuteSqlInternal(mergeStatement.Sql, Options.CommandTimeou +ut); + } + } + return new BulkMergeResult + { + Output = outputRows, + RowsAffected = rowsAffected.Values.LastOrDefault(), + RowsDeleted = rowsDeleted.Values.LastOrDefault(), + RowsInserted = rowsInserted.Values.LastOrDefault(), + RowsUpdated = rowsUpdated.Values.LastOrDefault() + }; + } + + private IEnumerable GetMergeOutputColumns(IEnumerable autoGeneratedColumns, bool delete = false) + { + List columnsToOutput = ["$action", $"[s].[{Constants.InternalId_ColumnName}]"]; + columnsToOutput.AddRange(autoGeneratedColumns.Select(o => $"[inserted].[{o}]")); + return columnsToOutput; + } + private object[] GetMergeOutputValues(IEnumerable columns, object[] values, IEnumerable propertie +es) + { + var columnList = columns.ToList(); + var valuesIndex = properties.Select(o => columnList.IndexOf($"[inserted].[{o.GetColumnName()}]")); + return valuesIndex.Select(i => values[i]).ToArray(); + } + internal int ExecuteUpdate(IEnumerable entities, Expression> updateOnCondition) + { + int rowsUpdated = 0; + foreach (var entityType in TableMapping.EntityTypes) + { + IEnumerable columnsToUpdate = CommonUtil.FormatColumns(GetColumnNames(entityType)); + string updateSetExpression = string.Join(",", columnsToUpdate.Select(o => $"t.{o}=s.{o}")); + string updateSql = $"UPDATE t SET {updateSetExpression} FROM {StagingTableName} AS s JOIN {CommonUtil.Format +tTableName(entityType.GetSchemaQualifiedTableName())} AS t ON {CommonUtil.GetJoinConditionSql(updateOnCondition, Prima +aryKeyColumnNames, "s", "t")}; SELECT @@RowCount;"; + rowsUpdated = Context.Database.ExecuteSqlInternal(updateSql, Options.CommandTimeout); + } + return rowsUpdated; + } + internal void ValidateBulkMerge(Expression> mergeOnCondition) + { + if (PrimaryKeyColumnNames.Length == 0 && mergeOnCondition == null) + throw new InvalidDataException("BulkMerge requires that the entity have a primary key or that Options.MergeO +OnCondition be set"); + } + internal void ValidateBulkUpdate(Expression> updateOnCondition) + { + if (PrimaryKeyColumnNames.Length == 0 && updateOnCondition == null) + throw new InvalidDataException("BulkUpdate requires that the entity have a primary key or the Options.Update +eOnCondition must be set."); + } + internal IEnumerable GetColumnNames(bool includePrimaryKeys = false) + { + return GetColumnNames(null, includePrimaryKeys); + } + internal IEnumerable GetColumnNames(IEntityType entityType, bool includePrimaryKeys = false) + { + return CommonUtil.FilterColumns(TableMapping.GetColumnNames(entityType, includePrimaryKeys), PrimaryKeyColumnNam +mes, InputColumns, IgnoreColumns); + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Data\BulkOperati +ionAsync.cs --- + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Data.SqlClient; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Metadata; +using N.EntityFrameworkCore.Extensions.Common; +using N.EntityFrameworkCore.Extensions.Sql; +using N.EntityFrameworkCore.Extensions.Util; + +namespace N.EntityFrameworkCore.Extensions; + +internal sealed partial class BulkOperation +{ + internal async Task> BulkInsertStagingDataAsync(IEnumerable entities, bool keepIdentity = tru +ue, bool useInternalId = false, CancellationToken cancellationToken = default) + { + IEnumerable columnsToInsert = GetColumnNames(keepIdentity); + string internalIdColumn = useInternalId ? Common.Constants.InternalId_ColumnName : null; + await Context.Database.CloneTableAsync(SchemaQualifiedTableNames, StagingTableName, TableMapping.GetQualifiedCol +lumnNames(columnsToInsert), internalIdColumn, cancellationToken); + StagingTableCreated = true; + return await DbContextExtensionsAsync.BulkInsertAsync(entities, Options, TableMapping, Connection, Transaction, + StagingTableName, columnsToInsert, SqlBulkCopyOptions.KeepIdentity, useInternalId, cancellationToken); + } + + internal async Task> ExecuteMergeAsync(Dictionary entityMap, Expression +>> mergeOnCondition, + bool autoMapOutput, bool keepIdentity, bool insertIfNotExists, bool update = false, bool delete = false, Cancell +lationToken cancellationToken = default) + { + Dictionary rowsInserted = []; + Dictionary rowsUpdated = []; + Dictionary rowsDeleted = []; + Dictionary rowsAffected = []; + List> outputRows = []; + + foreach (var entityType in TableMapping.EntityTypes) + { + rowsInserted[entityType] = 0; + rowsUpdated[entityType] = 0; + rowsDeleted[entityType] = 0; + rowsAffected[entityType] = 0; + + var columnsToInsert = GetColumnNames(entityType, keepIdentity); + var columnsToUpdate = update ? GetColumnNames(entityType) : []; + var autoGeneratedColumns = autoMapOutput ? TableMapping.GetAutoGeneratedColumns(entityType) : []; + var columnsToOutput = autoMapOutput ? GetMergeOutputColumns(autoGeneratedColumns, delete) : []; + var deleteEntityType = TableMapping.EntityType == entityType && delete; + + string mergeOnConditionSql = insertIfNotExists ? CommonUtil.GetJoinConditionSql(mergeOnCondition, Primary +yKeyColumnNames, "t", "s") : "1=2"; + bool toggleIdentity = keepIdentity && TableMapping.HasIdentityColumn; + var mergeStatement = SqlStatement.CreateMerge(StagingTableName, entityType.GetSchemaQualifiedTableName(), + mergeOnConditionSql, columnsToInsert, columnsToUpdate, columnsToOutput, deleteEntityType, toggleIdentity +y); + + if (autoMapOutput) + { + List allProperties = + [ + .. TableMapping.GetEntityProperties(entityType, ValueGenerated.OnAdd).ToArray(), + .. TableMapping.GetEntityProperties(entityType, ValueGenerated.OnAddOrUpdate).ToArray() + ]; + + var bulkQueryResult = await Context.BulkQueryAsync(mergeStatement.Sql, Connection, Transaction, Options, +, cancellationToken); + rowsAffected[entityType] = bulkQueryResult.RowsAffected; + + foreach (var result in bulkQueryResult.Results) + { + string action = (string)result[0]; + outputRows.Add(new BulkMergeOutputRow(action)); + + if (action == SqlMergeAction.Delete) + { + rowsDeleted[entityType]++; + } + else + { + int entityId = (int)result[1]; + var entity = entityMap[entityId]; + if (action == SqlMergeAction.Insert) + { + rowsInserted[entityType]++; + if (allProperties.Count != 0) + { + var entityValues = GetMergeOutputValues(columnsToOutput, result, allProperties); + Context.SetStoreGeneratedValues(entity, allProperties, entityValues); + } + } + else if (action == SqlMergeAction.Update) + { + rowsUpdated[entityType]++; + if (allProperties.Count != 0) + { + var entityValues = GetMergeOutputValues(columnsToOutput, result, allProperties); + Context.SetStoreGeneratedValues(entity, allProperties, entityValues); + } + } + } + } + } + else + { + rowsAffected[entityType] = await Context.Database.ExecuteSqlAsync(mergeStatement.Sql, Options.CommandTim +meout, cancellationToken); + } + } + return new BulkMergeResult + { + Output = outputRows, + RowsAffected = rowsAffected.Values.LastOrDefault(), + RowsDeleted = rowsDeleted.Values.LastOrDefault(), + RowsInserted = rowsInserted.Values.LastOrDefault(), + RowsUpdated = rowsUpdated.Values.LastOrDefault() + }; + } + internal async Task ExecuteUpdateAsync(IEnumerable entities, Expression> updateOnCondition, +, CancellationToken cancellationToken = default) + { + int rowsUpdated = 0; + foreach (var entityType in TableMapping.EntityTypes) + { + IEnumerable columnsToUpdate = CommonUtil.FormatColumns(GetColumnNames(entityType)); + string updateSetExpression = string.Join(",", columnsToUpdate.Select(o => $"t.{o}=s.{o}")); + string updateSql = $"UPDATE t SET {updateSetExpression} FROM {StagingTableName} AS s JOIN {CommonUtil.Format +tTableName(entityType.GetSchemaQualifiedTableName())} AS t ON {CommonUtil.GetJoinConditionSql(updateOnCondition, Prima +aryKeyColumnNames, "s", "t")}; SELECT @@RowCount;"; + rowsUpdated = await Context.Database.ExecuteSqlAsync(updateSql, Options.CommandTimeout, cancellationToken); + } + return rowsUpdated; + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Data\BulkOptions +s.cs --- + +using Microsoft.Data.SqlClient; +using Microsoft.EntityFrameworkCore.Metadata; +using N.EntityFrameworkCore.Extensions.Enums; + +namespace N.EntityFrameworkCore.Extensions; + +public class BulkOptions +{ + public int BatchSize { get; set; } + public SqlBulkCopyOptions BulkCopyOptions { get; internal set; } + public SqlBulkCopyColumnOrderHintCollection ColumnOrderHints { get; internal set; } + public bool EnableStreaming { get; internal set; } + public int NotifyAfter { get; internal set; } + public bool UsePermanentTable { get; set; } + public int? CommandTimeout { get; set; } + internal ConnectionBehavior ConnectionBehavior { get; set; } + internal IEntityType EntityType { get; set; } + + public SqlRowsCopiedEventHandler SqlRowsCopied { get; internal set; } + + public BulkOptions() + { + BulkCopyOptions = SqlBulkCopyOptions.Default; + ColumnOrderHints = new SqlBulkCopyColumnOrderHintCollection(); + ConnectionBehavior = ConnectionBehavior.Default; + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Data\BulkQueryRe +esult.cs --- + +using System.Collections.Generic; + +namespace N.EntityFrameworkCore.Extensions; + +public class BulkQueryResult +{ + public IEnumerable Results { get; internal set; } + public IEnumerable Columns { get; internal set; } + public int RowsAffected { get; internal set; } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Data\BulkSyncOpt +tions.cs --- + + +namespace N.EntityFrameworkCore.Extensions; + +public class BulkSyncOptions : BulkMergeOptions +{ + public BulkSyncOptions() + { + DeleteIfNotMatched = true; + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Data\BulkSyncRes +sult.cs --- + + +namespace N.EntityFrameworkCore.Extensions; + +public class BulkSyncResult : BulkMergeResult +{ + public new int RowsDeleted { get; set; } + public static BulkSyncResult Map(BulkMergeResult result) + { + return new BulkSyncResult() + { + Output = result.Output, + RowsAffected = result.RowsAffected, + RowsDeleted = result.RowsDeleted, + RowsInserted = result.RowsInserted, + RowsUpdated = result.RowsUpdated + }; + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Data\BulkUpdateO +Options.cs --- + +using System; +using System.Linq.Expressions; + +namespace N.EntityFrameworkCore.Extensions; + +public class BulkUpdateOptions : BulkOptions +{ + public Expression> InputColumns { get; set; } + public Expression> IgnoreColumns { get; set; } + public Expression> UpdateOnCondition { get; set; } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Data\DatabaseFac +cadeExtensions.cs --- + +using System; +using System.Collections.Generic; +using System.Data; +using System.Data.Common; +using System.Linq; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Storage; +using N.EntityFrameworkCore.Extensions.Enums; +using N.EntityFrameworkCore.Extensions.Util; + +namespace N.EntityFrameworkCore.Extensions; + +public static class DatabaseFacadeExtensions +{ + public static SqlQuery FromSqlQuery(this DatabaseFacade database, string sqlText, params object[] parameters) + { + return new SqlQuery(database, sqlText, parameters); + } + public static int ClearTable(this DatabaseFacade database, string tableName) + { + return database.ExecuteSqlRaw($"DELETE FROM {database.DelimitTableName(tableName)}"); + } + public static int DropTable(this DatabaseFacade database, string tableName, bool ifExists = false) + { + string formattedTableName = database.DelimitTableName(tableName); + string sql = ifExists ? $"DROP TABLE IF EXISTS {formattedTableName}" : $"DROP TABLE {formattedTableName}"; + return database.ExecuteSqlInternal(sql, null, ConnectionBehavior.Default); + } + public static void TruncateTable(this DatabaseFacade database, string tableName, bool ifExists = false) + { + bool truncateTable = !ifExists || database.TableExists(tableName); + if (!truncateTable) + return; + + string formattedTableName = database.DelimitTableName(tableName); + database.ExecuteSqlRaw($"TRUNCATE TABLE {formattedTableName}"); + } + public static bool TableExists(this DatabaseFacade database, string tableName) + { + var objectName = database.ParseObjectName(tableName); + return Convert.ToBoolean(database.ExecuteScalar( + "SELECT CASE WHEN EXISTS (SELECT 1 FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = @schema AND TABLE_NAM +ME = @name) THEN 1 ELSE 0 END", + [CreateParameter(database, "@schema", objectName.Schema), CreateParameter(database, "@name", objectName.Name +e)])); + } + public static bool TableHasIdentity(this DatabaseFacade database, string tableName) + { + var objectName = database.ParseObjectName(tableName); + return Convert.ToBoolean(database.ExecuteScalar("SELECT ISNULL(OBJECTPROPERTY(OBJECT_ID(@fullName), 'TableHasIde +entity'), 0)", + [CreateParameter(database, "@fullName", $"{objectName.Schema}.{objectName.Name}")])); + } + internal static int CloneTable(this DatabaseFacade database, string sourceTable, string destinationTable, IEnumerabl +le columnNames, string internalIdColumnName = null) + { + return database.CloneTable([sourceTable], destinationTable, columnNames, internalIdColumnName); + } + internal static int CloneTable(this DatabaseFacade database, IEnumerable sourceTables, string destinationTab +ble, IEnumerable columnNames, string internalIdColumnName = null) + { + string columns = columnNames != null && columnNames.Any() ? string.Join(",", columnNames.Select(database.FormatS +SelectColumn)) : "*"; + if (!string.IsNullOrEmpty(internalIdColumnName)) + columns = $"{columns},CAST(NULL AS INT) AS {database.DelimitIdentifier(internalIdColumnName)}"; + + return database.ExecuteSqlRaw($"SELECT TOP 0 {columns} INTO {destinationTable} FROM {string.Join(",", sourceTabl +les)}"); + } + internal static DbCommand CreateCommand(this DatabaseFacade database, ConnectionBehavior connectionBehavior = Connec +ctionBehavior.Default) + { + var dbConnection = database.GetDbConnection(connectionBehavior); + if (dbConnection.State != ConnectionState.Open) + dbConnection.Open(); + var command = dbConnection.CreateCommand(); + if (database.CurrentTransaction != null && connectionBehavior == ConnectionBehavior.Default) + command.Transaction = database.CurrentTransaction.GetDbTransaction(); + return command; + } + internal static int ExecuteSqlInternal(this DatabaseFacade database, string sql, int? commandTimeout = null, Connect +tionBehavior connectionBehavior = default) + { + return database.ExecuteSql(sql, null, commandTimeout, connectionBehavior); + } + internal static int ExecuteSql(this DatabaseFacade database, string sql, object[] parameters = null, int? commandTim +meout = null, ConnectionBehavior connectionBehavior = default) + { + using var command = database.CreateCommand(connectionBehavior); + command.CommandText = sql; + if (commandTimeout != null) + command.CommandTimeout = commandTimeout.Value; + if (parameters != null) + command.Parameters.AddRange(parameters); + return command.ExecuteNonQuery(); + } + internal static object ExecuteScalar(this DatabaseFacade database, string query, object[] parameters = null, int? co +ommandTimeout = null) + { + using var command = database.CreateCommand(); + command.CommandText = query; + if (commandTimeout.HasValue) + command.CommandTimeout = commandTimeout.Value; + if (parameters != null) + command.Parameters.AddRange(parameters); + return command.ExecuteScalar(); + } + internal static void ToggleIdentityInsert(this DatabaseFacade database, string tableName, bool enable) + { + bool hasIdentity = database.TableHasIdentity(tableName); + if (hasIdentity) + { + string boolString = enable ? "ON" : "OFF"; + database.ExecuteSql($"SET IDENTITY_INSERT {tableName} {boolString}"); + } + } + internal static DbConnection GetDbConnection(this DatabaseFacade database, ConnectionBehavior connectionBehavior) + { + return connectionBehavior == ConnectionBehavior.New ? database.GetDbConnection().CloneConnection() : database.Ge +etDbConnection(); + } + + private static DbParameter CreateParameter(DatabaseFacade database, string name, object value) + { + using var command = database.GetDbConnection().CreateCommand(); + var parameter = command.CreateParameter(); + parameter.ParameterName = name; + parameter.Value = value ?? DBNull.Value; + return parameter; + } + internal static string FormatSelectColumn(this DatabaseFacade database, string columnName) + { + if (columnName.Contains('[') || columnName.Contains('"') || columnName.Contains('(') || columnName.Contains(' ') +)) + return columnName; + + if (columnName.Contains('.')) + return string.Join(".", columnName.Split('.').Select(database.DelimitIdentifier)); + + return database.DelimitIdentifier(columnName); + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Data\DatabaseFac +cadeExtensionsAsync.cs --- + +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Infrastructure; +using N.EntityFrameworkCore.Extensions.Util; + +namespace N.EntityFrameworkCore.Extensions; + +public static class DatabaseFacadeExtensionsAsync +{ + public static async Task ClearTableAsync(this DatabaseFacade database, string tableName, CancellationToken canc +cellationToken = default) + { + return await database.ExecuteSqlRawAsync($"DELETE FROM {database.DelimitTableName(tableName)}", cancellationToke +en); + } + public static async Task TruncateTableAsync(this DatabaseFacade database, string tableName, bool ifExists = false, C +CancellationToken cancellationToken = default) + { + bool truncateTable = !ifExists || database.TableExists(tableName); + if (!truncateTable) + return; + + string formattedTableName = database.DelimitTableName(tableName); + await database.ExecuteSqlRawAsync($"TRUNCATE TABLE {formattedTableName}", cancellationToken); + } + internal static async Task CloneTableAsync(this DatabaseFacade database, string sourceTable, string destination +nTable, IEnumerable columnNames, string internalIdColumnName = null, CancellationToken cancellationToken = defaul +lt) + { + return await database.CloneTableAsync([sourceTable], destinationTable, columnNames, internalIdColumnName, cancel +llationToken); + } + internal static async Task CloneTableAsync(this DatabaseFacade database, IEnumerable sourceTables, stri +ing destinationTable, IEnumerable columnNames, string internalIdColumnName = null, CancellationToken cancellation +nToken = default) + { + string columns = columnNames != null && columnNames.Any() ? string.Join(",", columnNames.Select(database.FormatS +SelectColumn)) : "*"; + if (!string.IsNullOrEmpty(internalIdColumnName)) + columns = $"{columns},CAST(NULL AS INT) AS {database.DelimitIdentifier(internalIdColumnName)}"; + + return await database.ExecuteSqlRawAsync($"SELECT TOP 0 {columns} INTO {destinationTable} FROM {string.Join(",", +, sourceTables)}", cancellationToken); + } + internal static async Task ExecuteSqlAsync(this DatabaseFacade database, string sql, int? commandTimeout = null +l, CancellationToken cancellationToken = default) + { + return await database.ExecuteSqlAsync(sql, null, commandTimeout, cancellationToken); + } + internal static async Task ExecuteSqlAsync(this DatabaseFacade database, string sql, object[] parameters = null +l, int? commandTimeout = null, CancellationToken cancellationToken = default) + { + int value; + int? origCommandTimeout = database.GetCommandTimeout(); + database.SetCommandTimeout(commandTimeout); + value = parameters != null + ? await database.ExecuteSqlRawAsync(sql, parameters, cancellationToken) + : await database.ExecuteSqlRawAsync(sql, cancellationToken); + database.SetCommandTimeout(origCommandTimeout); + return value; + } + internal static async Task ExecuteScalarAsync(this DatabaseFacade database, string query, object[] parameter +rs = null, int? commandTimeout = null, CancellationToken cancellationToken = default) + { + await using var command = database.CreateCommand(); + command.CommandText = query; + if (commandTimeout.HasValue) + command.CommandTimeout = commandTimeout.Value; + if (parameters != null) + command.Parameters.AddRange(parameters); + return await command.ExecuteScalarAsync(cancellationToken); + } + internal static async Task ToggleIdentityInsertAsync(this DatabaseFacade database, string tableName, bool enable) + { + bool hasIdentity = database.TableHasIdentity(tableName); + if (hasIdentity) + { + string boolString = enable ? "ON" : "OFF"; + await database.ExecuteSqlAsync($"SET IDENTITY_INSERT {tableName} {boolString}", database.GetCommandTimeout() +)); + } + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Data\DbContextEx +xtensions.cs --- + +using System; +using System.Collections.Generic; +using System.Data; +using System.Data.Common; +using System.IO; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using Microsoft.Data.SqlClient; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.ChangeTracking.Internal; +using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Internal; +using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.EntityFrameworkCore.Query; +using Microsoft.EntityFrameworkCore.Query.Internal; +using N.EntityFrameworkCore.Extensions.Common; +using N.EntityFrameworkCore.Extensions.Enums; +using N.EntityFrameworkCore.Extensions.Extensions; +using N.EntityFrameworkCore.Extensions.Sql; +using N.EntityFrameworkCore.Extensions.Util; + +namespace N.EntityFrameworkCore.Extensions; + +public static class DbContextExtensions +{ + private static readonly EfExtensionsCommandInterceptor efExtensionsCommandInterceptor; + static DbContextExtensions() + { + efExtensionsCommandInterceptor = new EfExtensionsCommandInterceptor(); + } + public static void SetupEfCoreExtensions(this DbContextOptionsBuilder builder) + { + builder.AddInterceptors(efExtensionsCommandInterceptor); + } + public static int BulkDelete(this DbContext context, IEnumerable entities) + { + return context.BulkDelete(entities, new BulkDeleteOptions()); + } + public static int BulkDelete(this DbContext context, IEnumerable entities, Action> option +nsAction) + { + return context.BulkDelete(entities, optionsAction.Build()); + } + public static int BulkDelete(this DbContext context, IEnumerable entities, BulkDeleteOptions options) + { + var tableMapping = context.GetTableMapping(typeof(T), options.EntityType); + + using (var dbTransactionContext = new DbTransactionContext(context, options)) + { + var dbConnection = dbTransactionContext.Connection; + var transaction = dbTransactionContext.CurrentTransaction; + int rowsAffected = 0; + try + { + string stagingTableName = CommonUtil.GetStagingTableName(tableMapping, options.UsePermanentTable, dbConn +nection); + string destinationTableName = context.DelimitIdentifier(tableMapping.TableName, tableMapping.Schema); + string[] keyColumnNames = options.DeleteOnCondition != null ? CommonUtil.GetColumns(options.DeleteOnC +Condition, ["s"]) + : tableMapping.GetPrimaryKeyColumns().ToArray(); + + if (keyColumnNames.Length == 0 && options.DeleteOnCondition == null) + throw new InvalidDataException("BulkDelete requires that the entity have a primary key or the Option +ns.DeleteOnCondition must be set."); + + context.Database.CloneTable(destinationTableName, stagingTableName, keyColumnNames); + BulkInsert(entities, options, tableMapping, dbConnection, transaction, stagingTableName, keyColumnNames, +, SqlBulkCopyOptions.KeepIdentity, false); + + string joinCondition = CommonUtil.GetJoinConditionSql(context, options.DeleteOnCondition, keyColumnNa +ames); + string deleteSql = $"DELETE t FROM {stagingTableName} s JOIN {destinationTableName} t ON {joinCondition} +}"; + rowsAffected = context.Database.ExecuteSqlInternal(deleteSql, options.CommandTimeout); + + context.Database.DropTable(stagingTableName); + dbTransactionContext.Commit(); + } + catch (Exception) + { + dbTransactionContext.Rollback(); + throw; + } + return rowsAffected; + } + } + public static IEnumerable BulkFetch(this DbSet dbSet, IEnumerable entities) where T : class, new() + { + return dbSet.BulkFetch(entities, new BulkFetchOptions()); + } + public static IEnumerable BulkFetch(this DbSet dbSet, IEnumerable entities, Action> optionsAction) where T : class, new() + { + return dbSet.BulkFetch(entities, optionsAction.Build()); + } + public static IEnumerable BulkFetch(this DbSet dbSet, IEnumerable entities, BulkFetchOptions optio +ons) where T : class, new() + { + var context = dbSet.GetDbContext(); + var tableMapping = context.GetTableMapping(typeof(T)); + + using (var dbTransactionContext = new DbTransactionContext(context, options.CommandTimeout, ConnectionBehavior.N +New)) + { + string selectSql, stagingTableName = string.Empty; + var dbConnection = dbTransactionContext.Connection; + var transaction = dbTransactionContext.CurrentTransaction; + try + { + stagingTableName = CommonUtil.GetStagingTableName(tableMapping, true, dbConnection); + string destinationTableName = context.DelimitIdentifier(tableMapping.TableName, tableMapping.Schema); + string[] keyColumnNames = options.JoinOnCondition != null ? CommonUtil.GetColumns(options.JoinOnCondi +ition, ["s"]) + : tableMapping.GetPrimaryKeyColumns().ToArray(); + IEnumerable columnNames = CommonUtil.FilterColumns(tableMapping.GetColumns(true), keyColumnNa +ames, options.InputColumns, options.IgnoreColumns); + IEnumerable columnsToFetch = CommonUtil.FormatColumns(context, "t", columnNames); + + if (keyColumnNames.Length == 0 && options.JoinOnCondition == null) + throw new InvalidDataException("BulkFetch requires that the entity have a primary key or the Options +s.JoinOnCondition must be set."); + + context.Database.CloneTable(destinationTableName, stagingTableName, keyColumnNames); + BulkInsert(entities, options, tableMapping, dbConnection, transaction, stagingTableName, keyColumnNames, +, SqlBulkCopyOptions.KeepIdentity, false); + selectSql = $"SELECT {SqlUtil.ConvertToColumnString(columnsToFetch)} FROM {stagingTableName} s JOIN {des +stinationTableName} t ON {CommonUtil.GetJoinConditionSql(context, options.JoinOnCondition, keyColumnNames)}"; + + + dbTransactionContext.Commit(); + } + catch + { + dbTransactionContext.Rollback(); + throw; + } + + foreach (var item in context.FetchInternal(selectSql)) + { + yield return item; + } + context.Database.DropTable(stagingTableName); + } + } + public static void Fetch(this IQueryable queryable, Action> action, Action> opt +tionsAction) where T : class, new() + { + Fetch(queryable, action, optionsAction.Build()); + } + public static void Fetch(this IQueryable queryable, Action> action, FetchOptions options) wh +here T : class, new() + { + var dbContext = queryable.GetDbContext(); + var tableMapping = dbContext.GetTableMapping(typeof(T)); + HashSet includedColumns = GetIncludedColumns(tableMapping, options.InputColumns, options.IgnoreColumns); + int batch = 1; + int count = 0; + List entities = []; + foreach (var entity in queryable.AsNoTracking().AsEnumerable()) + { + ClearExcludedColumns(dbContext, tableMapping, entity, includedColumns); + entities.Add(entity); + count++; + if (count == options.BatchSize) + { + action(new FetchResult { Results = entities, Batch = batch }); + entities.Clear(); + count = 0; + batch++; + } + } + + if (entities.Count > 0) + action(new FetchResult { Results = entities, Batch = batch }); + } + public static int BulkInsert(this DbContext context, IEnumerable entities) + { + return context.BulkInsert(entities, new BulkInsertOptions()); + } + public static int BulkInsert(this DbContext context, IEnumerable entities, Action> option +nsAction) + { + return context.BulkInsert(entities, optionsAction.Build()); + } + public static int BulkInsert(this DbContext context, IEnumerable entities, BulkInsertOptions options) + { + int rowsAffected = 0; + using (var bulkOperation = new BulkOperation(context, options, options.InputColumns, options.IgnoreColumns)) + { + try + { + var bulkInsertResult = bulkOperation.BulkInsertStagingData(entities, true, true); + var bulkMergeResult = bulkOperation.ExecuteMerge(bulkInsertResult.EntityMap, options.InsertOnCondition, + options.AutoMapOutput, options.KeepIdentity, options.InsertIfNotExists); + rowsAffected = bulkMergeResult.RowsAffected; + bulkOperation.DbTransactionContext.Commit(); + } + catch (Exception) + { + bulkOperation.DbTransactionContext.Rollback(); + throw; + } + } + return rowsAffected; + } + public static BulkMergeResult BulkMerge(this DbContext context, IEnumerable entities) + { + return BulkMerge(context, entities, new BulkMergeOptions()); + } + public static BulkMergeResult BulkMerge(this DbContext context, IEnumerable entities, BulkMergeOptions o +options) + { + return InternalBulkMerge(context, entities, options); + } + public static BulkMergeResult BulkMerge(this DbContext context, IEnumerable entities, Action> optionsAction) + { + return BulkMerge(context, entities, optionsAction.Build()); + } + public static int BulkSaveChanges(this DbContext dbContext) + { + return dbContext.BulkSaveChanges(true); + } + public static int BulkSaveChanges(this DbContext dbContext, bool acceptAllChangesOnSuccess = true) + { + int rowsAffected = 0; + var stateManager = dbContext.GetDependencies().StateManager; + + dbContext.ChangeTracker.DetectChanges(); + var entries = stateManager.GetEntriesToSave(true); + + foreach (var saveEntryGroup in entries.GroupBy(o => new { o.EntityType, o.EntityState })) + { + var key = saveEntryGroup.Key; + var entities = saveEntryGroup.AsEnumerable(); + if (key.EntityState == EntityState.Added) + { + rowsAffected += dbContext.BulkInsert(entities, o => { o.EntityType = key.EntityType; }); + } + else if (key.EntityState == EntityState.Modified) + { + rowsAffected += dbContext.BulkUpdate(entities, o => { o.EntityType = key.EntityType; }); + } + else if (key.EntityState == EntityState.Deleted) + { + rowsAffected += dbContext.BulkDelete(entities, o => { o.EntityType = key.EntityType; }); + } + } + + if (acceptAllChangesOnSuccess) + dbContext.ChangeTracker.AcceptAllChanges(); + + return rowsAffected; + } + public static BulkSyncResult BulkSync(this DbContext context, IEnumerable entities) + { + return BulkSync(context, entities, new BulkSyncOptions()); + } + public static BulkSyncResult BulkSync(this DbContext context, IEnumerable entities, Action> optionsAction) + { + return BulkSyncResult.Map(InternalBulkMerge(context, entities, optionsAction.Build())); + } + public static BulkSyncResult BulkSync(this DbContext context, IEnumerable entities, BulkSyncOptions opti +ions) + { + return BulkSyncResult.Map(InternalBulkMerge(context, entities, options)); + } + public static int BulkUpdate(this DbContext context, IEnumerable entities) + { + return BulkUpdate(context, entities, new BulkUpdateOptions()); + } + public static int BulkUpdate(this DbContext context, IEnumerable entities, Action> option +nsAction) + { + return BulkUpdate(context, entities, optionsAction.Build()); + } + public static int BulkUpdate(this DbContext context, IEnumerable entities, BulkUpdateOptions options) + { + int rowsUpdated = 0; + using (var bulkOperation = new BulkOperation(context, options, options.InputColumns, options.IgnoreColumns)) + { + try + { + bulkOperation.ValidateBulkUpdate(options.UpdateOnCondition); + bulkOperation.BulkInsertStagingData(entities); + rowsUpdated = bulkOperation.ExecuteUpdate(entities, options.UpdateOnCondition); + bulkOperation.DbTransactionContext.Commit(); + } + catch (Exception) + { + bulkOperation.DbTransactionContext.Rollback(); + throw; + } + } + return rowsUpdated; + } + public static int DeleteFromQuery(this IQueryable queryable, int? commandTimeout = null) where T : class + { + using (var dbTransactionContext = new DbTransactionContext(queryable.GetDbContext(), commandTimeout)) + { + try + { + int rowsAffected = queryable.ExecuteDelete(); + dbTransactionContext.Commit(); + return rowsAffected; + } + catch (Exception) + { + dbTransactionContext.Rollback(); + throw; + } + } + } + public static int InsertFromQuery(this IQueryable queryable, string tableName, Expression> ins +sertObjectExpression, int? commandTimeout = null) where T : class + { + using (var dbTransactionContext = new DbTransactionContext(queryable.GetDbContext(), commandTimeout)) + { + var dbContext = dbTransactionContext.DbContext; + try + { + var tableMapping = dbContext.GetTableMapping(typeof(T)); + var columnNames = insertObjectExpression.GetObjectProperties(); + if (!dbContext.Database.TableExists(tableName)) + { + dbContext.Database.CloneTable(tableMapping.FullQualifedTableName, dbContext.Database.DelimitTableNam +me(tableName), tableMapping.GetQualifiedColumnNames(columnNames)); + } + + var entities = queryable.AsNoTracking().ToList(); + int rowsAffected = BulkInsert(entities, new BulkInsertOptions { KeepIdentity = true, AutoMapOutput = + false, CommandTimeout = commandTimeout }, tableMapping, + dbTransactionContext.Connection, dbTransactionContext.CurrentTransaction, dbContext.Database.Delimit +tTableName(tableName), columnNames, SqlBulkCopyOptions.KeepIdentity).RowsAffected; + dbTransactionContext.Commit(); + return rowsAffected; + } + catch (Exception) + { + dbTransactionContext.Rollback(); + throw; + } + } + } + public static int UpdateFromQuery(this IQueryable queryable, Expression> updateExpression, int? com +mmandTimeout = null) where T : class + { + using (var dbTransactionContext = new DbTransactionContext(queryable.GetDbContext(), commandTimeout)) + { + try + { + int rowsAffected = queryable.ExecuteUpdate(BuildSetPropertyCalls(updateExpression)); + dbTransactionContext.Commit(); + return rowsAffected; + } + catch (Exception) + { + dbTransactionContext.Rollback(); + throw; + } + } + } + public static QueryToFileResult QueryToCsvFile(this IQueryable queryable, string filePath) where T : class + { + return QueryToCsvFile(queryable, filePath, new QueryToFileOptions()); + } + public static QueryToFileResult QueryToCsvFile(this IQueryable queryable, Stream stream) where T : class + { + return QueryToCsvFile(queryable, stream, new QueryToFileOptions()); + } + public static QueryToFileResult QueryToCsvFile(this IQueryable queryable, string filePath, Action optionsAction) where T : class + { + return QueryToCsvFile(queryable, filePath, optionsAction.Build()); + } + public static QueryToFileResult QueryToCsvFile(this IQueryable queryable, Stream stream, Action optionsAction) where T : class + { + return QueryToCsvFile(queryable, stream, optionsAction.Build()); + } + public static QueryToFileResult QueryToCsvFile(this IQueryable queryable, string filePath, QueryToFileOptions + options) where T : class + { + using var fileStream = File.Create(filePath); + return QueryToCsvFile(queryable, fileStream, options); + } + public static QueryToFileResult QueryToCsvFile(this IQueryable queryable, Stream stream, QueryToFileOptions op +ptions) where T : class + { + return InternalQueryToFile(queryable, stream, options); + } + public static QueryToFileResult SqlQueryToCsvFile(this DatabaseFacade database, string filePath, string sqlText, par +rams object[] parameters) + { + return SqlQueryToCsvFile(database, filePath, new QueryToFileOptions(), sqlText, parameters); + } + public static QueryToFileResult SqlQueryToCsvFile(this DatabaseFacade database, Stream stream, string sqlText, param +ms object[] parameters) + { + return SqlQueryToCsvFile(database, stream, new QueryToFileOptions(), sqlText, parameters); + } + public static QueryToFileResult SqlQueryToCsvFile(this DatabaseFacade database, string filePath, Action optionsAction, string sqlText, params object[] parameters) + { + return SqlQueryToCsvFile(database, filePath, optionsAction.Build(), sqlText, parameters); + } + public static QueryToFileResult SqlQueryToCsvFile(this DatabaseFacade database, Stream stream, Action optionsAction, string sqlText, params object[] parameters) + { + return SqlQueryToCsvFile(database, stream, optionsAction.Build(), sqlText, parameters); + } + public static QueryToFileResult SqlQueryToCsvFile(this DatabaseFacade database, string filePath, QueryToFileOptions + options, string sqlText, params object[] parameters) + { + using var fileStream = File.Create(filePath); + return SqlQueryToCsvFile(database, fileStream, options, sqlText, parameters); + } + public static QueryToFileResult SqlQueryToCsvFile(this DatabaseFacade database, Stream stream, QueryToFileOptions op +ptions, string sqlText, params object[] parameters) + { + return InternalQueryToFile(database.GetDbConnection(), stream, options, sqlText, parameters); + } + public static void Clear(this DbSet dbSet) where T : class + { + var dbContext = dbSet.GetDbContext(); + var tableMapping = dbContext.GetTableMapping(typeof(T)); + dbContext.Database.ClearTable(tableMapping.FullQualifedTableName); + } + public static void Truncate(this DbSet dbSet) where T : class + { + var dbContext = dbSet.GetDbContext(); + var tableMapping = dbContext.GetTableMapping(typeof(T)); + dbContext.Database.TruncateTable(tableMapping.FullQualifedTableName); + } + public static IQueryable UsingTable(this IQueryable queryable, string tableName) where T : class + { + var dbContext = queryable.GetDbContext(); + var tableMapping = dbContext.GetTableMapping(typeof(T)); + efExtensionsCommandInterceptor.AddCommand(Guid.NewGuid(), + new EfExtensionsCommand + { + CommandType = EfExtensionsCommandType.ChangeTableName, + OldValue = tableMapping.FullQualifedTableName, + NewValue = dbContext.Database.DelimitTableName(tableName), + Connection = dbContext.GetDbConnection() + }); + return queryable; + } + public static TableMapping GetTableMapping(this DbContext dbContext, Type type, IEntityType entityType = null) + { + entityType ??= dbContext.Model.FindEntityType(type); + return new TableMapping(dbContext, entityType); + } + internal static void SetStoreGeneratedValues(this DbContext context, T entity, IEnumerable properties, +, object[] values) + { + int index = 0; + var updateEntry = entity as InternalEntityEntry; + if (updateEntry == null) + { + var entry = context.Entry(entity); + updateEntry = entry.GetInfrastructure(); + } + + if (updateEntry != null) + { + foreach (var property in properties) + { + if ((updateEntry.EntityState == EntityState.Added && + (property.ValueGenerated == ValueGenerated.OnAdd || property.ValueGenerated == ValueGenerated.OnAddO +OrUpdate)) || + (updateEntry.EntityState == EntityState.Modified && + (property.ValueGenerated == ValueGenerated.OnUpdate || property.ValueGenerated == ValueGenerated.OnA +AddOrUpdate)) || + updateEntry.EntityState == EntityState.Detached + ) + { + updateEntry.SetStoreGeneratedValue(property, values[index]); + } + index++; + } + if (updateEntry.EntityState == EntityState.Detached) + updateEntry.AcceptChanges(); + } + else + { + throw new InvalidOperationException("SetStoreValues() failed because an instance of InternalEntityEntry was + not found."); + } + } + internal static BulkInsertResult BulkInsert(IEnumerable entities, BulkOptions options, TableMapping tableMa +apping, DbConnection dbConnection, DbTransaction transaction, string tableName, + IEnumerable inputColumns = null, SqlBulkCopyOptions bulkCopyOptions = SqlBulkCopyOptions.Default, bool u +useInternalId = false) + { + using var dataReader = new EntityDataReader(tableMapping, entities, useInternalId); + var sqlBulkCopy = new SqlBulkCopy((SqlConnection)dbConnection, bulkCopyOptions | options.BulkCopyOptions, (SqlTr +ransaction)transaction) + { + DestinationTableName = tableName, + BatchSize = options.BatchSize, + NotifyAfter = options.NotifyAfter, + EnableStreaming = options.EnableStreaming, + }; + sqlBulkCopy.BulkCopyTimeout = options.CommandTimeout.HasValue ? options.CommandTimeout.Value : sqlBulkCopy.BulkC +CopyTimeout; + if (options.SqlRowsCopied != null) + sqlBulkCopy.SqlRowsCopied += options.SqlRowsCopied; + foreach (SqlBulkCopyColumnOrderHint columnOrderHint in options.ColumnOrderHints) + sqlBulkCopy.ColumnOrderHints.Add(columnOrderHint); + foreach (var property in dataReader.TableMapping.Properties) + { + var columnName = dataReader.TableMapping.GetColumnName(property); + if (inputColumns == null || inputColumns.Contains(columnName)) + sqlBulkCopy.ColumnMappings.Add(columnName, columnName); + } + if (useInternalId) + sqlBulkCopy.ColumnMappings.Add(Constants.InternalId_ColumnName, Constants.InternalId_ColumnName); + sqlBulkCopy.WriteToServer(dataReader); + + return new BulkInsertResult + { + RowsAffected = sqlBulkCopy.RowsCopied, + EntityMap = dataReader.EntityMap + }; + } + internal static BulkQueryResult BulkQuery(this DbContext context, string sqlText, BulkOptions options) + { + List results = []; + List columns = []; + using var command = context.Database.CreateCommand(); + command.CommandText = sqlText; + if (options.CommandTimeout.HasValue) + command.CommandTimeout = options.CommandTimeout.Value; + using var reader = command.ExecuteReader(); + while (reader.Read()) + { + if (columns.Count == 0) + { + for (int i = 0; i < reader.FieldCount; i++) + columns.Add(reader.GetName(i)); + } + object[] values = new object[reader.FieldCount]; + reader.GetValues(values); + results.Add(values); + } + + return new BulkQueryResult + { + Columns = columns, + Results = results, + RowsAffected = reader.RecordsAffected + }; + } + internal static DbContext GetDbContext(this IQueryable queryable) where T : class + { + DbContext dbContext; + try + { + if ((queryable as InternalDbSet) != null) + { + dbContext = queryable.GetPrivateFieldValue("_context") as DbContext; + } + else if ((queryable as EntityQueryable) != null) + { + var queryCompiler = queryable.Provider.GetPrivateFieldValue("_queryCompiler"); + var contextFactory = queryCompiler.GetPrivateFieldValue("_queryContextFactory"); + var queryDependencies = contextFactory.GetPrivateFieldValue("Dependencies") as QueryContextDependencies; + dbContext = queryDependencies.CurrentContext.Context as DbContext; + } + else + { + throw new Exception("This extension method could not find the DbContext for this type that implements IQ +Queryable"); + } + } + catch + { + throw new Exception("This extension method could not find the DbContext for this type that implements IQuery +yable"); + } + return dbContext; + } + internal static DbConnection GetDbConnection(this DbContext context, ConnectionBehavior connectionBehavior = Connect +tionBehavior.Default) + { + var dbConnection = context.Database.GetDbConnection(); + return connectionBehavior == ConnectionBehavior.New ? dbConnection.CloneConnection() : dbConnection; + } + private static IEnumerable FetchInternal(this DbContext dbContext, string sqlText, object[] parameters = null) +) where T : class, new() + { + using var command = dbContext.Database.CreateCommand(ConnectionBehavior.New); + command.CommandText = sqlText; + if (parameters != null) + command.Parameters.AddRange(parameters); + + var tableMapping = dbContext.GetTableMapping(typeof(T), null); + using var reader = command.ExecuteReader(); + var properties = reader.GetProperties(tableMapping); + var valuesFromProvider = properties.Select(p => tableMapping.GetValueFromProvider(p)).ToArray(); + + while (reader.Read()) + { + var entity = reader.MapEntity(dbContext, properties, valuesFromProvider); + yield return entity; + } + } + private static BulkMergeResult InternalBulkMerge(this DbContext context, IEnumerable entities, BulkMergeOpt +tions options) + { + BulkMergeResult bulkMergeResult; + using (var bulkOperation = new BulkOperation(context, options)) + { + try + { + bulkOperation.ValidateBulkMerge(options.MergeOnCondition); + var bulkInsertResult = bulkOperation.BulkInsertStagingData(entities, true, true); + bulkMergeResult = bulkOperation.ExecuteMerge(bulkInsertResult.EntityMap, options.MergeOnCondition, optio +ons.AutoMapOutput, + false, true, true, options.DeleteIfNotMatched); + bulkOperation.DbTransactionContext.Commit(); + } + catch (Exception) + { + bulkOperation.DbTransactionContext.Rollback(); + throw; + } + } + return bulkMergeResult; + } + private static void ClearEntityStateToUnchanged(DbContext dbContext, IEnumerable entities) + { + foreach (var entity in entities) + { + var entry = dbContext.Entry(entity); + if (entry.State == EntityState.Added || entry.State == EntityState.Modified) + dbContext.Entry(entity).State = EntityState.Unchanged; + } + } + private static void Validate(TableMapping tableMapping) + { + if (!tableMapping.GetPrimaryKeyColumns().Any()) + { + throw new Exception("You must have a primary key on this table to use this function."); + } + } + private static QueryToFileResult InternalQueryToFile(this IQueryable queryable, Stream stream, QueryToFileOpti +ions options) where T : class + { + return InternalQueryToFile(queryable.AsNoTracking().AsEnumerable(), stream, options); + } + private static QueryToFileResult InternalQueryToFile(DbConnection dbConnection, Stream stream, QueryToFileOptions op +ptions, string sqlText, object[] parameters = null) + { + int dataRowCount = 0; + int totalRowCount = 0; + long bytesWritten = 0; + + if (dbConnection.State == ConnectionState.Closed) + dbConnection.Open(); + + using var command = dbConnection.CreateCommand(); + command.CommandText = sqlText; + if (parameters != null) + command.Parameters.AddRange(parameters); + if (options.CommandTimeout.HasValue) + command.CommandTimeout = options.CommandTimeout.Value; + + using var streamWriter = new StreamWriter(stream, leaveOpen: true); + using (var reader = command.ExecuteReader()) + { + if (options.IncludeHeaderRow) + { + for (int i = 0; i < reader.FieldCount; i++) + { + streamWriter.Write(options.TextQualifer); + streamWriter.Write(reader.GetName(i)); + streamWriter.Write(options.TextQualifer); + if (i != reader.FieldCount - 1) + { + streamWriter.Write(options.ColumnDelimiter); + } + } + totalRowCount++; + streamWriter.Write(options.RowDelimiter); + } + while (reader.Read()) + { + object[] values = new object[reader.FieldCount]; + reader.GetValues(values); + for (int i = 0; i < values.Length; i++) + { + streamWriter.Write(options.TextQualifer); + streamWriter.Write(values[i]); + streamWriter.Write(options.TextQualifer); + if (i != values.Length - 1) + { + streamWriter.Write(options.ColumnDelimiter); + } + } + streamWriter.Write(options.RowDelimiter); + dataRowCount++; + totalRowCount++; + } + streamWriter.Flush(); + bytesWritten = streamWriter.BaseStream.Length; + } + return new QueryToFileResult() + { + BytesWritten = bytesWritten, + DataRowCount = dataRowCount, + TotalRowCount = totalRowCount + }; + } + private static QueryToFileResult InternalQueryToFile(IEnumerable entities, Stream stream, QueryToFileOptions o +options) + { + int dataRowCount = 0; + int totalRowCount = 0; + long bytesWritten = 0; + var properties = typeof(T).GetProperties().Where(p => p.CanRead && !typeof(System.Collections.IEnumerable).IsAss +signableFrom(p.PropertyType) || p.PropertyType == typeof(string)).ToArray(); + + using var streamWriter = new StreamWriter(stream, leaveOpen: true); + if (options.IncludeHeaderRow) + { + WriteCsvRow(streamWriter, properties.Select(p => p.Name), options); + totalRowCount++; + } + + foreach (var entity in entities) + { + WriteCsvRow(streamWriter, properties.Select(p => p.GetValue(entity)), options); + dataRowCount++; + totalRowCount++; + } + + streamWriter.Flush(); + bytesWritten = streamWriter.BaseStream.Length; + return new QueryToFileResult { BytesWritten = bytesWritten, DataRowCount = dataRowCount, TotalRowCount = totalRo +owCount }; + } + private static HashSet GetIncludedColumns(TableMapping tableMapping, Expression> inputCol +lumns, Expression> ignoreColumns) + { + var includedColumns = inputColumns != null + ? inputColumns.GetObjectProperties().ToHashSet() + : tableMapping.Properties.Select(p => p.Name).ToHashSet(); + + if (ignoreColumns != null) + includedColumns.ExceptWith(ignoreColumns.GetObjectProperties()); + + return includedColumns; + } + private static void ClearExcludedColumns(DbContext dbContext, TableMapping tableMapping, T entity, HashSet includedColumns) where T : class + { + var entry = dbContext.Entry(entity); + foreach (var property in tableMapping.Properties) + { + if (includedColumns.Contains(property.Name)) + continue; + + object defaultValue = property.ClrType.IsValueType ? Activator.CreateInstance(property.ClrType) : null; + if (property.DeclaringType is IComplexType complexType) + { + var complexProperty = entry.ComplexProperty(complexType.ComplexProperty); + if (complexProperty.CurrentValue != null) + complexProperty.Property(property).CurrentValue = defaultValue; + } + else + { + entry.Property(property.Name).CurrentValue = defaultValue; + } + } + } + private static void WriteCsvRow(TextWriter writer, IEnumerable values, QueryToFileOptions options) + { + bool first = true; + foreach (var value in values) + { + if (!first) + writer.Write(options.ColumnDelimiter); + + writer.Write(options.TextQualifer); + writer.Write(value); + writer.Write(options.TextQualifer); + first = false; + } + writer.Write(options.RowDelimiter); + } + private static Action> BuildSetPropertyCalls(Expression> updateExpression) whe +ere T : class + { + if (updateExpression.Body is not MemberInitExpression memberInitExpression) + throw new InvalidOperationException("UpdateFromQuery requires a member initialization expression."); + + var entityParameter = updateExpression.Parameters[0]; + var setPropertyMethod = typeof(UpdateSettersBuilder) + .GetMethods() + .Single(m => m.Name == nameof(UpdateSettersBuilder.SetProperty) && m.GetParameters().Length == 2 && m.Get +tParameters()[1].ParameterType.IsGenericType); + + return setters => + { + object currentBuilder = setters; + foreach (var binding in memberInitExpression.Bindings.OfType()) + { + var propertyInfo = binding.Member as PropertyInfo ?? throw new InvalidOperationException("Only property + bindings are supported."); + var propertyLambda = Expression.Lambda(Expression.Property(entityParameter, propertyInfo), entityParamet +ter); + var valueLambda = Expression.Lambda(binding.Expression, entityParameter); + currentBuilder = setPropertyMethod.MakeGenericMethod(propertyInfo.PropertyType).Invoke(currentBuilder, [ +[propertyLambda, valueLambda]); + } + }; + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Data\DbContextEx +xtensionsAsync.cs --- + +using System; +using System.Collections.Generic; +using System.Data; +using System.Data.Common; +using System.IO; +using System.Linq; +using System.Linq.Expressions; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Data.SqlClient; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Internal; +using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.EntityFrameworkCore.Query; +using N.EntityFrameworkCore.Extensions.Common; +using N.EntityFrameworkCore.Extensions.Enums; +using N.EntityFrameworkCore.Extensions.Extensions; +using N.EntityFrameworkCore.Extensions.Sql; +using N.EntityFrameworkCore.Extensions.Util; + +namespace N.EntityFrameworkCore.Extensions; + +public static class DbContextExtensionsAsync +{ + public static async Task BulkDeleteAsync(this DbContext context, IEnumerable entities, CancellationToken + cancellationToken = default) + { + return await context.BulkDeleteAsync(entities, new BulkDeleteOptions(), cancellationToken); + } + public static async Task BulkDeleteAsync(this DbContext context, IEnumerable entities, Action> optionsAction, CancellationToken cancellationToken = default) + { + return await context.BulkDeleteAsync(entities, optionsAction.Build(), cancellationToken); + } + public static async Task BulkDeleteAsync(this DbContext context, IEnumerable entities, BulkDeleteOptions< + options, CancellationToken cancellationToken = default) + { + int rowsAffected = 0; + var tableMapping = context.GetTableMapping(typeof(T), options.EntityType); + + using (var dbTransactionContext = new DbTransactionContext(context, options)) + { + var dbConnection = dbTransactionContext.Connection; + var transaction = dbTransactionContext.CurrentTransaction; + try + { + string stagingTableName = CommonUtil.GetStagingTableName(tableMapping, options.UsePermanentTable, dbConn +nection); + string destinationTableName = context.DelimitIdentifier(tableMapping.TableName, tableMapping.Schema); + string[] keyColumnNames = options.DeleteOnCondition != null ? CommonUtil.GetColumns(options.DeleteOnC +Condition, ["s"]) + : tableMapping.GetPrimaryKeyColumns().ToArray(); + + if (keyColumnNames.Length == 0 && options.DeleteOnCondition == null) + throw new InvalidDataException("BulkDelete requires that the entity have a primary key or the Option +ns.DeleteOnCondition must be set."); + + await context.Database.CloneTableAsync(destinationTableName, stagingTableName, keyColumnNames, null, can +ncellationToken); + await BulkInsertAsync(entities, options, tableMapping, dbConnection, transaction, stagingTableName, keyC +ColumnNames, SqlBulkCopyOptions.KeepIdentity, + false, cancellationToken); + string joinCondition = CommonUtil.GetJoinConditionSql(context, options.DeleteOnCondition, keyColumnNa +ames); + string deleteSql = $"DELETE t FROM {stagingTableName} s JOIN {destinationTableName} t ON {joinCondition} +}"; + rowsAffected = await context.Database.ExecuteSqlRawAsync(deleteSql, cancellationToken); + context.Database.DropTable(stagingTableName); + dbTransactionContext.Commit(); + } + catch (Exception) + { + dbTransactionContext.Rollback(); + throw; + } + return rowsAffected; + } + } + public static async Task> BulkFetchAsync(this DbSet dbSet, IEnumerable entities, Cancella +ationToken cancellationToken = default) where T : class, new() + { + return await dbSet.BulkFetchAsync(entities, new BulkFetchOptions(), cancellationToken); + } + public static async Task> BulkFetchAsync(this DbSet dbSet, IEnumerable entities, Action> optionsAction, CancellationToken cancellationToken = default) where T : class, new() + { + return await dbSet.BulkFetchAsync(entities, optionsAction.Build(), cancellationToken); + } + public static async Task> BulkFetchAsync(this DbSet dbSet, IEnumerable entities, BulkFetc +chOptions options, CancellationToken cancellationToken = default) where T : class, new() + { + var context = dbSet.GetDbContext(); + var tableMapping = context.GetTableMapping(typeof(T)); + + using (var dbTransactionContext = new DbTransactionContext(context, options.CommandTimeout, ConnectionBehavior.N +New)) + { + string selectSql; + var dbConnection = dbTransactionContext.Connection; + var transaction = dbTransactionContext.CurrentTransaction; + string stagingTableName = string.Empty; + try + { + stagingTableName = CommonUtil.GetStagingTableName(tableMapping, true, dbConnection); + string destinationTableName = context.DelimitIdentifier(tableMapping.TableName, tableMapping.Schema); + string[] keyColumnNames = options.JoinOnCondition != null ? CommonUtil.GetColumns(options.JoinOnCondi +ition, ["s"]) + : tableMapping.GetPrimaryKeyColumns().ToArray(); + IEnumerable columnNames = CommonUtil.FilterColumns(tableMapping.GetColumns(true), keyColumnNa +ames, options.InputColumns, options.IgnoreColumns); + IEnumerable columnsToFetch = CommonUtil.FormatColumns(context, "t", columnNames); + + if (keyColumnNames.Length == 0 && options.JoinOnCondition == null) + throw new InvalidDataException("BulkFetch requires that the entity have a primary key or the Options +s.JoinOnCondition must be set."); + + await context.Database.CloneTableAsync(destinationTableName, stagingTableName, keyColumnNames, null, can +ncellationToken); + await BulkInsertAsync(entities, options, tableMapping, dbConnection, transaction, stagingTableName, keyC +ColumnNames, SqlBulkCopyOptions.KeepIdentity, false, cancellationToken); + selectSql = $"SELECT {SqlUtil.ConvertToColumnString(columnsToFetch)} FROM {stagingTableName} s JOIN {des +stinationTableName} t ON {CommonUtil.GetJoinConditionSql(context, options.JoinOnCondition, keyColumnNames)}"; + + dbTransactionContext.Commit(); + } + catch + { + dbTransactionContext.Rollback(); + throw; + } + + var results = await context.FetchInternalAsync(selectSql, cancellationToken: cancellationToken); + context.Database.DropTable(stagingTableName); + return results; + } + } + public static async Task FetchAsync(this IQueryable queryable, Func, Task> action, Action> optionsAction, CancellationToken cancellationToken = default) where T : class, new() + { + await FetchAsync(queryable, action, optionsAction.Build(), cancellationToken); + } + public static async Task FetchAsync(this IQueryable queryable, Func, Task> action, FetchOptions +s options, CancellationToken cancellationToken = default) where T : class, new() + { + var dbContext = queryable.GetDbContext(); + var tableMapping = dbContext.GetTableMapping(typeof(T)); + HashSet includedColumns = GetIncludedColumns(tableMapping, options.InputColumns, options.IgnoreColumns); + int batch = 1; + int count = 0; + List entities = []; + await foreach (var entity in queryable.AsNoTracking().AsAsyncEnumerable().WithCancellation(cancellationToken)) + { + ClearExcludedColumns(dbContext, tableMapping, entity, includedColumns); + entities.Add(entity); + count++; + if (count == options.BatchSize) + { + await action(new FetchResult { Results = entities, Batch = batch }); + entities.Clear(); + count = 0; + batch++; + } + cancellationToken.ThrowIfCancellationRequested(); + } + + if (entities.Count > 0) + await action(new FetchResult { Results = entities, Batch = batch }); + } + public static async Task BulkInsertAsync(this DbContext context, IEnumerable entities, CancellationToken + cancellationToken = default) + { + return await context.BulkInsertAsync(entities, new BulkInsertOptions(), cancellationToken); + } + public static async Task BulkInsertAsync(this DbContext context, IEnumerable entities, Action> optionsAction, CancellationToken cancellationToken = default) + { + return await context.BulkInsertAsync(entities, optionsAction.Build(), cancellationToken); + } + public static async Task BulkInsertAsync(this DbContext context, IEnumerable entities, BulkInsertOptions< + options, CancellationToken cancellationToken = default) + { + int rowsAffected = 0; + using (var bulkOperation = new BulkOperation(context, options, options.InputColumns, options.IgnoreColumns)) + { + try + { + var bulkInsertResult = await bulkOperation.BulkInsertStagingDataAsync(entities, true, true); + var bulkMergeResult = await bulkOperation.ExecuteMergeAsync(bulkInsertResult.EntityMap, options.InsertOn +nCondition, + options.AutoMapOutput, options.KeepIdentity, options.InsertIfNotExists); + rowsAffected = bulkMergeResult.RowsAffected; + bulkOperation.DbTransactionContext.Commit(); + } + catch (Exception) + { + bulkOperation.DbTransactionContext.Rollback(); + throw; + } + } + return rowsAffected; + } + public static async Task> BulkMergeAsync(this DbContext context, IEnumerable entities, Canc +cellationToken cancellationToken = default) + { + return await BulkMergeAsync(context, entities, new BulkMergeOptions(), cancellationToken); + } + public static async Task> BulkMergeAsync(this DbContext context, IEnumerable entities, Bulk +kMergeOptions options, CancellationToken cancellationToken = default) + { + return await InternalBulkMergeAsync(context, entities, options, cancellationToken); + } + public static async Task> BulkMergeAsync(this DbContext context, IEnumerable entities, Acti +ion> optionsAction, CancellationToken cancellationToken = default) + { + return await BulkMergeAsync(context, entities, optionsAction.Build(), cancellationToken); + } + public static async Task BulkSaveChangesAsync(this DbContext dbContext) + { + return await dbContext.BulkSaveChangesAsync(true); + } + public static async Task BulkSaveChangesAsync(this DbContext dbContext, bool acceptAllChangesOnSuccess = true) + { + int rowsAffected = 0; + var stateManager = dbContext.GetDependencies().StateManager; + + dbContext.ChangeTracker.DetectChanges(); + var entries = stateManager.GetEntriesToSave(true); + + foreach (var saveEntryGroup in entries.GroupBy(o => new { o.EntityType, o.EntityState })) + { + var key = saveEntryGroup.Key; + var entities = saveEntryGroup.AsEnumerable(); + if (key.EntityState == EntityState.Added) + { + rowsAffected += await dbContext.BulkInsertAsync(entities, o => { o.EntityType = key.EntityType; }); + } + else if (key.EntityState == EntityState.Modified) + { + rowsAffected += await dbContext.BulkUpdateAsync(entities, o => { o.EntityType = key.EntityType; }); + } + else if (key.EntityState == EntityState.Deleted) + { + rowsAffected += await dbContext.BulkDeleteAsync(entities, o => { o.EntityType = key.EntityType; }); + } + } + + if (acceptAllChangesOnSuccess) + dbContext.ChangeTracker.AcceptAllChanges(); + + return rowsAffected; + } + public static async Task> BulkSyncAsync(this DbContext context, IEnumerable entities, Cancel +llationToken cancellationToken = default) + { + return await BulkSyncAsync(context, entities, new BulkSyncOptions(), cancellationToken); + } + public static async Task> BulkSyncAsync(this DbContext context, IEnumerable entities, Action +n> optionsAction, CancellationToken cancellationToken = default) + { + return BulkSyncResult.Map(await InternalBulkMergeAsync(context, entities, optionsAction.Build(), cancellation +nToken)); + } + public static async Task> BulkSyncAsync(this DbContext context, IEnumerable entities, BulkSy +yncOptions options, CancellationToken cancellationToken = default) + { + return BulkSyncResult.Map(await InternalBulkMergeAsync(context, entities, options, cancellationToken)); + } + public static async Task BulkUpdateAsync(this DbContext context, IEnumerable entities, CancellationToken + cancellationToken = default) + { + return await BulkUpdateAsync(context, entities, new BulkUpdateOptions(), cancellationToken); + } + public static async Task BulkUpdateAsync(this DbContext context, IEnumerable entities, Action> optionsAction, CancellationToken cancellationToken = default) + { + return await BulkUpdateAsync(context, entities, optionsAction.Build(), cancellationToken); + } + public static async Task BulkUpdateAsync(this DbContext context, IEnumerable entities, BulkUpdateOptions< + options, CancellationToken cancellationToken = default) + { + int rowsUpdated = 0; + using (var bulkOperation = new BulkOperation(context, options, options.InputColumns, options.IgnoreColumns)) + { + try + { + bulkOperation.ValidateBulkUpdate(options.UpdateOnCondition); + await bulkOperation.BulkInsertStagingDataAsync(entities, cancellationToken: cancellationToken); + rowsUpdated = await bulkOperation.ExecuteUpdateAsync(entities, options.UpdateOnCondition, cancellationTo +oken); + bulkOperation.DbTransactionContext.Commit(); + } + catch (Exception) + { + bulkOperation.DbTransactionContext.Rollback(); + throw; + } + } + return rowsUpdated; + } + public static async Task DeleteFromQueryAsync(this IQueryable queryable, int? commandTimeout = null, Canc +cellationToken cancellationToken = default) where T : class + { + var dbContext = queryable.GetDbContext(); + using (var dbTransactionContext = new DbTransactionContext(dbContext, commandTimeout)) + { + try + { + int rowsAffected = await queryable.ExecuteDeleteAsync(cancellationToken); + dbTransactionContext.Commit(); + return rowsAffected; + } + catch (Exception) + { + dbTransactionContext.Rollback(); + throw; + } + } + } + public static async Task InsertFromQueryAsync(this IQueryable queryable, string tableName, Expression> insertObjectExpression, int? commandTimeout = null, + CancellationToken cancellationToken = default) where T : class + { + var dbContext = queryable.GetDbContext(); + using (var dbTransactionContext = new DbTransactionContext(dbContext, commandTimeout)) + { + try + { + var tableMapping = dbContext.GetTableMapping(typeof(T)); + var columnNames = insertObjectExpression.GetObjectProperties(); + if (!dbContext.Database.TableExists(tableName)) + { + await dbContext.Database.CloneTableAsync(tableMapping.FullQualifedTableName, dbContext.Database.Deli +imitTableName(tableName), tableMapping.GetQualifiedColumnNames(columnNames), cancellationToken: cancellationToken); + } + + var entities = await queryable.AsNoTracking().ToListAsync(cancellationToken); + int rowsAffected = (int)(await BulkInsertAsync(entities, new BulkInsertOptions { KeepIdentity = true, +, AutoMapOutput = false, CommandTimeout = commandTimeout }, tableMapping, + dbTransactionContext.Connection, dbTransactionContext.CurrentTransaction, dbContext.Database.Delimit +tTableName(tableName), columnNames, SqlBulkCopyOptions.KeepIdentity, cancellationToken: cancellationToken)).RowsAffected; + dbTransactionContext.Commit(); + return rowsAffected; + } + catch (Exception) + { + dbTransactionContext.Rollback(); + throw; + } + } + } + public static async Task UpdateFromQueryAsync(this IQueryable queryable, Expression> updateExp +pression, int? commandTimeout = null, + CancellationToken cancellationToken = default) where T : class + { + var dbContext = queryable.GetDbContext(); + using (var dbTransactionContext = new DbTransactionContext(dbContext, commandTimeout)) + { + try + { + int rowsAffected = await queryable.ExecuteUpdateAsync(BuildSetPropertyCalls(updateExpression), cancellat +tionToken); + dbTransactionContext.Commit(); + return rowsAffected; + } + catch (Exception) + { + dbTransactionContext.Rollback(); + throw; + } + } + } + public static async Task QueryToCsvFileAsync(this IQueryable queryable, string filePath, Ca +ancellationToken cancellationToken = default) where T : class + { + return await QueryToCsvFileAsync(queryable, filePath, new QueryToFileOptions(), cancellationToken); + } + public static async Task QueryToCsvFileAsync(this IQueryable queryable, Stream stream, Canc +cellationToken cancellationToken = default) where T : class + { + return await QueryToCsvFileAsync(queryable, stream, new QueryToFileOptions(), cancellationToken); + } + public static async Task QueryToCsvFileAsync(this IQueryable queryable, string filePath, Ac +ction optionsAction, + CancellationToken cancellationToken = default) where T : class + { + return await QueryToCsvFileAsync(queryable, filePath, optionsAction.Build(), cancellationToken); + } + public static async Task QueryToCsvFileAsync(this IQueryable queryable, Stream stream, Acti +ion optionsAction, + CancellationToken cancellationToken = default) where T : class + { + return await QueryToCsvFileAsync(queryable, stream, optionsAction.Build(), cancellationToken); + } + public static async Task QueryToCsvFileAsync(this IQueryable queryable, string filePath, Qu +ueryToFileOptions options, + CancellationToken cancellationToken = default) where T : class + { + await using var fileStream = File.Create(filePath); + return await QueryToCsvFileAsync(queryable, fileStream, options, cancellationToken); + } + public static async Task QueryToCsvFileAsync(this IQueryable queryable, Stream stream, Quer +ryToFileOptions options, + CancellationToken cancellationToken = default) where T : class + { + return await InternalQueryToFileAsync(queryable, stream, options, cancellationToken); + } + public static async Task SqlQueryToCsvFileAsync(this DatabaseFacade database, string filePath, st +tring sqlText, object[] parameters, + CancellationToken cancellationToken = default) + { + return await SqlQueryToCsvFileAsync(database, filePath, new QueryToFileOptions(), sqlText, parameters, cancellat +tionToken); + } + public static async Task SqlQueryToCsvFileAsync(this DatabaseFacade database, Stream stream, stri +ing sqlText, object[] parameters, + CancellationToken cancellationToken = default) + { + return await SqlQueryToCsvFileAsync(database, stream, new QueryToFileOptions(), sqlText, parameters, cancellatio +onToken); + } + public static async Task SqlQueryToCsvFileAsync(this DatabaseFacade database, string filePath, Ac +ction optionsAction, string sqlText, object[] parameters, + CancellationToken cancellationToken = default) + { + return await SqlQueryToCsvFileAsync(database, filePath, optionsAction.Build(), sqlText, parameters, cancellation +nToken); + } + public static async Task SqlQueryToCsvFileAsync(this DatabaseFacade database, Stream stream, Acti +ion optionsAction, string sqlText, object[] parameters, + CancellationToken cancellationToken = default) + { + return await SqlQueryToCsvFileAsync(database, stream, optionsAction.Build(), sqlText, parameters, cancellationTo +oken); + } + public static async Task SqlQueryToCsvFileAsync(this DatabaseFacade database, string filePath, Qu +ueryToFileOptions options, string sqlText, object[] parameters, + CancellationToken cancellationToken = default) + { + await using var fileStream = File.Create(filePath); + return await SqlQueryToCsvFileAsync(database, fileStream, options, sqlText, parameters, cancellationToken); + } + public static async Task SqlQueryToCsvFileAsync(this DatabaseFacade database, Stream stream, Quer +ryToFileOptions options, string sqlText, object[] parameters, + CancellationToken cancellationToken = default) + { + return await InternalQueryToFileAsync(database.GetDbConnection(), stream, options, sqlText, parameters, cancella +ationToken); + } + public static async Task ClearAsync(this DbSet dbSet, CancellationToken cancellationToken = default) where T : +: class + { + var dbContext = dbSet.GetDbContext(); + var tableMapping = dbContext.GetTableMapping(typeof(T)); + await dbContext.Database.ClearTableAsync(tableMapping.FullQualifedTableName, cancellationToken); + } + public static async Task TruncateAsync(this DbSet dbSet, CancellationToken cancellationToken = default) where + T : class + { + var dbContext = dbSet.GetDbContext(); + var tableMapping = dbContext.GetTableMapping(typeof(T)); + await dbContext.Database.TruncateTableAsync(tableMapping.FullQualifedTableName, false, cancellationToken); + } + internal static async Task> BulkInsertAsync(IEnumerable entities, BulkOptions options, Tab +bleMapping tableMapping, DbConnection dbConnection, DbTransaction transaction, string tableName, + IEnumerable inputColumns = null, SqlBulkCopyOptions bulkCopyOptions = SqlBulkCopyOptions.Default, bool u +useInternalId = false, CancellationToken cancellationToken = default) + { + using var dataReader = new EntityDataReader(tableMapping, entities, useInternalId); + var sqlBulkCopy = new SqlBulkCopy((SqlConnection)dbConnection, bulkCopyOptions, (SqlTransaction)transaction) + { + DestinationTableName = tableName, + BatchSize = options.BatchSize + }; + if (options.CommandTimeout.HasValue) + { + sqlBulkCopy.BulkCopyTimeout = options.CommandTimeout.Value; + } + foreach (var property in dataReader.TableMapping.Properties) + { + var columnName = dataReader.TableMapping.GetColumnName(property); + if (inputColumns == null || inputColumns.Contains(columnName)) + sqlBulkCopy.ColumnMappings.Add(columnName, columnName); + } + if (useInternalId) + { + sqlBulkCopy.ColumnMappings.Add(Constants.InternalId_ColumnName, Constants.InternalId_ColumnName); + } + await sqlBulkCopy.WriteToServerAsync(dataReader, cancellationToken); + + return new BulkInsertResult + { + RowsAffected = sqlBulkCopy.RowsCopied, + EntityMap = dataReader.EntityMap + }; + } + internal static async Task BulkQueryAsync(this DbContext context, string sqlText, DbConnection dbCo +onnection, DbTransaction transaction, BulkOptions options, CancellationToken cancellationToken = default) + { + List results = []; + List columns = []; + await using var command = dbConnection.CreateCommand(); + command.CommandText = sqlText; + command.Transaction = transaction; + if (options.CommandTimeout.HasValue) + command.CommandTimeout = options.CommandTimeout.Value; + await using var reader = await command.ExecuteReaderAsync(cancellationToken); + while (await reader.ReadAsync(cancellationToken)) + { + if (columns.Count == 0) + { + for (int i = 0; i < reader.FieldCount; i++) + columns.Add(reader.GetName(i)); + } + object[] values = new object[reader.FieldCount]; + reader.GetValues(values); + results.Add(values); + } + + return new BulkQueryResult + { + Columns = columns, + Results = results, + RowsAffected = reader.RecordsAffected + }; + } + private static async Task> InternalBulkMergeAsync(this DbContext context, IEnumerable entit +ties, BulkMergeOptions options, CancellationToken cancellationToken = default) + { + BulkMergeResult bulkMergeResult; + using (var bulkOperation = new BulkOperation(context, options)) + { + try + { + bulkOperation.ValidateBulkMerge(options.MergeOnCondition); + var bulkInsertResult = await bulkOperation.BulkInsertStagingDataAsync(entities, true, true, cancellation +nToken); + bulkMergeResult = await bulkOperation.ExecuteMergeAsync(bulkInsertResult.EntityMap, options.MergeOnCondi +ition, options.AutoMapOutput, + false, true, true, options.DeleteIfNotMatched, cancellationToken); + bulkOperation.DbTransactionContext.Commit(); + } + catch (Exception) + { + bulkOperation.DbTransactionContext.Rollback(); + throw; + } + } + return bulkMergeResult; + } + private static async Task InternalQueryToFileAsync(this IQueryable queryable, Stream stream +m, QueryToFileOptions options, + CancellationToken cancellationToken = default) where T : class + { + return await InternalQueryToFileAsync(queryable.AsNoTracking().AsAsyncEnumerable(), stream, options, cancellatio +onToken); + } + private static async Task InternalQueryToFileAsync(DbConnection dbConnection, Stream stream, Quer +ryToFileOptions options, string sqlText, object[] parameters = null, + CancellationToken cancellationToken = default) + { + int dataRowCount = 0; + int totalRowCount = 0; + long bytesWritten = 0; + + if (dbConnection.State == ConnectionState.Closed) + dbConnection.Open(); + + await using var command = dbConnection.CreateCommand(); + command.CommandText = sqlText; + if (parameters != null) + command.Parameters.AddRange(parameters); + if (options.CommandTimeout.HasValue) + command.CommandTimeout = options.CommandTimeout.Value; + + await using var streamWriter = new StreamWriter(stream, leaveOpen: true); + using (var reader = await command.ExecuteReaderAsync(cancellationToken)) + { + if (options.IncludeHeaderRow) + { + for (int i = 0; i < reader.FieldCount; i++) + { + streamWriter.Write(options.TextQualifer); + streamWriter.Write(reader.GetName(i)); + streamWriter.Write(options.TextQualifer); + if (i != reader.FieldCount - 1) + { + await streamWriter.WriteAsync(options.ColumnDelimiter); + } + } + totalRowCount++; + await streamWriter.WriteAsync(options.RowDelimiter); + } + while (await reader.ReadAsync(cancellationToken)) + { + object[] values = new object[reader.FieldCount]; + reader.GetValues(values); + for (int i = 0; i < values.Length; i++) + { + streamWriter.Write(options.TextQualifer); + streamWriter.Write(values[i]); + streamWriter.Write(options.TextQualifer); + if (i != values.Length - 1) + { + await streamWriter.WriteAsync(options.ColumnDelimiter); + } + } + await streamWriter.WriteAsync(options.RowDelimiter); + dataRowCount++; + totalRowCount++; + } + await streamWriter.FlushAsync(); + bytesWritten = streamWriter.BaseStream.Length; + } + return new QueryToFileResult() + { + BytesWritten = bytesWritten, + DataRowCount = dataRowCount, + TotalRowCount = totalRowCount + }; + } + private static async Task InternalQueryToFileAsync(IAsyncEnumerable entities, Stream stream +m, QueryToFileOptions options, CancellationToken cancellationToken) where T : class + { + int dataRowCount = 0; + int totalRowCount = 0; + long bytesWritten = 0; + var properties = typeof(T).GetProperties().Where(p => p.CanRead && (!typeof(System.Collections.IEnumerable).IsAs +ssignableFrom(p.PropertyType) || p.PropertyType == typeof(string))).ToArray(); + + await using var streamWriter = new StreamWriter(stream, leaveOpen: true); + if (options.IncludeHeaderRow) + { + await WriteCsvRowAsync(streamWriter, properties.Select(p => (object)p.Name), options, cancellationToken); + totalRowCount++; + } + + await foreach (var entity in entities.WithCancellation(cancellationToken)) + { + await WriteCsvRowAsync(streamWriter, properties.Select(p => p.GetValue(entity)), options, cancellationToken) +); + dataRowCount++; + totalRowCount++; + } + + await streamWriter.FlushAsync(cancellationToken); + bytesWritten = streamWriter.BaseStream.Length; + return new QueryToFileResult { BytesWritten = bytesWritten, DataRowCount = dataRowCount, TotalRowCount = totalRo +owCount }; + } + private static async Task> FetchInternalAsync(this DbContext dbContext, string sqlText, object[] p +parameters = null, CancellationToken cancellationToken = default) where T : class, new() + { + List results = []; + await using var command = dbContext.Database.CreateCommand(ConnectionBehavior.New); + command.CommandText = sqlText; + if (parameters != null) + command.Parameters.AddRange(parameters); + + var tableMapping = dbContext.GetTableMapping(typeof(T), null); + var reader = await command.ExecuteReaderAsync(cancellationToken); + var properties = reader.GetProperties(tableMapping); + var valuesFromProvider = properties.Select(p => tableMapping.GetValueFromProvider(p)).ToArray(); + + while (await reader.ReadAsync(cancellationToken)) + { + var entity = reader.MapEntity(dbContext, properties, valuesFromProvider); + results.Add(entity); + } + + await reader.CloseAsync(); + await command.Connection.CloseAsync(); + return results; + } + private static HashSet GetIncludedColumns(TableMapping tableMapping, Expression> inputCol +lumns, Expression> ignoreColumns) + { + var includedColumns = inputColumns != null + ? inputColumns.GetObjectProperties().ToHashSet() + : tableMapping.Properties.Select(p => p.Name).ToHashSet(); + + if (ignoreColumns != null) + includedColumns.ExceptWith(ignoreColumns.GetObjectProperties()); + + return includedColumns; + } + private static void ClearExcludedColumns(DbContext dbContext, TableMapping tableMapping, T entity, HashSet includedColumns) where T : class + { + var entry = dbContext.Entry(entity); + foreach (var property in tableMapping.Properties) + { + if (includedColumns.Contains(property.Name)) + continue; + + object defaultValue = property.ClrType.IsValueType ? Activator.CreateInstance(property.ClrType) : null; + if (property.DeclaringType is IComplexType complexType) + { + var complexProperty = entry.ComplexProperty(complexType.ComplexProperty); + if (complexProperty.CurrentValue != null) + complexProperty.Property(property).CurrentValue = defaultValue; + } + else + { + entry.Property(property.Name).CurrentValue = defaultValue; + } + } + } + private static async Task WriteCsvRowAsync(TextWriter writer, IEnumerable values, QueryToFileOptions options +s, CancellationToken cancellationToken) + { + bool first = true; + foreach (var value in values) + { + if (!first) + await writer.WriteAsync(options.ColumnDelimiter); + + await writer.WriteAsync(options.TextQualifer); + await writer.WriteAsync(value?.ToString()); + await writer.WriteAsync(options.TextQualifer); + first = false; + cancellationToken.ThrowIfCancellationRequested(); + } + await writer.WriteAsync(options.RowDelimiter); + } + private static Action> BuildSetPropertyCalls(Expression> updateExpression) whe +ere T : class + { + if (updateExpression.Body is not MemberInitExpression memberInitExpression) + throw new InvalidOperationException("UpdateFromQuery requires a member initialization expression."); + + var entityParameter = updateExpression.Parameters[0]; + var setPropertyMethod = typeof(UpdateSettersBuilder) + .GetMethods() + .Single(m => m.Name == nameof(UpdateSettersBuilder.SetProperty) && m.GetParameters().Length == 2 && m.Get +tParameters()[1].ParameterType.IsGenericType); + + return setters => + { + object currentBuilder = setters; + foreach (var binding in memberInitExpression.Bindings.OfType()) + { + var propertyInfo = binding.Member as System.Reflection.PropertyInfo ?? throw new InvalidOperationExcepti +ion("Only property bindings are supported."); + var propertyLambda = Expression.Lambda(Expression.Property(entityParameter, propertyInfo), entityParamet +ter); + var valueLambda = Expression.Lambda(binding.Expression, entityParameter); + currentBuilder = setPropertyMethod.MakeGenericMethod(propertyInfo.PropertyType).Invoke(currentBuilder, [ +[propertyLambda, valueLambda]); + } + }; + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Data\DbTransacti +ionContext.cs --- + +using System; +using System.Data.Common; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Storage; +using N.EntityFrameworkCore.Extensions.Enums; +using N.EntityFrameworkCore.Extensions.Util; + + +namespace N.EntityFrameworkCore.Extensions; + +internal sealed class DbTransactionContext : IDisposable +{ + private bool closeConnection; + private bool ownsTransaction; + private int? defaultCommandTimeout; + private DbContext context; + private IDbContextTransaction transaction; + + public DbConnection Connection { get; internal set; } + public DbTransaction CurrentTransaction { get; private set; } + public DbContext DbContext => context; + + public DbTransactionContext(DbContext context, BulkOptions bulkOptions, bool openConnection = true) : this(context, + bulkOptions.CommandTimeout, bulkOptions.ConnectionBehavior, openConnection) + { + + } + public DbTransactionContext(DbContext context, int? commandTimeout = null, ConnectionBehavior connectionBehavior = C +ConnectionBehavior.Default, bool openConnection = true) + { + this.context = context; + Connection = context.GetDbConnection(connectionBehavior); + if (openConnection) + { + if (Connection.State == System.Data.ConnectionState.Closed) + { + Connection.Open(); + closeConnection = true; + } + } + if (connectionBehavior == ConnectionBehavior.Default) + { + ownsTransaction = context.Database.CurrentTransaction == null; + transaction = context.Database.CurrentTransaction; + defaultCommandTimeout = context.Database.GetCommandTimeout(); + if (transaction != null) + CurrentTransaction = transaction.GetDbTransaction(); + } + + context.Database.SetCommandTimeout(commandTimeout); + } + + public void Dispose() + { + context.Database.SetCommandTimeout(defaultCommandTimeout); + if (closeConnection) + { + Connection.Close(); + } + } + + internal void Commit() + { + if (ownsTransaction && transaction != null) + transaction.Commit(); + } + internal void Rollback() + { + if (transaction != null) + transaction.Rollback(); + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Data\EfExtension +nsCommand.cs --- + +using System.Data.Common; +using Microsoft.EntityFrameworkCore.Diagnostics; + +namespace N.EntityFrameworkCore.Extensions; + +internal sealed class EfExtensionsCommand +{ + public EfExtensionsCommandType CommandType { get; set; } + public string OldValue { get; set; } + public string NewValue { get; set; } + public DbConnection Connection { get; internal set; } + + internal bool Execute(DbCommand command, CommandEventData eventData, InterceptionResult result) + { + if (CommandType == EfExtensionsCommandType.ChangeTableName) + { + command.CommandText = command.CommandText.Replace(OldValue, NewValue); + } + + return true; + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Data\EfExtension +nsCommandInterceptor.cs --- + +using System; +using System.Collections.Concurrent; +using System.Data.Common; +using Microsoft.EntityFrameworkCore.Diagnostics; + +namespace N.EntityFrameworkCore.Extensions; + +public class EfExtensionsCommandInterceptor : DbCommandInterceptor +{ + private ConcurrentDictionary extensionCommands = new(); + public override InterceptionResult ReaderExecuting(DbCommand command, CommandEventData eventData, Inte +erceptionResult result) + { + foreach (var extensionCommand in extensionCommands) + { + if (extensionCommand.Value.Connection == command.Connection) + { + extensionCommand.Value.Execute(command, eventData, result); + extensionCommands.TryRemove(extensionCommand.Key, out _); + } + } + return result; + } + internal void AddCommand(Guid clientConnectionId, EfExtensionsCommand efExtensionsCommand) + { + extensionCommands.TryAdd(clientConnectionId, efExtensionsCommand); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Data\EntityDataR +Reader.cs --- + +using System; +using System.Collections.Generic; +using System.Data; +using Microsoft.EntityFrameworkCore.ChangeTracking; +using Microsoft.EntityFrameworkCore.ChangeTracking.Internal; +using Microsoft.EntityFrameworkCore.Metadata; +using N.EntityFrameworkCore.Extensions.Common; + +namespace N.EntityFrameworkCore.Extensions; + +internal sealed class EntityDataReader : IDataReader +{ + public TableMapping TableMapping { get; set; } + public Dictionary EntityMap { get; set; } + private Dictionary columnIndexes; + private int currentId; + private bool useInternalId; + private int tableFieldCount; + private IEnumerable entities; + private IEnumerator enumerator; + private Dictionary> selectors; + + public EntityDataReader(TableMapping tableMapping, IEnumerable entities, bool useInternalId) + { + this.columnIndexes = []; + this.currentId = 0; + this.useInternalId = useInternalId; + this.tableFieldCount = tableMapping.Properties.Length; + this.entities = entities; + this.enumerator = entities.GetEnumerator(); + this.selectors = []; + this.EntityMap = []; + this.FieldCount = tableMapping.Properties.Length; + this.TableMapping = tableMapping; + + + int i = 0; + foreach (var property in tableMapping.Properties) + { + selectors[i] = GetValueSelector(property); + columnIndexes[tableMapping.GetColumnName(property)] = i; + i++; + } + + if (useInternalId) + { + this.FieldCount++; + columnIndexes[Constants.InternalId_ColumnName] = i; + } + } + private Func GetValueSelector(IProperty property) + { + Func selector; + var valueGeneratorFactory = property.GetValueGeneratorFactory(); + if (valueGeneratorFactory != null) + { + var valueGenerator = valueGeneratorFactory.Invoke(property, this.TableMapping.EntityType); + selector = entry => valueGenerator.Next(entry); + } + else + { + var valueConverter = property.GetTypeMapping().Converter; + if (valueConverter != null) + { + selector = entry => valueConverter.ConvertToProvider(entry.CurrentValues[property]); + } + else + { + if (property.DeclaringType is IComplexType complexType) + { + selector = entry => entry.ComplexProperty(complexType.ComplexProperty).Property(property).CurrentVal +lue; + } + else + { + selector = entry => entry.CurrentValues[property]; + } + } + } + return selector; + } + public object this[int i] => throw new NotImplementedException(); + + public object this[string name] => throw new NotImplementedException(); + + public int Depth { get; set; } + + public bool IsClosed => throw new NotImplementedException(); + + public int RecordsAffected => throw new NotImplementedException(); + + public int FieldCount { get; set; } + + public void Close() + { + throw new NotImplementedException(); + } + + public void Dispose() + { + selectors = null; + enumerator.Dispose(); + } + + public bool GetBoolean(int i) + { + throw new NotImplementedException(); + } + + public byte GetByte(int i) + { + throw new NotImplementedException(); + } + + public long GetBytes(int i, long fieldOffset, byte[] buffer, int bufferoffset, int length) + { + throw new NotImplementedException(); + } + + public char GetChar(int i) + { + throw new NotImplementedException(); + } + + public long GetChars(int i, long fieldoffset, char[] buffer, int bufferoffset, int length) + { + throw new NotImplementedException(); + } + + public IDataReader GetData(int i) + { + throw new NotImplementedException(); + } + + public string GetDataTypeName(int i) + { + throw new NotImplementedException(); + } + + public DateTime GetDateTime(int i) + { + throw new NotImplementedException(); + } + + public decimal GetDecimal(int i) + { + throw new NotImplementedException(); + } + + public double GetDouble(int i) + { + throw new NotImplementedException(); + } + + public Type GetFieldType(int i) + { + throw new NotImplementedException(); + } + + public float GetFloat(int i) + { + throw new NotImplementedException(); + } + + public Guid GetGuid(int i) + { + throw new NotImplementedException(); + } + + public short GetInt16(int i) + { + throw new NotImplementedException(); + } + + public int GetInt32(int i) + { + throw new NotImplementedException(); + } + + public long GetInt64(int i) + { + throw new NotImplementedException(); + } + + public string GetName(int i) + { + throw new NotImplementedException(); + } + + public int GetOrdinal(string name) + { + return columnIndexes[name]; + } + + public DataTable GetSchemaTable() + { + throw new NotImplementedException(); + } + + public string GetString(int i) + { + throw new NotImplementedException(); + } + + public object GetValue(int i) + { + if (i == tableFieldCount) + { + return this.currentId; + } + else + { + return selectors[i](FindEntry(enumerator.Current)); + } + + } + + private EntityEntry FindEntry(object entity) + { + return entity is InternalEntityEntry internalEntry ? internalEntry.ToEntityEntry() : TableMapping.DbContext.Entr +ry(entity); + } + + public int GetValues(object[] values) + { + throw new NotImplementedException(); + } + + public bool IsDBNull(int i) + { + throw new NotImplementedException(); + } + + public bool NextResult() + { + throw new NotImplementedException(); + } + + public bool Read() + { + bool moveNext = enumerator.MoveNext(); + + if (moveNext && this.useInternalId) + { + this.currentId++; + this.EntityMap.Add(this.currentId, enumerator.Current); + } + return moveNext; + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Data\FetchOption +ns.cs --- + +using System; +using System.Linq.Expressions; + +namespace N.EntityFrameworkCore.Extensions; + +public class FetchOptions +{ + public Expression> IgnoreColumns { get; set; } + public Expression> InputColumns { get; set; } + public int BatchSize { get; set; } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Data\FetchResult +t.cs --- + +using System.Collections.Generic; + +namespace N.EntityFrameworkCore.Extensions; + +public class FetchResult +{ + public List Results { get; set; } + public int Batch { get; set; } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Data\QueryToFile +eOptions.cs --- + +namespace N.EntityFrameworkCore.Extensions; + +public class QueryToFileOptions +{ + public string ColumnDelimiter { get; set; } + public int? CommandTimeout { get; set; } + public bool IncludeHeaderRow { get; set; } + public string RowDelimiter { get; set; } + public string TextQualifer { get; set; } + + public QueryToFileOptions() + { + ColumnDelimiter = ","; + IncludeHeaderRow = true; + RowDelimiter = "\r\n"; + TextQualifer = ""; + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Data\QueryToFile +eResult.cs --- + +namespace N.EntityFrameworkCore.Extensions; + +public class QueryToFileResult +{ + public long BytesWritten { get; set; } + public int DataRowCount { get; internal set; } + public int TotalRowCount { get; internal set; } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Data\SqlMergeAct +tion.cs --- + +namespace N.EntityFrameworkCore.Extensions; + +internal static class SqlMergeAction +{ + public const string Insert = "INSERT"; + public const string Update = "UPDATE"; + public const string Delete = "DELETE"; +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Data\SqlQuery.cs +s --- + +using System; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.EntityFrameworkCore.Infrastructure; +using N.EntityFrameworkCore.Extensions.Sql; + +namespace N.EntityFrameworkCore.Extensions; + +public class SqlQuery +{ + private DatabaseFacade database; + public string SqlText { get; private set; } + public object[] Parameters { get; private set; } + + public SqlQuery(DatabaseFacade database, string sqlText, params object[] parameters) + { + this.database = database; + SqlText = sqlText; + Parameters = parameters; + } + + public int Count() + { + string countSqlText = SqlBuilder.Parse(SqlText).Count(); + return Convert.ToInt32(database.ExecuteScalar(countSqlText, Parameters)); + } + public async Task CountAsync(CancellationToken cancellationToken = default) + { + string countSqlText = SqlBuilder.Parse(SqlText).Count(); + return Convert.ToInt32(await database.ExecuteScalarAsync(countSqlText, Parameters, null, cancellationToken)); + } + public int ExecuteNonQuery() + { + return database.ExecuteSql(SqlText, Parameters); + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Data\TableMappin +ng.cs --- + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.EntityFrameworkCore.Metadata.Internal; +using N.EntityFrameworkCore.Extensions.Extensions; +using N.EntityFrameworkCore.Extensions.Util; + +namespace N.EntityFrameworkCore.Extensions; + +public class TableMapping +{ + public DbContext DbContext { get; private set; } + public IEntityType EntityType { get; set; } + public IProperty[] Properties { get; } + public string Schema { get; } + public string TableName { get; } + public IEnumerable EntityTypes { get; } + + public bool HasIdentityColumn => EntityType.FindPrimaryKey().Properties.Any(o => o.ValueGenerated != ValueGenerated. +.Never); + public StoreObjectIdentifier StoreObjectIdentifier => StoreObjectIdentifier.Table(TableName, EntityType.GetSchema() + ?? DbContext.Database.GetDefaultSchema()); + private Dictionary ColumnMap { get; set; } + public string FullQualifedTableName => DbContext.DelimitIdentifier(TableName, Schema); + + public TableMapping(DbContext dbContext, IEntityType entityType) + { + DbContext = dbContext; + EntityType = entityType; + Properties = GetProperties(entityType); + ColumnMap = Properties.Select(p => new KeyValuePair(GetColumnName(p), p)).ToDictionary(); + Schema = entityType.GetSchema() ?? dbContext.Database.GetDefaultSchema(); + TableName = entityType.GetTableName(); + EntityTypes = EntityType.GetAllBaseTypesInclusive().Where(o => !o.IsAbstract()); + } + public IProperty GetPropertyFromColumnName(string columnName) => ColumnMap[columnName]; + private static IProperty[] GetProperties(IEntityType entityType) + { + var properties = entityType.GetProperties().ToList(); + properties.AddRange(entityType.GetComplexProperties().SelectMany(p => p.ComplexType.GetProperties())); + return properties.ToArray(); + } + + public IEnumerable GetQualifiedColumnNames(IEnumerable columnNames, IEntityType entityType = null) + { + return Properties.Where(o => entityType == null || o.GetDeclaringEntityType() == entityType) + .Select(o => new + { + Column = FindColumn(o), + Name = GetColumnName(o) + }) + .Where(o => columnNames == null || columnNames.Contains(o.Name)) + .Select(o => $"{DbContext.DelimitIdentifier(o.Column?.Table.Name ?? TableName)}.{DbContext.DelimitIdentifier +r(o.Name)}").ToList(); + } + public string GetColumnName(IProperty property) => FindColumn(property)?.Name ?? property.Name; + private IColumnBase FindColumn(IProperty property) + { + var entityType = property.GetDeclaringEntityType(); + if (entityType == null || entityType.IsAbstract()) + entityType = EntityType; + var storeObjectIdentifier = StoreObjectIdentifier.Table(entityType.GetTableName(), entityType.GetSchema()); + return property.FindColumn(storeObjectIdentifier); + } + + private string FindTableName(IEntityType declaringEntityType, IEntityType entityType) => + declaringEntityType != null && declaringEntityType.IsAbstract() ? declaringEntityType.GetTableName() : entityTyp +pe.GetTableName(); + public IEnumerable GetColumnNames(IEntityType entityType, bool primaryKeyColumns) + { + List columns; + if (entityType != null) + { + columns = entityType.GetProperties().Where(o => (o.GetDeclaringEntityType() == entityType || o.GetDeclaringE +EntityType().IsAbstract() + || o.IsForeignKeyToSelf()) && o.ValueGenerated == ValueGenerated.Never) + .Select(GetColumnName).ToList(); + + columns.AddRange(entityType.GetComplexProperties().SelectMany(o => o.ComplexType.GetProperties() + .Select(GetColumnName))); + } + else + { + columns = EntityType.GetProperties().Where(o => o.ValueGenerated == ValueGenerated.Never) + .Select(GetColumnName).ToList(); + + columns.AddRange(EntityType.GetComplexProperties().SelectMany(o => o.ComplexType.GetProperties() + .Select(GetColumnName))); + } + if (primaryKeyColumns) + { + columns.AddRange(GetPrimaryKeyColumns()); + } + return columns.Distinct(); + } + public IEnumerable GetColumns(bool includePrimaryKeyColumns = false) + { + List columns = []; + foreach (var entityType in EntityTypes) + { + var storeObjectIdentifier = StoreObjectIdentifier.Create(entityType, StoreObjectType.Table).GetValueOrDefaul +lt(); + columns.AddRange(entityType.GetProperties().Where(o => o.ValueGenerated == ValueGenerated.Never) + .Select(GetColumnName)); + + columns.AddRange(EntityType.GetComplexProperties().SelectMany(o => o.ComplexType.GetProperties() + .Select(GetColumnName))); + + if (includePrimaryKeyColumns) + columns.AddRange(GetPrimaryKeyColumns()); + } + return columns.Where(o => o != null).Distinct(); + } + public IEnumerable GetPrimaryKeyColumns() => + EntityType.FindPrimaryKey().Properties.Select(GetColumnName); + + internal IEnumerable GetAutoGeneratedColumns(IEntityType entityType = null) + { + entityType ??= EntityType; + return entityType.GetProperties().Where(o => o.ValueGenerated != ValueGenerated.Never) + .Select(GetColumnName); + } + + internal IEnumerable GetEntityProperties(IEntityType entityType = null, ValueGenerated? valueGenerated = + null) + { + entityType ??= EntityType; + return entityType.GetProperties().Where(o => valueGenerated == null || o.ValueGenerated == valueGenerated).AsEnu +umerable(); + } + internal Func GetValueFromProvider(IProperty property) + { + var valueConverter = property.GetTypeMapping().Converter; + return valueConverter != null ? value => valueConverter.ConvertFromProvider(value) : value => value; + } + internal IEnumerable GetSchemaQualifiedTableNames() + { + return EntityTypes + .Select(o => DbContext.DelimitIdentifier(o.GetTableName(), o.GetSchema() ?? DbContext.Database.GetDefaultSch +hema())).Distinct(); + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Enums\Connection +nBehavior.cs --- + +namespace N.EntityFrameworkCore.Extensions.Enums; + +internal enum ConnectionBehavior +{ + Default, + New +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Enums\EfExtensio +onsCommandType.cs --- + +namespace N.EntityFrameworkCore.Extensions; + +internal enum EfExtensionsCommandType +{ + ChangeTableName +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Extensions\Commo +onExtensions.cs --- + +using System; + +namespace N.EntityFrameworkCore.Extensions.Extensions; + +internal static class CommonExtensions +{ + internal static T Build(this Action buildAction) where T : new() + { + var parameter = new T(); + buildAction(parameter); + return parameter; + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Extensions\DbDat +taReaderExtensions.cs --- + +using System; +using System.Collections.Generic; +using System.Data.Common; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Metadata; + +namespace N.EntityFrameworkCore.Extensions.Extensions; + +internal static class DbDataReaderExtensions +{ + internal static T MapEntity(this DbDataReader reader, DbContext dbContext, IProperty[] properties, Func[] valuesFromProvider) where T : class, new() + { + var entity = new T(); + var entry = dbContext.Entry(entity); + + for (var i = 0; i < reader.FieldCount; i++) + { + var property = properties[i]; + var value = valuesFromProvider[i].Invoke(reader.GetValue(i)); + if (value == DBNull.Value) + value = null; + + if (property.DeclaringType is IComplexType complexType) + { + var complexProperty = entry.ComplexProperty(complexType.ComplexProperty); + if (complexProperty.CurrentValue == null) + { + complexProperty.CurrentValue = Activator.CreateInstance(complexType.ClrType); + } + complexProperty.Property(property).CurrentValue = value; + } + else + { + entry.Property(property).CurrentValue = value; + } + } + return entity; + } + internal static IProperty[] GetProperties(this DbDataReader reader, TableMapping tableMapping) + { + List properties = []; + + for (var i = 0; i < reader.FieldCount; i++) + { + var property = tableMapping.GetPropertyFromColumnName(reader.GetName(i)); + properties.Add(property); + } + + return properties.ToArray(); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Extensions\IProp +pertyExtensions.cs --- + +using Microsoft.EntityFrameworkCore.Metadata; + +namespace N.EntityFrameworkCore.Extensions.Extensions; + +public static class IPropertyExtensions +{ + public static IEntityType GetDeclaringEntityType(this IProperty property) + { + return property.DeclaringType as IEntityType; + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Extensions\LinqE +Extensions.cs --- + +using System; +using System.Collections.Generic; +using System.Globalization; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Text; +using System.Text.RegularExpressions; +using Microsoft.EntityFrameworkCore; +using N.EntityFrameworkCore.Extensions.Util; + +namespace N.EntityFrameworkCore.Extensions; + +internal static class LinqExtensions +{ + internal static List GetObjectProperties(this Expression> expression) + { + if (expression == null) + { + return []; + } + else if (expression.Body is MemberExpression propertyExpression) + { + return [propertyExpression.Member.Name]; + } + else if (expression.Body is NewExpression newExpression) + { + return newExpression.Members.Select(o => o.Name).ToList(); + } + else if ((expression.Body is UnaryExpression unaryExpression) && (unaryExpression.Operand.GetPrivateFieldValue(" +"Member") is PropertyInfo propertyInfo)) + { + return [propertyInfo.Name]; + } + else + { + throw new InvalidOperationException("GetObjectProperties() encountered an unsupported expression type"); + } + } + internal static string ToSql(this ExpressionType expressionType) => expressionType switch + { + ExpressionType.AndAlso => "AND", + ExpressionType.Or => "OR", + ExpressionType.Add => "+", + ExpressionType.Subtract => "-", + ExpressionType.Multiply => "*", + ExpressionType.Divide => "/", + ExpressionType.Modulo => "%", + ExpressionType.Equal => "=", + _ => string.Empty + }; + + internal static string ToSql(this MemberBinding binding) + { + if (binding is MemberAssignment memberAssingment) + { + return GetExpressionValueAsString(memberAssingment.Expression); + } + else + { + throw new NotSupportedException(); + } + } + internal static string ToSql(this Expression expression) + { + var sb = new StringBuilder(); + if (expression is BinaryExpression binaryExpression) + { + sb.Append(binaryExpression.Left.ToSql()); + sb.Append($" {expression.NodeType.ToSql()} "); + sb.Append(binaryExpression.Right.ToSql()); + } + else if (expression is MemberExpression memberExpression) + { + return $"{memberExpression}"; + } + else if (expression is UnaryExpression unaryExpression) + { + return $"{unaryExpression.Operand}"; + } + return sb.ToString(); + } + internal static string GetExpressionValueAsString(Expression expression) + { + if (expression is ConstantExpression constantExpression) + { + return ConvertToSqlValue(constantExpression.Value); + } + else if (expression is MemberExpression memberExpression) + { + if (memberExpression.Expression is ParameterExpression parameterExpression) + { + return memberExpression.ToString(); + } + else + { + return ConvertToSqlValue(Expression.Lambda(expression).Compile().DynamicInvoke()); + } + } + else if (expression.NodeType == ExpressionType.Convert) + { + return ConvertToSqlValue(Expression.Lambda(expression).Compile().DynamicInvoke()); + } + else if (expression.NodeType == ExpressionType.Call) + { + var methodCallExpression = expression as MethodCallExpression; + List argValues = []; + foreach (var argument in methodCallExpression.Arguments) + { + argValues.Add(GetExpressionValueAsString(argument)); + } + return methodCallExpression.Method.Name switch + { + "ToString" => $"CONVERT(VARCHAR,{argValues[0]})", + _ => $"{methodCallExpression.Method.Name}({string.Join(",", argValues)})" + }; + } + else + { + var binaryExpression = expression as BinaryExpression; + string leftValue = GetExpressionValueAsString(binaryExpression.Left); + string rightValue = GetExpressionValueAsString(binaryExpression.Right); + string joinValue = expression.NodeType.ToSql(); + + return $"({leftValue} {joinValue} {rightValue})"; + } + } + internal static string ToSqlPredicate2(this Expression expression, params string[] parameters) + { + var sql = ToSqlString(expression.Body); + + for (var i = 0; i < parameters.Length; i++) + sql = sql.Replace($"${expression.Parameters[i].Name!}.", $"{parameters[i]}."); + + return sql; + } + internal static string ToSqlPredicate(this Expression expression, params string[] parameters) + { + var expressionBody = (string)expression.Body.GetPrivateFieldValue("DebugView"); + expressionBody = expressionBody.Replace(System.Environment.NewLine, " "); + var stringBuilder = new StringBuilder(expressionBody); + + int i = 0; + foreach (var expressionParam in expression.Parameters) + { + if (parameters.Length <= i) break; + stringBuilder.Replace((string)expressionParam.GetPrivateFieldValue("DebugView"), parameters[i]); + i++; + } + stringBuilder.Replace("== null", "IS NULL"); + stringBuilder.Replace("!= null", "IS NOT NULL"); + stringBuilder.Replace("&&", "AND"); + stringBuilder.Replace("==", "="); + stringBuilder.Replace("||", "OR"); + stringBuilder.Replace("(System.Nullable`1[System.Int32])", ""); + stringBuilder.Replace("(System.Int32)", ""); + return stringBuilder.ToString(); + } + internal static string ToSqlPredicate(this Expression expression, DbContext dbContext, params string[] paramet +ters) + { + string predicate = expression.ToSqlPredicate(parameters); + return DelimitMemberAccess(dbContext, predicate); + } + internal static string ToSqlUpdateSetExpression(this Expression expression, string tableName) + { + List setValues = []; + var memberInitExpression = expression.Body as MemberInitExpression; + foreach (var binding in memberInitExpression.Bindings) + { + string expValue = binding.ToSql(); + expValue = expValue.Replace($"{expression.Parameters.First().Name}.", ""); + setValues.Add($"[{binding.Member.Name}]={expValue}"); + } + return string.Join(",", setValues); + } + internal static string ToSqlUpdateSetExpression(this Expression expression, DbContext dbContext, string tableN +Name) + { + List setValues = []; + var memberInitExpression = expression.Body as MemberInitExpression; + foreach (var binding in memberInitExpression.Bindings) + { + string expValue = binding.ToSql(); + expValue = expValue.Replace($"{expression.Parameters.First().Name}.", ""); + expValue = DelimitMemberAccess(dbContext, expValue); + setValues.Add($"{dbContext.DelimitIdentifier(binding.Member.Name)}={expValue}"); + } + return string.Join(",", setValues); + } + private static string ToSqlString(Expression expression, string sql = null) + { + sql ??= ""; + if (expression is not BinaryExpression b) + return sql; + + var sb = new StringBuilder(); + if (b.Left is MemberExpression mel) + sb.Append($"${mel} = "); + if (b.Right is MemberExpression mer) + sb.Append($"${mer}"); + + if (b.Left is UnaryExpression ubl) + sb.Append($"${ubl.Operand} = "); + if (b.Right is UnaryExpression ubr) + sb.Append($"${ubr.Operand}"); + + if (sb.Length > 0) + return sb.ToString(); + + var left = ToSqlString(b.Left, sql); + if (string.IsNullOrWhiteSpace(left)) + return sql; + + var right = ToSqlString(b.Right, sql); + return $"{left} AND {right}"; + } + private static string ConvertToSqlValue(object value) + { + if (value == null) + return "NULL"; + if (value is string str) + return $"'{str.Replace("'", "''")}'"; + if (value is Guid guid) + return $"'{guid}'"; + if (value is bool b) + return b ? "1" : "0"; + if (value is DateTime dt) + return $"'{dt:yyyy-MM-ddTHH:mm:ss.fffffff}'"; // Convert to ISO-8601 + if (value is DateTimeOffset dto) + return $"'{dto:yyyy-MM-ddTHH:mm:ss.fffffffzzzz}'"; // Convert to ISO-8601 + var valueType = value.GetType(); + if (valueType.IsEnum) + return Convert.ToString((int)value); + if (!valueType.IsClass) + return Convert.ToString(value, CultureInfo.InvariantCulture); + + throw new NotImplementedException("Unhandled data type."); + } + private static string DelimitMemberAccess(DbContext dbContext, string expression) + { + return Regex.Replace(expression, @"(? + { + string alias = match.Groups[1].Value; + string member = match.Groups[2].Value; + return dbContext.DelimitMemberAccess(alias, member); + }); + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Extensions\Objec +ctExtensions.cs --- + +using System; +using System.Reflection; + +namespace N.EntityFrameworkCore.Extensions; + +internal static class ObjectExtensions +{ + internal static object GetPrivateFieldValue(this object obj, string propName) + { + if (obj == null) throw new ArgumentNullException(nameof(obj)); + Type t = obj.GetType(); + FieldInfo fieldInfo = null; + PropertyInfo propertyInfo = null; + while (fieldInfo == null && propertyInfo == null && t != null) + { + fieldInfo = t.GetField(propName, BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); + if (fieldInfo == null) + { + propertyInfo = t.GetProperty(propName, BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Insta +ance); + } + + t = t.BaseType; + } + if (fieldInfo == null && propertyInfo == null) + throw new ArgumentOutOfRangeException(nameof(propName), $"Field {propName} was not found in Type {obj.GetTyp +pe().FullName}"); + + if (fieldInfo != null) + return fieldInfo.GetValue(obj); + + return propertyInfo.GetValue(obj, null); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Extensions\SqlSt +tatementExtensions.cs --- + +using System.Collections.Generic; +using N.EntityFrameworkCore.Extensions.Sql; + +namespace N.EntityFrameworkCore.Extensions.Extensions; + +internal static class SqlStatementExtensions +{ + internal static void WriteInsert(this SqlStatement statement, IEnumerable insertColumns) + { + statement.CreatePart(SqlKeyword.Insert, SqlExpression.Columns(insertColumns)); + statement.CreatePart(SqlKeyword.Values, SqlExpression.Columns(insertColumns)); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\GlobalSuppressio +ons.cs --- + +// This file is used by Code Analysis to maintain SuppressMessage +// attributes that are applied to this project. +// Project-level suppressions either have no target or are given +// a specific target and scoped to a namespace, type, member, etc. + +using System.Diagnostics.CodeAnalysis; + +[assembly: SuppressMessage("Usage", "EF1001:Internal EF Core API usage.", Justification = "EntityFrameworkCore Extension +n", Scope = "member", Target = "~M:N.EntityFrameworkCore.Extensions.DbContextExtensions.BulkSaveChanges(Microsoft.EntityF +FrameworkCore.DbContext,System.Boolean)~System.Int32")] +[assembly: SuppressMessage("Usage", "EF1001:Internal EF Core API usage.", Justification = "EntityFrameworkCore Extension +n", Scope = "member", Target = "~M:N.EntityFrameworkCore.Extensions.DbContextExtensions.SetStoreGeneratedValues``1(Micros +soft.EntityFrameworkCore.DbContext,``0,System.Collections.Generic.IEnumerable{Microsoft.EntityFrameworkCore.Metadata.IPro +operty},System.Object[])")] +[assembly: SuppressMessage("Usage", "EF1001:Internal EF Core API usage.", Justification = "EntityFrameworkCore Extension +n", Scope = "member", Target = "~M:N.EntityFrameworkCore.Extensions.DbContextExtensionsAsync.BulkSaveChangesAsync(Microso +oft.EntityFrameworkCore.DbContext,System.Boolean)~System.Threading.Tasks.Task{System.Int32}")] +[assembly: SuppressMessage("Usage", "EF1001:Internal EF Core API usage.", Justification = "EntityFrameworkCore Extension +n", Scope = "member", Target = "~M:N.EntityFrameworkCore.Extensions.TableMapping.GetColumnNames(Microsoft.EntityFramework +kCore.Metadata.IEntityType,System.Boolean)~System.Collections.Generic.IEnumerable{System.String}")] +[assembly: SuppressMessage("Usage", "EF1001:Internal EF Core API usage.", Justification = "EntityFrameworkCore Extension +n", Scope = "member", Target = "~M:N.EntityFrameworkCore.Extensions.EntityDataReader`1.FindEntry(System.Object)~Microsoft +t.EntityFrameworkCore.ChangeTracking.EntityEntry")] + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\N.EntityFramewor +rk.Extensions.SqlServer.csproj --- + + + + + net10.0 + 10.0.5.1 + N.EntityFramework.Extensions.SqlServer + N.EntityFramework.Extensions.SqlServer + true + https://github.com/NorthernLight1/N.EntityFrameworkCore.Extensions/ + Northern25 + Copyright © 2026 + + N.EntityFrameworkCore.Extensions extends your DbContext in EF Core with high-performance bulk operation +ns: BulkDelete, BulkInsert, BulkMerge, BulkSync, BulkUpdate, Fetch, DeleteFromQuery, InsertFromQuery, UpdateFromQuery. + +Inheritance models supported: Table-Per-Concrete, Table-Per-Hierarchy, Table-Per-Type + MIT + README.md + + + + 5 + + + + + True + \ + + + + + + + + + + + + + + + + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Sql\SqlBuilder.c +cs --- + +using System; +using System.Collections.Generic; +using System.Data; +using System.Linq; +using System.Linq.Expressions; +using Microsoft.Data.SqlClient; + +namespace N.EntityFrameworkCore.Extensions.Sql; + +internal sealed class SqlBuilder +{ + private static readonly string[] keywords = ["DECLARE", "SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY"]; + internal string Sql => ToString(); + internal List Clauses { get; private set; } + internal List Parameters { get; private set; } + private SqlBuilder(string sql) + { + Clauses = []; + Parameters = []; + Initialize(sql); + } + + internal string Count() => + $"SELECT COUNT(*) FROM ({string.Join("\r\n", Clauses.Where(o => o.Name != "ORDER BY").Select(o => o.ToString())) +)}) s"; + public override string ToString() => string.Join("\r\n", Clauses.Select(o => o.ToString())); + internal static SqlBuilder Parse(string sql) => new SqlBuilder(sql); + internal string GetTableAlias() + { + var sqlFromClause = Clauses.First(o => o.Name == "FROM"); + var startIndex = sqlFromClause.InputText.LastIndexOf(" AS "); + return startIndex > 0 ? sqlFromClause.InputText[(startIndex + 4)..] : ""; + } + internal void ChangeToDelete() + { + Validate(); + var sqlClause = Clauses.FirstOrDefault(); + var sqlFromClause = Clauses.First(o => o.Name == "FROM"); + if (sqlClause != null) + { + sqlClause.Name = "DELETE"; + int aliasStartIndex = sqlFromClause.InputText.IndexOf("AS ") + 3; + int aliasLength = sqlFromClause.InputText.IndexOf(']', aliasStartIndex) - aliasStartIndex + 1; + sqlClause.InputText = sqlFromClause.InputText[aliasStartIndex..(aliasStartIndex + aliasLength)]; + } + } + internal void ChangeToUpdate(string updateExpression, string setExpression) + { + Validate(); + var sqlClause = Clauses.FirstOrDefault(); + if (sqlClause != null) + { + sqlClause.Name = "UPDATE"; + sqlClause.InputText = updateExpression; + Clauses.Insert(1, new SqlClause { Name = "SET", InputText = setExpression }); + } + } + internal void ChangeToInsert(string tableName, Expression> insertObjectExpression) + { + Validate(); + var sqlSelectClause = Clauses.FirstOrDefault(); + string columnsToInsert = string.Join(",", insertObjectExpression.GetObjectProperties()); + string insertValueExpression = $"INTO {tableName} ({columnsToInsert})"; + Clauses.Insert(0, new SqlClause { Name = "INSERT", InputText = insertValueExpression }); + sqlSelectClause.InputText = columnsToInsert; + } + internal void SelectColumns(IEnumerable columns) + { + var tableAlias = GetTableAlias(); + var sqlClause = Clauses.FirstOrDefault(); + if (sqlClause.Name == "SELECT") + { + sqlClause.InputText = string.Join(",", columns.Select(c => $"{tableAlias}.{c}")); + } + } + private void Initialize(string sqlText) + { + string curClause = string.Empty; + int curClauseIndex = 0; + for (int i = 0; i < sqlText.Length;) + { + string keyword = StartsWithString(sqlText.AsSpan(i), keywords, StringComparison.OrdinalIgnoreCase); + bool isWordStart = i == 0 || sqlText[i - 1] == ' ' || (i > 1 && sqlText[i - 2] == '\r' && sqlText[i - 1] == + '\n'); + if (keyword != null && isWordStart) + { + string inputText = sqlText[curClauseIndex..i]; + if (!string.IsNullOrEmpty(curClause)) + { + if (curClause == "DECLARE") + { + var declareParts = inputText[..inputText.IndexOf(';')].Trim().Split(' '); + int sizeStartIndex = declareParts[1].IndexOf('('); + int sizeLength = declareParts[1].IndexOf(')') - (sizeStartIndex + 1); + string dbTypeString = sizeStartIndex != -1 ? declareParts[1][..sizeStartIndex] : declareParts[1] +]; + SqlDbType dbType = (SqlDbType)Enum.Parse(typeof(SqlDbType), dbTypeString, true); + int size = sizeStartIndex != -1 ? + Convert.ToInt32(declareParts[1][(sizeStartIndex + 1)..(sizeStartIndex + 1 + sizeLength)]) : + 0; + string value = GetDeclareValue(declareParts[3]); + Parameters.Add(new SqlParameter(declareParts[0], dbType, size) { Value = value }); + } + else + { + Clauses.Add(SqlClause.Parse(curClause, inputText)); + } + } + curClause = keyword; + curClauseIndex = i + curClause.Length; + i = i + curClause.Length; + } + else + { + i++; + } + } + if (!string.IsNullOrEmpty(curClause)) + Clauses.Add(SqlClause.Parse(curClause, sqlText[curClauseIndex..])); + } + private string GetDeclareValue(string value) + { + if (value.StartsWith('\'')) + { + return value[1..^1]; + } + else if (value.StartsWith("N'")) + { + return value[2..^1]; + } + else if (value.StartsWith("CAST(")) + { + return value[5..]; + } + else + { + return value; + } + } + private static string StartsWithString(ReadOnlySpan textToSearch, string[] valuesToFind, StringComparison stri +ingComparison) + { + foreach (var valueToFind in valuesToFind) + { + if (textToSearch.StartsWith(valueToFind, stringComparison)) + return valueToFind; + } + + return null; + } + private void Validate() + { + if (Clauses.Count == 0) + { + throw new Exception("You must parse a valid sql statement before you can use this function."); + } + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Sql\SqlClause.cs +s --- + +namespace N.EntityFrameworkCore.Extensions.Sql; + +internal sealed class SqlClause +{ + internal string Name { get; set; } + internal string InputText { get; set; } + internal string Sql => ToString(); + internal static SqlClause Parse(string name, string inputText) + { + string cleanText = inputText.Replace("\r\n", "").Trim(); + return new SqlClause { Name = name, InputText = cleanText }; + } + public override string ToString() => $"{Name} {InputText}"; +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Sql\SqlExpressio +on.cs --- + +using System.Collections.Generic; +using System.Linq; +using System.Text; +using N.EntityFrameworkCore.Extensions.Util; + +namespace N.EntityFrameworkCore.Extensions.Sql; + +internal sealed class SqlExpression +{ + internal SqlExpressionType ExpressionType { get; } + List Items { get; set; } + internal string Sql => ToSql(); + string Alias { get; } + internal bool IsEmpty => Items.Count == 0; + + SqlExpression(SqlExpressionType expressionType, object item, string alias = null) + { + ExpressionType = expressionType; + Items = []; + if (item is IEnumerable values) + { + Items.AddRange(values.ToArray()); + } + else + { + Items.Add(item); + } + Alias = alias; + } + SqlExpression(SqlExpressionType expressionType, object[] items, string alias = null) + { + ExpressionType = expressionType; + Items = []; + Items.AddRange(items); + Alias = alias; + } + internal static SqlExpression Columns(IEnumerable columns) => + new SqlExpression(SqlExpressionType.Columns, columns); + + internal static SqlExpression Set(IEnumerable columns) => + new SqlExpression(SqlExpressionType.Set, columns); + + internal static SqlExpression String(string joinOnCondition) => + new SqlExpression(SqlExpressionType.String, joinOnCondition); + + internal static SqlExpression Table(string tableName, string alias = null) => + new SqlExpression(SqlExpressionType.Table, Util.CommonUtil.FormatTableName(tableName), alias); + + private string ToSql() + { + var sbSql = new StringBuilder(); + if (ExpressionType == SqlExpressionType.Columns) + { + var values = Items.Where(o => o != null).Select(o => o.ToString()).Where(o => !string.IsNullOrWhiteSpace(o)) +).ToArray(); + sbSql.Append(string.Join(",", CommonUtil.FormatColumns(values))); + } + else + { + sbSql.Append(string.Join(",", Items.Where(o => o != null).Select(o => o.ToString()).Where(o => !string.IsNul +llOrWhiteSpace(o)))); + } + if (Alias != null) + { + sbSql.Append(" "); + sbSql.Append(SqlKeyword.As.ToString().ToUpper()); + sbSql.Append(" "); + sbSql.Append(Alias); + } + return sbSql.ToString(); + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Sql\SqlExpressio +onType.cs --- + +namespace N.EntityFrameworkCore.Extensions.Sql; + +internal enum SqlExpressionType +{ + String, + Table, + Columns, + Set +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Sql\SqlKeyword.c +cs --- + +namespace N.EntityFrameworkCore.Extensions.Sql; + +internal enum SqlKeyword +{ + Select, + Delete, + Insert, + Values, + Update, + Set, + Merge, + Into, + From, + On, + Where, + Using, + When, + Then, + Matched, + Not, + Output, + As, + By, + Source, + Target, + Off, + Identity_Insert, + Semicolon, +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Sql\SqlPart.cs - +--- + +namespace N.EntityFrameworkCore.Extensions.Sql; + +internal sealed class SqlPart +{ + internal SqlKeyword Keyword { get; } + internal SqlExpression Expression { get; } + internal bool IgnoreOutput => GetIgnoreOutput(); + internal SqlPart(SqlKeyword keyword, SqlExpression expression) + { + Keyword = keyword; + Expression = expression; + } + private bool GetIgnoreOutput() => Keyword == SqlKeyword.Output && (Expression == null || Expression.IsEmpty); +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Sql\SqlStatement +t.cs --- + +using System.Collections.Generic; +using System.Linq; +using System.Text; +using N.EntityFrameworkCore.Extensions.Extensions; + +namespace N.EntityFrameworkCore.Extensions.Sql; + +internal sealed class SqlStatement +{ + internal string Sql => ToSql(); + List SqlParts { get; } + SqlStatement() + { + SqlParts = []; + } + internal void CreatePart(SqlKeyword keyword, SqlExpression expression = null) => + SqlParts.Add(new SqlPart(keyword, expression)); + internal void SetIdentityInsert(string tableName, bool enable) + { + CreatePart(SqlKeyword.Set); + CreatePart(SqlKeyword.Identity_Insert, SqlExpression.Table(tableName)); + if (enable) + CreatePart(SqlKeyword.On); + else + CreatePart(SqlKeyword.Off); + CreatePart(SqlKeyword.Semicolon); + } + internal static SqlStatement CreateMerge(string sourceTableName, string targetTableName, string joinOnCondition, + IEnumerable insertColumns, IEnumerable updateColumns, IEnumerable outputColumns, + bool deleteIfNotMatched = false, bool hasIdentityColumn = false) + { + var statement = new SqlStatement(); + if (hasIdentityColumn) + statement.SetIdentityInsert(targetTableName, true); + statement.CreatePart(SqlKeyword.Merge, SqlExpression.Table(targetTableName, "t")); + statement.CreatePart(SqlKeyword.Using, SqlExpression.Table(sourceTableName, "s")); + statement.CreatePart(SqlKeyword.On, SqlExpression.String(joinOnCondition)); + statement.CreatePart(SqlKeyword.When); + statement.CreatePart(SqlKeyword.Not); + statement.CreatePart(SqlKeyword.Matched); + statement.CreatePart(SqlKeyword.Then); + statement.WriteInsert(insertColumns); + if (updateColumns.Any()) + { + var updateSetColumns = updateColumns.Select(c => $"t.[{c}]=s.[{c}]"); + statement.CreatePart(SqlKeyword.When); + statement.CreatePart(SqlKeyword.Matched); + statement.CreatePart(SqlKeyword.Then); + statement.CreatePart(SqlKeyword.Update); + statement.CreatePart(SqlKeyword.Set, SqlExpression.Set(updateSetColumns)); + } + if (deleteIfNotMatched) + { + statement.CreatePart(SqlKeyword.When); + statement.CreatePart(SqlKeyword.Not); + statement.CreatePart(SqlKeyword.Matched); + statement.CreatePart(SqlKeyword.By); + statement.CreatePart(SqlKeyword.Source); + statement.CreatePart(SqlKeyword.Then); + statement.CreatePart(SqlKeyword.Delete); + } + if (outputColumns.Any()) + statement.CreatePart(SqlKeyword.Output, SqlExpression.Columns(outputColumns)); + statement.CreatePart(SqlKeyword.Semicolon); + + if (hasIdentityColumn) + statement.SetIdentityInsert(targetTableName, false); + return statement; + } + + private string ToSql() + { + var sbSql = new StringBuilder(); + foreach (var part in SqlParts) + { + if (part.Keyword == SqlKeyword.Semicolon) + { + int lastIndex = sbSql.Length - 1; + if (lastIndex > -1 && sbSql[lastIndex] == ' ') + { + sbSql[lastIndex] = ';'; + sbSql.Append("\n"); + } + else + { + sbSql.Append(";\n"); + } + } + else if (!part.IgnoreOutput) + { + sbSql.Append(part.Keyword.ToString().ToUpper()); + sbSql.Append(" "); + bool useParenthese = part.Keyword == SqlKeyword.Insert || part.Keyword == SqlKeyword.Values; + + if (part.Expression != null) + { + string expressionSql = useParenthese ? $"({part.Expression.Sql})" : part.Expression.Sql; + sbSql.Append(expressionSql); + sbSql.Append(" "); + } + } + } + return sbSql.ToString(); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Util\CommonUtil. +.cs --- + +using System; +using System.Collections.Generic; +using System.Data.Common; +using System.IO; +using System.Linq; +using System.Linq.Expressions; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Infrastructure; + +namespace N.EntityFrameworkCore.Extensions.Util; + +internal static class CommonUtil +{ + internal static string GetStagingTableName(TableMapping tableMapping, bool usePermanentTable, DbConnection dbConnect +tion) + { + string uniqueSuffix = Guid.NewGuid().ToString("N"); + if (usePermanentTable) + return tableMapping.DbContext.Database.GetPermanentStagingTableName(tableMapping.Schema, tableMapping.TableN +Name, uniqueSuffix); + return tableMapping.DbContext.Database.GetTemporaryTableName(tableMapping.TableName); + } + internal static IEnumerable FormatColumns(DbContext dbContext, IEnumerable columns) + { + return columns.Select(s => FormatColumn(dbContext, s)); + } + internal static IEnumerable FormatColumns(IEnumerable columns) + { + return columns.Select(FormatColumnLegacy); + } + internal static IEnumerable FormatColumns(DbContext dbContext, string tableAlias, IEnumerable column +ns) + { + return columns.Select(s => dbContext.DelimitMemberAccess(tableAlias, RemoveQualifier(s))); + } + internal static IEnumerable FormatColumns(DatabaseFacade database, string tableAlias, IEnumerable co +olumns) + { + return columns.Select(s => database.DelimitMemberAccess(tableAlias, RemoveQualifier(s))); + } + internal static IEnumerable FormatColumns(string tableAlias, IEnumerable columns) + { + return columns.Select(s => s.StartsWith('[') && s.EndsWith(']') ? $"[{tableAlias}].{s}" : $"[{tableAlias}].[{s}] +]"); + } + internal static IEnumerable FilterColumns(IEnumerable columnNames, string[] primaryKeyColumnNames +s, Expression> inputColumns, Expression> ignoreColumns) + { + var filteredColumnNames = columnNames; + if (inputColumns != null) + { + var inputColumnNames = inputColumns.GetObjectProperties(); + filteredColumnNames = filteredColumnNames.Intersect(inputColumnNames.Union(primaryKeyColumnNames)); + } + if (ignoreColumns != null) + { + var ignoreColumnNames = ignoreColumns.GetObjectProperties(); + if (ignoreColumnNames.Intersect(primaryKeyColumnNames).Any()) + { + throw new InvalidDataException("Primary key columns can not be ignored in BulkInsertOptions.IgnoreColumn +ns"); + } + else + { + filteredColumnNames = filteredColumnNames.Except(ignoreColumnNames); + } + } + return filteredColumnNames; + } + internal static string FormatTableName(DatabaseFacade database, string tableName) + { + return database.DelimitTableName(tableName); + } + internal static string FormatTableName(string tableName) + { + return string.Join(".", tableName.Split('.').Select(s => $"[{RemoveQualifier(s)}]")); + } + private static string FormatColumn(DbContext dbContext, string column) + { + var parts = column.Split('.'); + return string.Join(".", parts.Select(p => p.StartsWith('$') ? p : dbContext.DelimitIdentifier(RemoveQualifier(p) +)))); + } + private static string FormatColumnLegacy(string column) + { + var parts = column.Split('.'); + return string.Join(".", parts.Select(p => p.StartsWith('$') || (p.StartsWith('[') && p.EndsWith(']')) ? p : $"[{ +{p}]")); + } + private static string RemoveQualifier(string name) + { + return name.TrimStart('[').TrimEnd(']').Trim('"'); + } +} +internal static class CommonUtil +{ + internal static string[] GetColumns(Expression> expression, string[] tableNames) + { + List foundColumns = []; + string sqlText = (string)expression.Body.GetPrivateFieldValue("DebugView"); + var sqlSpan = sqlText.AsSpan(); + + int offset = 0; + while (offset < sqlSpan.Length) + { + int startIndex = sqlSpan[offset..].IndexOf('$'); + if (startIndex == -1) break; + startIndex += offset; + + var remaining = sqlSpan[startIndex..]; + int spaceIndex = remaining.IndexOf(' '); + var columnSpan = spaceIndex == -1 ? remaining : remaining[..spaceIndex]; + + int dotIndex = columnSpan.IndexOf('.'); + if (dotIndex >= 0) + { + var tablePart = columnSpan[1..dotIndex]; // skip leading '$' + var columnPart = columnSpan[(dotIndex + 1)..]; + if (tableNames == null || tableNames.Contains(tablePart.ToString())) + { + foundColumns.Add(columnPart.ToString()); + } + } + + offset = startIndex + 1; + } + + return foundColumns.ToArray(); + } + internal static string GetJoinConditionSql(Expression> joinKeyExpression, string[] storeGeneratedCo +olumnNames, string sourceTableName = "s", string targetTableName = "t") + { + if (joinKeyExpression != null) + return joinKeyExpression.ToSqlPredicate(sourceTableName, targetTableName); + + return string.Join(" AND ", storeGeneratedColumnNames.Select(c => $"{sourceTableName}.[{c}]={targetTableName}.[{ +{c}]")); + } + internal static string GetJoinConditionSql(DbContext dbContext, Expression> joinKeyExpression, stri +ing[] storeGeneratedColumnNames, string sourceTableName = "s", string targetTableName = "t") + { + if (joinKeyExpression != null) + return joinKeyExpression.ToSqlPredicate(dbContext, sourceTableName, targetTableName); + + return string.Join(" AND ", storeGeneratedColumnNames.Select(c => $"{dbContext.DelimitMemberAccess(sourceTableNa +ame, c)}={dbContext.DelimitMemberAccess(targetTableName, c)}")); + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Util\RelationalP +ProviderUtil.cs --- + +using System; +using System.Data.Common; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Storage; + +namespace N.EntityFrameworkCore.Extensions.Util; + +internal readonly record struct DatabaseObjectName(string Schema, string Name) +{ + internal bool HasSchema => !string.IsNullOrWhiteSpace(Schema); +} + +internal static class RelationalProviderUtil +{ + internal static string GetDefaultSchema(this DatabaseFacade database) => "dbo"; + + internal static string DelimitIdentifier(this DatabaseFacade database, string identifier) => + database.GetSqlGenerationHelper().DelimitIdentifier(UnwrapIdentifier(identifier)); + + internal static string DelimitIdentifier(this DatabaseFacade database, string identifier, string schema) => + schema == null + ? database.DelimitIdentifier(identifier) + : database.GetSqlGenerationHelper().DelimitIdentifier(UnwrapIdentifier(identifier), UnwrapIdentifier(schema) +)); + + internal static string DelimitIdentifier(this DbContext dbContext, string identifier) => + dbContext.Database.DelimitIdentifier(identifier); + + internal static string DelimitIdentifier(this DbContext dbContext, string identifier, string schema) => + dbContext.Database.DelimitIdentifier(identifier, schema); + + internal static string DelimitTableName(this DatabaseFacade database, string tableName) + { + var objectName = database.ParseObjectName(tableName); + return objectName.HasSchema + ? database.DelimitIdentifier(objectName.Name, objectName.Schema) + : database.DelimitIdentifier(objectName.Name); + } + + internal static string DelimitTableName(this DbContext dbContext, string tableName) => + dbContext.Database.DelimitTableName(tableName); + + internal static string DelimitMemberAccess(this DbContext dbContext, string alias, string columnName) => + $"{dbContext.DelimitIdentifier(alias)}.{dbContext.DelimitIdentifier(columnName)}"; + + internal static string DelimitMemberAccess(this DatabaseFacade database, string alias, string columnName) => + $"{database.DelimitIdentifier(alias)}.{database.DelimitIdentifier(columnName)}"; + + internal static DatabaseObjectName ParseObjectName(this DatabaseFacade database, string objectName) + { + string normalized = objectName.Trim(); + if (string.IsNullOrWhiteSpace(normalized)) + throw new ArgumentException("Object name cannot be empty.", nameof(objectName)); + + var parts = normalized.Split('.', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries); + return parts.Length switch + { + 1 => new DatabaseObjectName(IsTemporaryName(parts[0]) ? null : "dbo", UnwrapIdentifier(parts[0])), + 2 => new DatabaseObjectName(UnwrapIdentifier(parts[0]), UnwrapIdentifier(parts[1])), + _ => throw new InvalidOperationException($"Unsupported object name format '{objectName}'.") + }; + } + + internal static string UnwrapIdentifier(string value) => + value.Trim().Trim('[', ']', '"'); + + internal static string GetTemporaryTableName(this DatabaseFacade database, string baseName) + { + string temporaryName = $"tmp_be_xx_{UnwrapIdentifier(baseName)}_{Guid.NewGuid():N}"; + return database.DelimitIdentifier(temporaryName); + } + + internal static string GetPermanentStagingTableName(this DatabaseFacade database, string schema, string tableName, s +string uniqueSuffix) + { + string stagingName = $"tmp_be_xx_{UnwrapIdentifier(tableName)}_{uniqueSuffix}"; + return database.DelimitIdentifier(stagingName, schema); + } + + internal static DbConnection CloneConnection(this DbConnection dbConnection) => + dbConnection is ICloneable cloneable + ? (DbConnection)cloneable.Clone() + : throw new NotSupportedException($"Connection type '{dbConnection.GetType().FullName}' does not support clo +oning."); + + private static ISqlGenerationHelper GetSqlGenerationHelper(this DatabaseFacade database) => + ((IInfrastructure)database).Instance.GetService(typeof(ISqlGenerationHelper)) as ISqlGeneratio +onHelper + ?? throw new InvalidOperationException("Unable to resolve ISqlGenerationHelper."); + + private static bool IsTemporaryName(string objectName) => + UnwrapIdentifier(objectName).StartsWith("#", StringComparison.Ordinal); +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.SqlServer\Util\SqlUtil.cs + --- + +using System.Collections.Generic; + +namespace N.EntityFrameworkCore.Extensions; + +internal static class SqlUtil +{ + internal static string ConvertToColumnString(IEnumerable columnNames) + { + return string.Join(",", columnNames); + } +} + +=== DIRECTORY: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql === + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Common\Constant +ts.cs --- + +namespace N.EntityFrameworkCore.Extensions.Common; + +public static class Constants +{ + public static readonly string InternalId_ColumnName = "_be_xx_id"; +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Data\BulkDelete +eOptions.cs --- + +using System; +using System.Linq.Expressions; + +namespace N.EntityFrameworkCore.Extensions; + +public class BulkDeleteOptions : BulkOptions +{ + public Expression> DeleteOnCondition { get; set; } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Data\BulkFetchO +Options.cs --- + +using System; +using System.Linq.Expressions; + +namespace N.EntityFrameworkCore.Extensions; + +public class BulkFetchOptions : BulkOptions +{ + public Expression> IgnoreColumns { get; set; } + public Expression> InputColumns { get; set; } + public Expression> JoinOnCondition { get; set; } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Data\BulkInsert +tOptions.cs --- + +using System; +using System.Linq; +using System.Linq.Expressions; + +namespace N.EntityFrameworkCore.Extensions; + +public class BulkInsertOptions : BulkOptions +{ + public bool AutoMapOutput { get; set; } + public Expression> IgnoreColumns { get; set; } + public Expression> InputColumns { get; set; } + public bool InsertIfNotExists { get; set; } + public Expression> InsertOnCondition { get; set; } + public bool KeepIdentity { get; set; } + + public string[] GetInputColumns() => + InputColumns?.Body.Type.GetProperties().Select(o => o.Name).ToArray(); + + public BulkInsertOptions() + { + AutoMapOutput = true; + } + internal BulkInsertOptions(BulkOptions options) + { + EntityType = options.EntityType; + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Data\BulkInsert +tResult.cs --- + +using System.Collections.Generic; + +namespace N.EntityFrameworkCore.Extensions; + +internal sealed class BulkInsertResult +{ + internal int RowsAffected { get; set; } + internal Dictionary EntityMap { get; set; } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Data\BulkMergeO +Option.cs --- + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; + +namespace N.EntityFrameworkCore.Extensions; + +public class BulkMergeOptions : BulkOptions +{ + public Expression> MergeOnCondition { get; set; } + public Expression> IgnoreColumnsOnInsert { get; set; } + public Expression> IgnoreColumnsOnUpdate { get; set; } + public bool AutoMapOutput { get; set; } + internal bool DeleteIfNotMatched { get; set; } + + public BulkMergeOptions() + { + AutoMapOutput = true; + } + public List GetIgnoreColumnsOnInsert() => + IgnoreColumnsOnInsert?.Body.Type.GetProperties().Select(o => o.Name).ToList() ?? []; + public List GetIgnoreColumnsOnUpdate() => + IgnoreColumnsOnUpdate?.Body.Type.GetProperties().Select(o => o.Name).ToList() ?? []; +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Data\BulkMergeO +OutputRow.cs --- + +namespace N.EntityFrameworkCore.Extensions; + +public class BulkMergeOutputRow +{ + public string Action { get; set; } + + public BulkMergeOutputRow(string action) + { + Action = action; + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Data\BulkMergeR +Result.cs --- + +using System.Collections.Generic; + +namespace N.EntityFrameworkCore.Extensions; + +public class BulkMergeResult +{ + public IEnumerable> Output { get; set; } + public int RowsAffected { get; set; } + public int RowsDeleted { get; internal set; } + public int RowsInserted { get; internal set; } + public int RowsUpdated { get; internal set; } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Data\BulkOperat +tion.cs --- + +using System; +using System.Collections.Generic; +using System.Data.Common; +using System.IO; +using System.Linq; +using System.Linq.Expressions; +using Microsoft.Data.SqlClient; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.ChangeTracking.Internal; +using Microsoft.EntityFrameworkCore.Metadata; +using N.EntityFrameworkCore.Extensions.Common; +using N.EntityFrameworkCore.Extensions.Sql; +using N.EntityFrameworkCore.Extensions.Util; + +namespace N.EntityFrameworkCore.Extensions; + +internal sealed partial class BulkOperation : IDisposable +{ + internal DbConnection Connection => DbTransactionContext.Connection; + internal DbContext Context { get; } + internal bool StagingTableCreated { get; set; } + internal string StagingTableName { get; } + internal string[] PrimaryKeyColumnNames { get; } + internal BulkOptions Options { get; } + internal Expression> InputColumns { get; } + internal Expression> IgnoreColumns { get; } + internal DbTransactionContext DbTransactionContext { get; } + internal Type EntityType => typeof(T); + internal DbTransaction Transaction => DbTransactionContext.CurrentTransaction; + internal TableMapping TableMapping { get; } + internal IEnumerable SchemaQualifiedTableNames => TableMapping.GetSchemaQualifiedTableNames(); + + + public BulkOperation(DbContext dbContext, BulkOptions options, Expression> inputColumns = null, Expr +ression> ignoreColumns = null) + { + Context = dbContext; + Options = options; + InputColumns = inputColumns; + IgnoreColumns = ignoreColumns; + + DbTransactionContext = new DbTransactionContext(dbContext, options.CommandTimeout); + TableMapping = dbContext.GetTableMapping(typeof(T), options.EntityType); + StagingTableName = CommonUtil.GetStagingTableName(TableMapping, options.UsePermanentTable, Connection); + PrimaryKeyColumnNames = TableMapping.GetPrimaryKeyColumns().ToArray(); + } + public void Dispose() + { + if (StagingTableCreated) + { + Context.Database.DropTable(StagingTableName, true); + } + } + internal bool ShouldKeepIdentityForPostgresMerge() + { + return Context.Database.IsPostgreSql() + && GetGeneratedPrimaryKeyProperty()?.PropertyInfo != null + && PrimaryKeyColumnNames.Length == 1; + } + internal bool ShouldPreallocateIdentityValues(bool autoMapOutput, bool keepIdentity, IEnumerable entities) + { + if (!Context.Database.IsPostgreSql() || keepIdentity || !autoMapOutput) + return false; + + var identityProperty = GetGeneratedPrimaryKeyProperty(); + if (identityProperty?.PropertyInfo == null || PrimaryKeyColumnNames.Length != 1) + return false; + + var entityList = entities as IList ?? entities.ToList(); + if (entityList.Count == 0) + return false; + + // For BulkSaveChanges, entities are InternalEntityEntry (Added state) — always preallocate + if (entityList[0] is InternalEntityEntry) + return true; + + // For regular POCOs, only preallocate if all entities have the default PK value + object defaultValue = identityProperty.ClrType.IsValueType ? Activator.CreateInstance(identityProperty.ClrType) + : null; + return entityList.All(entity => Equals(identityProperty.PropertyInfo.GetValue(entity), defaultValue)); + } + internal void PreallocateIdentityValues(IEnumerable entities) + { + var identityProperty = GetGeneratedPrimaryKeyProperty(); + if (identityProperty?.PropertyInfo == null) + return; + + var entityList = entities.ToList(); + if (entityList.Count == 0) + return; + + string tableName = Context.DelimitIdentifier(TableMapping.EntityType.GetTableName(), TableMapping.EntityType.Get +tSchema() ?? Context.Database.GetDefaultSchema()); + string sequenceSql = $"SELECT nextval(pg_get_serial_sequence('{tableName}', '{identityProperty.GetColumnName()}' +')) FROM generate_series(1, {entityList.Count})"; + using var command = Connection.CreateCommand(); + command.CommandText = sequenceSql; + command.Transaction = Transaction; + using var reader = command.ExecuteReader(); + foreach (var entity in entityList) + { + if (!reader.Read()) + throw new InvalidDataException("Failed to allocate PostgreSql identity values."); + + object sequenceValue = Convert.ChangeType(reader.GetValue(0), identityProperty.ClrType); + if (entity is InternalEntityEntry internalEntry) + internalEntry.SetStoreGeneratedValue(identityProperty, sequenceValue); + else + identityProperty.PropertyInfo.SetValue(entity, sequenceValue); + } + } + internal BulkInsertResult BulkInsertStagingData(IEnumerable entities, bool keepIdentity = true, bool useIntern +nalId = false) + { + IEnumerable columnsToInsert = GetColumnNames(keepIdentity); + string internalIdColumn = useInternalId ? Common.Constants.InternalId_ColumnName : null; + Context.Database.CloneTable(SchemaQualifiedTableNames, StagingTableName, TableMapping.GetQualifiedColumnNames(co +olumnsToInsert), internalIdColumn); + StagingTableCreated = true; + return DbContextExtensions.BulkInsert(entities, Options, TableMapping, Connection, Transaction, StagingTableName +e, columnsToInsert, SqlBulkCopyOptions.KeepIdentity, useInternalId); + } + internal BulkMergeResult ExecuteMerge(Dictionary entityMap, Expression> mergeOnConditio +on, + bool autoMapOutput, bool keepIdentity, bool insertIfNotExists, bool update = false, bool delete = false, bool pr +reallocatedIds = false) + { + if (Context.Database.IsPostgreSql()) + return ExecuteMergePostgreSql(entityMap, mergeOnCondition, autoMapOutput, keepIdentity, insertIfNotExists, u +update, delete, preallocatedIds); + + Dictionary rowsInserted = []; + Dictionary rowsUpdated = []; + Dictionary rowsDeleted = []; + Dictionary rowsAffected = []; + List> outputRows = []; + + foreach (var entityType in TableMapping.EntityTypes) + { + rowsInserted[entityType] = 0; + rowsUpdated[entityType] = 0; + rowsDeleted[entityType] = 0; + rowsAffected[entityType] = 0; + + var columnsToInsert = GetColumnNames(entityType, keepIdentity); + var columnsToUpdate = update ? GetColumnNames(entityType) : []; + var autoGeneratedColumns = autoMapOutput ? TableMapping.GetAutoGeneratedColumns(entityType) : []; + var columnsToOutput = autoMapOutput ? GetMergeOutputColumns(autoGeneratedColumns, delete) : []; + var deleteEntityType = TableMapping.EntityType == entityType && delete ? delete : false; + + string mergeOnConditionSql = insertIfNotExists ? CommonUtil.GetJoinConditionSql(mergeOnCondition, Primary +yKeyColumnNames, "t", "s") : "1=2"; + bool toggleIdentity = keepIdentity && TableMapping.HasIdentityColumn; + var mergeStatement = SqlStatement.CreateMerge(StagingTableName, entityType.GetSchemaQualifiedTableName(), + mergeOnConditionSql, columnsToInsert, columnsToUpdate, columnsToOutput, deleteEntityType, toggleIdentity +y); + + if (autoMapOutput) + { + List allProperties = + [ + .. TableMapping.GetEntityProperties(entityType, ValueGenerated.OnAdd).ToArray(), + .. TableMapping.GetEntityProperties(entityType, ValueGenerated.OnAddOrUpdate).ToArray() + ]; + + var bulkQueryResult = Context.BulkQuery(mergeStatement.Sql, Options); + rowsAffected[entityType] = bulkQueryResult.RowsAffected; + + foreach (var result in bulkQueryResult.Results) + { + string action = (string)result[0]; + outputRows.Add(new BulkMergeOutputRow(action)); + + if (action == SqlMergeAction.Delete) + { + rowsDeleted[entityType]++; + } + else + { + int entityId = (int)result[1]; + var entity = entityMap[entityId]; + if (action == SqlMergeAction.Insert) + { + rowsInserted[entityType]++; + if (allProperties.Count != 0) + { + var entityValues = GetMergeOutputValues(columnsToOutput, result, allProperties); + Context.SetStoreGeneratedValues(entity, allProperties, entityValues); + } + } + else if (action == SqlMergeAction.Update) + { + rowsUpdated[entityType]++; + if (allProperties.Count != 0) + { + var entityValues = GetMergeOutputValues(columnsToOutput, result, allProperties); + Context.SetStoreGeneratedValues(entity, allProperties, entityValues); + } + } + } + } + } + else + { + rowsAffected[entityType] = Context.Database.ExecuteSqlInternal(mergeStatement.Sql, Options.CommandTimeou +ut); + } + } + return new BulkMergeResult + { + Output = outputRows, + RowsAffected = rowsAffected.Values.LastOrDefault(), + RowsDeleted = rowsDeleted.Values.LastOrDefault(), + RowsInserted = rowsInserted.Values.LastOrDefault(), + RowsUpdated = rowsUpdated.Values.LastOrDefault() + }; + } + + private IEnumerable GetMergeOutputColumns(IEnumerable autoGeneratedColumns, bool delete = false) + { + List columnsToOutput = ["$action", $"[s].[{Constants.InternalId_ColumnName}]"]; + columnsToOutput.AddRange(autoGeneratedColumns.Select(o => $"[inserted].[{o}]")); + return columnsToOutput; + } + private object[] GetMergeOutputValues(IEnumerable columns, object[] values, IEnumerable propertie +es) + { + var columnList = columns.ToList(); + var valuesIndex = properties.Select(o => columnList.IndexOf($"[inserted].[{o.GetColumnName()}]")); + return valuesIndex.Select(i => values[i]).ToArray(); + } + internal int ExecuteUpdate(IEnumerable entities, Expression> updateOnCondition) + { + if (Context.Database.IsPostgreSql()) + return ExecuteUpdatePostgreSql(updateOnCondition); + + int rowsUpdated = 0; + foreach (var entityType in TableMapping.EntityTypes) + { + IEnumerable columnsToUpdate = CommonUtil.FormatColumns(GetColumnNames(entityType)); + string updateSetExpression = string.Join(",", columnsToUpdate.Select(o => $"t.{o}=s.{o}")); + string updateSql = $"UPDATE t SET {updateSetExpression} FROM {StagingTableName} AS s JOIN {CommonUtil.Format +tTableName(entityType.GetSchemaQualifiedTableName())} AS t ON {CommonUtil.GetJoinConditionSql(updateOnCondition, Prima +aryKeyColumnNames, "s", "t")}; SELECT @@RowCount;"; + rowsUpdated = Context.Database.ExecuteSqlInternal(updateSql, Options.CommandTimeout); + } + return rowsUpdated; + } + private BulkMergeResult ExecuteMergePostgreSql(Dictionary entityMap, Expression> mergeO +OnCondition, + bool autoMapOutput, bool keepIdentity, bool insertIfNotExists, bool update, bool delete, bool preallocatedIds = + false) + { + Dictionary rowsInserted = []; + Dictionary rowsUpdated = []; + Dictionary rowsDeleted = []; + List> outputRows = []; + + foreach (var entityType in TableMapping.EntityTypes) + { + var targetTableName = Context.DelimitIdentifier(entityType.GetTableName(), entityType.GetSchema() ?? Context +t.Database.GetDefaultSchema()); + var columnsToInsert = GetColumnNames(entityType, keepIdentity).ToList(); + var columnsToUpdate = update ? GetColumnNames(entityType).ToList() : []; + var autoGeneratedColumns = autoMapOutput ? TableMapping.GetAutoGeneratedColumns(entityType).ToList() : []; + var allProperties = autoMapOutput + ? TableMapping.GetEntityProperties(entityType, ValueGenerated.OnAdd).Concat(TableMapping.GetEntityProper +rties(entityType, ValueGenerated.OnAddOrUpdate)).ToList() + : []; + + string matchJoinCondition = CommonUtil.GetJoinConditionSql(Context, mergeOnCondition, PrimaryKeyColumnNam +mes, "s", "t"); + string pkJoinCondition = CommonUtil.GetJoinConditionSql(Context, null, PrimaryKeyColumnNames, "s", "t"); + string joinCondition = insertIfNotExists ? matchJoinCondition : "1=2"; + + HashSet matchedIds = autoMapOutput && update + ? GetMatchedInternalIds(targetTableName, matchJoinCondition) + : []; + + rowsUpdated[entityType] = 0; + if (columnsToUpdate.Count > 0) + { + string updateSetExpression = string.Join(",", columnsToUpdate.Select(c => $"{Context.DelimitIdentifier(c +c)}={Context.DelimitMemberAccess("s", c)}")); + string updateSql = $"UPDATE {targetTableName} AS t SET {updateSetExpression} FROM {StagingTableName} AS + s WHERE {joinCondition}"; + rowsUpdated[entityType] = Context.Database.ExecuteSqlInternal(updateSql, Options.CommandTimeout); + } + + string insertColumnsSql = string.Join(",", columnsToInsert.Select(Context.DelimitIdentifier)); + string sourceColumnsSql = string.Join(",", columnsToInsert.Select(c => Context.DelimitMemberAccess("s", c))) +); + string insertSql = $"INSERT INTO {targetTableName} ({insertColumnsSql}) SELECT {sourceColumnsSql} FROM {Stag +gingTableName} AS s WHERE NOT EXISTS (SELECT 1 FROM {targetTableName} AS t WHERE {joinCondition})"; + rowsInserted[entityType] = Context.Database.ExecuteSqlInternal(insertSql, Options.CommandTimeout); + if (keepIdentity && rowsInserted[entityType] > 0) + SyncPostgreSqlIdentitySequence(entityType); + + rowsDeleted[entityType] = 0; + if (TableMapping.EntityType == entityType && delete) + { + // When IDs were preallocated (entities had Id=0), staging PKs are new sequences that don't match + // existing target PKs (UPDATE excludes PK from SET). Use the merge condition to identify rows to keep. + // When entities had explicit IDs, staging PKs match inserted/updated target rows → use PK-based delete. + string deleteJoinCondition = (preallocatedIds && mergeOnCondition != null) ? matchJoinCondition : pkJoin +nCondition; + string deleteSql = $"DELETE FROM {targetTableName} AS t WHERE NOT EXISTS (SELECT 1 FROM {StagingTableNam +me} AS s WHERE {deleteJoinCondition})"; + rowsDeleted[entityType] = Context.Database.ExecuteSqlInternal(deleteSql, Options.CommandTimeout); + for (int i = 0; i < rowsDeleted[entityType]; i++) + outputRows.Add(new BulkMergeOutputRow(SqlMergeAction.Delete)); + } + + if (autoMapOutput) + { + string outputColumnsSql = autoGeneratedColumns.Any() + ? "," + string.Join(",", autoGeneratedColumns.Select(c => Context.DelimitMemberAccess("t", c))) + : string.Empty; + var outputQuery = $"SELECT {Context.DelimitMemberAccess("s", Constants.InternalId_ColumnName)}{outputCol +lumnsSql} FROM {StagingTableName} AS s JOIN {targetTableName} AS t ON {matchJoinCondition}"; + var bulkQueryResult = Context.BulkQuery(outputQuery, Options); + var autoGeneratedColumnList = autoGeneratedColumns.ToList(); + foreach (var result in bulkQueryResult.Results) + { + int entityId = Convert.ToInt32(result[0]); + bool wasMatched = matchedIds.Contains(entityId); + string action = wasMatched ? SqlMergeAction.Update : SqlMergeAction.Insert; + outputRows.Add(new BulkMergeOutputRow(action)); + + if (entityMap.TryGetValue(entityId, out var entity) && allProperties.Count > 0) + { + object[] entityValues = allProperties.Select(p => result[1 + autoGeneratedColumnList.IndexOf(p.G +GetColumnName())]).ToArray(); + Context.SetStoreGeneratedValues(entity, allProperties, entityValues); + } + } + } + } + + return new BulkMergeResult + { + Output = outputRows, + RowsAffected = rowsInserted.Values.LastOrDefault() + rowsUpdated.Values.LastOrDefault() + rowsDeleted.Values +s.LastOrDefault(), + RowsDeleted = rowsDeleted.Values.LastOrDefault(), + RowsInserted = rowsInserted.Values.LastOrDefault(), + RowsUpdated = rowsUpdated.Values.LastOrDefault() + }; + } + private int ExecuteUpdatePostgreSql(Expression> updateOnCondition) + { + int rowsUpdated = 0; + foreach (var entityType in TableMapping.EntityTypes) + { + IEnumerable columnsToUpdate = GetColumnNames(entityType); + string updateSetExpression = string.Join(",", columnsToUpdate.Select(c => $"{Context.DelimitIdentifier(c)}={ +{Context.DelimitMemberAccess("s", c)}")); + string targetTableName = Context.DelimitIdentifier(entityType.GetTableName(), entityType.GetSchema() ?? Cont +text.Database.GetDefaultSchema()); + string updateSql = $"UPDATE {targetTableName} AS t SET {updateSetExpression} FROM {StagingTableName} AS s WH +HERE {CommonUtil.GetJoinConditionSql(Context, updateOnCondition, PrimaryKeyColumnNames, "s", "t")}"; + rowsUpdated = Context.Database.ExecuteSqlInternal(updateSql, Options.CommandTimeout); + } + return rowsUpdated; + } + private HashSet GetMatchedInternalIds(string targetTableName, string joinCondition) + { + var results = Context.BulkQuery( + $"SELECT {Context.DelimitMemberAccess("s", Constants.InternalId_ColumnName)} FROM {StagingTableName} AS s JO +OIN {targetTableName} AS t ON {joinCondition}", + Options); + return results.Results.Select(r => Convert.ToInt32(r[0])).ToHashSet(); + } + private IProperty GetGeneratedPrimaryKeyProperty() + { + return TableMapping.EntityType.GetProperties().SingleOrDefault(o => o.IsPrimaryKey() && o.ValueGenerated != Valu +ueGenerated.Never); + } + private void SyncPostgreSqlIdentitySequence(IEntityType entityType) + { + var identityProperty = entityType.GetProperties().SingleOrDefault(o => o.IsPrimaryKey() && o.ValueGenerated != V +ValueGenerated.Never); + if (identityProperty == null) + return; + + string tableName = Context.DelimitIdentifier(entityType.GetTableName(), entityType.GetSchema() ?? Context.Databa +ase.GetDefaultSchema()); + string columnName = Context.DelimitIdentifier(identityProperty.GetColumnName()); + string sequenceSql = $"SELECT setval(pg_get_serial_sequence('{tableName}', '{identityProperty.GetColumnName()}') +), COALESCE(MAX({columnName}), 0)) FROM {tableName}"; + Context.Database.ExecuteSqlInternal(sequenceSql, Options.CommandTimeout); + } + internal void ValidateBulkMerge(Expression> mergeOnCondition) + { + if (PrimaryKeyColumnNames.Length == 0 && mergeOnCondition == null) + throw new InvalidDataException("BulkMerge requires that the entity have a primary key or that Options.MergeO +OnCondition be set"); + } + internal void ValidateBulkUpdate(Expression> updateOnCondition) + { + if (PrimaryKeyColumnNames.Length == 0 && updateOnCondition == null) + throw new InvalidDataException("BulkUpdate requires that the entity have a primary key or the Options.Update +eOnCondition must be set."); + + } + internal IEnumerable GetColumnNames(bool includePrimaryKeys = false) + { + return GetColumnNames(null, includePrimaryKeys); + } + internal IEnumerable GetColumnNames(IEntityType entityType, bool includePrimaryKeys = false) + { + return CommonUtil.FilterColumns(TableMapping.GetColumnNames(entityType, includePrimaryKeys), PrimaryKeyColumnNam +mes, InputColumns, IgnoreColumns); + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Data\BulkOperat +tionAsync.cs --- + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Linq.Expressions; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Data.SqlClient; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.ChangeTracking.Internal; +using Microsoft.EntityFrameworkCore.Metadata; +using N.EntityFrameworkCore.Extensions.Common; +using N.EntityFrameworkCore.Extensions.Sql; +using N.EntityFrameworkCore.Extensions.Util; + +namespace N.EntityFrameworkCore.Extensions; + +internal sealed partial class BulkOperation +{ + internal async Task> BulkInsertStagingDataAsync(IEnumerable entities, bool keepIdentity = tru +ue, bool useInternalId = false, CancellationToken cancellationToken = default) + { + IEnumerable columnsToInsert = GetColumnNames(keepIdentity); + string internalIdColumn = useInternalId ? Common.Constants.InternalId_ColumnName : null; + await Context.Database.CloneTableAsync(SchemaQualifiedTableNames, StagingTableName, TableMapping.GetQualifiedCol +lumnNames(columnsToInsert), internalIdColumn, cancellationToken); + StagingTableCreated = true; + return await DbContextExtensionsAsync.BulkInsertAsync(entities, Options, TableMapping, Connection, Transaction, + StagingTableName, columnsToInsert, SqlBulkCopyOptions.KeepIdentity, useInternalId, cancellationToken); + } + + internal async Task> ExecuteMergeAsync(Dictionary entityMap, Expression +>> mergeOnCondition, + bool autoMapOutput, bool keepIdentity, bool insertIfNotExists, bool update = false, bool delete = false, bool pr +reallocatedIds = false, CancellationToken cancellationToken = default) + { + if (Context.Database.IsPostgreSql()) + return await ExecuteMergePostgreSqlAsync(entityMap, mergeOnCondition, autoMapOutput, keepIdentity, insertIfN +NotExists, update, delete, preallocatedIds, cancellationToken); + + Dictionary rowsInserted = []; + Dictionary rowsUpdated = []; + Dictionary rowsDeleted = []; + Dictionary rowsAffected = []; + List> outputRows = []; + + foreach (var entityType in TableMapping.EntityTypes) + { + rowsInserted[entityType] = 0; + rowsUpdated[entityType] = 0; + rowsDeleted[entityType] = 0; + rowsAffected[entityType] = 0; + + var columnsToInsert = GetColumnNames(entityType, keepIdentity); + var columnsToUpdate = update ? GetColumnNames(entityType) : []; + var autoGeneratedColumns = autoMapOutput ? TableMapping.GetAutoGeneratedColumns(entityType) : []; + var columnsToOutput = autoMapOutput ? GetMergeOutputColumns(autoGeneratedColumns, delete) : []; + var deleteEntityType = TableMapping.EntityType == entityType && delete ? delete : false; + + string mergeOnConditionSql = insertIfNotExists ? CommonUtil.GetJoinConditionSql(mergeOnCondition, Primary +yKeyColumnNames, "t", "s") : "1=2"; + bool toggleIdentity = keepIdentity && TableMapping.HasIdentityColumn; + var mergeStatement = SqlStatement.CreateMerge(StagingTableName, entityType.GetSchemaQualifiedTableName(), + mergeOnConditionSql, columnsToInsert, columnsToUpdate, columnsToOutput, deleteEntityType, toggleIdentity +y); + + if (autoMapOutput) + { + List allProperties = + [ + .. TableMapping.GetEntityProperties(entityType, ValueGenerated.OnAdd).ToArray(), + .. TableMapping.GetEntityProperties(entityType, ValueGenerated.OnAddOrUpdate).ToArray() + ]; + + var bulkQueryResult = await Context.BulkQueryAsync(mergeStatement.Sql, Connection, Transaction, Options, +, cancellationToken); + rowsAffected[entityType] = bulkQueryResult.RowsAffected; + + foreach (var result in bulkQueryResult.Results) + { + string action = (string)result[0]; + outputRows.Add(new BulkMergeOutputRow(action)); + + if (action == SqlMergeAction.Delete) + { + rowsDeleted[entityType]++; + } + else + { + int entityId = (int)result[1]; + var entity = entityMap[entityId]; + if (action == SqlMergeAction.Insert) + { + rowsInserted[entityType]++; + if (allProperties.Count != 0) + { + var entityValues = GetMergeOutputValues(columnsToOutput, result, allProperties); + Context.SetStoreGeneratedValues(entity, allProperties, entityValues); + } + } + else if (action == SqlMergeAction.Update) + { + rowsUpdated[entityType]++; + if (allProperties.Count != 0) + { + var entityValues = GetMergeOutputValues(columnsToOutput, result, allProperties); + Context.SetStoreGeneratedValues(entity, allProperties, entityValues); + } + } + } + } + } + else + { + rowsAffected[entityType] = await Context.Database.ExecuteSqlAsync(mergeStatement.Sql, Options.CommandTim +meout, cancellationToken); + } + } + return new BulkMergeResult + { + Output = outputRows, + RowsAffected = rowsAffected.Values.LastOrDefault(), + RowsDeleted = rowsDeleted.Values.LastOrDefault(), + RowsInserted = rowsInserted.Values.LastOrDefault(), + RowsUpdated = rowsUpdated.Values.LastOrDefault() + }; + } + internal async Task ExecuteUpdateAsync(IEnumerable entities, Expression> updateOnCondition, +, CancellationToken cancellationToken = default) + { + if (Context.Database.IsPostgreSql()) + return await ExecuteUpdatePostgreSqlAsync(updateOnCondition, cancellationToken); + + int rowsUpdated = 0; + foreach (var entityType in TableMapping.EntityTypes) + { + IEnumerable columnsToUpdate = CommonUtil.FormatColumns(GetColumnNames(entityType)); + string updateSetExpression = string.Join(",", columnsToUpdate.Select(o => $"t.{o}=s.{o}")); + string updateSql = $"UPDATE t SET {updateSetExpression} FROM {StagingTableName} AS s JOIN {CommonUtil.Format +tTableName(entityType.GetSchemaQualifiedTableName())} AS t ON {CommonUtil.GetJoinConditionSql(updateOnCondition, Prima +aryKeyColumnNames, "s", "t")}; SELECT @@RowCount;"; + rowsUpdated = await Context.Database.ExecuteSqlAsync(updateSql, Options.CommandTimeout, cancellationToken); + } + return rowsUpdated; + } + private async Task> ExecuteMergePostgreSqlAsync(Dictionary entityMap, Expression> mergeOnCondition, + bool autoMapOutput, bool keepIdentity, bool insertIfNotExists, bool update, bool delete, bool preallocatedIds = + false, CancellationToken cancellationToken = default) + { + Dictionary rowsInserted = []; + Dictionary rowsUpdated = []; + Dictionary rowsDeleted = []; + List> outputRows = []; + + foreach (var entityType in TableMapping.EntityTypes) + { + var targetTableName = Context.DelimitIdentifier(entityType.GetTableName(), entityType.GetSchema() ?? Context +t.Database.GetDefaultSchema()); + var columnsToInsert = GetColumnNames(entityType, keepIdentity).ToList(); + var columnsToUpdate = update ? GetColumnNames(entityType).ToList() : []; + var autoGeneratedColumns = autoMapOutput ? TableMapping.GetAutoGeneratedColumns(entityType).ToList() : []; + var allProperties = autoMapOutput + ? TableMapping.GetEntityProperties(entityType, ValueGenerated.OnAdd).Concat(TableMapping.GetEntityProper +rties(entityType, ValueGenerated.OnAddOrUpdate)).ToList() + : []; + + string matchJoinCondition = CommonUtil.GetJoinConditionSql(Context, mergeOnCondition, PrimaryKeyColumnNam +mes, "s", "t"); + string pkJoinCondition = CommonUtil.GetJoinConditionSql(Context, null, PrimaryKeyColumnNames, "s", "t"); + string joinCondition = insertIfNotExists ? matchJoinCondition : "1=2"; + + HashSet matchedIds = autoMapOutput && update + ? await GetMatchedInternalIdsAsync(targetTableName, matchJoinCondition, cancellationToken) + : []; + + rowsUpdated[entityType] = 0; + if (columnsToUpdate.Count > 0) + { + string updateSetExpression = string.Join(",", columnsToUpdate.Select(c => $"{Context.DelimitIdentifier(c +c)}={Context.DelimitMemberAccess("s", c)}")); + string updateSql = $"UPDATE {targetTableName} AS t SET {updateSetExpression} FROM {StagingTableName} AS + s WHERE {joinCondition}"; + rowsUpdated[entityType] = await Context.Database.ExecuteSqlAsync(updateSql, Options.CommandTimeout, canc +cellationToken); + } + + string insertColumnsSql = string.Join(",", columnsToInsert.Select(Context.DelimitIdentifier)); + string sourceColumnsSql = string.Join(",", columnsToInsert.Select(c => Context.DelimitMemberAccess("s", c))) +); + string insertSql = $"INSERT INTO {targetTableName} ({insertColumnsSql}) SELECT {sourceColumnsSql} FROM {Stag +gingTableName} AS s WHERE NOT EXISTS (SELECT 1 FROM {targetTableName} AS t WHERE {joinCondition})"; + rowsInserted[entityType] = await Context.Database.ExecuteSqlAsync(insertSql, Options.CommandTimeout, cancell +lationToken); + if (keepIdentity && rowsInserted[entityType] > 0) + await SyncPostgreSqlIdentitySequenceAsync(entityType, cancellationToken); + + rowsDeleted[entityType] = 0; + if (TableMapping.EntityType == entityType && delete) + { + string deleteJoinCondition = (preallocatedIds && mergeOnCondition != null) ? matchJoinCondition : pkJoin +nCondition; + string deleteSql = $"DELETE FROM {targetTableName} AS t WHERE NOT EXISTS (SELECT 1 FROM {StagingTableNam +me} AS s WHERE {deleteJoinCondition})"; + rowsDeleted[entityType] = await Context.Database.ExecuteSqlAsync(deleteSql, Options.CommandTimeout, canc +cellationToken); + for (int i = 0; i < rowsDeleted[entityType]; i++) + outputRows.Add(new BulkMergeOutputRow(SqlMergeAction.Delete)); + } + + if (autoMapOutput) + { + string outputColumnsSql = autoGeneratedColumns.Any() + ? "," + string.Join(",", autoGeneratedColumns.Select(c => Context.DelimitMemberAccess("t", c))) + : string.Empty; + string outputQuery = $"SELECT {Context.DelimitMemberAccess("s", Constants.InternalId_ColumnName)}{output +tColumnsSql} FROM {StagingTableName} AS s JOIN {targetTableName} AS t ON {matchJoinCondition}"; + var bulkQueryResult = await Context.BulkQueryAsync(outputQuery, Connection, Transaction, Options, cancel +llationToken); + var autoGeneratedColumnList = autoGeneratedColumns.ToList(); + foreach (var result in bulkQueryResult.Results) + { + int entityId = Convert.ToInt32(result[0]); + string action = matchedIds.Contains(entityId) ? SqlMergeAction.Update : SqlMergeAction.Insert; + outputRows.Add(new BulkMergeOutputRow(action)); + + if (entityMap.TryGetValue(entityId, out var entity) && allProperties.Count > 0) + { + object[] entityValues = allProperties.Select(p => result[1 + autoGeneratedColumnList.IndexOf(p.G +GetColumnName())]).ToArray(); + Context.SetStoreGeneratedValues(entity, allProperties, entityValues); + } + } + } + } + + return new BulkMergeResult + { + Output = outputRows, + RowsAffected = rowsInserted.Values.LastOrDefault() + rowsUpdated.Values.LastOrDefault() + rowsDeleted.Values +s.LastOrDefault(), + RowsDeleted = rowsDeleted.Values.LastOrDefault(), + RowsInserted = rowsInserted.Values.LastOrDefault(), + RowsUpdated = rowsUpdated.Values.LastOrDefault() + }; + } + private async Task ExecuteUpdatePostgreSqlAsync(Expression> updateOnCondition, CancellationTok +ken cancellationToken) + { + int rowsUpdated = 0; + foreach (var entityType in TableMapping.EntityTypes) + { + IEnumerable columnsToUpdate = GetColumnNames(entityType); + string updateSetExpression = string.Join(",", columnsToUpdate.Select(c => $"{Context.DelimitIdentifier(c)}={ +{Context.DelimitMemberAccess("s", c)}")); + string targetTableName = Context.DelimitIdentifier(entityType.GetTableName(), entityType.GetSchema() ?? Cont +text.Database.GetDefaultSchema()); + string updateSql = $"UPDATE {targetTableName} AS t SET {updateSetExpression} FROM {StagingTableName} AS s WH +HERE {CommonUtil.GetJoinConditionSql(Context, updateOnCondition, PrimaryKeyColumnNames, "s", "t")}"; + rowsUpdated = await Context.Database.ExecuteSqlAsync(updateSql, Options.CommandTimeout, cancellationToken); + } + return rowsUpdated; + } + private async Task> GetMatchedInternalIdsAsync(string targetTableName, string joinCondition, Cancellati +ionToken cancellationToken) + { + var results = await Context.BulkQueryAsync( + $"SELECT {Context.DelimitMemberAccess("s", Constants.InternalId_ColumnName)} FROM {StagingTableName} AS s JO +OIN {targetTableName} AS t ON {joinCondition}", + Connection, Transaction, Options, cancellationToken); + return results.Results.Select(r => Convert.ToInt32(r[0])).ToHashSet(); + } + internal async Task PreallocateIdentityValuesAsync(IEnumerable entities, CancellationToken cancellationToken) + { + var identityProperty = GetGeneratedPrimaryKeyProperty(); + if (identityProperty?.PropertyInfo == null) + return; + + var entityList = entities.ToList(); + if (entityList.Count == 0) + return; + + string tableName = Context.DelimitIdentifier(TableMapping.EntityType.GetTableName(), TableMapping.EntityType.Get +tSchema() ?? Context.Database.GetDefaultSchema()); + string sequenceSql = $"SELECT nextval(pg_get_serial_sequence('{tableName}', '{identityProperty.GetColumnName()}' +')) FROM generate_series(1, {entityList.Count})"; + await using var command = Connection.CreateCommand(); + command.CommandText = sequenceSql; + command.Transaction = Transaction; + await using var reader = await command.ExecuteReaderAsync(cancellationToken); + foreach (var entity in entityList) + { + if (!await reader.ReadAsync(cancellationToken)) + throw new InvalidDataException("Failed to allocate PostgreSql identity values."); + + object sequenceValue = Convert.ChangeType(reader.GetValue(0), identityProperty.ClrType); + if (entity is InternalEntityEntry internalEntry) + internalEntry.SetStoreGeneratedValue(identityProperty, sequenceValue); + else + identityProperty.PropertyInfo.SetValue(entity, sequenceValue); + } + } + private async Task SyncPostgreSqlIdentitySequenceAsync(IEntityType entityType, CancellationToken cancellationToken) + { + var identityProperty = entityType.GetProperties().SingleOrDefault(o => o.IsPrimaryKey() && o.ValueGenerated != V +ValueGenerated.Never); + if (identityProperty == null) + return; + + string tableName = Context.DelimitIdentifier(entityType.GetTableName(), entityType.GetSchema() ?? Context.Databa +ase.GetDefaultSchema()); + string columnName = Context.DelimitIdentifier(identityProperty.GetColumnName()); + string sequenceSql = $"SELECT setval(pg_get_serial_sequence('{tableName}', '{identityProperty.GetColumnName()}') +), COALESCE(MAX({columnName}), 0)) FROM {tableName}"; + await Context.Database.ExecuteSqlAsync(sequenceSql, Options.CommandTimeout, cancellationToken); + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Data\BulkOption +ns.cs --- + +using Microsoft.Data.SqlClient; +using Microsoft.EntityFrameworkCore.Metadata; +using N.EntityFrameworkCore.Extensions.Enums; + +namespace N.EntityFrameworkCore.Extensions; + +public class BulkOptions +{ + public int BatchSize { get; set; } + public SqlBulkCopyOptions BulkCopyOptions { get; internal set; } + public SqlBulkCopyColumnOrderHintCollection ColumnOrderHints { get; internal set; } + public bool EnableStreaming { get; internal set; } + public int NotifyAfter { get; internal set; } + public bool UsePermanentTable { get; set; } + public int? CommandTimeout { get; set; } + internal ConnectionBehavior ConnectionBehavior { get; set; } + internal IEntityType EntityType { get; set; } + + public SqlRowsCopiedEventHandler SqlRowsCopied { get; internal set; } + + public BulkOptions() + { + BulkCopyOptions = SqlBulkCopyOptions.Default; + ColumnOrderHints = new SqlBulkCopyColumnOrderHintCollection(); + ConnectionBehavior = ConnectionBehavior.Default; + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Data\BulkQueryR +Result.cs --- + +using System.Collections.Generic; + +namespace N.EntityFrameworkCore.Extensions; + +public class BulkQueryResult +{ + public IEnumerable Results { get; internal set; } + public IEnumerable Columns { get; internal set; } + public int RowsAffected { get; internal set; } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Data\BulkSyncOp +ptions.cs --- + + +namespace N.EntityFrameworkCore.Extensions; + +public class BulkSyncOptions : BulkMergeOptions +{ + public BulkSyncOptions() + { + DeleteIfNotMatched = true; + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Data\BulkSyncRe +esult.cs --- + + +namespace N.EntityFrameworkCore.Extensions; + +public class BulkSyncResult : BulkMergeResult +{ + public new int RowsDeleted { get; set; } + public static BulkSyncResult Map(BulkMergeResult result) + { + return new BulkSyncResult() + { + Output = result.Output, + RowsAffected = result.RowsAffected, + RowsDeleted = result.RowsDeleted, + RowsInserted = result.RowsInserted, + RowsUpdated = result.RowsUpdated + }; + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Data\BulkUpdate +eOptions.cs --- + +using System; +using System.Linq.Expressions; + +namespace N.EntityFrameworkCore.Extensions; + +public class BulkUpdateOptions : BulkOptions +{ + public Expression> InputColumns { get; set; } + public Expression> IgnoreColumns { get; set; } + public Expression> UpdateOnCondition { get; set; } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Data\DatabaseFa +acadeExtensions.cs --- + +using System; +using System.Collections.Generic; +using System.Data; +using System.Data.Common; +using System.Linq; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Storage; +using N.EntityFrameworkCore.Extensions.Enums; +using N.EntityFrameworkCore.Extensions.Util; + +namespace N.EntityFrameworkCore.Extensions; + +public static class DatabaseFacadeExtensions +{ + public static SqlQuery FromSqlQuery(this DatabaseFacade database, string sqlText, params object[] parameters) + { + return new SqlQuery(database, sqlText, parameters); + } + public static int ClearTable(this DatabaseFacade database, string tableName) + { + return database.ExecuteSqlRaw($"DELETE FROM {database.DelimitTableName(tableName)}"); + } + public static int DropTable(this DatabaseFacade database, string tableName, bool ifExists = false) + { + string formattedTableName = database.DelimitTableName(tableName); + string sql = ifExists ? $"DROP TABLE IF EXISTS {formattedTableName}" : $"DROP TABLE {formattedTableName}"; + return database.ExecuteSqlInternal(sql, null, ConnectionBehavior.Default); + } + public static void TruncateTable(this DatabaseFacade database, string tableName, bool ifExists = false) + { + bool truncateTable = !ifExists || database.TableExists(tableName); + if (!truncateTable) + return; + + string formattedTableName = database.DelimitTableName(tableName); + string sql = database.IsPostgreSql() + ? $"TRUNCATE TABLE {formattedTableName} RESTART IDENTITY" + : $"TRUNCATE TABLE {formattedTableName}"; + database.ExecuteSqlRaw(sql); + } + public static bool TableExists(this DatabaseFacade database, string tableName) + { + var objectName = database.ParseObjectName(tableName); + return Convert.ToBoolean(database.ExecuteScalar( + database.IsPostgreSql() + ? "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_schema = @schema AND table_name = + @name)" + : "SELECT CASE WHEN EXISTS (SELECT 1 FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = @schema AND TAB +BLE_NAME = @name) THEN 1 ELSE 0 END", + [CreateParameter(database, "@schema", objectName.Schema), CreateParameter(database, "@name", objectName.Name +e)])); + } + public static bool TableHasIdentity(this DatabaseFacade database, string tableName) + { + var objectName = database.ParseObjectName(tableName); + string sql = database.IsPostgreSql() + ? """ + SELECT EXISTS ( + SELECT 1 + FROM information_schema.columns + WHERE table_schema = @schema + AND table_name = @name + AND (is_identity = 'YES' OR column_default LIKE 'nextval(%') + ) + """ + : "SELECT ISNULL(OBJECTPROPERTY(OBJECT_ID(@fullName), 'TableHasIdentity'), 0)"; + + object[] parameters = database.IsPostgreSql() + ? [CreateParameter(database, "@schema", objectName.Schema), CreateParameter(database, "@name", objectName.Na +ame)] + : [CreateParameter(database, "@fullName", $"{objectName.Schema}.{objectName.Name}")]; + + return Convert.ToBoolean(database.ExecuteScalar(sql, parameters)); + } + internal static int CloneTable(this DatabaseFacade database, string sourceTable, string destinationTable, IEnumerabl +le columnNames, string internalIdColumnName = null) + { + return database.CloneTable([sourceTable], destinationTable, columnNames, internalIdColumnName); + } + internal static int CloneTable(this DatabaseFacade database, IEnumerable sourceTables, string destinationTab +ble, IEnumerable columnNames, string internalIdColumnName = null) + { + string columns = columnNames != null && columnNames.Any() ? string.Join(",", columnNames.Select(database.FormatS +SelectColumn)) : "*"; + if (!string.IsNullOrEmpty(internalIdColumnName)) + columns = $"{columns},CAST(NULL AS INT) AS {database.DelimitIdentifier(internalIdColumnName)}"; + + string sql = database.IsPostgreSql() + ? $"CREATE TABLE {destinationTable} AS SELECT {columns} FROM {string.Join(",", sourceTables)} LIMIT 0" + : $"SELECT TOP 0 {columns} INTO {destinationTable} FROM {string.Join(",", sourceTables)}"; + return database.ExecuteSqlRaw(sql); + } + internal static DbCommand CreateCommand(this DatabaseFacade database, ConnectionBehavior connectionBehavior = Connec +ctionBehavior.Default) + { + var dbConnection = database.GetDbConnection(connectionBehavior); + if (dbConnection.State != ConnectionState.Open) + dbConnection.Open(); + var command = dbConnection.CreateCommand(); + if (database.CurrentTransaction != null && connectionBehavior == ConnectionBehavior.Default) + command.Transaction = database.CurrentTransaction.GetDbTransaction(); + return command; + } + internal static int ExecuteSqlInternal(this DatabaseFacade database, string sql, int? commandTimeout = null, Connect +tionBehavior connectionBehavior = default) + { + return database.ExecuteSql(sql, null, commandTimeout, connectionBehavior); + } + internal static int ExecuteSql(this DatabaseFacade database, string sql, object[] parameters = null, int? commandTim +meout = null, ConnectionBehavior connectionBehavior = default) + { + using var command = database.CreateCommand(connectionBehavior); + command.CommandText = sql; + if (commandTimeout != null) + command.CommandTimeout = commandTimeout.Value; + if (parameters != null) + command.Parameters.AddRange(parameters); + return command.ExecuteNonQuery(); + } + internal static object ExecuteScalar(this DatabaseFacade database, string query, object[] parameters = null, int? co +ommandTimeout = null) + { + using var command = database.CreateCommand(); + command.CommandText = query; + if (commandTimeout.HasValue) + command.CommandTimeout = commandTimeout.Value; + if (parameters != null) + command.Parameters.AddRange(parameters); + return command.ExecuteScalar(); + } + internal static void ToggleIdentityInsert(this DatabaseFacade database, string tableName, bool enable) + { + if (database.IsPostgreSql()) + return; + + bool hasIdentity = database.TableHasIdentity(tableName); + if (hasIdentity) + { + string boolString = enable ? "ON" : "OFF"; + database.ExecuteSql($"SET IDENTITY_INSERT {tableName} {boolString}"); + } + } + internal static DbConnection GetDbConnection(this DatabaseFacade database, ConnectionBehavior connectionBehavior) + { + return connectionBehavior == ConnectionBehavior.New ? database.GetDbConnection().CloneConnection() : database.Ge +etDbConnection(); + } + + private static DbParameter CreateParameter(DatabaseFacade database, string name, object value) + { + using var command = database.GetDbConnection().CreateCommand(); + var parameter = command.CreateParameter(); + parameter.ParameterName = name; + parameter.Value = value ?? DBNull.Value; + return parameter; + } + internal static string FormatSelectColumn(this DatabaseFacade database, string columnName) + { + if (columnName.Contains('[') || columnName.Contains('"') || columnName.Contains('(') || columnName.Contains(' ') +)) + return columnName; + + if (columnName.Contains('.')) + return string.Join(".", columnName.Split('.').Select(database.DelimitIdentifier)); + + return database.DelimitIdentifier(columnName); + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Data\DatabaseFa +acadeExtensionsAsync.cs --- + +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Infrastructure; +using N.EntityFrameworkCore.Extensions.Util; + +namespace N.EntityFrameworkCore.Extensions; + +public static class DatabaseFacadeExtensionsAsync +{ + public static async Task ClearTableAsync(this DatabaseFacade database, string tableName, CancellationToken canc +cellationToken = default) + { + return await database.ExecuteSqlRawAsync($"DELETE FROM {database.DelimitTableName(tableName)}", cancellationToke +en); + } + public static async Task TruncateTableAsync(this DatabaseFacade database, string tableName, bool ifExists = false, C +CancellationToken cancellationToken = default) + { + bool truncateTable = !ifExists || database.TableExists(tableName); + if (!truncateTable) + return; + + string formattedTableName = database.DelimitTableName(tableName); + string sql = database.IsPostgreSql() + ? $"TRUNCATE TABLE {formattedTableName} RESTART IDENTITY" + : $"TRUNCATE TABLE {formattedTableName}"; + await database.ExecuteSqlRawAsync(sql, cancellationToken); + } + internal static async Task CloneTableAsync(this DatabaseFacade database, string sourceTable, string destination +nTable, IEnumerable columnNames, string internalIdColumnName = null, CancellationToken cancellationToken = defaul +lt) + { + return await database.CloneTableAsync([sourceTable], destinationTable, columnNames, internalIdColumnName, cancel +llationToken); + } + internal static async Task CloneTableAsync(this DatabaseFacade database, IEnumerable sourceTables, stri +ing destinationTable, IEnumerable columnNames, string internalIdColumnName = null, CancellationToken cancellation +nToken = default) + { + string columns = columnNames != null && columnNames.Any() ? string.Join(",", columnNames.Select(database.FormatS +SelectColumn)) : "*"; + if (!string.IsNullOrEmpty(internalIdColumnName)) + columns = $"{columns},CAST(NULL AS INT) AS {database.DelimitIdentifier(internalIdColumnName)}"; + + string sql = database.IsPostgreSql() + ? $"CREATE TABLE {destinationTable} AS SELECT {columns} FROM {string.Join(",", sourceTables)} LIMIT 0" + : $"SELECT TOP 0 {columns} INTO {destinationTable} FROM {string.Join(",", sourceTables)}"; + return await database.ExecuteSqlRawAsync(sql, cancellationToken); + } + internal static async Task ExecuteSqlAsync(this DatabaseFacade database, string sql, int? commandTimeout = null +l, CancellationToken cancellationToken = default) + { + return await database.ExecuteSqlAsync(sql, null, commandTimeout, cancellationToken); + } + internal static async Task ExecuteSqlAsync(this DatabaseFacade database, string sql, object[] parameters = null +l, int? commandTimeout = null, CancellationToken cancellationToken = default) + { + int value; + int? origCommandTimeout = database.GetCommandTimeout(); + database.SetCommandTimeout(commandTimeout); + value = parameters != null + ? await database.ExecuteSqlRawAsync(sql, parameters, cancellationToken) + : await database.ExecuteSqlRawAsync(sql, cancellationToken); + database.SetCommandTimeout(origCommandTimeout); + return value; + } + internal static async Task ExecuteScalarAsync(this DatabaseFacade database, string query, object[] parameter +rs = null, int? commandTimeout = null, CancellationToken cancellationToken = default) + { + await using var command = database.CreateCommand(); + command.CommandText = query; + if (commandTimeout.HasValue) + command.CommandTimeout = commandTimeout.Value; + if (parameters != null) + command.Parameters.AddRange(parameters); + return await command.ExecuteScalarAsync(cancellationToken); + } + internal static async Task ToggleIdentityInsertAsync(this DatabaseFacade database, string tableName, bool enable) + { + if (database.IsPostgreSql()) + return; + + bool hasIdentity = database.TableHasIdentity(tableName); + if (hasIdentity) + { + string boolString = enable ? "ON" : "OFF"; + await database.ExecuteSqlAsync($"SET IDENTITY_INSERT {tableName} {boolString}", database.GetCommandTimeout() +)); + } + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Data\DbContextE +Extensions.cs --- + +using System; +using System.Collections.Generic; +using System.Data; +using System.Data.Common; +using System.IO; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using Microsoft.Data.SqlClient; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.ChangeTracking.Internal; +using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Internal; +using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.EntityFrameworkCore.Query; +using Microsoft.EntityFrameworkCore.Query.Internal; +using Npgsql; +using N.EntityFrameworkCore.Extensions.Common; +using N.EntityFrameworkCore.Extensions.Enums; +using N.EntityFrameworkCore.Extensions.Extensions; +using N.EntityFrameworkCore.Extensions.Sql; +using N.EntityFrameworkCore.Extensions.Util; + +namespace N.EntityFrameworkCore.Extensions; + +public static class DbContextExtensions +{ + private static readonly EfExtensionsCommandInterceptor efExtensionsCommandInterceptor; + static DbContextExtensions() + { + efExtensionsCommandInterceptor = new EfExtensionsCommandInterceptor(); + } + public static void SetupEfCoreExtensions(this DbContextOptionsBuilder builder) + { + builder.AddInterceptors(efExtensionsCommandInterceptor); + } + public static int BulkDelete(this DbContext context, IEnumerable entities) + { + return context.BulkDelete(entities, new BulkDeleteOptions()); + } + public static int BulkDelete(this DbContext context, IEnumerable entities, Action> option +nsAction) + { + return context.BulkDelete(entities, optionsAction.Build()); + } + public static int BulkDelete(this DbContext context, IEnumerable entities, BulkDeleteOptions options) + { + var tableMapping = context.GetTableMapping(typeof(T), options.EntityType); + + using (var dbTransactionContext = new DbTransactionContext(context, options)) + { + var dbConnection = dbTransactionContext.Connection; + var transaction = dbTransactionContext.CurrentTransaction; + int rowsAffected = 0; + try + { + string stagingTableName = CommonUtil.GetStagingTableName(tableMapping, options.UsePermanentTable, dbConn +nection); + string destinationTableName = context.DelimitIdentifier(tableMapping.TableName, tableMapping.Schema); + string[] keyColumnNames = options.DeleteOnCondition != null ? CommonUtil.GetColumns(options.DeleteOnC +Condition, ["s"]) + : tableMapping.GetPrimaryKeyColumns().ToArray(); + + if (keyColumnNames.Length == 0 && options.DeleteOnCondition == null) + throw new InvalidDataException("BulkDelete requires that the entity have a primary key or the Option +ns.DeleteOnCondition must be set."); + + context.Database.CloneTable(destinationTableName, stagingTableName, keyColumnNames); + BulkInsert(entities, options, tableMapping, dbConnection, transaction, stagingTableName, keyColumnNames, +, SqlBulkCopyOptions.KeepIdentity, false); + + string joinCondition = CommonUtil.GetJoinConditionSql(context, options.DeleteOnCondition, keyColumnNa +ames); + string deleteSql = context.Database.IsPostgreSql() + ? $"DELETE FROM {destinationTableName} AS t USING {stagingTableName} AS s WHERE {joinCondition}" + : $"DELETE t FROM {stagingTableName} s JOIN {destinationTableName} t ON {joinCondition}"; + rowsAffected = context.Database.ExecuteSqlInternal(deleteSql, options.CommandTimeout); + + context.Database.DropTable(stagingTableName); + dbTransactionContext.Commit(); + } + catch (Exception) + { + dbTransactionContext.Rollback(); + throw; + } + return rowsAffected; + } + } + public static IEnumerable BulkFetch(this DbSet dbSet, IEnumerable entities) where T : class, new() + { + return dbSet.BulkFetch(entities, new BulkFetchOptions()); + } + public static IEnumerable BulkFetch(this DbSet dbSet, IEnumerable entities, Action> optionsAction) where T : class, new() + { + return dbSet.BulkFetch(entities, optionsAction.Build()); + } + public static IEnumerable BulkFetch(this DbSet dbSet, IEnumerable entities, BulkFetchOptions optio +ons) where T : class, new() + { + var context = dbSet.GetDbContext(); + var tableMapping = context.GetTableMapping(typeof(T)); + + using (var dbTransactionContext = new DbTransactionContext(context, options.CommandTimeout, ConnectionBehavior.N +New)) + { + string selectSql, stagingTableName = string.Empty; + var dbConnection = dbTransactionContext.Connection; + var transaction = dbTransactionContext.CurrentTransaction; + try + { + stagingTableName = CommonUtil.GetStagingTableName(tableMapping, true, dbConnection); + string destinationTableName = context.DelimitIdentifier(tableMapping.TableName, tableMapping.Schema); + string[] keyColumnNames = options.JoinOnCondition != null ? CommonUtil.GetColumns(options.JoinOnCondi +ition, ["s"]) + : tableMapping.GetPrimaryKeyColumns().ToArray(); + IEnumerable columnNames = CommonUtil.FilterColumns(tableMapping.GetColumns(true), keyColumnNa +ames, options.InputColumns, options.IgnoreColumns); + IEnumerable columnsToFetch = CommonUtil.FormatColumns(context, "t", columnNames); + + if (keyColumnNames.Length == 0 && options.JoinOnCondition == null) + throw new InvalidDataException("BulkFetch requires that the entity have a primary key or the Options +s.JoinOnCondition must be set."); + + context.Database.CloneTable(destinationTableName, stagingTableName, keyColumnNames); + BulkInsert(entities, options, tableMapping, dbConnection, transaction, stagingTableName, keyColumnNames, +, SqlBulkCopyOptions.KeepIdentity, false); + selectSql = $"SELECT {SqlUtil.ConvertToColumnString(columnsToFetch)} FROM {stagingTableName} s JOIN {des +stinationTableName} t ON {CommonUtil.GetJoinConditionSql(context, options.JoinOnCondition, keyColumnNames)}"; + + + dbTransactionContext.Commit(); + } + catch + { + dbTransactionContext.Rollback(); + throw; + } + + foreach (var item in context.FetchInternal(selectSql)) + { + yield return item; + } + context.Database.DropTable(stagingTableName); + } + } + public static void Fetch(this IQueryable queryable, Action> action, Action> opt +tionsAction) where T : class, new() + { + Fetch(queryable, action, optionsAction.Build()); + } + public static void Fetch(this IQueryable queryable, Action> action, FetchOptions options) wh +here T : class, new() + { + var dbContext = queryable.GetDbContext(); + var tableMapping = dbContext.GetTableMapping(typeof(T)); + HashSet includedColumns = GetIncludedColumns(tableMapping, options.InputColumns, options.IgnoreColumns); + int batch = 1; + int count = 0; + List entities = []; + foreach (var entity in queryable.AsNoTracking().AsEnumerable()) + { + ClearExcludedColumns(dbContext, tableMapping, entity, includedColumns); + entities.Add(entity); + count++; + if (count == options.BatchSize) + { + action(new FetchResult { Results = entities, Batch = batch }); + entities.Clear(); + count = 0; + batch++; + } + } + + if (entities.Count > 0) + action(new FetchResult { Results = entities, Batch = batch }); + } + public static int BulkInsert(this DbContext context, IEnumerable entities) + { + return context.BulkInsert(entities, new BulkInsertOptions()); + } + public static int BulkInsert(this DbContext context, IEnumerable entities, Action> option +nsAction) + { + return context.BulkInsert(entities, optionsAction.Build()); + } + public static int BulkInsert(this DbContext context, IEnumerable entities, BulkInsertOptions options) + { + int rowsAffected = 0; + using (var bulkOperation = new BulkOperation(context, options, options.InputColumns, options.IgnoreColumns)) + { + try + { + bool keepIdentity = options.KeepIdentity || bulkOperation.ShouldPreallocateIdentityValues(options.AutoMa +apOutput, options.KeepIdentity, entities); + if (keepIdentity && !options.KeepIdentity) + bulkOperation.PreallocateIdentityValues(entities); + var bulkInsertResult = bulkOperation.BulkInsertStagingData(entities, true, true); + var bulkMergeResult = bulkOperation.ExecuteMerge(bulkInsertResult.EntityMap, options.InsertOnCondition, + options.AutoMapOutput, keepIdentity, options.InsertIfNotExists); + rowsAffected = bulkMergeResult.RowsAffected; + bulkOperation.DbTransactionContext.Commit(); + } + catch (Exception) + { + bulkOperation.DbTransactionContext.Rollback(); + throw; + } + } + return rowsAffected; + } + public static BulkMergeResult BulkMerge(this DbContext context, IEnumerable entities) + { + return BulkMerge(context, entities, new BulkMergeOptions()); + } + public static BulkMergeResult BulkMerge(this DbContext context, IEnumerable entities, BulkMergeOptions o +options) + { + return InternalBulkMerge(context, entities, options); + } + public static BulkMergeResult BulkMerge(this DbContext context, IEnumerable entities, Action> optionsAction) + { + return BulkMerge(context, entities, optionsAction.Build()); + } + public static int BulkSaveChanges(this DbContext dbContext) + { + return dbContext.BulkSaveChanges(true); + } + public static int BulkSaveChanges(this DbContext dbContext, bool acceptAllChangesOnSuccess = true) + { + int rowsAffected = 0; + var stateManager = dbContext.GetDependencies().StateManager; + + dbContext.ChangeTracker.DetectChanges(); + var entries = stateManager.GetEntriesToSave(true); + + foreach (var saveEntryGroup in entries.GroupBy(o => new { o.EntityType, o.EntityState })) + { + var key = saveEntryGroup.Key; + var entities = saveEntryGroup.AsEnumerable(); + if (key.EntityState == EntityState.Added) + { + rowsAffected += dbContext.BulkInsert(entities, o => { o.EntityType = key.EntityType; }); + } + else if (key.EntityState == EntityState.Modified) + { + rowsAffected += dbContext.BulkUpdate(entities, o => { o.EntityType = key.EntityType; }); + } + else if (key.EntityState == EntityState.Deleted) + { + rowsAffected += dbContext.BulkDelete(entities, o => { o.EntityType = key.EntityType; }); + } + } + + if (acceptAllChangesOnSuccess) + dbContext.ChangeTracker.AcceptAllChanges(); + + return rowsAffected; + } + public static BulkSyncResult BulkSync(this DbContext context, IEnumerable entities) + { + return BulkSync(context, entities, new BulkSyncOptions()); + } + public static BulkSyncResult BulkSync(this DbContext context, IEnumerable entities, Action> optionsAction) + { + return BulkSyncResult.Map(InternalBulkMerge(context, entities, optionsAction.Build())); + } + public static BulkSyncResult BulkSync(this DbContext context, IEnumerable entities, BulkSyncOptions opti +ions) + { + return BulkSyncResult.Map(InternalBulkMerge(context, entities, options)); + } + public static int BulkUpdate(this DbContext context, IEnumerable entities) + { + return BulkUpdate(context, entities, new BulkUpdateOptions()); + } + public static int BulkUpdate(this DbContext context, IEnumerable entities, Action> option +nsAction) + { + return BulkUpdate(context, entities, optionsAction.Build()); + } + public static int BulkUpdate(this DbContext context, IEnumerable entities, BulkUpdateOptions options) + { + int rowsUpdated = 0; + using (var bulkOperation = new BulkOperation(context, options, options.InputColumns, options.IgnoreColumns)) + { + try + { + bulkOperation.ValidateBulkUpdate(options.UpdateOnCondition); + bulkOperation.BulkInsertStagingData(entities); + rowsUpdated = bulkOperation.ExecuteUpdate(entities, options.UpdateOnCondition); + bulkOperation.DbTransactionContext.Commit(); + } + catch (Exception) + { + bulkOperation.DbTransactionContext.Rollback(); + throw; + } + } + return rowsUpdated; + } + public static int DeleteFromQuery(this IQueryable queryable, int? commandTimeout = null) where T : class + { + using (var dbTransactionContext = new DbTransactionContext(queryable.GetDbContext(), commandTimeout)) + { + try + { + int rowsAffected = queryable.ExecuteDelete(); + dbTransactionContext.Commit(); + return rowsAffected; + } + catch (Exception) + { + dbTransactionContext.Rollback(); + throw; + } + } + } + public static int InsertFromQuery(this IQueryable queryable, string tableName, Expression> ins +sertObjectExpression, int? commandTimeout = null) where T : class + { + using (var dbTransactionContext = new DbTransactionContext(queryable.GetDbContext(), commandTimeout)) + { + var dbContext = dbTransactionContext.DbContext; + try + { + var tableMapping = dbContext.GetTableMapping(typeof(T)); + var columnNames = insertObjectExpression.GetObjectProperties(); + if (!dbContext.Database.TableExists(tableName)) + { + dbContext.Database.CloneTable(tableMapping.FullQualifedTableName, dbContext.Database.DelimitTableNam +me(tableName), tableMapping.GetQualifiedColumnNames(columnNames)); + } + + var entities = queryable.AsNoTracking().ToList(); + int rowsAffected = BulkInsert(entities, new BulkInsertOptions { KeepIdentity = true, AutoMapOutput = + false, CommandTimeout = commandTimeout }, tableMapping, + dbTransactionContext.Connection, dbTransactionContext.CurrentTransaction, dbContext.Database.Delimit +tTableName(tableName), columnNames, SqlBulkCopyOptions.KeepIdentity).RowsAffected; + dbTransactionContext.Commit(); + return rowsAffected; + } + catch (Exception) + { + dbTransactionContext.Rollback(); + throw; + } + } + } + public static int UpdateFromQuery(this IQueryable queryable, Expression> updateExpression, int? com +mmandTimeout = null) where T : class + { + using (var dbTransactionContext = new DbTransactionContext(queryable.GetDbContext(), commandTimeout)) + { + try + { + int rowsAffected = queryable.ExecuteUpdate(BuildSetPropertyCalls(updateExpression)); + dbTransactionContext.Commit(); + return rowsAffected; + } + catch (Exception) + { + dbTransactionContext.Rollback(); + throw; + } + } + } + public static QueryToFileResult QueryToCsvFile(this IQueryable queryable, string filePath) where T : class + { + return QueryToCsvFile(queryable, filePath, new QueryToFileOptions()); + } + public static QueryToFileResult QueryToCsvFile(this IQueryable queryable, Stream stream) where T : class + { + return QueryToCsvFile(queryable, stream, new QueryToFileOptions()); + } + public static QueryToFileResult QueryToCsvFile(this IQueryable queryable, string filePath, Action optionsAction) where T : class + { + return QueryToCsvFile(queryable, filePath, optionsAction.Build()); + } + public static QueryToFileResult QueryToCsvFile(this IQueryable queryable, Stream stream, Action optionsAction) where T : class + { + return QueryToCsvFile(queryable, stream, optionsAction.Build()); + } + public static QueryToFileResult QueryToCsvFile(this IQueryable queryable, string filePath, QueryToFileOptions + options) where T : class + { + using var fileStream = File.Create(filePath); + return QueryToCsvFile(queryable, fileStream, options); + } + public static QueryToFileResult QueryToCsvFile(this IQueryable queryable, Stream stream, QueryToFileOptions op +ptions) where T : class + { + return InternalQueryToFile(queryable, stream, options); + } + public static QueryToFileResult SqlQueryToCsvFile(this DatabaseFacade database, string filePath, string sqlText, par +rams object[] parameters) + { + return SqlQueryToCsvFile(database, filePath, new QueryToFileOptions(), sqlText, parameters); + } + public static QueryToFileResult SqlQueryToCsvFile(this DatabaseFacade database, Stream stream, string sqlText, param +ms object[] parameters) + { + return SqlQueryToCsvFile(database, stream, new QueryToFileOptions(), sqlText, parameters); + } + public static QueryToFileResult SqlQueryToCsvFile(this DatabaseFacade database, string filePath, Action optionsAction, string sqlText, params object[] parameters) + { + return SqlQueryToCsvFile(database, filePath, optionsAction.Build(), sqlText, parameters); + } + public static QueryToFileResult SqlQueryToCsvFile(this DatabaseFacade database, Stream stream, Action optionsAction, string sqlText, params object[] parameters) + { + return SqlQueryToCsvFile(database, stream, optionsAction.Build(), sqlText, parameters); + } + public static QueryToFileResult SqlQueryToCsvFile(this DatabaseFacade database, string filePath, QueryToFileOptions + options, string sqlText, params object[] parameters) + { + using var fileStream = File.Create(filePath); + return SqlQueryToCsvFile(database, fileStream, options, sqlText, parameters); + } + public static QueryToFileResult SqlQueryToCsvFile(this DatabaseFacade database, Stream stream, QueryToFileOptions op +ptions, string sqlText, params object[] parameters) + { + return InternalQueryToFile(database.GetDbConnection(), stream, options, sqlText, parameters); + } + public static void Clear(this DbSet dbSet) where T : class + { + var dbContext = dbSet.GetDbContext(); + var tableMapping = dbContext.GetTableMapping(typeof(T)); + dbContext.Database.ClearTable(tableMapping.FullQualifedTableName); + } + public static void Truncate(this DbSet dbSet) where T : class + { + var dbContext = dbSet.GetDbContext(); + var tableMapping = dbContext.GetTableMapping(typeof(T)); + dbContext.Database.TruncateTable(tableMapping.FullQualifedTableName); + } + public static IQueryable UsingTable(this IQueryable queryable, string tableName) where T : class + { + var dbContext = queryable.GetDbContext(); + var tableMapping = dbContext.GetTableMapping(typeof(T)); + efExtensionsCommandInterceptor.AddCommand(Guid.NewGuid(), + new EfExtensionsCommand + { + CommandType = EfExtensionsCommandType.ChangeTableName, + OldValue = tableMapping.FullQualifedTableName, + NewValue = dbContext.Database.DelimitTableName(tableName), + Connection = dbContext.GetDbConnection() + }); + return queryable; + } + public static TableMapping GetTableMapping(this DbContext dbContext, Type type, IEntityType entityType = null) + { + entityType ??= dbContext.Model.FindEntityType(type); + return new TableMapping(dbContext, entityType); + } + internal static void SetStoreGeneratedValues(this DbContext context, T entity, IEnumerable properties, +, object[] values) + { + int index = 0; + var updateEntry = entity as InternalEntityEntry; + if (updateEntry == null) + { + var entry = context.Entry(entity); + updateEntry = entry.GetInfrastructure(); + } + + if (updateEntry != null) + { + foreach (var property in properties) + { + if ((updateEntry.EntityState == EntityState.Added && + (property.ValueGenerated == ValueGenerated.OnAdd || property.ValueGenerated == ValueGenerated.OnAddO +OrUpdate)) || + (updateEntry.EntityState == EntityState.Modified && + (property.ValueGenerated == ValueGenerated.OnUpdate || property.ValueGenerated == ValueGenerated.OnA +AddOrUpdate)) || + updateEntry.EntityState == EntityState.Detached + ) + { + updateEntry.SetStoreGeneratedValue(property, values[index]); + } + index++; + } + if (updateEntry.EntityState == EntityState.Detached) + updateEntry.AcceptChanges(); + } + else + { + throw new InvalidOperationException("SetStoreValues() failed because an instance of InternalEntityEntry was + not found."); + } + } + internal static BulkInsertResult BulkInsert(IEnumerable entities, BulkOptions options, TableMapping tableMa +apping, DbConnection dbConnection, DbTransaction transaction, string tableName, + IEnumerable inputColumns = null, SqlBulkCopyOptions bulkCopyOptions = SqlBulkCopyOptions.Default, bool u +useInternalId = false) + { + using var dataReader = new EntityDataReader(tableMapping, entities, useInternalId); + if (dbConnection is NpgsqlConnection npgsqlConnection) + { + var columnNames = tableMapping.Properties + .Select(tableMapping.GetColumnName) + .Where(columnName => inputColumns == null || inputColumns.Contains(columnName)) + .ToList(); + if (useInternalId) + columnNames.Add(Constants.InternalId_ColumnName); + + string copySql = $"COPY {tableName} ({string.Join(",", columnNames.Select(tableMapping.DbContext.DelimitIden +ntifier))}) FROM STDIN (FORMAT BINARY)"; + using var importer = npgsqlConnection.BeginBinaryImport(copySql); + long rowsCopied = 0; + while (dataReader.Read()) + { + importer.StartRow(); + foreach (var columnName in columnNames) + { + object value = dataReader.GetValue(dataReader.GetOrdinal(columnName)); + if (value == null || value == DBNull.Value) + importer.WriteNull(); + else + importer.Write(value); + } + rowsCopied++; + } + importer.Complete(); + + return new BulkInsertResult + { + RowsAffected = (int)rowsCopied, + EntityMap = dataReader.EntityMap + }; + } + + var sqlBulkCopy = new SqlBulkCopy((SqlConnection)dbConnection, bulkCopyOptions | options.BulkCopyOptions, (SqlTr +ransaction)transaction) + { + DestinationTableName = tableName, + BatchSize = options.BatchSize, + NotifyAfter = options.NotifyAfter, + EnableStreaming = options.EnableStreaming, + }; + sqlBulkCopy.BulkCopyTimeout = options.CommandTimeout.HasValue ? options.CommandTimeout.Value : sqlBulkCopy.BulkC +CopyTimeout; + if (options.SqlRowsCopied != null) + sqlBulkCopy.SqlRowsCopied += options.SqlRowsCopied; + foreach (SqlBulkCopyColumnOrderHint columnOrderHint in options.ColumnOrderHints) + sqlBulkCopy.ColumnOrderHints.Add(columnOrderHint); + foreach (var property in dataReader.TableMapping.Properties) + { + var columnName = dataReader.TableMapping.GetColumnName(property); + if (inputColumns == null || inputColumns.Contains(columnName)) + sqlBulkCopy.ColumnMappings.Add(columnName, columnName); + } + if (useInternalId) + sqlBulkCopy.ColumnMappings.Add(Constants.InternalId_ColumnName, Constants.InternalId_ColumnName); + sqlBulkCopy.WriteToServer(dataReader); + + return new BulkInsertResult + { + RowsAffected = sqlBulkCopy.RowsCopied, + EntityMap = dataReader.EntityMap + }; + } + internal static BulkQueryResult BulkQuery(this DbContext context, string sqlText, BulkOptions options) + { + List results = []; + List columns = []; + using var command = context.Database.CreateCommand(); + command.CommandText = sqlText; + if (options.CommandTimeout.HasValue) + command.CommandTimeout = options.CommandTimeout.Value; + using var reader = command.ExecuteReader(); + while (reader.Read()) + { + if (columns.Count == 0) + { + for (int i = 0; i < reader.FieldCount; i++) + columns.Add(reader.GetName(i)); + } + object[] values = new object[reader.FieldCount]; + reader.GetValues(values); + results.Add(values); + } + + return new BulkQueryResult + { + Columns = columns, + Results = results, + RowsAffected = reader.RecordsAffected + }; + } + internal static DbContext GetDbContext(this IQueryable queryable) where T : class + { + DbContext dbContext; + try + { + if ((queryable as InternalDbSet) != null) + { + dbContext = queryable.GetPrivateFieldValue("_context") as DbContext; + } + else if ((queryable as EntityQueryable) != null) + { + var queryCompiler = queryable.Provider.GetPrivateFieldValue("_queryCompiler"); + var contextFactory = queryCompiler.GetPrivateFieldValue("_queryContextFactory"); + var queryDependencies = contextFactory.GetPrivateFieldValue("Dependencies") as QueryContextDependencies; + dbContext = queryDependencies.CurrentContext.Context as DbContext; + } + else + { + throw new Exception("This extension method could not find the DbContext for this type that implements IQ +Queryable"); + } + } + catch + { + throw new Exception("This extension method could not find the DbContext for this type that implements IQuery +yable"); + } + return dbContext; + } + internal static DbConnection GetDbConnection(this DbContext context, ConnectionBehavior connectionBehavior = Connect +tionBehavior.Default) + { + var dbConnection = context.Database.GetDbConnection(); + return connectionBehavior == ConnectionBehavior.New ? dbConnection.CloneConnection() : dbConnection; + } + private static IEnumerable FetchInternal(this DbContext dbContext, string sqlText, object[] parameters = null) +) where T : class, new() + { + using var command = dbContext.Database.CreateCommand(ConnectionBehavior.New); + command.CommandText = sqlText; + if (parameters != null) + command.Parameters.AddRange(parameters); + + var tableMapping = dbContext.GetTableMapping(typeof(T), null); + using var reader = command.ExecuteReader(); + var properties = reader.GetProperties(tableMapping); + var valuesFromProvider = properties.Select(p => tableMapping.GetValueFromProvider(p)).ToArray(); + + while (reader.Read()) + { + var entity = reader.MapEntity(dbContext, properties, valuesFromProvider); + yield return entity; + } + } + private static BulkMergeResult InternalBulkMerge(this DbContext context, IEnumerable entities, BulkMergeOpt +tions options) + { + BulkMergeResult bulkMergeResult; + using (var bulkOperation = new BulkOperation(context, options)) + { + try + { + bool shouldPreallocate = bulkOperation.ShouldPreallocateIdentityValues(true, false, entities); + bool keepIdentity = shouldPreallocate || bulkOperation.ShouldKeepIdentityForPostgresMerge(); + if (shouldPreallocate) + bulkOperation.PreallocateIdentityValues(entities); + bulkOperation.ValidateBulkMerge(options.MergeOnCondition); + var bulkInsertResult = bulkOperation.BulkInsertStagingData(entities, true, true); + bulkMergeResult = bulkOperation.ExecuteMerge(bulkInsertResult.EntityMap, options.MergeOnCondition, optio +ons.AutoMapOutput, + keepIdentity, true, true, options.DeleteIfNotMatched, shouldPreallocate); + bulkOperation.DbTransactionContext.Commit(); + } + catch (Exception) + { + bulkOperation.DbTransactionContext.Rollback(); + throw; + } + } + return bulkMergeResult; + } + private static void ClearEntityStateToUnchanged(DbContext dbContext, IEnumerable entities) + { + foreach (var entity in entities) + { + var entry = dbContext.Entry(entity); + if (entry.State == EntityState.Added || entry.State == EntityState.Modified) + dbContext.Entry(entity).State = EntityState.Unchanged; + } + } + private static void Validate(TableMapping tableMapping) + { + if (!tableMapping.GetPrimaryKeyColumns().Any()) + { + throw new Exception("You must have a primary key on this table to use this function."); + } + } + private static QueryToFileResult InternalQueryToFile(this IQueryable queryable, Stream stream, QueryToFileOpti +ions options) where T : class + { + return InternalQueryToFile(queryable.AsNoTracking().AsEnumerable(), stream, options); + } + private static QueryToFileResult InternalQueryToFile(DbConnection dbConnection, Stream stream, QueryToFileOptions op +ptions, string sqlText, object[] parameters = null) + { + int dataRowCount = 0; + int totalRowCount = 0; + long bytesWritten = 0; + + if (dbConnection.State == ConnectionState.Closed) + dbConnection.Open(); + + using var command = dbConnection.CreateCommand(); + command.CommandText = sqlText; + if (parameters != null) + command.Parameters.AddRange(parameters); + if (options.CommandTimeout.HasValue) + command.CommandTimeout = options.CommandTimeout.Value; + + using var streamWriter = new StreamWriter(stream, leaveOpen: true); + using (var reader = command.ExecuteReader()) + { + if (options.IncludeHeaderRow) + { + for (int i = 0; i < reader.FieldCount; i++) + { + streamWriter.Write(options.TextQualifer); + streamWriter.Write(reader.GetName(i)); + streamWriter.Write(options.TextQualifer); + if (i != reader.FieldCount - 1) + { + streamWriter.Write(options.ColumnDelimiter); + } + } + totalRowCount++; + streamWriter.Write(options.RowDelimiter); + } + while (reader.Read()) + { + object[] values = new object[reader.FieldCount]; + reader.GetValues(values); + for (int i = 0; i < values.Length; i++) + { + streamWriter.Write(options.TextQualifer); + streamWriter.Write(values[i]); + streamWriter.Write(options.TextQualifer); + if (i != values.Length - 1) + { + streamWriter.Write(options.ColumnDelimiter); + } + } + streamWriter.Write(options.RowDelimiter); + dataRowCount++; + totalRowCount++; + } + streamWriter.Flush(); + bytesWritten = streamWriter.BaseStream.Length; + } + return new QueryToFileResult() + { + BytesWritten = bytesWritten, + DataRowCount = dataRowCount, + TotalRowCount = totalRowCount + }; + } + private static QueryToFileResult InternalQueryToFile(IEnumerable entities, Stream stream, QueryToFileOptions o +options) + { + int dataRowCount = 0; + int totalRowCount = 0; + long bytesWritten = 0; + var properties = typeof(T).GetProperties().Where(p => p.CanRead && !typeof(System.Collections.IEnumerable).IsAss +signableFrom(p.PropertyType) || p.PropertyType == typeof(string)).ToArray(); + + using var streamWriter = new StreamWriter(stream, leaveOpen: true); + if (options.IncludeHeaderRow) + { + WriteCsvRow(streamWriter, properties.Select(p => p.Name), options); + totalRowCount++; + } + + foreach (var entity in entities) + { + WriteCsvRow(streamWriter, properties.Select(p => p.GetValue(entity)), options); + dataRowCount++; + totalRowCount++; + } + + streamWriter.Flush(); + bytesWritten = streamWriter.BaseStream.Length; + return new QueryToFileResult { BytesWritten = bytesWritten, DataRowCount = dataRowCount, TotalRowCount = totalRo +owCount }; + } + private static HashSet GetIncludedColumns(TableMapping tableMapping, Expression> inputCol +lumns, Expression> ignoreColumns) + { + var includedColumns = inputColumns != null + ? inputColumns.GetObjectProperties().ToHashSet() + : tableMapping.Properties.Select(p => p.Name).ToHashSet(); + + if (ignoreColumns != null) + includedColumns.ExceptWith(ignoreColumns.GetObjectProperties()); + + return includedColumns; + } + private static void ClearExcludedColumns(DbContext dbContext, TableMapping tableMapping, T entity, HashSet includedColumns) where T : class + { + var entry = dbContext.Entry(entity); + foreach (var property in tableMapping.Properties) + { + if (includedColumns.Contains(property.Name)) + continue; + + object defaultValue = property.ClrType.IsValueType ? Activator.CreateInstance(property.ClrType) : null; + if (property.DeclaringType is IComplexType complexType) + { + var complexProperty = entry.ComplexProperty(complexType.ComplexProperty); + if (complexProperty.CurrentValue != null) + complexProperty.Property(property).CurrentValue = defaultValue; + } + else + { + entry.Property(property.Name).CurrentValue = defaultValue; + } + } + } + private static void WriteCsvRow(TextWriter writer, IEnumerable values, QueryToFileOptions options) + { + bool first = true; + foreach (var value in values) + { + if (!first) + writer.Write(options.ColumnDelimiter); + + writer.Write(options.TextQualifer); + writer.Write(value); + writer.Write(options.TextQualifer); + first = false; + } + writer.Write(options.RowDelimiter); + } + private static Action> BuildSetPropertyCalls(Expression> updateExpression) whe +ere T : class + { + if (updateExpression.Body is not MemberInitExpression memberInitExpression) + throw new InvalidOperationException("UpdateFromQuery requires a member initialization expression."); + + var entityParameter = updateExpression.Parameters[0]; + var setPropertyMethod = typeof(UpdateSettersBuilder) + .GetMethods() + .Single(m => m.Name == nameof(UpdateSettersBuilder.SetProperty) && m.GetParameters().Length == 2 && m.Get +tParameters()[1].ParameterType.IsGenericType); + + return setters => + { + object currentBuilder = setters; + foreach (var binding in memberInitExpression.Bindings.OfType()) + { + var propertyInfo = binding.Member as PropertyInfo ?? throw new InvalidOperationException("Only property + bindings are supported."); + var propertyLambda = Expression.Lambda(Expression.Property(entityParameter, propertyInfo), entityParamet +ter); + var valueLambda = Expression.Lambda(binding.Expression, entityParameter); + currentBuilder = setPropertyMethod.MakeGenericMethod(propertyInfo.PropertyType).Invoke(currentBuilder, [ +[propertyLambda, valueLambda]); + } + }; + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Data\DbContextE +ExtensionsAsync.cs --- + +using System; +using System.Collections.Generic; +using System.Data; +using System.Data.Common; +using System.IO; +using System.Linq; +using System.Linq.Expressions; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Data.SqlClient; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Internal; +using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.EntityFrameworkCore.Query; +using Npgsql; +using N.EntityFrameworkCore.Extensions.Common; +using N.EntityFrameworkCore.Extensions.Enums; +using N.EntityFrameworkCore.Extensions.Extensions; +using N.EntityFrameworkCore.Extensions.Sql; +using N.EntityFrameworkCore.Extensions.Util; + +namespace N.EntityFrameworkCore.Extensions; + +public static class DbContextExtensionsAsync +{ + public static async Task BulkDeleteAsync(this DbContext context, IEnumerable entities, CancellationToken + cancellationToken = default) + { + return await context.BulkDeleteAsync(entities, new BulkDeleteOptions(), cancellationToken); + } + public static async Task BulkDeleteAsync(this DbContext context, IEnumerable entities, Action> optionsAction, CancellationToken cancellationToken = default) + { + return await context.BulkDeleteAsync(entities, optionsAction.Build(), cancellationToken); + } + public static async Task BulkDeleteAsync(this DbContext context, IEnumerable entities, BulkDeleteOptions< + options, CancellationToken cancellationToken = default) + { + int rowsAffected = 0; + var tableMapping = context.GetTableMapping(typeof(T), options.EntityType); + + using (var dbTransactionContext = new DbTransactionContext(context, options)) + { + var dbConnection = dbTransactionContext.Connection; + var transaction = dbTransactionContext.CurrentTransaction; + try + { + string stagingTableName = CommonUtil.GetStagingTableName(tableMapping, options.UsePermanentTable, dbConn +nection); + string destinationTableName = context.DelimitIdentifier(tableMapping.TableName, tableMapping.Schema); + string[] keyColumnNames = options.DeleteOnCondition != null ? CommonUtil.GetColumns(options.DeleteOnC +Condition, ["s"]) + : tableMapping.GetPrimaryKeyColumns().ToArray(); + + if (keyColumnNames.Length == 0 && options.DeleteOnCondition == null) + throw new InvalidDataException("BulkDelete requires that the entity have a primary key or the Option +ns.DeleteOnCondition must be set."); + + await context.Database.CloneTableAsync(destinationTableName, stagingTableName, keyColumnNames, null, can +ncellationToken); + await BulkInsertAsync(entities, options, tableMapping, dbConnection, transaction, stagingTableName, keyC +ColumnNames, SqlBulkCopyOptions.KeepIdentity, + false, cancellationToken); + string joinCondition = CommonUtil.GetJoinConditionSql(context, options.DeleteOnCondition, keyColumnNa +ames); + string deleteSql = context.Database.IsPostgreSql() + ? $"DELETE FROM {destinationTableName} AS t USING {stagingTableName} AS s WHERE {joinCondition}" + : $"DELETE t FROM {stagingTableName} s JOIN {destinationTableName} t ON {joinCondition}"; + rowsAffected = await context.Database.ExecuteSqlRawAsync(deleteSql, cancellationToken); + context.Database.DropTable(stagingTableName); + dbTransactionContext.Commit(); + } + catch (Exception) + { + dbTransactionContext.Rollback(); + throw; + } + return rowsAffected; + } + } + public static async Task> BulkFetchAsync(this DbSet dbSet, IEnumerable entities, Cancella +ationToken cancellationToken = default) where T : class, new() + { + return await dbSet.BulkFetchAsync(entities, new BulkFetchOptions(), cancellationToken); + } + public static async Task> BulkFetchAsync(this DbSet dbSet, IEnumerable entities, Action> optionsAction, CancellationToken cancellationToken = default) where T : class, new() + { + return await dbSet.BulkFetchAsync(entities, optionsAction.Build(), cancellationToken); + } + public static async Task> BulkFetchAsync(this DbSet dbSet, IEnumerable entities, BulkFetc +chOptions options, CancellationToken cancellationToken = default) where T : class, new() + { + var context = dbSet.GetDbContext(); + var tableMapping = context.GetTableMapping(typeof(T)); + + using (var dbTransactionContext = new DbTransactionContext(context, options.CommandTimeout, ConnectionBehavior.N +New)) + { + string selectSql; + var dbConnection = dbTransactionContext.Connection; + var transaction = dbTransactionContext.CurrentTransaction; + string stagingTableName = string.Empty; + try + { + stagingTableName = CommonUtil.GetStagingTableName(tableMapping, true, dbConnection); + string destinationTableName = context.DelimitIdentifier(tableMapping.TableName, tableMapping.Schema); + string[] keyColumnNames = options.JoinOnCondition != null ? CommonUtil.GetColumns(options.JoinOnCondi +ition, ["s"]) + : tableMapping.GetPrimaryKeyColumns().ToArray(); + IEnumerable columnNames = CommonUtil.FilterColumns(tableMapping.GetColumns(true), keyColumnNa +ames, options.InputColumns, options.IgnoreColumns); + IEnumerable columnsToFetch = CommonUtil.FormatColumns(context, "t", columnNames); + + if (keyColumnNames.Length == 0 && options.JoinOnCondition == null) + throw new InvalidDataException("BulkFetch requires that the entity have a primary key or the Options +s.JoinOnCondition must be set."); + + await context.Database.CloneTableAsync(destinationTableName, stagingTableName, keyColumnNames, null, can +ncellationToken); + await BulkInsertAsync(entities, options, tableMapping, dbConnection, transaction, stagingTableName, keyC +ColumnNames, SqlBulkCopyOptions.KeepIdentity, false, cancellationToken); + selectSql = $"SELECT {SqlUtil.ConvertToColumnString(columnsToFetch)} FROM {stagingTableName} s JOIN {des +stinationTableName} t ON {CommonUtil.GetJoinConditionSql(context, options.JoinOnCondition, keyColumnNames)}"; + + dbTransactionContext.Commit(); + } + catch + { + dbTransactionContext.Rollback(); + throw; + } + + var results = await context.FetchInternalAsync(selectSql, cancellationToken: cancellationToken); + context.Database.DropTable(stagingTableName); + return results; + } + } + public static async Task FetchAsync(this IQueryable queryable, Func, Task> action, Action> optionsAction, CancellationToken cancellationToken = default) where T : class, new() + { + await FetchAsync(queryable, action, optionsAction.Build(), cancellationToken); + } + public static async Task FetchAsync(this IQueryable queryable, Func, Task> action, FetchOptions +s options, CancellationToken cancellationToken = default) where T : class, new() + { + var dbContext = queryable.GetDbContext(); + var tableMapping = dbContext.GetTableMapping(typeof(T)); + HashSet includedColumns = GetIncludedColumns(tableMapping, options.InputColumns, options.IgnoreColumns); + int batch = 1; + int count = 0; + List entities = []; + await foreach (var entity in queryable.AsNoTracking().AsAsyncEnumerable().WithCancellation(cancellationToken)) + { + ClearExcludedColumns(dbContext, tableMapping, entity, includedColumns); + entities.Add(entity); + count++; + if (count == options.BatchSize) + { + await action(new FetchResult { Results = entities, Batch = batch }); + entities.Clear(); + count = 0; + batch++; + } + cancellationToken.ThrowIfCancellationRequested(); + } + + if (entities.Count > 0) + await action(new FetchResult { Results = entities, Batch = batch }); + } + public static async Task BulkInsertAsync(this DbContext context, IEnumerable entities, CancellationToken + cancellationToken = default) + { + return await context.BulkInsertAsync(entities, new BulkInsertOptions(), cancellationToken); + } + public static async Task BulkInsertAsync(this DbContext context, IEnumerable entities, Action> optionsAction, CancellationToken cancellationToken = default) + { + return await context.BulkInsertAsync(entities, optionsAction.Build(), cancellationToken); + } + public static async Task BulkInsertAsync(this DbContext context, IEnumerable entities, BulkInsertOptions< + options, CancellationToken cancellationToken = default) + { + int rowsAffected = 0; + using (var bulkOperation = new BulkOperation(context, options, options.InputColumns, options.IgnoreColumns)) + { + try + { + bool keepIdentity = options.KeepIdentity || bulkOperation.ShouldPreallocateIdentityValues(options.AutoMa +apOutput, options.KeepIdentity, entities); + if (keepIdentity && !options.KeepIdentity) + await bulkOperation.PreallocateIdentityValuesAsync(entities, cancellationToken); + var bulkInsertResult = await bulkOperation.BulkInsertStagingDataAsync(entities, true, true); + var bulkMergeResult = await bulkOperation.ExecuteMergeAsync(bulkInsertResult.EntityMap, options.InsertOn +nCondition, + options.AutoMapOutput, keepIdentity, options.InsertIfNotExists); + rowsAffected = bulkMergeResult.RowsAffected; + bulkOperation.DbTransactionContext.Commit(); + } + catch (Exception) + { + bulkOperation.DbTransactionContext.Rollback(); + throw; + } + } + return rowsAffected; + } + public static async Task> BulkMergeAsync(this DbContext context, IEnumerable entities, Canc +cellationToken cancellationToken = default) + { + return await BulkMergeAsync(context, entities, new BulkMergeOptions(), cancellationToken); + } + public static async Task> BulkMergeAsync(this DbContext context, IEnumerable entities, Bulk +kMergeOptions options, CancellationToken cancellationToken = default) + { + return await InternalBulkMergeAsync(context, entities, options, cancellationToken); + } + public static async Task> BulkMergeAsync(this DbContext context, IEnumerable entities, Acti +ion> optionsAction, CancellationToken cancellationToken = default) + { + return await BulkMergeAsync(context, entities, optionsAction.Build(), cancellationToken); + } + public static async Task BulkSaveChangesAsync(this DbContext dbContext) + { + return await dbContext.BulkSaveChangesAsync(true); + } + public static async Task BulkSaveChangesAsync(this DbContext dbContext, bool acceptAllChangesOnSuccess = true) + { + int rowsAffected = 0; + var stateManager = dbContext.GetDependencies().StateManager; + + dbContext.ChangeTracker.DetectChanges(); + var entries = stateManager.GetEntriesToSave(true); + + foreach (var saveEntryGroup in entries.GroupBy(o => new { o.EntityType, o.EntityState })) + { + var key = saveEntryGroup.Key; + var entities = saveEntryGroup.AsEnumerable(); + if (key.EntityState == EntityState.Added) + { + rowsAffected += await dbContext.BulkInsertAsync(entities, o => { o.EntityType = key.EntityType; }); + } + else if (key.EntityState == EntityState.Modified) + { + rowsAffected += await dbContext.BulkUpdateAsync(entities, o => { o.EntityType = key.EntityType; }); + } + else if (key.EntityState == EntityState.Deleted) + { + rowsAffected += await dbContext.BulkDeleteAsync(entities, o => { o.EntityType = key.EntityType; }); + } + } + + if (acceptAllChangesOnSuccess) + dbContext.ChangeTracker.AcceptAllChanges(); + + return rowsAffected; + } + public static async Task> BulkSyncAsync(this DbContext context, IEnumerable entities, Cancel +llationToken cancellationToken = default) + { + return await BulkSyncAsync(context, entities, new BulkSyncOptions(), cancellationToken); + } + public static async Task> BulkSyncAsync(this DbContext context, IEnumerable entities, Action +n> optionsAction, CancellationToken cancellationToken = default) + { + return BulkSyncResult.Map(await InternalBulkMergeAsync(context, entities, optionsAction.Build(), cancellation +nToken)); + } + public static async Task> BulkSyncAsync(this DbContext context, IEnumerable entities, BulkSy +yncOptions options, CancellationToken cancellationToken = default) + { + return BulkSyncResult.Map(await InternalBulkMergeAsync(context, entities, options, cancellationToken)); + } + public static async Task BulkUpdateAsync(this DbContext context, IEnumerable entities, CancellationToken + cancellationToken = default) + { + return await BulkUpdateAsync(context, entities, new BulkUpdateOptions(), cancellationToken); + } + public static async Task BulkUpdateAsync(this DbContext context, IEnumerable entities, Action> optionsAction, CancellationToken cancellationToken = default) + { + return await BulkUpdateAsync(context, entities, optionsAction.Build(), cancellationToken); + } + public static async Task BulkUpdateAsync(this DbContext context, IEnumerable entities, BulkUpdateOptions< + options, CancellationToken cancellationToken = default) + { + int rowsUpdated = 0; + using (var bulkOperation = new BulkOperation(context, options, options.InputColumns, options.IgnoreColumns)) + { + try + { + bulkOperation.ValidateBulkUpdate(options.UpdateOnCondition); + await bulkOperation.BulkInsertStagingDataAsync(entities, cancellationToken: cancellationToken); + rowsUpdated = await bulkOperation.ExecuteUpdateAsync(entities, options.UpdateOnCondition, cancellationTo +oken); + bulkOperation.DbTransactionContext.Commit(); + } + catch (Exception) + { + bulkOperation.DbTransactionContext.Rollback(); + throw; + } + } + return rowsUpdated; + } + public static async Task DeleteFromQueryAsync(this IQueryable queryable, int? commandTimeout = null, Canc +cellationToken cancellationToken = default) where T : class + { + var dbContext = queryable.GetDbContext(); + using (var dbTransactionContext = new DbTransactionContext(dbContext, commandTimeout)) + { + try + { + int rowsAffected = await queryable.ExecuteDeleteAsync(cancellationToken); + dbTransactionContext.Commit(); + return rowsAffected; + } + catch (Exception) + { + dbTransactionContext.Rollback(); + throw; + } + } + } + public static async Task InsertFromQueryAsync(this IQueryable queryable, string tableName, Expression> insertObjectExpression, int? commandTimeout = null, + CancellationToken cancellationToken = default) where T : class + { + var dbContext = queryable.GetDbContext(); + using (var dbTransactionContext = new DbTransactionContext(dbContext, commandTimeout)) + { + try + { + var tableMapping = dbContext.GetTableMapping(typeof(T)); + var columnNames = insertObjectExpression.GetObjectProperties(); + if (!dbContext.Database.TableExists(tableName)) + { + await dbContext.Database.CloneTableAsync(tableMapping.FullQualifedTableName, dbContext.Database.Deli +imitTableName(tableName), tableMapping.GetQualifiedColumnNames(columnNames), cancellationToken: cancellationToken); + } + + var entities = await queryable.AsNoTracking().ToListAsync(cancellationToken); + int rowsAffected = (int)(await BulkInsertAsync(entities, new BulkInsertOptions { KeepIdentity = true, +, AutoMapOutput = false, CommandTimeout = commandTimeout }, tableMapping, + dbTransactionContext.Connection, dbTransactionContext.CurrentTransaction, dbContext.Database.Delimit +tTableName(tableName), columnNames, SqlBulkCopyOptions.KeepIdentity, cancellationToken: cancellationToken)).RowsAffected; + dbTransactionContext.Commit(); + return rowsAffected; + } + catch (Exception) + { + dbTransactionContext.Rollback(); + throw; + } + } + } + public static async Task UpdateFromQueryAsync(this IQueryable queryable, Expression> updateExp +pression, int? commandTimeout = null, + CancellationToken cancellationToken = default) where T : class + { + var dbContext = queryable.GetDbContext(); + using (var dbTransactionContext = new DbTransactionContext(dbContext, commandTimeout)) + { + try + { + int rowsAffected = await queryable.ExecuteUpdateAsync(BuildSetPropertyCalls(updateExpression), cancellat +tionToken); + dbTransactionContext.Commit(); + return rowsAffected; + } + catch (Exception) + { + dbTransactionContext.Rollback(); + throw; + } + } + } + public static async Task QueryToCsvFileAsync(this IQueryable queryable, string filePath, Ca +ancellationToken cancellationToken = default) where T : class + { + return await QueryToCsvFileAsync(queryable, filePath, new QueryToFileOptions(), cancellationToken); + } + public static async Task QueryToCsvFileAsync(this IQueryable queryable, Stream stream, Canc +cellationToken cancellationToken = default) where T : class + { + return await QueryToCsvFileAsync(queryable, stream, new QueryToFileOptions(), cancellationToken); + } + public static async Task QueryToCsvFileAsync(this IQueryable queryable, string filePath, Ac +ction optionsAction, + CancellationToken cancellationToken = default) where T : class + { + return await QueryToCsvFileAsync(queryable, filePath, optionsAction.Build(), cancellationToken); + } + public static async Task QueryToCsvFileAsync(this IQueryable queryable, Stream stream, Acti +ion optionsAction, + CancellationToken cancellationToken = default) where T : class + { + return await QueryToCsvFileAsync(queryable, stream, optionsAction.Build(), cancellationToken); + } + public static async Task QueryToCsvFileAsync(this IQueryable queryable, string filePath, Qu +ueryToFileOptions options, + CancellationToken cancellationToken = default) where T : class + { + await using var fileStream = File.Create(filePath); + return await QueryToCsvFileAsync(queryable, fileStream, options, cancellationToken); + } + public static async Task QueryToCsvFileAsync(this IQueryable queryable, Stream stream, Quer +ryToFileOptions options, + CancellationToken cancellationToken = default) where T : class + { + return await InternalQueryToFileAsync(queryable, stream, options, cancellationToken); + } + public static async Task SqlQueryToCsvFileAsync(this DatabaseFacade database, string filePath, st +tring sqlText, object[] parameters, + CancellationToken cancellationToken = default) + { + return await SqlQueryToCsvFileAsync(database, filePath, new QueryToFileOptions(), sqlText, parameters, cancellat +tionToken); + } + public static async Task SqlQueryToCsvFileAsync(this DatabaseFacade database, Stream stream, stri +ing sqlText, object[] parameters, + CancellationToken cancellationToken = default) + { + return await SqlQueryToCsvFileAsync(database, stream, new QueryToFileOptions(), sqlText, parameters, cancellatio +onToken); + } + public static async Task SqlQueryToCsvFileAsync(this DatabaseFacade database, string filePath, Ac +ction optionsAction, string sqlText, object[] parameters, + CancellationToken cancellationToken = default) + { + return await SqlQueryToCsvFileAsync(database, filePath, optionsAction.Build(), sqlText, parameters, cancellation +nToken); + } + public static async Task SqlQueryToCsvFileAsync(this DatabaseFacade database, Stream stream, Acti +ion optionsAction, string sqlText, object[] parameters, + CancellationToken cancellationToken = default) + { + return await SqlQueryToCsvFileAsync(database, stream, optionsAction.Build(), sqlText, parameters, cancellationTo +oken); + } + public static async Task SqlQueryToCsvFileAsync(this DatabaseFacade database, string filePath, Qu +ueryToFileOptions options, string sqlText, object[] parameters, + CancellationToken cancellationToken = default) + { + await using var fileStream = File.Create(filePath); + return await SqlQueryToCsvFileAsync(database, fileStream, options, sqlText, parameters, cancellationToken); + } + public static async Task SqlQueryToCsvFileAsync(this DatabaseFacade database, Stream stream, Quer +ryToFileOptions options, string sqlText, object[] parameters, + CancellationToken cancellationToken = default) + { + return await InternalQueryToFileAsync(database.GetDbConnection(), stream, options, sqlText, parameters, cancella +ationToken); + } + public static async Task ClearAsync(this DbSet dbSet, CancellationToken cancellationToken = default) where T : +: class + { + var dbContext = dbSet.GetDbContext(); + var tableMapping = dbContext.GetTableMapping(typeof(T)); + await dbContext.Database.ClearTableAsync(tableMapping.FullQualifedTableName, cancellationToken); + } + public static async Task TruncateAsync(this DbSet dbSet, CancellationToken cancellationToken = default) where + T : class + { + var dbContext = dbSet.GetDbContext(); + var tableMapping = dbContext.GetTableMapping(typeof(T)); + await dbContext.Database.TruncateTableAsync(tableMapping.FullQualifedTableName, false, cancellationToken); + } + internal static async Task> BulkInsertAsync(IEnumerable entities, BulkOptions options, Tab +bleMapping tableMapping, DbConnection dbConnection, DbTransaction transaction, string tableName, + IEnumerable inputColumns = null, SqlBulkCopyOptions bulkCopyOptions = SqlBulkCopyOptions.Default, bool u +useInternalId = false, CancellationToken cancellationToken = default) + { + using var dataReader = new EntityDataReader(tableMapping, entities, useInternalId); + if (dbConnection is NpgsqlConnection npgsqlConnection) + { + var columnNames = tableMapping.Properties + .Select(tableMapping.GetColumnName) + .Where(columnName => inputColumns == null || inputColumns.Contains(columnName)) + .ToList(); + if (useInternalId) + columnNames.Add(Constants.InternalId_ColumnName); + + string copySql = $"COPY {tableName} ({string.Join(",", columnNames.Select(tableMapping.DbContext.DelimitIden +ntifier))}) FROM STDIN (FORMAT BINARY)"; + await using var importer = await npgsqlConnection.BeginBinaryImportAsync(copySql, cancellationToken); + long rowsCopied = 0; + while (dataReader.Read()) + { + await importer.StartRowAsync(cancellationToken); + foreach (var columnName in columnNames) + { + object value = dataReader.GetValue(dataReader.GetOrdinal(columnName)); + if (value == null || value == DBNull.Value) + await importer.WriteNullAsync(cancellationToken); + else + await importer.WriteAsync(value, cancellationToken); + } + rowsCopied++; + } + await importer.CompleteAsync(cancellationToken); + + return new BulkInsertResult + { + RowsAffected = (int)rowsCopied, + EntityMap = dataReader.EntityMap + }; + } + + var sqlBulkCopy = new SqlBulkCopy((SqlConnection)dbConnection, bulkCopyOptions, (SqlTransaction)transaction) + { + DestinationTableName = tableName, + BatchSize = options.BatchSize + }; + if (options.CommandTimeout.HasValue) + { + sqlBulkCopy.BulkCopyTimeout = options.CommandTimeout.Value; + } + foreach (var property in dataReader.TableMapping.Properties) + { + var columnName = dataReader.TableMapping.GetColumnName(property); + if (inputColumns == null || inputColumns.Contains(columnName)) + sqlBulkCopy.ColumnMappings.Add(columnName, columnName); + } + if (useInternalId) + { + sqlBulkCopy.ColumnMappings.Add(Constants.InternalId_ColumnName, Constants.InternalId_ColumnName); + } + await sqlBulkCopy.WriteToServerAsync(dataReader, cancellationToken); + + return new BulkInsertResult + { + RowsAffected = sqlBulkCopy.RowsCopied, + EntityMap = dataReader.EntityMap + }; + } + internal static async Task BulkQueryAsync(this DbContext context, string sqlText, DbConnection dbCo +onnection, DbTransaction transaction, BulkOptions options, CancellationToken cancellationToken = default) + { + List results = []; + List columns = []; + await using var command = dbConnection.CreateCommand(); + command.CommandText = sqlText; + command.Transaction = transaction; + if (options.CommandTimeout.HasValue) + command.CommandTimeout = options.CommandTimeout.Value; + await using var reader = await command.ExecuteReaderAsync(cancellationToken); + while (await reader.ReadAsync(cancellationToken)) + { + if (columns.Count == 0) + { + for (int i = 0; i < reader.FieldCount; i++) + columns.Add(reader.GetName(i)); + } + object[] values = new object[reader.FieldCount]; + reader.GetValues(values); + results.Add(values); + } + + return new BulkQueryResult + { + Columns = columns, + Results = results, + RowsAffected = reader.RecordsAffected + }; + } + private static async Task> InternalBulkMergeAsync(this DbContext context, IEnumerable entit +ties, BulkMergeOptions options, CancellationToken cancellationToken = default) + { + BulkMergeResult bulkMergeResult; + using (var bulkOperation = new BulkOperation(context, options)) + { + try + { + bool shouldPreallocate = bulkOperation.ShouldPreallocateIdentityValues(true, false, entities); + bool keepIdentity = shouldPreallocate || bulkOperation.ShouldKeepIdentityForPostgresMerge(); + if (shouldPreallocate) + await bulkOperation.PreallocateIdentityValuesAsync(entities, cancellationToken); + bulkOperation.ValidateBulkMerge(options.MergeOnCondition); + var bulkInsertResult = await bulkOperation.BulkInsertStagingDataAsync(entities, true, true, cancellation +nToken); + bulkMergeResult = await bulkOperation.ExecuteMergeAsync(bulkInsertResult.EntityMap, options.MergeOnCondi +ition, options.AutoMapOutput, + keepIdentity, true, true, options.DeleteIfNotMatched, shouldPreallocate, cancellationToken); + bulkOperation.DbTransactionContext.Commit(); + } + catch (Exception) + { + bulkOperation.DbTransactionContext.Rollback(); + throw; + } + } + return bulkMergeResult; + } + private static async Task InternalQueryToFileAsync(this IQueryable queryable, Stream stream +m, QueryToFileOptions options, + CancellationToken cancellationToken = default) where T : class + { + return await InternalQueryToFileAsync(queryable.AsNoTracking().AsAsyncEnumerable(), stream, options, cancellatio +onToken); + } + private static async Task InternalQueryToFileAsync(DbConnection dbConnection, Stream stream, Quer +ryToFileOptions options, string sqlText, object[] parameters = null, + CancellationToken cancellationToken = default) + { + int dataRowCount = 0; + int totalRowCount = 0; + long bytesWritten = 0; + + if (dbConnection.State == ConnectionState.Closed) + dbConnection.Open(); + + await using var command = dbConnection.CreateCommand(); + command.CommandText = sqlText; + if (parameters != null) + command.Parameters.AddRange(parameters); + if (options.CommandTimeout.HasValue) + command.CommandTimeout = options.CommandTimeout.Value; + + await using var streamWriter = new StreamWriter(stream, leaveOpen: true); + using (var reader = await command.ExecuteReaderAsync(cancellationToken)) + { + if (options.IncludeHeaderRow) + { + for (int i = 0; i < reader.FieldCount; i++) + { + streamWriter.Write(options.TextQualifer); + streamWriter.Write(reader.GetName(i)); + streamWriter.Write(options.TextQualifer); + if (i != reader.FieldCount - 1) + { + await streamWriter.WriteAsync(options.ColumnDelimiter); + } + } + totalRowCount++; + await streamWriter.WriteAsync(options.RowDelimiter); + } + while (await reader.ReadAsync(cancellationToken)) + { + object[] values = new object[reader.FieldCount]; + reader.GetValues(values); + for (int i = 0; i < values.Length; i++) + { + streamWriter.Write(options.TextQualifer); + streamWriter.Write(values[i]); + streamWriter.Write(options.TextQualifer); + if (i != values.Length - 1) + { + await streamWriter.WriteAsync(options.ColumnDelimiter); + } + } + await streamWriter.WriteAsync(options.RowDelimiter); + dataRowCount++; + totalRowCount++; + } + await streamWriter.FlushAsync(); + bytesWritten = streamWriter.BaseStream.Length; + } + return new QueryToFileResult() + { + BytesWritten = bytesWritten, + DataRowCount = dataRowCount, + TotalRowCount = totalRowCount + }; + } + private static async Task InternalQueryToFileAsync(IAsyncEnumerable entities, Stream stream +m, QueryToFileOptions options, CancellationToken cancellationToken) where T : class + { + int dataRowCount = 0; + int totalRowCount = 0; + long bytesWritten = 0; + var properties = typeof(T).GetProperties().Where(p => p.CanRead && (!typeof(System.Collections.IEnumerable).IsAs +ssignableFrom(p.PropertyType) || p.PropertyType == typeof(string))).ToArray(); + + await using var streamWriter = new StreamWriter(stream, leaveOpen: true); + if (options.IncludeHeaderRow) + { + await WriteCsvRowAsync(streamWriter, properties.Select(p => (object)p.Name), options, cancellationToken); + totalRowCount++; + } + + await foreach (var entity in entities.WithCancellation(cancellationToken)) + { + await WriteCsvRowAsync(streamWriter, properties.Select(p => p.GetValue(entity)), options, cancellationToken) +); + dataRowCount++; + totalRowCount++; + } + + await streamWriter.FlushAsync(cancellationToken); + bytesWritten = streamWriter.BaseStream.Length; + return new QueryToFileResult { BytesWritten = bytesWritten, DataRowCount = dataRowCount, TotalRowCount = totalRo +owCount }; + } + private static async Task> FetchInternalAsync(this DbContext dbContext, string sqlText, object[] p +parameters = null, CancellationToken cancellationToken = default) where T : class, new() + { + List results = []; + await using var command = dbContext.Database.CreateCommand(ConnectionBehavior.New); + command.CommandText = sqlText; + if (parameters != null) + command.Parameters.AddRange(parameters); + + var tableMapping = dbContext.GetTableMapping(typeof(T), null); + var reader = await command.ExecuteReaderAsync(cancellationToken); + var properties = reader.GetProperties(tableMapping); + var valuesFromProvider = properties.Select(p => tableMapping.GetValueFromProvider(p)).ToArray(); + + while (await reader.ReadAsync(cancellationToken)) + { + var entity = reader.MapEntity(dbContext, properties, valuesFromProvider); + results.Add(entity); + } + + await reader.CloseAsync(); + await command.Connection.CloseAsync(); + return results; + } + private static HashSet GetIncludedColumns(TableMapping tableMapping, Expression> inputCol +lumns, Expression> ignoreColumns) + { + var includedColumns = inputColumns != null + ? inputColumns.GetObjectProperties().ToHashSet() + : tableMapping.Properties.Select(p => p.Name).ToHashSet(); + + if (ignoreColumns != null) + includedColumns.ExceptWith(ignoreColumns.GetObjectProperties()); + + return includedColumns; + } + private static void ClearExcludedColumns(DbContext dbContext, TableMapping tableMapping, T entity, HashSet includedColumns) where T : class + { + var entry = dbContext.Entry(entity); + foreach (var property in tableMapping.Properties) + { + if (includedColumns.Contains(property.Name)) + continue; + + object defaultValue = property.ClrType.IsValueType ? Activator.CreateInstance(property.ClrType) : null; + if (property.DeclaringType is IComplexType complexType) + { + var complexProperty = entry.ComplexProperty(complexType.ComplexProperty); + if (complexProperty.CurrentValue != null) + complexProperty.Property(property).CurrentValue = defaultValue; + } + else + { + entry.Property(property.Name).CurrentValue = defaultValue; + } + } + } + private static async Task WriteCsvRowAsync(TextWriter writer, IEnumerable values, QueryToFileOptions options +s, CancellationToken cancellationToken) + { + bool first = true; + foreach (var value in values) + { + if (!first) + await writer.WriteAsync(options.ColumnDelimiter); + + await writer.WriteAsync(options.TextQualifer); + await writer.WriteAsync(value?.ToString()); + await writer.WriteAsync(options.TextQualifer); + first = false; + cancellationToken.ThrowIfCancellationRequested(); + } + await writer.WriteAsync(options.RowDelimiter); + } + private static Action> BuildSetPropertyCalls(Expression> updateExpression) whe +ere T : class + { + if (updateExpression.Body is not MemberInitExpression memberInitExpression) + throw new InvalidOperationException("UpdateFromQuery requires a member initialization expression."); + + var entityParameter = updateExpression.Parameters[0]; + var setPropertyMethod = typeof(UpdateSettersBuilder) + .GetMethods() + .Single(m => m.Name == nameof(UpdateSettersBuilder.SetProperty) && m.GetParameters().Length == 2 && m.Get +tParameters()[1].ParameterType.IsGenericType); + + return setters => + { + object currentBuilder = setters; + foreach (var binding in memberInitExpression.Bindings.OfType()) + { + var propertyInfo = binding.Member as System.Reflection.PropertyInfo ?? throw new InvalidOperationExcepti +ion("Only property bindings are supported."); + var propertyLambda = Expression.Lambda(Expression.Property(entityParameter, propertyInfo), entityParamet +ter); + var valueLambda = Expression.Lambda(binding.Expression, entityParameter); + currentBuilder = setPropertyMethod.MakeGenericMethod(propertyInfo.PropertyType).Invoke(currentBuilder, [ +[propertyLambda, valueLambda]); + } + }; + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Data\DbTransact +tionContext.cs --- + +using System; +using System.Data.Common; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Storage; +using N.EntityFrameworkCore.Extensions.Enums; +using N.EntityFrameworkCore.Extensions.Util; + + +namespace N.EntityFrameworkCore.Extensions; + +internal sealed class DbTransactionContext : IDisposable +{ + private bool closeConnection; + private bool ownsTransaction; + private int? defaultCommandTimeout; + private DbContext context; + private IDbContextTransaction transaction; + + public DbConnection Connection { get; internal set; } + public DbTransaction CurrentTransaction { get; private set; } + public DbContext DbContext => context; + + public DbTransactionContext(DbContext context, BulkOptions bulkOptions, bool openConnection = true) : this(context, + bulkOptions.CommandTimeout, bulkOptions.ConnectionBehavior, openConnection) + { + + } + public DbTransactionContext(DbContext context, int? commandTimeout = null, ConnectionBehavior connectionBehavior = C +ConnectionBehavior.Default, bool openConnection = true) + { + this.context = context; + Connection = context.GetDbConnection(connectionBehavior); + if (openConnection) + { + if (Connection.State == System.Data.ConnectionState.Closed) + { + Connection.Open(); + closeConnection = true; + } + } + if (connectionBehavior == ConnectionBehavior.Default) + { + ownsTransaction = context.Database.CurrentTransaction == null; + transaction = context.Database.CurrentTransaction; + defaultCommandTimeout = context.Database.GetCommandTimeout(); + if (transaction != null) + CurrentTransaction = transaction.GetDbTransaction(); + } + + context.Database.SetCommandTimeout(commandTimeout); + } + + public void Dispose() + { + context.Database.SetCommandTimeout(defaultCommandTimeout); + if (closeConnection) + { + Connection.Close(); + } + } + + internal void Commit() + { + if (ownsTransaction && transaction != null) + transaction.Commit(); + } + internal void Rollback() + { + if (transaction != null) + transaction.Rollback(); + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Data\EfExtensio +onsCommand.cs --- + +using System.Data.Common; +using Microsoft.EntityFrameworkCore.Diagnostics; + +namespace N.EntityFrameworkCore.Extensions; + +internal sealed class EfExtensionsCommand +{ + public EfExtensionsCommandType CommandType { get; set; } + public string OldValue { get; set; } + public string NewValue { get; set; } + public DbConnection Connection { get; internal set; } + + internal bool Execute(DbCommand command, CommandEventData eventData, InterceptionResult result) + { + if (CommandType == EfExtensionsCommandType.ChangeTableName) + { + command.CommandText = command.CommandText.Replace(OldValue, NewValue); + } + + return true; + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Data\EfExtensio +onsCommandInterceptor.cs --- + +using System; +using System.Collections.Concurrent; +using System.Data.Common; +using Microsoft.EntityFrameworkCore.Diagnostics; + +namespace N.EntityFrameworkCore.Extensions; + +public class EfExtensionsCommandInterceptor : DbCommandInterceptor +{ + private ConcurrentDictionary extensionCommands = new(); + public override InterceptionResult ReaderExecuting(DbCommand command, CommandEventData eventData, Inte +erceptionResult result) + { + foreach (var extensionCommand in extensionCommands) + { + if (extensionCommand.Value.Connection == command.Connection) + { + extensionCommand.Value.Execute(command, eventData, result); + extensionCommands.TryRemove(extensionCommand.Key, out _); + } + } + return result; + } + internal void AddCommand(Guid clientConnectionId, EfExtensionsCommand efExtensionsCommand) + { + extensionCommands.TryAdd(clientConnectionId, efExtensionsCommand); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Data\EntityData +aReader.cs --- + +using System; +using System.Collections.Generic; +using System.Data; +using Microsoft.EntityFrameworkCore.ChangeTracking; +using Microsoft.EntityFrameworkCore.ChangeTracking.Internal; +using Microsoft.EntityFrameworkCore.Metadata; +using N.EntityFrameworkCore.Extensions.Common; + +namespace N.EntityFrameworkCore.Extensions; + +internal sealed class EntityDataReader : IDataReader +{ + public TableMapping TableMapping { get; set; } + public Dictionary EntityMap { get; set; } + private Dictionary columnIndexes; + private int currentId; + private bool useInternalId; + private int tableFieldCount; + private IEnumerable entities; + private IEnumerator enumerator; + private Dictionary> selectors; + + public EntityDataReader(TableMapping tableMapping, IEnumerable entities, bool useInternalId) + { + this.columnIndexes = []; + this.currentId = 0; + this.useInternalId = useInternalId; + this.tableFieldCount = tableMapping.Properties.Length; + this.entities = entities; + this.enumerator = entities.GetEnumerator(); + this.selectors = []; + this.EntityMap = []; + this.FieldCount = tableMapping.Properties.Length; + this.TableMapping = tableMapping; + + + int i = 0; + foreach (var property in tableMapping.Properties) + { + selectors[i] = GetValueSelector(property); + columnIndexes[tableMapping.GetColumnName(property)] = i; + i++; + } + + if (useInternalId) + { + this.FieldCount++; + columnIndexes[Constants.InternalId_ColumnName] = i; + } + } + private Func GetValueSelector(IProperty property) + { + Func selector; + var valueGeneratorFactory = property.GetValueGeneratorFactory(); + if (valueGeneratorFactory != null) + { + var valueGenerator = valueGeneratorFactory.Invoke(property, this.TableMapping.EntityType); + selector = entry => valueGenerator.Next(entry); + } + else + { + var valueConverter = property.GetTypeMapping().Converter; + if (valueConverter != null) + { + selector = entry => valueConverter.ConvertToProvider(entry.CurrentValues[property]); + } + else + { + if (property.DeclaringType is IComplexType complexType) + { + selector = entry => entry.ComplexProperty(complexType.ComplexProperty).Property(property).CurrentVal +lue; + } + else + { + selector = entry => entry.CurrentValues[property]; + } + } + } + return selector; + } + public object this[int i] => throw new NotImplementedException(); + + public object this[string name] => throw new NotImplementedException(); + + public int Depth { get; set; } + + public bool IsClosed => throw new NotImplementedException(); + + public int RecordsAffected => throw new NotImplementedException(); + + public int FieldCount { get; set; } + + public void Close() + { + throw new NotImplementedException(); + } + + public void Dispose() + { + selectors = null; + enumerator.Dispose(); + } + + public bool GetBoolean(int i) + { + throw new NotImplementedException(); + } + + public byte GetByte(int i) + { + throw new NotImplementedException(); + } + + public long GetBytes(int i, long fieldOffset, byte[] buffer, int bufferoffset, int length) + { + throw new NotImplementedException(); + } + + public char GetChar(int i) + { + throw new NotImplementedException(); + } + + public long GetChars(int i, long fieldoffset, char[] buffer, int bufferoffset, int length) + { + throw new NotImplementedException(); + } + + public IDataReader GetData(int i) + { + throw new NotImplementedException(); + } + + public string GetDataTypeName(int i) + { + throw new NotImplementedException(); + } + + public DateTime GetDateTime(int i) + { + throw new NotImplementedException(); + } + + public decimal GetDecimal(int i) + { + throw new NotImplementedException(); + } + + public double GetDouble(int i) + { + throw new NotImplementedException(); + } + + public Type GetFieldType(int i) + { + throw new NotImplementedException(); + } + + public float GetFloat(int i) + { + throw new NotImplementedException(); + } + + public Guid GetGuid(int i) + { + throw new NotImplementedException(); + } + + public short GetInt16(int i) + { + throw new NotImplementedException(); + } + + public int GetInt32(int i) + { + throw new NotImplementedException(); + } + + public long GetInt64(int i) + { + throw new NotImplementedException(); + } + + public string GetName(int i) + { + throw new NotImplementedException(); + } + + public int GetOrdinal(string name) + { + return columnIndexes[name]; + } + + public DataTable GetSchemaTable() + { + throw new NotImplementedException(); + } + + public string GetString(int i) + { + throw new NotImplementedException(); + } + + public object GetValue(int i) + { + if (i == tableFieldCount) + { + return this.currentId; + } + else + { + return selectors[i](FindEntry(enumerator.Current)); + } + + } + + private EntityEntry FindEntry(object entity) + { + return entity is InternalEntityEntry internalEntry ? internalEntry.ToEntityEntry() : TableMapping.DbContext.Entr +ry(entity); + } + + public int GetValues(object[] values) + { + throw new NotImplementedException(); + } + + public bool IsDBNull(int i) + { + throw new NotImplementedException(); + } + + public bool NextResult() + { + throw new NotImplementedException(); + } + + public bool Read() + { + bool moveNext = enumerator.MoveNext(); + + if (moveNext && this.useInternalId) + { + this.currentId++; + this.EntityMap.Add(this.currentId, enumerator.Current); + } + return moveNext; + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Data\FetchOptio +ons.cs --- + +using System; +using System.Linq.Expressions; + +namespace N.EntityFrameworkCore.Extensions; + +public class FetchOptions +{ + public Expression> IgnoreColumns { get; set; } + public Expression> InputColumns { get; set; } + public int BatchSize { get; set; } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Data\FetchResul +lt.cs --- + +using System.Collections.Generic; + +namespace N.EntityFrameworkCore.Extensions; + +public class FetchResult +{ + public List Results { get; set; } + public int Batch { get; set; } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Data\QueryToFil +leOptions.cs --- + +namespace N.EntityFrameworkCore.Extensions; + +public class QueryToFileOptions +{ + public string ColumnDelimiter { get; set; } + public int? CommandTimeout { get; set; } + public bool IncludeHeaderRow { get; set; } + public string RowDelimiter { get; set; } + public string TextQualifer { get; set; } + + public QueryToFileOptions() + { + ColumnDelimiter = ","; + IncludeHeaderRow = true; + RowDelimiter = "\r\n"; + TextQualifer = ""; + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Data\QueryToFil +leResult.cs --- + +namespace N.EntityFrameworkCore.Extensions; + +public class QueryToFileResult +{ + public long BytesWritten { get; set; } + public int DataRowCount { get; internal set; } + public int TotalRowCount { get; internal set; } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Data\SqlMergeAc +ction.cs --- + +namespace N.EntityFrameworkCore.Extensions; + +internal static class SqlMergeAction +{ + public const string Insert = "INSERT"; + public const string Update = "UPDATE"; + public const string Delete = "DELETE"; +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Data\SqlQuery.c +cs --- + +using System; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.EntityFrameworkCore.Infrastructure; +using N.EntityFrameworkCore.Extensions.Sql; + +namespace N.EntityFrameworkCore.Extensions; + +public class SqlQuery +{ + private DatabaseFacade database; + public string SqlText { get; private set; } + public object[] Parameters { get; private set; } + + public SqlQuery(DatabaseFacade database, string sqlText, params object[] parameters) + { + this.database = database; + SqlText = sqlText; + Parameters = parameters; + } + + public int Count() + { + string countSqlText = SqlBuilder.Parse(SqlText).Count(); + return Convert.ToInt32(database.ExecuteScalar(countSqlText, Parameters)); + } + public async Task CountAsync(CancellationToken cancellationToken = default) + { + string countSqlText = SqlBuilder.Parse(SqlText).Count(); + return Convert.ToInt32(await database.ExecuteScalarAsync(countSqlText, Parameters, null, cancellationToken)); + } + public int ExecuteNonQuery() + { + return database.ExecuteSql(SqlText, Parameters); + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Data\TableMappi +ing.cs --- + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.EntityFrameworkCore.Metadata.Internal; +using N.EntityFrameworkCore.Extensions.Extensions; +using N.EntityFrameworkCore.Extensions.Util; + +namespace N.EntityFrameworkCore.Extensions; + +public class TableMapping +{ + public DbContext DbContext { get; private set; } + public IEntityType EntityType { get; set; } + public IProperty[] Properties { get; } + public string Schema { get; } + public string TableName { get; } + public IEnumerable EntityTypes { get; } + + public bool HasIdentityColumn => EntityType.FindPrimaryKey().Properties.Any(o => o.ValueGenerated != ValueGenerated. +.Never); + public StoreObjectIdentifier StoreObjectIdentifier => StoreObjectIdentifier.Table(TableName, EntityType.GetSchema() + ?? DbContext.Database.GetDefaultSchema()); + private Dictionary ColumnMap { get; set; } + public string FullQualifedTableName => DbContext.DelimitIdentifier(TableName, Schema); + + public TableMapping(DbContext dbContext, IEntityType entityType) + { + DbContext = dbContext; + EntityType = entityType; + Properties = GetProperties(entityType); + ColumnMap = Properties.Select(p => new KeyValuePair(GetColumnName(p), p)).ToDictionary(); + Schema = entityType.GetSchema() ?? dbContext.Database.GetDefaultSchema(); + TableName = entityType.GetTableName(); + EntityTypes = EntityType.GetAllBaseTypesInclusive().Where(o => !o.IsAbstract()); + } + public IProperty GetPropertyFromColumnName(string columnName) => ColumnMap[columnName]; + private static IProperty[] GetProperties(IEntityType entityType) + { + var properties = entityType.GetProperties().ToList(); + properties.AddRange(entityType.GetComplexProperties().SelectMany(p => p.ComplexType.GetProperties())); + return properties.ToArray(); + } + + public IEnumerable GetQualifiedColumnNames(IEnumerable columnNames, IEntityType entityType = null) + { + return Properties.Where(o => entityType == null || o.GetDeclaringEntityType() == entityType) + .Select(o => new + { + Column = FindColumn(o), + Name = GetColumnName(o) + }) + .Where(o => columnNames == null || columnNames.Contains(o.Name)) + .Select(o => $"{DbContext.DelimitIdentifier(o.Column?.Table.Name ?? TableName)}.{DbContext.DelimitIdentifier +r(o.Name)}").ToList(); + } + public string GetColumnName(IProperty property) => FindColumn(property)?.Name ?? property.Name; + private IColumnBase FindColumn(IProperty property) + { + var entityType = property.GetDeclaringEntityType(); + if (entityType == null || entityType.IsAbstract()) + entityType = EntityType; + var storeObjectIdentifier = StoreObjectIdentifier.Table(entityType.GetTableName(), entityType.GetSchema()); + return property.FindColumn(storeObjectIdentifier); + } + + private string FindTableName(IEntityType declaringEntityType, IEntityType entityType) => + declaringEntityType != null && declaringEntityType.IsAbstract() ? declaringEntityType.GetTableName() : entityTyp +pe.GetTableName(); + public IEnumerable GetColumnNames(IEntityType entityType, bool primaryKeyColumns) + { + List columns; + if (entityType != null) + { + columns = entityType.GetProperties().Where(o => (o.GetDeclaringEntityType() == entityType || o.GetDeclaringE +EntityType().IsAbstract() + || o.IsForeignKeyToSelf()) && o.ValueGenerated == ValueGenerated.Never) + .Select(GetColumnName).ToList(); + + columns.AddRange(entityType.GetComplexProperties().SelectMany(o => o.ComplexType.GetProperties() + .Select(GetColumnName))); + } + else + { + columns = EntityType.GetProperties().Where(o => o.ValueGenerated == ValueGenerated.Never) + .Select(GetColumnName).ToList(); + + columns.AddRange(EntityType.GetComplexProperties().SelectMany(o => o.ComplexType.GetProperties() + .Select(GetColumnName))); + } + if (primaryKeyColumns) + { + columns.AddRange(GetPrimaryKeyColumns()); + } + return columns.Distinct(); + } + public IEnumerable GetColumns(bool includePrimaryKeyColumns = false) + { + List columns = []; + foreach (var entityType in EntityTypes) + { + var storeObjectIdentifier = StoreObjectIdentifier.Create(entityType, StoreObjectType.Table).GetValueOrDefaul +lt(); + columns.AddRange(entityType.GetProperties().Where(o => o.ValueGenerated == ValueGenerated.Never) + .Select(GetColumnName)); + + columns.AddRange(EntityType.GetComplexProperties().SelectMany(o => o.ComplexType.GetProperties() + .Select(GetColumnName))); + + if (includePrimaryKeyColumns) + columns.AddRange(GetPrimaryKeyColumns()); + } + return columns.Where(o => o != null).Distinct(); + } + public IEnumerable GetPrimaryKeyColumns() => + EntityType.FindPrimaryKey().Properties.Select(GetColumnName); + + internal IEnumerable GetAutoGeneratedColumns(IEntityType entityType = null) + { + entityType ??= EntityType; + return entityType.GetProperties().Where(o => o.ValueGenerated != ValueGenerated.Never) + .Select(GetColumnName); + } + + internal IEnumerable GetEntityProperties(IEntityType entityType = null, ValueGenerated? valueGenerated = + null) + { + entityType ??= EntityType; + return entityType.GetProperties().Where(o => valueGenerated == null || o.ValueGenerated == valueGenerated).AsEnu +umerable(); + } + internal Func GetValueFromProvider(IProperty property) + { + var valueConverter = property.GetTypeMapping().Converter; + return valueConverter != null ? value => valueConverter.ConvertFromProvider(value) : value => value; + } + internal IEnumerable GetSchemaQualifiedTableNames() + { + return EntityTypes + .Select(o => DbContext.DelimitIdentifier(o.GetTableName(), o.GetSchema() ?? DbContext.Database.GetDefaultSch +hema())).Distinct(); + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Enums\Connectio +onBehavior.cs --- + +namespace N.EntityFrameworkCore.Extensions.Enums; + +internal enum ConnectionBehavior +{ + Default, + New +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Enums\EfExtensi +ionsCommandType.cs --- + +namespace N.EntityFrameworkCore.Extensions; + +internal enum EfExtensionsCommandType +{ + ChangeTableName +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Extensions\Comm +monExtensions.cs --- + +using System; + +namespace N.EntityFrameworkCore.Extensions.Extensions; + +internal static class CommonExtensions +{ + internal static T Build(this Action buildAction) where T : new() + { + var parameter = new T(); + buildAction(parameter); + return parameter; + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Extensions\DbDa +ataReaderExtensions.cs --- + +using System; +using System.Collections.Generic; +using System.Data.Common; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Metadata; + +namespace N.EntityFrameworkCore.Extensions.Extensions; + +internal static class DbDataReaderExtensions +{ + internal static T MapEntity(this DbDataReader reader, DbContext dbContext, IProperty[] properties, Func[] valuesFromProvider) where T : class, new() + { + var entity = new T(); + var entry = dbContext.Entry(entity); + + for (var i = 0; i < reader.FieldCount; i++) + { + var property = properties[i]; + var value = valuesFromProvider[i].Invoke(reader.GetValue(i)); + if (value == DBNull.Value) + value = null; + + if (property.DeclaringType is IComplexType complexType) + { + var complexProperty = entry.ComplexProperty(complexType.ComplexProperty); + if (complexProperty.CurrentValue == null) + { + complexProperty.CurrentValue = Activator.CreateInstance(complexType.ClrType); + } + complexProperty.Property(property).CurrentValue = value; + } + else + { + entry.Property(property).CurrentValue = value; + } + } + return entity; + } + internal static IProperty[] GetProperties(this DbDataReader reader, TableMapping tableMapping) + { + List properties = []; + + for (var i = 0; i < reader.FieldCount; i++) + { + var property = tableMapping.GetPropertyFromColumnName(reader.GetName(i)); + properties.Add(property); + } + + return properties.ToArray(); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Extensions\IPro +opertyExtensions.cs --- + +using Microsoft.EntityFrameworkCore.Metadata; + +namespace N.EntityFrameworkCore.Extensions.Extensions; + +public static class IPropertyExtensions +{ + public static IEntityType GetDeclaringEntityType(this IProperty property) + { + return property.DeclaringType as IEntityType; + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Extensions\Linq +qExtensions.cs --- + +using System; +using System.Collections.Generic; +using System.Globalization; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Text; +using System.Text.RegularExpressions; +using Microsoft.EntityFrameworkCore; +using N.EntityFrameworkCore.Extensions.Util; + +namespace N.EntityFrameworkCore.Extensions; + +internal static class LinqExtensions +{ + internal static List GetObjectProperties(this Expression> expression) + { + if (expression == null) + { + return []; + } + else if (expression.Body is MemberExpression propertyExpression) + { + return [propertyExpression.Member.Name]; + } + else if (expression.Body is NewExpression newExpression) + { + return newExpression.Members.Select(o => o.Name).ToList(); + } + else if ((expression.Body is UnaryExpression unaryExpression) && (unaryExpression.Operand.GetPrivateFieldValue(" +"Member") is PropertyInfo propertyInfo)) + { + return [propertyInfo.Name]; + } + else + { + throw new InvalidOperationException("GetObjectProperties() encountered an unsupported expression type"); + } + } + internal static string ToSql(this ExpressionType expressionType) => expressionType switch + { + ExpressionType.AndAlso => "AND", + ExpressionType.Or => "OR", + ExpressionType.Add => "+", + ExpressionType.Subtract => "-", + ExpressionType.Multiply => "*", + ExpressionType.Divide => "/", + ExpressionType.Modulo => "%", + ExpressionType.Equal => "=", + _ => string.Empty + }; + + internal static string ToSql(this MemberBinding binding) + { + if (binding is MemberAssignment memberAssingment) + { + return GetExpressionValueAsString(memberAssingment.Expression); + } + else + { + throw new NotSupportedException(); + } + } + internal static string ToSql(this Expression expression) + { + var sb = new StringBuilder(); + if (expression is BinaryExpression binaryExpression) + { + sb.Append(binaryExpression.Left.ToSql()); + sb.Append($" {expression.NodeType.ToSql()} "); + sb.Append(binaryExpression.Right.ToSql()); + } + else if (expression is MemberExpression memberExpression) + { + return $"{memberExpression}"; + } + else if (expression is UnaryExpression unaryExpression) + { + return $"{unaryExpression.Operand}"; + } + return sb.ToString(); + } + internal static string GetExpressionValueAsString(Expression expression) + { + if (expression is ConstantExpression constantExpression) + { + return ConvertToSqlValue(constantExpression.Value); + } + else if (expression is MemberExpression memberExpression) + { + if (memberExpression.Expression is ParameterExpression parameterExpression) + { + return memberExpression.ToString(); + } + else + { + return ConvertToSqlValue(Expression.Lambda(expression).Compile().DynamicInvoke()); + } + } + else if (expression.NodeType == ExpressionType.Convert) + { + return ConvertToSqlValue(Expression.Lambda(expression).Compile().DynamicInvoke()); + } + else if (expression.NodeType == ExpressionType.Call) + { + var methodCallExpression = expression as MethodCallExpression; + List argValues = []; + foreach (var argument in methodCallExpression.Arguments) + { + argValues.Add(GetExpressionValueAsString(argument)); + } + return methodCallExpression.Method.Name switch + { + "ToString" => $"CONVERT(VARCHAR,{argValues[0]})", + _ => $"{methodCallExpression.Method.Name}({string.Join(",", argValues)})" + }; + } + else + { + var binaryExpression = expression as BinaryExpression; + string leftValue = GetExpressionValueAsString(binaryExpression.Left); + string rightValue = GetExpressionValueAsString(binaryExpression.Right); + string joinValue = expression.NodeType.ToSql(); + + return $"({leftValue} {joinValue} {rightValue})"; + } + } + internal static string ToSqlPredicate2(this Expression expression, params string[] parameters) + { + var sql = ToSqlString(expression.Body); + + for (var i = 0; i < parameters.Length; i++) + sql = sql.Replace($"${expression.Parameters[i].Name!}.", $"{parameters[i]}."); + + return sql; + } + internal static string ToSqlPredicate(this Expression expression, params string[] parameters) + { + var expressionBody = (string)expression.Body.GetPrivateFieldValue("DebugView"); + expressionBody = expressionBody.Replace(System.Environment.NewLine, " "); + var stringBuilder = new StringBuilder(expressionBody); + + int i = 0; + foreach (var expressionParam in expression.Parameters) + { + if (parameters.Length <= i) break; + stringBuilder.Replace((string)expressionParam.GetPrivateFieldValue("DebugView"), parameters[i]); + i++; + } + stringBuilder.Replace("== null", "IS NULL"); + stringBuilder.Replace("!= null", "IS NOT NULL"); + stringBuilder.Replace("&&", "AND"); + stringBuilder.Replace("==", "="); + stringBuilder.Replace("||", "OR"); + stringBuilder.Replace("(System.Nullable`1[System.Int32])", ""); + stringBuilder.Replace("(System.Int32)", ""); + return stringBuilder.ToString(); + } + internal static string ToSqlPredicate(this Expression expression, DbContext dbContext, params string[] paramet +ters) + { + string predicate = expression.ToSqlPredicate(parameters); + return DelimitMemberAccess(dbContext, predicate); + } + internal static string ToSqlUpdateSetExpression(this Expression expression, string tableName) + { + List setValues = []; + var memberInitExpression = expression.Body as MemberInitExpression; + foreach (var binding in memberInitExpression.Bindings) + { + string expValue = binding.ToSql(); + expValue = expValue.Replace($"{expression.Parameters.First().Name}.", ""); + setValues.Add($"[{binding.Member.Name}]={expValue}"); + } + return string.Join(",", setValues); + } + internal static string ToSqlUpdateSetExpression(this Expression expression, DbContext dbContext, string tableN +Name) + { + List setValues = []; + var memberInitExpression = expression.Body as MemberInitExpression; + foreach (var binding in memberInitExpression.Bindings) + { + string expValue = binding.ToSql(); + expValue = expValue.Replace($"{expression.Parameters.First().Name}.", ""); + expValue = DelimitMemberAccess(dbContext, expValue); + setValues.Add($"{dbContext.DelimitIdentifier(binding.Member.Name)}={expValue}"); + } + return string.Join(",", setValues); + } + private static string ToSqlString(Expression expression, string sql = null) + { + sql ??= ""; + if (expression is not BinaryExpression b) + return sql; + + var sb = new StringBuilder(); + if (b.Left is MemberExpression mel) + sb.Append($"${mel} = "); + if (b.Right is MemberExpression mer) + sb.Append($"${mer}"); + + if (b.Left is UnaryExpression ubl) + sb.Append($"${ubl.Operand} = "); + if (b.Right is UnaryExpression ubr) + sb.Append($"${ubr.Operand}"); + + if (sb.Length > 0) + return sb.ToString(); + + var left = ToSqlString(b.Left, sql); + if (string.IsNullOrWhiteSpace(left)) + return sql; + + var right = ToSqlString(b.Right, sql); + return $"{left} AND {right}"; + } + private static string ConvertToSqlValue(object value) + { + if (value == null) + return "NULL"; + if (value is string str) + return $"'{str.Replace("'", "''")}'"; + if (value is Guid guid) + return $"'{guid}'"; + if (value is bool b) + return b ? "1" : "0"; + if (value is DateTime dt) + return $"'{dt:yyyy-MM-ddTHH:mm:ss.fffffff}'"; // Convert to ISO-8601 + if (value is DateTimeOffset dto) + return $"'{dto:yyyy-MM-ddTHH:mm:ss.fffffffzzzz}'"; // Convert to ISO-8601 + var valueType = value.GetType(); + if (valueType.IsEnum) + return Convert.ToString((int)value); + if (!valueType.IsClass) + return Convert.ToString(value, CultureInfo.InvariantCulture); + + throw new NotImplementedException("Unhandled data type."); + } + private static string DelimitMemberAccess(DbContext dbContext, string expression) + { + return Regex.Replace(expression, @"(? + { + string alias = match.Groups[1].Value; + string member = match.Groups[2].Value; + return dbContext.DelimitMemberAccess(alias, member); + }); + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Extensions\Obje +ectExtensions.cs --- + +using System; +using System.Reflection; + +namespace N.EntityFrameworkCore.Extensions; + +internal static class ObjectExtensions +{ + internal static object GetPrivateFieldValue(this object obj, string propName) + { + if (obj == null) throw new ArgumentNullException(nameof(obj)); + Type t = obj.GetType(); + FieldInfo fieldInfo = null; + PropertyInfo propertyInfo = null; + while (fieldInfo == null && propertyInfo == null && t != null) + { + fieldInfo = t.GetField(propName, BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); + if (fieldInfo == null) + { + propertyInfo = t.GetProperty(propName, BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Insta +ance); + } + + t = t.BaseType; + } + if (fieldInfo == null && propertyInfo == null) + throw new ArgumentOutOfRangeException(nameof(propName), $"Field {propName} was not found in Type {obj.GetTyp +pe().FullName}"); + + if (fieldInfo != null) + return fieldInfo.GetValue(obj); + + return propertyInfo.GetValue(obj, null); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Extensions\SqlS +StatementExtensions.cs --- + +using System.Collections.Generic; +using N.EntityFrameworkCore.Extensions.Sql; + +namespace N.EntityFrameworkCore.Extensions.Extensions; + +internal static class SqlStatementExtensions +{ + internal static void WriteInsert(this SqlStatement statement, IEnumerable insertColumns) + { + statement.CreatePart(SqlKeyword.Insert, SqlExpression.Columns(insertColumns)); + statement.CreatePart(SqlKeyword.Values, SqlExpression.Columns(insertColumns)); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\GlobalSuppressi +ions.cs --- + +// This file is used by Code Analysis to maintain SuppressMessage +// attributes that are applied to this project. +// Project-level suppressions either have no target or are given +// a specific target and scoped to a namespace, type, member, etc. + +using System.Diagnostics.CodeAnalysis; + +[assembly: SuppressMessage("Usage", "EF1001:Internal EF Core API usage.", Justification = "EntityFrameworkCore Extension +n", Scope = "member", Target = "~M:N.EntityFrameworkCore.Extensions.DbContextExtensions.BulkSaveChanges(Microsoft.EntityF +FrameworkCore.DbContext,System.Boolean)~System.Int32")] +[assembly: SuppressMessage("Usage", "EF1001:Internal EF Core API usage.", Justification = "EntityFrameworkCore Extension +n", Scope = "member", Target = "~M:N.EntityFrameworkCore.Extensions.DbContextExtensions.SetStoreGeneratedValues``1(Micros +soft.EntityFrameworkCore.DbContext,``0,System.Collections.Generic.IEnumerable{Microsoft.EntityFrameworkCore.Metadata.IPro +operty},System.Object[])")] +[assembly: SuppressMessage("Usage", "EF1001:Internal EF Core API usage.", Justification = "EntityFrameworkCore Extension +n", Scope = "member", Target = "~M:N.EntityFrameworkCore.Extensions.DbContextExtensionsAsync.BulkSaveChangesAsync(Microso +oft.EntityFrameworkCore.DbContext,System.Boolean)~System.Threading.Tasks.Task{System.Int32}")] +[assembly: SuppressMessage("Usage", "EF1001:Internal EF Core API usage.", Justification = "EntityFrameworkCore Extension +n", Scope = "member", Target = "~M:N.EntityFrameworkCore.Extensions.TableMapping.GetColumnNames(Microsoft.EntityFramework +kCore.Metadata.IEntityType,System.Boolean)~System.Collections.Generic.IEnumerable{System.String}")] +[assembly: SuppressMessage("Usage", "EF1001:Internal EF Core API usage.", Justification = "EntityFrameworkCore Extension +n", Scope = "member", Target = "~M:N.EntityFrameworkCore.Extensions.EntityDataReader`1.FindEntry(System.Object)~Microsoft +t.EntityFrameworkCore.ChangeTracking.EntityEntry")] + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\N.EntityFramewo +ork.Extensions.PostgreSql.csproj --- + + + + + net10.0 + 10.0.5.1 + N.EntityFramework.Extensions.PostgreSql + true + https://github.com/NorthernLight1/N.EntityFramework.Extensions.PostgreSql/ + Northern25 + Copyright © 2026 + + N.EntityFramework.Extensions.PostgreSql extends your DbContext in EF Core with high-performance bulk op +perations for PostgreSql: BulkDelete, BulkInsert, BulkMerge, BulkSync, BulkUpdate, Fetch, DeleteFromQuery, InsertFromQuer +ry, UpdateFromQuery. + +Inheritance models supported: Table-Per-Concrete, Table-Per-Hierarchy, Table-Per-Type + MIT + README.md + + + + 5 + + + + + True + \ + + + + + + + + + + + + + + + + + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Sql\SqlBuilder. +.cs --- + +using System; +using System.Collections.Generic; +using System.Data; +using System.Linq; +using System.Linq.Expressions; +using Microsoft.Data.SqlClient; + +namespace N.EntityFrameworkCore.Extensions.Sql; + +internal sealed class SqlBuilder +{ + private static readonly string[] keywords = ["DECLARE", "SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY"]; + internal string Sql => ToString(); + internal List Clauses { get; private set; } + internal List Parameters { get; private set; } + private SqlBuilder(string sql) + { + Clauses = []; + Parameters = []; + Initialize(sql); + } + + internal string Count() => + $"SELECT COUNT(*) FROM ({string.Join("\r\n", Clauses.Where(o => o.Name != "ORDER BY").Select(o => o.ToString())) +)}) s"; + public override string ToString() => string.Join("\r\n", Clauses.Select(o => o.ToString())); + internal static SqlBuilder Parse(string sql) => new SqlBuilder(sql); + internal string GetTableAlias() + { + var sqlFromClause = Clauses.First(o => o.Name == "FROM"); + var startIndex = sqlFromClause.InputText.LastIndexOf(" AS "); + return startIndex > 0 ? sqlFromClause.InputText[(startIndex + 4)..] : ""; + } + internal void ChangeToDelete() + { + Validate(); + var sqlClause = Clauses.FirstOrDefault(); + var sqlFromClause = Clauses.First(o => o.Name == "FROM"); + if (sqlClause != null) + { + sqlClause.Name = "DELETE"; + int aliasStartIndex = sqlFromClause.InputText.IndexOf("AS ") + 3; + int aliasLength = sqlFromClause.InputText.IndexOf(']', aliasStartIndex) - aliasStartIndex + 1; + sqlClause.InputText = sqlFromClause.InputText[aliasStartIndex..(aliasStartIndex + aliasLength)]; + } + } + internal void ChangeToUpdate(string updateExpression, string setExpression) + { + Validate(); + var sqlClause = Clauses.FirstOrDefault(); + if (sqlClause != null) + { + sqlClause.Name = "UPDATE"; + sqlClause.InputText = updateExpression; + Clauses.Insert(1, new SqlClause { Name = "SET", InputText = setExpression }); + } + } + internal void ChangeToInsert(string tableName, Expression> insertObjectExpression) + { + Validate(); + var sqlSelectClause = Clauses.FirstOrDefault(); + string columnsToInsert = string.Join(",", insertObjectExpression.GetObjectProperties()); + string insertValueExpression = $"INTO {tableName} ({columnsToInsert})"; + Clauses.Insert(0, new SqlClause { Name = "INSERT", InputText = insertValueExpression }); + sqlSelectClause.InputText = columnsToInsert; + } + internal void SelectColumns(IEnumerable columns) + { + var tableAlias = GetTableAlias(); + var sqlClause = Clauses.FirstOrDefault(); + if (sqlClause.Name == "SELECT") + { + sqlClause.InputText = string.Join(",", columns.Select(c => $"{tableAlias}.{c}")); + } + } + private void Initialize(string sqlText) + { + string curClause = string.Empty; + int curClauseIndex = 0; + for (int i = 0; i < sqlText.Length;) + { + string keyword = StartsWithString(sqlText.AsSpan(i), keywords, StringComparison.OrdinalIgnoreCase); + bool isWordStart = i == 0 || sqlText[i - 1] == ' ' || (i > 1 && sqlText[i - 2] == '\r' && sqlText[i - 1] == + '\n'); + if (keyword != null && isWordStart) + { + string inputText = sqlText[curClauseIndex..i]; + if (!string.IsNullOrEmpty(curClause)) + { + if (curClause == "DECLARE") + { + var declareParts = inputText[..inputText.IndexOf(';')].Trim().Split(' '); + int sizeStartIndex = declareParts[1].IndexOf('('); + int sizeLength = declareParts[1].IndexOf(')') - (sizeStartIndex + 1); + string dbTypeString = sizeStartIndex != -1 ? declareParts[1][..sizeStartIndex] : declareParts[1] +]; + SqlDbType dbType = (SqlDbType)Enum.Parse(typeof(SqlDbType), dbTypeString, true); + int size = sizeStartIndex != -1 ? + Convert.ToInt32(declareParts[1][(sizeStartIndex + 1)..(sizeStartIndex + 1 + sizeLength)]) : + 0; + string value = GetDeclareValue(declareParts[3]); + Parameters.Add(new SqlParameter(declareParts[0], dbType, size) { Value = value }); + } + else + { + Clauses.Add(SqlClause.Parse(curClause, inputText)); + } + } + curClause = keyword; + curClauseIndex = i + curClause.Length; + i = i + curClause.Length; + } + else + { + i++; + } + } + if (!string.IsNullOrEmpty(curClause)) + Clauses.Add(SqlClause.Parse(curClause, sqlText[curClauseIndex..])); + } + private string GetDeclareValue(string value) + { + if (value.StartsWith('\'')) + { + return value[1..^1]; + } + else if (value.StartsWith("N'")) + { + return value[2..^1]; + } + else if (value.StartsWith("CAST(")) + { + return value[5..]; + } + else + { + return value; + } + } + private static string StartsWithString(ReadOnlySpan textToSearch, string[] valuesToFind, StringComparison stri +ingComparison) + { + foreach (var valueToFind in valuesToFind) + { + if (textToSearch.StartsWith(valueToFind, stringComparison)) + return valueToFind; + } + + return null; + } + private void Validate() + { + if (Clauses.Count == 0) + { + throw new Exception("You must parse a valid sql statement before you can use this function."); + } + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Sql\SqlClause.c +cs --- + +namespace N.EntityFrameworkCore.Extensions.Sql; + +internal sealed class SqlClause +{ + internal string Name { get; set; } + internal string InputText { get; set; } + internal string Sql => ToString(); + internal static SqlClause Parse(string name, string inputText) + { + string cleanText = inputText.Replace("\r\n", "").Trim(); + return new SqlClause { Name = name, InputText = cleanText }; + } + public override string ToString() => $"{Name} {InputText}"; +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Sql\SqlExpressi +ion.cs --- + +using System.Collections.Generic; +using System.Linq; +using System.Text; +using N.EntityFrameworkCore.Extensions.Util; + +namespace N.EntityFrameworkCore.Extensions.Sql; + +internal sealed class SqlExpression +{ + internal SqlExpressionType ExpressionType { get; } + List Items { get; set; } + internal string Sql => ToSql(); + string Alias { get; } + internal bool IsEmpty => Items.Count == 0; + + SqlExpression(SqlExpressionType expressionType, object item, string alias = null) + { + ExpressionType = expressionType; + Items = []; + if (item is IEnumerable values) + { + Items.AddRange(values.ToArray()); + } + else + { + Items.Add(item); + } + Alias = alias; + } + SqlExpression(SqlExpressionType expressionType, object[] items, string alias = null) + { + ExpressionType = expressionType; + Items = []; + Items.AddRange(items); + Alias = alias; + } + internal static SqlExpression Columns(IEnumerable columns) => + new SqlExpression(SqlExpressionType.Columns, columns); + + internal static SqlExpression Set(IEnumerable columns) => + new SqlExpression(SqlExpressionType.Set, columns); + + internal static SqlExpression String(string joinOnCondition) => + new SqlExpression(SqlExpressionType.String, joinOnCondition); + + internal static SqlExpression Table(string tableName, string alias = null) => + new SqlExpression(SqlExpressionType.Table, Util.CommonUtil.FormatTableName(tableName), alias); + + private string ToSql() + { + var sbSql = new StringBuilder(); + if (ExpressionType == SqlExpressionType.Columns) + { + var values = Items.Where(o => o != null).Select(o => o.ToString()).Where(o => !string.IsNullOrWhiteSpace(o)) +).ToArray(); + sbSql.Append(string.Join(",", CommonUtil.FormatColumns(values))); + } + else + { + sbSql.Append(string.Join(",", Items.Where(o => o != null).Select(o => o.ToString()).Where(o => !string.IsNul +llOrWhiteSpace(o)))); + } + if (Alias != null) + { + sbSql.Append(" "); + sbSql.Append(SqlKeyword.As.ToString().ToUpper()); + sbSql.Append(" "); + sbSql.Append(Alias); + } + return sbSql.ToString(); + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Sql\SqlExpressi +ionType.cs --- + +namespace N.EntityFrameworkCore.Extensions.Sql; + +internal enum SqlExpressionType +{ + String, + Table, + Columns, + Set +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Sql\SqlKeyword. +.cs --- + +namespace N.EntityFrameworkCore.Extensions.Sql; + +internal enum SqlKeyword +{ + Select, + Delete, + Insert, + Values, + Update, + Set, + Merge, + Into, + From, + On, + Where, + Using, + When, + Then, + Matched, + Not, + Output, + As, + By, + Source, + Target, + Off, + Identity_Insert, + Semicolon, +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Sql\SqlPart.cs + --- + +namespace N.EntityFrameworkCore.Extensions.Sql; + +internal sealed class SqlPart +{ + internal SqlKeyword Keyword { get; } + internal SqlExpression Expression { get; } + internal bool IgnoreOutput => GetIgnoreOutput(); + internal SqlPart(SqlKeyword keyword, SqlExpression expression) + { + Keyword = keyword; + Expression = expression; + } + private bool GetIgnoreOutput() => Keyword == SqlKeyword.Output && (Expression == null || Expression.IsEmpty); +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Sql\SqlStatemen +nt.cs --- + +using System.Collections.Generic; +using System.Linq; +using System.Text; +using N.EntityFrameworkCore.Extensions.Extensions; + +namespace N.EntityFrameworkCore.Extensions.Sql; + +internal sealed class SqlStatement +{ + internal string Sql => ToSql(); + List SqlParts { get; } + SqlStatement() + { + SqlParts = []; + } + internal void CreatePart(SqlKeyword keyword, SqlExpression expression = null) => + SqlParts.Add(new SqlPart(keyword, expression)); + internal void SetIdentityInsert(string tableName, bool enable) + { + CreatePart(SqlKeyword.Set); + CreatePart(SqlKeyword.Identity_Insert, SqlExpression.Table(tableName)); + if (enable) + CreatePart(SqlKeyword.On); + else + CreatePart(SqlKeyword.Off); + CreatePart(SqlKeyword.Semicolon); + } + internal static SqlStatement CreateMerge(string sourceTableName, string targetTableName, string joinOnCondition, + IEnumerable insertColumns, IEnumerable updateColumns, IEnumerable outputColumns, + bool deleteIfNotMatched = false, bool hasIdentityColumn = false) + { + var statement = new SqlStatement(); + if (hasIdentityColumn) + statement.SetIdentityInsert(targetTableName, true); + statement.CreatePart(SqlKeyword.Merge, SqlExpression.Table(targetTableName, "t")); + statement.CreatePart(SqlKeyword.Using, SqlExpression.Table(sourceTableName, "s")); + statement.CreatePart(SqlKeyword.On, SqlExpression.String(joinOnCondition)); + statement.CreatePart(SqlKeyword.When); + statement.CreatePart(SqlKeyword.Not); + statement.CreatePart(SqlKeyword.Matched); + statement.CreatePart(SqlKeyword.Then); + statement.WriteInsert(insertColumns); + if (updateColumns.Any()) + { + var updateSetColumns = updateColumns.Select(c => $"t.[{c}]=s.[{c}]"); + statement.CreatePart(SqlKeyword.When); + statement.CreatePart(SqlKeyword.Matched); + statement.CreatePart(SqlKeyword.Then); + statement.CreatePart(SqlKeyword.Update); + statement.CreatePart(SqlKeyword.Set, SqlExpression.Set(updateSetColumns)); + } + if (deleteIfNotMatched) + { + statement.CreatePart(SqlKeyword.When); + statement.CreatePart(SqlKeyword.Not); + statement.CreatePart(SqlKeyword.Matched); + statement.CreatePart(SqlKeyword.By); + statement.CreatePart(SqlKeyword.Source); + statement.CreatePart(SqlKeyword.Then); + statement.CreatePart(SqlKeyword.Delete); + } + if (outputColumns.Any()) + statement.CreatePart(SqlKeyword.Output, SqlExpression.Columns(outputColumns)); + statement.CreatePart(SqlKeyword.Semicolon); + + if (hasIdentityColumn) + statement.SetIdentityInsert(targetTableName, false); + return statement; + } + + private string ToSql() + { + var sbSql = new StringBuilder(); + foreach (var part in SqlParts) + { + if (part.Keyword == SqlKeyword.Semicolon) + { + int lastIndex = sbSql.Length - 1; + if (lastIndex > -1 && sbSql[lastIndex] == ' ') + { + sbSql[lastIndex] = ';'; + sbSql.Append("\n"); + } + else + { + sbSql.Append(";\n"); + } + } + else if (!part.IgnoreOutput) + { + sbSql.Append(part.Keyword.ToString().ToUpper()); + sbSql.Append(" "); + bool useParenthese = part.Keyword == SqlKeyword.Insert || part.Keyword == SqlKeyword.Values; + + if (part.Expression != null) + { + string expressionSql = useParenthese ? $"({part.Expression.Sql})" : part.Expression.Sql; + sbSql.Append(expressionSql); + sbSql.Append(" "); + } + } + } + return sbSql.ToString(); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Util\CommonUtil +l.cs --- + +using System; +using System.Collections.Generic; +using System.Data.Common; +using System.IO; +using System.Linq; +using System.Linq.Expressions; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Infrastructure; + +namespace N.EntityFrameworkCore.Extensions.Util; + +internal static class CommonUtil +{ + internal static string GetStagingTableName(TableMapping tableMapping, bool usePermanentTable, DbConnection dbConnect +tion) + { + string uniqueSuffix = Guid.NewGuid().ToString("N"); + if (usePermanentTable) + return tableMapping.DbContext.Database.GetPermanentStagingTableName(tableMapping.Schema, tableMapping.TableN +Name, uniqueSuffix); + return tableMapping.DbContext.Database.GetTemporaryTableName(tableMapping.TableName); + } + internal static IEnumerable FormatColumns(DbContext dbContext, IEnumerable columns) + { + return columns.Select(s => FormatColumn(dbContext, s)); + } + internal static IEnumerable FormatColumns(IEnumerable columns) + { + return columns.Select(FormatColumnLegacy); + } + internal static IEnumerable FormatColumns(DbContext dbContext, string tableAlias, IEnumerable column +ns) + { + return columns.Select(s => dbContext.DelimitMemberAccess(tableAlias, RemoveQualifier(s))); + } + internal static IEnumerable FormatColumns(DatabaseFacade database, string tableAlias, IEnumerable co +olumns) + { + return columns.Select(s => database.DelimitMemberAccess(tableAlias, RemoveQualifier(s))); + } + internal static IEnumerable FormatColumns(string tableAlias, IEnumerable columns) + { + return columns.Select(s => s.StartsWith('[') && s.EndsWith(']') ? $"[{tableAlias}].{s}" : $"[{tableAlias}].[{s}] +]"); + } + internal static IEnumerable FilterColumns(IEnumerable columnNames, string[] primaryKeyColumnNames +s, Expression> inputColumns, Expression> ignoreColumns) + { + var filteredColumnNames = columnNames; + if (inputColumns != null) + { + var inputColumnNames = inputColumns.GetObjectProperties(); + filteredColumnNames = filteredColumnNames.Intersect(inputColumnNames.Union(primaryKeyColumnNames)); + } + if (ignoreColumns != null) + { + var ignoreColumnNames = ignoreColumns.GetObjectProperties(); + if (ignoreColumnNames.Intersect(primaryKeyColumnNames).Any()) + { + throw new InvalidDataException("Primary key columns can not be ignored in BulkInsertOptions.IgnoreColumn +ns"); + } + else + { + filteredColumnNames = filteredColumnNames.Except(ignoreColumnNames); + } + } + return filteredColumnNames; + } + internal static string FormatTableName(DatabaseFacade database, string tableName) + { + return database.DelimitTableName(tableName); + } + internal static string FormatTableName(string tableName) + { + return string.Join(".", tableName.Split('.').Select(s => $"[{RemoveQualifier(s)}]")); + } + private static string FormatColumn(DbContext dbContext, string column) + { + var parts = column.Split('.'); + return string.Join(".", parts.Select(p => p.StartsWith('$') ? p : dbContext.DelimitIdentifier(RemoveQualifier(p) +)))); + } + private static string FormatColumnLegacy(string column) + { + var parts = column.Split('.'); + return string.Join(".", parts.Select(p => p.StartsWith('$') || (p.StartsWith('[') && p.EndsWith(']')) ? p : $"[{ +{p}]")); + } + private static string RemoveQualifier(string name) + { + return name.TrimStart('[').TrimEnd(']').Trim('"'); + } +} +internal static class CommonUtil +{ + internal static string[] GetColumns(Expression> expression, string[] tableNames) + { + List foundColumns = []; + string sqlText = (string)expression.Body.GetPrivateFieldValue("DebugView"); + var sqlSpan = sqlText.AsSpan(); + + int offset = 0; + while (offset < sqlSpan.Length) + { + int startIndex = sqlSpan[offset..].IndexOf('$'); + if (startIndex == -1) break; + startIndex += offset; + + var remaining = sqlSpan[startIndex..]; + int spaceIndex = remaining.IndexOf(' '); + var columnSpan = spaceIndex == -1 ? remaining : remaining[..spaceIndex]; + + int dotIndex = columnSpan.IndexOf('.'); + if (dotIndex >= 0) + { + var tablePart = columnSpan[1..dotIndex]; // skip leading '$' + var columnPart = columnSpan[(dotIndex + 1)..]; + if (tableNames == null || tableNames.Contains(tablePart.ToString())) + { + foundColumns.Add(columnPart.ToString()); + } + } + + offset = startIndex + 1; + } + + return foundColumns.ToArray(); + } + internal static string GetJoinConditionSql(Expression> joinKeyExpression, string[] storeGeneratedCo +olumnNames, string sourceTableName = "s", string targetTableName = "t") + { + if (joinKeyExpression != null) + return joinKeyExpression.ToSqlPredicate(sourceTableName, targetTableName); + + return string.Join(" AND ", storeGeneratedColumnNames.Select(c => $"{sourceTableName}.[{c}]={targetTableName}.[{ +{c}]")); + } + internal static string GetJoinConditionSql(DbContext dbContext, Expression> joinKeyExpression, stri +ing[] storeGeneratedColumnNames, string sourceTableName = "s", string targetTableName = "t") + { + if (joinKeyExpression != null) + return joinKeyExpression.ToSqlPredicate(dbContext, sourceTableName, targetTableName); + + return string.Join(" AND ", storeGeneratedColumnNames.Select(c => $"{dbContext.DelimitMemberAccess(sourceTableNa +ame, c)}={dbContext.DelimitMemberAccess(targetTableName, c)}")); + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Util\Relational +lProviderUtil.cs --- + +using System; +using System.Data.Common; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Storage; +using Npgsql; + +namespace N.EntityFrameworkCore.Extensions.Util; + +internal enum DatabaseProvider +{ + SqlServer, + PostgreSql +} + +internal readonly record struct DatabaseObjectName(string Schema, string Name) +{ + internal bool HasSchema => !string.IsNullOrWhiteSpace(Schema); +} + +internal static class RelationalProviderUtil +{ + internal static DatabaseProvider GetDatabaseProvider(this DatabaseFacade database) + { + return database.ProviderName switch + { + string providerName when providerName.Contains("SqlServer", StringComparison.OrdinalIgnoreCase) => DatabaseP +Provider.SqlServer, + string providerName when providerName.Contains("Npgsql", StringComparison.OrdinalIgnoreCase) => DatabaseProv +vider.PostgreSql, + _ => throw new NotSupportedException($"The database provider '{database.ProviderName}' is not supported.") + }; + } + + internal static bool IsSqlServer(this DatabaseFacade database) => database.GetDatabaseProvider() == DatabaseProvider +r.SqlServer; + + internal static bool IsPostgreSql(this DatabaseFacade database) => database.GetDatabaseProvider() == DatabaseProvide +er.PostgreSql; + + internal static string GetDefaultSchema(this DatabaseFacade database) => + database.IsPostgreSql() ? "public" : "dbo"; + + internal static string DelimitIdentifier(this DatabaseFacade database, string identifier) => + database.GetSqlGenerationHelper().DelimitIdentifier(UnwrapIdentifier(identifier)); + + internal static string DelimitIdentifier(this DatabaseFacade database, string identifier, string schema) => + schema == null + ? database.DelimitIdentifier(identifier) + : database.GetSqlGenerationHelper().DelimitIdentifier(UnwrapIdentifier(identifier), UnwrapIdentifier(schema) +)); + + internal static string DelimitIdentifier(this DbContext dbContext, string identifier) => + dbContext.Database.DelimitIdentifier(identifier); + + internal static string DelimitIdentifier(this DbContext dbContext, string identifier, string schema) => + dbContext.Database.DelimitIdentifier(identifier, schema); + + internal static string DelimitTableName(this DatabaseFacade database, string tableName) + { + var objectName = database.ParseObjectName(tableName); + return objectName.HasSchema + ? database.DelimitIdentifier(objectName.Name, objectName.Schema) + : database.DelimitIdentifier(objectName.Name); + } + + internal static string DelimitTableName(this DbContext dbContext, string tableName) => + dbContext.Database.DelimitTableName(tableName); + + internal static string DelimitMemberAccess(this DbContext dbContext, string alias, string columnName) => + $"{dbContext.DelimitIdentifier(alias)}.{dbContext.DelimitIdentifier(columnName)}"; + + internal static string DelimitMemberAccess(this DatabaseFacade database, string alias, string columnName) => + $"{database.DelimitIdentifier(alias)}.{database.DelimitIdentifier(columnName)}"; + + internal static DatabaseObjectName ParseObjectName(this DatabaseFacade database, string objectName) + { + string normalized = objectName.Trim(); + if (string.IsNullOrWhiteSpace(normalized)) + throw new ArgumentException("Object name cannot be empty.", nameof(objectName)); + + var parts = normalized.Split('.', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries); + return parts.Length switch + { + 1 => new DatabaseObjectName(IsTemporaryName(parts[0]) ? null : database.GetDefaultSchema(), UnwrapIdentifier +r(parts[0])), + 2 => new DatabaseObjectName(UnwrapIdentifier(parts[0]), UnwrapIdentifier(parts[1])), + _ => throw new InvalidOperationException($"Unsupported object name format '{objectName}'.") + }; + } + + internal static string UnwrapIdentifier(string value) => + value.Trim().Trim('[', ']', '"'); + + internal static string GetTemporaryTableName(this DatabaseFacade database, string baseName) + { + string temporaryName = $"tmp_be_xx_{UnwrapIdentifier(baseName)}_{Guid.NewGuid():N}"; + return database.DelimitIdentifier(temporaryName); + } + + internal static string GetPermanentStagingTableName(this DatabaseFacade database, string schema, string tableName, s +string uniqueSuffix) + { + string stagingName = $"tmp_be_xx_{UnwrapIdentifier(tableName)}_{uniqueSuffix}"; + return database.DelimitIdentifier(stagingName, schema); + } + + internal static DbConnection CloneConnection(this DbConnection dbConnection) => + dbConnection switch + { + ICloneable cloneable => (DbConnection)cloneable.Clone(), + _ => throw new NotSupportedException($"Connection type '{dbConnection.GetType().FullName}' does not support + cloning.") + }; + + private static ISqlGenerationHelper GetSqlGenerationHelper(this DatabaseFacade database) => + ((IInfrastructure)database).Instance.GetService(typeof(ISqlGenerationHelper)) as ISqlGeneratio +onHelper + ?? throw new InvalidOperationException("Unable to resolve ISqlGenerationHelper."); + + private static bool IsTemporaryName(string objectName) => + UnwrapIdentifier(objectName).StartsWith("#", StringComparison.Ordinal); +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFramework.Extensions.PostgreSql\Util\SqlUtil.cs +s --- + +using System.Collections.Generic; + +namespace N.EntityFrameworkCore.Extensions; + +internal static class SqlUtil +{ + internal static string ConvertToColumnString(IEnumerable columnNames) + { + return string.Join(",", columnNames); + } +} + +=== DIRECTORY: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFram +mework.Extensions.SqlServer.Test === + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\appsettings.json --- + +{ + "ConnectionStrings": { + "SqlServerTestDatabase": "Server=(localdb)\\mssqllocaldb;Database=N.EntityFrameworkCore.Test;Trusted_Connection=True +e;" + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\Common\Config.cs --- + +using System; +using System.Data.Common; +using Microsoft.Data.SqlClient; +using Microsoft.Extensions.Configuration; + +namespace N.EntityFrameworkCore.Extensions.Test.Common; + +public class Config +{ + private static readonly IConfigurationRoot configuration = new ConfigurationBuilder() + .AddJsonFile("appsettings.json", optional: true, reloadOnChange: true) + .AddEnvironmentVariables() + .Build(); + + public static string GetConnectionString(string name) + { + return configuration.GetConnectionString(name); + } + public static bool IsSqlServer => true; + public static string GetTestDatabaseConnectionString() => GetConnectionString("SqlServerTestDatabase"); + public static DbParameter CreateParameter(string name, object value) => new SqlParameter(name, value ?? DBNull.Value +e); + public static string DelimitIdentifier(string identifier) => $"[{identifier}]"; + public static string DelimitTableName(string tableName) => tableName; + public static bool IsPrimaryKeyViolation(Exception exception) => + exception.Message.StartsWith("Violation of PRIMARY KEY constraint 'PK_Orders'.", StringComparison.Ordinal); +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\Common\TestDatabaseInitializer.cs --- + +using System.Threading.Tasks; +using Microsoft.EntityFrameworkCore; +using N.EntityFrameworkCore.Extensions.Test.Data; + +namespace N.EntityFrameworkCore.Extensions.Test.Common; + +internal static class TestDatabaseInitializer +{ + internal static void EnsureCreated(TestDbContext dbContext) + { + dbContext.Database.EnsureCreated(); + CreateSqlServerObjects(dbContext); + } + + internal static async Task EnsureCreatedAsync(TestDbContext dbContext) + { + await dbContext.Database.EnsureCreatedAsync(); + await CreateSqlServerObjectsAsync(dbContext); + } + + internal static void CreateSqlServerObjects(TestDbContext dbContext) + { + dbContext.Database.ExecuteSqlRaw(""" + CREATE OR ALTER TRIGGER trgProductWithTriggers + ON ProductsWithTrigger + FOR INSERT, UPDATE, DELETE + AS + BEGIN + PRINT 1 + END + """); + } + + internal static async Task CreateSqlServerObjectsAsync(TestDbContext dbContext) + { + await dbContext.Database.ExecuteSqlRawAsync(""" + CREATE OR ALTER TRIGGER trgProductWithTriggers + ON ProductsWithTrigger + FOR INSERT, UPDATE, DELETE + AS + BEGIN + PRINT 1 + END + """); + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\Data\Address.cs --- + +using System; +using System.Collections.Generic; +using System.ComponentModel.DataAnnotations.Schema; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +[ComplexType] +public class Address +{ + public required string Line1 { get; set; } + public string? Line2 { get; set; } + public required string City { get; set; } + public required string Country { get; set; } + public required string PostCode { get; set; } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\Data\Enums\ProductStatus.cs --- + +namespace N.EntityFrameworkCore.Extensions.Test.Data.Enums; + +public enum ProductStatus +{ + InStock, + OutOfStock, +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\Data\Order.cs --- + +using System; +using System.ComponentModel.DataAnnotations; + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class Order +{ + [Key] + public long Id { get; set; } + public string ExternalId { get; set; } + public Guid? GlobalId { get; set; } + public decimal Price { get; set; } + public DateTime AddedDateTime { get; set; } + public DateTime? ModifiedDateTime { get; set; } + public DateTimeOffset? ModifiedDateTimeOffset { get; set; } + public bool DbActive { get; set; } + public DateTime DbAddedDateTime { get; set; } + public DateTime DbModifiedDateTime { get; set; } + public bool? Trigger { get; set; } + public bool Active { get; set; } + public OrderStatus Status { get; set; } + public Order() + { + AddedDateTime = DateTime.UtcNow; + Active = true; + } +} + +public enum OrderStatus +{ + Unknown, + Completed, + Error +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\Data\OrderWithComplexType.cs --- + +using System; +using System.Collections.Generic; +using System.ComponentModel.DataAnnotations; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class OrderWithComplexType +{ + [Key] + public long Id { get; set; } + [Required] + public Address ShippingAddress { get; set; } + [Required] + public Address BillingAddress { get; set; } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\Data\Position.cs --- + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class Position +{ + public int Building; + public int Aisle; + public int Bay; +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\Data\Product.cs --- + +using System; +using System.ComponentModel.DataAnnotations; +using System.ComponentModel.DataAnnotations.Schema; +using N.EntityFrameworkCore.Extensions.Test.Data.Enums; + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class Product +{ + [Key] + [DatabaseGenerated(DatabaseGeneratedOption.None)] + public string Id { get; set; } + [StringLength(50)] + public string Name { get; set; } + public decimal Price { get; set; } + public bool OutOfStock { get; set; } + [Column("Status")] + [StringLength(25)] + public string StatusString { get; set; } + public int? ProductCategoryId { get; set; } + public System.Drawing.Color Color { get; set; } + public ProductStatus? StatusEnum { get; set; } + public DateTime? UpdatedDateTime { get; set; } + + public Position Position { get; set; } + + public virtual ProductCategory ProductCategory { get; set; } + public Product() + { + + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\Data\ProductCategory.cs --- + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class ProductCategory +{ + public int Id { get; set; } + public string Name { get; set; } + public bool Active { get; internal set; } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\Data\ProductWithComplexKey.cs --- + +using System; +using System.ComponentModel.DataAnnotations; +using System.ComponentModel.DataAnnotations.Schema; + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class ProductWithComplexKey +{ + public Guid Key1 { get; set; } + public Guid Key2 { get; set; } + public Guid Key3 { get; set; } + public Guid Key4 { get; set; } + public string ExternalId { get; set; } + public decimal Price { get; set; } + public bool OutOfStock { get; set; } + [Column("Status")] + [StringLength(25)] + public string StatusString { get; set; } + public DateTime? UpdatedDateTime { get; set; } + public ProductWithComplexKey() + { + Key3 = Guid.NewGuid(); + Key4 = Guid.NewGuid(); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\Data\ProductWithCustomSchema.cs --- + +using System.ComponentModel.DataAnnotations; +using System.ComponentModel.DataAnnotations.Schema; + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class ProductWithCustomSchema +{ + [Key] + [DatabaseGenerated(DatabaseGeneratedOption.None)] + public string Id { get; set; } + [StringLength(50)] + public string Name { get; set; } + public decimal Price { get; set; } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\Data\ProductWithTrigger.cs --- + +using System.ComponentModel.DataAnnotations; +using System.ComponentModel.DataAnnotations.Schema; + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class ProductWithTrigger +{ + [Key] + [DatabaseGenerated(DatabaseGeneratedOption.None)] + public string Id { get; set; } + [StringLength(50)] + public string Name { get; set; } + public decimal Price { get; set; } + public bool OutOfStock { get; set; } + [Column("Status")] + [StringLength(25)] + public string StatusString { get; set; } + public ProductWithTrigger() + { + + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\Data\TestDbContext.cs --- + +using System; +using System.Drawing; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Diagnostics; +using N.EntityFrameworkCore.Extensions.Test.Common; + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class TestDbContext : DbContext +{ + public virtual DbSet Products { get; set; } + public virtual DbSet ProductCategories { get; set; } + public virtual DbSet ProductsWithCustomSchema { get; set; } + public virtual DbSet ProductsWithComplexKey { get; set; } + public virtual DbSet ProductsWithTrigger { get; set; } + public virtual DbSet Orders { get; set; } + public virtual DbSet OrdersWithComplexType { get; set; } + public virtual DbSet TpcPeople { get; set; } + public virtual DbSet TphPeople { get; set; } + public virtual DbSet TphCustomers { get; set; } + public virtual DbSet TphVendors { get; set; } + public virtual DbSet TptPeople { get; set; } + public virtual DbSet TptCustomers { get; set; } + public virtual DbSet TptVendors { get; set; } + + protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder) + { + optionsBuilder.UseSqlServer(Config.GetTestDatabaseConnectionString()); + optionsBuilder.SetupEfCoreExtensions(); + optionsBuilder.UseLazyLoadingProxies(); + // Tell EF Core to allow mismatched models for this test run + optionsBuilder.ConfigureWarnings(warnings => + warnings.Ignore(RelationalEventId.PendingModelChangesWarning)); + } + protected override void OnModelCreating(ModelBuilder modelBuilder) + { + modelBuilder.Entity().ToTable("Product", "top"); + modelBuilder.Entity().HasKey(c => new { c.Key1 }); + modelBuilder.Entity().Property("Key1").HasDefaultValueSql("newsequentialid()"); + modelBuilder.Entity().Property("Key2").HasDefaultValueSql("newsequentialid()"); + modelBuilder.Entity().HasKey(p => new { p.Key3, p.Key4 }); + modelBuilder.Entity().Property("DbAddedDateTime").HasDefaultValueSql("getdate()"); + modelBuilder.Entity().Property("DbModifiedDateTime").HasComputedColumnSql("getdate()"); + modelBuilder.Entity().Property(p => p.DbActive).HasDefaultValueSql("((1))"); + modelBuilder.Entity().Property(p => p.Status).HasConversion(); + modelBuilder.Entity(b => + { + b.ComplexProperty(e => e.BillingAddress); + b.ComplexProperty(e => e.ShippingAddress); + }); + modelBuilder.Entity().UseTpcMappingStrategy(); + modelBuilder.Entity().ToTable("TpcCustomer"); + modelBuilder.Entity().ToTable("TpcVendor"); + modelBuilder.Entity().Property("CreatedDate"); + modelBuilder.Entity().ToTable("TptPeople"); + modelBuilder.Entity().ToTable("TptCustomer"); + modelBuilder.Entity().ToTable("TptVendor"); + modelBuilder.Entity(t => + { + t.ComplexProperty(p => p.Position).IsRequired(); + t.Property(p => p.Color).HasConversion(x => x.ToArgb(), x => Color.FromArgb(x)); + }); + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\Data\TpcCustomer.cs --- + +using System; + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class TpcCustomer : TpcPerson +{ + public string Email { get; set; } + public string Phone { get; set; } + public DateTime AddedDate { get; set; } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\Data\TpcPerson.cs --- + +using System.ComponentModel.DataAnnotations.Schema; + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public abstract class TpcPerson +{ + [DatabaseGenerated(DatabaseGeneratedOption.None)] + public long Id { get; set; } + public string FirstName { get; set; } + public string LastName { get; set; } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\Data\TpcVendor.cs --- + + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class TpcVendor : TpcPerson +{ + public string Email { get; set; } + public string Phone { get; set; } + public string Url { get; set; } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\Data\TphCustomer.cs --- + +using System; + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class TphCustomer : TphPerson +{ + public string Email { get; set; } + public string Phone { get; set; } + public DateTime AddedDate { get; set; } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\Data\TphPerson.cs --- + +using System.ComponentModel.DataAnnotations.Schema; + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +[Table("TphPeople")] +public abstract class TphPerson +{ + public long Id { get; set; } + public string FirstName { get; set; } + public string LastName { get; set; } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\Data\TphVendor.cs --- + + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class TphVendor : TphPerson +{ + public string Email { get; set; } + public string Phone { get; set; } + public string Url { get; set; } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\Data\TptCustomer.cs --- + +using System; + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class TptCustomer : TptPerson +{ + public string Email { get; set; } + public string Phone { get; set; } + public DateTime AddedDate { get; set; } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\Data\TptPerson.cs --- + +using System.ComponentModel.DataAnnotations.Schema; + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class TptPerson +{ + [DatabaseGenerated(DatabaseGeneratedOption.None)] + public long Id { get; set; } + public string FirstName { get; set; } + public string LastName { get; set; } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\Data\TptVendor.cs --- + + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class TptVendor : TptPerson +{ + public string Email { get; set; } + public string Phone { get; set; } + public string Url { get; set; } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\DatabaseExtensions\DatabaseExtensionsBase.cs --- + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using N.EntityFrameworkCore.Extensions.Test.Common; +using N.EntityFrameworkCore.Extensions.Test.Data; + +namespace N.EntityFrameworkCore.Extensions.Test.DatabaseExtensions; + +public class DatabaseExtensionsBase +{ + protected TestDbContext SetupDbContext(bool populateData) + { + TestDbContext dbContext = new TestDbContext(); + TestDatabaseInitializer.EnsureCreated(dbContext); + dbContext.Orders.Truncate(); + if (populateData) + { + var orders = new List(); + int id = 1; + for (int i = 0; i < 2050; i++) + { + DateTime addedDateTime = DateTime.UtcNow.AddDays(-id); + orders.Add(new Order + { + Id = id, + ExternalId = string.Format("id-{0}", i), + Price = 1.25M, + AddedDateTime = addedDateTime, + ModifiedDateTime = addedDateTime.AddHours(3) + }); + id++; + } + for (int i = 0; i < 1050; i++) + { + orders.Add(new Order { Id = id, Price = 5.35M }); + id++; + } + for (int i = 0; i < 2050; i++) + { + orders.Add(new Order { Id = id, Price = 1.25M }); + id++; + } + for (int i = 0; i < 6000; i++) + { + orders.Add(new Order { Id = id, Price = 15.35M }); + id++; + } + for (int i = 0; i < 6000; i++) + { + orders.Add(new Order { Id = id, Price = 15.35M }); + id++; + } + + Debug.WriteLine("Last Id for Order is {0}", id); + dbContext.BulkInsert(orders, new BulkInsertOptions() { KeepIdentity = true }); + } + return dbContext; + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\DatabaseExtensions\SqlQuery_Count.cs --- + +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Common; + +namespace N.EntityFrameworkCore.Extensions.Test.DatabaseExtensions; + +[TestClass] +public class SqlQuery_Count : DatabaseExtensionsBase +{ + [TestMethod] + public void With_Decimal_Value() + { + var dbContext = SetupDbContext(true); + int efCount = dbContext.Orders.Where(o => o.Price > 5M).Count(); + string sql = $"SELECT * FROM {Config.DelimitTableName("Orders")} WHERE {Config.DelimitIdentifier("Price")} > @Pr +rice"; + var sqlCount = dbContext.Database.FromSqlQuery(sql, Config.CreateParameter("@Price", 5M)).Count(); + + Assert.IsTrue(efCount > 0, "Count from EF should be greater than zero"); + Assert.IsTrue(efCount > 0, "Count from SQL should be greater than zero"); + Assert.IsTrue(efCount == sqlCount, "Count from EF should match the count from the SqlQuery"); + } + [TestMethod] + public void With_OrderBy() + { + var dbContext = SetupDbContext(true); + int efCount = dbContext.Orders.Where(o => o.Price > 5M).Count(); + string sql = $"SELECT * FROM {Config.DelimitTableName("Orders")} WHERE {Config.DelimitIdentifier("Price")} > @Pr +rice ORDER BY {Config.DelimitIdentifier("Id")}"; + var sqlCount = dbContext.Database.FromSqlQuery(sql, Config.CreateParameter("@Price", 5M)).Count(); + + Assert.IsTrue(efCount > 0, "Count from EF should be greater than zero"); + Assert.IsTrue(efCount > 0, "Count from SQL should be greater than zero"); + Assert.IsTrue(efCount == sqlCount, "Count from EF should match the count from the SqlQuery"); + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\DatabaseExtensions\SqlQuery_CountAsync.cs --- + +using System.Linq; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Common; + +namespace N.EntityFrameworkCore.Extensions.Test.DatabaseExtensions; + +[TestClass] +public class SqlQuery_CountAsync : DatabaseExtensionsBase +{ + [TestMethod] + public async Task With_Decimal_Value() + { + var dbContext = SetupDbContext(true); + int efCount = dbContext.Orders.Where(o => o.Price > 5M).Count(); + string sql = $"SELECT * FROM {Config.DelimitTableName("Orders")} WHERE {Config.DelimitIdentifier("Price")} > @Pr +rice"; + var sqlCount = await dbContext.Database.FromSqlQuery(sql, Config.CreateParameter("@Price", 5M)).CountAsync(); + + Assert.IsTrue(efCount > 0, "Count from EF should be greater than zero"); + Assert.IsTrue(efCount > 0, "Count from SQL should be greater than zero"); + Assert.IsTrue(efCount == sqlCount, "Count from EF should match the count from the SqlQuery"); + } + [TestMethod] + public async Task With_OrderBy() + { + var dbContext = SetupDbContext(true); + int efCount = dbContext.Orders.Where(o => o.Price > 5M).Count(); + string sql = $"SELECT * FROM {Config.DelimitTableName("Orders")} WHERE {Config.DelimitIdentifier("Price")} > @Pr +rice ORDER BY {Config.DelimitIdentifier("Id")}"; + var sqlCount = await dbContext.Database.FromSqlQuery(sql, Config.CreateParameter("@Price", 5M)).CountAsync(); + + Assert.IsTrue(efCount > 0, "Count from EF should be greater than zero"); + Assert.IsTrue(efCount > 0, "Count from SQL should be greater than zero"); + Assert.IsTrue(efCount == sqlCount, "Count from EF should match the count from the SqlQuery"); + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\DatabaseExtensions\SqlQueryToCsvFile.cs --- + +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Common; + +namespace N.EntityFrameworkCore.Extensions.Test.DatabaseExtensions; + +[TestClass] +public class SqlQueryToCsvFile : DatabaseExtensionsBase +{ + [TestMethod] + public void With_Default_Options() + { + var dbContext = SetupDbContext(true); + int count = dbContext.Orders.Where(o => o.Price > 5M).Count(); + string sql = $"SELECT * FROM {Config.DelimitTableName("Orders")} WHERE {Config.DelimitIdentifier("Price")} > @Pr +rice"; + var queryToCsvFileResult = dbContext.Database.SqlQueryToCsvFile("SqlQueryToCsvFile-Test.csv", sql, Config.Create +eParameter("@Price", 5M)); + + Assert.IsTrue(count > 0, "There should be existing data in the source table"); + Assert.IsTrue(queryToCsvFileResult.DataRowCount == count, "The number of data rows written to the file should ma +atch the count from the database"); + Assert.IsTrue(queryToCsvFileResult.TotalRowCount == count + 1, "The total number of rows written to the file sho +ould match the count from the database plus the header row"); + } + [TestMethod] + public void With_Options_ColumnDelimiter_TextQualifer() + { + var dbContext = SetupDbContext(true); + string filePath = "SqlQueryToCsvFile_Options_ColumnDelimiter_TextQualifer-Test.csv"; + int count = dbContext.Orders.Where(o => o.Price > 5M).Count(); + string sql = $"SELECT * FROM {Config.DelimitTableName("Orders")} WHERE {Config.DelimitIdentifier("Price")} > @Pr +rice"; + var queryToCsvFileResult = dbContext.Database.SqlQueryToCsvFile(filePath, options => { options.ColumnDelimiter = += "|"; options.TextQualifer = "\""; }, + sql, Config.CreateParameter("@Price", 5M)); + + Assert.IsTrue(count > 0, "There should be existing data in the source table"); + Assert.IsTrue(queryToCsvFileResult.DataRowCount == count, "The number of data rows written to the file should ma +atch the count from the database"); + Assert.IsTrue(queryToCsvFileResult.TotalRowCount == count + 1, "The total number of rows written to the file sho +ould match the count from the database plus the header row"); + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\DatabaseExtensions\SqlQueryToCsvFileAsync.cs --- + +using System.Linq; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Common; + +namespace N.EntityFrameworkCore.Extensions.Test.DatabaseExtensions; + +[TestClass] +public class SqlQueryToCsvFileAsync : DatabaseExtensionsBase +{ + [TestMethod] + public async Task With_Default_Options() + { + var dbContext = SetupDbContext(true); + int count = dbContext.Orders.Where(o => o.Price > 5M).Count(); + string sql = $"SELECT * FROM {Config.DelimitTableName("Orders")} WHERE {Config.DelimitIdentifier("Price")} > @Pr +rice"; + var queryToCsvFileResult = await dbContext.Database.SqlQueryToCsvFileAsync("SqlQueryToCsvFile-Test.csv", sql, ne +ew object[] { Config.CreateParameter("@Price", 5M) }); + + Assert.IsTrue(count > 0, "There should be existing data in the source table"); + Assert.IsTrue(queryToCsvFileResult.DataRowCount == count, "The number of data rows written to the file should ma +atch the count from the database"); + Assert.IsTrue(queryToCsvFileResult.TotalRowCount == count + 1, "The total number of rows written to the file sho +ould match the count from the database plus the header row"); + } + [TestMethod] + public async Task With_Options_ColumnDelimiter_TextQualifer() + { + var dbContext = SetupDbContext(true); + string filePath = "SqlQueryToCsvFile_Options_ColumnDelimiter_TextQualifer-Test.csv"; + int count = dbContext.Orders.Where(o => o.Price > 5M).Count(); + string sql = $"SELECT * FROM {Config.DelimitTableName("Orders")} WHERE {Config.DelimitIdentifier("Price")} > @Pr +rice"; + var queryToCsvFileResult = await dbContext.Database.SqlQueryToCsvFileAsync(filePath, options => { options.Column +nDelimiter = "|"; options.TextQualifer = "\""; }, + sql, new object[] { Config.CreateParameter("@Price", 5M) }); + + Assert.IsTrue(count > 0, "There should be existing data in the source table"); + Assert.IsTrue(queryToCsvFileResult.DataRowCount == count, "The number of data rows written to the file should ma +atch the count from the database"); + Assert.IsTrue(queryToCsvFileResult.TotalRowCount == count + 1, "The total number of rows written to the file sho +ould match the count from the database plus the header row"); + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\DatabaseExtensions\TableExists.cs --- + +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace N.EntityFrameworkCore.Extensions.Test.DatabaseExtensions; + +[TestClass] +public class TableExists : DatabaseExtensionsBase +{ + [TestMethod] + public void With_Orders_Table() + { + var dbContext = SetupDbContext(true); + int efCount = dbContext.Orders.Where(o => o.Price > 5M).Count(); + bool ordersTableExists = dbContext.Database.TableExists("Orders"); + bool orderNewTableExists = dbContext.Database.TableExists("OrdersNew"); + + Assert.IsTrue(ordersTableExists, "Orders table should exist"); + Assert.IsTrue(!orderNewTableExists, "Orders_New table should not exist"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\DatabaseExtensions\TruncateTable.cs --- + +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace N.EntityFrameworkCore.Extensions.Test.DatabaseExtensions; + +[TestClass] +public class TruncateTable : DatabaseExtensionsBase +{ + [TestMethod] + public void With_Orders_Table() + { + var dbContext = SetupDbContext(true); + int oldOrdersCount = dbContext.Orders.Count(); + dbContext.Database.TruncateTable("Orders"); + int newOrdersCount = dbContext.Orders.Count(); + + Assert.IsTrue(oldOrdersCount > 0, "Orders table should have data"); + Assert.IsTrue(newOrdersCount == 0, "Order table should be empty after truncating"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\DatabaseExtensions\TruncateTableAsync.cs --- + +using System.Linq; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace N.EntityFrameworkCore.Extensions.Test.DatabaseExtensions; + +[TestClass] +public class TruncateTableAsync : DatabaseExtensionsBase +{ + [TestMethod] + public async Task With_Orders_Table() + { + var dbContext = SetupDbContext(true); + int oldOrdersCount = dbContext.Orders.Count(); + await dbContext.Database.TruncateTableAsync("Orders"); + int newOrdersCount = dbContext.Orders.Count(); + + Assert.IsTrue(oldOrdersCount > 0, "Orders table should have data"); + Assert.IsTrue(newOrdersCount == 0, "Order table should be empty after truncating"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\DbContextExtensions\BulkDelete.cs --- + +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Data; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class BulkDelete : DbContextExtensionsBase +{ + [TestMethod] + public void With_Default_Options() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M).ToList(); + int rowsDeleted = dbContext.BulkDelete(orders); + int newTotal = dbContext.Orders.Where(o => o.Price == 1.25M).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(rowsDeleted == orders.Count, "The number of rows deleted must match the count of existing rows in + database"); + Assert.IsTrue(newTotal == 0, "Must be 0 to indicate all records were deleted"); + } + [TestMethod] + public void With_Inheritance_Tpc() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpc); + var customers = dbContext.TpcPeople.OfType().ToList(); + int rowsDeleted = dbContext.BulkDelete(customers); + var newCustomers = dbContext.TpcPeople.OfType().Count(); + + Assert.IsTrue(customers.Count > 0, "There must be tphCustomer records in database"); + Assert.IsTrue(rowsDeleted == customers.Count, "The number of rows deleted must match the count of existing rows + in database"); + Assert.IsTrue(newCustomers == 0, "Must be 0 to indicate all records were deleted"); + } + [TestMethod] + public void With_Inheritance_Tph() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tph); + var customers = dbContext.TphPeople.OfType().ToList(); + int rowsDeleted = dbContext.BulkDelete(customers); + var newCustomers = dbContext.TphPeople.OfType().Count(); + + Assert.IsTrue(customers.Count > 0, "There must be tphCustomer records in database"); + Assert.IsTrue(rowsDeleted == customers.Count, "The number of rows deleted must match the count of existing rows + in database"); + Assert.IsTrue(newCustomers == 0, "Must be 0 to indicate all records were deleted"); + } + [TestMethod] + public void With_Inheritance_Tpt() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpt); + var customers = dbContext.TptCustomers.ToList(); + int rowsDeleted = dbContext.BulkDelete(customers); + var newCustomers = dbContext.TptCustomers.Count(); + + Assert.IsTrue(customers.Count > 0, "There must be tphCustomer records in database"); + Assert.IsTrue(rowsDeleted == customers.Count, "The number of rows deleted must match the count of existing rows + in database"); + Assert.IsTrue(newCustomers == 0, "Must be 0 to indicate all records were deleted"); + } + [TestMethod] + public void With_Options_DeleteOnCondition() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Where(o => o.Price == 1.25M).Count(); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).ToList(); + int rowsDeleted = dbContext.BulkDelete(orders, options => { options.DeleteOnCondition = (s, t) => s.ExternalId = +== t.ExternalId; options.UsePermanentTable = true; }); + int newTotal = dbContext.Orders.Where(o => o.Price == 1.25M).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price < $2)"); + Assert.IsTrue(rowsDeleted == orders.Count, "The number of rows deleted must match the count of existing rows in + database"); + Assert.IsTrue(newTotal == oldTotal - rowsDeleted, "Must be 0 to indicate all records were deleted"); + } + [TestMethod] + public void With_Transaction() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M).ToList(); + int rowsDeleted, newTotal = 0; + using (var transaction = dbContext.Database.BeginTransaction()) + { + rowsDeleted = dbContext.BulkDelete(orders); + newTotal = dbContext.Orders.Where(o => o.Price == 1.25M).Count(); + transaction.Rollback(); + } + var rollbackTotal = dbContext.Orders.Count(o => o.Price == 1.25M); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price < $2)"); + Assert.IsTrue(rowsDeleted == orders.Count, "The number of rows deleted must match the count of existing rows in + database"); + Assert.IsTrue(newTotal == 0, "Must be 0 to indicate all records were deleted"); + Assert.IsTrue(rollbackTotal == orders.Count, "The number of rows after the transacation has been rollbacked shou +uld match the original count"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\DbContextExtensions\BulkDeleteAsync.cs --- + +using System.Linq; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Data; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class BulkDeleteAsync : DbContextExtensionsBase +{ + [TestMethod] + public async Task With_Default_Options() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M).ToList(); + int rowsDeleted = await dbContext.BulkDeleteAsync(orders); + int newTotal = dbContext.Orders.Where(o => o.Price == 1.25M).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(rowsDeleted == orders.Count, "The number of rows deleted must match the count of existing rows in + database"); + Assert.IsTrue(newTotal == 0, "Must be 0 to indicate all records were deleted"); + } + [TestMethod] + public async Task With_Inheritance_Tpc() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpc); + var customers = dbContext.TpcPeople.OfType().ToList(); + int rowsDeleted = await dbContext.BulkDeleteAsync(customers); + var newCustomers = dbContext.TpcPeople.OfType().Count(); + + Assert.IsTrue(customers.Count > 0, "There must be tphCustomer records in database"); + Assert.IsTrue(rowsDeleted == customers.Count, "The number of rows deleted must match the count of existing rows + in database"); + Assert.IsTrue(newCustomers == 0, "Must be 0 to indicate all records were deleted"); + } + [TestMethod] + public async Task With_Inheritance_Tph() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tph); + var customers = dbContext.TphPeople.OfType().ToList(); + int rowsDeleted = await dbContext.BulkDeleteAsync(customers); + var newCustomers = dbContext.TphPeople.OfType().Count(); + + Assert.IsTrue(customers.Count > 0, "There must be tphCustomer records in database"); + Assert.IsTrue(rowsDeleted == customers.Count, "The number of rows deleted must match the count of existing rows + in database"); + Assert.IsTrue(newCustomers == 0, "Must be 0 to indicate all records were deleted"); + } + [TestMethod] + public async Task With_Inheritance_Tpt() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpt); + var customers = dbContext.TptCustomers.ToList(); + int rowsDeleted = await dbContext.BulkDeleteAsync(customers); + var newCustomers = dbContext.TptCustomers.Count(); + + Assert.IsTrue(customers.Count > 0, "There must be tphCustomer records in database"); + Assert.IsTrue(rowsDeleted == customers.Count, "The number of rows deleted must match the count of existing rows + in database"); + Assert.IsTrue(newCustomers == 0, "Must be 0 to indicate all records were deleted"); + } + [TestMethod] + public async Task With_Options_DeleteOnCondition() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Where(o => o.Price == 1.25M).Count(); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).ToList(); + int rowsDeleted = await dbContext.BulkDeleteAsync(orders, options => { options.DeleteOnCondition = (s, t) => s.E +ExternalId == t.ExternalId; options.UsePermanentTable = true; }); + int newTotal = dbContext.Orders.Where(o => o.Price == 1.25M).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price < $2)"); + Assert.IsTrue(rowsDeleted == orders.Count, "The number of rows deleted must match the count of existing rows in + database"); + Assert.IsTrue(newTotal == oldTotal - rowsDeleted, "Must be 0 to indicate all records were deleted"); + } + [TestMethod] + public async Task With_Transaction() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M).ToList(); + int rowsDeleted, newTotal = 0; + using (var transaction = dbContext.Database.BeginTransaction()) + { + rowsDeleted = await dbContext.BulkDeleteAsync(orders); + newTotal = dbContext.Orders.Where(o => o.Price == 1.25M).Count(); + transaction.Rollback(); + } + var rollbackTotal = dbContext.Orders.Count(o => o.Price == 1.25M); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price < $2)"); + Assert.IsTrue(rowsDeleted == orders.Count, "The number of rows deleted must match the count of existing rows in + database"); + Assert.IsTrue(newTotal == 0, "Must be 0 to indicate all records were deleted"); + Assert.IsTrue(rollbackTotal == orders.Count, "The number of rows after the transacation has been rollbacked shou +uld match the original count"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\DbContextExtensions\BulkFetch.cs --- + +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class BulkFetch : DbContextExtensionsBase +{ + [TestMethod] + public void With_Complex_Property() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(o => o.Price == 1.25m).ToList(); + var fetchedProducts = dbContext.Products.BulkFetch(products); + bool foundNullPositionProperty = fetchedProducts.Any(o => o.Position == null); + + Assert.IsTrue(products.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(products.Count == fetchedProducts.Count(), "The number of rows deleted must match the count of exi +isting rows in database"); + Assert.IsFalse(foundNullPositionProperty, "The Position complex property should be populated when using BulkFetc +ch()"); + } + [TestMethod] + public void With_Default_Options() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M).ToList(); + var fetchedOrders = dbContext.Orders.BulkFetch(orders); + bool ordersAreMatched = true; + + foreach (var fetchedOrder in fetchedOrders) + { + var order = orders.First(o => o.Id == fetchedOrder.Id); + if (order.ExternalId != fetchedOrder.ExternalId || order.AddedDateTime != fetchedOrder.AddedDateTime || orde +er.ModifiedDateTime != fetchedOrder.ModifiedDateTime) + { + ordersAreMatched = false; + break; + } + } + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(orders.Count == fetchedOrders.Count(), "The number of rows deleted must match the count of existin +ng rows in database"); + Assert.IsTrue(ordersAreMatched, "The orders from BulkFetch() should match what is retrieved from DbContext"); + } + [TestMethod] + public void With_Enum() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(o => o.Price == 1.25m).ToList(); + var fetchedProducts = dbContext.Products.BulkFetch(products); + bool productsAreMatched = true; + + foreach (var fetchedProduct in fetchedProducts) + { + var product = products.First(o => o.Id == fetchedProduct.Id); + if (product.Id != fetchedProduct.Id || product.Name != fetchedProduct.Name || product.StatusEnum != fetchedP +Product.StatusEnum) + { + productsAreMatched = false; + break; + } + } + + Assert.IsTrue(products.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(products.Count == fetchedProducts.Count(), "The number of rows deleted must match the count of exi +isting rows in database"); + Assert.IsTrue(productsAreMatched, "The products from BulkFetch() should match what is retrieved from DbContext") +); + } + [TestMethod] + public void With_IQueryable() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price <= 10 && o.ExternalId != null); + var fetchedOrders = dbContext.Orders.BulkFetch(orders, options => { options.IgnoreColumns = o => new { o.Externa +alId }; }).ToList(); + int newTotal = dbContext.Orders.Where(o => o.Price <= 10 && o.ExternalId == null).Count(); + bool foundNullExternalId = fetchedOrders.Where(o => o.ExternalId != null).Any(); + + Assert.IsTrue(orders.Count() > 0, "There must be orders in the database that match condition (Price <= 10 And Ex +xternalId != null)"); + Assert.IsTrue(orders.Count() == fetchedOrders.Count(), "The number of orders must match the number of fetched or +rders"); + Assert.IsTrue(!foundNullExternalId, "Fetched orders should not contain any items where ExternalId is null."); + } + [TestMethod] + public void With_Options_IgnoreColumns() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price <= 10 && o.ExternalId != null).ToList(); + var fetchedOrders = dbContext.Orders.BulkFetch(orders, options => { options.IgnoreColumns = o => new { o.Externa +alId }; }).ToList(); + int newTotal = dbContext.Orders.Where(o => o.Price <= 10 && o.ExternalId == null).Count(); + bool foundNullExternalId = fetchedOrders.Where(o => o.ExternalId != null).Any(); + + Assert.IsTrue(orders.Count() > 0, "There must be orders in the database that match condition (Price <= 10 And Ex +xternalId != null)"); + Assert.IsTrue(orders.Count() == fetchedOrders.Count(), "The number of orders must match the number of fetched or +rders"); + Assert.IsTrue(!foundNullExternalId, "Fetched orders should not contain any items where ExternalId is null."); + } + [TestMethod] + public void With_ValueConverter() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(o => o.Price == 1.25M).ToList(); + var fetchedProducts = dbContext.Products.BulkFetch(products); + bool areMatched = true; + + foreach (var fetchedProduct in fetchedProducts) + { + var product = products.First(o => o.Id == fetchedProduct.Id); + if (product.Name != fetchedProduct.Name || product.Price != fetchedProduct.Price + || product.Color != fetchedProduct.Color) + { + areMatched = false; + break; + } + } + + Assert.IsTrue(products.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(products.Count == fetchedProducts.Count(), "The number of rows deleted must match the count of exi +isting rows in database"); + Assert.IsTrue(areMatched, "The products from BulkFetch() should match what is retrieved from DbContext"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\DbContextExtensions\BulkFetchAsync.cs --- + +using System.Linq; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class BulkFetchAsync : DbContextExtensionsBase +{ + [TestMethod] + public async Task With_Complex_Property() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(o => o.Price == 1.25m).ToList(); + var fetchedProducts = (await dbContext.Products.BulkFetchAsync(products)).ToList(); + bool foundNullPositionProperty = fetchedProducts.Any(o => o.Position == null); + + Assert.IsTrue(products.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(products.Count == fetchedProducts.Count, "The number of rows deleted must match the count of exist +ting rows in database"); + Assert.IsFalse(foundNullPositionProperty, "The Position complex property should be populated when using BulkFetc +chAsync()"); + } + [TestMethod] + public async Task With_Default_Options() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M).ToList(); + var fetchedOrders = (await dbContext.Orders.BulkFetchAsync(orders)).ToList(); + bool ordersAreMatched = true; + + foreach (var fetchedOrder in fetchedOrders) + { + var order = orders.First(o => o.Id == fetchedOrder.Id); + if (order.ExternalId != fetchedOrder.ExternalId || order.AddedDateTime != fetchedOrder.AddedDateTime || orde +er.ModifiedDateTime != fetchedOrder.ModifiedDateTime) + { + ordersAreMatched = false; + break; + } + } + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(orders.Count == fetchedOrders.Count, "The number of rows deleted must match the count of existing + rows in database"); + Assert.IsTrue(ordersAreMatched, "The orders from BulkFetchAsync() should match what is retrieved from DbContext" +"); + } + [TestMethod] + public async Task With_Enum() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(o => o.Price == 1.25m).ToList(); + var fetchedProducts = (await dbContext.Products.BulkFetchAsync(products)).ToList(); + bool productsAreMatched = true; + + foreach (var fetchedProduct in fetchedProducts) + { + var product = products.First(o => o.Id == fetchedProduct.Id); + if (product.Id != fetchedProduct.Id || product.Name != fetchedProduct.Name || product.StatusEnum != fetchedP +Product.StatusEnum) + { + productsAreMatched = false; + break; + } + } + + Assert.IsTrue(products.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(products.Count == fetchedProducts.Count, "The number of rows deleted must match the count of exist +ting rows in database"); + Assert.IsTrue(productsAreMatched, "The products from BulkFetchAsync() should match what is retrieved from DbCont +text"); + } + [TestMethod] + public async Task With_IQueryable() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price <= 10 && o.ExternalId != null); + var fetchedOrders = (await dbContext.Orders.BulkFetchAsync(orders, options => { options.IgnoreColumns = o => new +w { o.ExternalId }; })).ToList(); + bool foundNonNullExternalId = fetchedOrders.Any(o => o.ExternalId != null); + + Assert.IsTrue(orders.Count() > 0, "There must be orders in the database that match condition (Price <= 10 And Ex +xternalId != null)"); + Assert.IsTrue(orders.Count() == fetchedOrders.Count, "The number of orders must match the number of fetched orde +ers"); + Assert.IsFalse(foundNonNullExternalId, "Fetched orders should not contain any items where ExternalId is not null +l."); + } + [TestMethod] + public async Task With_Options_IgnoreColumns() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price <= 10 && o.ExternalId != null).ToList(); + var fetchedOrders = (await dbContext.Orders.BulkFetchAsync(orders, options => { options.IgnoreColumns = o => new +w { o.ExternalId }; })).ToList(); + bool foundNonNullExternalId = fetchedOrders.Any(o => o.ExternalId != null); + + Assert.IsTrue(orders.Count() > 0, "There must be orders in the database that match condition (Price <= 10 And Ex +xternalId != null)"); + Assert.IsTrue(orders.Count() == fetchedOrders.Count, "The number of orders must match the number of fetched orde +ers"); + Assert.IsFalse(foundNonNullExternalId, "Fetched orders should not contain any items where ExternalId is not null +l."); + } + [TestMethod] + public async Task With_ValueConverter() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(o => o.Price == 1.25M).ToList(); + var fetchedProducts = (await dbContext.Products.BulkFetchAsync(products)).ToList(); + bool areMatched = true; + + foreach (var fetchedProduct in fetchedProducts) + { + var product = products.First(o => o.Id == fetchedProduct.Id); + if (product.Name != fetchedProduct.Name || product.Price != fetchedProduct.Price + || product.Color != fetchedProduct.Color) + { + areMatched = false; + break; + } + } + + Assert.IsTrue(products.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(products.Count == fetchedProducts.Count, "The number of rows deleted must match the count of exist +ting rows in database"); + Assert.IsTrue(areMatched, "The products from BulkFetchAsync() should match what is retrieved from DbContext"); + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\DbContextExtensions\BulkInsert.cs --- + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.Data.SqlClient; +using Microsoft.EntityFrameworkCore; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Common; +using N.EntityFrameworkCore.Extensions.Test.Data; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class BulkInsert : DbContextExtensionsBase +{ + [TestMethod] + public void With_Complex_Key() + { + var dbContext = SetupDbContext(true); + var products = new List(); + for (int i = 50000; i < 60000; i++) + { + var key = i.ToString(); + products.Add(new ProductWithComplexKey { Price = 1.57M }); + } + int oldTotal = dbContext.ProductsWithComplexKey.Where(o => o.Price <= 10).Count(); + int rowsInserted = dbContext.BulkInsert(products); + int newTotal = dbContext.ProductsWithComplexKey.Where(o => o.Price <= 10).Count(); + + Assert.IsTrue(rowsInserted == products.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public void With_Complex_Type() + { + var dbContext = SetupDbContext(true); + var orders = new List(); + for (int i = 1; i < 1000; i++) + { + orders.Add(new OrderWithComplexType + { + Id = i, + ShippingAddress = new Address + { + Line1 = $"123 Main St, {i}", + City = "Atlanta", + Country = "USA", + PostCode = "30303" + }, + BillingAddress = new Address + { + Line1 = $"456 Oak St, {i}", + City = "Atlanta", + Country = "USA", + PostCode = "30303" + } + }); + } + int oldTotal = dbContext.OrdersWithComplexType.Count(); + int rowsInserted = dbContext.BulkInsert(orders); + int newTotal = dbContext.OrdersWithComplexType.Count(); + + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public void With_Default_Options() + { + var dbContext = SetupDbContext(false); + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M }); + } + int oldTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + int rowsInserted = dbContext.BulkInsert(orders); + int newTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public void With_Inheritance_Tpc() + { + var dbContext = SetupDbContext(false); + var customers = new List(); + var vendors = new List(); + for (int i = 0; i < 20000; i++) + { + customers.Add(new TpcCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + AddedDate = DateTime.UtcNow + }); + } + for (int i = 20000; i < 30000; i++) + { + vendors.Add(new TpcVendor + { + Id = i, + FirstName = string.Format("Mike_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("mike.smith{0}@domain.com", i), + Url = string.Format("http://domain.com/mike.smith{0}", i) + }); + } + int oldTotal = dbContext.TpcPeople.Count(); + int customerRowsInserted = dbContext.BulkInsert(customers, o => o.UsePermanentTable = true); + int vendorRowsInserted = dbContext.BulkInsert(vendors, o => o.UsePermanentTable = true); + int rowsInserted = customerRowsInserted + vendorRowsInserted; + int newTotal = dbContext.TpcPeople.Count(); + + Assert.IsTrue(rowsInserted == customers.Count + vendors.Count, "The number of rows inserted must match the count +t of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public void With_Inheritance_Tph() + { + var dbContext = SetupDbContext(false); + var customers = new List(); + var vendors = new List(); + for (int i = 0; i < 20000; i++) + { + customers.Add(new TphCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + Phone = "404-555-1111", + AddedDate = DateTime.UtcNow + }); + } + for (int i = 20000; i < 30000; i++) + { + vendors.Add(new TphVendor + { + Id = i, + FirstName = string.Format("Mike_{0}", i), + LastName = string.Format("Smith_{0}", i), + Phone = "404-555-2222", + Email = string.Format("mike.smith{0}@domain.com", i), + Url = string.Format("http://domain.com/mike.smith{0}", i) + }); + } + int oldTotal = dbContext.TphPeople.Count(); + int customerRowsInserted = dbContext.BulkInsert(customers); + int vendorRowsInserted = dbContext.BulkInsert(vendors); + int rowsInserted = customerRowsInserted + vendorRowsInserted; + int newTotal = dbContext.TphPeople.Count(); + + Assert.IsTrue(rowsInserted == customers.Count + vendors.Count, "The number of rows inserted must match the count +t of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public void With_Inheritance_Tpt() + { + var dbContext = SetupDbContext(false); + var customers = new List(); + var vendors = new List(); + for (int i = 0; i < 20000; i++) + { + customers.Add(new TptCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + Phone = "777-555-1234", + AddedDate = DateTime.UtcNow + }); + } + for (int i = 20000; i < 30000; i++) + { + vendors.Add(new TptVendor + { + Id = i, + FirstName = string.Format("Mike_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("mike.smith{0}@domain.com", i), + Url = string.Format("http://domain.com/mike.smith{0}", i) + }); + } + int oldTotal = dbContext.TptPeople.Count(); + int customerRowsInserted = dbContext.BulkInsert(customers, o => o.UsePermanentTable = true); + int vendorRowsInserted = dbContext.BulkInsert(vendors); + int rowsInserted = customerRowsInserted + vendorRowsInserted; + int newTotal = dbContext.TptPeople.Count(); + + Assert.IsTrue(rowsInserted == customers.Count + vendors.Count, "The number of rows inserted must match the count +t of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public void Without_Identity_Column() + { + var dbContext = SetupDbContext(true); + var products = new List(); + for (int i = 50000; i < 60000; i++) + { + products.Add(new Product { Id = i.ToString(), Price = 1.57M }); + } + int oldTotal = dbContext.Products.Where(o => o.Price <= 10).Count(); + int rowsInserted = dbContext.BulkInsert(products); + int newTotal = dbContext.Products.Where(o => o.Price <= 10).Count(); + + Assert.IsTrue(rowsInserted == products.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public void With_Options_AutoMapIdentity() + { + var dbContext = SetupDbContext(false); + var orders = new List(); + for (int i = 0; i < 5000; i++) + { + orders.Add(new Order { ExternalId = i.ToString(), Price = ((decimal)i + 0.55M) }); + } + int rowsAdded = dbContext.BulkInsert(orders, new BulkInsertOptions + { + UsePermanentTable = true + }); + bool autoMapIdentityMatched = true; + var ordersInDb = dbContext.Orders.ToList(); + Order order1 = null; + Order order2 = null; + foreach (var order in orders) + { + order1 = order; + var orderinDb = ordersInDb.First(o => o.Id == order.Id); + order2 = orderinDb; + if (!(orderinDb.ExternalId == order.ExternalId && orderinDb.Price == order.Price)) + { + autoMapIdentityMatched = false; + break; + } + } + + Assert.IsTrue(rowsAdded == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(autoMapIdentityMatched, "The auto mapping of ids of entities that were merged failed to match up") +); + } + [TestMethod] + public void With_Options_IgnoreColumns() + { + var dbContext = SetupDbContext(false); + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i, ExternalId = i.ToString(), Price = 1.57M, Active = true }); + } + int oldTotal = dbContext.Orders.Where(o => o.Price <= 10 && o.ExternalId == null).Count(); + int rowsInserted = dbContext.BulkInsert(orders, options => { options.UsePermanentTable = true; options.IgnoreCol +lumns = o => new { o.ExternalId }; }); + int newTotal = dbContext.Orders.Where(o => o.Price <= 10 && o.ExternalId == null).Count(); + + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public void With_Options_InputColumns() + { + var dbContext = SetupDbContext(false); + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i, ExternalId = i.ToString(), Price = 1.57M, Active = true, Status = OrderStatus +s.Completed }); + } + int oldTotal = dbContext.Orders.Where(o => o.Price == 1.57M && o.ExternalId == null && o.Active == true).Count() +); + int rowsInserted = dbContext.BulkInsert(orders, options => + { + options.UsePermanentTable = true; + options.InputColumns = o => new { o.Price, o.Active, o.AddedDateTime, o.Status }; + }); + int newTotal = dbContext.Orders.Where(o => o.Price == 1.57M && o.ExternalId == null && o.Active == true).Count() +); + + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public void With_KeepIdentity() + { + var dbContext = SetupDbContext(false); + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i + 1000, Price = 1.57M }); + } + int oldTotal = dbContext.Orders.Count(); + int rowsInserted = dbContext.BulkInsert(orders, options => { options.KeepIdentity = true; options.BatchSize = 10 +000; }); + var oldOrders = dbContext.Orders.OrderBy(o => o.Id).ToList(); + var newOrders = dbContext.Orders.OrderBy(o => o.Id).ToList(); + bool allIdentityFieldsMatch = true; + for (int i = 0; i < 20000; i++) + { + if (newOrders[i].Id != oldOrders[i].Id) + { + allIdentityFieldsMatch = false; + break; + } + } + try + { + int rowsInserted2 = dbContext.BulkInsert(orders, new BulkInsertOptions() + { + KeepIdentity = true, + BatchSize = 1000, + }); + } + catch (Exception ex) + { + Assert.IsInstanceOfType(ex, typeof(SqlException)); + Assert.IsTrue(Config.IsPrimaryKeyViolation(ex)); + } + + Assert.IsTrue(oldTotal == 0, "There should not be any records in the table"); + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(allIdentityFieldsMatch, "The identities between the source and the database should match."); + } + [TestMethod] + public void With_Schema() + { + var dbContext = SetupDbContext(false); + var products = new List(); + for (int i = 1; i < 10000; i++) + { + var key = i.ToString(); + products.Add(new ProductWithCustomSchema + { + Id = key, + Name = $"Product-{key}", + Price = 1.57M + }); + } + int oldTotal = dbContext.ProductsWithCustomSchema.Where(o => o.Price <= 10).Count(); + int rowsInserted = dbContext.BulkInsert(products); + int newTotal = dbContext.ProductsWithCustomSchema.Where(o => o.Price <= 10).Count(); + + Assert.IsTrue(rowsInserted == products.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public void With_Transaction() + { + var dbContext = SetupDbContext(false); + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M }); + } + int oldTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + int rowsInserted, newTotal; + using (var transaction = dbContext.Database.BeginTransaction()) + { + rowsInserted = dbContext.BulkInsert(orders); + newTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + transaction.Rollback(); + } + int rollbackTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + Assert.IsTrue(rollbackTotal == oldTotal, "The number of rows after the transacation has been rollbacked should m +match the original count"); + } + [TestMethod] + public void With_Options_InsertIfNotExists() + { + var dbContext = SetupDbContext(true); + var orders = new List(); + long maxId = dbContext.Orders.Max(o => o.Id); + long expectedRowsInserted = 1000; + int existingRowsToAdd = 100; + long startId = maxId - existingRowsToAdd + 1, endId = maxId + expectedRowsInserted + 1; + for (long i = startId; i < endId; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M }); + } + + int oldTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + int rowsInserted = dbContext.BulkInsert(orders, new BulkInsertOptions() { InsertIfNotExists = true }); + int newTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + + Assert.IsTrue(rowsInserted == expectedRowsInserted, "The number of rows inserted must match the count of order l +list"); + Assert.IsTrue(newTotal - oldTotal == expectedRowsInserted, "The new count minus the old count should match the n +number of rows inserted."); + } + [TestMethod] + public void With_Proxy_Type() + { + var dbContext = SetupDbContext(false); + int oldTotalCount = dbContext.Products.Where(o => o.Price == 10.57M).Count(); + + var products = new List(); + for (int i = 0; i < 2000; i++) + { + var product = dbContext.Products.CreateProxy(); + product.Id = (-i).ToString(); + product.Price = 10.57M; + products.Add(product); + } + int oldTotal = dbContext.Products.Where(o => o.Price == 10.57M).Count(); + int rowsInserted = dbContext.BulkInsert(products); + int newTotal = dbContext.Products.Where(o => o.Price == 10.57M).Count(); + + Assert.IsTrue(rowsInserted == products.Count, "The number of rows inserted must match the count of products list +t"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public void With_Trigger() + { + var dbContext = SetupDbContext(false); + var products = new List(); + for (int i = 1; i < 1000; i++) + { + products.Add(new ProductWithTrigger { Id = i.ToString(), Price = 1.57M, StatusString = "InStock" }); + } + + //The return int from BulkInsert() will be off when using triggers + dbContext.BulkInsert(products, options => + { + options.AutoMapOutput = false; + if (Config.IsSqlServer) + options.BulkCopyOptions = SqlBulkCopyOptions.FireTriggers; + }); + var rowsInserted = dbContext.ProductsWithTrigger.Count(); + + Assert.IsTrue(rowsInserted == products.Count, $"The number of rows inserted must match the count of products ({r +rowsInserted}!={products.Count})"); + } + [TestMethod] + public void With_ValueGenerated_Default() + { + var dbContext = SetupDbContext(false); + var nowDateTime = DateTime.Now; + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M }); + } + int oldTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + int rowsInserted = dbContext.BulkInsert(orders); + int newTotal = dbContext.Orders.Where(o => o.Price <= 10 && o.DbAddedDateTime > nowDateTime && o.DbActive).Count +t(); + + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public void With_ValueGenerated_Computed() + { + var dbContext = SetupDbContext(false); + var nowDateTime = DateTime.Now; + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M, DbModifiedDateTime = nowDateTime }); + } + int oldTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + int rowsInserted = dbContext.BulkInsert(orders); + int newTotal = dbContext.Orders.Where(o => o.Price <= 10 && o.DbModifiedDateTime > nowDateTime).Count(); + + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\DbContextExtensions\BulkInsertAsync.cs --- + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.Data.SqlClient; +using Microsoft.EntityFrameworkCore; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Common; +using N.EntityFrameworkCore.Extensions.Test.Data; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class BulkInsertAsync : DbContextExtensionsBase +{ + [TestMethod] + public async Task With_Complex_Key() + { + var dbContext = SetupDbContext(true); + var products = new List(); + for (int i = 50000; i < 60000; i++) + { + var key = i.ToString(); + products.Add(new ProductWithComplexKey { Price = 1.57M }); + } + int oldTotal = dbContext.ProductsWithComplexKey.Where(o => o.Price <= 10).Count(); + int rowsInserted = await dbContext.BulkInsertAsync(products); + int newTotal = dbContext.ProductsWithComplexKey.Where(o => o.Price <= 10).Count(); + + Assert.IsTrue(rowsInserted == products.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public async Task With_Complex_Type() + { + var dbContext = SetupDbContext(true); + var orders = new List(); + for (int i = 1; i < 1000; i++) + { + orders.Add(new OrderWithComplexType + { + Id = i, + ShippingAddress = new Address + { + Line1 = $"123 Main St, {i}", + City = "Atlanta", + Country = "USA", + PostCode = "30303" + }, + BillingAddress = new Address + { + Line1 = $"456 Oak St, {i}", + City = "Atlanta", + Country = "USA", + PostCode = "30303" + } + }); + } + int oldTotal = dbContext.OrdersWithComplexType.Count(); + int rowsInserted = await dbContext.BulkInsertAsync(orders); + int newTotal = dbContext.OrdersWithComplexType.Count(); + + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public async Task With_Default_Options() + { + var dbContext = SetupDbContext(false); + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M }); + } + int oldTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + int rowsInserted = await dbContext.BulkInsertAsync(orders); + int newTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + //[TestMethod] + //public async Task With_IEnumerable() + //{ + // var dbContext = SetupDbContext(false); + // var orders = dbContext.Orders.Where(o => o.Price <= 10); + + // foreach(var order in orders) + // { + // order.Price = 15.75M; + // } + // int oldTotal = orders.Count(); + // int rowsInserted = await dbContext.BulkInsertAsync(orders); + // int newTotal = orders.Count(); + + // Assert.IsTrue(rowsInserted == oldTotal, "The number of rows inserted must match the count of order list"); + // Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number + of rows inserted."); + //} + [TestMethod] + public async Task With_Inheritance_Tpc() + { + var dbContext = SetupDbContext(false); + var customers = new List(); + var vendors = new List(); + for (int i = 0; i < 20000; i++) + { + customers.Add(new TpcCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + AddedDate = DateTime.UtcNow + }); + } + for (int i = 20000; i < 30000; i++) + { + vendors.Add(new TpcVendor + { + Id = i, + FirstName = string.Format("Mike_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("mike.smith{0}@domain.com", i), + Url = string.Format("http://domain.com/mike.smith{0}", i) + }); + } + int oldTotal = dbContext.TpcPeople.Count(); + int customerRowsInserted = await dbContext.BulkInsertAsync(customers); + int vendorRowsInserted = await dbContext.BulkInsertAsync(vendors); + int rowsInserted = customerRowsInserted + vendorRowsInserted; + int newTotal = dbContext.TpcPeople.Count(); + + Assert.IsTrue(rowsInserted == customers.Count + vendors.Count, "The number of rows inserted must match the count +t of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public async Task With_Inheritance_Tph() + { + var dbContext = SetupDbContext(false); + var customers = new List(); + var vendors = new List(); + for (int i = 0; i < 20000; i++) + { + customers.Add(new TphCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + Phone = "404-555-1111", + AddedDate = DateTime.UtcNow + }); + } + for (int i = 20000; i < 30000; i++) + { + vendors.Add(new TphVendor + { + Id = i, + FirstName = string.Format("Mike_{0}", i), + LastName = string.Format("Smith_{0}", i), + Phone = "404-555-2222", + Email = string.Format("mike.smith{0}@domain.com", i), + Url = string.Format("http://domain.com/mike.smith{0}", i) + }); + } + int oldTotal = dbContext.TphPeople.Count(); + int customerRowsInserted = await dbContext.BulkInsertAsync(customers); + int vendorRowsInserted = await dbContext.BulkInsertAsync(vendors); + int rowsInserted = customerRowsInserted + vendorRowsInserted; + int newTotal = dbContext.TphPeople.Count(); + + Assert.IsTrue(rowsInserted == customers.Count + vendors.Count, "The number of rows inserted must match the count +t of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public async Task With_Inheritance_Tpt() + { + var dbContext = SetupDbContext(false); + var customers = new List(); + var vendors = new List(); + for (int i = 0; i < 20000; i++) + { + customers.Add(new TptCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + Phone = "777-555-1234", + AddedDate = DateTime.UtcNow + }); + } + for (int i = 20000; i < 30000; i++) + { + vendors.Add(new TptVendor + { + Id = i, + FirstName = string.Format("Mike_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("mike.smith{0}@domain.com", i), + Url = string.Format("http://domain.com/mike.smith{0}", i) + }); + } + int oldTotal = dbContext.TptPeople.Count(); + int customerRowsInserted = await dbContext.BulkInsertAsync(customers, o => o.UsePermanentTable = true); + int vendorRowsInserted = await dbContext.BulkInsertAsync(vendors); + int rowsInserted = customerRowsInserted + vendorRowsInserted; + int newTotal = dbContext.TptPeople.Count(); + + Assert.IsTrue(rowsInserted == customers.Count + vendors.Count, "The number of rows inserted must match the count +t of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public async Task Without_Identity_Column() + { + var dbContext = SetupDbContext(true); + var products = new List(); + for (int i = 50000; i < 60000; i++) + { + products.Add(new Product { Id = i.ToString(), Price = 1.57M }); + } + int oldTotal = dbContext.Products.Where(o => o.Price <= 10).Count(); + int rowsInserted = await dbContext.BulkInsertAsync(products); + int newTotal = dbContext.Products.Where(o => o.Price <= 10).Count(); + + Assert.IsTrue(rowsInserted == products.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public async Task With_Options_AutoMapIdentity() + { + + var dbContext = SetupDbContext(false); + var orders = new List(); + for (int i = 0; i < 5000; i++) + { + orders.Add(new Order { ExternalId = i.ToString(), Price = ((decimal)i + 0.55M) }); + } + int rowsAdded = await dbContext.BulkInsertAsync(orders, new BulkInsertOptions + { + UsePermanentTable = true + }); + bool autoMapIdentityMatched = true; + var ordersInDb = dbContext.Orders.ToList(); + Order order1 = null; + Order order2 = null; + foreach (var order in orders) + { + order1 = order; + var orderinDb = ordersInDb.First(o => o.Id == order.Id); + order2 = orderinDb; + if (!(orderinDb.ExternalId == order.ExternalId && orderinDb.Price == order.Price)) + { + autoMapIdentityMatched = false; + break; + } + } + + Assert.IsTrue(rowsAdded == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(autoMapIdentityMatched, "The auto mapping of ids of entities that were merged failed to match up") +); + } + [TestMethod] + public async Task With_Options_IgnoreColumns() + { + var dbContext = SetupDbContext(false); + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i, ExternalId = i.ToString(), Price = 1.57M, Active = true }); + } + int oldTotal = dbContext.Orders.Where(o => o.Price <= 10 && o.ExternalId == null).Count(); + int rowsInserted = await dbContext.BulkInsertAsync(orders, options => { options.UsePermanentTable = true; option +ns.IgnoreColumns = o => new { o.ExternalId }; }); + int newTotal = dbContext.Orders.Where(o => o.Price <= 10 && o.ExternalId == null).Count(); + + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public async Task With_Options_InputColumns() + { + var dbContext = SetupDbContext(false); + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i, ExternalId = i.ToString(), Price = 1.57M, Active = true, Status = OrderStatus +s.Completed }); + } + int oldTotal = dbContext.Orders.Where(o => o.Price == 1.57M && o.ExternalId == null && o.Active == true).Count() +); + int rowsInserted = await dbContext.BulkInsertAsync(orders, options => + { + options.UsePermanentTable = true; + options.InputColumns = o => new { o.Price, o.Active, o.AddedDateTime, o.Status }; + }); + int newTotal = dbContext.Orders.Where(o => o.Price == 1.57M && o.ExternalId == null && o.Active == true).Count() +); + + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public async Task With_KeepIdentity() + { + var dbContext = SetupDbContext(false); + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i + 1000, Price = 1.57M }); + } + int oldTotal = dbContext.Orders.Count(); + int rowsInserted = await dbContext.BulkInsertAsync(orders, options => { options.KeepIdentity = true; options.Bat +tchSize = 1000; }); + var oldOrders = dbContext.Orders.OrderBy(o => o.Id).ToList(); + var newOrders = dbContext.Orders.OrderBy(o => o.Id).ToList(); + bool allIdentityFieldsMatch = true; + for (int i = 0; i < 20000; i++) + { + if (newOrders[i].Id != oldOrders[i].Id) + { + allIdentityFieldsMatch = false; + break; + } + } + try + { + int rowsInserted2 = await dbContext.BulkInsertAsync(orders, new BulkInsertOptions() + { + KeepIdentity = true, + BatchSize = 1000, + }); + } + catch (Exception ex) + { + Assert.IsInstanceOfType(ex, typeof(SqlException)); + Assert.IsTrue(Config.IsPrimaryKeyViolation(ex)); + } + + Assert.IsTrue(oldTotal == 0, "There should not be any records in the table"); + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(allIdentityFieldsMatch, "The identities between the source and the database should match."); + } + [TestMethod] + public async Task With_Proxy_Type() + { + var dbContext = SetupDbContext(false); + int oldTotalCount = dbContext.Products.Where(o => o.Price == 10.57M).Count(); + + var products = new List(); + for (int i = 0; i < 2000; i++) + { + var product = dbContext.Products.CreateProxy(); + product.Id = (-i).ToString(); + product.Price = 10.57M; + products.Add(product); + } + int oldTotal = dbContext.Products.Where(o => o.Price == 10.57M).Count(); + int rowsInserted = await dbContext.BulkInsertAsync(products); + int newTotal = dbContext.Products.Where(o => o.Price == 10.57M).Count(); + + Assert.IsTrue(rowsInserted == products.Count, "The number of rows inserted must match the count of products list +t"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public async Task With_Trigger() + { + var dbContext = SetupDbContext(false); + var products = new List(); + for (int i = 1; i < 1000; i++) + { + products.Add(new ProductWithTrigger { Id = i.ToString(), Price = 1.57M, StatusString = "InStock" }); + } + + //The return int from BulkInsertAsync() will be off when using triggers + await dbContext.BulkInsertAsync(products, options => + { + options.AutoMapOutput = false; + if (Config.IsSqlServer) + options.BulkCopyOptions = SqlBulkCopyOptions.FireTriggers; + }); + var rowsInserted = dbContext.ProductsWithTrigger.Count(); + + Assert.IsTrue(rowsInserted == products.Count, $"The number of rows inserted must match the count of products ({r +rowsInserted}!={products.Count})"); + } + [TestMethod] + public async Task With_Schema() + { + var dbContext = SetupDbContext(false); + var products = new List(); + for (int i = 1; i < 10000; i++) + { + var key = i.ToString(); + products.Add(new ProductWithCustomSchema + { + Id = key, + Name = $"Product-{key}", + Price = 1.57M + }); + } + int oldTotal = dbContext.ProductsWithCustomSchema.Where(o => o.Price <= 10).Count(); + int rowsInserted = await dbContext.BulkInsertAsync(products); + int newTotal = dbContext.ProductsWithCustomSchema.Where(o => o.Price <= 10).Count(); + + Assert.IsTrue(rowsInserted == products.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public async Task With_Transaction() + { + var dbContext = SetupDbContext(false); + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M }); + } + int oldTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + int rowsInserted, newTotal; + using (var transaction = dbContext.Database.BeginTransaction()) + { + rowsInserted = await dbContext.BulkInsertAsync(orders); + newTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + transaction.Rollback(); + } + int rollbackTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + Assert.IsTrue(rollbackTotal == oldTotal, "The number of rows after the transacation has been rollbacked should m +match the original count"); + } + [TestMethod] + public async Task With_Options_InsertIfNotExists() + { + var dbContext = SetupDbContext(true); + var orders = new List(); + long maxId = dbContext.Orders.Max(o => o.Id); + long expectedRowsInserted = 1000; + int existingRowsToAdd = 100; + long startId = maxId - existingRowsToAdd + 1, endId = maxId + expectedRowsInserted + 1; + for (long i = startId; i < endId; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M }); + } + + int oldTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + int rowsInserted = await dbContext.BulkInsertAsync(orders, new BulkInsertOptions() { InsertIfNotExists = + true }); + int newTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + + Assert.IsTrue(rowsInserted == expectedRowsInserted, "The number of rows inserted must match the count of order l +list"); + Assert.IsTrue(newTotal - oldTotal == expectedRowsInserted, "The new count minus the old count should match the n +number of rows inserted."); + } + [TestMethod] + public async Task With_ValueGenerated_Default() + { + var dbContext = SetupDbContext(false); + var nowDateTime = DateTime.Now; + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M }); + } + int oldTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + int rowsInserted = await dbContext.BulkInsertAsync(orders); + int newTotal = dbContext.Orders.Where(o => o.Price <= 10 && o.DbAddedDateTime > nowDateTime && o.DbActive).Count +t(); + + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public async Task With_ValueGenerated_Computed() + { + var dbContext = SetupDbContext(false); + var nowDateTime = DateTime.Now; + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M, DbModifiedDateTime = nowDateTime }); + } + int oldTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + int rowsInserted = await dbContext.BulkInsertAsync(orders); + int newTotal = dbContext.Orders.Where(o => o.Price <= 10 && o.DbModifiedDateTime > nowDateTime).Count(); + + Assert.IsTrue(rowsInserted == orders.Count(), "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\DbContextExtensions\BulkMerge.cs --- + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Data; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class BulkMerge : DbContextExtensionsBase +{ + [TestMethod] + public void With_Complex_Key() + { + var dbContext = SetupDbContext(true); + var products = dbContext.ProductsWithComplexKey.Where(o => o.Price == 1.25M).ToList(); + int productsToAdd = 5000; + decimal updatedPrice = 5.25M; + var productsToUpdate = products.ToList(); + foreach (var product in products) + { + product.Price = updatedPrice; + } + for (int i = 0; i < productsToAdd; i++) + { + products.Add(new ProductWithComplexKey { ExternalId = (20000 + i).ToString(), Price = 3.55M }); + } + var result = dbContext.BulkMerge(products); + var allProducts = dbContext.ProductsWithComplexKey.ToList(); + bool areAddedOrdersMerged = true; + bool areUpdatedOrdersMerged = true; + foreach (var product in allProducts) + { + if (productsToUpdate.Contains(product) && product.Price != updatedPrice) + { + areUpdatedOrdersMerged = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == products.Count(), "The number of rows inserted must match the count of orde +er list"); + Assert.IsTrue(result.RowsUpdated == productsToUpdate.Count, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == productsToAdd, "The number of rows added must match"); + Assert.IsTrue(areAddedOrdersMerged, "The orders that were added did not merge correctly"); + Assert.IsTrue(areUpdatedOrdersMerged, "The orders that were updated did not merge correctly"); + } + [TestMethod] + public void With_Default_Options() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Id <= 10000).OrderBy(o => o.Id).ToList(); + int ordersToAdd = 5000; + int ordersToUpdate = orders.Count; + foreach (var order in orders) + { + order.Price = Convert.ToDecimal(order.Id + .25); + } + for (int i = 0; i < ordersToAdd; i++) + { + orders.Add(new Order { Id = 100000 + i, Price = 3.55M }); + } + var result = dbContext.BulkMerge(orders, o => o.UsePermanentTable = true); + var newOrders = dbContext.Orders.OrderBy(o => o.Id).ToList(); + bool areAddedOrdersMerged = true; + bool areUpdatedOrdersMerged = true; + foreach (var newOrder in newOrders.Where(o => o.Id <= 10000).OrderBy(o => o.Id)) + { + if (newOrder.Price != Convert.ToDecimal(newOrder.Id + .25)) + { + areUpdatedOrdersMerged = false; + break; + } + } + foreach (var newOrder in newOrders.Where(o => o.Id >= 500000).OrderBy(o => o.Id)) + { + if (newOrder.Price != 3.55M) + { + areAddedOrdersMerged = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == orders.Count(), "The number of rows inserted must match the count of order + list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(areAddedOrdersMerged, "The orders that were added did not merge correctly"); + Assert.IsTrue(areUpdatedOrdersMerged, "The orders that were updated did not merge correctly"); + } + [TestMethod] + public void With_Inheritance_Tpc() + { + var dbContext = SetupDbContext(true); + var customers = dbContext.TpcPeople.Where(o => o.Id <= 1000).OfType().ToList(); + int customersToAdd = 5000; + int customersToUpdate = customers.Count; + foreach (var customer in customers) + { + customer.FirstName = "BulkMerge_Tpc_Update"; + } + for (int i = 0; i < customersToAdd; i++) + { + customers.Add(new TpcCustomer + { + Id = 10000 + i, + FirstName = "BulkMerge_Tpc_Add", + AddedDate = DateTime.UtcNow + }); + } + var result = dbContext.BulkMerge(customers, options => { options.MergeOnCondition = (s, t) => s.Id == t.Id; }); + int customersAdded = dbContext.TpcPeople.Where(o => o.FirstName == "BulkMerge_Tpc_Add").OfType().Co +ount(); + int customersUpdated = dbContext.TpcPeople.Where(o => o.FirstName == "BulkMerge_Tpc_Update").OfType +>().Count(); + + Assert.IsTrue(result.RowsAffected == customers.Count, "The number of rows inserted must match the count of custo +omer list"); + Assert.IsTrue(result.RowsUpdated == customersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == customersToAdd, "The number of rows added must match"); + Assert.IsTrue(customersToAdd == customersAdded, "The custmoers that were added did not merge correctly"); + Assert.IsTrue(customersToUpdate == customersUpdated, "The customers that were updated did not merge correctly"); + } + [TestMethod] + public void With_Inheritance_Tph() + { + var dbContext = SetupDbContext(true); + var customers = dbContext.TphPeople.Where(o => o.Id <= 1000).OfType().ToList(); + int customersToAdd = 5000; + int customersToUpdate = customers.Count; + foreach (var customer in customers) + { + customer.FirstName = "BulkMerge_Tph_Update"; + } + for (int i = 0; i < customersToAdd; i++) + { + customers.Add(new TphCustomer + { + Id = 10000 + i, + FirstName = "BulkMerge_Tph_Add", + AddedDate = DateTime.UtcNow + }); + } + var result = dbContext.BulkMerge(customers); + int customersAdded = dbContext.TphPeople.Where(o => o.FirstName == "BulkMerge_Tph_Add").OfType().Co +ount(); + int customersUpdated = dbContext.TphPeople.Where(o => o.FirstName == "BulkMerge_Tph_Update").OfType +>().Count(); + + Assert.IsTrue(result.RowsAffected == customers.Count(), "The number of rows inserted must match the count of cus +stomer list"); + Assert.IsTrue(result.RowsUpdated == customersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == customersToAdd, "The number of rows added must match"); + Assert.IsTrue(customersToAdd == customersAdded, "The custmoers that were added did not merge correctly"); + Assert.IsTrue(customersToUpdate == customersUpdated, "The customers that were updated did not merge correctly"); + } + [TestMethod] + public void With_Inheritance_Tpt() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpt); + var customers = dbContext.TptPeople.Where(o => o.Id <= 1000).OfType().ToList(); + int customersToAdd = 5000; + int customersToUpdate = customers.Count; + foreach (var customer in customers) + { + customer.FirstName = "BulkMerge_Tpt_Update"; + } + for (int i = 0; i < customersToAdd; i++) + { + customers.Add(new TptCustomer + { + Id = 10000 + i, + FirstName = "BulkMerge_Tpt_Add", + AddedDate = DateTime.UtcNow + }); + } + var result = dbContext.BulkMerge(customers); + int customersAdded = dbContext.TptPeople.Where(o => o.FirstName == "BulkMerge_Tpt_Add").OfType().Co +ount(); + int customersUpdated = dbContext.TptPeople.Where(o => o.FirstName == "BulkMerge_Tpt_Update").OfType +>().Count(); + + Assert.IsTrue(result.RowsAffected == customers.Count(), "The number of rows inserted must match the count of cus +stomer list"); + Assert.IsTrue(result.RowsUpdated == customersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == customersToAdd, "The number of rows added must match"); + Assert.IsTrue(customersToAdd == customersAdded, "The custmoers that were added did not merge correctly"); + Assert.IsTrue(customersToUpdate == customersUpdated, "The customers that were updated did not merge correctly"); + } + [TestMethod] + public void With_Default_Options_MergeOnCondition() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Id <= 100 && o.ExternalId != null).OrderBy(o => o.Id).ToList(); + int ordersToAdd = 50; + int ordersToUpdate = orders.Count; + foreach (var order in orders) + { + order.Price = Convert.ToDecimal(order.Id + .25); + } + for (int i = 0; i < ordersToAdd; i++) + { + orders.Add(new Order { Id = 100000 + i, Price = 3.55M }); + } + var result = dbContext.BulkMerge(orders, options => { options.MergeOnCondition = (s, t) => s.ExternalId == t.Ext +ternalId; options.BatchSize = 1000; }); + var newOrders = dbContext.Orders.OrderBy(o => o.Id).ToList(); + bool areAddedOrdersMerged = true; + bool areUpdatedOrdersMerged = true; + foreach (var newOrder in newOrders.Where(o => o.Id <= 100 && o.ExternalId != null).OrderBy(o => o.Id)) + { + if (newOrder.Price != Convert.ToDecimal(newOrder.Id + .25)) + { + areUpdatedOrdersMerged = false; + break; + } + } + foreach (var newOrder in newOrders.Where(o => o.Id >= 500000).OrderBy(o => o.Id)) + { + if (newOrder.Price != 3.55M) + { + areAddedOrdersMerged = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == orders.Count(), "The number of rows inserted must match the count of order + list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(areAddedOrdersMerged, "The orders that were added did not merge correctly"); + Assert.IsTrue(areUpdatedOrdersMerged, "The orders that were updated did not merge correctly"); + } + [TestMethod] + public void With_Options_AutoMapIdentity() + { + var dbContext = SetupDbContext(true); + int ordersToUpdate = 3; + int ordersToAdd = 2; + var orders = new List + { + new Order { ExternalId = "id-1", Price=7.10M }, + new Order { ExternalId = "id-2", Price=9.33M }, + new Order { ExternalId = "id-3", Price=3.25M }, + new Order { ExternalId = "id-1000001", Price=2.15M }, + new Order { ExternalId = "id-1000002", Price=5.75M }, + }; + var result = dbContext.BulkMerge(orders, new BulkMergeOptions + { + MergeOnCondition = (s, t) => s.ExternalId == t.ExternalId, + UsePermanentTable = true + }); + bool autoMapIdentityMatched = true; + foreach (var order in orders) + { + if (!dbContext.Orders.Any(o => o.ExternalId == order.ExternalId && o.Price == order.Price)) + { + autoMapIdentityMatched = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == ordersToAdd + ordersToUpdate, "The number of rows inserted must match the c +count of order list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(autoMapIdentityMatched, "The auto mapping of ids of entities that were merged failed to match up") +); + } + [TestMethod] + public void With_Options_AutoMapOutput() + { + var dbContext = SetupDbContext(true); + int ordersToUpdate = 3; + int ordersToAdd = 2; + var orders = new List + { + new Order { ExternalId = "id-1", Price=7.10M }, + new Order { ExternalId = "id-2", Price=9.33M }, + new Order { ExternalId = "id-3", Price=3.25M }, + new Order { ExternalId = "id-1000001", Price=2.15M }, + new Order { ExternalId = "id-1000002", Price=5.75M }, + }; + var result = dbContext.BulkMerge(orders, new BulkMergeOptions + { + MergeOnCondition = (s, t) => s.ExternalId == t.ExternalId, + AutoMapOutput = true + }); + var autoMapIdentityMatched = orders.All(x => x.Id != 0); + + Assert.IsTrue(result.RowsAffected == ordersToAdd + ordersToUpdate, "The number of rows inserted must match the c +count of order list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(autoMapIdentityMatched, "The auto mapping of ids of entities that were merged failed to match up") +); + } + [TestMethod] + public void With_Key() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(o => o.Price == 1.25M).OrderBy(o => o.Id).ToList(); + int productsToAdd = 5000; + var productsToUpdate = products.ToList(); + foreach (var product in products) + { + product.Price = Convert.ToDecimal(product.Id) + .25M; + } + for (int i = 0; i < productsToAdd; i++) + { + products.Add(new Product { Id = (20000 + i).ToString(), Price = 3.55M }); + } + var result = dbContext.BulkMerge(products); + var newProducts = dbContext.Products.OrderBy(o => o.Id).ToList(); + bool areAddedOrdersMerged = true; + bool areUpdatedOrdersMerged = true; + foreach (var newProduct in newProducts.Where(o => productsToUpdate.Select(o => o.Id).Contains(o.Id))) + { + if (newProduct.Price != Convert.ToDecimal(newProduct.Id) + .25M) + { + areUpdatedOrdersMerged = false; + break; + } + } + foreach (var newProduct in newProducts.Where(o => Convert.ToInt32(o.Id) >= 20000).OrderBy(o => o.Id)) + { + if (newProduct.Price != 3.55M) + { + areAddedOrdersMerged = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == products.Count(), "The number of rows inserted must match the count of orde +er list"); + Assert.IsTrue(result.RowsUpdated == productsToUpdate.Count, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == productsToAdd, "The number of rows added must match"); + Assert.IsTrue(areAddedOrdersMerged, "The orders that were added did not merge correctly"); + Assert.IsTrue(areUpdatedOrdersMerged, "The orders that were updated did not merge correctly"); + } + [TestMethod] + public void With_Transaction() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Id <= 10000).OrderBy(o => o.Id).ToList(); + int ordersToAdd = 5000; + int ordersToUpdate = orders.Count; + foreach (var order in orders) + { + order.Price = Convert.ToDecimal(order.Id + .25); + } + for (int i = 0; i < ordersToAdd; i++) + { + orders.Add(new Order { Id = 100000 + i, Price = 3.55M }); + } + BulkMergeResult result; + using (var transaction = dbContext.Database.BeginTransaction()) + { + result = dbContext.BulkMerge(orders); + transaction.Rollback(); + } + int ordersUpdated = dbContext.Orders.Count(o => o.Id <= 10000 && o.Price == ((decimal)o.Id + .25M) && o.Price != += 1.25M); + int ordersAdded = dbContext.Orders.Count(o => o.Id >= 100000); + + Assert.IsTrue(result.RowsAffected == orders.Count(), "The number of rows inserted must match the count of order + list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(ordersAdded == 0, "The number of rows added must equal 0 since transaction was rollbacked"); + Assert.IsTrue(ordersUpdated == 0, "The number of rows updated must equal 0 since transaction was rollbacked"); + } + [TestMethod] + public void With_ValueGenerated_Default() + { + var dbContext = SetupDbContext(false); + var nowDateTime = DateTime.Now; + var orders = new List(); + for (int i = 0; i < 1000; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M }); + } + int oldTotal = dbContext.Orders.Where(o => o.DbAddedDateTime > nowDateTime).Count(); + var mergeResult = dbContext.BulkMerge(orders); + int newTotal = dbContext.Orders.Where(o => o.Price <= 1.57M + && o.DbAddedDateTime > nowDateTime).Count(); + + Assert.IsTrue(mergeResult.RowsInserted == orders.Count, "The number of rows inserted must match the count of ord +der list"); + Assert.IsTrue(newTotal - oldTotal == mergeResult.RowsInserted, "The new count minus the old count should match t +the number of rows inserted."); + } + [TestMethod] + public void With_ValueGenerated_Computed() + { + var dbContext = SetupDbContext(false); + var nowDateTime = DateTime.Now; + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M, DbModifiedDateTime = nowDateTime }); + } + int oldTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + var result = dbContext.BulkMerge(orders); + int newTotal = dbContext.Orders.Where(o => o.Price <= 10 && o.DbModifiedDateTime > nowDateTime).Count(); + + Assert.IsTrue(result.RowsInserted == orders.Count(), "The number of rows inserted must match the count of order + list"); + Assert.IsTrue(newTotal - oldTotal == result.RowsInserted, "The new count minus the old count should match the nu +umber of rows inserted."); + } + [TestMethod] + public void With_Merge_On_Enum() + { + var dbContext = SetupDbContext(true); + dbContext.BulkSaveChanges(); + var nowDateTime = DateTime.Now; + var orders = new List(); + for (int i = 0; i < 20; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M, DbModifiedDateTime = nowDateTime, Status = OrderStatus.Complet +ted }); + } + + var result = dbContext.BulkMerge(orders, options => options.MergeOnCondition = (s, t) => s.Id == t.Id && s.Statu +us == t.Status); + + Assert.AreEqual(1, result.RowsInserted); + Assert.AreEqual(19, result.RowsUpdated); + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\DbContextExtensions\BulkMergeAsync.cs --- + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Data; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class BulkMergeAsync : DbContextExtensionsBase +{ + [TestMethod] + public async Task With_Complex_Key() + { + var dbContext = SetupDbContext(true); + var products = dbContext.ProductsWithComplexKey.Where(o => o.Price == 1.25M).ToList(); + int productsToAdd = 5000; + decimal updatedPrice = 5.25M; + var productsToUpdate = products.ToList(); + foreach (var product in products) + { + product.Price = updatedPrice; + } + for (int i = 0; i < productsToAdd; i++) + { + products.Add(new ProductWithComplexKey { ExternalId = (20000 + i).ToString(), Price = 3.55M }); + } + var result = await dbContext.BulkMergeAsync(products); + var allProducts = dbContext.ProductsWithComplexKey.ToList(); + bool areAddedOrdersMerged = true; + bool areUpdatedOrdersMerged = true; + foreach (var product in allProducts) + { + if (productsToUpdate.Contains(product) && product.Price != updatedPrice) + { + areUpdatedOrdersMerged = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == products.Count(), "The number of rows inserted must match the count of orde +er list"); + Assert.IsTrue(result.RowsUpdated == productsToUpdate.Count, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == productsToAdd, "The number of rows added must match"); + Assert.IsTrue(areAddedOrdersMerged, "The orders that were added did not merge correctly"); + Assert.IsTrue(areUpdatedOrdersMerged, "The orders that were updated did not merge correctly"); + } + [TestMethod] + public async Task With_Default_Options() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Id <= 10000).OrderBy(o => o.Id).ToList(); + int ordersToAdd = 5000; + int ordersToUpdate = orders.Count; + foreach (var order in orders) + { + order.Price = Convert.ToDecimal(order.Id + .25); + } + for (int i = 0; i < ordersToAdd; i++) + { + orders.Add(new Order { Id = 100000 + i, Price = 3.55M }); + } + var result = await dbContext.BulkMergeAsync(orders); + var newOrders = dbContext.Orders.OrderBy(o => o.Id).ToList(); + bool areAddedOrdersMerged = true; + bool areUpdatedOrdersMerged = true; + foreach (var newOrder in newOrders.Where(o => o.Id <= 10000).OrderBy(o => o.Id)) + { + if (newOrder.Price != Convert.ToDecimal(newOrder.Id + .25)) + { + areUpdatedOrdersMerged = false; + break; + } + } + foreach (var newOrder in newOrders.Where(o => o.Id >= 500000).OrderBy(o => o.Id)) + { + if (newOrder.Price != 3.55M) + { + areAddedOrdersMerged = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == orders.Count(), "The number of rows inserted must match the count of order + list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(areAddedOrdersMerged, "The orders that were added did not merge correctly"); + Assert.IsTrue(areUpdatedOrdersMerged, "The orders that were updated did not merge correctly"); + } + [TestMethod] + public async Task With_Inheritance_Tpc() + { + var dbContext = SetupDbContext(true); + var customers = dbContext.TpcPeople.Where(o => o.Id <= 1000).OfType().ToList(); + int customersToAdd = 5000; + int customersToUpdate = customers.Count; + foreach (var customer in customers) + { + customer.FirstName = "BulkMerge_Tpc_Update"; + } + for (int i = 0; i < customersToAdd; i++) + { + customers.Add(new TpcCustomer + { + Id = 10000 + i, + FirstName = "BulkMerge_Tpc_Add", + AddedDate = DateTime.UtcNow + }); + } + var result = await dbContext.BulkMergeAsync(customers, options => { options.MergeOnCondition = (s, t) => s.Id == += t.Id; }); + int customersAdded = dbContext.TpcPeople.Where(o => o.FirstName == "BulkMerge_Tpc_Add").OfType().Co +ount(); + int customersUpdated = dbContext.TpcPeople.Where(o => o.FirstName == "BulkMerge_Tpc_Update").OfType +>().Count(); + + Assert.IsTrue(result.RowsAffected == customers.Count, "The number of rows inserted must match the count of custo +omer list"); + Assert.IsTrue(result.RowsUpdated == customersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == customersToAdd, "The number of rows added must match"); + Assert.IsTrue(customersToAdd == customersAdded, "The custmoers that were added did not merge correctly"); + Assert.IsTrue(customersToUpdate == customersUpdated, "The customers that were updated did not merge correctly"); + } + [TestMethod] + public async Task With_Inheritance_Tph() + { + var dbContext = SetupDbContext(true); + var customers = dbContext.TphPeople.Where(o => o.Id <= 1000).OfType().ToList(); + int customersToAdd = 5000; + int customersToUpdate = customers.Count; + foreach (var customer in customers) + { + customer.FirstName = "BulkMerge_Tph_Update"; + } + for (int i = 0; i < customersToAdd; i++) + { + customers.Add(new TphCustomer + { + Id = 10000 + i, + FirstName = "BulkMerge_Tph_Add", + AddedDate = DateTime.UtcNow + }); + } + var result = await dbContext.BulkMergeAsync(customers); + int customersAdded = dbContext.TphPeople.Where(o => o.FirstName == "BulkMerge_Tph_Add").OfType().Co +ount(); + int customersUpdated = dbContext.TphPeople.Where(o => o.FirstName == "BulkMerge_Tph_Update").OfType +>().Count(); + + Assert.IsTrue(result.RowsAffected == customers.Count(), "The number of rows inserted must match the count of cus +stomer list"); + Assert.IsTrue(result.RowsUpdated == customersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == customersToAdd, "The number of rows added must match"); + Assert.IsTrue(customersToAdd == customersAdded, "The custmoers that were added did not merge correctly"); + Assert.IsTrue(customersToUpdate == customersUpdated, "The customers that were updated did not merge correctly"); + } + [TestMethod] + public async Task With_Inheritance_Tpt() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpt); + var customers = dbContext.TptPeople.Where(o => o.Id <= 1000).OfType().ToList(); + int customersToAdd = 5000; + int customersToUpdate = customers.Count; + foreach (var customer in customers) + { + customer.FirstName = "BulkMergeAsync_Tpt_Update"; + } + for (int i = 0; i < customersToAdd; i++) + { + customers.Add(new TptCustomer + { + Id = 10000 + i, + FirstName = "BulkMergeAsync_Tpt_Add", + AddedDate = DateTime.UtcNow + }); + } + var result = await dbContext.BulkMergeAsync(customers); + int customersAdded = dbContext.TptPeople.Where(o => o.FirstName == "BulkMergeAsync_Tpt_Add").OfType +>().Count(); + int customersUpdated = dbContext.TptPeople.Where(o => o.FirstName == "BulkMergeAsync_Tpt_Update").OfType().Count(); + + Assert.IsTrue(result.RowsAffected == customers.Count(), "The number of rows inserted must match the count of cus +stomer list"); + Assert.IsTrue(result.RowsUpdated == customersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == customersToAdd, "The number of rows added must match"); + Assert.IsTrue(customersToAdd == customersAdded, "The custmoers that were added did not merge correctly"); + Assert.IsTrue(customersToUpdate == customersUpdated, "The customers that were updated did not merge correctly"); + } + [TestMethod] + public async Task With_Default_Options_MergeOnCondition() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Id <= 100 && o.ExternalId != null).OrderBy(o => o.Id).ToList(); + int ordersToAdd = 50; + int ordersToUpdate = orders.Count; + foreach (var order in orders) + { + order.Price = Convert.ToDecimal(order.Id + .25); + } + for (int i = 0; i < ordersToAdd; i++) + { + orders.Add(new Order { Id = 100000 + i, Price = 3.55M }); + } + var result = await dbContext.BulkMergeAsync(orders, options => { options.MergeOnCondition = (s, t) => s.External +lId == t.ExternalId; options.BatchSize = 1000; }); + var newOrders = dbContext.Orders.OrderBy(o => o.Id).ToList(); + bool areAddedOrdersMerged = true; + bool areUpdatedOrdersMerged = true; + foreach (var newOrder in newOrders.Where(o => o.Id <= 100 && o.ExternalId != null).OrderBy(o => o.Id)) + { + if (newOrder.Price != Convert.ToDecimal(newOrder.Id + .25)) + { + areUpdatedOrdersMerged = false; + break; + } + } + foreach (var newOrder in newOrders.Where(o => o.Id >= 500000).OrderBy(o => o.Id)) + { + if (newOrder.Price != 3.55M) + { + areAddedOrdersMerged = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == orders.Count(), "The number of rows inserted must match the count of order + list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(areAddedOrdersMerged, "The orders that were added did not merge correctly"); + Assert.IsTrue(areUpdatedOrdersMerged, "The orders that were updated did not merge correctly"); + } + [TestMethod] + public async Task With_Options_AutoMapIdentity() + { + var dbContext = SetupDbContext(true); + int ordersToUpdate = 3; + int ordersToAdd = 2; + var orders = new List + { + new Order { ExternalId = "id-1", Price=7.10M }, + new Order { ExternalId = "id-2", Price=9.33M }, + new Order { ExternalId = "id-3", Price=3.25M }, + new Order { ExternalId = "id-1000001", Price=2.15M }, + new Order { ExternalId = "id-1000002", Price=5.75M }, + }; + var result = await dbContext.BulkMergeAsync(orders, new BulkMergeOptions + { + MergeOnCondition = (s, t) => s.ExternalId == t.ExternalId, + UsePermanentTable = true + }); + bool autoMapIdentityMatched = true; + foreach (var order in orders) + { + if (!dbContext.Orders.Any(o => o.ExternalId == order.ExternalId && o.Price == order.Price)) + { + autoMapIdentityMatched = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == ordersToAdd + ordersToUpdate, "The number of rows inserted must match the c +count of order list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(autoMapIdentityMatched, "The auto mapping of ids of entities that were merged failed to match up") +); + } + [TestMethod] + public async Task With_Options_AutoMapOutput() + { + var dbContext = SetupDbContext(true); + int ordersToUpdate = 3; + int ordersToAdd = 2; + var orders = new List + { + new Order { ExternalId = "id-1", Price=7.10M }, + new Order { ExternalId = "id-2", Price=9.33M }, + new Order { ExternalId = "id-3", Price=3.25M }, + new Order { ExternalId = "id-1000001", Price=2.15M }, + new Order { ExternalId = "id-1000002", Price=5.75M }, + }; + var result = await dbContext.BulkMergeAsync(orders, new BulkMergeOptions + { + MergeOnCondition = (s, t) => s.ExternalId == t.ExternalId, + AutoMapOutput = true + }); + var autoMapIdentityMatched = orders.All(x => x.Id != 0); + + Assert.IsTrue(result.RowsAffected == ordersToAdd + ordersToUpdate, "The number of rows inserted must match the c +count of order list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(autoMapIdentityMatched, "The auto mapping of ids of entities that were merged failed to match up") +); + } + [TestMethod] + public async Task With_Key() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(o => o.Price == 1.25M).OrderBy(o => o.Id).ToList(); + int productsToAdd = 5000; + var productsToUpdate = products.ToList(); + foreach (var product in products) + { + product.Price = Convert.ToDecimal(product.Id) + .25M; + } + for (int i = 0; i < productsToAdd; i++) + { + products.Add(new Product { Id = (20000 + i).ToString(), Price = 3.55M }); + } + var result = await dbContext.BulkMergeAsync(products); + var newProducts = dbContext.Products.OrderBy(o => o.Id).ToList(); + bool areAddedOrdersMerged = true; + bool areUpdatedOrdersMerged = true; + foreach (var newProduct in newProducts.Where(o => productsToUpdate.Select(o => o.Id).Contains(o.Id))) + { + if (newProduct.Price != Convert.ToDecimal(newProduct.Id) + .25M) + { + areUpdatedOrdersMerged = false; + break; + } + } + foreach (var newProduct in newProducts.Where(o => Convert.ToInt32(o.Id) >= 20000).OrderBy(o => o.Id)) + { + if (newProduct.Price != 3.55M) + { + areAddedOrdersMerged = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == products.Count(), "The number of rows inserted must match the count of orde +er list"); + Assert.IsTrue(result.RowsUpdated == productsToUpdate.Count, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == productsToAdd, "The number of rows added must match"); + Assert.IsTrue(areAddedOrdersMerged, "The orders that were added did not merge correctly"); + Assert.IsTrue(areUpdatedOrdersMerged, "The orders that were updated did not merge correctly"); + } + [TestMethod] + public async Task With_Transaction() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Id <= 10000).OrderBy(o => o.Id).ToList(); + int ordersToAdd = 5000; + int ordersToUpdate = orders.Count; + foreach (var order in orders) + { + order.Price = Convert.ToDecimal(order.Id + .25); + } + for (int i = 0; i < ordersToAdd; i++) + { + orders.Add(new Order { Id = 100000 + i, Price = 3.55M }); + } + BulkMergeResult result; + using (var transaction = dbContext.Database.BeginTransaction()) + { + result = await dbContext.BulkMergeAsync(orders); + transaction.Rollback(); + } + int ordersUpdated = dbContext.Orders.Count(o => o.Id <= 10000 && o.Price == ((decimal)o.Id + .25M) && o.Price != += 1.25M); + int ordersAdded = dbContext.Orders.Count(o => o.Id >= 100000); + + Assert.IsTrue(result.RowsAffected == orders.Count(), "The number of rows inserted must match the count of order + list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(ordersAdded == 0, "The number of rows added must equal 0 since transaction was rollbacked"); + Assert.IsTrue(ordersUpdated == 0, "The number of rows updated must equal 0 since transaction was rollbacked"); + } + [TestMethod] + public async Task With_ValueGenerated_Default() + { + var dbContext = SetupDbContext(false); + var nowDateTime = DateTime.Now; + var orders = new List(); + for (int i = 0; i < 1000; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M }); + } + int oldTotal = dbContext.Orders.Where(o => o.DbAddedDateTime > nowDateTime).Count(); + var mergeResult = await dbContext.BulkMergeAsync(orders); + int newTotal = dbContext.Orders.Where(o => o.Price <= 1.57M + && o.DbAddedDateTime > nowDateTime).Count(); + + Assert.IsTrue(mergeResult.RowsInserted == orders.Count, "The number of rows inserted must match the count of ord +der list"); + Assert.IsTrue(newTotal - oldTotal == mergeResult.RowsInserted, "The new count minus the old count should match t +the number of rows inserted."); + } + [TestMethod] + public async Task With_ValueGenerated_Computed() + { + var dbContext = SetupDbContext(false); + var nowDateTime = DateTime.Now; + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M, DbModifiedDateTime = nowDateTime }); + } + int oldTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + var result = await dbContext.BulkMergeAsync(orders); + int newTotal = dbContext.Orders.Where(o => o.Price <= 10 && o.DbModifiedDateTime > nowDateTime).Count(); + + Assert.IsTrue(result.RowsInserted == orders.Count(), "The number of rows inserted must match the count of order + list"); + Assert.IsTrue(newTotal - oldTotal == result.RowsInserted, "The new count minus the old count should match the nu +umber of rows inserted."); + } + [TestMethod] + public async Task With_Merge_On_Enum() + { + var dbContext = SetupDbContext(true); + await dbContext.BulkSaveChangesAsync(); + var nowDateTime = DateTime.Now; + var orders = new List(); + for (int i = 0; i < 20; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M, DbModifiedDateTime = nowDateTime, Status = OrderStatus.Complet +ted }); + } + + var result = await dbContext.BulkMergeAsync(orders, options => options.MergeOnCondition = (s, t) => s.Id == t.Id +d && s.Status == t.Status); + + Assert.AreEqual(1, result.RowsInserted); + Assert.AreEqual(19, result.RowsUpdated); + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\DbContextExtensions\BulkSaveChanges.cs --- + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Data; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class BulkSaveChanges : DbContextExtensionsBase +{ + [TestMethod] + public void With_Default_Options() + { + var dbContext = SetupDbContext(true); + var totalCount = dbContext.Orders.Count(); + + //Add new orders + var ordersToAdd = new List(); + for (int i = 0; i < 2000; i++) + { + ordersToAdd.Add(new Order { Id = -i, Price = 10.57M }); + } + dbContext.Orders.AddRange(ordersToAdd); + + //Delete orders + var ordersToDelete = dbContext.Orders.Where(o => o.Price <= 5).ToList(); + dbContext.Orders.RemoveRange(ordersToDelete); + + //Update existing orders + var ordersToUpdate = dbContext.Orders.Where(o => o.Price > 5 && o.Price <= 10).ToList(); + foreach (var orderToUpdate in ordersToUpdate) + { + orderToUpdate.Price = 99M; + } + + + int rowsAffected = dbContext.BulkSaveChanges(); + int ordersAddedCount = dbContext.Orders.Where(o => o.Price == 10.57M).Count(); + int ordersDeletedCount = dbContext.Orders.Where(o => o.Price <= 5).Count(); + int ordersUpdatedCount = dbContext.Orders.Where(o => o.Price == 99M).Count(); + + Assert.IsTrue(rowsAffected == ordersToAdd.Count + ordersToDelete.Count + ordersToUpdate.Count, "The number of ro +ows affected must equal the sum of entities added, deleted and updated"); + Assert.IsTrue(ordersAddedCount == ordersToAdd.Count(), "The number of orders to add did not match what was expec +cted."); + Assert.IsTrue(ordersDeletedCount == 0, "The number of orders that was deleted did not match what was expected.") +); + Assert.IsTrue(ordersUpdatedCount == ordersToUpdate.Count(), "The number of orders that was updated did not match +h what was expected."); + } + [TestMethod] + public void With_Add_Changes() + { + var dbContext = SetupDbContext(true); + var oldTotalCount = dbContext.Orders.Where(o => o.Price == 10.57M).Count(); + + var ordersToAdd = new List(); + for (int i = 0; i < 2000; i++) + { + ordersToAdd.Add(new Order { Id = -i, Price = 10.57M }); + } + dbContext.Orders.AddRange(ordersToAdd); + + int rowsAffected = dbContext.BulkSaveChanges(); + int newTotalCount = dbContext.Orders.Where(o => o.Price == 10.57M).Count(); + + Assert.IsTrue(ordersToAdd.Where(o => o.Id <= 0).Count() == 0, "Primary key should have been updated for all enti +ities"); + Assert.IsTrue(rowsAffected == ordersToAdd.Count, "The number of rows affected must equal the sum of entities add +ded, deleted and updated"); + Assert.IsTrue(oldTotalCount + ordersToAdd.Count == newTotalCount, "The number of orders to add did not match wha +at was expected."); + } + [TestMethod] + public void With_Delete_Changes() + { + var dbContext = SetupDbContext(true); + var oldTotalCount = dbContext.Orders.Where(o => o.Price <= 5).Count(); + + //Delete orders + var ordersToDelete = dbContext.Orders.Where(o => o.Price <= 5).ToList(); + dbContext.Orders.RemoveRange(ordersToDelete); + + int rowsAffected = dbContext.BulkSaveChanges(); + int newTotalCount = dbContext.Orders.Where(o => o.Price <= 5).Count(); + + Assert.IsTrue(rowsAffected == ordersToDelete.Count, "The number of rows affected must equal the sum of entities + added, deleted and updated"); + Assert.IsTrue(oldTotalCount - ordersToDelete.Count == newTotalCount, "The number of orders to add did not match + what was expected."); + } + [TestMethod] + public void With_Update_Changes() + { + var dbContext = SetupDbContext(true); + var oldTotalCount = dbContext.Orders.Where(o => o.Price <= 10).Count(); + + //Update existing orders + var ordersToUpdate = dbContext.Orders.Where(o => o.Price <= 10).ToList(); + foreach (var orderToUpdate in ordersToUpdate) + { + orderToUpdate.Price = 99M; + } + + int rowsAffected = dbContext.BulkSaveChanges(); + int newTotalCount = dbContext.Orders.Where(o => o.Price <= 10).Count(); + int expectedCount = dbContext.Orders.Where(o => o.Price == 99M).Count(); + + Assert.IsTrue(rowsAffected == ordersToUpdate.Count, "The number of rows affected must equal the sum of entities + added, deleted and updated"); + Assert.IsTrue(oldTotalCount - ordersToUpdate.Count == newTotalCount, "The number of orders to add did not match + what was expected."); + } + [TestMethod] + public void With_Inheritance_Tpc() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpc); + //Delete Customers + var customersToDelete = dbContext.TpcPeople.OfType().Where(o => o.Id <= 1000); + int expectedRowsDeleted = customersToDelete.Count(); + dbContext.TpcPeople.RemoveRange(customersToDelete); + //Update Customers + var customersToUpdate = dbContext.TpcPeople.OfType().Where(o => o.Id > 1000 && o.Id <= 1500); + int expectedRowsUpdated = customersToUpdate.Count(); + foreach (var customerToUpdate in customersToUpdate) + { + customerToUpdate.FirstName = "CustomerUpdated"; + } + //Add New Customers + long maxId = dbContext.TpcPeople.OfType().Max(o => o.Id); + int expectedRowsAdded = 3000; + for (long i = maxId + 1; i <= maxId + expectedRowsAdded; i++) + { + dbContext.TpcPeople.Add(new TpcCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + AddedDate = DateTime.UtcNow + }); + } + int rowsAffected = dbContext.BulkSaveChanges(); + int rowsAfterDelete = customersToDelete.Count(); + int rowsUpdated = customersToUpdate.Where(o => o.FirstName == "CustomerUpdated").Count(); + int rowsAdded = dbContext.TpcPeople.OfType().Where(o => o.Id > maxId).Count(); + int expectedRowsAffected = expectedRowsDeleted + expectedRowsUpdated + expectedRowsAdded; + + Assert.IsTrue(rowsAfterDelete == 0, "The number of rows deleted not not match what was expected."); + Assert.IsTrue(rowsUpdated == expectedRowsUpdated, "The number of rows updated not not match what was expected.") +); + Assert.IsTrue(rowsAdded == expectedRowsAdded, "The number of rows added not not match what was expected."); + Assert.IsTrue(rowsAffected == expectedRowsAffected, "The new count minus the old count should match the number o +of rows inserted."); + } + [TestMethod] + public void With_Inheritance_Tph() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tph); + //Delete Customers + var customersToDelete = dbContext.TphCustomers.Where(o => o.Id <= 1000); + int expectedRowsDeleted = customersToDelete.Count(); + dbContext.TphCustomers.RemoveRange(customersToDelete); + //Update Customers + var customersToUpdate = dbContext.TphCustomers.Where(o => o.Id > 1000 && o.Id <= 1500); + int expectedRowsUpdated = customersToUpdate.Count(); + foreach (var customerToUpdate in customersToUpdate) + { + customerToUpdate.FirstName = "CustomerUpdated"; + } + //Add New Customers + long maxId = dbContext.TphPeople.Max(o => o.Id); + int expectedRowsAdded = 3000; + for (long i = maxId + 1; i <= maxId + expectedRowsAdded; i++) + { + dbContext.TphCustomers.Add(new TphCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + AddedDate = DateTime.UtcNow + }); + } + int rowsAffected = dbContext.BulkSaveChanges(); + int rowsAfterDelete = customersToDelete.Count(); + int rowsUpdated = customersToUpdate.Where(o => o.FirstName == "CustomerUpdated").Count(); + int rowsAdded = dbContext.TphCustomers.Where(o => o.Id > maxId).Count(); + int expectedRowsAffected = expectedRowsDeleted + expectedRowsUpdated + expectedRowsAdded; + + Assert.IsTrue(expectedRowsDeleted > 0, "The expected number of rows to delete must be greater than zero."); + Assert.IsTrue(rowsAfterDelete == 0, "The number of rows deleted not not match what was expected."); + Assert.IsTrue(rowsUpdated == expectedRowsUpdated, "The number of rows updated not not match what was expected.") +); + Assert.IsTrue(rowsAdded == expectedRowsAdded, "The number of rows added not not match what was expected."); + Assert.IsTrue(rowsAffected == expectedRowsAffected, "The new count minus the old count should match the number o +of rows inserted."); + } + [TestMethod] + public void With_Inheritance_Tpt() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpt); + //Delete Customers + var customersToDelete = dbContext.TptCustomers.Where(o => o.Id <= 1000); + int expectedRowsDeleted = customersToDelete.Count(); + dbContext.TptCustomers.RemoveRange(customersToDelete); + //Update Customers + var customersToUpdate = dbContext.TptCustomers.Where(o => o.Id > 1000 && o.Id <= 1500); + int expectedRowsUpdated = customersToUpdate.Count(); + foreach (var customerToUpdate in customersToUpdate) + { + customerToUpdate.Email = "name@domain.com"; + customerToUpdate.FirstName = "CustomerUpdated"; + } + //Add New Customers + long maxId = dbContext.TptPeople.Max(o => o.Id); + int expectedRowsAdded = 3000; + for (long i = maxId + 1; i <= maxId + expectedRowsAdded; i++) + { + dbContext.TptCustomers.Add(new TptCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + AddedDate = DateTime.UtcNow + }); + } + int rowsAffected = dbContext.BulkSaveChanges(); + int rowsAfterDelete = customersToDelete.Count(); + int rowsUpdated = customersToUpdate.Where(o => o.FirstName == "CustomerUpdated").Count(); + int rowsAdded = dbContext.TptCustomers.Where(o => o.Id > maxId).Count(); + int expectedRowsAffected = expectedRowsDeleted + expectedRowsUpdated + expectedRowsAdded; + + Assert.IsTrue(rowsAfterDelete == 0, "The number of rows deleted not not match what was expected."); + Assert.IsTrue(rowsUpdated == expectedRowsUpdated, "The number of rows updated not not match what was expected.") +); + Assert.IsTrue(rowsAdded == expectedRowsAdded, "The number of rows added not not match what was expected."); + Assert.IsTrue(rowsAffected == expectedRowsAffected, "The new count minus the old count should match the number o +of rows inserted."); + } + [TestMethod] + public void With_Schema() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Schema); + var totalCount = dbContext.ProductsWithCustomSchema.Count(); + + //Add new products + var productsToAdd = new List(); + for (int i = 0; i < 2000; i++) + { + productsToAdd.Add(new ProductWithCustomSchema { Id = (-i).ToString(), Price = 10.57M }); + } + dbContext.ProductsWithCustomSchema.AddRange(productsToAdd); + + //Delete products + var productsToDelete = dbContext.ProductsWithCustomSchema.Where(o => o.Price <= 5).ToList(); + dbContext.ProductsWithCustomSchema.RemoveRange(productsToDelete); + + //Update existing products + var productsToUpdate = dbContext.ProductsWithCustomSchema.Where(o => o.Price > 5 && o.Price <= 10).ToList(); + foreach (var productToUpdate in productsToUpdate) + { + productToUpdate.Price = 99M; + } + + int rowsAffected = dbContext.BulkSaveChanges(); + int productsAddedCount = dbContext.ProductsWithCustomSchema.Where(o => o.Price == 10.57M).Count(); + int productsDeletedCount = dbContext.ProductsWithCustomSchema.Where(o => o.Price <= 5).Count(); + int productsUpdatedCount = dbContext.ProductsWithCustomSchema.Where(o => o.Price == 99M).Count(); + + Assert.IsTrue(rowsAffected == productsToAdd.Count + productsToDelete.Count + productsToUpdate.Count, "The number +r of rows affected must equal the sum of entities added, deleted and updated"); + Assert.IsTrue(productsAddedCount == productsToAdd.Count(), "The number of products to add did not match what was +s expected."); + Assert.IsTrue(productsDeletedCount == 0, "The number of products that was deleted did not match what was expecte +ed."); + Assert.IsTrue(productsUpdatedCount == productsToUpdate.Count(), "The number of products that was updated did not +t match what was expected."); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\DbContextExtensions\BulkSaveChangesAsync.cs --- + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.EntityFrameworkCore; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Data; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class BulkSaveChangesAsync : DbContextExtensionsBase +{ + [TestMethod] + public async Task With_Default_Options() + { + var dbContext = SetupDbContext(true); + var totalCount = dbContext.Orders.Count(); + + //Add new orders + var ordersToAdd = new List(); + for (int i = 0; i < 2000; i++) + { + ordersToAdd.Add(new Order { Id = -i, Price = 10.57M }); + } + dbContext.Orders.AddRange(ordersToAdd); + + //Delete orders + var ordersToDelete = dbContext.Orders.Where(o => o.Price <= 5).ToList(); + dbContext.Orders.RemoveRange(ordersToDelete); + + //Update existing orders + var ordersToUpdate = dbContext.Orders.Where(o => o.Price > 5 && o.Price <= 10).ToList(); + foreach (var orderToUpdate in ordersToUpdate) + { + orderToUpdate.Price = 99M; + } + + + int rowsAffected = await dbContext.BulkSaveChangesAsync(); + int ordersAddedCount = dbContext.Orders.Where(o => o.Price == 10.57M).Count(); + int ordersDeletedCount = dbContext.Orders.Where(o => o.Price <= 5).Count(); + int ordersUpdatedCount = dbContext.Orders.Where(o => o.Price == 99M).Count(); + + Assert.IsTrue(rowsAffected == ordersToAdd.Count + ordersToDelete.Count + ordersToUpdate.Count, "The number of ro +ows affected must equal the sum of entities added, deleted and updated"); + Assert.IsTrue(ordersAddedCount == ordersToAdd.Count(), "The number of orders to add did not match what was expec +cted."); + Assert.IsTrue(ordersDeletedCount == 0, "The number of orders that was deleted did not match what was expected.") +); + Assert.IsTrue(ordersUpdatedCount == ordersToUpdate.Count(), "The number of orders that was updated did not match +h what was expected."); + } + [TestMethod] + public async Task With_Add_Changes() + { + var dbContext = SetupDbContext(true); + var oldTotalCount = dbContext.Orders.Where(o => o.Price == 10.57M).Count(); + + var ordersToAdd = new List(); + for (int i = 0; i < 2000; i++) + { + ordersToAdd.Add(new Order { Id = -i, Price = 10.57M }); + } + dbContext.Orders.AddRange(ordersToAdd); + + int rowsAffected = await dbContext.BulkSaveChangesAsync(); + int newTotalCount = dbContext.Orders.Where(o => o.Price == 10.57M).Count(); + + Assert.IsTrue(ordersToAdd.Where(o => o.Id <= 0).Count() == 0, "Primary key should have been updated for all enti +ities"); + Assert.IsTrue(rowsAffected == ordersToAdd.Count, "The number of rows affected must equal the sum of entities add +ded, deleted and updated"); + Assert.IsTrue(oldTotalCount + ordersToAdd.Count == newTotalCount, "The number of orders to add did not match wha +at was expected."); + } + [TestMethod] + public async Task With_Delete_Changes() + { + var dbContext = SetupDbContext(true); + var oldTotalCount = dbContext.Orders.Where(o => o.Price <= 5).Count(); + + //Delete orders + var ordersToDelete = dbContext.Orders.Where(o => o.Price <= 5).ToList(); + dbContext.Orders.RemoveRange(ordersToDelete); + + int rowsAffected = await dbContext.BulkSaveChangesAsync(); + int newTotalCount = dbContext.Orders.Where(o => o.Price <= 5).Count(); + + Assert.IsTrue(rowsAffected == ordersToDelete.Count, "The number of rows affected must equal the sum of entities + added, deleted and updated"); + Assert.IsTrue(oldTotalCount - ordersToDelete.Count == newTotalCount, "The number of orders to add did not match + what was expected."); + } + [TestMethod] + public async Task With_Update_Changes() + { + var dbContext = SetupDbContext(true); + var oldTotalCount = dbContext.Orders.Where(o => o.Price <= 10).Count(); + + //Update existing orders + var ordersToUpdate = dbContext.Orders.Where(o => o.Price <= 10).ToList(); + foreach (var orderToUpdate in ordersToUpdate) + { + orderToUpdate.Price = 99M; + } + + int rowsAffected = await dbContext.BulkSaveChangesAsync(); + int newTotalCount = dbContext.Orders.Where(o => o.Price <= 10).Count(); + int expectedCount = dbContext.Orders.Where(o => o.Price == 99M).Count(); + + Assert.IsTrue(rowsAffected == ordersToUpdate.Count, "The number of rows affected must equal the sum of entities + added, deleted and updated"); + Assert.IsTrue(oldTotalCount - ordersToUpdate.Count == newTotalCount, "The number of orders to add did not match + what was expected."); + } + [TestMethod] + public async Task With_Inheritance_Tpc() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpc); + //Delete Customers + var customersToDelete = dbContext.TpcPeople.OfType().Where(o => o.Id <= 1000); + int expectedRowsDeleted = customersToDelete.Count(); + dbContext.TpcPeople.RemoveRange(customersToDelete); + //Update Customers + var customersToUpdate = dbContext.TpcPeople.OfType().Where(o => o.Id > 1000 && o.Id <= 1500); + int expectedRowsUpdated = customersToUpdate.Count(); + foreach (var customerToUpdate in customersToUpdate) + { + customerToUpdate.FirstName = "CustomerUpdated"; + } + //Add New Customers + long maxId = dbContext.TpcPeople.OfType().Max(o => o.Id); + int expectedRowsAdded = 3000; + for (long i = maxId + 1; i <= maxId + expectedRowsAdded; i++) + { + dbContext.TpcPeople.Add(new TpcCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + AddedDate = DateTime.UtcNow + }); + } + int rowsAffected = await dbContext.BulkSaveChangesAsync(); + int rowsAfterDelete = customersToDelete.Count(); + int rowsUpdated = customersToUpdate.Where(o => o.FirstName == "CustomerUpdated").Count(); + int rowsAdded = dbContext.TpcPeople.OfType().Where(o => o.Id > maxId).Count(); + int expectedRowsAffected = expectedRowsDeleted + expectedRowsUpdated + expectedRowsAdded; + + Assert.IsTrue(rowsAfterDelete == 0, "The number of rows deleted not not match what was expected."); + Assert.IsTrue(rowsUpdated == expectedRowsUpdated, "The number of rows updated not not match what was expected.") +); + Assert.IsTrue(rowsAdded == expectedRowsAdded, "The number of rows added not not match what was expected."); + Assert.IsTrue(rowsAffected == expectedRowsAffected, "The new count minus the old count should match the number o +of rows inserted."); + } + [TestMethod] + public async Task With_Inheritance_Tph() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tph); + //Delete Customers + var customersToDelete = dbContext.TphCustomers.Where(o => o.Id <= 1000); + int expectedRowsDeleted = customersToDelete.Count(); + dbContext.TphCustomers.RemoveRange(customersToDelete); + //Update Customers + var customersToUpdate = dbContext.TphCustomers.Where(o => o.Id > 1000 && o.Id <= 1500); + int expectedRowsUpdated = customersToUpdate.Count(); + foreach (var customerToUpdate in customersToUpdate) + { + customerToUpdate.FirstName = "CustomerUpdated"; + } + //Add New Customers + long maxId = dbContext.TphPeople.Max(o => o.Id); + int expectedRowsAdded = 3000; + for (long i = maxId + 1; i <= maxId + expectedRowsAdded; i++) + { + dbContext.TphCustomers.Add(new TphCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + AddedDate = DateTime.UtcNow + }); + } + int rowsAffected = await dbContext.BulkSaveChangesAsync(); + int rowsAfterDelete = customersToDelete.Count(); + int rowsUpdated = customersToUpdate.Where(o => o.FirstName == "CustomerUpdated").Count(); + int rowsAdded = dbContext.TphCustomers.Where(o => o.Id > maxId).Count(); + int expectedRowsAffected = expectedRowsDeleted + expectedRowsUpdated + expectedRowsAdded; + + Assert.IsTrue(expectedRowsDeleted > 0, "The expected number of rows to delete must be greater than zero."); + Assert.IsTrue(rowsAfterDelete == 0, "The number of rows deleted not not match what was expected."); + Assert.IsTrue(rowsUpdated == expectedRowsUpdated, "The number of rows updated not not match what was expected.") +); + Assert.IsTrue(rowsAdded == expectedRowsAdded, "The number of rows added not not match what was expected."); + Assert.IsTrue(rowsAffected == expectedRowsAffected, "The new count minus the old count should match the number o +of rows inserted."); + } + [TestMethod] + public async Task With_Inheritance_Tpt() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpt); + //Delete Customers + var customersToDelete = dbContext.TptCustomers.Where(o => o.Id <= 1000); + int expectedRowsDeleted = customersToDelete.Count(); + dbContext.TptCustomers.RemoveRange(customersToDelete); + //Update Customers + var customersToUpdate = dbContext.TptCustomers.Where(o => o.Id > 1000 && o.Id <= 1500); + int expectedRowsUpdated = customersToUpdate.Count(); + foreach (var customerToUpdate in customersToUpdate) + { + customerToUpdate.Email = "name@domain.com"; + customerToUpdate.FirstName = "CustomerUpdated"; + } + //Add New Customers + long maxId = dbContext.TptPeople.Max(o => o.Id); + int expectedRowsAdded = 3000; + for (long i = maxId + 1; i <= maxId + expectedRowsAdded; i++) + { + dbContext.TptCustomers.Add(new TptCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + AddedDate = DateTime.UtcNow + }); + } + int rowsAffected = await dbContext.BulkSaveChangesAsync(); + int rowsAfterDelete = customersToDelete.Count(); + int rowsUpdated = customersToUpdate.Where(o => o.FirstName == "CustomerUpdated").Count(); + int rowsAdded = dbContext.TptCustomers.Where(o => o.Id > maxId).Count(); + int expectedRowsAffected = expectedRowsDeleted + expectedRowsUpdated + expectedRowsAdded; + + Assert.IsTrue(rowsAfterDelete == 0, "The number of rows deleted not not match what was expected."); + Assert.IsTrue(rowsUpdated == expectedRowsUpdated, "The number of rows updated not not match what was expected.") +); + Assert.IsTrue(rowsAdded == expectedRowsAdded, "The number of rows added not not match what was expected."); + Assert.IsTrue(rowsAffected == expectedRowsAffected, "The new count minus the old count should match the number o +of rows inserted."); + } + [TestMethod] + public async Task With_Schema() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Schema); + var totalCount = await dbContext.ProductsWithCustomSchema.CountAsync(); + + //Add new products + var productsToAdd = new List(); + for (int i = 0; i < 2000; i++) + { + productsToAdd.Add(new ProductWithCustomSchema { Id = (-i).ToString(), Price = 10.57M }); + } + dbContext.ProductsWithCustomSchema.AddRange(productsToAdd); + + //Delete products + var productsToDelete = dbContext.ProductsWithCustomSchema.Where(o => o.Price <= 5).ToList(); + dbContext.ProductsWithCustomSchema.RemoveRange(productsToDelete); + + //Update existing products + var productsToUpdate = dbContext.ProductsWithCustomSchema.Where(o => o.Price > 5 && o.Price <= 10).ToList(); + foreach (var productToUpdate in productsToUpdate) + { + productToUpdate.Price = 99M; + } + + int rowsAffected = await dbContext.BulkSaveChangesAsync(); + int productsAddedCount = await dbContext.ProductsWithCustomSchema.Where(o => o.Price == 10.57M).CountAsync(); + int productsDeletedCount = await dbContext.ProductsWithCustomSchema.Where(o => o.Price <= 5).CountAsync(); + int productsUpdatedCount = await dbContext.ProductsWithCustomSchema.Where(o => o.Price == 99M).CountAsync(); + + Assert.IsTrue(rowsAffected == productsToAdd.Count + productsToDelete.Count + productsToUpdate.Count, "The number +r of rows affected must equal the sum of entities added, deleted and updated"); + Assert.IsTrue(productsAddedCount == productsToAdd.Count(), "The number of products to add did not match what was +s expected."); + Assert.IsTrue(productsDeletedCount == 0, "The number of products that was deleted did not match what was expecte +ed."); + Assert.IsTrue(productsUpdatedCount == productsToUpdate.Count(), "The number of products that was updated did not +t match what was expected."); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\DbContextExtensions\BulkSync.cs --- + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Data; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class BulkSync : DbContextExtensionsBase +{ + [TestMethod] + public void With_Default_Options() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Count(); + var orders = dbContext.Orders.Where(o => o.Id <= 10000).OrderBy(o => o.Id).ToList(); + int ordersToAdd = 5000; + int ordersToUpdate = orders.Count; + foreach (var order in orders) + { + order.Price = Convert.ToDecimal(order.Id + .25); + } + for (int i = 0; i < ordersToAdd; i++) + { + orders.Add(new Order { Id = 100000 + i, Price = 3.55M }); + } + var result = dbContext.BulkSync(orders); + var newOrders = dbContext.Orders.OrderBy(o => o.Id).ToList(); + bool areAddedOrdersMerged = true; + bool areUpdatedOrdersMerged = true; + foreach (var newOrder in newOrders.Where(o => o.Id <= 10000).OrderBy(o => o.Id)) + { + if (newOrder.Price != Convert.ToDecimal(newOrder.Id + .25)) + { + areUpdatedOrdersMerged = false; + break; + } + } + foreach (var newOrder in newOrders.Where(o => o.Id >= 500000).OrderBy(o => o.Id)) + { + if (newOrder.Price != 3.55M) + { + areAddedOrdersMerged = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == oldTotal + ordersToAdd, "The number of rows inserted must match the count o +of order list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(result.RowsDeleted == oldTotal - orders.Count() + ordersToAdd, "The number of rows deleted must ma +atch the difference from the total existing orders to the new orders to add/update"); + Assert.IsTrue(areAddedOrdersMerged, "The orders that were added did not merge correctly"); + Assert.IsTrue(areUpdatedOrdersMerged, "The orders that were updated did not merge correctly"); + } + [TestMethod] + public void With_Inheritance_Tpc() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpc); + var customers = dbContext.TpcPeople.Where(o => o.Id <= 1000).OfType().ToList(); + int customersToAdd = 5000; + int customersToUpdate = customers.Count; + int customersToDelete = dbContext.TpcPeople.OfType().Count() - customersToUpdate; + + foreach (var customer in customers) + { + customer.FirstName = "BulkSync_Tpc_Update"; + } + for (int i = 0; i < customersToAdd; i++) + { + customers.Add(new TpcCustomer + { + Id = 10000 + i, + FirstName = "BulkSync_Tpc_Add", + AddedDate = DateTime.UtcNow + }); + } + var result = dbContext.BulkSync(customers, options => { options.MergeOnCondition = (s, t) => s.Id == t.Id; }); + int customersAdded = dbContext.TpcPeople.Where(o => o.FirstName == "BulkSync_Tpc_Add").OfType().Cou +unt(); + int customersUpdated = dbContext.TpcPeople.Where(o => o.FirstName == "BulkSync_Tpc_Update").OfType( +().Count(); + int newCustomerTotal = dbContext.TpcPeople.OfType().Count(); + + Assert.IsTrue(result.RowsAffected == customersAdded + customersToUpdate + customersToDelete, "The number of rows +s affected must match the sum of customers added, updated and deleted."); + Assert.IsTrue(result.RowsUpdated == customersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == customersToAdd, "The number of rows added must match"); + Assert.IsTrue(result.RowsDeleted == customersToDelete, "The number of rows deleted must match the difference fro +om the total existing orders to the new orders to add/update"); + Assert.IsTrue(customersToAdd == customersAdded, "The custmoers that were added did not merge correctly"); + Assert.IsTrue(customersToUpdate == customersUpdated, "The customers that were updated did not merge correctly"); + Assert.IsTrue(newCustomerTotal == customersToAdd + customersToUpdate, "The count of customers in the database sh +hould match the sum of customers added and updated."); + } + [TestMethod] + public void With_Inheritance_Tph() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tph); + var customers = dbContext.TphCustomers.Where(o => o.Id <= 1000).ToList(); + int customersToAdd = 5000; + int customersToUpdate = customers.Count; + int customersToDelete = dbContext.TphPeople.Count() - customersToUpdate; + + foreach (var customer in customers) + { + customer.FirstName = "BulkSync_Tph_Update"; + } + for (int i = 0; i < customersToAdd; i++) + { + customers.Add(new TphCustomer + { + Id = 10000 + i, + FirstName = "BulkSync_Tph_Add", + AddedDate = DateTime.UtcNow + }); + } + var result = dbContext.BulkSync(customers, options => { options.UsePermanentTable = true; options.MergeOnConditi +ion = (s, t) => s.Id == t.Id; }); + int customersAdded = dbContext.TphCustomers.Where(o => o.FirstName == "BulkSync_Tph_Add").Count(); + int customersUpdated = dbContext.TphCustomers.Where(o => o.FirstName == "BulkSync_Tph_Update").Count(); + int newCustomerTotal = dbContext.TphCustomers.Count(); + + Assert.IsTrue(result.RowsAffected == customersAdded + customersToUpdate + customersToDelete, "The number of rows +s affected must match the sum of customers added, updated and deleted."); + Assert.IsTrue(result.RowsUpdated == customersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == customersToAdd, "The number of rows added must match"); + Assert.IsTrue(result.RowsDeleted == customersToDelete, "The number of rows deleted must match the difference fro +om the total existing orders to the new orders to add/update"); + Assert.IsTrue(customersToAdd == customersAdded, "The customers that were added did not merge correctly"); + Assert.IsTrue(customersToUpdate == customersUpdated, "The customers that were updated did not merge correctly"); + Assert.IsTrue(newCustomerTotal == customersToAdd + customersToUpdate, "The count of customers in the database sh +hould match the sum of customers added and updated."); + } + [TestMethod] + public void With_Inheritance_Tpt() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpt); + var customers = dbContext.TptPeople.Where(o => o.Id <= 1000).OfType().ToList(); + int customersToAdd = 5000; + int customersToUpdate = customers.Count; + int customersToDelete = dbContext.TptCustomers.Count() - customersToUpdate; + + foreach (var customer in customers) + { + customer.FirstName = "BulkSync_Tpt_Update"; + } + for (int i = 0; i < customersToAdd; i++) + { + customers.Add(new TptCustomer + { + Id = 10000 + i, + FirstName = "BulkSync_Tpt_Add", + AddedDate = DateTime.UtcNow + }); + } + var result = dbContext.BulkSync(customers, options => { options.MergeOnCondition = (s, t) => s.Id == t.Id; }); + int customersAdded = dbContext.TptPeople.Where(o => o.FirstName == "BulkSync_Tpt_Add").OfType().Cou +unt(); + int customersUpdated = dbContext.TptPeople.Where(o => o.FirstName == "BulkSync_Tpt_Update").OfType( +().Count(); + int newCustomerTotal = dbContext.TptPeople.OfType().Count(); + + Assert.IsTrue(result.RowsAffected == customersAdded + customersToUpdate + customersToDelete, "The number of rows +s affected must match the sum of customers added, updated and deleted."); + Assert.IsTrue(result.RowsUpdated == customersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == customersToAdd, "The number of rows added must match"); + Assert.IsTrue(result.RowsDeleted == customersToDelete, "The number of rows deleted must match the difference fro +om the total existing orders to the new orders to add/update"); + Assert.IsTrue(customersToAdd == customersAdded, "The custmoers that were added did not merge correctly"); + Assert.IsTrue(customersToUpdate == customersUpdated, "The customers that were updated did not merge correctly"); + Assert.IsTrue(newCustomerTotal == customersToAdd + customersToUpdate, "The count of customers in the database sh +hould match the sum of customers added and updated."); + } + [TestMethod] + public void With_Options_AutoMapIdentity() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Count(); + int ordersToUpdate = 3; + int ordersToAdd = 2; + var orders = new List + { + new Order { ExternalId = "id-1", Price=7.10M }, + new Order { ExternalId = "id-2", Price=9.33M }, + new Order { ExternalId = "id-3", Price=3.25M }, + new Order { ExternalId = "id-1000001", Price=2.15M }, + new Order { ExternalId = "id-1000002", Price=5.75M }, + }; + var result = dbContext.BulkSync(orders, options => { options.MergeOnCondition = (s, t) => s.ExternalId == t.Exte +ernalId; options.UsePermanentTable = true; }); + bool autoMapIdentityMatched = true; + foreach (var order in orders) + { + if (!dbContext.Orders.Any(o => o.ExternalId == order.ExternalId && o.Price == order.Price)) + { + autoMapIdentityMatched = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == oldTotal + ordersToAdd, "The number of rows inserted must match the count o +of order list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(result.RowsDeleted == oldTotal - orders.Count() + ordersToAdd, "The number of rows deleted must ma +atch the difference from the total existing orders to the new orders to add/update"); + Assert.IsTrue(autoMapIdentityMatched, "The auto mapping of ids of entities that were merged failed to match up") +); + } + [TestMethod] + public void With_Options_MergeOnCondition() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Count(); + var orders = dbContext.Orders.Where(o => o.Id <= 100 && o.ExternalId != null).OrderBy(o => o.Id).ToList(); + int ordersToAdd = 50; + int ordersToUpdate = orders.Count; + foreach (var order in orders) + { + order.Price = Convert.ToDecimal(order.Id + .25); + } + for (int i = 0; i < ordersToAdd; i++) + { + orders.Add(new Order { Id = 100000 + i, Price = 3.55M }); + } + var result = dbContext.BulkSync(orders, new BulkSyncOptions + { + MergeOnCondition = (s, t) => s.ExternalId == t.ExternalId, + BatchSize = 1000 + }); + var newOrders = dbContext.Orders.OrderBy(o => o.Id).ToList(); + bool areAddedOrdersMerged = true; + bool areUpdatedOrdersMerged = true; + foreach (var newOrder in newOrders.Where(o => o.Id <= 100 && o.ExternalId != null).OrderBy(o => o.Id)) + { + if (newOrder.Price != Convert.ToDecimal(newOrder.Id + .25)) + { + areUpdatedOrdersMerged = false; + break; + } + } + foreach (var newOrder in newOrders.Where(o => o.Id >= 500000).OrderBy(o => o.Id)) + { + if (newOrder.Price != 3.55M) + { + areAddedOrdersMerged = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == oldTotal + ordersToAdd, "The number of rows inserted must match the count o +of order list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(result.RowsDeleted == oldTotal - orders.Count() + ordersToAdd, "The number of rows deleted must ma +atch the difference from the total existing orders to the new orders to add/update"); + Assert.IsTrue(areAddedOrdersMerged, "The orders that were added did not merge correctly"); + Assert.IsTrue(areUpdatedOrdersMerged, "The orders that were updated did not merge correctly"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\DbContextExtensions\BulkSyncAsync.cs --- + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.EntityFrameworkCore; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Data; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class BulkSyncAsync : DbContextExtensionsBase +{ + [TestMethod] + public async Task With_Default_Options() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Count(); + var orders = dbContext.Orders.Where(o => o.Id <= 10000).OrderBy(o => o.Id).ToList(); + int ordersToAdd = 5000; + int ordersToUpdate = orders.Count; + foreach (var order in orders) + { + order.Price = Convert.ToDecimal(order.Id + .25); + } + for (int i = 0; i < ordersToAdd; i++) + { + orders.Add(new Order { Id = 100000 + i, Price = 3.55M }); + } + var result = await dbContext.BulkSyncAsync(orders); + var newOrders = dbContext.Orders.OrderBy(o => o.Id).ToList(); + bool areAddedOrdersMerged = true; + bool areUpdatedOrdersMerged = true; + foreach (var newOrder in newOrders.Where(o => o.Id <= 10000).OrderBy(o => o.Id)) + { + if (newOrder.Price != Convert.ToDecimal(newOrder.Id + .25)) + { + areUpdatedOrdersMerged = false; + break; + } + } + foreach (var newOrder in newOrders.Where(o => o.Id >= 500000).OrderBy(o => o.Id)) + { + if (newOrder.Price != 3.55M) + { + areAddedOrdersMerged = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == oldTotal + ordersToAdd, "The number of rows inserted must match the count o +of order list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(result.RowsDeleted == oldTotal - orders.Count() + ordersToAdd, "The number of rows deleted must ma +atch the difference from the total existing orders to the new orders to add/update"); + Assert.IsTrue(areAddedOrdersMerged, "The orders that were added did not merge correctly"); + Assert.IsTrue(areUpdatedOrdersMerged, "The orders that were updated did not merge correctly"); + } + [TestMethod] + public async Task With_Inheritance_Tpc() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpc); + var customers = await dbContext.TpcPeople.Where(o => o.Id <= 1000).OfType().ToListAsync(); + int customersToAdd = 5000; + int customersToUpdate = customers.Count; + int customersToDelete = dbContext.TpcPeople.OfType().Count() - customersToUpdate; + + foreach (var customer in customers) + { + customer.FirstName = "BulkSync_Tpc_Update"; + } + for (int i = 0; i < customersToAdd; i++) + { + customers.Add(new TpcCustomer + { + Id = 10000 + i, + FirstName = "BulkSync_Tpc_Add", + AddedDate = DateTime.UtcNow + }); + } + var result = await dbContext.BulkSyncAsync(customers, options => { options.MergeOnCondition = (s, t) => s.Id == + t.Id; }); + int customersAdded = dbContext.TpcPeople.Where(o => o.FirstName == "BulkSync_Tpc_Add").OfType().Cou +unt(); + int customersUpdated = dbContext.TpcPeople.Where(o => o.FirstName == "BulkSync_Tpc_Update").OfType( +().Count(); + int newCustomerTotal = dbContext.TpcPeople.OfType().Count(); + + Assert.IsTrue(result.RowsAffected == customersAdded + customersToUpdate + customersToDelete, "The number of rows +s affected must match the sum of customers added, updated and deleted."); + Assert.IsTrue(result.RowsUpdated == customersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == customersToAdd, "The number of rows added must match"); + Assert.IsTrue(result.RowsDeleted == customersToDelete, "The number of rows deleted must match the difference fro +om the total existing orders to the new orders to add/update"); + Assert.IsTrue(customersToAdd == customersAdded, "The custmoers that were added did not merge correctly"); + Assert.IsTrue(customersToUpdate == customersUpdated, "The customers that were updated did not merge correctly"); + Assert.IsTrue(newCustomerTotal == customersToAdd + customersToUpdate, "The count of customers in the database sh +hould match the sum of customers added and updated."); + } + [TestMethod] + public async Task With_Inheritance_Tph() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tph); + var customers = await dbContext.TphCustomers.Where(o => o.Id <= 1000).ToListAsync(); + int customersToAdd = 5000; + int customersToUpdate = customers.Count; + int customersToDelete = dbContext.TphPeople.Count() - customersToUpdate; + + foreach (var customer in customers) + { + customer.FirstName = "BulkSync_Tph_Update"; + } + for (int i = 0; i < customersToAdd; i++) + { + customers.Add(new TphCustomer + { + Id = 10000 + i, + FirstName = "BulkSync_Tph_Add", + AddedDate = DateTime.UtcNow + }); + } + var result = await dbContext.BulkSyncAsync(customers, options => { options.UsePermanentTable = true; options.Mer +rgeOnCondition = (s, t) => s.Id == t.Id; }); + int customersAdded = dbContext.TphCustomers.Where(o => o.FirstName == "BulkSync_Tph_Add").Count(); + int customersUpdated = dbContext.TphCustomers.Where(o => o.FirstName == "BulkSync_Tph_Update").Count(); + int newCustomerTotal = dbContext.TphCustomers.Count(); + + Assert.IsTrue(result.RowsAffected == customersAdded + customersToUpdate + customersToDelete, "The number of rows +s affected must match the sum of customers added, updated and deleted."); + Assert.IsTrue(result.RowsUpdated == customersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == customersToAdd, "The number of rows added must match"); + Assert.IsTrue(result.RowsDeleted == customersToDelete, "The number of rows deleted must match the difference fro +om the total existing orders to the new orders to add/update"); + Assert.IsTrue(customersToAdd == customersAdded, "The customers that were added did not merge correctly"); + Assert.IsTrue(customersToUpdate == customersUpdated, "The customers that were updated did not merge correctly"); + Assert.IsTrue(newCustomerTotal == customersToAdd + customersToUpdate, "The count of customers in the database sh +hould match the sum of customers added and updated."); + } + [TestMethod] + public async Task With_Inheritance_Tpt() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpt); + var customers = await dbContext.TptPeople.Where(o => o.Id <= 1000).OfType().ToListAsync(); + int customersToAdd = 5000; + int customersToUpdate = customers.Count; + int customersToDelete = dbContext.TptCustomers.Count() - customersToUpdate; + + foreach (var customer in customers) + { + customer.FirstName = "BulkSync_Tpt_Update"; + } + for (int i = 0; i < customersToAdd; i++) + { + customers.Add(new TptCustomer + { + Id = 10000 + i, + FirstName = "BulkSync_Tpt_Add", + AddedDate = DateTime.UtcNow + }); + } + var result = await dbContext.BulkSyncAsync(customers, options => { options.MergeOnCondition = (s, t) => s.Id == + t.Id; }); + int customersAdded = dbContext.TptPeople.Where(o => o.FirstName == "BulkSync_Tpt_Add").OfType().Cou +unt(); + int customersUpdated = dbContext.TptPeople.Where(o => o.FirstName == "BulkSync_Tpt_Update").OfType( +().Count(); + int newCustomerTotal = dbContext.TptPeople.OfType().Count(); + + Assert.IsTrue(result.RowsAffected == customersAdded + customersToUpdate + customersToDelete, "The number of rows +s affected must match the sum of customers added, updated and deleted."); + Assert.IsTrue(result.RowsUpdated == customersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == customersToAdd, "The number of rows added must match"); + Assert.IsTrue(result.RowsDeleted == customersToDelete, "The number of rows deleted must match the difference fro +om the total existing orders to the new orders to add/update"); + Assert.IsTrue(customersToAdd == customersAdded, "The custmoers that were added did not merge correctly"); + Assert.IsTrue(customersToUpdate == customersUpdated, "The customers that were updated did not merge correctly"); + Assert.IsTrue(newCustomerTotal == customersToAdd + customersToUpdate, "The count of customers in the database sh +hould match the sum of customers added and updated."); + } + [TestMethod] + public async Task With_Options_AutoMapIdentity() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Count(); + int ordersToUpdate = 3; + int ordersToAdd = 2; + var orders = new List + { + new Order { ExternalId = "id-1", Price=7.10M }, + new Order { ExternalId = "id-2", Price=9.33M }, + new Order { ExternalId = "id-3", Price=3.25M }, + new Order { ExternalId = "id-1000001", Price=2.15M }, + new Order { ExternalId = "id-1000002", Price=5.75M }, + }; + var result = await dbContext.BulkSyncAsync(orders, options => { options.MergeOnCondition = (s, t) => s.ExternalI +Id == t.ExternalId; options.UsePermanentTable = true; }); + bool autoMapIdentityMatched = true; + foreach (var order in orders) + { + if (!dbContext.Orders.Any(o => o.ExternalId == order.ExternalId && o.Price == order.Price)) + { + autoMapIdentityMatched = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == oldTotal + ordersToAdd, "The number of rows inserted must match the count o +of order list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(result.RowsDeleted == oldTotal - orders.Count() + ordersToAdd, "The number of rows deleted must ma +atch the difference from the total existing orders to the new orders to add/update"); + Assert.IsTrue(autoMapIdentityMatched, "The auto mapping of ids of entities that were merged failed to match up") +); + } + [TestMethod] + public async Task With_Options_MergeOnCondition() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Count(); + var orders = dbContext.Orders.Where(o => o.Id <= 100 && o.ExternalId != null).OrderBy(o => o.Id).ToList(); + int ordersToAdd = 50; + int ordersToUpdate = orders.Count; + foreach (var order in orders) + { + order.Price = Convert.ToDecimal(order.Id + .25); + } + for (int i = 0; i < ordersToAdd; i++) + { + orders.Add(new Order { Id = 100000 + i, Price = 3.55M }); + } + var result = await dbContext.BulkSyncAsync(orders, new BulkSyncOptions + { + MergeOnCondition = (s, t) => s.ExternalId == t.ExternalId, + BatchSize = 1000 + }); + var newOrders = dbContext.Orders.OrderBy(o => o.Id).ToList(); + bool areAddedOrdersMerged = true; + bool areUpdatedOrdersMerged = true; + foreach (var newOrder in newOrders.Where(o => o.Id <= 100 && o.ExternalId != null).OrderBy(o => o.Id)) + { + if (newOrder.Price != Convert.ToDecimal(newOrder.Id + .25)) + { + areUpdatedOrdersMerged = false; + break; + } + } + foreach (var newOrder in newOrders.Where(o => o.Id >= 500000).OrderBy(o => o.Id)) + { + if (newOrder.Price != 3.55M) + { + areAddedOrdersMerged = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == oldTotal + ordersToAdd, "The number of rows inserted must match the count o +of order list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(result.RowsDeleted == oldTotal - orders.Count() + ordersToAdd, "The number of rows deleted must ma +atch the difference from the total existing orders to the new orders to add/update"); + Assert.IsTrue(areAddedOrdersMerged, "The orders that were added did not merge correctly"); + Assert.IsTrue(areUpdatedOrdersMerged, "The orders that were updated did not merge correctly"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\DbContextExtensions\BulkUpdate.cs --- + +using System.Linq; +using Microsoft.EntityFrameworkCore; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Data; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class BulkUpdate : DbContextExtensionsBase +{ + [TestMethod] + public void With_Complex_Key() + { + var dbContext = SetupDbContext(true); + var products = dbContext.ProductsWithComplexKey.Where(o => o.Price == 1.25M).ToList(); + foreach (var product in products) + { + product.Price = 2.35M; + } + var oldTotal = dbContext.ProductsWithComplexKey.Where(o => o.Price == 2.35M).Count(); + int rowsUpdated = dbContext.BulkUpdate(products); + var newTotal = dbContext.ProductsWithComplexKey.Where(o => o.Price == 2.35M).Count(); + + Assert.IsTrue(products.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(rowsUpdated == products.Count, "The number of rows updated must match the count of entities that w +were retrieved"); + Assert.IsTrue(newTotal == rowsUpdated + oldTotal, "The count of new orders must be equal the number of rows upda +ated in the database."); + } + [TestMethod] + public void With_Default_Options() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M).OrderBy(o => o.Id).ToList(); + long maxId = 0; + foreach (var order in orders) + { + order.Price = 2.35M; + maxId = order.Id; + } + int rowsUpdated = dbContext.BulkUpdate(orders); + var newOrders = dbContext.Orders.Where(o => o.Price == 2.35M).OrderBy(o => o.Id).Count(); + int entitiesWithChanges = dbContext.ChangeTracker.Entries().Where(t => t.State == EntityState.Modified).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(rowsUpdated == orders.Count, "The number of rows updated must match the count of entities that wer +re retrieved"); + Assert.IsTrue(newOrders == rowsUpdated, "The count of new orders must be equal the number of rows updated in the +e database."); + } + [TestMethod] + public void With_Inheritance_Tpc() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpc); + var customers = dbContext.TpcPeople.Where(o => o.LastName != "BulkUpdate_Tpc").OfType().ToList(); + var vendors = dbContext.TpcPeople.OfType().ToList(); + foreach (var customer in customers) + { + customer.FirstName = string.Format("Id={0}", customer.Id); + customer.LastName = "BulkUpdate_Tpc"; + } + int rowsUpdated = dbContext.BulkUpdate(customers); + var newCustomers = dbContext.TpcPeople.Where(o => o.LastName == "BulkUpdate_Tpc").OfType().Count(); + int entitiesWithChanges = dbContext.ChangeTracker.Entries().Where(t => t.State == EntityState.Modified).Count(); + + Assert.IsTrue(vendors.Count > 0 && vendors.Count != customers.Count, "There should be vendor records in the data +abase"); + Assert.IsTrue(customers.Count > 0, "There must be customers in database that match this condition (Price = $1.25 +5)"); + Assert.IsTrue(rowsUpdated == customers.Count, "The number of rows updated must match the count of entities that + were retrieved"); + Assert.IsTrue(newCustomers == rowsUpdated, "The count of new customers must be equal the number of rows updated + in the database."); + } + [TestMethod] + public void With_Inheritance_Tph() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tph); + var customers = dbContext.TphPeople.Where(o => o.LastName != "BulkUpdateTest").OfType().ToList(); + var vendors = dbContext.TphPeople.OfType().ToList(); + foreach (var customer in customers) + { + customer.FirstName = string.Format("Id={0}", customer.Id); + customer.LastName = "BulkUpdateTest"; + } + int rowsUpdated = dbContext.BulkUpdate(customers); + var newCustomers = dbContext.TphPeople.Where(o => o.LastName == "BulkUpdateTest").OrderBy(o => o.Id).Count(); + int entitiesWithChanges = dbContext.ChangeTracker.Entries().Where(t => t.State == EntityState.Modified).Count(); + + Assert.IsTrue(vendors.Count > 0 && vendors.Count != customers.Count, "There should be vendor records in the data +abase"); + Assert.IsTrue(customers.Count > 0, "There must be customers in database that match this condition (Price = $1.25 +5)"); + Assert.IsTrue(rowsUpdated == customers.Count, "The number of rows updated must match the count of entities that + were retrieved"); + Assert.IsTrue(newCustomers == rowsUpdated, "The count of new customers must be equal the number of rows updated + in the database."); + } + [TestMethod] + public void With_Inheritance_Tpt() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpt); + var customers = dbContext.TptCustomers.Where(o => o.LastName != "BulkUpdateTest").ToList(); + foreach (var customer in customers) + { + customer.FirstName = string.Format("Id={0}", customer.Id); + customer.LastName = "BulkUpdateTest"; + } + int rowsUpdated = dbContext.BulkUpdate(customers); + var newCustomers = dbContext.TptCustomers.Where(o => o.LastName == "BulkUpdateTest").OrderBy(o => o.Id).Count(); + //int entitiesWithChanges = dbContext.ChangeTracker.Entries().Where(t => t.State == EntityState.Modified).Count( +(); + + Assert.IsTrue(customers.Count > 0, "There must be customers in database that match this condition (Price = $1.25 +5)"); + Assert.IsTrue(rowsUpdated == customers.Count, "The number of rows updated must match the count of entities that + were retrieved"); + Assert.IsTrue(newCustomers == rowsUpdated, "The count of new customers must be equal the number of rows updated + in the database."); + } + [TestMethod] + public void With_Options_InputColumns_PropertyExpression() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).OrderBy(o => o.Id).ToList(); + foreach (var order in orders) + { + order.Price = 2.35M; + order.ExternalId = null; + } + var oldTotal = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + int rowsUpdated = dbContext.BulkUpdate(orders, options => { options.InputColumns = o => o.Price; }); + var newTotal1 = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + var newTotal2 = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(newTotal1 == rowsUpdated + oldTotal, "The count of new orders must be equal the number of rows upd +dated in the database."); + Assert.IsTrue(newTotal2 == 0, "There should be not records with condition (Price = $1.25)"); + } + [TestMethod] + public void With_Options_InputColumns_NewExpression() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).OrderBy(o => o.Id).ToList(); + foreach (var order in orders) + { + order.Price = 2.35M; + order.ExternalId = null; + } + var oldTotal = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + int rowsUpdated = dbContext.BulkUpdate(orders, options => { options.InputColumns = o => new { o.Price }; }); + var newTotal1 = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + var newTotal2 = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(newTotal1 == rowsUpdated + oldTotal, "The count of new orders must be equal the number of rows upd +dated in the database."); + Assert.IsTrue(newTotal2 == 0, "There should be not records with condition (Price = $1.25)"); + } + [TestMethod] + public void With_Options_IgnoreColumns_PropertyExpression() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).OrderBy(o => o.Id).ToList(); + foreach (var order in orders) + { + order.Price = 2.35M; + order.ExternalId = null; + } + var oldTotal = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + int rowsUpdated = dbContext.BulkUpdate(orders, options => { options.IgnoreColumns = o => o.ExternalId; }); + var newTotal1 = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + var newTotal2 = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(newTotal1 == rowsUpdated + oldTotal, "The count of new orders must be equal the number of rows upd +dated in the database."); + Assert.IsTrue(newTotal2 == 0, "There should be not records with condition (Price = $1.25)"); + } + [TestMethod] + public void With_Options_IgnoreColumns_NewExpression() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).OrderBy(o => o.Id).ToList(); + foreach (var order in orders) + { + order.Price = 2.35M; + order.ExternalId = null; + } + var oldTotal = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + int rowsUpdated = dbContext.BulkUpdate(orders, options => { options.IgnoreColumns = o => new { o.ExternalId }; } +}); + var newTotal1 = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + var newTotal2 = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(newTotal1 == rowsUpdated + oldTotal, "The count of new orders must be equal the number of rows upd +dated in the database."); + Assert.IsTrue(newTotal2 == 0, "There should be not records with condition (Price = $1.25)"); + } + [TestMethod] + public void With_Options_UpdateOnCondition() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M).OrderBy(o => o.Id).ToList(); + int ordersWithExternalId = orders.Where(o => o.ExternalId != null).Count(); + foreach (var order in orders) + { + order.Price = 2.35M; + } + var oldTotal = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + int rowsUpdated = dbContext.BulkUpdate(orders, options => { options.UpdateOnCondition = (s, t) => s.ExternalId = +== t.ExternalId; }); + var newTotal = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(rowsUpdated == ordersWithExternalId, "The number of rows updated must match the count of entities + that were retrieved"); + Assert.IsTrue(newTotal == rowsUpdated + oldTotal, "The count of new orders must be equal the number of rows upda +ated in the database."); + } + [TestMethod] + public void With_Options_UpdateOnCondition_Enum() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(o => o.Price == 1.25M).OrderBy(o => o.Id).ToList(); + foreach (var product in products) + { + product.Price = 2.35M; + } + int rowsUpdated = dbContext.BulkUpdate(products, o => + { + o.UpdateOnCondition = (s, t) => s.Id == t.Id && s.StatusEnum == t.StatusEnum; + }); + var newProducts = dbContext.Products.Where(o => o.Price == 2.35M).OrderBy(o => o.Id).Count(); + + Assert.IsTrue(products.Count > 0, "There must be products in database that match this condition (Price = $1.25)" +"); + Assert.IsTrue(rowsUpdated == products.Count, "The number of rows updated must match the count of entities that w +were retrieved"); + Assert.IsTrue(newProducts == rowsUpdated, "The count of new products must be equal the number of rows updated in +n the database."); + } + [TestMethod] + public void With_Transaction() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M).OrderBy(o => o.Id).ToList(); + long maxId = 0; + foreach (var order in orders) + { + order.Price = 2.35M; + maxId = order.Id; + } + int rowsUpdated, newOrders; + using (var transaction = dbContext.Database.BeginTransaction()) + { + rowsUpdated = dbContext.BulkUpdate(orders); + newOrders = dbContext.Orders.Where(o => o.Price == 2.35M).Count(); + transaction.Rollback(); + } + int rollbackTotal = dbContext.Orders.Where(o => o.Price == 1.25M).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(rowsUpdated == orders.Count, "The number of rows updated must match the count of entities that wer +re retrieved"); + Assert.IsTrue(newOrders == rowsUpdated, "The count of new orders must be equal the number of rows updated in the +e database."); + Assert.IsTrue(rollbackTotal == orders.Count, "The number of rows after the transacation has been rollbacked shou +uld match the original count"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\DbContextExtensions\BulkUpdateAsync.cs --- + +using System.Linq; +using System.Threading.Tasks; +using Microsoft.EntityFrameworkCore; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Data; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class BulkUpdateAsync : DbContextExtensionsBase +{ + [TestMethod] + public async Task With_Complex_Key() + { + var dbContext = SetupDbContext(true); + var products = dbContext.ProductsWithComplexKey.Where(o => o.Price == 1.25M).ToList(); + foreach (var product in products) + { + product.Price = 2.35M; + } + var oldTotal = dbContext.ProductsWithComplexKey.Where(o => o.Price == 2.35M).Count(); + int rowsUpdated = await dbContext.BulkUpdateAsync(products); + var newTotal = dbContext.ProductsWithComplexKey.Where(o => o.Price == 2.35M).Count(); + + Assert.IsTrue(products.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(rowsUpdated == products.Count, "The number of rows updated must match the count of entities that w +were retrieved"); + Assert.IsTrue(newTotal == rowsUpdated + oldTotal, "The count of new orders must be equal the number of rows upda +ated in the database."); + } + [TestMethod] + public async Task With_Default_Options() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M).OrderBy(o => o.Id).ToList(); + long maxId = 0; + foreach (var order in orders) + { + order.Price = 2.35M; + maxId = order.Id; + } + int rowsUpdated = await dbContext.BulkUpdateAsync(orders); + var newOrders = dbContext.Orders.Where(o => o.Price == 2.35M).OrderBy(o => o.Id).Count(); + int entitiesWithChanges = dbContext.ChangeTracker.Entries().Where(t => t.State == EntityState.Modified).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(rowsUpdated == orders.Count, "The number of rows updated must match the count of entities that wer +re retrieved"); + Assert.IsTrue(newOrders == rowsUpdated, "The count of new orders must be equal the number of rows updated in the +e database."); + } + [TestMethod] + public async Task With_Inheritance_Tpc() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpc); + var customers = dbContext.TpcPeople.Where(o => o.LastName != "BulkUpdateTest").OfType().ToList(); + var vendors = dbContext.TpcPeople.OfType().ToList(); + foreach (var customer in customers) + { + customer.FirstName = string.Format("Id={0}", customer.Id); + customer.LastName = "BulkUpdate_Tpc"; + } + int rowsUpdated = await dbContext.BulkUpdateAsync(customers, options => { options.UpdateOnCondition = (s, t) => + s.Id == t.Id; }); + var newCustomers = dbContext.TpcPeople.Where(o => o.LastName == "BulkUpdate_Tpc").OfType().Count(); + int entitiesWithChanges = dbContext.ChangeTracker.Entries().Where(t => t.State == EntityState.Modified).Count(); + + Assert.IsTrue(vendors.Count > 0 && vendors.Count != customers.Count, "There should be vendor records in the data +abase"); + Assert.IsTrue(customers.Count > 0, "There must be customers in database that match this condition (Price = $1.25 +5)"); + Assert.IsTrue(rowsUpdated == customers.Count, "The number of rows updated must match the count of entities that + were retrieved"); + Assert.IsTrue(newCustomers == rowsUpdated, "The count of new customers must be equal the number of rows updated + in the database."); + } + [TestMethod] + public async Task With_Inheritance_Tph() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tph); + var customers = dbContext.TphPeople.Where(o => o.LastName != "BulkUpdateTest").OfType().ToList(); + var vendors = dbContext.TphPeople.OfType().ToList(); + foreach (var customer in customers) + { + customer.FirstName = string.Format("Id={0}", customer.Id); + customer.LastName = "BulkUpdateTest"; + } + int rowsUpdated = await dbContext.BulkUpdateAsync(customers); + var newCustomers = dbContext.TphPeople.Where(o => o.LastName == "BulkUpdateTest").OrderBy(o => o.Id).Count(); + int entitiesWithChanges = dbContext.ChangeTracker.Entries().Where(t => t.State == EntityState.Modified).Count(); + + Assert.IsTrue(vendors.Count > 0 && vendors.Count != customers.Count, "There should be vendor records in the data +abase"); + Assert.IsTrue(customers.Count > 0, "There must be customers in database that match this condition (Price = $1.25 +5)"); + Assert.IsTrue(rowsUpdated == customers.Count, "The number of rows updated must match the count of entities that + were retrieved"); + Assert.IsTrue(newCustomers == rowsUpdated, "The count of new customers must be equal the number of rows updated + in the database."); + } + [TestMethod] + public async Task With_Inheritance_Tpt() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpt); + var customers = dbContext.TptCustomers.Where(o => o.LastName != "BulkUpdateTest").ToList(); + foreach (var customer in customers) + { + customer.FirstName = string.Format("Id={0}", customer.Id); + customer.LastName = "BulkUpdateTest"; + } + int rowsUpdated = await dbContext.BulkUpdateAsync(customers); + var newCustomers = await dbContext.TptCustomers.Where(o => o.LastName == "BulkUpdateTest").CountAsync(); + + Assert.IsTrue(customers.Count > 0, "There must be customers in database that match this condition (Price = $1.25 +5)"); + Assert.IsTrue(rowsUpdated == customers.Count, "The number of rows updated must match the count of entities that + were retrieved"); + Assert.IsTrue(newCustomers == rowsUpdated, "The count of new customers must be equal the number of rows updated + in the database."); + } + [TestMethod] + public async Task With_Options_InputColumns_PropertyExpression() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).OrderBy(o => o.Id).ToList(); + foreach (var order in orders) + { + order.Price = 2.35M; + order.ExternalId = null; + } + var oldTotal = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + int rowsUpdated = await dbContext.BulkUpdateAsync(orders, options => { options.InputColumns = o => o.Price; }); + var newTotal1 = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + var newTotal2 = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(newTotal1 == rowsUpdated + oldTotal, "The count of new orders must be equal the number of rows upd +dated in the database."); + Assert.IsTrue(newTotal2 == 0, "There should be not records with condition (Price = $1.25)"); + } + [TestMethod] + public async Task With_Options_InputColumns_NewExpression() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).OrderBy(o => o.Id).ToList(); + foreach (var order in orders) + { + order.Price = 2.35M; + order.ExternalId = null; + } + var oldTotal = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + int rowsUpdated = await dbContext.BulkUpdateAsync(orders, options => { options.InputColumns = o => new { o.Price +e }; }); + var newTotal1 = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + var newTotal2 = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).Count(); + + Assert.IsTrue(orders.Count() > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(newTotal1 == rowsUpdated + oldTotal, "The count of new orders must be equal the number of rows upd +dated in the database."); + Assert.IsTrue(newTotal2 == 0, "There should be not records with condition (Price = $1.25)"); + } + [TestMethod] + public async Task With_Options_IgnoreColumns_PropertyExpression() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).OrderBy(o => o.Id).ToList(); + foreach (var order in orders) + { + order.Price = 2.35M; + order.ExternalId = null; + } + var oldTotal = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + int rowsUpdated = await dbContext.BulkUpdateAsync(orders, options => { options.IgnoreColumns = o => o.ExternalId +d; }); + var newTotal1 = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + var newTotal2 = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(newTotal1 == rowsUpdated + oldTotal, "The count of new orders must be equal the number of rows upd +dated in the database."); + Assert.IsTrue(newTotal2 == 0, "There should be not records with condition (Price = $1.25)"); + } + [TestMethod] + public async Task With_Options_IgnoreColumns_NewExpression() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).OrderBy(o => o.Id).ToList(); + foreach (var order in orders) + { + order.Price = 2.35M; + order.ExternalId = null; + } + var oldTotal = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + int rowsUpdated = await dbContext.BulkUpdateAsync(orders, options => { options.IgnoreColumns = o => new { o.Exte +ernalId }; }); + var newTotal1 = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + var newTotal2 = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(newTotal1 == rowsUpdated + oldTotal, "The count of new orders must be equal the number of rows upd +dated in the database."); + Assert.IsTrue(newTotal2 == 0, "There should be not records with condition (Price = $1.25)"); + } + [TestMethod] + public async Task With_Options_UpdateOnCondition() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M).OrderBy(o => o.Id).ToList(); + int ordersWithExternalId = orders.Where(o => o.ExternalId != null).Count(); + foreach (var order in orders) + { + order.Price = 2.35M; + } + var oldTotal = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + int rowsUpdated = await dbContext.BulkUpdateAsync(orders, options => { options.UpdateOnCondition = (s, t) => s.E +ExternalId == t.ExternalId; }); + var newTotal = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(rowsUpdated == ordersWithExternalId, "The number of rows updated must match the count of entities + that were retrieved"); + Assert.IsTrue(newTotal == rowsUpdated + oldTotal, "The count of new orders must be equal the number of rows upda +ated in the database."); + } + [TestMethod] + public async Task With_Options_UpdateOnCondition_Enum() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(o => o.Price == 1.25M).OrderBy(o => o.Id).ToList(); + foreach (var product in products) + { + product.Price = 2.35M; + } + int rowsUpdated = await dbContext.BulkUpdateAsync(products, o => + { + o.UpdateOnCondition = (s, t) => s.Id == t.Id && s.StatusEnum == t.StatusEnum; + }); + var newProducts = dbContext.Products.Where(o => o.Price == 2.35M).OrderBy(o => o.Id).Count(); + + Assert.IsTrue(products.Count > 0, "There must be products in database that match this condition (Price = $1.25)" +"); + Assert.IsTrue(rowsUpdated == products.Count, "The number of rows updated must match the count of entities that w +were retrieved"); + Assert.IsTrue(newProducts == rowsUpdated, "The count of new products must be equal the number of rows updated in +n the database."); + } + [TestMethod] + public async Task With_Transaction() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M).OrderBy(o => o.Id).ToList(); + long maxId = 0; + foreach (var order in orders) + { + order.Price = 2.35M; + maxId = order.Id; + } + int rowsUpdated, newOrders; + using (var transaction = dbContext.Database.BeginTransaction()) + { + rowsUpdated = await dbContext.BulkUpdateAsync(orders); + newOrders = dbContext.Orders.Where(o => o.Price == 2.35M).Count(); + transaction.Rollback(); + } + int rollbackTotal = dbContext.Orders.Where(o => o.Price == 1.25M).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(rowsUpdated == orders.Count, "The number of rows updated must match the count of entities that wer +re retrieved"); + Assert.IsTrue(newOrders == rowsUpdated, "The count of new orders must be equal the number of rows updated in the +e database."); + Assert.IsTrue(rollbackTotal == orders.Count, "The number of rows after the transacation has been rollbacked shou +uld match the original count"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\DbContextExtensions\DbContextExtensionsBase.cs --- + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Drawing; +using Microsoft.EntityFrameworkCore; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Common; +using N.EntityFrameworkCore.Extensions.Test.Data; +using N.EntityFrameworkCore.Extensions.Test.Data.Enums; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +public enum PopulateDataMode +{ + Normal, + Tpc, + Tph, + Tpt, + Schema +} +[TestClass] +public class DbContextExtensionsBase +{ + [TestInitialize] + public void Init() + { + TestDbContext dbContext = new TestDbContext(); + TestDatabaseInitializer.EnsureCreated(dbContext); + } + protected TestDbContext SetupDbContext(bool populateData, PopulateDataMode mode = PopulateDataMode.Normal) + { + TestDbContext dbContext = new TestDbContext(); + TestDatabaseInitializer.EnsureCreated(dbContext); + dbContext.Orders.Truncate(); + dbContext.Products.Truncate(); + dbContext.ProductCategories.Clear(); + dbContext.ProductsWithCustomSchema.Truncate(); + dbContext.ProductsWithTrigger.Truncate(); + dbContext.Database.ClearTable("TpcCustomer"); + dbContext.Database.ClearTable("TpcVendor"); + dbContext.TphPeople.Truncate(); + dbContext.Database.ClearTable("TptPeople"); + dbContext.Database.ClearTable("TptCustomer"); + dbContext.Database.ClearTable("TptVendor"); + dbContext.Database.DropTable("ProductsUnderTen", true); + dbContext.Database.DropTable("OrdersUnderTen", true); + dbContext.Database.DropTable("OrdersLast30Days", true); + if (populateData) + { + if (mode == PopulateDataMode.Normal) + { + var orders = new List(); + int id = 1; + for (int i = 0; i < 2050; i++) + { + DateTime addedDateTime = DateTime.UtcNow.AddDays(-id); + orders.Add(new Order + { + Id = id, + ExternalId = string.Format("id-{0}", i), + Price = 1.25M, + AddedDateTime = addedDateTime, + ModifiedDateTime = addedDateTime.AddHours(3), + Status = OrderStatus.Completed + }); + id++; + } + for (int i = 0; i < 1050; i++) + { + orders.Add(new Order { Id = id, Price = 5.35M }); + id++; + } + for (int i = 0; i < 2050; i++) + { + orders.Add(new Order { Id = id, Price = 1.25M }); + id++; + } + for (int i = 0; i < 6000; i++) + { + orders.Add(new Order { Id = id, Price = 15.35M }); + id++; + } + for (int i = 0; i < 6000; i++) + { + orders.Add(new Order { Id = id, Price = 15.35M }); + id++; + } + + Debug.WriteLine("Last Id for Order is {0}", id); + dbContext.BulkInsert(orders, new BulkInsertOptions() { KeepIdentity = true }); + + var productCategories = new List() + { + new ProductCategory { Id=1, Name="Category-1", Active=true}, + new ProductCategory { Id=2, Name="Category-2", Active=true}, + new ProductCategory { Id=3, Name="Category-3", Active=true}, + new ProductCategory { Id=4, Name="Category-4", Active=false}, + }; + dbContext.BulkInsert(productCategories, o => { o.KeepIdentity = true; o.UsePermanentTable = true; }); + var products = new List(); + id = 1; + for (int i = 0; i < 2050; i++) + { + products.Add(new Product + { + Id = i.ToString(), + Price = 1.25M, + OutOfStock = false, + ProductCategoryId = 4, + StatusEnum = ProductStatus.InStock, + Color = Color.Black, + Position = new Position { Building = 5, Aisle = 33, Bay = i }, + }); + id++; + } + for (int i = 2050; i < 7000; i++) + { + products.Add(new Product { Id = i.ToString(), Price = 1.25M, OutOfStock = true, StatusEnum = Product +tStatus.OutOfStock }); + id++; + } + + Debug.WriteLine("Last Id for Product is {0}", id); + dbContext.BulkInsert(products, new BulkInsertOptions() { KeepIdentity = false, AutoMapOutput = + false, UsePermanentTable = true }); + + //ProductWithComplexKey + var productsWithComplexKey = new List(); + id = 1; + + for (int i = 0; i < 2050; i++) + { + productsWithComplexKey.Add(new ProductWithComplexKey { Price = 1.25M }); + id++; + } + + Debug.WriteLine("Last Id for ProductsWithComplexKey is {0}", id); + dbContext.BulkInsert(productsWithComplexKey, new BulkInsertOptions() { KeepIdenti +ity = false, AutoMapOutput = false }); + } + else if (mode == PopulateDataMode.Tph) + { + //TPH Customers & Vendors + var tphCustomers = new List(); + var tphVendors = new List(); + for (int i = 0; i < 2000; i++) + { + tphCustomers.Add(new TphCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + Phone = "404-555-1111", + AddedDate = DateTime.UtcNow + }); + } + for (int i = 2000; i < 3000; i++) + { + tphVendors.Add(new TphVendor + { + Id = i, + FirstName = string.Format("Mike_{0}", i), + LastName = string.Format("Smith_{0}", i), + Phone = "404-555-2222", + Email = string.Format("mike.smith{0}@domain.com", i), + Url = string.Format("http://domain.com/mike.smith{0}", i) + }); + } + dbContext.BulkInsert(tphCustomers, new BulkInsertOptions() { KeepIdentity = true }); + dbContext.BulkInsert(tphVendors, new BulkInsertOptions() { KeepIdentity = true }); + } + else if (mode == PopulateDataMode.Tpc) + { + //TPC Customers & Vendors + var tpcCustomers = new List(); + var tpcVendors = new List(); + for (int i = 1; i <= 2000; i++) + { + tpcCustomers.Add(new TpcCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + Phone = "404-555-1111", + AddedDate = DateTime.UtcNow + }); + } + for (int i = 2001; i <= 3000; i++) + { + tpcVendors.Add(new TpcVendor + { + Id = i, + FirstName = string.Format("Mike_{0}", i), + LastName = string.Format("Smith_{0}", i), + Phone = "404-555-2222", + Email = string.Format("mike.smith{0}@domain.com", i), + Url = string.Format("http://domain.com/mike.smith{0}", i) + }); + } + dbContext.BulkInsert(tpcCustomers, new BulkInsertOptions() { KeepIdentity = true }); + dbContext.BulkInsert(tpcVendors, new BulkInsertOptions() { KeepIdentity = true }); + } + else if (mode == PopulateDataMode.Tpt) + { + //Customers & Vendors + var tptCustomers = new List(); + var tptVendors = new List(); + for (int i = 1; i <= 2000; i++) + { + tptCustomers.Add(new TptCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + Phone = "404-555-1111", + AddedDate = DateTime.UtcNow + }); + } + for (int i = 2001; i < 3000; i++) + { + tptVendors.Add(new TptVendor + { + Id = i, + FirstName = string.Format("Mike_{0}", i), + LastName = string.Format("Smith_{0}", i), + Phone = "404-555-2222", + Email = string.Format("mike.smith{0}@domain.com", i), + Url = string.Format("http://domain.com/mike.smith{0}", i) + }); + } + dbContext.BulkInsert(tptCustomers, new BulkInsertOptions() { KeepIdentity = true, UsePerman +nentTable = true }); + dbContext.BulkInsert(tptVendors, new BulkInsertOptions() { KeepIdentity = true }); + } + else if (mode == PopulateDataMode.Schema) + { + //ProductWithCustomSchema + var productsWithCustomSchema = new List(); + int id = 1; + + for (int i = 0; i < 2050; i++) + { + productsWithCustomSchema.Add(new ProductWithCustomSchema { Id = id.ToString(), Price = 1.25M }); + id++; + } + for (int i = 2050; i < 5000; i++) + { + productsWithCustomSchema.Add(new ProductWithCustomSchema { Id = id.ToString(), Price = 6.75M }); + id++; + } + + dbContext.BulkInsert(productsWithCustomSchema); + } + } + return dbContext; + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\DbContextExtensions\DeleteFromQuery.cs --- + +using System; +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class DeleteFromQuery : DbContextExtensionsBase +{ + [TestMethod] + public void With_Boolean_Value() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(p => p.OutOfStock); + int oldTotal = products.Count(a => a.OutOfStock); + int rowUpdated = products.DeleteFromQuery(); + int newTotal = dbContext.Products.Count(o => o.OutOfStock); + + Assert.IsTrue(oldTotal > 0, "There must be products in database that match this condition (OutOfStock == true)") +); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndition (OutOfStock == false)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public void With_Child_Relationship() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(p => !p.ProductCategory.Active); + int oldTotal = products.Count(); + int rowsDeleted = products.DeleteFromQuery(); + int newTotal = products.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be products in database that match this condition (ProductCategory.Activ +ve == false)"); + Assert.IsTrue(rowsDeleted == oldTotal, "The number of rows update must match the count of rows that match the co +ondition (ProductCategory.Active == false)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were deleted"); + } + [TestMethod] + public void With_Decimal_Using_IQueryable() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price <= 10); + int oldTotal = orders.Count(); + int rowsDeleted = orders.DeleteFromQuery(); + int newTotal = orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(rowsDeleted == oldTotal, "The number of rows deleted must match the count of existing rows in data +abase"); + Assert.IsTrue(newTotal == 0, "Delete() Failed: must be 0 to indicate all records were deleted"); + } + [TestMethod] + public void With_Decimal_Using_IEnumerable() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price <= 10); + int oldTotal = orders.Count(); + int rowsDeleted = orders.DeleteFromQuery(); + int newTotal = orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(rowsDeleted == oldTotal, "The number of rows deleted must match the count of existing rows in data +abase"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were deleted"); + } + [TestMethod] + public void With_DateTime() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Count(); + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + int rowsToDelete = dbContext.Orders.Where(o => o.ModifiedDateTime != null && o.ModifiedDateTime >= dateTime).Cou +unt(); + int rowsDeleted = dbContext.Orders.Where(o => o.ModifiedDateTime != null && o.ModifiedDateTime >= dateTime) + .DeleteFromQuery(); + int newTotal = dbContext.Orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(rowsDeleted == rowsToDelete, "The number of rows deleted must match the count of the rows that mat +tched in the database"); + Assert.IsTrue(oldTotal - newTotal == rowsDeleted, "The rows deleted must match the new count minues the old coun +nt"); + } + [TestMethod] + public void With_Delete_All() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Count(); + int rowsDeleted = dbContext.Orders.DeleteFromQuery(); + int newTotal = dbContext.Orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(rowsDeleted == oldTotal, "The number of rows deleted must match the count of existing rows in data +abase"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were deleted"); + } + [TestMethod] + public void With_Different_Values() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Count(); + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + var orders = dbContext.Orders.Where(o => o.Id == 1 && o.Active && o.ModifiedDateTime >= dateTime); + int rowsToDelete = orders.Count(); + int rowsDeleted = orders.DeleteFromQuery(); + int newTotal = dbContext.Orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(rowsDeleted == rowsToDelete, "The number of rows deleted must match the count of the rows that mat +tched in the database"); + Assert.IsTrue(oldTotal - newTotal == rowsDeleted, "The rows deleted must match the new count minues the old coun +nt"); + } + [TestMethod] + public void With_Empty_List() + { + var dbContext = SetupDbContext(false); + int oldTotal = dbContext.Orders.Count(); + int rowsDeleted = dbContext.Orders.DeleteFromQuery(); + int newTotal = dbContext.Orders.Count(); + + Assert.IsTrue(oldTotal == 0, "There must be no orders in database that match this condition"); + Assert.IsTrue(rowsDeleted == oldTotal, "The number of rows deleted must match the count of existing rows in data +abase"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were deleted"); + } + [TestMethod] + public void With_Schema() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Schema); + int oldTotal = dbContext.ProductsWithCustomSchema.Count(); + int rowsDeleted = dbContext.ProductsWithCustomSchema.DeleteFromQuery(); + int newTotal = dbContext.ProductsWithCustomSchema.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be products in database that match this condition"); + Assert.IsTrue(rowsDeleted == oldTotal, "The number of rows deleted must match the count of existing rows in data +abase"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were deleted"); + } + [TestMethod] + public void With_Transaction() + { + var dbContext = SetupDbContext(true); + int rowsDeleted; + int oldTotal = dbContext.Orders.Count(); + var orders = dbContext.Orders.Where(o => o.Price <= 10); + int rowsToDelete = orders.Count(); + using (var transaction = dbContext.Database.BeginTransaction()) + { + rowsDeleted = orders.DeleteFromQuery(); + transaction.Rollback(); + } + int newTotal = dbContext.Orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < $10)"); + Assert.IsTrue(rowsDeleted == orders.Count(), "The number of rows update must match the count of rows that match + the condtion (Price < $10)"); + Assert.IsTrue(newTotal == oldTotal, "The new count must match the old count since the transaction was rollbacked +d"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\DbContextExtensions\DeleteFromQueryAsync.cs --- + +using System; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class DeleteFromQueryAsync : DbContextExtensionsBase +{ + [TestMethod] + public async Task With_Boolean_Value() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(p => p.OutOfStock); + int oldTotal = products.Count(a => a.OutOfStock); + int rowUpdated = await products.DeleteFromQueryAsync(); + int newTotal = dbContext.Products.Count(o => o.OutOfStock); + + Assert.IsTrue(oldTotal > 0, "There must be products in database that match this condition (OutOfStock == true)") +); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndition (OutOfStock == false)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public async Task With_Child_Relationship() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(p => !p.ProductCategory.Active); + int oldTotal = products.Count(); + int rowsDeleted = await products.DeleteFromQueryAsync(); + int newTotal = products.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be products in database that match this condition (ProductCategory.Activ +ve == false)"); + Assert.IsTrue(rowsDeleted == oldTotal, "The number of rows update must match the count of rows that match the co +ondition (ProductCategory.Active == false)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were deleted"); + } + [TestMethod] + public async Task With_Decimal_Using_IQueryable() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price <= 10); + int oldTotal = orders.Count(); + int rowsDeleted = await orders.DeleteFromQueryAsync(); + int newTotal = orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(rowsDeleted == oldTotal, "The number of rows deleted must match the count of existing rows in data +abase"); + Assert.IsTrue(newTotal == 0, "Delete() Failed: must be 0 to indicate all records were deleted"); + } + [TestMethod] + public async Task With_Decimal_Using_IEnumerable() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price <= 10); + int oldTotal = orders.Count(); + int rowsDeleted = await orders.DeleteFromQueryAsync(); + int newTotal = orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(rowsDeleted == oldTotal, "The number of rows deleted must match the count of existing rows in data +abase"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were deleted"); + } + [TestMethod] + public async Task With_DateTime() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Count(); + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + int rowsToDelete = dbContext.Orders.Where(o => o.ModifiedDateTime != null && o.ModifiedDateTime >= dateTime).Cou +unt(); + int rowsDeleted = await dbContext.Orders.Where(o => o.ModifiedDateTime != null && o.ModifiedDateTime >= dateTime +e) + .DeleteFromQueryAsync(); + int newTotal = dbContext.Orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(rowsDeleted == rowsToDelete, "The number of rows deleted must match the count of the rows that mat +tched in the database"); + Assert.IsTrue(oldTotal - newTotal == rowsDeleted, "The rows deleted must match the new count minues the old coun +nt"); + } + [TestMethod] + public async Task With_Delete_All() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Count(); + int rowsDeleted = await dbContext.Orders.DeleteFromQueryAsync(); + int newTotal = dbContext.Orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(rowsDeleted == oldTotal, "The number of rows deleted must match the count of existing rows in data +abase"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were deleted"); + } + [TestMethod] + public async Task With_Different_Values() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Count(); + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + var orders = dbContext.Orders.Where(o => o.Id == 1 && o.Active && o.ModifiedDateTime >= dateTime); + int rowsToDelete = orders.Count(); + int rowsDeleted = await orders.DeleteFromQueryAsync(); + int newTotal = dbContext.Orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(rowsDeleted == rowsToDelete, "The number of rows deleted must match the count of the rows that mat +tched in the database"); + Assert.IsTrue(oldTotal - newTotal == rowsDeleted, "The rows deleted must match the new count minues the old coun +nt"); + } + [TestMethod] + public async Task With_Empty_List() + { + var dbContext = SetupDbContext(false); + int oldTotal = dbContext.Orders.Count(); + int rowsDeleted = await dbContext.Orders.DeleteFromQueryAsync(); + int newTotal = dbContext.Orders.Count(); + + Assert.IsTrue(oldTotal == 0, "There must be no orders in database that match this condition"); + Assert.IsTrue(rowsDeleted == oldTotal, "The number of rows deleted must match the count of existing rows in data +abase"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were deleted"); + } + [TestMethod] + public async Task With_Schema() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Schema); + int oldTotal = dbContext.ProductsWithCustomSchema.Count(); + int rowsDeleted = await dbContext.ProductsWithCustomSchema.DeleteFromQueryAsync(); + int newTotal = dbContext.ProductsWithCustomSchema.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be products in database that match this condition"); + Assert.IsTrue(rowsDeleted == oldTotal, "The number of rows deleted must match the count of existing rows in data +abase"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were deleted"); + } + [TestMethod] + public async Task With_Transaction() + { + var dbContext = SetupDbContext(true); + int rowsDeleted; + int oldTotal = dbContext.Orders.Count(); + var orders = dbContext.Orders.Where(o => o.Price <= 10); + int rowsToDelete = orders.Count(); + using (var transaction = dbContext.Database.BeginTransaction()) + { + rowsDeleted = await orders.DeleteFromQueryAsync(); + transaction.Rollback(); + } + int newTotal = dbContext.Orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < $10)"); + Assert.IsTrue(rowsDeleted == orders.Count(), "The number of rows update must match the count of rows that match + the condtion (Price < $10)"); + Assert.IsTrue(newTotal == oldTotal, "The new count must match the old count since the transaction was rollbacked +d"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\DbContextExtensions\Fetch.cs --- + +using System; +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class Fetch : DbContextExtensionsBase +{ + [TestMethod] + public void With_BulkInsert() + { + var dbContext = SetupDbContext(true); + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + var orders = dbContext.Orders.Where(o => o.AddedDateTime <= dateTime); + int totalOrdersToFetch = orders.Count(); + int totalOrdersFetched = 0; + int batchSize = 5000; + orders.Fetch(result => + { + totalOrdersFetched += result.Results.Count(); + var ordersFetched = result.Results; + foreach (var orderFetched in ordersFetched) + { + orderFetched.Price = 75; + } + dbContext.BulkInsert(ordersFetched); + }, options => { options.BatchSize = batchSize; }); + + int totalOrder = orders.Count(); + int totalOrderInserted = orders.Where(o => o.Price == 75).Count(); + Assert.IsTrue(totalOrdersToFetch == totalOrdersFetched, "The total number of rows fetched must match the number + of rows to fetch"); + Assert.IsTrue(totalOrderInserted == totalOrdersFetched, "The total number of rows updated must match the number + of rows that were fetched"); + Assert.IsTrue(totalOrder - totalOrdersToFetch == totalOrderInserted, "The total number of rows must match the nu +umber of rows that were updated"); + } + [TestMethod] + public void With_BulkUpdate() + { + var dbContext = SetupDbContext(true); + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + var orders = dbContext.Orders.Where(o => o.AddedDateTime <= dateTime); + int totalOrdersToFetch = orders.Count(); + int totalOrdersFetched = 0; + int batchSize = 5000; + orders.Fetch(result => + { + totalOrdersFetched += result.Results.Count(); + var ordersFetched = result.Results; + foreach (var orderFetched in ordersFetched) + { + orderFetched.Price = 75; + } + dbContext.BulkUpdate(ordersFetched); + }, options => { options.BatchSize = batchSize; }); + + int totalOrder = orders.Count(); + int totalOrderUpdated = orders.Where(o => o.Price == 75).Count(); + Assert.IsTrue(totalOrdersToFetch == totalOrdersFetched, "The total number of rows fetched must match the number + of rows to fetch"); + Assert.IsTrue(totalOrderUpdated == totalOrdersFetched, "The total number of rows updated must match the number o +of rows that were fetched"); + Assert.IsTrue(totalOrder == totalOrderUpdated, "The total number of rows must match the number of rows that were +e updated"); + } + [TestMethod] + public void With_DateTime() + { + var dbContext = SetupDbContext(true); + int batchSize = 1000; + int batchCount = 0; + int totalCount = 0; + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + var orders = dbContext.Orders.Where(o => o.AddedDateTime <= dateTime); + int expectedTotalCount = orders.Count(); + int expectedBatchCount = (int)Math.Ceiling(expectedTotalCount / (decimal)batchSize); + + orders.Fetch(result => + { + batchCount++; + totalCount += result.Results.Count(); + Assert.IsTrue(result.Results.Count <= batchSize, "The count of results in each batch callback should less th +han or equal to the batchSize"); + }, options => { options.BatchSize = batchSize; }); + + Assert.IsTrue(expectedTotalCount > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(expectedTotalCount == totalCount, "The total number of rows fetched must match the count of existi +ing rows in database"); + Assert.IsTrue(expectedBatchCount == batchCount, "The total number of batches fetched must match what is expected +d"); + } + [TestMethod] + public void With_Decimal() + { + var dbContext = SetupDbContext(true); + int batchSize = 1000; + int batchCount = 0; + int totalCount = 0; + var orders = dbContext.Orders.Where(o => o.Price < 10M); + int expectedTotalCount = orders.Count(); + int expectedBatchCount = (int)Math.Ceiling(expectedTotalCount / (decimal)batchSize); + + orders.Fetch(result => + { + batchCount++; + totalCount += result.Results.Count(); + Assert.IsTrue(result.Results.Count <= batchSize, "The count of results in each batch callback should less th +han or equal to the batchSize"); + }, options => { options.BatchSize = batchSize; }); + + Assert.IsTrue(expectedTotalCount > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(expectedTotalCount == totalCount, "The total number of rows fetched must match the count of existi +ing rows in database"); + Assert.IsTrue(expectedBatchCount == batchCount, "The total number of batches fetched must match what is expected +d"); + } + [TestMethod] + public void With_Enum() + { + var dbContext = SetupDbContext(true); + int batchSize = 1000; + int batchCount = 0; + int totalCount = 0; + var products = dbContext.Products.Where(o => o.Price < 10M); + int expectedTotalCount = products.Count(); + int expectedBatchCount = (int)Math.Ceiling(expectedTotalCount / (decimal)batchSize); + + products.Fetch(result => + { + batchCount++; + totalCount += result.Results.Count(); + Assert.IsTrue(result.Results.Count <= batchSize, "The count of results in each batch callback should less th +han or equal to the batchSize"); + }, options => { options.BatchSize = batchSize; }); + + Assert.IsTrue(expectedTotalCount > 0, "There must be products in database that match this condition"); + Assert.IsTrue(expectedTotalCount == totalCount, "The total number of rows fetched must match the count of existi +ing rows in database"); + Assert.IsTrue(expectedBatchCount == batchCount, "The total number of batches fetched must match what is expected +d"); + } + [TestMethod] + public void With_Options_IgnoreColumns() + { + var dbContext = SetupDbContext(true); + int batchSize = 1000; + int batchCount = 0; + int totalCount = 0; + var orders = dbContext.Orders.Where(o => o.Price < 10M); + int expectedTotalCount = orders.Count(); + int expectedBatchCount = (int)Math.Ceiling(expectedTotalCount / (decimal)batchSize); + + orders.Fetch(result => + { + batchCount++; + totalCount += result.Results.Count(); + bool isAllExternalIdNull = !result.Results.Any(o => o.ExternalId != null); + Assert.IsTrue(isAllExternalIdNull, "All records should have ExternalId equal to NULL since it was not loaded +d."); + Assert.IsTrue(result.Results.Count <= batchSize, "The count of results in each batch callback should less th +han or equal to the batchSize"); + }, options => { options.BatchSize = batchSize; options.IgnoreColumns = s => new { s.ExternalId }; }); + + Assert.IsTrue(expectedTotalCount > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(expectedTotalCount == totalCount, "The total number of rows fetched must match the count of existi +ing rows in database"); + Assert.IsTrue(expectedBatchCount == batchCount, "The total number of batches fetched must match what is expected +d"); + } + [TestMethod] + public void With_Options_InputColumns() + { + var dbContext = SetupDbContext(true); + int batchSize = 1000; + int batchCount = 0; + int totalCount = 0; + var orders = dbContext.Orders.Where(o => o.Price < 10M); + int expectedTotalCount = orders.Count(); + int expectedBatchCount = (int)Math.Ceiling(expectedTotalCount / (decimal)batchSize); + + orders.Fetch(result => + { + batchCount++; + totalCount += result.Results.Count(); + bool isAllExternalIdNull = !result.Results.Any(o => o.ExternalId != null); + Assert.IsTrue(isAllExternalIdNull, "All records should have ExternalId equal to NULL since it was not loaded +d."); + Assert.IsTrue(result.Results.Count <= batchSize, "The count of results in each batch callback should less th +han or equal to the batchSize"); + }, options => { options.BatchSize = batchSize; options.InputColumns = s => new { s.Id, s.Price }; }); + + Assert.IsTrue(expectedTotalCount > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(expectedTotalCount == totalCount, "The total number of rows fetched must match the count of existi +ing rows in database"); + Assert.IsTrue(expectedBatchCount == batchCount, "The total number of batches fetched must match what is expected +d"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\DbContextExtensions\FetchAsync.cs --- + +using System; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class FetchAsync : DbContextExtensionsBase +{ + [TestMethod] + public async Task With_BulkInsert() + { + var dbContext = SetupDbContext(true); + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + var orders = dbContext.Orders.Where(o => o.AddedDateTime <= dateTime); + int totalOrdersToFetch = orders.Count(); + int totalOrdersFetched = 0; + int batchSize = 5000; + await orders.FetchAsync(async result => + { + totalOrdersFetched += result.Results.Count; + var ordersFetched = result.Results; + foreach (var orderFetched in ordersFetched) + { + orderFetched.Price = 75; + } + await dbContext.BulkInsertAsync(ordersFetched); + }, options => { options.BatchSize = batchSize; }); + + int totalOrder = orders.Count(); + int totalOrderInserted = orders.Where(o => o.Price == 75).Count(); + Assert.IsTrue(totalOrdersToFetch == totalOrdersFetched, "The total number of rows fetched must match the number + of rows to fetch"); + Assert.IsTrue(totalOrderInserted == totalOrdersFetched, "The total number of rows updated must match the number + of rows that were fetched"); + Assert.IsTrue(totalOrder - totalOrdersToFetch == totalOrderInserted, "The total number of rows must match the nu +umber of rows that were updated"); + } + [TestMethod] + public async Task With_BulkUpdate() + { + var dbContext = SetupDbContext(true); + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + var orders = dbContext.Orders.Where(o => o.AddedDateTime <= dateTime); + int totalOrdersToFetch = orders.Count(); + int totalOrdersFetched = 0; + int batchSize = 5000; + await orders.FetchAsync(async result => + { + totalOrdersFetched += result.Results.Count; + var ordersFetched = result.Results; + foreach (var orderFetched in ordersFetched) + { + orderFetched.Price = 75; + } + await dbContext.BulkUpdateAsync(ordersFetched); + }, options => { options.BatchSize = batchSize; }); + + int totalOrder = orders.Count(); + int totalOrderUpdated = orders.Where(o => o.Price == 75).Count(); + Assert.IsTrue(totalOrdersToFetch == totalOrdersFetched, "The total number of rows fetched must match the number + of rows to fetch"); + Assert.IsTrue(totalOrderUpdated == totalOrdersFetched, "The total number of rows updated must match the number o +of rows that were fetched"); + Assert.IsTrue(totalOrder == totalOrderUpdated, "The total number of rows must match the number of rows that were +e updated"); + } + [TestMethod] + public async Task With_DateTime() + { + var dbContext = SetupDbContext(true); + int batchSize = 1000; + int batchCount = 0; + int totalCount = 0; + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + var orders = dbContext.Orders.Where(o => o.AddedDateTime <= dateTime); + int expectedTotalCount = orders.Count(); + int expectedBatchCount = (int)Math.Ceiling(expectedTotalCount / (decimal)batchSize); + + await orders.FetchAsync(async result => + { + await Task.Run(() => + { + batchCount++; + totalCount += result.Results.Count; + Assert.IsTrue(result.Results.Count <= batchSize, "The count of results in each batch callback should les +ss than or equal to the batchSize"); + }); + }, options => { options.BatchSize = batchSize; }); + + Assert.IsTrue(expectedTotalCount > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(expectedTotalCount == totalCount, "The total number of rows fetched must match the count of existi +ing rows in database"); + Assert.IsTrue(expectedBatchCount == batchCount, "The total number of batches fetched must match what is expected +d"); + } + [TestMethod] + public async Task With_Decimal() + { + var dbContext = SetupDbContext(true); + int batchSize = 1000; + int batchCount = 0; + int totalCount = 0; + var orders = dbContext.Orders.Where(o => o.Price < 10M); + int expectedTotalCount = orders.Count(); + int expectedBatchCount = (int)Math.Ceiling(expectedTotalCount / (decimal)batchSize); + + await orders.FetchAsync(async result => + { + await Task.Run(() => + { + batchCount++; + totalCount += result.Results.Count(); + Assert.IsTrue(result.Results.Count <= batchSize, "The count of results in each batch callback should les +ss than or equal to the batchSize"); + }); + }, options => { options.BatchSize = batchSize; }); + + Assert.IsTrue(expectedTotalCount > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(expectedTotalCount == totalCount, "The total number of rows fetched must match the count of existi +ing rows in database"); + Assert.IsTrue(expectedBatchCount == batchCount, "The total number of batches fetched must match what is expected +d"); + } + [TestMethod] + public async Task With_Enum() + { + var dbContext = SetupDbContext(true); + int batchSize = 1000; + int batchCount = 0; + int totalCount = 0; + var products = dbContext.Products.Where(o => o.Price < 10M); + int expectedTotalCount = products.Count(); + int expectedBatchCount = (int)Math.Ceiling(expectedTotalCount / (decimal)batchSize); + + await products.FetchAsync(async result => + { + await Task.Run(() => + { + batchCount++; + totalCount += result.Results.Count(); + Assert.IsTrue(result.Results.Count <= batchSize, "The count of results in each batch callback should les +ss than or equal to the batchSize"); + }); + }, options => { options.BatchSize = batchSize; }); + + Assert.IsTrue(expectedTotalCount > 0, "There must be products in database that match this condition"); + Assert.IsTrue(expectedTotalCount == totalCount, "The total number of rows fetched must match the count of existi +ing rows in database"); + Assert.IsTrue(expectedBatchCount == batchCount, "The total number of batches fetched must match what is expected +d"); + } + [TestMethod] + public async Task With_Options_IgnoreColumns() + { + var dbContext = SetupDbContext(true); + int batchSize = 1000; + int batchCount = 0; + int totalCount = 0; + var orders = dbContext.Orders.Where(o => o.Price < 10M); + int expectedTotalCount = orders.Count(); + int expectedBatchCount = (int)Math.Ceiling(expectedTotalCount / (decimal)batchSize); + + await orders.FetchAsync(async result => + { + await Task.Run(() => + { + batchCount++; + totalCount += result.Results.Count; + bool isAllExternalIdNull = !result.Results.Any(o => o.ExternalId != null); + Assert.IsTrue(isAllExternalIdNull, "All records should have ExternalId equal to NULL since it was not lo +oaded."); + Assert.IsTrue(result.Results.Count <= batchSize, "The count of results in each batch callback should les +ss than or equal to the batchSize"); + }); + }, options => { options.BatchSize = batchSize; options.IgnoreColumns = s => new { s.ExternalId }; }); + + Assert.IsTrue(expectedTotalCount > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(expectedTotalCount == totalCount, "The total number of rows fetched must match the count of existi +ing rows in database"); + Assert.IsTrue(expectedBatchCount == batchCount, "The total number of batches fetched must match what is expected +d"); + } + [TestMethod] + public async Task With_Options_InputColumns() + { + var dbContext = SetupDbContext(true); + int batchSize = 1000; + int batchCount = 0; + int totalCount = 0; + var orders = dbContext.Orders.Where(o => o.Price < 10M); + int expectedTotalCount = orders.Count(); + int expectedBatchCount = (int)Math.Ceiling(expectedTotalCount / (decimal)batchSize); + + await orders.FetchAsync(async result => + { + await Task.Run(() => + { + batchCount++; + totalCount += result.Results.Count(); + bool isAllExternalIdNull = !result.Results.Any(o => o.ExternalId != null); + Assert.IsTrue(isAllExternalIdNull, "All records should have ExternalId equal to NULL since it was not lo +oaded."); + Assert.IsTrue(result.Results.Count <= batchSize, "The count of results in each batch callback should les +ss than or equal to the batchSize"); + }); + }, options => { options.BatchSize = batchSize; options.InputColumns = s => new { s.Id, s.Price }; }); + + Assert.IsTrue(expectedTotalCount > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(expectedTotalCount == totalCount, "The total number of rows fetched must match the count of existi +ing rows in database"); + Assert.IsTrue(expectedBatchCount == batchCount, "The total number of batches fetched must match what is expected +d"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\DbContextExtensions\InsertFromQuery.cs --- + +using System; +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class InsertFromQuery : DbContextExtensionsBase +{ + [TestMethod] + public void With_DateTime_Value() + { + var dbContext = SetupDbContext(true); + string tableName = "OrdersLast30Days"; + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + int oldTotal = dbContext.Orders.Count(); + + var orders = dbContext.Orders.Where(o => o.AddedDateTime >= dateTime); + int oldSourceTotal = orders.Count(); + int rowsInserted = orders.InsertFromQuery(tableName, + o => new { o.Id, o.ExternalId, o.Price, o.AddedDateTime, o.ModifiedDateTime, o.Active }); + int newSourceTotal = orders.Count(); + int newTargetTotal = orders.UsingTable(tableName).Count(); + + Assert.IsTrue(oldTotal > oldSourceTotal, "The total should be greater then the number of rows selected from the + source table"); + Assert.IsTrue(oldSourceTotal > 0, "There should be existing data in the source table"); + Assert.IsTrue(oldSourceTotal == newSourceTotal, "There should not be any change in the count of rows in the sour +rce table"); + Assert.IsTrue(rowsInserted == oldSourceTotal, "The number of records inserted must match the count of the sourc +ce table"); + Assert.IsTrue(rowsInserted == newTargetTotal, "The different in count in the target table before and after the i +insert must match the total row inserted"); + } + [TestMethod] + public void With_Decimal_Value() + { + var dbContext = SetupDbContext(true); + string tableName = "OrdersUnderTen"; + var orders = dbContext.Orders.Where(o => o.Price < 10M); + int oldSourceTotal = orders.Count(); + int rowsInserted = dbContext.Orders.Where(o => o.Price < 10M).InsertFromQuery(tableName, o => new { o.Id, o.Pric +ce, o.AddedDateTime, o.Active }); + int newSourceTotal = orders.Count(); + int newTargetTotal = orders.UsingTable(tableName).Count(); + + Assert.IsTrue(oldSourceTotal > 0, "There should be existing data in the source table"); + Assert.IsTrue(oldSourceTotal == newSourceTotal, "There should not be any change in the count of rows in the sour +rce table"); + Assert.IsTrue(rowsInserted == oldSourceTotal, "The number of records inserted must match the count of the sourc +ce table"); + Assert.IsTrue(rowsInserted == newTargetTotal, "The different in count in the target table before and after the i +insert must match the total row inserted"); + } + [TestMethod] + public void With_Schema() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Schema); + string tableName = "ProductsUnderTen"; + var products = dbContext.ProductsWithCustomSchema.Where(o => o.Price < 10M); + int oldSourceTotal = products.Count(); + int rowsInserted = dbContext.ProductsWithCustomSchema.Where(o => o.Price < 10M).InsertFromQuery(tableName, o => + new { o.Id, o.Price }); + int newSourceTotal = products.Count(); + int newTargetTotal = products.UsingTable(tableName).Count(); + + Assert.IsTrue(oldSourceTotal > 0, "There should be existing data in the source table"); + Assert.IsTrue(oldSourceTotal == newSourceTotal, "There should not be any change in the count of rows in the sour +rce table"); + Assert.IsTrue(rowsInserted == oldSourceTotal, "The number of records inserted must match the count of the sourc +ce table"); + Assert.IsTrue(rowsInserted == newTargetTotal, "The different in count in the target table before and after the i +insert must match the total row inserted"); + } + [TestMethod] + public void With_Transaction() + { + var dbContext = SetupDbContext(true); + string tableName = "OrdersUnderTen"; + int rowsInserted; + bool tableExistsBefore, tableExistsAfter; + int oldSourceTotal = dbContext.Orders.Where(o => o.Price < 10M).Count(); + using (var transaction = dbContext.Database.BeginTransaction()) + { + rowsInserted = dbContext.Orders.Where(o => o.Price < 10M).InsertFromQuery(tableName, o => new { o.Price, o.I +Id, o.AddedDateTime, o.Active }); + tableExistsBefore = dbContext.Database.TableExists(tableName); + transaction.Rollback(); + } + tableExistsAfter = dbContext.Database.TableExists(tableName); + int newSourceTotal = dbContext.Orders.Where(o => o.Price < 10M).Count(); + + Assert.IsTrue(oldSourceTotal > 0, "There must be orders in database that match this condition (Price < $10)"); + Assert.IsTrue(rowsInserted == oldSourceTotal, "The number of rows update must match the count of rows that match +h the condtion (Price < $10)"); + Assert.IsTrue(newSourceTotal == oldSourceTotal, "The new count must match the old count since the transaction wa +as rollbacked"); + Assert.IsTrue(tableExistsBefore, string.Format("Table {0} should exist before transaction rollback", tableName)) +); + Assert.IsFalse(tableExistsAfter, string.Format("Table {0} should not exist after transaction rollback", tableNam +me)); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\DbContextExtensions\InsertFromQueryAsync.cs --- + +using System; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class InsertFromQueryAsync : DbContextExtensionsBase +{ + [TestMethod] + public async Task With_DateTime_Value() + { + var dbContext = SetupDbContext(true); + string tableName = "OrdersLast30Days"; + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + int oldTotal = dbContext.Orders.Count(); + + var orders = dbContext.Orders.Where(o => o.AddedDateTime >= dateTime); + int oldSourceTotal = orders.Count(); + int rowsInserted = await orders.InsertFromQueryAsync(tableName, + o => new { o.Id, o.ExternalId, o.Price, o.AddedDateTime, o.ModifiedDateTime, o.Active }); + int newSourceTotal = orders.Count(); + int newTargetTotal = orders.UsingTable(tableName).Count(); + + Assert.IsTrue(oldTotal > oldSourceTotal, "The total should be greater then the number of rows selected from the + source table"); + Assert.IsTrue(oldSourceTotal > 0, "There should be existing data in the source table"); + Assert.IsTrue(oldSourceTotal == newSourceTotal, "There should not be any change in the count of rows in the sour +rce table"); + Assert.IsTrue(rowsInserted == oldSourceTotal, "The number of records inserted must match the count of the sourc +ce table"); + Assert.IsTrue(rowsInserted == newTargetTotal, "The different in count in the target table before and after the i +insert must match the total row inserted"); + } + [TestMethod] + public async Task With_Decimal_Value() + { + var dbContext = SetupDbContext(true); + string tableName = "OrdersUnderTen"; + var orders = dbContext.Orders.Where(o => o.Price < 10M); + int oldSourceTotal = orders.Count(); + int rowsInserted = await dbContext.Orders.Where(o => o.Price < 10M).InsertFromQueryAsync(tableName, o => new { o +o.Id, o.Price, o.AddedDateTime, o.Active }); + int newSourceTotal = orders.Count(); + int newTargetTotal = orders.UsingTable(tableName).Count(); + + Assert.IsTrue(oldSourceTotal > 0, "There should be existing data in the source table"); + Assert.IsTrue(oldSourceTotal == newSourceTotal, "There should not be any change in the count of rows in the sour +rce table"); + Assert.IsTrue(rowsInserted == oldSourceTotal, "The number of records inserted must match the count of the sourc +ce table"); + Assert.IsTrue(rowsInserted == newTargetTotal, "The different in count in the target table before and after the i +insert must match the total row inserted"); + } + [TestMethod] + public async Task With_Schema() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Schema); + string tableName = "ProductsUnderTen"; + var products = dbContext.ProductsWithCustomSchema.Where(o => o.Price < 10M); + int oldSourceTotal = products.Count(); + int rowsInserted = await dbContext.ProductsWithCustomSchema.Where(o => o.Price < 10M).InsertFromQueryAsync(table +eName, o => new { o.Id, o.Price }); + int newSourceTotal = products.Count(); + int newTargetTotal = products.UsingTable(tableName).Count(); + + Assert.IsTrue(oldSourceTotal > 0, "There should be existing data in the source table"); + Assert.IsTrue(oldSourceTotal == newSourceTotal, "There should not be any change in the count of rows in the sour +rce table"); + Assert.IsTrue(rowsInserted == oldSourceTotal, "The number of records inserted must match the count of the sourc +ce table"); + Assert.IsTrue(rowsInserted == newTargetTotal, "The different in count in the target table before and after the i +insert must match the total row inserted"); + } + [TestMethod] + public async Task With_Transaction() + { + var dbContext = SetupDbContext(true); + string tableName = "OrdersUnderTen"; + int rowsInserted; + bool tableExistsBefore, tableExistsAfter; + int oldSourceTotal = dbContext.Orders.Where(o => o.Price < 10M).Count(); + using (var transaction = dbContext.Database.BeginTransaction()) + { + rowsInserted = await dbContext.Orders.Where(o => o.Price < 10M).InsertFromQueryAsync(tableName, o => new { o +o.Price, o.Id, o.AddedDateTime, o.Active }); + tableExistsBefore = dbContext.Database.TableExists(tableName); + transaction.Rollback(); + } + tableExistsAfter = dbContext.Database.TableExists(tableName); + int newSourceTotal = dbContext.Orders.Where(o => o.Price < 10M).Count(); + + Assert.IsTrue(oldSourceTotal > 0, "There must be orders in database that match this condition (Price < $10)"); + Assert.IsTrue(rowsInserted == oldSourceTotal, "The number of rows update must match the count of rows that match +h the condtion (Price < $10)"); + Assert.IsTrue(newSourceTotal == oldSourceTotal, "The new count must match the old count since the transaction wa +as rollbacked"); + Assert.IsTrue(tableExistsBefore, string.Format("Table {0} should exist before transaction rollback", tableName)) +); + Assert.IsFalse(tableExistsAfter, string.Format("Table {0} should not exist after transaction rollback", tableNam +me)); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\DbContextExtensions\QueryToCsvFile.cs --- + +using System.IO; +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class QueryToCsvFile : DbContextExtensionsBase +{ + [TestMethod] + public void With_Default_Options() + { + var dbContext = SetupDbContext(true); + var query = dbContext.Orders.Where(o => o.Price < 10M); + int count = query.Count(); + var queryToCsvFileResult = query.QueryToCsvFile("QueryToCsvFile-Test.csv"); + + Assert.IsTrue(count > 0, "There should be existing data in the source table"); + Assert.IsTrue(queryToCsvFileResult.DataRowCount == count, "The number of data rows written to the file should ma +atch the count from the database"); + Assert.IsTrue(queryToCsvFileResult.TotalRowCount == count + 1, "The total number of rows written to the file sho +ould match the count from the database plus the header row"); + } + [TestMethod] + public void With_Options_ColumnDelimiter_TextQualifer_HeaderRow() + { + var dbContext = SetupDbContext(true); + var query = dbContext.Orders.Where(o => o.Price < 10M); + int count = query.Count(); + var queryToCsvFileResult = query.QueryToCsvFile("QueryToCsvFile_Options_ColumnDelimiter_TextQualifer_HeaderRow-T +Test.csv", options => { options.ColumnDelimiter = "|"; options.TextQualifer = "\""; options.IncludeHeaderRow = false; }); + + Assert.IsTrue(count > 0, "There should be existing data in the source table"); + Assert.IsTrue(queryToCsvFileResult.DataRowCount == count, "The number of data rows written to the file should ma +atch the count from the database"); + Assert.IsTrue(queryToCsvFileResult.TotalRowCount == count, "The total number of rows written to the file should + match the count from the database without any header row"); + } + [TestMethod] + public void Using_FileStream() + { + var dbContext = SetupDbContext(true); + var query = dbContext.Orders.Where(o => o.Price < 10M); + int count = query.Count(); + var fileStream = File.Create("QueryToCsvFile_Stream-Test.csv"); + var queryToCsvFileResult = query.QueryToCsvFile(fileStream); + + Assert.IsTrue(count > 0, "There should be existing data in the source table"); + Assert.IsTrue(queryToCsvFileResult.DataRowCount == count, "The number of data rows written to the file should ma +atch the count from the database"); + Assert.IsTrue(queryToCsvFileResult.TotalRowCount == count + 1, "The total number of rows written to the file sho +ould match the count from the database plus the header row"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\DbContextExtensions\QueryToCsvFileAsync.cs --- + +using System.IO; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class QueryToCsvFileAsync : DbContextExtensionsBase +{ + [TestMethod] + public async Task With_Default_Options() + { + var dbContext = SetupDbContext(true); + var query = dbContext.Orders.Where(o => o.Price < 10M); + int count = query.Count(); + var queryToCsvFileResult = await query.QueryToCsvFileAsync("QueryToCsvFile-Test.csv"); + + Assert.IsTrue(count > 0, "There should be existing data in the source table"); + Assert.IsTrue(queryToCsvFileResult.DataRowCount == count, "The number of data rows written to the file should ma +atch the count from the database"); + Assert.IsTrue(queryToCsvFileResult.TotalRowCount == count + 1, "The total number of rows written to the file sho +ould match the count from the database plus the header row"); + } + [TestMethod] + public async Task With_Options_ColumnDelimiter_TextQualifer_HeaderRow() + { + var dbContext = SetupDbContext(true); + var query = dbContext.Orders.Where(o => o.Price < 10M); + int count = query.Count(); + var queryToCsvFileResult = await query.QueryToCsvFileAsync("QueryToCsvFile_Options_ColumnDelimiter_TextQualifer_ +_HeaderRow-Test.csv", options => { options.ColumnDelimiter = "|"; options.TextQualifer = "\""; options.IncludeHeaderRow = += false; }); + + Assert.IsTrue(count > 0, "There should be existing data in the source table"); + Assert.IsTrue(queryToCsvFileResult.DataRowCount == count, "The number of data rows written to the file should ma +atch the count from the database"); + Assert.IsTrue(queryToCsvFileResult.TotalRowCount == count, "The total number of rows written to the file should + match the count from the database without any header row"); + } + [TestMethod] + public async Task Using_FileStream() + { + var dbContext = SetupDbContext(true); + var query = dbContext.Orders.Where(o => o.Price < 10M); + int count = query.Count(); + var fileStream = File.Create("QueryToCsvFile_Stream-Test.csv"); + var queryToCsvFileResult = await query.QueryToCsvFileAsync(fileStream); + + Assert.IsTrue(count > 0, "There should be existing data in the source table"); + Assert.IsTrue(queryToCsvFileResult.DataRowCount == count, "The number of data rows written to the file should ma +atch the count from the database"); + Assert.IsTrue(queryToCsvFileResult.TotalRowCount == count + 1, "The total number of rows written to the file sho +ould match the count from the database plus the header row"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\DbContextExtensions\UpdateFromQuery.cs --- + +using System; +using System.Collections.Generic; +using System.Globalization; +using System.Linq; +using System.Threading; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Data; +using N.EntityFrameworkCore.Extensions.Test.Data.Enums; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class UpdateFromQuery : DbContextExtensionsBase +{ + [TestMethod] + public void With_Boolean_Value() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Products.Count(a => a.OutOfStock); + int rowUpdated = dbContext.Products.Where(a => a.OutOfStock).UpdateFromQuery(a => new Product { OutOfStock = fal +lse }); + int newTotal = dbContext.Products.Count(o => o.OutOfStock); + + Assert.IsTrue(oldTotal > 0, "There must be articles in database that match this condition (OutOfStock == true)") +); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndition (OutOfStock == false)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public void With_Concatenating_String() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.ExternalId == null); + int oldTotal = orders.Count(); + int rowUpdated = orders.UpdateFromQuery(o => new Order { ExternalId = Convert.ToString(o.Id) + "Test" }); + int newTotal = orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (ExternalId == null)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndition (ExternalId == null)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public void With_Concatenating_String_And_Number() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.ExternalId == null); + int oldTotal = orders.Count(); + int rowUpdated = orders.UpdateFromQuery(o => new Order { ExternalId = Convert.ToString(o.Id) + "Test" }); + int newTotal = orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (ExternalId == null)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndition (ExternalId == null)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public void With_DateTime_Value() + { + var dbContext = SetupDbContext(true); + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + DateTime now = DateTime.UtcNow; + + int oldTotal = dbContext.Orders.Where(o => o.AddedDateTime >= dateTime).Count(); + int rowUpdated = dbContext.Orders.Where(o => o.AddedDateTime >= dateTime).UpdateFromQuery(o => new Order { Modif +fiedDateTime = now }); + int newTotal = dbContext.Orders.Where(o => o.ModifiedDateTime == now).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Orders added in last 30 +0 days)"); + Assert.IsTrue(rowUpdated == newTotal, "The number of rows updated should equal the new count of rows that match + the condition (Orders added in last 30 days)"); + Assert.IsTrue(oldTotal == newTotal, "The count of rows matching the condition before and after the update should +d be equal. (Orders added in last 30 days)"); + } + [TestMethod] + public void With_DateTimeOffset_Value() + { + var dbContext = SetupDbContext(true); + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + DateTimeOffset now = DateTimeOffset.UtcNow; + + int oldTotal = dbContext.Orders.Where(o => o.AddedDateTime >= dateTime).Count(); + int rowUpdated = dbContext.Orders.Where(o => o.AddedDateTime >= dateTime).UpdateFromQuery(o => new Order { Modif +fiedDateTimeOffset = now }); + int newTotal = dbContext.Orders.Where(o => o.ModifiedDateTimeOffset == now).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Orders added in last 30 +0 days)"); + Assert.IsTrue(rowUpdated == newTotal, "The number of rows updated should equal the new count of rows that match + the condition (Orders added in last 30 days)"); + Assert.IsTrue(oldTotal == newTotal, "The count of rows matching the condition before and after the update should +d be equal. (Orders added in last 30 days)"); + } + [TestMethod] + public void With_DateTimeOffset_No_UTC_Value() + { + var dbContext = SetupDbContext(true); + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + DateTimeOffset now = DateTimeOffset.Parse("2020-06-17T16:00:00+05:00"); + + int oldTotal = dbContext.Orders.Where(o => o.AddedDateTime >= dateTime).Count(); + int rowUpdated = dbContext.Orders.Where(o => o.AddedDateTime >= dateTime).UpdateFromQuery(o => new Order { Modif +fiedDateTimeOffset = now }); + int newTotal = dbContext.Orders.Where(o => o.ModifiedDateTimeOffset == now).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Orders added in last 30 +0 days)"); + Assert.IsTrue(rowUpdated == newTotal, "The number of rows updated should equal the new count of rows that match + the condition (Orders added in last 30 days)"); + Assert.IsTrue(oldTotal == newTotal, "The count of rows matching the condition before and after the update should +d be equal. (Orders added in last 30 days)"); + } + [TestMethod] + public void With_Decimal_Value() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price < 10M); + int oldTotal = orders.Count(); + int rowUpdated = orders.UpdateFromQuery(o => new Order { Price = 25.30M }); + int newTotal = orders.Count(); + int matchCount = dbContext.Orders.Where(o => o.Price == 25.30M).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < $10)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndtion (Price < $10)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + Assert.IsTrue(matchCount == rowUpdated, "The match count must be equal the number of rows updated in the databas +se."); + } + [TestMethod] + public void With_Different_Culture() + { + Thread.CurrentThread.CurrentCulture = CultureInfo.GetCultureInfo("sv-SE"); + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Where(o => o.Price < 10M).Count(); + int rowUpdated = dbContext.Orders.Where(o => o.Price < 10M).UpdateFromQuery(o => new Order { Price = 25.30M }); + int newTotal = dbContext.Orders.Where(o => o.Price < 10M).Count(); + int matchCount = dbContext.Orders.Where(o => o.Price == 25.30M).Count(); + + Assert.AreEqual("25,30", Convert.ToString(25.30M)); + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < $10)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndtion (Price < $10)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + Assert.IsTrue(matchCount == rowUpdated, "The match count must be equal the number of rows updated in the databas +se."); + } + [TestMethod] + public void With_Enum_Value() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(a => a.StatusEnum == ProductStatus.OutOfStock && a.OutOfStock); + int oldTotal = products.Count(); + int rowUpdated = products.UpdateFromQuery(a => new Product { StatusEnum = ProductStatus.InStock }); + int newTotal = products.Count(o => o.StatusEnum == ProductStatus.OutOfStock && o.OutOfStock); + int newTotal2 = dbContext.Products.Count(o => o.StatusEnum == ProductStatus.InStock && o.OutOfStock); + + Assert.IsTrue(oldTotal > 0, "There must be articles in database that match this condition (OutOfStock == true)") +); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndition (OutOfStock == false)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + Assert.IsTrue(newTotal2 == oldTotal, "All rows must have been updated"); + } + [TestMethod] + public void With_Guid_Value() + { + var dbContext = SetupDbContext(true); + var guid = Guid.NewGuid(); + var orders = dbContext.Orders.Where(o => o.Price < 10M); + int oldTotal = orders.Count(); + int rowUpdated = orders.UpdateFromQuery(o => new Order { GlobalId = guid }); + int matchCount = dbContext.Orders.Where(o => o.GlobalId == guid).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < $10)"); + Assert.IsTrue(rowUpdated == oldTotal, $"The number of rows update must match the count of rows that match the co +ondition (GlobalId = '{guid}')"); + Assert.IsTrue(matchCount == rowUpdated, "The match count must be equal the number of rows updated in the databas +se."); + } + [TestMethod] + public void With_Long_List() + { + var dbContext = SetupDbContext(true); + var ids = new List() { 1, 2, 3, 4, 5, 6, 7, 8 }; + var orders = dbContext.Orders.Where(o => ids.Contains(o.Id)); + int oldTotal = orders.Count(); + int rowUpdated = orders.UpdateFromQuery(o => new Order { Price = 25.25M }); + int newTotal = orders.Where(o => o.Price != 25.25M).Count(); + int matchCount = dbContext.Orders.Where(o => ids.Contains(o.Id) && o.Price == 25.25M).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < $10)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndtion (Price < $10)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + Assert.IsTrue(matchCount == rowUpdated, "The match count must be equal the number of rows updated in the databas +se."); + } + [TestMethod] + public void With_MethodCall() + { + var dbContext = SetupDbContext(true); + + int oldTotal = dbContext.Orders.Count(a => a.Price < 10); + int rowUpdated = dbContext.Orders.Where(a => a.Price < 10).UpdateFromQuery(o => new Order { Price = Math.Ceiling +g((o.Price + 10.5M) * 3 / 1) }); + int newTotal = dbContext.Orders.Count(o => o.Price < 10); + + Assert.IsTrue(oldTotal > 0, "There must be order in database that match this condition (Price < 10)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndition (Price < 10)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public void With_Null_Value() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.ExternalId != null); + int oldTotal = orders.Count(); + int rowUpdated = orders.UpdateFromQuery(o => new Order { ExternalId = null }); + int newTotal = orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (ExternalId != null)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndition (ExternalId != null)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public void With_Schema() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Schema); + var products = dbContext.ProductsWithCustomSchema.Where(o => o.Price < 5M); + int oldTotal = products.Count(); + int rowUpdated = products.UpdateFromQuery(o => new ProductWithCustomSchema { Price = 25.30M }); + int newTotal = products.Count(); + int matchCount = dbContext.ProductsWithCustomSchema.Where(o => o.Price == 25.30M).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be products in database that match this condition (Price < 5)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndtion (Price < 5)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + Assert.IsTrue(matchCount == rowUpdated, "The match count must be equal the number of rows updated in the databas +se."); + } + [TestMethod] + public void With_String_Containing_Apostrophe() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Where(o => o.ExternalId == null).Count(); + int rowUpdated = dbContext.Orders.Where(o => o.ExternalId == null).UpdateFromQuery(o => new Order { ExternalId = += "inv'alid" }); + int newTotal = dbContext.Orders.Where(o => o.ExternalId == null).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (ExternalId == null)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndition (ExternalId == null)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public void With_Transaction() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Where(o => o.Price < 10M).Count(); + int rowUpdated; + using (var transaction = dbContext.Database.BeginTransaction()) + { + rowUpdated = dbContext.Orders.Where(o => o.Price < 10M).UpdateFromQuery(o => new Order { Price = 25.30M }); + transaction.Rollback(); + } + int newTotal = dbContext.Orders.Where(o => o.Price < 10M).Count(); + int matchCount = dbContext.Orders.Where(o => o.Price == 25.30M).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < $10)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndtion (Price < $10)"); + Assert.IsTrue(newTotal == oldTotal, "The new count must match the old count since the transaction was rollbacked +d"); + Assert.IsTrue(matchCount == 0, "The match count must be equal to 0 since the transaction was rollbacked."); + } + [TestMethod] + public void With_Variables() + { + var dbContext = SetupDbContext(true); + decimal priceStart = 10M; + decimal priceUpdate = 0.34M; + + int oldTotal = dbContext.Orders.Count(a => a.Price < 10); + int rowUpdated = dbContext.Orders.Where(a => a.Price < 10).UpdateFromQuery(a => new Order { Price = priceStart + ++ priceUpdate }); + int newTotal = dbContext.Orders.Count(o => o.Price < 10); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < 10)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndition (Price < 10)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public void With_Variable_And_Decimal() + { + var dbContext = SetupDbContext(true); + decimal priceStart = 10M; + + int oldTotal = dbContext.Orders.Count(a => a.Price < 10); + int rowUpdated = dbContext.Orders.Where(a => a.Price < 10).UpdateFromQuery(a => new Order { Price = priceStart + ++ 7M }); + int newTotal = dbContext.Orders.Count(o => o.Price < 10); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < 10)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndition (Price < 10)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\DbContextExtensions\UpdateFromQueryAsync.cs --- + +using System; +using System.Collections.Generic; +using System.Globalization; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.EntityFrameworkCore; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Data; +using N.EntityFrameworkCore.Extensions.Test.Data.Enums; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class UpdateFromQueryAsync : DbContextExtensionsBase +{ + [TestMethod] + public async Task With_Boolean_Value() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Products.Count(a => a.OutOfStock); + int rowUpdated = await dbContext.Products.Where(a => a.OutOfStock).UpdateFromQueryAsync(a => new Product { OutOf +fStock = false }); + int newTotal = dbContext.Products.Count(o => o.OutOfStock); + + Assert.IsTrue(oldTotal > 0, "There must be articles in database that match this condition (OutOfStock == true)") +); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndition (OutOfStock == false)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public async Task With_Concatenating_String() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.ExternalId == null); + int oldTotal = orders.Count(); + int rowUpdated = await orders.UpdateFromQueryAsync(o => new Order { ExternalId = Convert.ToString(o.Id) + "Test" +" }); + int newTotal = orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (ExternalId == null)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndition (ExternalId == null)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public async Task With_Concatenating_String_And_Number() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.ExternalId == null); + int oldTotal = orders.Count(); + int rowUpdated = await orders.UpdateFromQueryAsync(o => new Order { ExternalId = Convert.ToString(o.Id) + "Test" +" }); + int newTotal = orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (ExternalId == null)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndition (ExternalId == null)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public async Task With_DateTime_Value() + { + var dbContext = SetupDbContext(true); + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + DateTime now = DateTime.UtcNow; + + int oldTotal = dbContext.Orders.Where(o => o.AddedDateTime >= dateTime).Count(); + int rowUpdated = await dbContext.Orders.Where(o => o.AddedDateTime >= dateTime).UpdateFromQueryAsync(o => new Or +rder { ModifiedDateTime = now }); + int newTotal = dbContext.Orders.Where(o => o.ModifiedDateTime == now).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Orders added in last 30 +0 days)"); + Assert.IsTrue(rowUpdated == newTotal, "The number of rows updated should equal the new count of rows that match + the condition (Orders added in last 30 days)"); + Assert.IsTrue(oldTotal == newTotal, "The count of rows matching the condition before and after the update should +d be equal. (Orders added in last 30 days)"); + } + [TestMethod] + public async Task With_DateTimeOffset_Value() + { + var dbContext = SetupDbContext(true); + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + DateTimeOffset now = DateTimeOffset.UtcNow; + + int oldTotal = dbContext.Orders.Where(o => o.AddedDateTime >= dateTime).Count(); + int rowUpdated = await dbContext.Orders.Where(o => o.AddedDateTime >= dateTime).UpdateFromQueryAsync(o => new Or +rder { ModifiedDateTimeOffset = now }); + int newTotal = dbContext.Orders.Where(o => o.ModifiedDateTimeOffset == now).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Orders added in last 30 +0 days)"); + Assert.IsTrue(rowUpdated == newTotal, "The number of rows updated should equal the new count of rows that match + the condition (Orders added in last 30 days)"); + Assert.IsTrue(oldTotal == newTotal, "The count of rows matching the condition before and after the update should +d be equal. (Orders added in last 30 days)"); + } + [TestMethod] + public async Task With_DateTimeOffset_No_UTC_Value() + { + var dbContext = SetupDbContext(true); + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + DateTimeOffset now = DateTimeOffset.Parse("2020-06-17T16:00:00+05:00"); + + int oldTotal = dbContext.Orders.Where(o => o.AddedDateTime >= dateTime).Count(); + int rowUpdated = await dbContext.Orders.Where(o => o.AddedDateTime >= dateTime).UpdateFromQueryAsync(o => new Or +rder { ModifiedDateTimeOffset = now }); + int newTotal = dbContext.Orders.Where(o => o.ModifiedDateTimeOffset == now).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Orders added in last 30 +0 days)"); + Assert.IsTrue(rowUpdated == newTotal, "The number of rows updated should equal the new count of rows that match + the condition (Orders added in last 30 days)"); + Assert.IsTrue(oldTotal == newTotal, "The count of rows matching the condition before and after the update should +d be equal. (Orders added in last 30 days)"); + } + [TestMethod] + public async Task With_Decimal_Value() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price < 10M); + int oldTotal = orders.Count(); + int rowUpdated = await orders.UpdateFromQueryAsync(o => new Order { Price = 25.30M }); + int newTotal = orders.Count(); + int matchCount = dbContext.Orders.Where(o => o.Price == 25.30M).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < $10)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndtion (Price < $10)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + Assert.IsTrue(matchCount == rowUpdated, "The match count must be equal the number of rows updated in the databas +se."); + } + [TestMethod] + public async Task With_Different_Culture() + { + Thread.CurrentThread.CurrentCulture = CultureInfo.GetCultureInfo("sv-SE"); + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Where(o => o.Price < 10M).Count(); + int rowUpdated = await dbContext.Orders.Where(o => o.Price < 10M).UpdateFromQueryAsync(o => new Order { Price = + 25.30M }); + int newTotal = dbContext.Orders.Where(o => o.Price < 10M).Count(); + int matchCount = dbContext.Orders.Where(o => o.Price == 25.30M).Count(); + + Assert.AreEqual("25,30", Convert.ToString(25.30M)); + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < $10)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndtion (Price < $10)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + Assert.IsTrue(matchCount == rowUpdated, "The match count must be equal the number of rows updated in the databas +se."); + } + [TestMethod] + public async Task With_Enum_Value() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(a => a.StatusEnum == ProductStatus.OutOfStock && a.OutOfStock); + int oldTotal = products.Count(); + int rowUpdated = await products.UpdateFromQueryAsync(a => new Product { StatusEnum = ProductStatus.InStock }); + int newTotal = products.Count(o => o.StatusEnum == ProductStatus.OutOfStock && o.OutOfStock); + int newTotal2 = dbContext.Products.Count(o => o.StatusEnum == ProductStatus.InStock && o.OutOfStock); + + Assert.IsTrue(oldTotal > 0, "There must be articles in database that match this condition (OutOfStock == true)") +); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndition (OutOfStock == false)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + Assert.IsTrue(newTotal2 == oldTotal, "All rows must have been updated"); + } + [TestMethod] + public async Task With_Guid_Value() + { + var dbContext = SetupDbContext(true); + var guid = Guid.NewGuid(); + var orders = dbContext.Orders.Where(o => o.Price < 10M); + int oldTotal = await orders.CountAsync(); + int rowUpdated = await orders.UpdateFromQueryAsync(o => new Order { GlobalId = guid }); + int matchCount = await dbContext.Orders.Where(o => o.GlobalId == guid).CountAsync(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < $10)"); + Assert.IsTrue(rowUpdated == oldTotal, $"The number of rows update must match the count of rows that match the co +ondition (GlobalId = '{guid}')"); + Assert.IsTrue(matchCount == rowUpdated, "The match count must be equal the number of rows updated in the databas +se."); + } + [TestMethod] + public async Task With_Long_List() + { + var dbContext = SetupDbContext(true); + var ids = new List() { 1, 2, 3, 4, 5, 6, 7, 8 }; + var orders = dbContext.Orders.Where(o => ids.Contains(o.Id)); + int oldTotal = orders.Count(); + int rowUpdated = await orders.UpdateFromQueryAsync(o => new Order { Price = 25.25M }); + int newTotal = orders.Where(o => o.Price != 25.25M).Count(); + int matchCount = dbContext.Orders.Where(o => ids.Contains(o.Id) && o.Price == 25.25M).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < $10)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndtion (Price < $10)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + Assert.IsTrue(matchCount == rowUpdated, "The match count must be equal the number of rows updated in the databas +se."); + } + [TestMethod] + public async Task With_MethodCall() + { + var dbContext = SetupDbContext(true); + + int oldTotal = dbContext.Orders.Count(a => a.Price < 10); + int rowUpdated = await dbContext.Orders.Where(a => a.Price < 10).UpdateFromQueryAsync(o => new Order { Price = M +Math.Ceiling((o.Price + 10.5M) * 3 / 1) }); + int newTotal = dbContext.Orders.Count(o => o.Price < 10); + + Assert.IsTrue(oldTotal > 0, "There must be order in database that match this condition (Price < 10)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndition (Price < 10)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public async Task With_Null_Value() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.ExternalId != null); + int oldTotal = orders.Count(); + int rowUpdated = await orders.UpdateFromQueryAsync(o => new Order { ExternalId = null }); + int newTotal = orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (ExternalId != null)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndition (ExternalId != null)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public async Task With_Schema() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Schema); + var products = dbContext.ProductsWithCustomSchema.Where(o => o.Price < 5M); + int oldTotal = products.Count(); + int rowUpdated = await products.UpdateFromQueryAsync(o => new ProductWithCustomSchema { Price = 25.30M }); + int newTotal = products.Count(); + int matchCount = dbContext.ProductsWithCustomSchema.Where(o => o.Price == 25.30M).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be products in database that match this condition (Price < 5)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndtion (Price < 5)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + Assert.IsTrue(matchCount == rowUpdated, "The match count must be equal the number of rows updated in the databas +se."); + } + [TestMethod] + public async Task With_String_Containing_Apostrophe() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Where(o => o.ExternalId == null).Count(); + int rowUpdated = await dbContext.Orders.Where(o => o.ExternalId == null).UpdateFromQueryAsync(o => new Order { E +ExternalId = "inv'alid" }); + int newTotal = dbContext.Orders.Where(o => o.ExternalId == null).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (ExternalId == null)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndition (ExternalId == null)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public async Task With_Transaction() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Where(o => o.Price < 10M).Count(); + int rowUpdated; + using (var transaction = dbContext.Database.BeginTransaction()) + { + rowUpdated = await dbContext.Orders.Where(o => o.Price < 10M).UpdateFromQueryAsync(o => new Order { Price = + 25.30M }); + transaction.Rollback(); + } + int newTotal = dbContext.Orders.Where(o => o.Price < 10M).Count(); + int matchCount = dbContext.Orders.Where(o => o.Price == 25.30M).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < $10)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndtion (Price < $10)"); + Assert.IsTrue(newTotal == oldTotal, "The new count must match the old count since the transaction was rollbacked +d"); + Assert.IsTrue(matchCount == 0, "The match count must be equal to 0 since the transaction was rollbacked."); + } + [TestMethod] + public async Task With_Variables() + { + var dbContext = SetupDbContext(true); + decimal priceStart = 10M; + decimal priceUpdate = 0.34M; + + int oldTotal = dbContext.Orders.Count(a => a.Price < 10); + int rowUpdated = await dbContext.Orders.Where(a => a.Price < 10).UpdateFromQueryAsync(a => new Order { Price = p +priceStart + priceUpdate }); + int newTotal = dbContext.Orders.Count(o => o.Price < 10); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < 10)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndition (Price < 10)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public async Task With_Variable_And_Decimal() + { + var dbContext = SetupDbContext(true); + decimal priceStart = 10M; + + int oldTotal = dbContext.Orders.Count(a => a.Price < 10); + int rowUpdated = await dbContext.Orders.Where(a => a.Price < 10).UpdateFromQueryAsync(a => new Order { Price = p +priceStart + 7M }); + int newTotal = dbContext.Orders.Count(o => o.Price < 10); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < 10)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndition (Price < 10)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\DbSetExtensions\Clear.cs --- + +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +namespace N.EntityFrameworkCore.Extensions.Test.DbSetExtensions; + +[TestClass] +public class Clear : DbContextExtensionsBase +{ + [TestMethod] + public void Using_Dbset() + { + var dbContext = SetupDbContext(true); + int oldOrdersCount = dbContext.Orders.Count(); + dbContext.Orders.Clear(); + int newOrdersCount = dbContext.Orders.Count(); + + Assert.IsTrue(oldOrdersCount > 0, "Orders table should have data"); + Assert.IsTrue(newOrdersCount == 0, "Order table should be empty after truncating"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\DbSetExtensions\ClearAsync.cs --- + +using System.Linq; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +namespace N.EntityFrameworkCore.Extensions.Test.DbSetExtensions; + +[TestClass] +public class ClearAsync : DbContextExtensionsBase +{ + [TestMethod] + public async Task Using_Dbset() + { + var dbContext = SetupDbContext(true); + int oldOrdersCount = dbContext.Orders.Count(); + await dbContext.Orders.ClearAsync(); + int newOrdersCount = dbContext.Orders.Count(); + + Assert.IsTrue(oldOrdersCount > 0, "Orders table should have data"); + Assert.IsTrue(newOrdersCount == 0, "Order table should be empty after truncating"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\DbSetExtensions\Truncate.cs --- + +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +namespace N.EntityFrameworkCore.Extensions.Test.DbSetExtensions; + +[TestClass] +public class Truncate : DbContextExtensionsBase +{ + [TestMethod] + public void Using_Dbset() + { + var dbContext = SetupDbContext(true); + int oldOrdersCount = dbContext.Orders.Count(); + dbContext.Orders.Truncate(); + int newOrdersCount = dbContext.Orders.Count(); + + Assert.IsTrue(oldOrdersCount > 0, "Orders table should have data"); + Assert.IsTrue(newOrdersCount == 0, "Order table should be empty after truncating"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\DbSetExtensions\TruncateAsync.cs --- + +using System.Linq; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +namespace N.EntityFrameworkCore.Extensions.Test.DbSetExtensions; + +[TestClass] +public class TruncateAsync : DbContextExtensionsBase +{ + [TestMethod] + public async Task Using_Dbset() + { + var dbContext = SetupDbContext(true); + int oldOrdersCount = dbContext.Orders.Count(); + await dbContext.Orders.TruncateAsync(); + int newOrdersCount = dbContext.Orders.Count(); + + Assert.IsTrue(oldOrdersCount > 0, "Orders table should have data"); + Assert.IsTrue(newOrdersCount == 0, "Order table should be empty after truncating"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\LinqExtensions\ToSqlPredicateTests.cs --- + +using System; +using System.Linq.Expressions; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace N.EntityFrameworkCore.Extensions.Test.LinqExtensions; + +[TestClass] +public class ToSqlPredicateTests +{ + [TestMethod] + public void Should_handle_int() + { + Expression> expression = (s, t) => s.Id == t.Id; + + var sqlPredicate = expression.ToSqlPredicate("s", "t"); + + Assert.AreEqual("s.Id = t.Id", sqlPredicate); + } + + [TestMethod] + public void Should_handle_enum() + { + Expression> expression = (s, t) => s.Type == t.Type; + + var sqlPredicate = expression.ToSqlPredicate("s", "t"); + + Assert.AreEqual("s.Type = t.Type", sqlPredicate); + } + + [TestMethod] + public void Should_handle_complex_one() + { + Expression> expression = (s, t) => s.Type == t.Type && + (s.Id == t.Id && + s.ExternalId == t.ExternalId); + + var sqlPredicate = expression.ToSqlPredicate("s", "t"); + + Assert.AreEqual("s.Type = t.Type AND s.Id = t.Id AND s.ExternalId = t.ExternalId", sqlPredicate); + } + + [TestMethod] + public void Should_handle_prop_naming() + { + Expression> expression = (source, target) => source.Id == target.Id && + source.ExternalId == target.ExternalId; + + var sqlPredicate = expression.ToSqlPredicate("s", "t"); + + Assert.AreEqual("s.Id = t.Id AND s.ExternalId = t.ExternalId", sqlPredicate); + } + + [TestMethod] + public void Should_handle_simple_big_one() + { + Expression> expression = (s, t) => s.Type == t.Type && + s.Id == t.Id && + s.ExternalId == t.ExternalId && + s.TesterVar1 == t.TesterVar1 && + s.TesterVar2 == t.TesterVar2 && + s.TesterVar3 == t.TesterVar3 && + s.TesterVar4 == t.TesterVar4 && + s.TesterVar5 == t.TesterVar5; + + var sqlPredicate = expression.ToSqlPredicate("s", "t"); + + Assert.AreEqual("s.Type = t.Type AND s.Id = t.Id AND s.ExternalId = t.ExternalId AND s.TesterVar1 = t.TesterVar1 +1 AND s.TesterVar2 = t.TesterVar2 AND s.TesterVar3 = t.TesterVar3 AND s.TesterVar4 = t.TesterVar4 AND s.TesterVar5 = t.Te +esterVar5", sqlPredicate); + } + + [TestMethod] + public void Should_handle_complex_big_one() + { + Expression> expression = (s, t) => s.Type == t.Type && + s.Id == t.Id && + (s.ExternalId == t.ExternalId || s.TesterVar1 == t +t.TesterVar1) && + (s.TesterVar2 == t.TesterVar2 || (s.TesterVar2 == + null && t.TesterVar2 == null)) && + (s.TesterVar3 == t.TesterVar3 || (s.TesterVar3 != + null && t.TesterVar3 != null)); + + var sqlPredicate = expression.ToSqlPredicate("s", "t"); + + Assert.AreEqual("s.Type = t.Type AND s.Id = t.Id AND (s.ExternalId = t.ExternalId OR s.TesterVar1 = t.TesterVar1 +1) AND (s.TesterVar2 = t.TesterVar2 OR s.TesterVar2 IS NULL AND t.TesterVar2 IS NULL) AND (s.TesterVar3 = t.TesterVar3 OR +R s.TesterVar3 IS NOT NULL AND t.TesterVar3 IS NOT NULL)", sqlPredicate); + } + + record Entity + { + public Guid Id { get; set; } + public EntityType Type { get; set; } + public int ExternalId { get; set; } + public string TesterVar1 { get; set; } + public string TesterVar2 { get; set; } + public string TesterVar3 { get; set; } + public string TesterVar4 { get; set; } + public string TesterVar5 { get; set; } + } + + enum EntityType + { + One, + Two, + Three + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\Migrations\20250509021251_Initial.cs --- + +using System; +using Microsoft.EntityFrameworkCore.Migrations; + +#nullable disable + +namespace N.EntityFrameworkCore.Extensions.Test.Migrations +{ + public partial class Initial : Migration + { + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.Sql("CREATE TRIGGER trgProductWithTriggers\r\nON ProductsWithTrigger\r\nFOR INSERT, UPDATE, +, DELETE\r\nAS\r\nBEGIN\r\n PRINT 1 END"); + } + + protected override void Down(MigrationBuilder migrationBuilder) + { + + } + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\Migrations\20250509021251_Initial.Designer.cs --- + +// +using System; +using System.Collections.Generic; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.EntityFrameworkCore.Migrations; +using Microsoft.EntityFrameworkCore.Storage.ValueConversion; +using N.EntityFrameworkCore.Extensions.Test.Data; + +#nullable disable + +namespace N.EntityFrameworkCore.Extensions.Test.Migrations +{ + [DbContext(typeof(TestDbContext))] + [Migration("20250509021251_Initial")] + partial class Initial + { + protected override void BuildTargetModel(ModelBuilder modelBuilder) + { + + } + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\Migrations\TestDbContextModelSnapshot.cs --- + +// +using System; +using System.Collections.Generic; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.EntityFrameworkCore.Storage.ValueConversion; +using N.EntityFrameworkCore.Extensions.Test.Data; + +#nullable disable + +namespace N.EntityFrameworkCore.Extensions.Test.Migrations +{ + [DbContext(typeof(TestDbContext))] + partial class TestDbContextModelSnapshot : ModelSnapshot + { + protected override void BuildModel(ModelBuilder modelBuilder) + { + + } + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.SqlServer.Test\N.EntityFramework.Extensions.SqlServer.Test.csproj --- + + + + + net10.0 + + $(MSBuildThisFileDirectory)..\..\N.EntityFramework.Extensions.SqlServer.runsettings + + + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + + + + + + + + + + + + + + + + Always + + + + + + +=== DIRECTORY: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFram +mework.Extensions.PostgreSql.Test === + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\appsettings.json --- + +{ + "DatabaseProvider": "PostgreSql", + "UsePostgreSqlContainer": true, + "ConnectionStrings": { + "SqlServerTestDatabase": "Server=(localdb)\\mssqllocaldb;Database=N.EntityFrameworkCore.Test;Trusted_Connection=True +e;", + "PostgreSqlTestDatabase": "Host=localhost;Database=N.EntityFrameworkCore.Test;Username=postgres;Password=postgres" + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\Common\Config.cs --- + +using System; +using System.Data.Common; +using Microsoft.Data.SqlClient; +using Microsoft.Extensions.Configuration; +using Npgsql; + +namespace N.EntityFrameworkCore.Extensions.Test.Common; + +public class Config +{ + private static readonly IConfigurationRoot configuration = new ConfigurationBuilder() + .AddJsonFile("appsettings.json", optional: true, reloadOnChange: true) + .AddEnvironmentVariables() + .Build(); + + public static string GetConnectionString(string name) + { + return configuration.GetConnectionString(name); + } + public static string DatabaseProvider => configuration["DatabaseProvider"] ?? "SqlServer"; + public static bool IsSqlServer => string.Equals(DatabaseProvider, "SqlServer", StringComparison.OrdinalIgnoreCase); + public static bool IsPostgreSql => string.Equals(DatabaseProvider, "PostgreSql", StringComparison.OrdinalIgnoreCase) +); + public static bool UsePostgreSqlContainer => + IsPostgreSql && !string.Equals(configuration["UsePostgreSqlContainer"], "false", StringComparison.OrdinalIgnoreC +Case); + public static string GetTestDatabaseConnectionString() => IsPostgreSql + ? (UsePostgreSqlContainer ? PostgreSqlContainerManager.GetConnectionString() : GetConnectionString("PostgreSqlTe +estDatabase")) + : GetConnectionString("SqlServerTestDatabase"); + public static DbParameter CreateParameter(string name, object value) => IsPostgreSql + ? new NpgsqlParameter(name, value ?? DBNull.Value) + : new SqlParameter(name, value ?? DBNull.Value); + public static string DelimitIdentifier(string identifier) => IsPostgreSql ? $"\"{identifier}\"" : $"[{identifier}]"; + public static string DelimitTableName(string tableName) => IsPostgreSql ? $"\"{tableName}\"" : tableName; + public static bool IsPrimaryKeyViolation(Exception exception) => + IsSqlServer + ? exception.Message.StartsWith("Violation of PRIMARY KEY constraint 'PK_Orders'.", StringComparison.Ordinal) + : exception.Message.Contains("duplicate key value violates unique constraint", StringComparison.OrdinalIgnor +reCase); +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\Common\PostgreSqlContainerManager.cs --- + +using System; +using System.Threading.Tasks; +using Testcontainers.PostgreSql; + +namespace N.EntityFrameworkCore.Extensions.Test.Common; + +internal static class PostgreSqlContainerManager +{ + private static readonly object syncRoot = new(); + private static Task initializationTask; + private static PostgreSqlContainer container; + private static bool cleanupRegistered; + + internal static string GetConnectionString() + { + EnsureStarted(); + return container.GetConnectionString(); + } + + internal static void EnsureStarted() + { + EnsureStartedAsync().GetAwaiter().GetResult(); + } + + internal static Task EnsureStartedAsync() + { + lock (syncRoot) + { + initializationTask ??= StartContainerAsync(); + return initializationTask; + } + } + + private static async Task StartContainerAsync() + { + try + { + container = new PostgreSqlBuilder("postgres:17-alpine") + .WithDatabase("N.EntityFrameworkCore.Test") + .WithUsername("postgres") + .WithPassword("postgres") + .Build(); + + await container.StartAsync(); + RegisterCleanup(); + } + catch (Exception ex) + { + throw new InvalidOperationException("PostgreSql tests require Docker when UsePostgreSqlContainer is enabled. +.", ex); + } + } + + private static void RegisterCleanup() + { + lock (syncRoot) + { + if (cleanupRegistered) + return; + + AppDomain.CurrentDomain.ProcessExit += (_, _) => + { + if (container != null) + container.DisposeAsync().AsTask().GetAwaiter().GetResult(); + }; + cleanupRegistered = true; + } + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\Common\TestDatabaseInitializer.cs --- + +using System.Threading.Tasks; +using Microsoft.EntityFrameworkCore; +using N.EntityFrameworkCore.Extensions.Test.Data; + +namespace N.EntityFrameworkCore.Extensions.Test.Common; + +internal static class TestDatabaseInitializer +{ + internal static void EnsureCreated(TestDbContext dbContext) + { + if (Config.UsePostgreSqlContainer) + PostgreSqlContainerManager.EnsureStarted(); + + if (Config.IsPostgreSql) + dbContext.Database.ExecuteSqlRaw("CREATE EXTENSION IF NOT EXISTS pgcrypto"); + + dbContext.Database.EnsureCreated(); + CreateProviderSpecificObjects(dbContext); + } + + internal static async Task EnsureCreatedAsync(TestDbContext dbContext) + { + if (Config.UsePostgreSqlContainer) + await PostgreSqlContainerManager.EnsureStartedAsync(); + + if (Config.IsPostgreSql) + await dbContext.Database.ExecuteSqlRawAsync("CREATE EXTENSION IF NOT EXISTS pgcrypto"); + + await dbContext.Database.EnsureCreatedAsync(); + await CreateProviderSpecificObjectsAsync(dbContext); + } + + internal static void CreateProviderSpecificObjects(TestDbContext dbContext) + { + if (Config.IsPostgreSql) + { + dbContext.Database.ExecuteSqlRaw(""" + CREATE OR REPLACE FUNCTION set_order_modified_datetime() + RETURNS TRIGGER AS $$ + BEGIN + NEW."DbModifiedDateTime" = CURRENT_TIMESTAMP; + RETURN NEW; + END; + $$ LANGUAGE plpgsql; + """); + dbContext.Database.ExecuteSqlRaw(""" + DROP TRIGGER IF EXISTS trg_order_modified_datetime ON "Orders"; + CREATE TRIGGER trg_order_modified_datetime + BEFORE INSERT OR UPDATE ON "Orders" + FOR EACH ROW + EXECUTE FUNCTION set_order_modified_datetime(); + """); + dbContext.Database.ExecuteSqlRaw(""" + CREATE OR REPLACE FUNCTION trg_product_with_triggers() + RETURNS TRIGGER AS $$ + BEGIN + RETURN COALESCE(NEW, OLD); + END; + $$ LANGUAGE plpgsql; + """); + dbContext.Database.ExecuteSqlRaw(""" + DROP TRIGGER IF EXISTS trgProductWithTriggers ON "ProductsWithTrigger"; + CREATE TRIGGER trgProductWithTriggers + BEFORE INSERT OR UPDATE OR DELETE ON "ProductsWithTrigger" + FOR EACH ROW + EXECUTE FUNCTION trg_product_with_triggers(); + """); + } + else + { + dbContext.Database.ExecuteSqlRaw(""" + IF OBJECT_ID('trgProductWithTriggers', 'TR') IS NOT NULL + DROP TRIGGER trgProductWithTriggers + """); + dbContext.Database.ExecuteSqlRaw(""" + CREATE TRIGGER trgProductWithTriggers + ON ProductsWithTrigger + FOR INSERT, UPDATE, DELETE + AS + BEGIN + PRINT 1 + END + """); + } + } + + internal static async Task CreateProviderSpecificObjectsAsync(TestDbContext dbContext) + { + if (Config.IsPostgreSql) + { + await dbContext.Database.ExecuteSqlRawAsync(""" + CREATE OR REPLACE FUNCTION set_order_modified_datetime() + RETURNS TRIGGER AS $$ + BEGIN + NEW."DbModifiedDateTime" = CURRENT_TIMESTAMP; + RETURN NEW; + END; + $$ LANGUAGE plpgsql; + """); + await dbContext.Database.ExecuteSqlRawAsync(""" + DROP TRIGGER IF EXISTS trg_order_modified_datetime ON "Orders"; + CREATE TRIGGER trg_order_modified_datetime + BEFORE INSERT OR UPDATE ON "Orders" + FOR EACH ROW + EXECUTE FUNCTION set_order_modified_datetime(); + """); + await dbContext.Database.ExecuteSqlRawAsync(""" + CREATE OR REPLACE FUNCTION trg_product_with_triggers() + RETURNS TRIGGER AS $$ + BEGIN + RETURN COALESCE(NEW, OLD); + END; + $$ LANGUAGE plpgsql; + """); + await dbContext.Database.ExecuteSqlRawAsync(""" + DROP TRIGGER IF EXISTS trgProductWithTriggers ON "ProductsWithTrigger"; + CREATE TRIGGER trgProductWithTriggers + BEFORE INSERT OR UPDATE OR DELETE ON "ProductsWithTrigger" + FOR EACH ROW + EXECUTE FUNCTION trg_product_with_triggers(); + """); + } + else + { + await dbContext.Database.ExecuteSqlRawAsync(""" + IF OBJECT_ID('trgProductWithTriggers', 'TR') IS NOT NULL + DROP TRIGGER trgProductWithTriggers + """); + await dbContext.Database.ExecuteSqlRawAsync(""" + CREATE TRIGGER trgProductWithTriggers + ON ProductsWithTrigger + FOR INSERT, UPDATE, DELETE + AS + BEGIN + PRINT 1 + END + """); + } + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\Data\Address.cs --- + +using System; +using System.Collections.Generic; +using System.ComponentModel.DataAnnotations.Schema; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +[ComplexType] +public class Address +{ + public required string Line1 { get; set; } + public string? Line2 { get; set; } + public required string City { get; set; } + public required string Country { get; set; } + public required string PostCode { get; set; } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\Data\Enums\ProductStatus.cs --- + +namespace N.EntityFrameworkCore.Extensions.Test.Data.Enums; + +public enum ProductStatus +{ + InStock, + OutOfStock, +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\Data\Order.cs --- + +using System; +using System.ComponentModel.DataAnnotations; + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class Order +{ + [Key] + public long Id { get; set; } + public string ExternalId { get; set; } + public Guid? GlobalId { get; set; } + public decimal Price { get; set; } + public DateTime AddedDateTime { get; set; } + public DateTime? ModifiedDateTime { get; set; } + public DateTimeOffset? ModifiedDateTimeOffset { get; set; } + public bool DbActive { get; set; } + public DateTime DbAddedDateTime { get; set; } + public DateTime DbModifiedDateTime { get; set; } + public bool? Trigger { get; set; } + public bool Active { get; set; } + public OrderStatus Status { get; set; } + public Order() + { + AddedDateTime = DateTime.UtcNow; + Active = true; + } +} + +public enum OrderStatus +{ + Unknown, + Completed, + Error +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\Data\OrderWithComplexType.cs --- + +using System; +using System.Collections.Generic; +using System.ComponentModel.DataAnnotations; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class OrderWithComplexType +{ + [Key] + public long Id { get; set; } + [Required] + public Address ShippingAddress { get; set; } + [Required] + public Address BillingAddress { get; set; } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\Data\Position.cs --- + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class Position +{ + public int Building; + public int Aisle; + public int Bay; +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\Data\Product.cs --- + +using System; +using System.ComponentModel.DataAnnotations; +using System.ComponentModel.DataAnnotations.Schema; +using N.EntityFrameworkCore.Extensions.Test.Data.Enums; + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class Product +{ + [Key] + [DatabaseGenerated(DatabaseGeneratedOption.None)] + public string Id { get; set; } + [StringLength(50)] + public string Name { get; set; } + public decimal Price { get; set; } + public bool OutOfStock { get; set; } + [Column("Status")] + [StringLength(25)] + public string StatusString { get; set; } + public int? ProductCategoryId { get; set; } + public System.Drawing.Color Color { get; set; } + public ProductStatus? StatusEnum { get; set; } + public DateTime? UpdatedDateTime { get; set; } + + public Position Position { get; set; } + + public virtual ProductCategory ProductCategory { get; set; } + public Product() + { + + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\Data\ProductCategory.cs --- + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class ProductCategory +{ + public int Id { get; set; } + public string Name { get; set; } + public bool Active { get; internal set; } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\Data\ProductWithComplexKey.cs --- + +using System; +using System.ComponentModel.DataAnnotations; +using System.ComponentModel.DataAnnotations.Schema; + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class ProductWithComplexKey +{ + public Guid Key1 { get; set; } + public Guid Key2 { get; set; } + public Guid Key3 { get; set; } + public Guid Key4 { get; set; } + public string ExternalId { get; set; } + public decimal Price { get; set; } + public bool OutOfStock { get; set; } + [Column("Status")] + [StringLength(25)] + public string StatusString { get; set; } + public DateTime? UpdatedDateTime { get; set; } + public ProductWithComplexKey() + { + Key3 = Guid.NewGuid(); + Key4 = Guid.NewGuid(); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\Data\ProductWithCustomSchema.cs --- + +using System.ComponentModel.DataAnnotations; +using System.ComponentModel.DataAnnotations.Schema; + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class ProductWithCustomSchema +{ + [Key] + [DatabaseGenerated(DatabaseGeneratedOption.None)] + public string Id { get; set; } + [StringLength(50)] + public string Name { get; set; } + public decimal Price { get; set; } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\Data\ProductWithTrigger.cs --- + +using System.ComponentModel.DataAnnotations; +using System.ComponentModel.DataAnnotations.Schema; + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class ProductWithTrigger +{ + [Key] + [DatabaseGenerated(DatabaseGeneratedOption.None)] + public string Id { get; set; } + [StringLength(50)] + public string Name { get; set; } + public decimal Price { get; set; } + public bool OutOfStock { get; set; } + [Column("Status")] + [StringLength(25)] + public string StatusString { get; set; } + public ProductWithTrigger() + { + + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\Data\TestDbContext.cs --- + +using System; +using System.Drawing; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Diagnostics; +using N.EntityFrameworkCore.Extensions.Test.Common; + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class TestDbContext : DbContext +{ + public virtual DbSet Products { get; set; } + public virtual DbSet ProductCategories { get; set; } + public virtual DbSet ProductsWithCustomSchema { get; set; } + public virtual DbSet ProductsWithComplexKey { get; set; } + public virtual DbSet ProductsWithTrigger { get; set; } + public virtual DbSet Orders { get; set; } + public virtual DbSet OrdersWithComplexType { get; set; } + public virtual DbSet TpcPeople { get; set; } + public virtual DbSet TphPeople { get; set; } + public virtual DbSet TphCustomers { get; set; } + public virtual DbSet TphVendors { get; set; } + public virtual DbSet TptPeople { get; set; } + public virtual DbSet TptCustomers { get; set; } + public virtual DbSet TptVendors { get; set; } + + protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder) + { + if (Config.IsPostgreSql) + optionsBuilder.UseNpgsql(Config.GetTestDatabaseConnectionString()); + else + optionsBuilder.UseSqlServer(Config.GetTestDatabaseConnectionString()); + optionsBuilder.SetupEfCoreExtensions(); + optionsBuilder.UseLazyLoadingProxies(); + // Tell EF Core to allow mismatched models for this test run + optionsBuilder.ConfigureWarnings(warnings => + warnings.Ignore(RelationalEventId.PendingModelChangesWarning)); + } + protected override void OnModelCreating(ModelBuilder modelBuilder) + { + modelBuilder.Entity().ToTable("Product", "top"); + modelBuilder.Entity().HasKey(c => new { c.Key1 }); + modelBuilder.Entity().Property("Key1").HasDefaultValueSql(Config.IsPostgreSql ? "ge +en_random_uuid()" : "newsequentialid()"); + modelBuilder.Entity().Property("Key2").HasDefaultValueSql(Config.IsPostgreSql ? "ge +en_random_uuid()" : "newsequentialid()"); + modelBuilder.Entity().HasKey(p => new { p.Key3, p.Key4 }); + modelBuilder.Entity().Property("DbAddedDateTime").HasDefaultValueSql(Config.IsPostgreSql ? "CUR +RRENT_TIMESTAMP" : "getdate()"); + if (Config.IsPostgreSql) + modelBuilder.Entity().Property("DbModifiedDateTime").HasDefaultValueSql("CURRENT_TIMESTAMP" +").ValueGeneratedOnAddOrUpdate(); + else + modelBuilder.Entity().Property("DbModifiedDateTime").HasComputedColumnSql("getdate()"); + modelBuilder.Entity().Property(p => p.DbActive).HasDefaultValueSql(Config.IsPostgreSql ? "TRUE" : " +"((1))"); + modelBuilder.Entity().Property(p => p.Status).HasConversion(); + modelBuilder.Entity(b => + { + b.ComplexProperty(e => e.BillingAddress); + b.ComplexProperty(e => e.ShippingAddress); + }); + modelBuilder.Entity().UseTpcMappingStrategy(); + modelBuilder.Entity().ToTable("TpcCustomer"); + modelBuilder.Entity().ToTable("TpcVendor"); + modelBuilder.Entity().Property("CreatedDate"); + modelBuilder.Entity().ToTable("TptPeople"); + modelBuilder.Entity().ToTable("TptCustomer"); + modelBuilder.Entity().ToTable("TptVendor"); + modelBuilder.Entity(t => + { + t.ComplexProperty(p => p.Position).IsRequired(); + t.Property(p => p.Color).HasConversion(x => x.ToArgb(), x => Color.FromArgb(x)); + }); + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\Data\TpcCustomer.cs --- + +using System; + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class TpcCustomer : TpcPerson +{ + public string Email { get; set; } + public string Phone { get; set; } + public DateTime AddedDate { get; set; } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\Data\TpcPerson.cs --- + +using System.ComponentModel.DataAnnotations.Schema; + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public abstract class TpcPerson +{ + [DatabaseGenerated(DatabaseGeneratedOption.None)] + public long Id { get; set; } + public string FirstName { get; set; } + public string LastName { get; set; } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\Data\TpcVendor.cs --- + + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class TpcVendor : TpcPerson +{ + public string Email { get; set; } + public string Phone { get; set; } + public string Url { get; set; } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\Data\TphCustomer.cs --- + +using System; + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class TphCustomer : TphPerson +{ + public string Email { get; set; } + public string Phone { get; set; } + public DateTime AddedDate { get; set; } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\Data\TphPerson.cs --- + +using System.ComponentModel.DataAnnotations.Schema; + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +[Table("TphPeople")] +public abstract class TphPerson +{ + public long Id { get; set; } + public string FirstName { get; set; } + public string LastName { get; set; } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\Data\TphVendor.cs --- + + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class TphVendor : TphPerson +{ + public string Email { get; set; } + public string Phone { get; set; } + public string Url { get; set; } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\Data\TptCustomer.cs --- + +using System; + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class TptCustomer : TptPerson +{ + public string Email { get; set; } + public string Phone { get; set; } + public DateTime AddedDate { get; set; } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\Data\TptPerson.cs --- + +using System.ComponentModel.DataAnnotations.Schema; + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class TptPerson +{ + [DatabaseGenerated(DatabaseGeneratedOption.None)] + public long Id { get; set; } + public string FirstName { get; set; } + public string LastName { get; set; } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\Data\TptVendor.cs --- + + +namespace N.EntityFrameworkCore.Extensions.Test.Data; + +public class TptVendor : TptPerson +{ + public string Email { get; set; } + public string Phone { get; set; } + public string Url { get; set; } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\DatabaseExtensions\DatabaseExtensionsBase.cs --- + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Common; +using N.EntityFrameworkCore.Extensions.Test.Data; + +namespace N.EntityFrameworkCore.Extensions.Test.DatabaseExtensions; + +public class DatabaseExtensionsBase +{ + private TestDbContext _currentDbContext; + + [TestCleanup] + public void Cleanup() + { + _currentDbContext?.Dispose(); + _currentDbContext = null; + } + + protected TestDbContext SetupDbContext(bool populateData) + { + var dbContext = new TestDbContext(); + _currentDbContext = dbContext; + TestDatabaseInitializer.EnsureCreated(dbContext); + dbContext.Orders.Truncate(); + if (populateData) + { + var orders = new List(); + int id = 1; + for (int i = 0; i < 2050; i++) + { + DateTime addedDateTime = DateTime.UtcNow.AddDays(-id); + orders.Add(new Order + { + Id = id, + ExternalId = string.Format("id-{0}", i), + Price = 1.25M, + AddedDateTime = addedDateTime, + ModifiedDateTime = addedDateTime.AddHours(3) + }); + id++; + } + for (int i = 0; i < 1050; i++) + { + orders.Add(new Order { Id = id, Price = 5.35M }); + id++; + } + for (int i = 0; i < 2050; i++) + { + orders.Add(new Order { Id = id, Price = 1.25M }); + id++; + } + for (int i = 0; i < 6000; i++) + { + orders.Add(new Order { Id = id, Price = 15.35M }); + id++; + } + for (int i = 0; i < 6000; i++) + { + orders.Add(new Order { Id = id, Price = 15.35M }); + id++; + } + + Debug.WriteLine("Last Id for Order is {0}", id); + dbContext.BulkInsert(orders, new BulkInsertOptions() { KeepIdentity = true }); + } + return dbContext; + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\DatabaseExtensions\SqlQuery_Count.cs --- + +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Common; + +namespace N.EntityFrameworkCore.Extensions.Test.DatabaseExtensions; + +[TestClass] +public class SqlQuery_Count : DatabaseExtensionsBase +{ + [TestMethod] + public void With_Decimal_Value() + { + var dbContext = SetupDbContext(true); + int efCount = dbContext.Orders.Where(o => o.Price > 5M).Count(); + string sql = $"SELECT * FROM {Config.DelimitTableName("Orders")} WHERE {Config.DelimitIdentifier("Price")} > @Pr +rice"; + var sqlCount = dbContext.Database.FromSqlQuery(sql, Config.CreateParameter("@Price", 5M)).Count(); + + Assert.IsTrue(efCount > 0, "Count from EF should be greater than zero"); + Assert.IsTrue(efCount > 0, "Count from SQL should be greater than zero"); + Assert.IsTrue(efCount == sqlCount, "Count from EF should match the count from the SqlQuery"); + } + [TestMethod] + public void With_OrderBy() + { + var dbContext = SetupDbContext(true); + int efCount = dbContext.Orders.Where(o => o.Price > 5M).Count(); + string sql = $"SELECT * FROM {Config.DelimitTableName("Orders")} WHERE {Config.DelimitIdentifier("Price")} > @Pr +rice ORDER BY {Config.DelimitIdentifier("Id")}"; + var sqlCount = dbContext.Database.FromSqlQuery(sql, Config.CreateParameter("@Price", 5M)).Count(); + + Assert.IsTrue(efCount > 0, "Count from EF should be greater than zero"); + Assert.IsTrue(efCount > 0, "Count from SQL should be greater than zero"); + Assert.IsTrue(efCount == sqlCount, "Count from EF should match the count from the SqlQuery"); + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\DatabaseExtensions\SqlQuery_CountAsync.cs --- + +using System.Linq; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Common; + +namespace N.EntityFrameworkCore.Extensions.Test.DatabaseExtensions; + +[TestClass] +public class SqlQuery_CountAsync : DatabaseExtensionsBase +{ + [TestMethod] + public async Task With_Decimal_Value() + { + var dbContext = SetupDbContext(true); + int efCount = dbContext.Orders.Where(o => o.Price > 5M).Count(); + string sql = $"SELECT * FROM {Config.DelimitTableName("Orders")} WHERE {Config.DelimitIdentifier("Price")} > @Pr +rice"; + var sqlCount = await dbContext.Database.FromSqlQuery(sql, Config.CreateParameter("@Price", 5M)).CountAsync(); + + Assert.IsTrue(efCount > 0, "Count from EF should be greater than zero"); + Assert.IsTrue(efCount > 0, "Count from SQL should be greater than zero"); + Assert.IsTrue(efCount == sqlCount, "Count from EF should match the count from the SqlQuery"); + } + [TestMethod] + public async Task With_OrderBy() + { + var dbContext = SetupDbContext(true); + int efCount = dbContext.Orders.Where(o => o.Price > 5M).Count(); + string sql = $"SELECT * FROM {Config.DelimitTableName("Orders")} WHERE {Config.DelimitIdentifier("Price")} > @Pr +rice ORDER BY {Config.DelimitIdentifier("Id")}"; + var sqlCount = await dbContext.Database.FromSqlQuery(sql, Config.CreateParameter("@Price", 5M)).CountAsync(); + + Assert.IsTrue(efCount > 0, "Count from EF should be greater than zero"); + Assert.IsTrue(efCount > 0, "Count from SQL should be greater than zero"); + Assert.IsTrue(efCount == sqlCount, "Count from EF should match the count from the SqlQuery"); + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\DatabaseExtensions\SqlQueryToCsvFile.cs --- + +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Common; + +namespace N.EntityFrameworkCore.Extensions.Test.DatabaseExtensions; + +[TestClass] +public class SqlQueryToCsvFile : DatabaseExtensionsBase +{ + [TestMethod] + public void With_Default_Options() + { + var dbContext = SetupDbContext(true); + int count = dbContext.Orders.Where(o => o.Price > 5M).Count(); + string sql = $"SELECT * FROM {Config.DelimitTableName("Orders")} WHERE {Config.DelimitIdentifier("Price")} > @Pr +rice"; + var queryToCsvFileResult = dbContext.Database.SqlQueryToCsvFile("SqlQueryToCsvFile-Test.csv", sql, Config.Create +eParameter("@Price", 5M)); + + Assert.IsTrue(count > 0, "There should be existing data in the source table"); + Assert.IsTrue(queryToCsvFileResult.DataRowCount == count, "The number of data rows written to the file should ma +atch the count from the database"); + Assert.IsTrue(queryToCsvFileResult.TotalRowCount == count + 1, "The total number of rows written to the file sho +ould match the count from the database plus the header row"); + } + [TestMethod] + public void With_Options_ColumnDelimiter_TextQualifer() + { + var dbContext = SetupDbContext(true); + string filePath = "SqlQueryToCsvFile_Options_ColumnDelimiter_TextQualifer-Test.csv"; + int count = dbContext.Orders.Where(o => o.Price > 5M).Count(); + string sql = $"SELECT * FROM {Config.DelimitTableName("Orders")} WHERE {Config.DelimitIdentifier("Price")} > @Pr +rice"; + var queryToCsvFileResult = dbContext.Database.SqlQueryToCsvFile(filePath, options => { options.ColumnDelimiter = += "|"; options.TextQualifer = "\""; }, + sql, Config.CreateParameter("@Price", 5M)); + + Assert.IsTrue(count > 0, "There should be existing data in the source table"); + Assert.IsTrue(queryToCsvFileResult.DataRowCount == count, "The number of data rows written to the file should ma +atch the count from the database"); + Assert.IsTrue(queryToCsvFileResult.TotalRowCount == count + 1, "The total number of rows written to the file sho +ould match the count from the database plus the header row"); + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\DatabaseExtensions\SqlQueryToCsvFileAsync.cs --- + +using System.Linq; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Common; + +namespace N.EntityFrameworkCore.Extensions.Test.DatabaseExtensions; + +[TestClass] +public class SqlQueryToCsvFileAsync : DatabaseExtensionsBase +{ + [TestMethod] + public async Task With_Default_Options() + { + var dbContext = SetupDbContext(true); + int count = dbContext.Orders.Where(o => o.Price > 5M).Count(); + string sql = $"SELECT * FROM {Config.DelimitTableName("Orders")} WHERE {Config.DelimitIdentifier("Price")} > @Pr +rice"; + var queryToCsvFileResult = await dbContext.Database.SqlQueryToCsvFileAsync("SqlQueryToCsvFile-Test.csv", sql, ne +ew object[] { Config.CreateParameter("@Price", 5M) }); + + Assert.IsTrue(count > 0, "There should be existing data in the source table"); + Assert.IsTrue(queryToCsvFileResult.DataRowCount == count, "The number of data rows written to the file should ma +atch the count from the database"); + Assert.IsTrue(queryToCsvFileResult.TotalRowCount == count + 1, "The total number of rows written to the file sho +ould match the count from the database plus the header row"); + } + [TestMethod] + public async Task With_Options_ColumnDelimiter_TextQualifer() + { + var dbContext = SetupDbContext(true); + string filePath = "SqlQueryToCsvFile_Options_ColumnDelimiter_TextQualifer-Test.csv"; + int count = dbContext.Orders.Where(o => o.Price > 5M).Count(); + string sql = $"SELECT * FROM {Config.DelimitTableName("Orders")} WHERE {Config.DelimitIdentifier("Price")} > @Pr +rice"; + var queryToCsvFileResult = await dbContext.Database.SqlQueryToCsvFileAsync(filePath, options => { options.Column +nDelimiter = "|"; options.TextQualifer = "\""; }, + sql, new object[] { Config.CreateParameter("@Price", 5M) }); + + Assert.IsTrue(count > 0, "There should be existing data in the source table"); + Assert.IsTrue(queryToCsvFileResult.DataRowCount == count, "The number of data rows written to the file should ma +atch the count from the database"); + Assert.IsTrue(queryToCsvFileResult.TotalRowCount == count + 1, "The total number of rows written to the file sho +ould match the count from the database plus the header row"); + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\DatabaseExtensions\TableExists.cs --- + +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace N.EntityFrameworkCore.Extensions.Test.DatabaseExtensions; + +[TestClass] +public class TableExists : DatabaseExtensionsBase +{ + [TestMethod] + public void With_Orders_Table() + { + var dbContext = SetupDbContext(true); + int efCount = dbContext.Orders.Where(o => o.Price > 5M).Count(); + bool ordersTableExists = dbContext.Database.TableExists("Orders"); + bool orderNewTableExists = dbContext.Database.TableExists("OrdersNew"); + + Assert.IsTrue(ordersTableExists, "Orders table should exist"); + Assert.IsTrue(!orderNewTableExists, "Orders_New table should not exist"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\DatabaseExtensions\TruncateTable.cs --- + +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace N.EntityFrameworkCore.Extensions.Test.DatabaseExtensions; + +[TestClass] +public class TruncateTable : DatabaseExtensionsBase +{ + [TestMethod] + public void With_Orders_Table() + { + var dbContext = SetupDbContext(true); + int oldOrdersCount = dbContext.Orders.Count(); + dbContext.Database.TruncateTable("Orders"); + int newOrdersCount = dbContext.Orders.Count(); + + Assert.IsTrue(oldOrdersCount > 0, "Orders table should have data"); + Assert.IsTrue(newOrdersCount == 0, "Order table should be empty after truncating"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\DatabaseExtensions\TruncateTableAsync.cs --- + +using System.Linq; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace N.EntityFrameworkCore.Extensions.Test.DatabaseExtensions; + +[TestClass] +public class TruncateTableAsync : DatabaseExtensionsBase +{ + [TestMethod] + public async Task With_Orders_Table() + { + var dbContext = SetupDbContext(true); + int oldOrdersCount = dbContext.Orders.Count(); + await dbContext.Database.TruncateTableAsync("Orders"); + int newOrdersCount = dbContext.Orders.Count(); + + Assert.IsTrue(oldOrdersCount > 0, "Orders table should have data"); + Assert.IsTrue(newOrdersCount == 0, "Order table should be empty after truncating"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\DbContextExtensions\BulkDelete.cs --- + +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Data; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class BulkDelete : DbContextExtensionsBase +{ + [TestMethod] + public void With_Default_Options() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M).ToList(); + int rowsDeleted = dbContext.BulkDelete(orders); + int newTotal = dbContext.Orders.Where(o => o.Price == 1.25M).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(rowsDeleted == orders.Count, "The number of rows deleted must match the count of existing rows in + database"); + Assert.IsTrue(newTotal == 0, "Must be 0 to indicate all records were deleted"); + } + [TestMethod] + public void With_Inheritance_Tpc() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpc); + var customers = dbContext.TpcPeople.OfType().ToList(); + int rowsDeleted = dbContext.BulkDelete(customers); + var newCustomers = dbContext.TpcPeople.OfType().Count(); + + Assert.IsTrue(customers.Count > 0, "There must be tphCustomer records in database"); + Assert.IsTrue(rowsDeleted == customers.Count, "The number of rows deleted must match the count of existing rows + in database"); + Assert.IsTrue(newCustomers == 0, "Must be 0 to indicate all records were deleted"); + } + [TestMethod] + public void With_Inheritance_Tph() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tph); + var customers = dbContext.TphPeople.OfType().ToList(); + int rowsDeleted = dbContext.BulkDelete(customers); + var newCustomers = dbContext.TphPeople.OfType().Count(); + + Assert.IsTrue(customers.Count > 0, "There must be tphCustomer records in database"); + Assert.IsTrue(rowsDeleted == customers.Count, "The number of rows deleted must match the count of existing rows + in database"); + Assert.IsTrue(newCustomers == 0, "Must be 0 to indicate all records were deleted"); + } + [TestMethod] + public void With_Inheritance_Tpt() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpt); + var customers = dbContext.TptCustomers.ToList(); + int rowsDeleted = dbContext.BulkDelete(customers); + var newCustomers = dbContext.TptCustomers.Count(); + + Assert.IsTrue(customers.Count > 0, "There must be tphCustomer records in database"); + Assert.IsTrue(rowsDeleted == customers.Count, "The number of rows deleted must match the count of existing rows + in database"); + Assert.IsTrue(newCustomers == 0, "Must be 0 to indicate all records were deleted"); + } + [TestMethod] + public void With_Options_DeleteOnCondition() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Where(o => o.Price == 1.25M).Count(); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).ToList(); + int rowsDeleted = dbContext.BulkDelete(orders, options => { options.DeleteOnCondition = (s, t) => s.ExternalId = +== t.ExternalId; options.UsePermanentTable = true; }); + int newTotal = dbContext.Orders.Where(o => o.Price == 1.25M).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price < $2)"); + Assert.IsTrue(rowsDeleted == orders.Count, "The number of rows deleted must match the count of existing rows in + database"); + Assert.IsTrue(newTotal == oldTotal - rowsDeleted, "Must be 0 to indicate all records were deleted"); + } + [TestMethod] + public void With_Transaction() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M).ToList(); + int rowsDeleted, newTotal = 0; + using (var transaction = dbContext.Database.BeginTransaction()) + { + rowsDeleted = dbContext.BulkDelete(orders); + newTotal = dbContext.Orders.Where(o => o.Price == 1.25M).Count(); + transaction.Rollback(); + } + var rollbackTotal = dbContext.Orders.Count(o => o.Price == 1.25M); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price < $2)"); + Assert.IsTrue(rowsDeleted == orders.Count, "The number of rows deleted must match the count of existing rows in + database"); + Assert.IsTrue(newTotal == 0, "Must be 0 to indicate all records were deleted"); + Assert.IsTrue(rollbackTotal == orders.Count, "The number of rows after the transacation has been rollbacked shou +uld match the original count"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\DbContextExtensions\BulkDeleteAsync.cs --- + +using System.Linq; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Data; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class BulkDeleteAsync : DbContextExtensionsBase +{ + [TestMethod] + public async Task With_Default_Options() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M).ToList(); + int rowsDeleted = await dbContext.BulkDeleteAsync(orders); + int newTotal = dbContext.Orders.Where(o => o.Price == 1.25M).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(rowsDeleted == orders.Count, "The number of rows deleted must match the count of existing rows in + database"); + Assert.IsTrue(newTotal == 0, "Must be 0 to indicate all records were deleted"); + } + [TestMethod] + public async Task With_Inheritance_Tpc() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpc); + var customers = dbContext.TpcPeople.OfType().ToList(); + int rowsDeleted = await dbContext.BulkDeleteAsync(customers); + var newCustomers = dbContext.TpcPeople.OfType().Count(); + + Assert.IsTrue(customers.Count > 0, "There must be tphCustomer records in database"); + Assert.IsTrue(rowsDeleted == customers.Count, "The number of rows deleted must match the count of existing rows + in database"); + Assert.IsTrue(newCustomers == 0, "Must be 0 to indicate all records were deleted"); + } + [TestMethod] + public async Task With_Inheritance_Tph() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tph); + var customers = dbContext.TphPeople.OfType().ToList(); + int rowsDeleted = await dbContext.BulkDeleteAsync(customers); + var newCustomers = dbContext.TphPeople.OfType().Count(); + + Assert.IsTrue(customers.Count > 0, "There must be tphCustomer records in database"); + Assert.IsTrue(rowsDeleted == customers.Count, "The number of rows deleted must match the count of existing rows + in database"); + Assert.IsTrue(newCustomers == 0, "Must be 0 to indicate all records were deleted"); + } + [TestMethod] + public async Task With_Inheritance_Tpt() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpt); + var customers = dbContext.TptCustomers.ToList(); + int rowsDeleted = await dbContext.BulkDeleteAsync(customers); + var newCustomers = dbContext.TptCustomers.Count(); + + Assert.IsTrue(customers.Count > 0, "There must be tphCustomer records in database"); + Assert.IsTrue(rowsDeleted == customers.Count, "The number of rows deleted must match the count of existing rows + in database"); + Assert.IsTrue(newCustomers == 0, "Must be 0 to indicate all records were deleted"); + } + [TestMethod] + public async Task With_Options_DeleteOnCondition() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Where(o => o.Price == 1.25M).Count(); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).ToList(); + int rowsDeleted = await dbContext.BulkDeleteAsync(orders, options => { options.DeleteOnCondition = (s, t) => s.E +ExternalId == t.ExternalId; options.UsePermanentTable = true; }); + int newTotal = dbContext.Orders.Where(o => o.Price == 1.25M).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price < $2)"); + Assert.IsTrue(rowsDeleted == orders.Count, "The number of rows deleted must match the count of existing rows in + database"); + Assert.IsTrue(newTotal == oldTotal - rowsDeleted, "Must be 0 to indicate all records were deleted"); + } + [TestMethod] + public async Task With_Transaction() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M).ToList(); + int rowsDeleted, newTotal = 0; + using (var transaction = dbContext.Database.BeginTransaction()) + { + rowsDeleted = await dbContext.BulkDeleteAsync(orders); + newTotal = dbContext.Orders.Where(o => o.Price == 1.25M).Count(); + transaction.Rollback(); + } + var rollbackTotal = dbContext.Orders.Count(o => o.Price == 1.25M); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price < $2)"); + Assert.IsTrue(rowsDeleted == orders.Count, "The number of rows deleted must match the count of existing rows in + database"); + Assert.IsTrue(newTotal == 0, "Must be 0 to indicate all records were deleted"); + Assert.IsTrue(rollbackTotal == orders.Count, "The number of rows after the transacation has been rollbacked shou +uld match the original count"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\DbContextExtensions\BulkFetch.cs --- + +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class BulkFetch : DbContextExtensionsBase +{ + [TestMethod] + public void With_Complex_Property() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(o => o.Price == 1.25m).ToList(); + var fetchedProducts = dbContext.Products.BulkFetch(products); + bool foundNullPositionProperty = fetchedProducts.Any(o => o.Position == null); + + Assert.IsTrue(products.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(products.Count == fetchedProducts.Count(), "The number of rows deleted must match the count of exi +isting rows in database"); + Assert.IsFalse(foundNullPositionProperty, "The Position complex property should be populated when using BulkFetc +ch()"); + } + [TestMethod] + public void With_Default_Options() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M).ToList(); + var fetchedOrders = dbContext.Orders.BulkFetch(orders); + bool ordersAreMatched = true; + + foreach (var fetchedOrder in fetchedOrders) + { + var order = orders.First(o => o.Id == fetchedOrder.Id); + if (order.ExternalId != fetchedOrder.ExternalId || order.AddedDateTime != fetchedOrder.AddedDateTime || orde +er.ModifiedDateTime != fetchedOrder.ModifiedDateTime) + { + ordersAreMatched = false; + break; + } + } + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(orders.Count == fetchedOrders.Count(), "The number of rows deleted must match the count of existin +ng rows in database"); + Assert.IsTrue(ordersAreMatched, "The orders from BulkFetch() should match what is retrieved from DbContext"); + } + [TestMethod] + public void With_Enum() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(o => o.Price == 1.25m).ToList(); + var fetchedProducts = dbContext.Products.BulkFetch(products); + bool productsAreMatched = true; + + foreach (var fetchedProduct in fetchedProducts) + { + var product = products.First(o => o.Id == fetchedProduct.Id); + if (product.Id != fetchedProduct.Id || product.Name != fetchedProduct.Name || product.StatusEnum != fetchedP +Product.StatusEnum) + { + productsAreMatched = false; + break; + } + } + + Assert.IsTrue(products.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(products.Count == fetchedProducts.Count(), "The number of rows deleted must match the count of exi +isting rows in database"); + Assert.IsTrue(productsAreMatched, "The products from BulkFetch() should match what is retrieved from DbContext") +); + } + [TestMethod] + public void With_IQueryable() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price <= 10 && o.ExternalId != null); + var fetchedOrders = dbContext.Orders.BulkFetch(orders, options => { options.IgnoreColumns = o => new { o.Externa +alId }; }).ToList(); + int newTotal = dbContext.Orders.Where(o => o.Price <= 10 && o.ExternalId == null).Count(); + bool foundNullExternalId = fetchedOrders.Where(o => o.ExternalId != null).Any(); + + Assert.IsTrue(orders.Count() > 0, "There must be orders in the database that match condition (Price <= 10 And Ex +xternalId != null)"); + Assert.IsTrue(orders.Count() == fetchedOrders.Count(), "The number of orders must match the number of fetched or +rders"); + Assert.IsTrue(!foundNullExternalId, "Fetched orders should not contain any items where ExternalId is null."); + } + [TestMethod] + public void With_Options_IgnoreColumns() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price <= 10 && o.ExternalId != null).ToList(); + var fetchedOrders = dbContext.Orders.BulkFetch(orders, options => { options.IgnoreColumns = o => new { o.Externa +alId }; }).ToList(); + int newTotal = dbContext.Orders.Where(o => o.Price <= 10 && o.ExternalId == null).Count(); + bool foundNullExternalId = fetchedOrders.Where(o => o.ExternalId != null).Any(); + + Assert.IsTrue(orders.Count() > 0, "There must be orders in the database that match condition (Price <= 10 And Ex +xternalId != null)"); + Assert.IsTrue(orders.Count() == fetchedOrders.Count(), "The number of orders must match the number of fetched or +rders"); + Assert.IsTrue(!foundNullExternalId, "Fetched orders should not contain any items where ExternalId is null."); + } + [TestMethod] + public void With_ValueConverter() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(o => o.Price == 1.25M).ToList(); + var fetchedProducts = dbContext.Products.BulkFetch(products); + bool areMatched = true; + + foreach (var fetchedProduct in fetchedProducts) + { + var product = products.First(o => o.Id == fetchedProduct.Id); + if (product.Name != fetchedProduct.Name || product.Price != fetchedProduct.Price + || product.Color != fetchedProduct.Color) + { + areMatched = false; + break; + } + } + + Assert.IsTrue(products.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(products.Count == fetchedProducts.Count(), "The number of rows deleted must match the count of exi +isting rows in database"); + Assert.IsTrue(areMatched, "The products from BulkFetch() should match what is retrieved from DbContext"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\DbContextExtensions\BulkFetchAsync.cs --- + +using System.Linq; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class BulkFetchAsync : DbContextExtensionsBase +{ + [TestMethod] + public async Task With_Complex_Property() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(o => o.Price == 1.25m).ToList(); + var fetchedProducts = (await dbContext.Products.BulkFetchAsync(products)).ToList(); + bool foundNullPositionProperty = fetchedProducts.Any(o => o.Position == null); + + Assert.IsTrue(products.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(products.Count == fetchedProducts.Count, "The number of rows deleted must match the count of exist +ting rows in database"); + Assert.IsFalse(foundNullPositionProperty, "The Position complex property should be populated when using BulkFetc +chAsync()"); + } + [TestMethod] + public async Task With_Default_Options() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M).ToList(); + var fetchedOrders = (await dbContext.Orders.BulkFetchAsync(orders)).ToList(); + bool ordersAreMatched = true; + + foreach (var fetchedOrder in fetchedOrders) + { + var order = orders.First(o => o.Id == fetchedOrder.Id); + if (order.ExternalId != fetchedOrder.ExternalId || order.AddedDateTime != fetchedOrder.AddedDateTime || orde +er.ModifiedDateTime != fetchedOrder.ModifiedDateTime) + { + ordersAreMatched = false; + break; + } + } + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(orders.Count == fetchedOrders.Count, "The number of rows deleted must match the count of existing + rows in database"); + Assert.IsTrue(ordersAreMatched, "The orders from BulkFetchAsync() should match what is retrieved from DbContext" +"); + } + [TestMethod] + public async Task With_Enum() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(o => o.Price == 1.25m).ToList(); + var fetchedProducts = (await dbContext.Products.BulkFetchAsync(products)).ToList(); + bool productsAreMatched = true; + + foreach (var fetchedProduct in fetchedProducts) + { + var product = products.First(o => o.Id == fetchedProduct.Id); + if (product.Id != fetchedProduct.Id || product.Name != fetchedProduct.Name || product.StatusEnum != fetchedP +Product.StatusEnum) + { + productsAreMatched = false; + break; + } + } + + Assert.IsTrue(products.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(products.Count == fetchedProducts.Count, "The number of rows deleted must match the count of exist +ting rows in database"); + Assert.IsTrue(productsAreMatched, "The products from BulkFetchAsync() should match what is retrieved from DbCont +text"); + } + [TestMethod] + public async Task With_IQueryable() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price <= 10 && o.ExternalId != null); + var fetchedOrders = (await dbContext.Orders.BulkFetchAsync(orders, options => { options.IgnoreColumns = o => new +w { o.ExternalId }; })).ToList(); + bool foundNonNullExternalId = fetchedOrders.Any(o => o.ExternalId != null); + + Assert.IsTrue(orders.Count() > 0, "There must be orders in the database that match condition (Price <= 10 And Ex +xternalId != null)"); + Assert.IsTrue(orders.Count() == fetchedOrders.Count, "The number of orders must match the number of fetched orde +ers"); + Assert.IsFalse(foundNonNullExternalId, "Fetched orders should not contain any items where ExternalId is not null +l."); + } + [TestMethod] + public async Task With_Options_IgnoreColumns() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price <= 10 && o.ExternalId != null).ToList(); + var fetchedOrders = (await dbContext.Orders.BulkFetchAsync(orders, options => { options.IgnoreColumns = o => new +w { o.ExternalId }; })).ToList(); + bool foundNonNullExternalId = fetchedOrders.Any(o => o.ExternalId != null); + + Assert.IsTrue(orders.Count() > 0, "There must be orders in the database that match condition (Price <= 10 And Ex +xternalId != null)"); + Assert.IsTrue(orders.Count() == fetchedOrders.Count, "The number of orders must match the number of fetched orde +ers"); + Assert.IsFalse(foundNonNullExternalId, "Fetched orders should not contain any items where ExternalId is not null +l."); + } + [TestMethod] + public async Task With_ValueConverter() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(o => o.Price == 1.25M).ToList(); + var fetchedProducts = (await dbContext.Products.BulkFetchAsync(products)).ToList(); + bool areMatched = true; + + foreach (var fetchedProduct in fetchedProducts) + { + var product = products.First(o => o.Id == fetchedProduct.Id); + if (product.Name != fetchedProduct.Name || product.Price != fetchedProduct.Price + || product.Color != fetchedProduct.Color) + { + areMatched = false; + break; + } + } + + Assert.IsTrue(products.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(products.Count == fetchedProducts.Count, "The number of rows deleted must match the count of exist +ting rows in database"); + Assert.IsTrue(areMatched, "The products from BulkFetchAsync() should match what is retrieved from DbContext"); + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\DbContextExtensions\BulkInsert.cs --- + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.Data.SqlClient; +using Microsoft.EntityFrameworkCore; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Common; +using N.EntityFrameworkCore.Extensions.Test.Data; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class BulkInsert : DbContextExtensionsBase +{ + [TestMethod] + public void With_Complex_Key() + { + var dbContext = SetupDbContext(true); + var products = new List(); + for (int i = 50000; i < 60000; i++) + { + var key = i.ToString(); + products.Add(new ProductWithComplexKey { Price = 1.57M }); + } + int oldTotal = dbContext.ProductsWithComplexKey.Where(o => o.Price <= 10).Count(); + int rowsInserted = dbContext.BulkInsert(products); + int newTotal = dbContext.ProductsWithComplexKey.Where(o => o.Price <= 10).Count(); + + Assert.IsTrue(rowsInserted == products.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public void With_Complex_Type() + { + var dbContext = SetupDbContext(true); + var orders = new List(); + for (int i = 1; i < 1000; i++) + { + orders.Add(new OrderWithComplexType + { + Id = i, + ShippingAddress = new Address + { + Line1 = $"123 Main St, {i}", + City = "Atlanta", + Country = "USA", + PostCode = "30303" + }, + BillingAddress = new Address + { + Line1 = $"456 Oak St, {i}", + City = "Atlanta", + Country = "USA", + PostCode = "30303" + } + }); + } + int oldTotal = dbContext.OrdersWithComplexType.Count(); + int rowsInserted = dbContext.BulkInsert(orders); + int newTotal = dbContext.OrdersWithComplexType.Count(); + + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public void With_Default_Options() + { + var dbContext = SetupDbContext(false); + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M }); + } + int oldTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + int rowsInserted = dbContext.BulkInsert(orders); + int newTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public void With_Inheritance_Tpc() + { + var dbContext = SetupDbContext(false); + var customers = new List(); + var vendors = new List(); + for (int i = 0; i < 20000; i++) + { + customers.Add(new TpcCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + AddedDate = DateTime.UtcNow + }); + } + for (int i = 20000; i < 30000; i++) + { + vendors.Add(new TpcVendor + { + Id = i, + FirstName = string.Format("Mike_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("mike.smith{0}@domain.com", i), + Url = string.Format("http://domain.com/mike.smith{0}", i) + }); + } + int oldTotal = dbContext.TpcPeople.Count(); + int customerRowsInserted = dbContext.BulkInsert(customers, o => o.UsePermanentTable = true); + int vendorRowsInserted = dbContext.BulkInsert(vendors, o => o.UsePermanentTable = true); + int rowsInserted = customerRowsInserted + vendorRowsInserted; + int newTotal = dbContext.TpcPeople.Count(); + + Assert.IsTrue(rowsInserted == customers.Count + vendors.Count, "The number of rows inserted must match the count +t of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public void With_Inheritance_Tph() + { + var dbContext = SetupDbContext(false); + var customers = new List(); + var vendors = new List(); + for (int i = 0; i < 20000; i++) + { + customers.Add(new TphCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + Phone = "404-555-1111", + AddedDate = DateTime.UtcNow + }); + } + for (int i = 20000; i < 30000; i++) + { + vendors.Add(new TphVendor + { + Id = i, + FirstName = string.Format("Mike_{0}", i), + LastName = string.Format("Smith_{0}", i), + Phone = "404-555-2222", + Email = string.Format("mike.smith{0}@domain.com", i), + Url = string.Format("http://domain.com/mike.smith{0}", i) + }); + } + int oldTotal = dbContext.TphPeople.Count(); + int customerRowsInserted = dbContext.BulkInsert(customers); + int vendorRowsInserted = dbContext.BulkInsert(vendors); + int rowsInserted = customerRowsInserted + vendorRowsInserted; + int newTotal = dbContext.TphPeople.Count(); + + Assert.IsTrue(rowsInserted == customers.Count + vendors.Count, "The number of rows inserted must match the count +t of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public void With_Inheritance_Tpt() + { + var dbContext = SetupDbContext(false); + var customers = new List(); + var vendors = new List(); + for (int i = 0; i < 20000; i++) + { + customers.Add(new TptCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + Phone = "777-555-1234", + AddedDate = DateTime.UtcNow + }); + } + for (int i = 20000; i < 30000; i++) + { + vendors.Add(new TptVendor + { + Id = i, + FirstName = string.Format("Mike_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("mike.smith{0}@domain.com", i), + Url = string.Format("http://domain.com/mike.smith{0}", i) + }); + } + int oldTotal = dbContext.TptPeople.Count(); + int customerRowsInserted = dbContext.BulkInsert(customers, o => o.UsePermanentTable = true); + int vendorRowsInserted = dbContext.BulkInsert(vendors); + int rowsInserted = customerRowsInserted + vendorRowsInserted; + int newTotal = dbContext.TptPeople.Count(); + + Assert.IsTrue(rowsInserted == customers.Count + vendors.Count, "The number of rows inserted must match the count +t of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public void Without_Identity_Column() + { + var dbContext = SetupDbContext(true); + var products = new List(); + for (int i = 50000; i < 60000; i++) + { + products.Add(new Product { Id = i.ToString(), Price = 1.57M }); + } + int oldTotal = dbContext.Products.Where(o => o.Price <= 10).Count(); + int rowsInserted = dbContext.BulkInsert(products); + int newTotal = dbContext.Products.Where(o => o.Price <= 10).Count(); + + Assert.IsTrue(rowsInserted == products.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public void With_Options_AutoMapIdentity() + { + var dbContext = SetupDbContext(false); + var orders = new List(); + for (int i = 0; i < 5000; i++) + { + orders.Add(new Order { ExternalId = i.ToString(), Price = ((decimal)i + 0.55M) }); + } + int rowsAdded = dbContext.BulkInsert(orders, new BulkInsertOptions + { + UsePermanentTable = true + }); + bool autoMapIdentityMatched = true; + var ordersInDb = dbContext.Orders.ToList(); + Order order1 = null; + Order order2 = null; + foreach (var order in orders) + { + order1 = order; + var orderinDb = ordersInDb.First(o => o.Id == order.Id); + order2 = orderinDb; + if (!(orderinDb.ExternalId == order.ExternalId && orderinDb.Price == order.Price)) + { + autoMapIdentityMatched = false; + break; + } + } + + Assert.IsTrue(rowsAdded == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(autoMapIdentityMatched, "The auto mapping of ids of entities that were merged failed to match up") +); + } + [TestMethod] + public void With_Options_IgnoreColumns() + { + var dbContext = SetupDbContext(false); + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i, ExternalId = i.ToString(), Price = 1.57M, Active = true }); + } + int oldTotal = dbContext.Orders.Where(o => o.Price <= 10 && o.ExternalId == null).Count(); + int rowsInserted = dbContext.BulkInsert(orders, options => { options.UsePermanentTable = true; options.IgnoreCol +lumns = o => new { o.ExternalId }; }); + int newTotal = dbContext.Orders.Where(o => o.Price <= 10 && o.ExternalId == null).Count(); + + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public void With_Options_InputColumns() + { + var dbContext = SetupDbContext(false); + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i, ExternalId = i.ToString(), Price = 1.57M, Active = true, Status = OrderStatus +s.Completed }); + } + int oldTotal = dbContext.Orders.Where(o => o.Price == 1.57M && o.ExternalId == null && o.Active == true).Count() +); + int rowsInserted = dbContext.BulkInsert(orders, options => + { + options.UsePermanentTable = true; + options.InputColumns = o => new { o.Price, o.Active, o.AddedDateTime, o.Status }; + }); + int newTotal = dbContext.Orders.Where(o => o.Price == 1.57M && o.ExternalId == null && o.Active == true).Count() +); + + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public void With_KeepIdentity() + { + var dbContext = SetupDbContext(false); + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i + 1000, Price = 1.57M }); + } + int oldTotal = dbContext.Orders.Count(); + int rowsInserted = dbContext.BulkInsert(orders, options => { options.KeepIdentity = true; options.BatchSize = 10 +000; }); + var oldOrders = dbContext.Orders.OrderBy(o => o.Id).ToList(); + var newOrders = dbContext.Orders.OrderBy(o => o.Id).ToList(); + bool allIdentityFieldsMatch = true; + for (int i = 0; i < 20000; i++) + { + if (newOrders[i].Id != oldOrders[i].Id) + { + allIdentityFieldsMatch = false; + break; + } + } + try + { + int rowsInserted2 = dbContext.BulkInsert(orders, new BulkInsertOptions() + { + KeepIdentity = true, + BatchSize = 1000, + }); + } + catch (Exception ex) + { + Assert.IsTrue(Config.IsPrimaryKeyViolation(ex)); + } + + Assert.IsTrue(oldTotal == 0, "There should not be any records in the table"); + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(allIdentityFieldsMatch, "The identities between the source and the database should match."); + } + [TestMethod] + public void With_Schema() + { + var dbContext = SetupDbContext(false); + var products = new List(); + for (int i = 1; i < 10000; i++) + { + var key = i.ToString(); + products.Add(new ProductWithCustomSchema + { + Id = key, + Name = $"Product-{key}", + Price = 1.57M + }); + } + int oldTotal = dbContext.ProductsWithCustomSchema.Where(o => o.Price <= 10).Count(); + int rowsInserted = dbContext.BulkInsert(products); + int newTotal = dbContext.ProductsWithCustomSchema.Where(o => o.Price <= 10).Count(); + + Assert.IsTrue(rowsInserted == products.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public void With_Transaction() + { + var dbContext = SetupDbContext(false); + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M }); + } + int oldTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + int rowsInserted, newTotal; + using (var transaction = dbContext.Database.BeginTransaction()) + { + rowsInserted = dbContext.BulkInsert(orders); + newTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + transaction.Rollback(); + } + int rollbackTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + Assert.IsTrue(rollbackTotal == oldTotal, "The number of rows after the transacation has been rollbacked should m +match the original count"); + } + [TestMethod] + public void With_Options_InsertIfNotExists() + { + var dbContext = SetupDbContext(true); + var orders = new List(); + long maxId = dbContext.Orders.Max(o => o.Id); + long expectedRowsInserted = 1000; + int existingRowsToAdd = 100; + long startId = maxId - existingRowsToAdd + 1, endId = maxId + expectedRowsInserted + 1; + for (long i = startId; i < endId; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M }); + } + + int oldTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + int rowsInserted = dbContext.BulkInsert(orders, new BulkInsertOptions() { InsertIfNotExists = true }); + int newTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + + Assert.IsTrue(rowsInserted == expectedRowsInserted, "The number of rows inserted must match the count of order l +list"); + Assert.IsTrue(newTotal - oldTotal == expectedRowsInserted, "The new count minus the old count should match the n +number of rows inserted."); + } + [TestMethod] + public void With_Proxy_Type() + { + var dbContext = SetupDbContext(false); + int oldTotalCount = dbContext.Products.Where(o => o.Price == 10.57M).Count(); + + var products = new List(); + for (int i = 0; i < 2000; i++) + { + var product = dbContext.Products.CreateProxy(); + product.Id = (-i).ToString(); + product.Price = 10.57M; + products.Add(product); + } + int oldTotal = dbContext.Products.Where(o => o.Price == 10.57M).Count(); + int rowsInserted = dbContext.BulkInsert(products); + int newTotal = dbContext.Products.Where(o => o.Price == 10.57M).Count(); + + Assert.IsTrue(rowsInserted == products.Count, "The number of rows inserted must match the count of products list +t"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public void With_Trigger() + { + var dbContext = SetupDbContext(false); + var products = new List(); + for (int i = 1; i < 1000; i++) + { + products.Add(new ProductWithTrigger { Id = i.ToString(), Price = 1.57M, StatusString = "InStock" }); + } + + //The return int from BulkInsert() will be off when using triggers + dbContext.BulkInsert(products, options => + { + options.AutoMapOutput = false; + if (Config.IsSqlServer) + options.BulkCopyOptions = SqlBulkCopyOptions.FireTriggers; + }); + var rowsInserted = dbContext.ProductsWithTrigger.Count(); + + Assert.IsTrue(rowsInserted == products.Count, $"The number of rows inserted must match the count of products ({r +rowsInserted}!={products.Count})"); + } + [TestMethod] + public void With_ValueGenerated_Default() + { + var dbContext = SetupDbContext(false); + var nowDateTime = Config.IsPostgreSql ? DateTime.UtcNow : DateTime.Now; + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M }); + } + int oldTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + int rowsInserted = dbContext.BulkInsert(orders); + int newTotal = dbContext.Orders.Where(o => o.Price <= 10 && o.DbAddedDateTime > nowDateTime && o.DbActive).Count +t(); + + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public void With_ValueGenerated_Computed() + { + var dbContext = SetupDbContext(false); + var nowDateTime = Config.IsPostgreSql ? DateTime.UtcNow : DateTime.Now; + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M, DbModifiedDateTime = nowDateTime }); + } + int oldTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + int rowsInserted = dbContext.BulkInsert(orders); + int newTotal = dbContext.Orders.Where(o => o.Price <= 10 && o.DbModifiedDateTime > nowDateTime).Count(); + + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\DbContextExtensions\BulkInsertAsync.cs --- + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.Data.SqlClient; +using Microsoft.EntityFrameworkCore; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Common; +using N.EntityFrameworkCore.Extensions.Test.Data; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class BulkInsertAsync : DbContextExtensionsBase +{ + [TestMethod] + public async Task With_Complex_Key() + { + var dbContext = SetupDbContext(true); + var products = new List(); + for (int i = 50000; i < 60000; i++) + { + var key = i.ToString(); + products.Add(new ProductWithComplexKey { Price = 1.57M }); + } + int oldTotal = dbContext.ProductsWithComplexKey.Where(o => o.Price <= 10).Count(); + int rowsInserted = await dbContext.BulkInsertAsync(products); + int newTotal = dbContext.ProductsWithComplexKey.Where(o => o.Price <= 10).Count(); + + Assert.IsTrue(rowsInserted == products.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public async Task With_Complex_Type() + { + var dbContext = SetupDbContext(true); + var orders = new List(); + for (int i = 1; i < 1000; i++) + { + orders.Add(new OrderWithComplexType + { + Id = i, + ShippingAddress = new Address + { + Line1 = $"123 Main St, {i}", + City = "Atlanta", + Country = "USA", + PostCode = "30303" + }, + BillingAddress = new Address + { + Line1 = $"456 Oak St, {i}", + City = "Atlanta", + Country = "USA", + PostCode = "30303" + } + }); + } + int oldTotal = dbContext.OrdersWithComplexType.Count(); + int rowsInserted = await dbContext.BulkInsertAsync(orders); + int newTotal = dbContext.OrdersWithComplexType.Count(); + + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public async Task With_Default_Options() + { + var dbContext = SetupDbContext(false); + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M }); + } + int oldTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + int rowsInserted = await dbContext.BulkInsertAsync(orders); + int newTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + //[TestMethod] + //public async Task With_IEnumerable() + //{ + // var dbContext = SetupDbContext(false); + // var orders = dbContext.Orders.Where(o => o.Price <= 10); + + // foreach(var order in orders) + // { + // order.Price = 15.75M; + // } + // int oldTotal = orders.Count(); + // int rowsInserted = await dbContext.BulkInsertAsync(orders); + // int newTotal = orders.Count(); + + // Assert.IsTrue(rowsInserted == oldTotal, "The number of rows inserted must match the count of order list"); + // Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number + of rows inserted."); + //} + [TestMethod] + public async Task With_Inheritance_Tpc() + { + var dbContext = SetupDbContext(false); + var customers = new List(); + var vendors = new List(); + for (int i = 0; i < 20000; i++) + { + customers.Add(new TpcCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + AddedDate = DateTime.UtcNow + }); + } + for (int i = 20000; i < 30000; i++) + { + vendors.Add(new TpcVendor + { + Id = i, + FirstName = string.Format("Mike_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("mike.smith{0}@domain.com", i), + Url = string.Format("http://domain.com/mike.smith{0}", i) + }); + } + int oldTotal = dbContext.TpcPeople.Count(); + int customerRowsInserted = await dbContext.BulkInsertAsync(customers); + int vendorRowsInserted = await dbContext.BulkInsertAsync(vendors); + int rowsInserted = customerRowsInserted + vendorRowsInserted; + int newTotal = dbContext.TpcPeople.Count(); + + Assert.IsTrue(rowsInserted == customers.Count + vendors.Count, "The number of rows inserted must match the count +t of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public async Task With_Inheritance_Tph() + { + var dbContext = SetupDbContext(false); + var customers = new List(); + var vendors = new List(); + for (int i = 0; i < 20000; i++) + { + customers.Add(new TphCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + Phone = "404-555-1111", + AddedDate = DateTime.UtcNow + }); + } + for (int i = 20000; i < 30000; i++) + { + vendors.Add(new TphVendor + { + Id = i, + FirstName = string.Format("Mike_{0}", i), + LastName = string.Format("Smith_{0}", i), + Phone = "404-555-2222", + Email = string.Format("mike.smith{0}@domain.com", i), + Url = string.Format("http://domain.com/mike.smith{0}", i) + }); + } + int oldTotal = dbContext.TphPeople.Count(); + int customerRowsInserted = await dbContext.BulkInsertAsync(customers); + int vendorRowsInserted = await dbContext.BulkInsertAsync(vendors); + int rowsInserted = customerRowsInserted + vendorRowsInserted; + int newTotal = dbContext.TphPeople.Count(); + + Assert.IsTrue(rowsInserted == customers.Count + vendors.Count, "The number of rows inserted must match the count +t of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public async Task With_Inheritance_Tpt() + { + var dbContext = SetupDbContext(false); + var customers = new List(); + var vendors = new List(); + for (int i = 0; i < 20000; i++) + { + customers.Add(new TptCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + Phone = "777-555-1234", + AddedDate = DateTime.UtcNow + }); + } + for (int i = 20000; i < 30000; i++) + { + vendors.Add(new TptVendor + { + Id = i, + FirstName = string.Format("Mike_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("mike.smith{0}@domain.com", i), + Url = string.Format("http://domain.com/mike.smith{0}", i) + }); + } + int oldTotal = dbContext.TptPeople.Count(); + int customerRowsInserted = await dbContext.BulkInsertAsync(customers, o => o.UsePermanentTable = true); + int vendorRowsInserted = await dbContext.BulkInsertAsync(vendors); + int rowsInserted = customerRowsInserted + vendorRowsInserted; + int newTotal = dbContext.TptPeople.Count(); + + Assert.IsTrue(rowsInserted == customers.Count + vendors.Count, "The number of rows inserted must match the count +t of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public async Task Without_Identity_Column() + { + var dbContext = SetupDbContext(true); + var products = new List(); + for (int i = 50000; i < 60000; i++) + { + products.Add(new Product { Id = i.ToString(), Price = 1.57M }); + } + int oldTotal = dbContext.Products.Where(o => o.Price <= 10).Count(); + int rowsInserted = await dbContext.BulkInsertAsync(products); + int newTotal = dbContext.Products.Where(o => o.Price <= 10).Count(); + + Assert.IsTrue(rowsInserted == products.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public async Task With_Options_AutoMapIdentity() + { + + var dbContext = SetupDbContext(false); + var orders = new List(); + for (int i = 0; i < 5000; i++) + { + orders.Add(new Order { ExternalId = i.ToString(), Price = ((decimal)i + 0.55M) }); + } + int rowsAdded = await dbContext.BulkInsertAsync(orders, new BulkInsertOptions + { + UsePermanentTable = true + }); + bool autoMapIdentityMatched = true; + var ordersInDb = dbContext.Orders.ToList(); + Order order1 = null; + Order order2 = null; + foreach (var order in orders) + { + order1 = order; + var orderinDb = ordersInDb.First(o => o.Id == order.Id); + order2 = orderinDb; + if (!(orderinDb.ExternalId == order.ExternalId && orderinDb.Price == order.Price)) + { + autoMapIdentityMatched = false; + break; + } + } + + Assert.IsTrue(rowsAdded == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(autoMapIdentityMatched, "The auto mapping of ids of entities that were merged failed to match up") +); + } + [TestMethod] + public async Task With_Options_IgnoreColumns() + { + var dbContext = SetupDbContext(false); + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i, ExternalId = i.ToString(), Price = 1.57M, Active = true }); + } + int oldTotal = dbContext.Orders.Where(o => o.Price <= 10 && o.ExternalId == null).Count(); + int rowsInserted = await dbContext.BulkInsertAsync(orders, options => { options.UsePermanentTable = true; option +ns.IgnoreColumns = o => new { o.ExternalId }; }); + int newTotal = dbContext.Orders.Where(o => o.Price <= 10 && o.ExternalId == null).Count(); + + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public async Task With_Options_InputColumns() + { + var dbContext = SetupDbContext(false); + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i, ExternalId = i.ToString(), Price = 1.57M, Active = true, Status = OrderStatus +s.Completed }); + } + int oldTotal = dbContext.Orders.Where(o => o.Price == 1.57M && o.ExternalId == null && o.Active == true).Count() +); + int rowsInserted = await dbContext.BulkInsertAsync(orders, options => + { + options.UsePermanentTable = true; + options.InputColumns = o => new { o.Price, o.Active, o.AddedDateTime, o.Status }; + }); + int newTotal = dbContext.Orders.Where(o => o.Price == 1.57M && o.ExternalId == null && o.Active == true).Count() +); + + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public async Task With_KeepIdentity() + { + var dbContext = SetupDbContext(false); + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i + 1000, Price = 1.57M }); + } + int oldTotal = dbContext.Orders.Count(); + int rowsInserted = await dbContext.BulkInsertAsync(orders, options => { options.KeepIdentity = true; options.Bat +tchSize = 1000; }); + var oldOrders = dbContext.Orders.OrderBy(o => o.Id).ToList(); + var newOrders = dbContext.Orders.OrderBy(o => o.Id).ToList(); + bool allIdentityFieldsMatch = true; + for (int i = 0; i < 20000; i++) + { + if (newOrders[i].Id != oldOrders[i].Id) + { + allIdentityFieldsMatch = false; + break; + } + } + try + { + int rowsInserted2 = await dbContext.BulkInsertAsync(orders, new BulkInsertOptions() + { + KeepIdentity = true, + BatchSize = 1000, + }); + } + catch (Exception ex) + { + Assert.IsTrue(Config.IsPrimaryKeyViolation(ex)); + } + + Assert.IsTrue(oldTotal == 0, "There should not be any records in the table"); + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(allIdentityFieldsMatch, "The identities between the source and the database should match."); + } + [TestMethod] + public async Task With_Proxy_Type() + { + var dbContext = SetupDbContext(false); + int oldTotalCount = dbContext.Products.Where(o => o.Price == 10.57M).Count(); + + var products = new List(); + for (int i = 0; i < 2000; i++) + { + var product = dbContext.Products.CreateProxy(); + product.Id = (-i).ToString(); + product.Price = 10.57M; + products.Add(product); + } + int oldTotal = dbContext.Products.Where(o => o.Price == 10.57M).Count(); + int rowsInserted = await dbContext.BulkInsertAsync(products); + int newTotal = dbContext.Products.Where(o => o.Price == 10.57M).Count(); + + Assert.IsTrue(rowsInserted == products.Count, "The number of rows inserted must match the count of products list +t"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public async Task With_Trigger() + { + var dbContext = SetupDbContext(false); + var products = new List(); + for (int i = 1; i < 1000; i++) + { + products.Add(new ProductWithTrigger { Id = i.ToString(), Price = 1.57M, StatusString = "InStock" }); + } + + //The return int from BulkInsertAsync() will be off when using triggers + await dbContext.BulkInsertAsync(products, options => + { + options.AutoMapOutput = false; + if (Config.IsSqlServer) + options.BulkCopyOptions = SqlBulkCopyOptions.FireTriggers; + }); + var rowsInserted = dbContext.ProductsWithTrigger.Count(); + + Assert.IsTrue(rowsInserted == products.Count, $"The number of rows inserted must match the count of products ({r +rowsInserted}!={products.Count})"); + } + [TestMethod] + public async Task With_Schema() + { + var dbContext = SetupDbContext(false); + var products = new List(); + for (int i = 1; i < 10000; i++) + { + var key = i.ToString(); + products.Add(new ProductWithCustomSchema + { + Id = key, + Name = $"Product-{key}", + Price = 1.57M + }); + } + int oldTotal = dbContext.ProductsWithCustomSchema.Where(o => o.Price <= 10).Count(); + int rowsInserted = await dbContext.BulkInsertAsync(products); + int newTotal = dbContext.ProductsWithCustomSchema.Where(o => o.Price <= 10).Count(); + + Assert.IsTrue(rowsInserted == products.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public async Task With_Transaction() + { + var dbContext = SetupDbContext(false); + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M }); + } + int oldTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + int rowsInserted, newTotal; + using (var transaction = dbContext.Database.BeginTransaction()) + { + rowsInserted = await dbContext.BulkInsertAsync(orders); + newTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + transaction.Rollback(); + } + int rollbackTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + Assert.IsTrue(rollbackTotal == oldTotal, "The number of rows after the transacation has been rollbacked should m +match the original count"); + } + [TestMethod] + public async Task With_Options_InsertIfNotExists() + { + var dbContext = SetupDbContext(true); + var orders = new List(); + long maxId = dbContext.Orders.Max(o => o.Id); + long expectedRowsInserted = 1000; + int existingRowsToAdd = 100; + long startId = maxId - existingRowsToAdd + 1, endId = maxId + expectedRowsInserted + 1; + for (long i = startId; i < endId; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M }); + } + + int oldTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + int rowsInserted = await dbContext.BulkInsertAsync(orders, new BulkInsertOptions() { InsertIfNotExists = + true }); + int newTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + + Assert.IsTrue(rowsInserted == expectedRowsInserted, "The number of rows inserted must match the count of order l +list"); + Assert.IsTrue(newTotal - oldTotal == expectedRowsInserted, "The new count minus the old count should match the n +number of rows inserted."); + } + [TestMethod] + public async Task With_ValueGenerated_Default() + { + var dbContext = SetupDbContext(false); + var nowDateTime = Config.IsPostgreSql ? DateTime.UtcNow : DateTime.Now; + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M }); + } + int oldTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + int rowsInserted = await dbContext.BulkInsertAsync(orders); + int newTotal = dbContext.Orders.Where(o => o.Price <= 10 && o.DbAddedDateTime > nowDateTime && o.DbActive).Count +t(); + + Assert.IsTrue(rowsInserted == orders.Count, "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } + [TestMethod] + public async Task With_ValueGenerated_Computed() + { + var dbContext = SetupDbContext(false); + var nowDateTime = Config.IsPostgreSql ? DateTime.UtcNow : DateTime.Now; + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M, DbModifiedDateTime = nowDateTime }); + } + int oldTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + int rowsInserted = await dbContext.BulkInsertAsync(orders); + int newTotal = dbContext.Orders.Where(o => o.Price <= 10 && o.DbModifiedDateTime > nowDateTime).Count(); + + Assert.IsTrue(rowsInserted == orders.Count(), "The number of rows inserted must match the count of order list"); + Assert.IsTrue(newTotal - oldTotal == rowsInserted, "The new count minus the old count should match the number of +f rows inserted."); + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\DbContextExtensions\BulkMerge.cs --- + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Common; +using N.EntityFrameworkCore.Extensions.Test.Data; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class BulkMerge : DbContextExtensionsBase +{ + [TestMethod] + public void With_Complex_Key() + { + var dbContext = SetupDbContext(true); + var products = dbContext.ProductsWithComplexKey.Where(o => o.Price == 1.25M).ToList(); + int productsToAdd = 5000; + decimal updatedPrice = 5.25M; + var productsToUpdate = products.ToList(); + foreach (var product in products) + { + product.Price = updatedPrice; + } + for (int i = 0; i < productsToAdd; i++) + { + products.Add(new ProductWithComplexKey { ExternalId = (20000 + i).ToString(), Price = 3.55M }); + } + var result = dbContext.BulkMerge(products); + var allProducts = dbContext.ProductsWithComplexKey.ToList(); + bool areAddedOrdersMerged = true; + bool areUpdatedOrdersMerged = true; + foreach (var product in allProducts) + { + if (productsToUpdate.Contains(product) && product.Price != updatedPrice) + { + areUpdatedOrdersMerged = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == products.Count(), "The number of rows inserted must match the count of orde +er list"); + Assert.IsTrue(result.RowsUpdated == productsToUpdate.Count, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == productsToAdd, "The number of rows added must match"); + Assert.IsTrue(areAddedOrdersMerged, "The orders that were added did not merge correctly"); + Assert.IsTrue(areUpdatedOrdersMerged, "The orders that were updated did not merge correctly"); + } + [TestMethod] + public void With_Default_Options() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Id <= 10000).OrderBy(o => o.Id).ToList(); + int ordersToAdd = 5000; + int ordersToUpdate = orders.Count; + foreach (var order in orders) + { + order.Price = Convert.ToDecimal(order.Id + .25); + } + for (int i = 0; i < ordersToAdd; i++) + { + orders.Add(new Order { Id = 100000 + i, Price = 3.55M }); + } + var result = dbContext.BulkMerge(orders, o => o.UsePermanentTable = true); + var newOrders = dbContext.Orders.OrderBy(o => o.Id).ToList(); + bool areAddedOrdersMerged = true; + bool areUpdatedOrdersMerged = true; + foreach (var newOrder in newOrders.Where(o => o.Id <= 10000).OrderBy(o => o.Id)) + { + if (newOrder.Price != Convert.ToDecimal(newOrder.Id + .25)) + { + areUpdatedOrdersMerged = false; + break; + } + } + foreach (var newOrder in newOrders.Where(o => o.Id >= 500000).OrderBy(o => o.Id)) + { + if (newOrder.Price != 3.55M) + { + areAddedOrdersMerged = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == orders.Count(), "The number of rows inserted must match the count of order + list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(areAddedOrdersMerged, "The orders that were added did not merge correctly"); + Assert.IsTrue(areUpdatedOrdersMerged, "The orders that were updated did not merge correctly"); + } + [TestMethod] + public void With_Inheritance_Tpc() + { + var dbContext = SetupDbContext(true); + var customers = dbContext.TpcPeople.Where(o => o.Id <= 1000).OfType().ToList(); + int customersToAdd = 5000; + int customersToUpdate = customers.Count; + foreach (var customer in customers) + { + customer.FirstName = "BulkMerge_Tpc_Update"; + } + for (int i = 0; i < customersToAdd; i++) + { + customers.Add(new TpcCustomer + { + Id = 10000 + i, + FirstName = "BulkMerge_Tpc_Add", + AddedDate = DateTime.UtcNow + }); + } + var result = dbContext.BulkMerge(customers, options => { options.MergeOnCondition = (s, t) => s.Id == t.Id; }); + int customersAdded = dbContext.TpcPeople.Where(o => o.FirstName == "BulkMerge_Tpc_Add").OfType().Co +ount(); + int customersUpdated = dbContext.TpcPeople.Where(o => o.FirstName == "BulkMerge_Tpc_Update").OfType +>().Count(); + + Assert.IsTrue(result.RowsAffected == customers.Count, "The number of rows inserted must match the count of custo +omer list"); + Assert.IsTrue(result.RowsUpdated == customersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == customersToAdd, "The number of rows added must match"); + Assert.IsTrue(customersToAdd == customersAdded, "The custmoers that were added did not merge correctly"); + Assert.IsTrue(customersToUpdate == customersUpdated, "The customers that were updated did not merge correctly"); + } + [TestMethod] + public void With_Inheritance_Tph() + { + var dbContext = SetupDbContext(true); + var customers = dbContext.TphPeople.Where(o => o.Id <= 1000).OfType().ToList(); + int customersToAdd = 5000; + int customersToUpdate = customers.Count; + foreach (var customer in customers) + { + customer.FirstName = "BulkMerge_Tph_Update"; + } + for (int i = 0; i < customersToAdd; i++) + { + customers.Add(new TphCustomer + { + Id = 10000 + i, + FirstName = "BulkMerge_Tph_Add", + AddedDate = DateTime.UtcNow + }); + } + var result = dbContext.BulkMerge(customers); + int customersAdded = dbContext.TphPeople.Where(o => o.FirstName == "BulkMerge_Tph_Add").OfType().Co +ount(); + int customersUpdated = dbContext.TphPeople.Where(o => o.FirstName == "BulkMerge_Tph_Update").OfType +>().Count(); + + Assert.IsTrue(result.RowsAffected == customers.Count(), "The number of rows inserted must match the count of cus +stomer list"); + Assert.IsTrue(result.RowsUpdated == customersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == customersToAdd, "The number of rows added must match"); + Assert.IsTrue(customersToAdd == customersAdded, "The custmoers that were added did not merge correctly"); + Assert.IsTrue(customersToUpdate == customersUpdated, "The customers that were updated did not merge correctly"); + } + [TestMethod] + public void With_Inheritance_Tpt() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpt); + var customers = dbContext.TptPeople.Where(o => o.Id <= 1000).OfType().ToList(); + int customersToAdd = 5000; + int customersToUpdate = customers.Count; + foreach (var customer in customers) + { + customer.FirstName = "BulkMerge_Tpt_Update"; + } + for (int i = 0; i < customersToAdd; i++) + { + customers.Add(new TptCustomer + { + Id = 10000 + i, + FirstName = "BulkMerge_Tpt_Add", + AddedDate = DateTime.UtcNow + }); + } + var result = dbContext.BulkMerge(customers); + int customersAdded = dbContext.TptPeople.Where(o => o.FirstName == "BulkMerge_Tpt_Add").OfType().Co +ount(); + int customersUpdated = dbContext.TptPeople.Where(o => o.FirstName == "BulkMerge_Tpt_Update").OfType +>().Count(); + + Assert.IsTrue(result.RowsAffected == customers.Count(), "The number of rows inserted must match the count of cus +stomer list"); + Assert.IsTrue(result.RowsUpdated == customersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == customersToAdd, "The number of rows added must match"); + Assert.IsTrue(customersToAdd == customersAdded, "The custmoers that were added did not merge correctly"); + Assert.IsTrue(customersToUpdate == customersUpdated, "The customers that were updated did not merge correctly"); + } + [TestMethod] + public void With_Default_Options_MergeOnCondition() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Id <= 100 && o.ExternalId != null).OrderBy(o => o.Id).ToList(); + int ordersToAdd = 50; + int ordersToUpdate = orders.Count; + foreach (var order in orders) + { + order.Price = Convert.ToDecimal(order.Id + .25); + } + for (int i = 0; i < ordersToAdd; i++) + { + orders.Add(new Order { Id = 100000 + i, Price = 3.55M }); + } + var result = dbContext.BulkMerge(orders, options => { options.MergeOnCondition = (s, t) => s.ExternalId == t.Ext +ternalId; options.BatchSize = 1000; }); + var newOrders = dbContext.Orders.OrderBy(o => o.Id).ToList(); + bool areAddedOrdersMerged = true; + bool areUpdatedOrdersMerged = true; + foreach (var newOrder in newOrders.Where(o => o.Id <= 100 && o.ExternalId != null).OrderBy(o => o.Id)) + { + if (newOrder.Price != Convert.ToDecimal(newOrder.Id + .25)) + { + areUpdatedOrdersMerged = false; + break; + } + } + foreach (var newOrder in newOrders.Where(o => o.Id >= 500000).OrderBy(o => o.Id)) + { + if (newOrder.Price != 3.55M) + { + areAddedOrdersMerged = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == orders.Count(), "The number of rows inserted must match the count of order + list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(areAddedOrdersMerged, "The orders that were added did not merge correctly"); + Assert.IsTrue(areUpdatedOrdersMerged, "The orders that were updated did not merge correctly"); + } + [TestMethod] + public void With_Options_AutoMapIdentity() + { + var dbContext = SetupDbContext(true); + int ordersToUpdate = 3; + int ordersToAdd = 2; + var orders = new List + { + new Order { ExternalId = "id-1", Price=7.10M }, + new Order { ExternalId = "id-2", Price=9.33M }, + new Order { ExternalId = "id-3", Price=3.25M }, + new Order { ExternalId = "id-1000001", Price=2.15M }, + new Order { ExternalId = "id-1000002", Price=5.75M }, + }; + var result = dbContext.BulkMerge(orders, new BulkMergeOptions + { + MergeOnCondition = (s, t) => s.ExternalId == t.ExternalId, + UsePermanentTable = true + }); + bool autoMapIdentityMatched = true; + foreach (var order in orders) + { + if (!dbContext.Orders.Any(o => o.ExternalId == order.ExternalId && o.Price == order.Price)) + { + autoMapIdentityMatched = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == ordersToAdd + ordersToUpdate, "The number of rows inserted must match the c +count of order list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(autoMapIdentityMatched, "The auto mapping of ids of entities that were merged failed to match up") +); + } + [TestMethod] + public void With_Options_AutoMapOutput() + { + var dbContext = SetupDbContext(true); + int ordersToUpdate = 3; + int ordersToAdd = 2; + var orders = new List + { + new Order { ExternalId = "id-1", Price=7.10M }, + new Order { ExternalId = "id-2", Price=9.33M }, + new Order { ExternalId = "id-3", Price=3.25M }, + new Order { ExternalId = "id-1000001", Price=2.15M }, + new Order { ExternalId = "id-1000002", Price=5.75M }, + }; + var result = dbContext.BulkMerge(orders, new BulkMergeOptions + { + MergeOnCondition = (s, t) => s.ExternalId == t.ExternalId, + AutoMapOutput = true + }); + var autoMapIdentityMatched = orders.All(x => x.Id != 0); + + Assert.IsTrue(result.RowsAffected == ordersToAdd + ordersToUpdate, "The number of rows inserted must match the c +count of order list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(autoMapIdentityMatched, "The auto mapping of ids of entities that were merged failed to match up") +); + } + [TestMethod] + public void With_Key() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(o => o.Price == 1.25M).OrderBy(o => o.Id).ToList(); + int productsToAdd = 5000; + var productsToUpdate = products.ToList(); + foreach (var product in products) + { + product.Price = Convert.ToDecimal(product.Id) + .25M; + } + for (int i = 0; i < productsToAdd; i++) + { + products.Add(new Product { Id = (20000 + i).ToString(), Price = 3.55M }); + } + var result = dbContext.BulkMerge(products); + var newProducts = dbContext.Products.OrderBy(o => o.Id).ToList(); + bool areAddedOrdersMerged = true; + bool areUpdatedOrdersMerged = true; + foreach (var newProduct in newProducts.Where(o => productsToUpdate.Select(o => o.Id).Contains(o.Id))) + { + if (newProduct.Price != Convert.ToDecimal(newProduct.Id) + .25M) + { + areUpdatedOrdersMerged = false; + break; + } + } + foreach (var newProduct in newProducts.Where(o => Convert.ToInt32(o.Id) >= 20000).OrderBy(o => o.Id)) + { + if (newProduct.Price != 3.55M) + { + areAddedOrdersMerged = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == products.Count(), "The number of rows inserted must match the count of orde +er list"); + Assert.IsTrue(result.RowsUpdated == productsToUpdate.Count, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == productsToAdd, "The number of rows added must match"); + Assert.IsTrue(areAddedOrdersMerged, "The orders that were added did not merge correctly"); + Assert.IsTrue(areUpdatedOrdersMerged, "The orders that were updated did not merge correctly"); + } + [TestMethod] + public void With_Transaction() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Id <= 10000).OrderBy(o => o.Id).ToList(); + int ordersToAdd = 5000; + int ordersToUpdate = orders.Count; + foreach (var order in orders) + { + order.Price = Convert.ToDecimal(order.Id + .25); + } + for (int i = 0; i < ordersToAdd; i++) + { + orders.Add(new Order { Id = 100000 + i, Price = 3.55M }); + } + BulkMergeResult result; + using (var transaction = dbContext.Database.BeginTransaction()) + { + result = dbContext.BulkMerge(orders); + transaction.Rollback(); + } + int ordersUpdated = dbContext.Orders.Count(o => o.Id <= 10000 && o.Price == ((decimal)o.Id + .25M) && o.Price != += 1.25M); + int ordersAdded = dbContext.Orders.Count(o => o.Id >= 100000); + + Assert.IsTrue(result.RowsAffected == orders.Count(), "The number of rows inserted must match the count of order + list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(ordersAdded == 0, "The number of rows added must equal 0 since transaction was rollbacked"); + Assert.IsTrue(ordersUpdated == 0, "The number of rows updated must equal 0 since transaction was rollbacked"); + } + [TestMethod] + public void With_ValueGenerated_Default() + { + var dbContext = SetupDbContext(false); + var nowDateTime = Config.IsPostgreSql ? DateTime.UtcNow : DateTime.Now; + var orders = new List(); + for (int i = 0; i < 1000; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M }); + } + int oldTotal = dbContext.Orders.Where(o => o.DbAddedDateTime > nowDateTime).Count(); + var mergeResult = dbContext.BulkMerge(orders); + int newTotal = dbContext.Orders.Where(o => o.Price <= 1.57M + && o.DbAddedDateTime > nowDateTime).Count(); + + Assert.IsTrue(mergeResult.RowsInserted == orders.Count, "The number of rows inserted must match the count of ord +der list"); + Assert.IsTrue(newTotal - oldTotal == mergeResult.RowsInserted, "The new count minus the old count should match t +the number of rows inserted."); + } + [TestMethod] + public void With_ValueGenerated_Computed() + { + var dbContext = SetupDbContext(false); + var nowDateTime = Config.IsPostgreSql ? DateTime.UtcNow : DateTime.Now; + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M, DbModifiedDateTime = nowDateTime }); + } + int oldTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + var result = dbContext.BulkMerge(orders); + int newTotal = dbContext.Orders.Where(o => o.Price <= 10 && o.DbModifiedDateTime > nowDateTime).Count(); + + Assert.IsTrue(result.RowsInserted == orders.Count(), "The number of rows inserted must match the count of order + list"); + Assert.IsTrue(newTotal - oldTotal == result.RowsInserted, "The new count minus the old count should match the nu +umber of rows inserted."); + } + [TestMethod] + public void With_Merge_On_Enum() + { + var dbContext = SetupDbContext(true); + dbContext.BulkSaveChanges(); + var nowDateTime = Config.IsPostgreSql ? DateTime.UtcNow : DateTime.Now; + var orders = new List(); + for (int i = 0; i < 20; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M, DbModifiedDateTime = nowDateTime, Status = OrderStatus.Complet +ted }); + } + + var result = dbContext.BulkMerge(orders, options => options.MergeOnCondition = (s, t) => s.Id == t.Id && s.Statu +us == t.Status); + + Assert.AreEqual(1, result.RowsInserted); + Assert.AreEqual(19, result.RowsUpdated); + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\DbContextExtensions\BulkMergeAsync.cs --- + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Common; +using N.EntityFrameworkCore.Extensions.Test.Data; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class BulkMergeAsync : DbContextExtensionsBase +{ + [TestMethod] + public async Task With_Complex_Key() + { + var dbContext = SetupDbContext(true); + var products = dbContext.ProductsWithComplexKey.Where(o => o.Price == 1.25M).ToList(); + int productsToAdd = 5000; + decimal updatedPrice = 5.25M; + var productsToUpdate = products.ToList(); + foreach (var product in products) + { + product.Price = updatedPrice; + } + for (int i = 0; i < productsToAdd; i++) + { + products.Add(new ProductWithComplexKey { ExternalId = (20000 + i).ToString(), Price = 3.55M }); + } + var result = await dbContext.BulkMergeAsync(products); + var allProducts = dbContext.ProductsWithComplexKey.ToList(); + bool areAddedOrdersMerged = true; + bool areUpdatedOrdersMerged = true; + foreach (var product in allProducts) + { + if (productsToUpdate.Contains(product) && product.Price != updatedPrice) + { + areUpdatedOrdersMerged = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == products.Count(), "The number of rows inserted must match the count of orde +er list"); + Assert.IsTrue(result.RowsUpdated == productsToUpdate.Count, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == productsToAdd, "The number of rows added must match"); + Assert.IsTrue(areAddedOrdersMerged, "The orders that were added did not merge correctly"); + Assert.IsTrue(areUpdatedOrdersMerged, "The orders that were updated did not merge correctly"); + } + [TestMethod] + public async Task With_Default_Options() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Id <= 10000).OrderBy(o => o.Id).ToList(); + int ordersToAdd = 5000; + int ordersToUpdate = orders.Count; + foreach (var order in orders) + { + order.Price = Convert.ToDecimal(order.Id + .25); + } + for (int i = 0; i < ordersToAdd; i++) + { + orders.Add(new Order { Id = 100000 + i, Price = 3.55M }); + } + var result = await dbContext.BulkMergeAsync(orders); + var newOrders = dbContext.Orders.OrderBy(o => o.Id).ToList(); + bool areAddedOrdersMerged = true; + bool areUpdatedOrdersMerged = true; + foreach (var newOrder in newOrders.Where(o => o.Id <= 10000).OrderBy(o => o.Id)) + { + if (newOrder.Price != Convert.ToDecimal(newOrder.Id + .25)) + { + areUpdatedOrdersMerged = false; + break; + } + } + foreach (var newOrder in newOrders.Where(o => o.Id >= 500000).OrderBy(o => o.Id)) + { + if (newOrder.Price != 3.55M) + { + areAddedOrdersMerged = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == orders.Count(), "The number of rows inserted must match the count of order + list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(areAddedOrdersMerged, "The orders that were added did not merge correctly"); + Assert.IsTrue(areUpdatedOrdersMerged, "The orders that were updated did not merge correctly"); + } + [TestMethod] + public async Task With_Inheritance_Tpc() + { + var dbContext = SetupDbContext(true); + var customers = dbContext.TpcPeople.Where(o => o.Id <= 1000).OfType().ToList(); + int customersToAdd = 5000; + int customersToUpdate = customers.Count; + foreach (var customer in customers) + { + customer.FirstName = "BulkMerge_Tpc_Update"; + } + for (int i = 0; i < customersToAdd; i++) + { + customers.Add(new TpcCustomer + { + Id = 10000 + i, + FirstName = "BulkMerge_Tpc_Add", + AddedDate = DateTime.UtcNow + }); + } + var result = await dbContext.BulkMergeAsync(customers, options => { options.MergeOnCondition = (s, t) => s.Id == += t.Id; }); + int customersAdded = dbContext.TpcPeople.Where(o => o.FirstName == "BulkMerge_Tpc_Add").OfType().Co +ount(); + int customersUpdated = dbContext.TpcPeople.Where(o => o.FirstName == "BulkMerge_Tpc_Update").OfType +>().Count(); + + Assert.IsTrue(result.RowsAffected == customers.Count, "The number of rows inserted must match the count of custo +omer list"); + Assert.IsTrue(result.RowsUpdated == customersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == customersToAdd, "The number of rows added must match"); + Assert.IsTrue(customersToAdd == customersAdded, "The custmoers that were added did not merge correctly"); + Assert.IsTrue(customersToUpdate == customersUpdated, "The customers that were updated did not merge correctly"); + } + [TestMethod] + public async Task With_Inheritance_Tph() + { + var dbContext = SetupDbContext(true); + var customers = dbContext.TphPeople.Where(o => o.Id <= 1000).OfType().ToList(); + int customersToAdd = 5000; + int customersToUpdate = customers.Count; + foreach (var customer in customers) + { + customer.FirstName = "BulkMerge_Tph_Update"; + } + for (int i = 0; i < customersToAdd; i++) + { + customers.Add(new TphCustomer + { + Id = 10000 + i, + FirstName = "BulkMerge_Tph_Add", + AddedDate = DateTime.UtcNow + }); + } + var result = await dbContext.BulkMergeAsync(customers); + int customersAdded = dbContext.TphPeople.Where(o => o.FirstName == "BulkMerge_Tph_Add").OfType().Co +ount(); + int customersUpdated = dbContext.TphPeople.Where(o => o.FirstName == "BulkMerge_Tph_Update").OfType +>().Count(); + + Assert.IsTrue(result.RowsAffected == customers.Count(), "The number of rows inserted must match the count of cus +stomer list"); + Assert.IsTrue(result.RowsUpdated == customersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == customersToAdd, "The number of rows added must match"); + Assert.IsTrue(customersToAdd == customersAdded, "The custmoers that were added did not merge correctly"); + Assert.IsTrue(customersToUpdate == customersUpdated, "The customers that were updated did not merge correctly"); + } + [TestMethod] + public async Task With_Inheritance_Tpt() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpt); + var customers = dbContext.TptPeople.Where(o => o.Id <= 1000).OfType().ToList(); + int customersToAdd = 5000; + int customersToUpdate = customers.Count; + foreach (var customer in customers) + { + customer.FirstName = "BulkMergeAsync_Tpt_Update"; + } + for (int i = 0; i < customersToAdd; i++) + { + customers.Add(new TptCustomer + { + Id = 10000 + i, + FirstName = "BulkMergeAsync_Tpt_Add", + AddedDate = DateTime.UtcNow + }); + } + var result = await dbContext.BulkMergeAsync(customers); + int customersAdded = dbContext.TptPeople.Where(o => o.FirstName == "BulkMergeAsync_Tpt_Add").OfType +>().Count(); + int customersUpdated = dbContext.TptPeople.Where(o => o.FirstName == "BulkMergeAsync_Tpt_Update").OfType().Count(); + + Assert.IsTrue(result.RowsAffected == customers.Count(), "The number of rows inserted must match the count of cus +stomer list"); + Assert.IsTrue(result.RowsUpdated == customersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == customersToAdd, "The number of rows added must match"); + Assert.IsTrue(customersToAdd == customersAdded, "The custmoers that were added did not merge correctly"); + Assert.IsTrue(customersToUpdate == customersUpdated, "The customers that were updated did not merge correctly"); + } + [TestMethod] + public async Task With_Default_Options_MergeOnCondition() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Id <= 100 && o.ExternalId != null).OrderBy(o => o.Id).ToList(); + int ordersToAdd = 50; + int ordersToUpdate = orders.Count; + foreach (var order in orders) + { + order.Price = Convert.ToDecimal(order.Id + .25); + } + for (int i = 0; i < ordersToAdd; i++) + { + orders.Add(new Order { Id = 100000 + i, Price = 3.55M }); + } + var result = await dbContext.BulkMergeAsync(orders, options => { options.MergeOnCondition = (s, t) => s.External +lId == t.ExternalId; options.BatchSize = 1000; }); + var newOrders = dbContext.Orders.OrderBy(o => o.Id).ToList(); + bool areAddedOrdersMerged = true; + bool areUpdatedOrdersMerged = true; + foreach (var newOrder in newOrders.Where(o => o.Id <= 100 && o.ExternalId != null).OrderBy(o => o.Id)) + { + if (newOrder.Price != Convert.ToDecimal(newOrder.Id + .25)) + { + areUpdatedOrdersMerged = false; + break; + } + } + foreach (var newOrder in newOrders.Where(o => o.Id >= 500000).OrderBy(o => o.Id)) + { + if (newOrder.Price != 3.55M) + { + areAddedOrdersMerged = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == orders.Count(), "The number of rows inserted must match the count of order + list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(areAddedOrdersMerged, "The orders that were added did not merge correctly"); + Assert.IsTrue(areUpdatedOrdersMerged, "The orders that were updated did not merge correctly"); + } + [TestMethod] + public async Task With_Options_AutoMapIdentity() + { + var dbContext = SetupDbContext(true); + int ordersToUpdate = 3; + int ordersToAdd = 2; + var orders = new List + { + new Order { ExternalId = "id-1", Price=7.10M }, + new Order { ExternalId = "id-2", Price=9.33M }, + new Order { ExternalId = "id-3", Price=3.25M }, + new Order { ExternalId = "id-1000001", Price=2.15M }, + new Order { ExternalId = "id-1000002", Price=5.75M }, + }; + var result = await dbContext.BulkMergeAsync(orders, new BulkMergeOptions + { + MergeOnCondition = (s, t) => s.ExternalId == t.ExternalId, + UsePermanentTable = true + }); + bool autoMapIdentityMatched = true; + foreach (var order in orders) + { + if (!dbContext.Orders.Any(o => o.ExternalId == order.ExternalId && o.Price == order.Price)) + { + autoMapIdentityMatched = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == ordersToAdd + ordersToUpdate, "The number of rows inserted must match the c +count of order list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(autoMapIdentityMatched, "The auto mapping of ids of entities that were merged failed to match up") +); + } + [TestMethod] + public async Task With_Options_AutoMapOutput() + { + var dbContext = SetupDbContext(true); + int ordersToUpdate = 3; + int ordersToAdd = 2; + var orders = new List + { + new Order { ExternalId = "id-1", Price=7.10M }, + new Order { ExternalId = "id-2", Price=9.33M }, + new Order { ExternalId = "id-3", Price=3.25M }, + new Order { ExternalId = "id-1000001", Price=2.15M }, + new Order { ExternalId = "id-1000002", Price=5.75M }, + }; + var result = await dbContext.BulkMergeAsync(orders, new BulkMergeOptions + { + MergeOnCondition = (s, t) => s.ExternalId == t.ExternalId, + AutoMapOutput = true + }); + var autoMapIdentityMatched = orders.All(x => x.Id != 0); + + Assert.IsTrue(result.RowsAffected == ordersToAdd + ordersToUpdate, "The number of rows inserted must match the c +count of order list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(autoMapIdentityMatched, "The auto mapping of ids of entities that were merged failed to match up") +); + } + [TestMethod] + public async Task With_Key() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(o => o.Price == 1.25M).OrderBy(o => o.Id).ToList(); + int productsToAdd = 5000; + var productsToUpdate = products.ToList(); + foreach (var product in products) + { + product.Price = Convert.ToDecimal(product.Id) + .25M; + } + for (int i = 0; i < productsToAdd; i++) + { + products.Add(new Product { Id = (20000 + i).ToString(), Price = 3.55M }); + } + var result = await dbContext.BulkMergeAsync(products); + var newProducts = dbContext.Products.OrderBy(o => o.Id).ToList(); + bool areAddedOrdersMerged = true; + bool areUpdatedOrdersMerged = true; + foreach (var newProduct in newProducts.Where(o => productsToUpdate.Select(o => o.Id).Contains(o.Id))) + { + if (newProduct.Price != Convert.ToDecimal(newProduct.Id) + .25M) + { + areUpdatedOrdersMerged = false; + break; + } + } + foreach (var newProduct in newProducts.Where(o => Convert.ToInt32(o.Id) >= 20000).OrderBy(o => o.Id)) + { + if (newProduct.Price != 3.55M) + { + areAddedOrdersMerged = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == products.Count(), "The number of rows inserted must match the count of orde +er list"); + Assert.IsTrue(result.RowsUpdated == productsToUpdate.Count, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == productsToAdd, "The number of rows added must match"); + Assert.IsTrue(areAddedOrdersMerged, "The orders that were added did not merge correctly"); + Assert.IsTrue(areUpdatedOrdersMerged, "The orders that were updated did not merge correctly"); + } + [TestMethod] + public async Task With_Transaction() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Id <= 10000).OrderBy(o => o.Id).ToList(); + int ordersToAdd = 5000; + int ordersToUpdate = orders.Count; + foreach (var order in orders) + { + order.Price = Convert.ToDecimal(order.Id + .25); + } + for (int i = 0; i < ordersToAdd; i++) + { + orders.Add(new Order { Id = 100000 + i, Price = 3.55M }); + } + BulkMergeResult result; +using (var transaction = dbContext.Database.BeginTransaction()) + { + result = await dbContext.BulkMergeAsync(orders); + transaction.Rollback(); + } + int ordersUpdated = dbContext.Orders.Count(o => o.Id <= 10000 && o.Price == ((decimal)o.Id + .25M) && o.Price != += 1.25M); + int ordersAdded = dbContext.Orders.Count(o => o.Id >= 100000); + + Assert.IsTrue(result.RowsAffected == orders.Count(), "The number of rows inserted must match the count of order + list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(ordersAdded == 0, "The number of rows added must equal 0 since transaction was rollbacked"); + Assert.IsTrue(ordersUpdated == 0, "The number of rows updated must equal 0 since transaction was rollbacked"); + } + [TestMethod] + public async Task With_ValueGenerated_Default() + { + var dbContext = SetupDbContext(false); + var nowDateTime = Config.IsPostgreSql ? DateTime.UtcNow : DateTime.Now; + var orders = new List(); + for (int i = 0; i < 1000; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M }); + } + int oldTotal = dbContext.Orders.Where(o => o.DbAddedDateTime > nowDateTime).Count(); + var mergeResult = await dbContext.BulkMergeAsync(orders); + int newTotal = dbContext.Orders.Where(o => o.Price <= 1.57M + && o.DbAddedDateTime > nowDateTime).Count(); + + Assert.IsTrue(mergeResult.RowsInserted == orders.Count, "The number of rows inserted must match the count of ord +der list"); + Assert.IsTrue(newTotal - oldTotal == mergeResult.RowsInserted, "The new count minus the old count should match t +the number of rows inserted."); + } + [TestMethod] + public async Task With_ValueGenerated_Computed() + { + var dbContext = SetupDbContext(false); + var nowDateTime = Config.IsPostgreSql ? DateTime.UtcNow : DateTime.Now; + var orders = new List(); + for (int i = 0; i < 20000; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M, DbModifiedDateTime = nowDateTime }); + } + int oldTotal = dbContext.Orders.Where(o => o.Price <= 10).Count(); + var result = await dbContext.BulkMergeAsync(orders); + int newTotal = dbContext.Orders.Where(o => o.Price <= 10 && o.DbModifiedDateTime > nowDateTime).Count(); + + Assert.IsTrue(result.RowsInserted == orders.Count(), "The number of rows inserted must match the count of order + list"); + Assert.IsTrue(newTotal - oldTotal == result.RowsInserted, "The new count minus the old count should match the nu +umber of rows inserted."); + } + [TestMethod] + public async Task With_Merge_On_Enum() + { + var dbContext = SetupDbContext(true); + await dbContext.BulkSaveChangesAsync(); + var nowDateTime = Config.IsPostgreSql ? DateTime.UtcNow : DateTime.Now; + var orders = new List(); + for (int i = 0; i < 20; i++) + { + orders.Add(new Order { Id = i, Price = 1.57M, DbModifiedDateTime = nowDateTime, Status = OrderStatus.Complet +ted }); + } + + var result = await dbContext.BulkMergeAsync(orders, options => options.MergeOnCondition = (s, t) => s.Id == t.Id +d && s.Status == t.Status); + + Assert.AreEqual(1, result.RowsInserted); + Assert.AreEqual(19, result.RowsUpdated); + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\DbContextExtensions\BulkSaveChanges.cs --- + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Data; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class BulkSaveChanges : DbContextExtensionsBase +{ + [TestMethod] + public void With_Default_Options() + { + var dbContext = SetupDbContext(true); + var totalCount = dbContext.Orders.Count(); + + //Add new orders + var ordersToAdd = new List(); + for (int i = 0; i < 2000; i++) + { + ordersToAdd.Add(new Order { Id = -i, Price = 10.57M }); + } + dbContext.Orders.AddRange(ordersToAdd); + + //Delete orders + var ordersToDelete = dbContext.Orders.Where(o => o.Price <= 5).ToList(); + dbContext.Orders.RemoveRange(ordersToDelete); + + //Update existing orders + var ordersToUpdate = dbContext.Orders.Where(o => o.Price > 5 && o.Price <= 10).ToList(); + foreach (var orderToUpdate in ordersToUpdate) + { + orderToUpdate.Price = 99M; + } + + + int rowsAffected = dbContext.BulkSaveChanges(); + int ordersAddedCount = dbContext.Orders.Where(o => o.Price == 10.57M).Count(); + int ordersDeletedCount = dbContext.Orders.Where(o => o.Price <= 5).Count(); + int ordersUpdatedCount = dbContext.Orders.Where(o => o.Price == 99M).Count(); + + Assert.IsTrue(rowsAffected == ordersToAdd.Count + ordersToDelete.Count + ordersToUpdate.Count, "The number of ro +ows affected must equal the sum of entities added, deleted and updated"); + Assert.IsTrue(ordersAddedCount == ordersToAdd.Count(), "The number of orders to add did not match what was expec +cted."); + Assert.IsTrue(ordersDeletedCount == 0, "The number of orders that was deleted did not match what was expected.") +); + Assert.IsTrue(ordersUpdatedCount == ordersToUpdate.Count(), "The number of orders that was updated did not match +h what was expected."); + } + [TestMethod] + public void With_Add_Changes() + { + var dbContext = SetupDbContext(true); + var oldTotalCount = dbContext.Orders.Where(o => o.Price == 10.57M).Count(); + + var ordersToAdd = new List(); + for (int i = 0; i < 2000; i++) + { + ordersToAdd.Add(new Order { Id = -i, Price = 10.57M }); + } + dbContext.Orders.AddRange(ordersToAdd); + + int rowsAffected = dbContext.BulkSaveChanges(); + int newTotalCount = dbContext.Orders.Where(o => o.Price == 10.57M).Count(); + + Assert.IsTrue(ordersToAdd.Where(o => o.Id <= 0).Count() == 0, "Primary key should have been updated for all enti +ities"); + Assert.IsTrue(rowsAffected == ordersToAdd.Count, "The number of rows affected must equal the sum of entities add +ded, deleted and updated"); + Assert.IsTrue(oldTotalCount + ordersToAdd.Count == newTotalCount, "The number of orders to add did not match wha +at was expected."); + } + [TestMethod] + public void With_Delete_Changes() + { + var dbContext = SetupDbContext(true); + var oldTotalCount = dbContext.Orders.Where(o => o.Price <= 5).Count(); + + //Delete orders + var ordersToDelete = dbContext.Orders.Where(o => o.Price <= 5).ToList(); + dbContext.Orders.RemoveRange(ordersToDelete); + + int rowsAffected = dbContext.BulkSaveChanges(); + int newTotalCount = dbContext.Orders.Where(o => o.Price <= 5).Count(); + + Assert.IsTrue(rowsAffected == ordersToDelete.Count, "The number of rows affected must equal the sum of entities + added, deleted and updated"); + Assert.IsTrue(oldTotalCount - ordersToDelete.Count == newTotalCount, "The number of orders to add did not match + what was expected."); + } + [TestMethod] + public void With_Update_Changes() + { + var dbContext = SetupDbContext(true); + var oldTotalCount = dbContext.Orders.Where(o => o.Price <= 10).Count(); + + //Update existing orders + var ordersToUpdate = dbContext.Orders.Where(o => o.Price <= 10).ToList(); + foreach (var orderToUpdate in ordersToUpdate) + { + orderToUpdate.Price = 99M; + } + + int rowsAffected = dbContext.BulkSaveChanges(); + int newTotalCount = dbContext.Orders.Where(o => o.Price <= 10).Count(); + int expectedCount = dbContext.Orders.Where(o => o.Price == 99M).Count(); + + Assert.IsTrue(rowsAffected == ordersToUpdate.Count, "The number of rows affected must equal the sum of entities + added, deleted and updated"); + Assert.IsTrue(oldTotalCount - ordersToUpdate.Count == newTotalCount, "The number of orders to add did not match + what was expected."); + } + [TestMethod] + public void With_Inheritance_Tpc() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpc); + //Delete Customers + var customersToDelete = dbContext.TpcPeople.OfType().Where(o => o.Id <= 1000); + int expectedRowsDeleted = customersToDelete.Count(); + dbContext.TpcPeople.RemoveRange(customersToDelete); + //Update Customers + var customersToUpdate = dbContext.TpcPeople.OfType().Where(o => o.Id > 1000 && o.Id <= 1500); + int expectedRowsUpdated = customersToUpdate.Count(); + foreach (var customerToUpdate in customersToUpdate) + { + customerToUpdate.FirstName = "CustomerUpdated"; + } + //Add New Customers + long maxId = dbContext.TpcPeople.OfType().Max(o => o.Id); + int expectedRowsAdded = 3000; + for (long i = maxId + 1; i <= maxId + expectedRowsAdded; i++) + { + dbContext.TpcPeople.Add(new TpcCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + AddedDate = DateTime.UtcNow + }); + } + int rowsAffected = dbContext.BulkSaveChanges(); + int rowsAfterDelete = customersToDelete.Count(); + int rowsUpdated = customersToUpdate.Where(o => o.FirstName == "CustomerUpdated").Count(); + int rowsAdded = dbContext.TpcPeople.OfType().Where(o => o.Id > maxId).Count(); + int expectedRowsAffected = expectedRowsDeleted + expectedRowsUpdated + expectedRowsAdded; + + Assert.IsTrue(rowsAfterDelete == 0, "The number of rows deleted not not match what was expected."); + Assert.IsTrue(rowsUpdated == expectedRowsUpdated, "The number of rows updated not not match what was expected.") +); + Assert.IsTrue(rowsAdded == expectedRowsAdded, "The number of rows added not not match what was expected."); + Assert.IsTrue(rowsAffected == expectedRowsAffected, "The new count minus the old count should match the number o +of rows inserted."); + } + [TestMethod] + public void With_Inheritance_Tph() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tph); + //Delete Customers + var customersToDelete = dbContext.TphCustomers.Where(o => o.Id <= 1000); + int expectedRowsDeleted = customersToDelete.Count(); + dbContext.TphCustomers.RemoveRange(customersToDelete); + //Update Customers + var customersToUpdate = dbContext.TphCustomers.Where(o => o.Id > 1000 && o.Id <= 1500); + int expectedRowsUpdated = customersToUpdate.Count(); + foreach (var customerToUpdate in customersToUpdate) + { + customerToUpdate.FirstName = "CustomerUpdated"; + } + //Add New Customers + long maxId = dbContext.TphPeople.Max(o => o.Id); + int expectedRowsAdded = 3000; + for (long i = maxId + 1; i <= maxId + expectedRowsAdded; i++) + { + dbContext.TphCustomers.Add(new TphCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + AddedDate = DateTime.UtcNow + }); + } + int rowsAffected = dbContext.BulkSaveChanges(); + int rowsAfterDelete = customersToDelete.Count(); + int rowsUpdated = customersToUpdate.Where(o => o.FirstName == "CustomerUpdated").Count(); + int rowsAdded = dbContext.TphCustomers.Where(o => o.Id > maxId).Count(); + int expectedRowsAffected = expectedRowsDeleted + expectedRowsUpdated + expectedRowsAdded; + + Assert.IsTrue(expectedRowsDeleted > 0, "The expected number of rows to delete must be greater than zero."); + Assert.IsTrue(rowsAfterDelete == 0, "The number of rows deleted not not match what was expected."); + Assert.IsTrue(rowsUpdated == expectedRowsUpdated, "The number of rows updated not not match what was expected.") +); + Assert.IsTrue(rowsAdded == expectedRowsAdded, "The number of rows added not not match what was expected."); + Assert.IsTrue(rowsAffected == expectedRowsAffected, "The new count minus the old count should match the number o +of rows inserted."); + } + [TestMethod] + public void With_Inheritance_Tpt() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpt); + //Delete Customers + var customersToDelete = dbContext.TptCustomers.Where(o => o.Id <= 1000); + int expectedRowsDeleted = customersToDelete.Count(); + dbContext.TptCustomers.RemoveRange(customersToDelete); + //Update Customers + var customersToUpdate = dbContext.TptCustomers.Where(o => o.Id > 1000 && o.Id <= 1500); + int expectedRowsUpdated = customersToUpdate.Count(); + foreach (var customerToUpdate in customersToUpdate) + { + customerToUpdate.Email = "name@domain.com"; + customerToUpdate.FirstName = "CustomerUpdated"; + } + //Add New Customers + long maxId = dbContext.TptPeople.Max(o => o.Id); + int expectedRowsAdded = 3000; + for (long i = maxId + 1; i <= maxId + expectedRowsAdded; i++) + { + dbContext.TptCustomers.Add(new TptCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + AddedDate = DateTime.UtcNow + }); + } + int rowsAffected = dbContext.BulkSaveChanges(); + int rowsAfterDelete = customersToDelete.Count(); + int rowsUpdated = customersToUpdate.Where(o => o.FirstName == "CustomerUpdated").Count(); + int rowsAdded = dbContext.TptCustomers.Where(o => o.Id > maxId).Count(); + int expectedRowsAffected = expectedRowsDeleted + expectedRowsUpdated + expectedRowsAdded; + + Assert.IsTrue(rowsAfterDelete == 0, "The number of rows deleted not not match what was expected."); + Assert.IsTrue(rowsUpdated == expectedRowsUpdated, "The number of rows updated not not match what was expected.") +); + Assert.IsTrue(rowsAdded == expectedRowsAdded, "The number of rows added not not match what was expected."); + Assert.IsTrue(rowsAffected == expectedRowsAffected, "The new count minus the old count should match the number o +of rows inserted."); + } + [TestMethod] + public void With_Schema() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Schema); + var totalCount = dbContext.ProductsWithCustomSchema.Count(); + + //Add new products + var productsToAdd = new List(); + for (int i = 0; i < 2000; i++) + { + productsToAdd.Add(new ProductWithCustomSchema { Id = (-i).ToString(), Price = 10.57M }); + } + dbContext.ProductsWithCustomSchema.AddRange(productsToAdd); + + //Delete products + var productsToDelete = dbContext.ProductsWithCustomSchema.Where(o => o.Price <= 5).ToList(); + dbContext.ProductsWithCustomSchema.RemoveRange(productsToDelete); + + //Update existing products + var productsToUpdate = dbContext.ProductsWithCustomSchema.Where(o => o.Price > 5 && o.Price <= 10).ToList(); + foreach (var productToUpdate in productsToUpdate) + { + productToUpdate.Price = 99M; + } + + int rowsAffected = dbContext.BulkSaveChanges(); + int productsAddedCount = dbContext.ProductsWithCustomSchema.Where(o => o.Price == 10.57M).Count(); + int productsDeletedCount = dbContext.ProductsWithCustomSchema.Where(o => o.Price <= 5).Count(); + int productsUpdatedCount = dbContext.ProductsWithCustomSchema.Where(o => o.Price == 99M).Count(); + + Assert.IsTrue(rowsAffected == productsToAdd.Count + productsToDelete.Count + productsToUpdate.Count, "The number +r of rows affected must equal the sum of entities added, deleted and updated"); + Assert.IsTrue(productsAddedCount == productsToAdd.Count(), "The number of products to add did not match what was +s expected."); + Assert.IsTrue(productsDeletedCount == 0, "The number of products that was deleted did not match what was expecte +ed."); + Assert.IsTrue(productsUpdatedCount == productsToUpdate.Count(), "The number of products that was updated did not +t match what was expected."); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\DbContextExtensions\BulkSaveChangesAsync.cs --- + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.EntityFrameworkCore; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Data; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class BulkSaveChangesAsync : DbContextExtensionsBase +{ + [TestMethod] + public async Task With_Default_Options() + { + var dbContext = SetupDbContext(true); + var totalCount = dbContext.Orders.Count(); + + //Add new orders + var ordersToAdd = new List(); + for (int i = 0; i < 2000; i++) + { + ordersToAdd.Add(new Order { Id = -i, Price = 10.57M }); + } + dbContext.Orders.AddRange(ordersToAdd); + + //Delete orders + var ordersToDelete = dbContext.Orders.Where(o => o.Price <= 5).ToList(); + dbContext.Orders.RemoveRange(ordersToDelete); + + //Update existing orders + var ordersToUpdate = dbContext.Orders.Where(o => o.Price > 5 && o.Price <= 10).ToList(); + foreach (var orderToUpdate in ordersToUpdate) + { + orderToUpdate.Price = 99M; + } + + + int rowsAffected = await dbContext.BulkSaveChangesAsync(); + int ordersAddedCount = dbContext.Orders.Where(o => o.Price == 10.57M).Count(); + int ordersDeletedCount = dbContext.Orders.Where(o => o.Price <= 5).Count(); + int ordersUpdatedCount = dbContext.Orders.Where(o => o.Price == 99M).Count(); + + Assert.IsTrue(rowsAffected == ordersToAdd.Count + ordersToDelete.Count + ordersToUpdate.Count, "The number of ro +ows affected must equal the sum of entities added, deleted and updated"); + Assert.IsTrue(ordersAddedCount == ordersToAdd.Count(), "The number of orders to add did not match what was expec +cted."); + Assert.IsTrue(ordersDeletedCount == 0, "The number of orders that was deleted did not match what was expected.") +); + Assert.IsTrue(ordersUpdatedCount == ordersToUpdate.Count(), "The number of orders that was updated did not match +h what was expected."); + } + [TestMethod] + public async Task With_Add_Changes() + { + var dbContext = SetupDbContext(true); + var oldTotalCount = dbContext.Orders.Where(o => o.Price == 10.57M).Count(); + + var ordersToAdd = new List(); + for (int i = 0; i < 2000; i++) + { + ordersToAdd.Add(new Order { Id = -i, Price = 10.57M }); + } + dbContext.Orders.AddRange(ordersToAdd); + + int rowsAffected = await dbContext.BulkSaveChangesAsync(); + int newTotalCount = dbContext.Orders.Where(o => o.Price == 10.57M).Count(); + + Assert.IsTrue(ordersToAdd.Where(o => o.Id <= 0).Count() == 0, "Primary key should have been updated for all enti +ities"); + Assert.IsTrue(rowsAffected == ordersToAdd.Count, "The number of rows affected must equal the sum of entities add +ded, deleted and updated"); + Assert.IsTrue(oldTotalCount + ordersToAdd.Count == newTotalCount, "The number of orders to add did not match wha +at was expected."); + } + [TestMethod] + public async Task With_Delete_Changes() + { + var dbContext = SetupDbContext(true); + var oldTotalCount = dbContext.Orders.Where(o => o.Price <= 5).Count(); + + //Delete orders + var ordersToDelete = dbContext.Orders.Where(o => o.Price <= 5).ToList(); + dbContext.Orders.RemoveRange(ordersToDelete); + + int rowsAffected = await dbContext.BulkSaveChangesAsync(); + int newTotalCount = dbContext.Orders.Where(o => o.Price <= 5).Count(); + + Assert.IsTrue(rowsAffected == ordersToDelete.Count, "The number of rows affected must equal the sum of entities + added, deleted and updated"); + Assert.IsTrue(oldTotalCount - ordersToDelete.Count == newTotalCount, "The number of orders to add did not match + what was expected."); + } + [TestMethod] + public async Task With_Update_Changes() + { + var dbContext = SetupDbContext(true); + var oldTotalCount = dbContext.Orders.Where(o => o.Price <= 10).Count(); + + //Update existing orders + var ordersToUpdate = dbContext.Orders.Where(o => o.Price <= 10).ToList(); + foreach (var orderToUpdate in ordersToUpdate) + { + orderToUpdate.Price = 99M; + } + + int rowsAffected = await dbContext.BulkSaveChangesAsync(); + int newTotalCount = dbContext.Orders.Where(o => o.Price <= 10).Count(); + int expectedCount = dbContext.Orders.Where(o => o.Price == 99M).Count(); + + Assert.IsTrue(rowsAffected == ordersToUpdate.Count, "The number of rows affected must equal the sum of entities + added, deleted and updated"); + Assert.IsTrue(oldTotalCount - ordersToUpdate.Count == newTotalCount, "The number of orders to add did not match + what was expected."); + } + [TestMethod] + public async Task With_Inheritance_Tpc() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpc); + //Delete Customers + var customersToDelete = dbContext.TpcPeople.OfType().Where(o => o.Id <= 1000); + int expectedRowsDeleted = customersToDelete.Count(); + dbContext.TpcPeople.RemoveRange(customersToDelete); + //Update Customers + var customersToUpdate = dbContext.TpcPeople.OfType().Where(o => o.Id > 1000 && o.Id <= 1500); + int expectedRowsUpdated = customersToUpdate.Count(); + foreach (var customerToUpdate in customersToUpdate) + { + customerToUpdate.FirstName = "CustomerUpdated"; + } + //Add New Customers + long maxId = dbContext.TpcPeople.OfType().Max(o => o.Id); + int expectedRowsAdded = 3000; + for (long i = maxId + 1; i <= maxId + expectedRowsAdded; i++) + { + dbContext.TpcPeople.Add(new TpcCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + AddedDate = DateTime.UtcNow + }); + } + int rowsAffected = await dbContext.BulkSaveChangesAsync(); + int rowsAfterDelete = customersToDelete.Count(); + int rowsUpdated = customersToUpdate.Where(o => o.FirstName == "CustomerUpdated").Count(); + int rowsAdded = dbContext.TpcPeople.OfType().Where(o => o.Id > maxId).Count(); + int expectedRowsAffected = expectedRowsDeleted + expectedRowsUpdated + expectedRowsAdded; + + Assert.IsTrue(rowsAfterDelete == 0, "The number of rows deleted not not match what was expected."); + Assert.IsTrue(rowsUpdated == expectedRowsUpdated, "The number of rows updated not not match what was expected.") +); + Assert.IsTrue(rowsAdded == expectedRowsAdded, "The number of rows added not not match what was expected."); + Assert.IsTrue(rowsAffected == expectedRowsAffected, "The new count minus the old count should match the number o +of rows inserted."); + } + [TestMethod] + public async Task With_Inheritance_Tph() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tph); + //Delete Customers + var customersToDelete = dbContext.TphCustomers.Where(o => o.Id <= 1000); + int expectedRowsDeleted = customersToDelete.Count(); + dbContext.TphCustomers.RemoveRange(customersToDelete); + //Update Customers + var customersToUpdate = dbContext.TphCustomers.Where(o => o.Id > 1000 && o.Id <= 1500); + int expectedRowsUpdated = customersToUpdate.Count(); + foreach (var customerToUpdate in customersToUpdate) + { + customerToUpdate.FirstName = "CustomerUpdated"; + } + //Add New Customers + long maxId = dbContext.TphPeople.Max(o => o.Id); + int expectedRowsAdded = 3000; + for (long i = maxId + 1; i <= maxId + expectedRowsAdded; i++) + { + dbContext.TphCustomers.Add(new TphCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + AddedDate = DateTime.UtcNow + }); + } + int rowsAffected = await dbContext.BulkSaveChangesAsync(); + int rowsAfterDelete = customersToDelete.Count(); + int rowsUpdated = customersToUpdate.Where(o => o.FirstName == "CustomerUpdated").Count(); + int rowsAdded = dbContext.TphCustomers.Where(o => o.Id > maxId).Count(); + int expectedRowsAffected = expectedRowsDeleted + expectedRowsUpdated + expectedRowsAdded; + + Assert.IsTrue(expectedRowsDeleted > 0, "The expected number of rows to delete must be greater than zero."); + Assert.IsTrue(rowsAfterDelete == 0, "The number of rows deleted not not match what was expected."); + Assert.IsTrue(rowsUpdated == expectedRowsUpdated, "The number of rows updated not not match what was expected.") +); + Assert.IsTrue(rowsAdded == expectedRowsAdded, "The number of rows added not not match what was expected."); + Assert.IsTrue(rowsAffected == expectedRowsAffected, "The new count minus the old count should match the number o +of rows inserted."); + } + [TestMethod] + public async Task With_Inheritance_Tpt() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpt); + //Delete Customers + var customersToDelete = dbContext.TptCustomers.Where(o => o.Id <= 1000); + int expectedRowsDeleted = customersToDelete.Count(); + dbContext.TptCustomers.RemoveRange(customersToDelete); + //Update Customers + var customersToUpdate = dbContext.TptCustomers.Where(o => o.Id > 1000 && o.Id <= 1500); + int expectedRowsUpdated = customersToUpdate.Count(); + foreach (var customerToUpdate in customersToUpdate) + { + customerToUpdate.Email = "name@domain.com"; + customerToUpdate.FirstName = "CustomerUpdated"; + } + //Add New Customers + long maxId = dbContext.TptPeople.Max(o => o.Id); + int expectedRowsAdded = 3000; + for (long i = maxId + 1; i <= maxId + expectedRowsAdded; i++) + { + dbContext.TptCustomers.Add(new TptCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + AddedDate = DateTime.UtcNow + }); + } + int rowsAffected = await dbContext.BulkSaveChangesAsync(); + int rowsAfterDelete = customersToDelete.Count(); + int rowsUpdated = customersToUpdate.Where(o => o.FirstName == "CustomerUpdated").Count(); + int rowsAdded = dbContext.TptCustomers.Where(o => o.Id > maxId).Count(); + int expectedRowsAffected = expectedRowsDeleted + expectedRowsUpdated + expectedRowsAdded; + + Assert.IsTrue(rowsAfterDelete == 0, "The number of rows deleted not not match what was expected."); + Assert.IsTrue(rowsUpdated == expectedRowsUpdated, "The number of rows updated not not match what was expected.") +); + Assert.IsTrue(rowsAdded == expectedRowsAdded, "The number of rows added not not match what was expected."); + Assert.IsTrue(rowsAffected == expectedRowsAffected, "The new count minus the old count should match the number o +of rows inserted."); + } + [TestMethod] + public async Task With_Schema() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Schema); + var totalCount = await dbContext.ProductsWithCustomSchema.CountAsync(); + + //Add new products + var productsToAdd = new List(); + for (int i = 0; i < 2000; i++) + { + productsToAdd.Add(new ProductWithCustomSchema { Id = (-i).ToString(), Price = 10.57M }); + } + dbContext.ProductsWithCustomSchema.AddRange(productsToAdd); + + //Delete products + var productsToDelete = dbContext.ProductsWithCustomSchema.Where(o => o.Price <= 5).ToList(); + dbContext.ProductsWithCustomSchema.RemoveRange(productsToDelete); + + //Update existing products + var productsToUpdate = dbContext.ProductsWithCustomSchema.Where(o => o.Price > 5 && o.Price <= 10).ToList(); + foreach (var productToUpdate in productsToUpdate) + { + productToUpdate.Price = 99M; + } + + int rowsAffected = await dbContext.BulkSaveChangesAsync(); + int productsAddedCount = await dbContext.ProductsWithCustomSchema.Where(o => o.Price == 10.57M).CountAsync(); + int productsDeletedCount = await dbContext.ProductsWithCustomSchema.Where(o => o.Price <= 5).CountAsync(); + int productsUpdatedCount = await dbContext.ProductsWithCustomSchema.Where(o => o.Price == 99M).CountAsync(); + + Assert.IsTrue(rowsAffected == productsToAdd.Count + productsToDelete.Count + productsToUpdate.Count, "The number +r of rows affected must equal the sum of entities added, deleted and updated"); + Assert.IsTrue(productsAddedCount == productsToAdd.Count(), "The number of products to add did not match what was +s expected."); + Assert.IsTrue(productsDeletedCount == 0, "The number of products that was deleted did not match what was expecte +ed."); + Assert.IsTrue(productsUpdatedCount == productsToUpdate.Count(), "The number of products that was updated did not +t match what was expected."); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\DbContextExtensions\BulkSync.cs --- + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Data; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class BulkSync : DbContextExtensionsBase +{ + [TestMethod] + public void With_Default_Options() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Count(); + var orders = dbContext.Orders.Where(o => o.Id <= 10000).OrderBy(o => o.Id).ToList(); + int ordersToAdd = 5000; + int ordersToUpdate = orders.Count; + foreach (var order in orders) + { + order.Price = Convert.ToDecimal(order.Id + .25); + } + for (int i = 0; i < ordersToAdd; i++) + { + orders.Add(new Order { Id = 100000 + i, Price = 3.55M }); + } + var result = dbContext.BulkSync(orders); + var newOrders = dbContext.Orders.OrderBy(o => o.Id).ToList(); + bool areAddedOrdersMerged = true; + bool areUpdatedOrdersMerged = true; + foreach (var newOrder in newOrders.Where(o => o.Id <= 10000).OrderBy(o => o.Id)) + { + if (newOrder.Price != Convert.ToDecimal(newOrder.Id + .25)) + { + areUpdatedOrdersMerged = false; + break; + } + } + foreach (var newOrder in newOrders.Where(o => o.Id >= 500000).OrderBy(o => o.Id)) + { + if (newOrder.Price != 3.55M) + { + areAddedOrdersMerged = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == oldTotal + ordersToAdd, "The number of rows inserted must match the count o +of order list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(result.RowsDeleted == oldTotal - orders.Count() + ordersToAdd, "The number of rows deleted must ma +atch the difference from the total existing orders to the new orders to add/update"); + Assert.IsTrue(areAddedOrdersMerged, "The orders that were added did not merge correctly"); + Assert.IsTrue(areUpdatedOrdersMerged, "The orders that were updated did not merge correctly"); + } + [TestMethod] + public void With_Inheritance_Tpc() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpc); + var customers = dbContext.TpcPeople.Where(o => o.Id <= 1000).OfType().ToList(); + int customersToAdd = 5000; + int customersToUpdate = customers.Count; + int customersToDelete = dbContext.TpcPeople.OfType().Count() - customersToUpdate; + + foreach (var customer in customers) + { + customer.FirstName = "BulkSync_Tpc_Update"; + } + for (int i = 0; i < customersToAdd; i++) + { + customers.Add(new TpcCustomer + { + Id = 10000 + i, + FirstName = "BulkSync_Tpc_Add", + AddedDate = DateTime.UtcNow + }); + } + var result = dbContext.BulkSync(customers, options => { options.MergeOnCondition = (s, t) => s.Id == t.Id; }); + int customersAdded = dbContext.TpcPeople.Where(o => o.FirstName == "BulkSync_Tpc_Add").OfType().Cou +unt(); + int customersUpdated = dbContext.TpcPeople.Where(o => o.FirstName == "BulkSync_Tpc_Update").OfType( +().Count(); + int newCustomerTotal = dbContext.TpcPeople.OfType().Count(); + + Assert.IsTrue(result.RowsAffected == customersAdded + customersToUpdate + customersToDelete, "The number of rows +s affected must match the sum of customers added, updated and deleted."); + Assert.IsTrue(result.RowsUpdated == customersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == customersToAdd, "The number of rows added must match"); + Assert.IsTrue(result.RowsDeleted == customersToDelete, "The number of rows deleted must match the difference fro +om the total existing orders to the new orders to add/update"); + Assert.IsTrue(customersToAdd == customersAdded, "The custmoers that were added did not merge correctly"); + Assert.IsTrue(customersToUpdate == customersUpdated, "The customers that were updated did not merge correctly"); + Assert.IsTrue(newCustomerTotal == customersToAdd + customersToUpdate, "The count of customers in the database sh +hould match the sum of customers added and updated."); + } + [TestMethod] + public void With_Inheritance_Tph() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tph); + var customers = dbContext.TphCustomers.Where(o => o.Id <= 1000).ToList(); + int customersToAdd = 5000; + int customersToUpdate = customers.Count; + int customersToDelete = dbContext.TphPeople.Count() - customersToUpdate; + + foreach (var customer in customers) + { + customer.FirstName = "BulkSync_Tph_Update"; + } + for (int i = 0; i < customersToAdd; i++) + { + customers.Add(new TphCustomer + { + Id = 10000 + i, + FirstName = "BulkSync_Tph_Add", + AddedDate = DateTime.UtcNow + }); + } + var result = dbContext.BulkSync(customers, options => { options.UsePermanentTable = true; options.MergeOnConditi +ion = (s, t) => s.Id == t.Id; }); + int customersAdded = dbContext.TphCustomers.Where(o => o.FirstName == "BulkSync_Tph_Add").Count(); + int customersUpdated = dbContext.TphCustomers.Where(o => o.FirstName == "BulkSync_Tph_Update").Count(); + int newCustomerTotal = dbContext.TphCustomers.Count(); + + Assert.IsTrue(result.RowsAffected == customersAdded + customersToUpdate + customersToDelete, "The number of rows +s affected must match the sum of customers added, updated and deleted."); + Assert.IsTrue(result.RowsUpdated == customersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == customersToAdd, "The number of rows added must match"); + Assert.IsTrue(result.RowsDeleted == customersToDelete, "The number of rows deleted must match the difference fro +om the total existing orders to the new orders to add/update"); + Assert.IsTrue(customersToAdd == customersAdded, "The customers that were added did not merge correctly"); + Assert.IsTrue(customersToUpdate == customersUpdated, "The customers that were updated did not merge correctly"); + Assert.IsTrue(newCustomerTotal == customersToAdd + customersToUpdate, "The count of customers in the database sh +hould match the sum of customers added and updated."); + } + [TestMethod] + public void With_Inheritance_Tpt() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpt); + var customers = dbContext.TptPeople.Where(o => o.Id <= 1000).OfType().ToList(); + int customersToAdd = 5000; + int customersToUpdate = customers.Count; + int customersToDelete = dbContext.TptCustomers.Count() - customersToUpdate; + + foreach (var customer in customers) + { + customer.FirstName = "BulkSync_Tpt_Update"; + } + for (int i = 0; i < customersToAdd; i++) + { + customers.Add(new TptCustomer + { + Id = 10000 + i, + FirstName = "BulkSync_Tpt_Add", + AddedDate = DateTime.UtcNow + }); + } + var result = dbContext.BulkSync(customers, options => { options.MergeOnCondition = (s, t) => s.Id == t.Id; }); + int customersAdded = dbContext.TptPeople.Where(o => o.FirstName == "BulkSync_Tpt_Add").OfType().Cou +unt(); + int customersUpdated = dbContext.TptPeople.Where(o => o.FirstName == "BulkSync_Tpt_Update").OfType( +().Count(); + int newCustomerTotal = dbContext.TptPeople.OfType().Count(); + + Assert.IsTrue(result.RowsAffected == customersAdded + customersToUpdate + customersToDelete, "The number of rows +s affected must match the sum of customers added, updated and deleted."); + Assert.IsTrue(result.RowsUpdated == customersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == customersToAdd, "The number of rows added must match"); + Assert.IsTrue(result.RowsDeleted == customersToDelete, "The number of rows deleted must match the difference fro +om the total existing orders to the new orders to add/update"); + Assert.IsTrue(customersToAdd == customersAdded, "The custmoers that were added did not merge correctly"); + Assert.IsTrue(customersToUpdate == customersUpdated, "The customers that were updated did not merge correctly"); + Assert.IsTrue(newCustomerTotal == customersToAdd + customersToUpdate, "The count of customers in the database sh +hould match the sum of customers added and updated."); + } + [TestMethod] + public void With_Options_AutoMapIdentity() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Count(); + int ordersToUpdate = 3; + int ordersToAdd = 2; + var orders = new List + { + new Order { ExternalId = "id-1", Price=7.10M }, + new Order { ExternalId = "id-2", Price=9.33M }, + new Order { ExternalId = "id-3", Price=3.25M }, + new Order { ExternalId = "id-1000001", Price=2.15M }, + new Order { ExternalId = "id-1000002", Price=5.75M }, + }; + var result = dbContext.BulkSync(orders, options => { options.MergeOnCondition = (s, t) => s.ExternalId == t.Exte +ernalId; options.UsePermanentTable = true; }); + bool autoMapIdentityMatched = true; + foreach (var order in orders) + { + if (!dbContext.Orders.Any(o => o.ExternalId == order.ExternalId && o.Price == order.Price)) + { + autoMapIdentityMatched = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == oldTotal + ordersToAdd, "The number of rows inserted must match the count o +of order list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(result.RowsDeleted == oldTotal - orders.Count() + ordersToAdd, "The number of rows deleted must ma +atch the difference from the total existing orders to the new orders to add/update"); + Assert.IsTrue(autoMapIdentityMatched, "The auto mapping of ids of entities that were merged failed to match up") +); + } + [TestMethod] + public void With_Options_MergeOnCondition() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Count(); + var orders = dbContext.Orders.Where(o => o.Id <= 100 && o.ExternalId != null).OrderBy(o => o.Id).ToList(); + int ordersToAdd = 50; + int ordersToUpdate = orders.Count; + foreach (var order in orders) + { + order.Price = Convert.ToDecimal(order.Id + .25); + } + for (int i = 0; i < ordersToAdd; i++) + { + orders.Add(new Order { Id = 100000 + i, Price = 3.55M }); + } + var result = dbContext.BulkSync(orders, new BulkSyncOptions + { + MergeOnCondition = (s, t) => s.ExternalId == t.ExternalId, + BatchSize = 1000 + }); + var newOrders = dbContext.Orders.OrderBy(o => o.Id).ToList(); + bool areAddedOrdersMerged = true; + bool areUpdatedOrdersMerged = true; + foreach (var newOrder in newOrders.Where(o => o.Id <= 100 && o.ExternalId != null).OrderBy(o => o.Id)) + { + if (newOrder.Price != Convert.ToDecimal(newOrder.Id + .25)) + { + areUpdatedOrdersMerged = false; + break; + } + } + foreach (var newOrder in newOrders.Where(o => o.Id >= 500000).OrderBy(o => o.Id)) + { + if (newOrder.Price != 3.55M) + { + areAddedOrdersMerged = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == oldTotal + ordersToAdd, "The number of rows inserted must match the count o +of order list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(result.RowsDeleted == oldTotal - orders.Count() + ordersToAdd, "The number of rows deleted must ma +atch the difference from the total existing orders to the new orders to add/update"); + Assert.IsTrue(areAddedOrdersMerged, "The orders that were added did not merge correctly"); + Assert.IsTrue(areUpdatedOrdersMerged, "The orders that were updated did not merge correctly"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\DbContextExtensions\BulkSyncAsync.cs --- + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.EntityFrameworkCore; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Data; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class BulkSyncAsync : DbContextExtensionsBase +{ + [TestMethod] + public async Task With_Default_Options() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Count(); + var orders = dbContext.Orders.Where(o => o.Id <= 10000).OrderBy(o => o.Id).ToList(); + int ordersToAdd = 5000; + int ordersToUpdate = orders.Count; + foreach (var order in orders) + { + order.Price = Convert.ToDecimal(order.Id + .25); + } + for (int i = 0; i < ordersToAdd; i++) + { + orders.Add(new Order { Id = 100000 + i, Price = 3.55M }); + } + var result = await dbContext.BulkSyncAsync(orders); + var newOrders = dbContext.Orders.OrderBy(o => o.Id).ToList(); + bool areAddedOrdersMerged = true; + bool areUpdatedOrdersMerged = true; + foreach (var newOrder in newOrders.Where(o => o.Id <= 10000).OrderBy(o => o.Id)) + { + if (newOrder.Price != Convert.ToDecimal(newOrder.Id + .25)) + { + areUpdatedOrdersMerged = false; + break; + } + } + foreach (var newOrder in newOrders.Where(o => o.Id >= 500000).OrderBy(o => o.Id)) + { + if (newOrder.Price != 3.55M) + { + areAddedOrdersMerged = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == oldTotal + ordersToAdd, "The number of rows inserted must match the count o +of order list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(result.RowsDeleted == oldTotal - orders.Count() + ordersToAdd, "The number of rows deleted must ma +atch the difference from the total existing orders to the new orders to add/update"); + Assert.IsTrue(areAddedOrdersMerged, "The orders that were added did not merge correctly"); + Assert.IsTrue(areUpdatedOrdersMerged, "The orders that were updated did not merge correctly"); + } + [TestMethod] + public async Task With_Inheritance_Tpc() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpc); + var customers = await dbContext.TpcPeople.Where(o => o.Id <= 1000).OfType().ToListAsync(); + int customersToAdd = 5000; + int customersToUpdate = customers.Count; + int customersToDelete = dbContext.TpcPeople.OfType().Count() - customersToUpdate; + + foreach (var customer in customers) + { + customer.FirstName = "BulkSync_Tpc_Update"; + } + for (int i = 0; i < customersToAdd; i++) + { + customers.Add(new TpcCustomer + { + Id = 10000 + i, + FirstName = "BulkSync_Tpc_Add", + AddedDate = DateTime.UtcNow + }); + } + var result = await dbContext.BulkSyncAsync(customers, options => { options.MergeOnCondition = (s, t) => s.Id == + t.Id; }); + int customersAdded = dbContext.TpcPeople.Where(o => o.FirstName == "BulkSync_Tpc_Add").OfType().Cou +unt(); + int customersUpdated = dbContext.TpcPeople.Where(o => o.FirstName == "BulkSync_Tpc_Update").OfType( +().Count(); + int newCustomerTotal = dbContext.TpcPeople.OfType().Count(); + + Assert.IsTrue(result.RowsAffected == customersAdded + customersToUpdate + customersToDelete, "The number of rows +s affected must match the sum of customers added, updated and deleted."); + Assert.IsTrue(result.RowsUpdated == customersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == customersToAdd, "The number of rows added must match"); + Assert.IsTrue(result.RowsDeleted == customersToDelete, "The number of rows deleted must match the difference fro +om the total existing orders to the new orders to add/update"); + Assert.IsTrue(customersToAdd == customersAdded, "The custmoers that were added did not merge correctly"); + Assert.IsTrue(customersToUpdate == customersUpdated, "The customers that were updated did not merge correctly"); + Assert.IsTrue(newCustomerTotal == customersToAdd + customersToUpdate, "The count of customers in the database sh +hould match the sum of customers added and updated."); + } + [TestMethod] + public async Task With_Inheritance_Tph() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tph); + var customers = await dbContext.TphCustomers.Where(o => o.Id <= 1000).ToListAsync(); + int customersToAdd = 5000; + int customersToUpdate = customers.Count; + int customersToDelete = dbContext.TphPeople.Count() - customersToUpdate; + + foreach (var customer in customers) + { + customer.FirstName = "BulkSync_Tph_Update"; + } + for (int i = 0; i < customersToAdd; i++) + { + customers.Add(new TphCustomer + { + Id = 10000 + i, + FirstName = "BulkSync_Tph_Add", + AddedDate = DateTime.UtcNow + }); + } + var result = await dbContext.BulkSyncAsync(customers, options => { options.UsePermanentTable = true; options.Mer +rgeOnCondition = (s, t) => s.Id == t.Id; }); + int customersAdded = dbContext.TphCustomers.Where(o => o.FirstName == "BulkSync_Tph_Add").Count(); + int customersUpdated = dbContext.TphCustomers.Where(o => o.FirstName == "BulkSync_Tph_Update").Count(); + int newCustomerTotal = dbContext.TphCustomers.Count(); + + Assert.IsTrue(result.RowsAffected == customersAdded + customersToUpdate + customersToDelete, "The number of rows +s affected must match the sum of customers added, updated and deleted."); + Assert.IsTrue(result.RowsUpdated == customersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == customersToAdd, "The number of rows added must match"); + Assert.IsTrue(result.RowsDeleted == customersToDelete, "The number of rows deleted must match the difference fro +om the total existing orders to the new orders to add/update"); + Assert.IsTrue(customersToAdd == customersAdded, "The customers that were added did not merge correctly"); + Assert.IsTrue(customersToUpdate == customersUpdated, "The customers that were updated did not merge correctly"); + Assert.IsTrue(newCustomerTotal == customersToAdd + customersToUpdate, "The count of customers in the database sh +hould match the sum of customers added and updated."); + } + [TestMethod] + public async Task With_Inheritance_Tpt() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpt); + var customers = await dbContext.TptPeople.Where(o => o.Id <= 1000).OfType().ToListAsync(); + int customersToAdd = 5000; + int customersToUpdate = customers.Count; + int customersToDelete = dbContext.TptCustomers.Count() - customersToUpdate; + + foreach (var customer in customers) + { + customer.FirstName = "BulkSync_Tpt_Update"; + } + for (int i = 0; i < customersToAdd; i++) + { + customers.Add(new TptCustomer + { + Id = 10000 + i, + FirstName = "BulkSync_Tpt_Add", + AddedDate = DateTime.UtcNow + }); + } + var result = await dbContext.BulkSyncAsync(customers, options => { options.MergeOnCondition = (s, t) => s.Id == + t.Id; }); + int customersAdded = dbContext.TptPeople.Where(o => o.FirstName == "BulkSync_Tpt_Add").OfType().Cou +unt(); + int customersUpdated = dbContext.TptPeople.Where(o => o.FirstName == "BulkSync_Tpt_Update").OfType( +().Count(); + int newCustomerTotal = dbContext.TptPeople.OfType().Count(); + + Assert.IsTrue(result.RowsAffected == customersAdded + customersToUpdate + customersToDelete, "The number of rows +s affected must match the sum of customers added, updated and deleted."); + Assert.IsTrue(result.RowsUpdated == customersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == customersToAdd, "The number of rows added must match"); + Assert.IsTrue(result.RowsDeleted == customersToDelete, "The number of rows deleted must match the difference fro +om the total existing orders to the new orders to add/update"); + Assert.IsTrue(customersToAdd == customersAdded, "The custmoers that were added did not merge correctly"); + Assert.IsTrue(customersToUpdate == customersUpdated, "The customers that were updated did not merge correctly"); + Assert.IsTrue(newCustomerTotal == customersToAdd + customersToUpdate, "The count of customers in the database sh +hould match the sum of customers added and updated."); + } + [TestMethod] + public async Task With_Options_AutoMapIdentity() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Count(); + int ordersToUpdate = 3; + int ordersToAdd = 2; + var orders = new List + { + new Order { ExternalId = "id-1", Price=7.10M }, + new Order { ExternalId = "id-2", Price=9.33M }, + new Order { ExternalId = "id-3", Price=3.25M }, + new Order { ExternalId = "id-1000001", Price=2.15M }, + new Order { ExternalId = "id-1000002", Price=5.75M }, + }; + var result = await dbContext.BulkSyncAsync(orders, options => { options.MergeOnCondition = (s, t) => s.ExternalI +Id == t.ExternalId; options.UsePermanentTable = true; }); + bool autoMapIdentityMatched = true; + foreach (var order in orders) + { + if (!dbContext.Orders.Any(o => o.ExternalId == order.ExternalId && o.Price == order.Price)) + { + autoMapIdentityMatched = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == oldTotal + ordersToAdd, "The number of rows inserted must match the count o +of order list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(result.RowsDeleted == oldTotal - orders.Count() + ordersToAdd, "The number of rows deleted must ma +atch the difference from the total existing orders to the new orders to add/update"); + Assert.IsTrue(autoMapIdentityMatched, "The auto mapping of ids of entities that were merged failed to match up") +); + } + [TestMethod] + public async Task With_Options_MergeOnCondition() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Count(); + var orders = dbContext.Orders.Where(o => o.Id <= 100 && o.ExternalId != null).OrderBy(o => o.Id).ToList(); + int ordersToAdd = 50; + int ordersToUpdate = orders.Count; + foreach (var order in orders) + { + order.Price = Convert.ToDecimal(order.Id + .25); + } + for (int i = 0; i < ordersToAdd; i++) + { + orders.Add(new Order { Id = 100000 + i, Price = 3.55M }); + } + var result = await dbContext.BulkSyncAsync(orders, new BulkSyncOptions + { + MergeOnCondition = (s, t) => s.ExternalId == t.ExternalId, + BatchSize = 1000 + }); + var newOrders = dbContext.Orders.OrderBy(o => o.Id).ToList(); + bool areAddedOrdersMerged = true; + bool areUpdatedOrdersMerged = true; + foreach (var newOrder in newOrders.Where(o => o.Id <= 100 && o.ExternalId != null).OrderBy(o => o.Id)) + { + if (newOrder.Price != Convert.ToDecimal(newOrder.Id + .25)) + { + areUpdatedOrdersMerged = false; + break; + } + } + foreach (var newOrder in newOrders.Where(o => o.Id >= 500000).OrderBy(o => o.Id)) + { + if (newOrder.Price != 3.55M) + { + areAddedOrdersMerged = false; + break; + } + } + + Assert.IsTrue(result.RowsAffected == oldTotal + ordersToAdd, "The number of rows inserted must match the count o +of order list"); + Assert.IsTrue(result.RowsUpdated == ordersToUpdate, "The number of rows updated must match"); + Assert.IsTrue(result.RowsInserted == ordersToAdd, "The number of rows added must match"); + Assert.IsTrue(result.RowsDeleted == oldTotal - orders.Count() + ordersToAdd, "The number of rows deleted must ma +atch the difference from the total existing orders to the new orders to add/update"); + Assert.IsTrue(areAddedOrdersMerged, "The orders that were added did not merge correctly"); + Assert.IsTrue(areUpdatedOrdersMerged, "The orders that were updated did not merge correctly"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\DbContextExtensions\BulkUpdate.cs --- + +using System.Linq; +using Microsoft.EntityFrameworkCore; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Data; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class BulkUpdate : DbContextExtensionsBase +{ + [TestMethod] + public void With_Complex_Key() + { + var dbContext = SetupDbContext(true); + var products = dbContext.ProductsWithComplexKey.Where(o => o.Price == 1.25M).ToList(); + foreach (var product in products) + { + product.Price = 2.35M; + } + var oldTotal = dbContext.ProductsWithComplexKey.Where(o => o.Price == 2.35M).Count(); + int rowsUpdated = dbContext.BulkUpdate(products); + var newTotal = dbContext.ProductsWithComplexKey.Where(o => o.Price == 2.35M).Count(); + + Assert.IsTrue(products.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(rowsUpdated == products.Count, "The number of rows updated must match the count of entities that w +were retrieved"); + Assert.IsTrue(newTotal == rowsUpdated + oldTotal, "The count of new orders must be equal the number of rows upda +ated in the database."); + } + [TestMethod] + public void With_Default_Options() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M).OrderBy(o => o.Id).ToList(); + long maxId = 0; + foreach (var order in orders) + { + order.Price = 2.35M; + maxId = order.Id; + } + int rowsUpdated = dbContext.BulkUpdate(orders); + var newOrders = dbContext.Orders.Where(o => o.Price == 2.35M).OrderBy(o => o.Id).Count(); + int entitiesWithChanges = dbContext.ChangeTracker.Entries().Where(t => t.State == EntityState.Modified).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(rowsUpdated == orders.Count, "The number of rows updated must match the count of entities that wer +re retrieved"); + Assert.IsTrue(newOrders == rowsUpdated, "The count of new orders must be equal the number of rows updated in the +e database."); + } + [TestMethod] + public void With_Inheritance_Tpc() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpc); + var customers = dbContext.TpcPeople.Where(o => o.LastName != "BulkUpdate_Tpc").OfType().ToList(); + var vendors = dbContext.TpcPeople.OfType().ToList(); + foreach (var customer in customers) + { + customer.FirstName = string.Format("Id={0}", customer.Id); + customer.LastName = "BulkUpdate_Tpc"; + } + int rowsUpdated = dbContext.BulkUpdate(customers); + var newCustomers = dbContext.TpcPeople.Where(o => o.LastName == "BulkUpdate_Tpc").OfType().Count(); + int entitiesWithChanges = dbContext.ChangeTracker.Entries().Where(t => t.State == EntityState.Modified).Count(); + + Assert.IsTrue(vendors.Count > 0 && vendors.Count != customers.Count, "There should be vendor records in the data +abase"); + Assert.IsTrue(customers.Count > 0, "There must be customers in database that match this condition (Price = $1.25 +5)"); + Assert.IsTrue(rowsUpdated == customers.Count, "The number of rows updated must match the count of entities that + were retrieved"); + Assert.IsTrue(newCustomers == rowsUpdated, "The count of new customers must be equal the number of rows updated + in the database."); + } + [TestMethod] + public void With_Inheritance_Tph() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tph); + var customers = dbContext.TphPeople.Where(o => o.LastName != "BulkUpdateTest").OfType().ToList(); + var vendors = dbContext.TphPeople.OfType().ToList(); + foreach (var customer in customers) + { + customer.FirstName = string.Format("Id={0}", customer.Id); + customer.LastName = "BulkUpdateTest"; + } + int rowsUpdated = dbContext.BulkUpdate(customers); + var newCustomers = dbContext.TphPeople.Where(o => o.LastName == "BulkUpdateTest").OrderBy(o => o.Id).Count(); + int entitiesWithChanges = dbContext.ChangeTracker.Entries().Where(t => t.State == EntityState.Modified).Count(); + + Assert.IsTrue(vendors.Count > 0 && vendors.Count != customers.Count, "There should be vendor records in the data +abase"); + Assert.IsTrue(customers.Count > 0, "There must be customers in database that match this condition (Price = $1.25 +5)"); + Assert.IsTrue(rowsUpdated == customers.Count, "The number of rows updated must match the count of entities that + were retrieved"); + Assert.IsTrue(newCustomers == rowsUpdated, "The count of new customers must be equal the number of rows updated + in the database."); + } + [TestMethod] + public void With_Inheritance_Tpt() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpt); + var customers = dbContext.TptCustomers.Where(o => o.LastName != "BulkUpdateTest").ToList(); + foreach (var customer in customers) + { + customer.FirstName = string.Format("Id={0}", customer.Id); + customer.LastName = "BulkUpdateTest"; + } + int rowsUpdated = dbContext.BulkUpdate(customers); + var newCustomers = dbContext.TptCustomers.Where(o => o.LastName == "BulkUpdateTest").OrderBy(o => o.Id).Count(); + //int entitiesWithChanges = dbContext.ChangeTracker.Entries().Where(t => t.State == EntityState.Modified).Count( +(); + + Assert.IsTrue(customers.Count > 0, "There must be customers in database that match this condition (Price = $1.25 +5)"); + Assert.IsTrue(rowsUpdated == customers.Count, "The number of rows updated must match the count of entities that + were retrieved"); + Assert.IsTrue(newCustomers == rowsUpdated, "The count of new customers must be equal the number of rows updated + in the database."); + } + [TestMethod] + public void With_Options_InputColumns_PropertyExpression() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).OrderBy(o => o.Id).ToList(); + foreach (var order in orders) + { + order.Price = 2.35M; + order.ExternalId = null; + } + var oldTotal = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + int rowsUpdated = dbContext.BulkUpdate(orders, options => { options.InputColumns = o => o.Price; }); + var newTotal1 = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + var newTotal2 = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(newTotal1 == rowsUpdated + oldTotal, "The count of new orders must be equal the number of rows upd +dated in the database."); + Assert.IsTrue(newTotal2 == 0, "There should be not records with condition (Price = $1.25)"); + } + [TestMethod] + public void With_Options_InputColumns_NewExpression() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).OrderBy(o => o.Id).ToList(); + foreach (var order in orders) + { + order.Price = 2.35M; + order.ExternalId = null; + } + var oldTotal = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + int rowsUpdated = dbContext.BulkUpdate(orders, options => { options.InputColumns = o => new { o.Price }; }); + var newTotal1 = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + var newTotal2 = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(newTotal1 == rowsUpdated + oldTotal, "The count of new orders must be equal the number of rows upd +dated in the database."); + Assert.IsTrue(newTotal2 == 0, "There should be not records with condition (Price = $1.25)"); + } + [TestMethod] + public void With_Options_IgnoreColumns_PropertyExpression() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).OrderBy(o => o.Id).ToList(); + foreach (var order in orders) + { + order.Price = 2.35M; + order.ExternalId = null; + } + var oldTotal = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + int rowsUpdated = dbContext.BulkUpdate(orders, options => { options.IgnoreColumns = o => o.ExternalId; }); + var newTotal1 = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + var newTotal2 = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(newTotal1 == rowsUpdated + oldTotal, "The count of new orders must be equal the number of rows upd +dated in the database."); + Assert.IsTrue(newTotal2 == 0, "There should be not records with condition (Price = $1.25)"); + } + [TestMethod] + public void With_Options_IgnoreColumns_NewExpression() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).OrderBy(o => o.Id).ToList(); + foreach (var order in orders) + { + order.Price = 2.35M; + order.ExternalId = null; + } + var oldTotal = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + int rowsUpdated = dbContext.BulkUpdate(orders, options => { options.IgnoreColumns = o => new { o.ExternalId }; } +}); + var newTotal1 = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + var newTotal2 = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(newTotal1 == rowsUpdated + oldTotal, "The count of new orders must be equal the number of rows upd +dated in the database."); + Assert.IsTrue(newTotal2 == 0, "There should be not records with condition (Price = $1.25)"); + } + [TestMethod] + public void With_Options_UpdateOnCondition() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M).OrderBy(o => o.Id).ToList(); + int ordersWithExternalId = orders.Where(o => o.ExternalId != null).Count(); + foreach (var order in orders) + { + order.Price = 2.35M; + } + var oldTotal = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + int rowsUpdated = dbContext.BulkUpdate(orders, options => { options.UpdateOnCondition = (s, t) => s.ExternalId = +== t.ExternalId; }); + var newTotal = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(rowsUpdated == ordersWithExternalId, "The number of rows updated must match the count of entities + that were retrieved"); + Assert.IsTrue(newTotal == rowsUpdated + oldTotal, "The count of new orders must be equal the number of rows upda +ated in the database."); + } + [TestMethod] + public void With_Options_UpdateOnCondition_Enum() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(o => o.Price == 1.25M).OrderBy(o => o.Id).ToList(); + foreach (var product in products) + { + product.Price = 2.35M; + } + int rowsUpdated = dbContext.BulkUpdate(products, o => + { + o.UpdateOnCondition = (s, t) => s.Id == t.Id && s.StatusEnum == t.StatusEnum; + }); + var newProducts = dbContext.Products.Where(o => o.Price == 2.35M).OrderBy(o => o.Id).Count(); + + Assert.IsTrue(products.Count > 0, "There must be products in database that match this condition (Price = $1.25)" +"); + Assert.IsTrue(rowsUpdated == products.Count, "The number of rows updated must match the count of entities that w +were retrieved"); + Assert.IsTrue(newProducts == rowsUpdated, "The count of new products must be equal the number of rows updated in +n the database."); + } + [TestMethod] + public void With_Transaction() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M).OrderBy(o => o.Id).ToList(); + long maxId = 0; + foreach (var order in orders) + { + order.Price = 2.35M; + maxId = order.Id; + } + int rowsUpdated, newOrders; + using (var transaction = dbContext.Database.BeginTransaction()) + { + rowsUpdated = dbContext.BulkUpdate(orders); + newOrders = dbContext.Orders.Where(o => o.Price == 2.35M).Count(); + transaction.Rollback(); + } + int rollbackTotal = dbContext.Orders.Where(o => o.Price == 1.25M).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(rowsUpdated == orders.Count, "The number of rows updated must match the count of entities that wer +re retrieved"); + Assert.IsTrue(newOrders == rowsUpdated, "The count of new orders must be equal the number of rows updated in the +e database."); + Assert.IsTrue(rollbackTotal == orders.Count, "The number of rows after the transacation has been rollbacked shou +uld match the original count"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\DbContextExtensions\BulkUpdateAsync.cs --- + +using System.Linq; +using System.Threading.Tasks; +using Microsoft.EntityFrameworkCore; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Data; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class BulkUpdateAsync : DbContextExtensionsBase +{ + [TestMethod] + public async Task With_Complex_Key() + { + var dbContext = SetupDbContext(true); + var products = dbContext.ProductsWithComplexKey.Where(o => o.Price == 1.25M).ToList(); + foreach (var product in products) + { + product.Price = 2.35M; + } + var oldTotal = dbContext.ProductsWithComplexKey.Where(o => o.Price == 2.35M).Count(); + int rowsUpdated = await dbContext.BulkUpdateAsync(products); + var newTotal = dbContext.ProductsWithComplexKey.Where(o => o.Price == 2.35M).Count(); + + Assert.IsTrue(products.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(rowsUpdated == products.Count, "The number of rows updated must match the count of entities that w +were retrieved"); + Assert.IsTrue(newTotal == rowsUpdated + oldTotal, "The count of new orders must be equal the number of rows upda +ated in the database."); + } + [TestMethod] + public async Task With_Default_Options() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M).OrderBy(o => o.Id).ToList(); + long maxId = 0; + foreach (var order in orders) + { + order.Price = 2.35M; + maxId = order.Id; + } + int rowsUpdated = await dbContext.BulkUpdateAsync(orders); + var newOrders = dbContext.Orders.Where(o => o.Price == 2.35M).OrderBy(o => o.Id).Count(); + int entitiesWithChanges = dbContext.ChangeTracker.Entries().Where(t => t.State == EntityState.Modified).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(rowsUpdated == orders.Count, "The number of rows updated must match the count of entities that wer +re retrieved"); + Assert.IsTrue(newOrders == rowsUpdated, "The count of new orders must be equal the number of rows updated in the +e database."); + } + [TestMethod] + public async Task With_Inheritance_Tpc() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpc); + var customers = dbContext.TpcPeople.Where(o => o.LastName != "BulkUpdateTest").OfType().ToList(); + var vendors = dbContext.TpcPeople.OfType().ToList(); + foreach (var customer in customers) + { + customer.FirstName = string.Format("Id={0}", customer.Id); + customer.LastName = "BulkUpdate_Tpc"; + } + int rowsUpdated = await dbContext.BulkUpdateAsync(customers, options => { options.UpdateOnCondition = (s, t) => + s.Id == t.Id; }); + var newCustomers = dbContext.TpcPeople.Where(o => o.LastName == "BulkUpdate_Tpc").OfType().Count(); + int entitiesWithChanges = dbContext.ChangeTracker.Entries().Where(t => t.State == EntityState.Modified).Count(); + + Assert.IsTrue(vendors.Count > 0 && vendors.Count != customers.Count, "There should be vendor records in the data +abase"); + Assert.IsTrue(customers.Count > 0, "There must be customers in database that match this condition (Price = $1.25 +5)"); + Assert.IsTrue(rowsUpdated == customers.Count, "The number of rows updated must match the count of entities that + were retrieved"); + Assert.IsTrue(newCustomers == rowsUpdated, "The count of new customers must be equal the number of rows updated + in the database."); + } + [TestMethod] + public async Task With_Inheritance_Tph() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tph); + var customers = dbContext.TphPeople.Where(o => o.LastName != "BulkUpdateTest").OfType().ToList(); + var vendors = dbContext.TphPeople.OfType().ToList(); + foreach (var customer in customers) + { + customer.FirstName = string.Format("Id={0}", customer.Id); + customer.LastName = "BulkUpdateTest"; + } + int rowsUpdated = await dbContext.BulkUpdateAsync(customers); + var newCustomers = dbContext.TphPeople.Where(o => o.LastName == "BulkUpdateTest").OrderBy(o => o.Id).Count(); + int entitiesWithChanges = dbContext.ChangeTracker.Entries().Where(t => t.State == EntityState.Modified).Count(); + + Assert.IsTrue(vendors.Count > 0 && vendors.Count != customers.Count, "There should be vendor records in the data +abase"); + Assert.IsTrue(customers.Count > 0, "There must be customers in database that match this condition (Price = $1.25 +5)"); + Assert.IsTrue(rowsUpdated == customers.Count, "The number of rows updated must match the count of entities that + were retrieved"); + Assert.IsTrue(newCustomers == rowsUpdated, "The count of new customers must be equal the number of rows updated + in the database."); + } + [TestMethod] + public async Task With_Inheritance_Tpt() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Tpt); + var customers = dbContext.TptCustomers.Where(o => o.LastName != "BulkUpdateTest").ToList(); + foreach (var customer in customers) + { + customer.FirstName = string.Format("Id={0}", customer.Id); + customer.LastName = "BulkUpdateTest"; + } + int rowsUpdated = await dbContext.BulkUpdateAsync(customers); + var newCustomers = await dbContext.TptCustomers.Where(o => o.LastName == "BulkUpdateTest").CountAsync(); + + Assert.IsTrue(customers.Count > 0, "There must be customers in database that match this condition (Price = $1.25 +5)"); + Assert.IsTrue(rowsUpdated == customers.Count, "The number of rows updated must match the count of entities that + were retrieved"); + Assert.IsTrue(newCustomers == rowsUpdated, "The count of new customers must be equal the number of rows updated + in the database."); + } + [TestMethod] + public async Task With_Options_InputColumns_PropertyExpression() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).OrderBy(o => o.Id).ToList(); + foreach (var order in orders) + { + order.Price = 2.35M; + order.ExternalId = null; + } + var oldTotal = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + int rowsUpdated = await dbContext.BulkUpdateAsync(orders, options => { options.InputColumns = o => o.Price; }); + var newTotal1 = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + var newTotal2 = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(newTotal1 == rowsUpdated + oldTotal, "The count of new orders must be equal the number of rows upd +dated in the database."); + Assert.IsTrue(newTotal2 == 0, "There should be not records with condition (Price = $1.25)"); + } + [TestMethod] + public async Task With_Options_InputColumns_NewExpression() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).OrderBy(o => o.Id).ToList(); + foreach (var order in orders) + { + order.Price = 2.35M; + order.ExternalId = null; + } + var oldTotal = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + int rowsUpdated = await dbContext.BulkUpdateAsync(orders, options => { options.InputColumns = o => new { o.Price +e }; }); + var newTotal1 = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + var newTotal2 = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).Count(); + + Assert.IsTrue(orders.Count() > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(newTotal1 == rowsUpdated + oldTotal, "The count of new orders must be equal the number of rows upd +dated in the database."); + Assert.IsTrue(newTotal2 == 0, "There should be not records with condition (Price = $1.25)"); + } + [TestMethod] + public async Task With_Options_IgnoreColumns_PropertyExpression() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).OrderBy(o => o.Id).ToList(); + foreach (var order in orders) + { + order.Price = 2.35M; + order.ExternalId = null; + } + var oldTotal = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + int rowsUpdated = await dbContext.BulkUpdateAsync(orders, options => { options.IgnoreColumns = o => o.ExternalId +d; }); + var newTotal1 = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + var newTotal2 = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(newTotal1 == rowsUpdated + oldTotal, "The count of new orders must be equal the number of rows upd +dated in the database."); + Assert.IsTrue(newTotal2 == 0, "There should be not records with condition (Price = $1.25)"); + } + [TestMethod] + public async Task With_Options_IgnoreColumns_NewExpression() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).OrderBy(o => o.Id).ToList(); + foreach (var order in orders) + { + order.Price = 2.35M; + order.ExternalId = null; + } + var oldTotal = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + int rowsUpdated = await dbContext.BulkUpdateAsync(orders, options => { options.IgnoreColumns = o => new { o.Exte +ernalId }; }); + var newTotal1 = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + var newTotal2 = dbContext.Orders.Where(o => o.Price == 1.25M && o.ExternalId != null).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(newTotal1 == rowsUpdated + oldTotal, "The count of new orders must be equal the number of rows upd +dated in the database."); + Assert.IsTrue(newTotal2 == 0, "There should be not records with condition (Price = $1.25)"); + } + [TestMethod] + public async Task With_Options_UpdateOnCondition() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M).OrderBy(o => o.Id).ToList(); + int ordersWithExternalId = orders.Where(o => o.ExternalId != null).Count(); + foreach (var order in orders) + { + order.Price = 2.35M; + } + var oldTotal = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + int rowsUpdated = await dbContext.BulkUpdateAsync(orders, options => { options.UpdateOnCondition = (s, t) => s.E +ExternalId == t.ExternalId; }); + var newTotal = dbContext.Orders.Where(o => o.Price == 2.35M && o.ExternalId != null).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(rowsUpdated == ordersWithExternalId, "The number of rows updated must match the count of entities + that were retrieved"); + Assert.IsTrue(newTotal == rowsUpdated + oldTotal, "The count of new orders must be equal the number of rows upda +ated in the database."); + } + [TestMethod] + public async Task With_Options_UpdateOnCondition_Enum() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(o => o.Price == 1.25M).OrderBy(o => o.Id).ToList(); + foreach (var product in products) + { + product.Price = 2.35M; + } + int rowsUpdated = await dbContext.BulkUpdateAsync(products, o => + { + o.UpdateOnCondition = (s, t) => s.Id == t.Id && s.StatusEnum == t.StatusEnum; + }); + var newProducts = dbContext.Products.Where(o => o.Price == 2.35M).OrderBy(o => o.Id).Count(); + + Assert.IsTrue(products.Count > 0, "There must be products in database that match this condition (Price = $1.25)" +"); + Assert.IsTrue(rowsUpdated == products.Count, "The number of rows updated must match the count of entities that w +were retrieved"); + Assert.IsTrue(newProducts == rowsUpdated, "The count of new products must be equal the number of rows updated in +n the database."); + } + [TestMethod] + public async Task With_Transaction() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price == 1.25M).OrderBy(o => o.Id).ToList(); + long maxId = 0; + foreach (var order in orders) + { + order.Price = 2.35M; + maxId = order.Id; + } + int rowsUpdated, newOrders; + using (var transaction = dbContext.Database.BeginTransaction()) + { + rowsUpdated = await dbContext.BulkUpdateAsync(orders); + newOrders = dbContext.Orders.Where(o => o.Price == 2.35M).Count(); + transaction.Rollback(); + } + int rollbackTotal = dbContext.Orders.Where(o => o.Price == 1.25M).Count(); + + Assert.IsTrue(orders.Count > 0, "There must be orders in database that match this condition (Price = $1.25)"); + Assert.IsTrue(rowsUpdated == orders.Count, "The number of rows updated must match the count of entities that wer +re retrieved"); + Assert.IsTrue(newOrders == rowsUpdated, "The count of new orders must be equal the number of rows updated in the +e database."); + Assert.IsTrue(rollbackTotal == orders.Count, "The number of rows after the transacation has been rollbacked shou +uld match the original count"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\DbContextExtensions\DbContextExtensionsBase.cs --- + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Drawing; +using Microsoft.EntityFrameworkCore; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Common; +using N.EntityFrameworkCore.Extensions.Test.Data; +using N.EntityFrameworkCore.Extensions.Test.Data.Enums; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +public enum PopulateDataMode +{ + Normal, + Tpc, + Tph, + Tpt, + Schema +} +[TestClass] +public class DbContextExtensionsBase +{ + private TestDbContext _currentDbContext; + + [TestInitialize] + public void Init() + { + using var dbContext = new TestDbContext(); + TestDatabaseInitializer.EnsureCreated(dbContext); + } + + [TestCleanup] + public void Cleanup() + { + _currentDbContext?.Dispose(); + _currentDbContext = null; + } + + protected TestDbContext SetupDbContext(bool populateData, PopulateDataMode mode = PopulateDataMode.Normal) + { + var dbContext = new TestDbContext(); + _currentDbContext = dbContext; + TestDatabaseInitializer.EnsureCreated(dbContext); + dbContext.Orders.Truncate(); + dbContext.Products.Truncate(); + dbContext.ProductCategories.Clear(); + dbContext.ProductsWithCustomSchema.Truncate(); + dbContext.ProductsWithTrigger.Truncate(); + dbContext.Database.ClearTable("TpcCustomer"); + dbContext.Database.ClearTable("TpcVendor"); + dbContext.TphPeople.Truncate(); + dbContext.Database.ClearTable("TptPeople"); + dbContext.Database.ClearTable("TptCustomer"); + dbContext.Database.ClearTable("TptVendor"); + dbContext.Database.DropTable("ProductsUnderTen", true); + dbContext.Database.DropTable("OrdersUnderTen", true); + dbContext.Database.DropTable("OrdersLast30Days", true); + if (populateData) + { + if (mode == PopulateDataMode.Normal) + { + var orders = new List(); + int id = 1; + for (int i = 0; i < 2050; i++) + { + DateTime addedDateTime = DateTime.UtcNow.AddDays(-id); + orders.Add(new Order + { + Id = id, + ExternalId = string.Format("id-{0}", i), + Price = 1.25M, + AddedDateTime = addedDateTime, + ModifiedDateTime = addedDateTime.AddHours(3), + Status = OrderStatus.Completed + }); + id++; + } + for (int i = 0; i < 1050; i++) + { + orders.Add(new Order { Id = id, Price = 5.35M }); + id++; + } + for (int i = 0; i < 2050; i++) + { + orders.Add(new Order { Id = id, Price = 1.25M }); + id++; + } + for (int i = 0; i < 6000; i++) + { + orders.Add(new Order { Id = id, Price = 15.35M }); + id++; + } + for (int i = 0; i < 6000; i++) + { + orders.Add(new Order { Id = id, Price = 15.35M }); + id++; + } + + Debug.WriteLine("Last Id for Order is {0}", id); + dbContext.BulkInsert(orders, new BulkInsertOptions() { KeepIdentity = true }); + + var productCategories = new List() + { + new ProductCategory { Id=1, Name="Category-1", Active=true}, + new ProductCategory { Id=2, Name="Category-2", Active=true}, + new ProductCategory { Id=3, Name="Category-3", Active=true}, + new ProductCategory { Id=4, Name="Category-4", Active=false}, + }; + dbContext.BulkInsert(productCategories, o => { o.KeepIdentity = true; o.UsePermanentTable = true; }); + var products = new List(); + id = 1; + for (int i = 0; i < 2050; i++) + { + products.Add(new Product + { + Id = i.ToString(), + Price = 1.25M, + OutOfStock = false, + ProductCategoryId = 4, + StatusEnum = ProductStatus.InStock, + Color = Color.Black, + Position = new Position { Building = 5, Aisle = 33, Bay = i }, + }); + id++; + } + for (int i = 2050; i < 7000; i++) + { + products.Add(new Product { Id = i.ToString(), Price = 1.25M, OutOfStock = true, StatusEnum = Product +tStatus.OutOfStock }); + id++; + } + + Debug.WriteLine("Last Id for Product is {0}", id); + dbContext.BulkInsert(products, new BulkInsertOptions() { KeepIdentity = false, AutoMapOutput = + false, UsePermanentTable = true }); + + //ProductWithComplexKey + var productsWithComplexKey = new List(); + id = 1; + + for (int i = 0; i < 2050; i++) + { + productsWithComplexKey.Add(new ProductWithComplexKey { Price = 1.25M }); + id++; + } + + Debug.WriteLine("Last Id for ProductsWithComplexKey is {0}", id); + dbContext.BulkInsert(productsWithComplexKey, new BulkInsertOptions() { KeepIdenti +ity = false, AutoMapOutput = false }); + } + else if (mode == PopulateDataMode.Tph) + { + //TPH Customers & Vendors + var tphCustomers = new List(); + var tphVendors = new List(); + for (int i = 0; i < 2000; i++) + { + tphCustomers.Add(new TphCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + Phone = "404-555-1111", + AddedDate = DateTime.UtcNow + }); + } + for (int i = 2000; i < 3000; i++) + { + tphVendors.Add(new TphVendor + { + Id = i, + FirstName = string.Format("Mike_{0}", i), + LastName = string.Format("Smith_{0}", i), + Phone = "404-555-2222", + Email = string.Format("mike.smith{0}@domain.com", i), + Url = string.Format("http://domain.com/mike.smith{0}", i) + }); + } + dbContext.BulkInsert(tphCustomers, new BulkInsertOptions() { KeepIdentity = true }); + dbContext.BulkInsert(tphVendors, new BulkInsertOptions() { KeepIdentity = true }); + } + else if (mode == PopulateDataMode.Tpc) + { + //TPC Customers & Vendors + var tpcCustomers = new List(); + var tpcVendors = new List(); + for (int i = 1; i <= 2000; i++) + { + tpcCustomers.Add(new TpcCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + Phone = "404-555-1111", + AddedDate = DateTime.UtcNow + }); + } + for (int i = 2001; i <= 3000; i++) + { + tpcVendors.Add(new TpcVendor + { + Id = i, + FirstName = string.Format("Mike_{0}", i), + LastName = string.Format("Smith_{0}", i), + Phone = "404-555-2222", + Email = string.Format("mike.smith{0}@domain.com", i), + Url = string.Format("http://domain.com/mike.smith{0}", i) + }); + } + dbContext.BulkInsert(tpcCustomers, new BulkInsertOptions() { KeepIdentity = true }); + dbContext.BulkInsert(tpcVendors, new BulkInsertOptions() { KeepIdentity = true }); + } + else if (mode == PopulateDataMode.Tpt) + { + //Customers & Vendors + var tptCustomers = new List(); + var tptVendors = new List(); + for (int i = 1; i <= 2000; i++) + { + tptCustomers.Add(new TptCustomer + { + Id = i, + FirstName = string.Format("John_{0}", i), + LastName = string.Format("Smith_{0}", i), + Email = string.Format("john.smith{0}@domain.com", i), + Phone = "404-555-1111", + AddedDate = DateTime.UtcNow + }); + } + for (int i = 2001; i < 3000; i++) + { + tptVendors.Add(new TptVendor + { + Id = i, + FirstName = string.Format("Mike_{0}", i), + LastName = string.Format("Smith_{0}", i), + Phone = "404-555-2222", + Email = string.Format("mike.smith{0}@domain.com", i), + Url = string.Format("http://domain.com/mike.smith{0}", i) + }); + } + dbContext.BulkInsert(tptCustomers, new BulkInsertOptions() { KeepIdentity = true, UsePerman +nentTable = true }); + dbContext.BulkInsert(tptVendors, new BulkInsertOptions() { KeepIdentity = true }); + } + else if (mode == PopulateDataMode.Schema) + { + //ProductWithCustomSchema + var productsWithCustomSchema = new List(); + int id = 1; + + for (int i = 0; i < 2050; i++) + { + productsWithCustomSchema.Add(new ProductWithCustomSchema { Id = id.ToString(), Price = 1.25M }); + id++; + } + for (int i = 2050; i < 5000; i++) + { + productsWithCustomSchema.Add(new ProductWithCustomSchema { Id = id.ToString(), Price = 6.75M }); + id++; + } + + dbContext.BulkInsert(productsWithCustomSchema); + } + } + return dbContext; + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\DbContextExtensions\DeleteFromQuery.cs --- + +using System; +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class DeleteFromQuery : DbContextExtensionsBase +{ + [TestMethod] + public void With_Boolean_Value() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(p => p.OutOfStock); + int oldTotal = products.Count(a => a.OutOfStock); + int rowUpdated = products.DeleteFromQuery(); + int newTotal = dbContext.Products.Count(o => o.OutOfStock); + + Assert.IsTrue(oldTotal > 0, "There must be products in database that match this condition (OutOfStock == true)") +); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndition (OutOfStock == false)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public void With_Child_Relationship() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(p => !p.ProductCategory.Active); + int oldTotal = products.Count(); + int rowsDeleted = products.DeleteFromQuery(); + int newTotal = products.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be products in database that match this condition (ProductCategory.Activ +ve == false)"); + Assert.IsTrue(rowsDeleted == oldTotal, "The number of rows update must match the count of rows that match the co +ondition (ProductCategory.Active == false)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were deleted"); + } + [TestMethod] + public void With_Decimal_Using_IQueryable() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price <= 10); + int oldTotal = orders.Count(); + int rowsDeleted = orders.DeleteFromQuery(); + int newTotal = orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(rowsDeleted == oldTotal, "The number of rows deleted must match the count of existing rows in data +abase"); + Assert.IsTrue(newTotal == 0, "Delete() Failed: must be 0 to indicate all records were deleted"); + } + [TestMethod] + public void With_Decimal_Using_IEnumerable() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price <= 10); + int oldTotal = orders.Count(); + int rowsDeleted = orders.DeleteFromQuery(); + int newTotal = orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(rowsDeleted == oldTotal, "The number of rows deleted must match the count of existing rows in data +abase"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were deleted"); + } + [TestMethod] + public void With_DateTime() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Count(); + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + int rowsToDelete = dbContext.Orders.Where(o => o.ModifiedDateTime != null && o.ModifiedDateTime >= dateTime).Cou +unt(); + int rowsDeleted = dbContext.Orders.Where(o => o.ModifiedDateTime != null && o.ModifiedDateTime >= dateTime) + .DeleteFromQuery(); + int newTotal = dbContext.Orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(rowsDeleted == rowsToDelete, "The number of rows deleted must match the count of the rows that mat +tched in the database"); + Assert.IsTrue(oldTotal - newTotal == rowsDeleted, "The rows deleted must match the new count minues the old coun +nt"); + } + [TestMethod] + public void With_Delete_All() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Count(); + int rowsDeleted = dbContext.Orders.DeleteFromQuery(); + int newTotal = dbContext.Orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(rowsDeleted == oldTotal, "The number of rows deleted must match the count of existing rows in data +abase"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were deleted"); + } + [TestMethod] + public void With_Different_Values() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Count(); + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + var orders = dbContext.Orders.Where(o => o.Id == 1 && o.Active && o.ModifiedDateTime >= dateTime); + int rowsToDelete = orders.Count(); + int rowsDeleted = orders.DeleteFromQuery(); + int newTotal = dbContext.Orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(rowsDeleted == rowsToDelete, "The number of rows deleted must match the count of the rows that mat +tched in the database"); + Assert.IsTrue(oldTotal - newTotal == rowsDeleted, "The rows deleted must match the new count minues the old coun +nt"); + } + [TestMethod] + public void With_Empty_List() + { + var dbContext = SetupDbContext(false); + int oldTotal = dbContext.Orders.Count(); + int rowsDeleted = dbContext.Orders.DeleteFromQuery(); + int newTotal = dbContext.Orders.Count(); + + Assert.IsTrue(oldTotal == 0, "There must be no orders in database that match this condition"); + Assert.IsTrue(rowsDeleted == oldTotal, "The number of rows deleted must match the count of existing rows in data +abase"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were deleted"); + } + [TestMethod] + public void With_Schema() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Schema); + int oldTotal = dbContext.ProductsWithCustomSchema.Count(); + int rowsDeleted = dbContext.ProductsWithCustomSchema.DeleteFromQuery(); + int newTotal = dbContext.ProductsWithCustomSchema.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be products in database that match this condition"); + Assert.IsTrue(rowsDeleted == oldTotal, "The number of rows deleted must match the count of existing rows in data +abase"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were deleted"); + } + [TestMethod] + public void With_Transaction() + { + var dbContext = SetupDbContext(true); + int rowsDeleted; + int oldTotal = dbContext.Orders.Count(); + var orders = dbContext.Orders.Where(o => o.Price <= 10); + int rowsToDelete = orders.Count(); + using (var transaction = dbContext.Database.BeginTransaction()) + { + rowsDeleted = orders.DeleteFromQuery(); + transaction.Rollback(); + } + int newTotal = dbContext.Orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < $10)"); + Assert.IsTrue(rowsDeleted == orders.Count(), "The number of rows update must match the count of rows that match + the condtion (Price < $10)"); + Assert.IsTrue(newTotal == oldTotal, "The new count must match the old count since the transaction was rollbacked +d"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\DbContextExtensions\DeleteFromQueryAsync.cs --- + +using System; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class DeleteFromQueryAsync : DbContextExtensionsBase +{ + [TestMethod] + public async Task With_Boolean_Value() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(p => p.OutOfStock); + int oldTotal = products.Count(a => a.OutOfStock); + int rowUpdated = await products.DeleteFromQueryAsync(); + int newTotal = dbContext.Products.Count(o => o.OutOfStock); + + Assert.IsTrue(oldTotal > 0, "There must be products in database that match this condition (OutOfStock == true)") +); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndition (OutOfStock == false)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public async Task With_Child_Relationship() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(p => !p.ProductCategory.Active); + int oldTotal = products.Count(); + int rowsDeleted = await products.DeleteFromQueryAsync(); + int newTotal = products.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be products in database that match this condition (ProductCategory.Activ +ve == false)"); + Assert.IsTrue(rowsDeleted == oldTotal, "The number of rows update must match the count of rows that match the co +ondition (ProductCategory.Active == false)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were deleted"); + } + [TestMethod] + public async Task With_Decimal_Using_IQueryable() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price <= 10); + int oldTotal = orders.Count(); + int rowsDeleted = await orders.DeleteFromQueryAsync(); + int newTotal = orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(rowsDeleted == oldTotal, "The number of rows deleted must match the count of existing rows in data +abase"); + Assert.IsTrue(newTotal == 0, "Delete() Failed: must be 0 to indicate all records were deleted"); + } + [TestMethod] + public async Task With_Decimal_Using_IEnumerable() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price <= 10); + int oldTotal = orders.Count(); + int rowsDeleted = await orders.DeleteFromQueryAsync(); + int newTotal = orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(rowsDeleted == oldTotal, "The number of rows deleted must match the count of existing rows in data +abase"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were deleted"); + } + [TestMethod] + public async Task With_DateTime() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Count(); + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + int rowsToDelete = dbContext.Orders.Where(o => o.ModifiedDateTime != null && o.ModifiedDateTime >= dateTime).Cou +unt(); + int rowsDeleted = await dbContext.Orders.Where(o => o.ModifiedDateTime != null && o.ModifiedDateTime >= dateTime +e) + .DeleteFromQueryAsync(); + int newTotal = dbContext.Orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(rowsDeleted == rowsToDelete, "The number of rows deleted must match the count of the rows that mat +tched in the database"); + Assert.IsTrue(oldTotal - newTotal == rowsDeleted, "The rows deleted must match the new count minues the old coun +nt"); + } + [TestMethod] + public async Task With_Delete_All() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Count(); + int rowsDeleted = await dbContext.Orders.DeleteFromQueryAsync(); + int newTotal = dbContext.Orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(rowsDeleted == oldTotal, "The number of rows deleted must match the count of existing rows in data +abase"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were deleted"); + } + [TestMethod] + public async Task With_Different_Values() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Count(); + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + var orders = dbContext.Orders.Where(o => o.Id == 1 && o.Active && o.ModifiedDateTime >= dateTime); + int rowsToDelete = orders.Count(); + int rowsDeleted = await orders.DeleteFromQueryAsync(); + int newTotal = dbContext.Orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(rowsDeleted == rowsToDelete, "The number of rows deleted must match the count of the rows that mat +tched in the database"); + Assert.IsTrue(oldTotal - newTotal == rowsDeleted, "The rows deleted must match the new count minues the old coun +nt"); + } + [TestMethod] + public async Task With_Empty_List() + { + var dbContext = SetupDbContext(false); + int oldTotal = dbContext.Orders.Count(); + int rowsDeleted = await dbContext.Orders.DeleteFromQueryAsync(); + int newTotal = dbContext.Orders.Count(); + + Assert.IsTrue(oldTotal == 0, "There must be no orders in database that match this condition"); + Assert.IsTrue(rowsDeleted == oldTotal, "The number of rows deleted must match the count of existing rows in data +abase"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were deleted"); + } + [TestMethod] + public async Task With_Schema() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Schema); + int oldTotal = dbContext.ProductsWithCustomSchema.Count(); + int rowsDeleted = await dbContext.ProductsWithCustomSchema.DeleteFromQueryAsync(); + int newTotal = dbContext.ProductsWithCustomSchema.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be products in database that match this condition"); + Assert.IsTrue(rowsDeleted == oldTotal, "The number of rows deleted must match the count of existing rows in data +abase"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were deleted"); + } + [TestMethod] + public async Task With_Transaction() + { + var dbContext = SetupDbContext(true); + int rowsDeleted; + int oldTotal = dbContext.Orders.Count(); + var orders = dbContext.Orders.Where(o => o.Price <= 10); + int rowsToDelete = orders.Count(); + using (var transaction = dbContext.Database.BeginTransaction()) + { + rowsDeleted = await orders.DeleteFromQueryAsync(); + transaction.Rollback(); + } + int newTotal = dbContext.Orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < $10)"); + Assert.IsTrue(rowsDeleted == orders.Count(), "The number of rows update must match the count of rows that match + the condtion (Price < $10)"); + Assert.IsTrue(newTotal == oldTotal, "The new count must match the old count since the transaction was rollbacked +d"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\DbContextExtensions\Fetch.cs --- + +using System; +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class Fetch : DbContextExtensionsBase +{ + [TestMethod] + public void With_BulkInsert() + { + var dbContext = SetupDbContext(true); + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + var orders = dbContext.Orders.Where(o => o.AddedDateTime <= dateTime); + int totalOrdersToFetch = orders.Count(); + int totalOrdersFetched = 0; + int batchSize = 5000; + orders.Fetch(result => + { + totalOrdersFetched += result.Results.Count(); + var ordersFetched = result.Results; + foreach (var orderFetched in ordersFetched) + { + orderFetched.Price = 75; + } + dbContext.BulkInsert(ordersFetched); + }, options => { options.BatchSize = batchSize; }); + + int totalOrder = orders.Count(); + int totalOrderInserted = orders.Where(o => o.Price == 75).Count(); + Assert.IsTrue(totalOrdersToFetch == totalOrdersFetched, "The total number of rows fetched must match the number + of rows to fetch"); + Assert.IsTrue(totalOrderInserted == totalOrdersFetched, "The total number of rows updated must match the number + of rows that were fetched"); + Assert.IsTrue(totalOrder - totalOrdersToFetch == totalOrderInserted, "The total number of rows must match the nu +umber of rows that were updated"); + } + [TestMethod] + public void With_BulkUpdate() + { + var dbContext = SetupDbContext(true); + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + var orders = dbContext.Orders.Where(o => o.AddedDateTime <= dateTime); + int totalOrdersToFetch = orders.Count(); + int totalOrdersFetched = 0; + int batchSize = 5000; + orders.Fetch(result => + { + totalOrdersFetched += result.Results.Count(); + var ordersFetched = result.Results; + foreach (var orderFetched in ordersFetched) + { + orderFetched.Price = 75; + } + dbContext.BulkUpdate(ordersFetched); + }, options => { options.BatchSize = batchSize; }); + + int totalOrder = orders.Count(); + int totalOrderUpdated = orders.Where(o => o.Price == 75).Count(); + Assert.IsTrue(totalOrdersToFetch == totalOrdersFetched, "The total number of rows fetched must match the number + of rows to fetch"); + Assert.IsTrue(totalOrderUpdated == totalOrdersFetched, "The total number of rows updated must match the number o +of rows that were fetched"); + Assert.IsTrue(totalOrder == totalOrderUpdated, "The total number of rows must match the number of rows that were +e updated"); + } + [TestMethod] + public void With_DateTime() + { + var dbContext = SetupDbContext(true); + int batchSize = 1000; + int batchCount = 0; + int totalCount = 0; + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + var orders = dbContext.Orders.Where(o => o.AddedDateTime <= dateTime); + int expectedTotalCount = orders.Count(); + int expectedBatchCount = (int)Math.Ceiling(expectedTotalCount / (decimal)batchSize); + + orders.Fetch(result => + { + batchCount++; + totalCount += result.Results.Count(); + Assert.IsTrue(result.Results.Count <= batchSize, "The count of results in each batch callback should less th +han or equal to the batchSize"); + }, options => { options.BatchSize = batchSize; }); + + Assert.IsTrue(expectedTotalCount > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(expectedTotalCount == totalCount, "The total number of rows fetched must match the count of existi +ing rows in database"); + Assert.IsTrue(expectedBatchCount == batchCount, "The total number of batches fetched must match what is expected +d"); + } + [TestMethod] + public void With_Decimal() + { + var dbContext = SetupDbContext(true); + int batchSize = 1000; + int batchCount = 0; + int totalCount = 0; + var orders = dbContext.Orders.Where(o => o.Price < 10M); + int expectedTotalCount = orders.Count(); + int expectedBatchCount = (int)Math.Ceiling(expectedTotalCount / (decimal)batchSize); + + orders.Fetch(result => + { + batchCount++; + totalCount += result.Results.Count(); + Assert.IsTrue(result.Results.Count <= batchSize, "The count of results in each batch callback should less th +han or equal to the batchSize"); + }, options => { options.BatchSize = batchSize; }); + + Assert.IsTrue(expectedTotalCount > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(expectedTotalCount == totalCount, "The total number of rows fetched must match the count of existi +ing rows in database"); + Assert.IsTrue(expectedBatchCount == batchCount, "The total number of batches fetched must match what is expected +d"); + } + [TestMethod] + public void With_Enum() + { + var dbContext = SetupDbContext(true); + int batchSize = 1000; + int batchCount = 0; + int totalCount = 0; + var products = dbContext.Products.Where(o => o.Price < 10M); + int expectedTotalCount = products.Count(); + int expectedBatchCount = (int)Math.Ceiling(expectedTotalCount / (decimal)batchSize); + + products.Fetch(result => + { + batchCount++; + totalCount += result.Results.Count(); + Assert.IsTrue(result.Results.Count <= batchSize, "The count of results in each batch callback should less th +han or equal to the batchSize"); + }, options => { options.BatchSize = batchSize; }); + + Assert.IsTrue(expectedTotalCount > 0, "There must be products in database that match this condition"); + Assert.IsTrue(expectedTotalCount == totalCount, "The total number of rows fetched must match the count of existi +ing rows in database"); + Assert.IsTrue(expectedBatchCount == batchCount, "The total number of batches fetched must match what is expected +d"); + } + [TestMethod] + public void With_Options_IgnoreColumns() + { + var dbContext = SetupDbContext(true); + int batchSize = 1000; + int batchCount = 0; + int totalCount = 0; + var orders = dbContext.Orders.Where(o => o.Price < 10M); + int expectedTotalCount = orders.Count(); + int expectedBatchCount = (int)Math.Ceiling(expectedTotalCount / (decimal)batchSize); + + orders.Fetch(result => + { + batchCount++; + totalCount += result.Results.Count(); + bool isAllExternalIdNull = !result.Results.Any(o => o.ExternalId != null); + Assert.IsTrue(isAllExternalIdNull, "All records should have ExternalId equal to NULL since it was not loaded +d."); + Assert.IsTrue(result.Results.Count <= batchSize, "The count of results in each batch callback should less th +han or equal to the batchSize"); + }, options => { options.BatchSize = batchSize; options.IgnoreColumns = s => new { s.ExternalId }; }); + + Assert.IsTrue(expectedTotalCount > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(expectedTotalCount == totalCount, "The total number of rows fetched must match the count of existi +ing rows in database"); + Assert.IsTrue(expectedBatchCount == batchCount, "The total number of batches fetched must match what is expected +d"); + } + [TestMethod] + public void With_Options_InputColumns() + { + var dbContext = SetupDbContext(true); + int batchSize = 1000; + int batchCount = 0; + int totalCount = 0; + var orders = dbContext.Orders.Where(o => o.Price < 10M); + int expectedTotalCount = orders.Count(); + int expectedBatchCount = (int)Math.Ceiling(expectedTotalCount / (decimal)batchSize); + + orders.Fetch(result => + { + batchCount++; + totalCount += result.Results.Count(); + bool isAllExternalIdNull = !result.Results.Any(o => o.ExternalId != null); + Assert.IsTrue(isAllExternalIdNull, "All records should have ExternalId equal to NULL since it was not loaded +d."); + Assert.IsTrue(result.Results.Count <= batchSize, "The count of results in each batch callback should less th +han or equal to the batchSize"); + }, options => { options.BatchSize = batchSize; options.InputColumns = s => new { s.Id, s.Price }; }); + + Assert.IsTrue(expectedTotalCount > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(expectedTotalCount == totalCount, "The total number of rows fetched must match the count of existi +ing rows in database"); + Assert.IsTrue(expectedBatchCount == batchCount, "The total number of batches fetched must match what is expected +d"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\DbContextExtensions\FetchAsync.cs --- + +using System; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class FetchAsync : DbContextExtensionsBase +{ + [TestMethod] + public async Task With_BulkInsert() + { + var dbContext = SetupDbContext(true); + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + var orders = dbContext.Orders.Where(o => o.AddedDateTime <= dateTime); + int totalOrdersToFetch = orders.Count(); + int totalOrdersFetched = 0; + int batchSize = 5000; + await orders.FetchAsync(async result => + { + totalOrdersFetched += result.Results.Count; + var ordersFetched = result.Results; + foreach (var orderFetched in ordersFetched) + { + orderFetched.Price = 75; + } + await dbContext.BulkInsertAsync(ordersFetched); + }, options => { options.BatchSize = batchSize; }); + + int totalOrder = orders.Count(); + int totalOrderInserted = orders.Where(o => o.Price == 75).Count(); + Assert.IsTrue(totalOrdersToFetch == totalOrdersFetched, "The total number of rows fetched must match the number + of rows to fetch"); + Assert.IsTrue(totalOrderInserted == totalOrdersFetched, "The total number of rows updated must match the number + of rows that were fetched"); + Assert.IsTrue(totalOrder - totalOrdersToFetch == totalOrderInserted, "The total number of rows must match the nu +umber of rows that were updated"); + } + [TestMethod] + public async Task With_BulkUpdate() + { + var dbContext = SetupDbContext(true); + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + var orders = dbContext.Orders.Where(o => o.AddedDateTime <= dateTime); + int totalOrdersToFetch = orders.Count(); + int totalOrdersFetched = 0; + int batchSize = 5000; + await orders.FetchAsync(async result => + { + totalOrdersFetched += result.Results.Count; + var ordersFetched = result.Results; + foreach (var orderFetched in ordersFetched) + { + orderFetched.Price = 75; + } + await dbContext.BulkUpdateAsync(ordersFetched); + }, options => { options.BatchSize = batchSize; }); + + int totalOrder = orders.Count(); + int totalOrderUpdated = orders.Where(o => o.Price == 75).Count(); + Assert.IsTrue(totalOrdersToFetch == totalOrdersFetched, "The total number of rows fetched must match the number + of rows to fetch"); + Assert.IsTrue(totalOrderUpdated == totalOrdersFetched, "The total number of rows updated must match the number o +of rows that were fetched"); + Assert.IsTrue(totalOrder == totalOrderUpdated, "The total number of rows must match the number of rows that were +e updated"); + } + [TestMethod] + public async Task With_DateTime() + { + var dbContext = SetupDbContext(true); + int batchSize = 1000; + int batchCount = 0; + int totalCount = 0; + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + var orders = dbContext.Orders.Where(o => o.AddedDateTime <= dateTime); + int expectedTotalCount = orders.Count(); + int expectedBatchCount = (int)Math.Ceiling(expectedTotalCount / (decimal)batchSize); + + await orders.FetchAsync(async result => + { + await Task.Run(() => + { + batchCount++; + totalCount += result.Results.Count; + Assert.IsTrue(result.Results.Count <= batchSize, "The count of results in each batch callback should les +ss than or equal to the batchSize"); + }); + }, options => { options.BatchSize = batchSize; }); + + Assert.IsTrue(expectedTotalCount > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(expectedTotalCount == totalCount, "The total number of rows fetched must match the count of existi +ing rows in database"); + Assert.IsTrue(expectedBatchCount == batchCount, "The total number of batches fetched must match what is expected +d"); + } + [TestMethod] + public async Task With_Decimal() + { + var dbContext = SetupDbContext(true); + int batchSize = 1000; + int batchCount = 0; + int totalCount = 0; + var orders = dbContext.Orders.Where(o => o.Price < 10M); + int expectedTotalCount = orders.Count(); + int expectedBatchCount = (int)Math.Ceiling(expectedTotalCount / (decimal)batchSize); + + await orders.FetchAsync(async result => + { + await Task.Run(() => + { + batchCount++; + totalCount += result.Results.Count(); + Assert.IsTrue(result.Results.Count <= batchSize, "The count of results in each batch callback should les +ss than or equal to the batchSize"); + }); + }, options => { options.BatchSize = batchSize; }); + + Assert.IsTrue(expectedTotalCount > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(expectedTotalCount == totalCount, "The total number of rows fetched must match the count of existi +ing rows in database"); + Assert.IsTrue(expectedBatchCount == batchCount, "The total number of batches fetched must match what is expected +d"); + } + [TestMethod] + public async Task With_Enum() + { + var dbContext = SetupDbContext(true); + int batchSize = 1000; + int batchCount = 0; + int totalCount = 0; + var products = dbContext.Products.Where(o => o.Price < 10M); + int expectedTotalCount = products.Count(); + int expectedBatchCount = (int)Math.Ceiling(expectedTotalCount / (decimal)batchSize); + + await products.FetchAsync(async result => + { + await Task.Run(() => + { + batchCount++; + totalCount += result.Results.Count(); + Assert.IsTrue(result.Results.Count <= batchSize, "The count of results in each batch callback should les +ss than or equal to the batchSize"); + }); + }, options => { options.BatchSize = batchSize; }); + + Assert.IsTrue(expectedTotalCount > 0, "There must be products in database that match this condition"); + Assert.IsTrue(expectedTotalCount == totalCount, "The total number of rows fetched must match the count of existi +ing rows in database"); + Assert.IsTrue(expectedBatchCount == batchCount, "The total number of batches fetched must match what is expected +d"); + } + [TestMethod] + public async Task With_Options_IgnoreColumns() + { + var dbContext = SetupDbContext(true); + int batchSize = 1000; + int batchCount = 0; + int totalCount = 0; + var orders = dbContext.Orders.Where(o => o.Price < 10M); + int expectedTotalCount = orders.Count(); + int expectedBatchCount = (int)Math.Ceiling(expectedTotalCount / (decimal)batchSize); + + await orders.FetchAsync(async result => + { + await Task.Run(() => + { + batchCount++; + totalCount += result.Results.Count; + bool isAllExternalIdNull = !result.Results.Any(o => o.ExternalId != null); + Assert.IsTrue(isAllExternalIdNull, "All records should have ExternalId equal to NULL since it was not lo +oaded."); + Assert.IsTrue(result.Results.Count <= batchSize, "The count of results in each batch callback should les +ss than or equal to the batchSize"); + }); + }, options => { options.BatchSize = batchSize; options.IgnoreColumns = s => new { s.ExternalId }; }); + + Assert.IsTrue(expectedTotalCount > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(expectedTotalCount == totalCount, "The total number of rows fetched must match the count of existi +ing rows in database"); + Assert.IsTrue(expectedBatchCount == batchCount, "The total number of batches fetched must match what is expected +d"); + } + [TestMethod] + public async Task With_Options_InputColumns() + { + var dbContext = SetupDbContext(true); + int batchSize = 1000; + int batchCount = 0; + int totalCount = 0; + var orders = dbContext.Orders.Where(o => o.Price < 10M); + int expectedTotalCount = orders.Count(); + int expectedBatchCount = (int)Math.Ceiling(expectedTotalCount / (decimal)batchSize); + + await orders.FetchAsync(async result => + { + await Task.Run(() => + { + batchCount++; + totalCount += result.Results.Count(); + bool isAllExternalIdNull = !result.Results.Any(o => o.ExternalId != null); + Assert.IsTrue(isAllExternalIdNull, "All records should have ExternalId equal to NULL since it was not lo +oaded."); + Assert.IsTrue(result.Results.Count <= batchSize, "The count of results in each batch callback should les +ss than or equal to the batchSize"); + }); + }, options => { options.BatchSize = batchSize; options.InputColumns = s => new { s.Id, s.Price }; }); + + Assert.IsTrue(expectedTotalCount > 0, "There must be orders in database that match this condition"); + Assert.IsTrue(expectedTotalCount == totalCount, "The total number of rows fetched must match the count of existi +ing rows in database"); + Assert.IsTrue(expectedBatchCount == batchCount, "The total number of batches fetched must match what is expected +d"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\DbContextExtensions\InsertFromQuery.cs --- + +using System; +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class InsertFromQuery : DbContextExtensionsBase +{ + [TestMethod] + public void With_DateTime_Value() + { + var dbContext = SetupDbContext(true); + string tableName = "OrdersLast30Days"; + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + int oldTotal = dbContext.Orders.Count(); + + var orders = dbContext.Orders.Where(o => o.AddedDateTime >= dateTime); + int oldSourceTotal = orders.Count(); + int rowsInserted = orders.InsertFromQuery(tableName, + o => new { o.Id, o.ExternalId, o.Price, o.AddedDateTime, o.ModifiedDateTime, o.Active }); + int newSourceTotal = orders.Count(); + int newTargetTotal = orders.UsingTable(tableName).Count(); + + Assert.IsTrue(oldTotal > oldSourceTotal, "The total should be greater then the number of rows selected from the + source table"); + Assert.IsTrue(oldSourceTotal > 0, "There should be existing data in the source table"); + Assert.IsTrue(oldSourceTotal == newSourceTotal, "There should not be any change in the count of rows in the sour +rce table"); + Assert.IsTrue(rowsInserted == oldSourceTotal, "The number of records inserted must match the count of the sourc +ce table"); + Assert.IsTrue(rowsInserted == newTargetTotal, "The different in count in the target table before and after the i +insert must match the total row inserted"); + } + [TestMethod] + public void With_Decimal_Value() + { + var dbContext = SetupDbContext(true); + string tableName = "OrdersUnderTen"; + var orders = dbContext.Orders.Where(o => o.Price < 10M); + int oldSourceTotal = orders.Count(); + int rowsInserted = dbContext.Orders.Where(o => o.Price < 10M).InsertFromQuery(tableName, o => new { o.Id, o.Pric +ce, o.AddedDateTime, o.Active }); + int newSourceTotal = orders.Count(); + int newTargetTotal = orders.UsingTable(tableName).Count(); + + Assert.IsTrue(oldSourceTotal > 0, "There should be existing data in the source table"); + Assert.IsTrue(oldSourceTotal == newSourceTotal, "There should not be any change in the count of rows in the sour +rce table"); + Assert.IsTrue(rowsInserted == oldSourceTotal, "The number of records inserted must match the count of the sourc +ce table"); + Assert.IsTrue(rowsInserted == newTargetTotal, "The different in count in the target table before and after the i +insert must match the total row inserted"); + } + [TestMethod] + public void With_Schema() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Schema); + string tableName = "ProductsUnderTen"; + var products = dbContext.ProductsWithCustomSchema.Where(o => o.Price < 10M); + int oldSourceTotal = products.Count(); + int rowsInserted = dbContext.ProductsWithCustomSchema.Where(o => o.Price < 10M).InsertFromQuery(tableName, o => + new { o.Id, o.Price }); + int newSourceTotal = products.Count(); + int newTargetTotal = products.UsingTable(tableName).Count(); + + Assert.IsTrue(oldSourceTotal > 0, "There should be existing data in the source table"); + Assert.IsTrue(oldSourceTotal == newSourceTotal, "There should not be any change in the count of rows in the sour +rce table"); + Assert.IsTrue(rowsInserted == oldSourceTotal, "The number of records inserted must match the count of the sourc +ce table"); + Assert.IsTrue(rowsInserted == newTargetTotal, "The different in count in the target table before and after the i +insert must match the total row inserted"); + } + [TestMethod] + public void With_Transaction() + { + var dbContext = SetupDbContext(true); + string tableName = "OrdersUnderTen"; + int rowsInserted; + bool tableExistsBefore, tableExistsAfter; + int oldSourceTotal = dbContext.Orders.Where(o => o.Price < 10M).Count(); + using (var transaction = dbContext.Database.BeginTransaction()) + { + rowsInserted = dbContext.Orders.Where(o => o.Price < 10M).InsertFromQuery(tableName, o => new { o.Price, o.I +Id, o.AddedDateTime, o.Active }); + tableExistsBefore = dbContext.Database.TableExists(tableName); + transaction.Rollback(); + } + tableExistsAfter = dbContext.Database.TableExists(tableName); + int newSourceTotal = dbContext.Orders.Where(o => o.Price < 10M).Count(); + + Assert.IsTrue(oldSourceTotal > 0, "There must be orders in database that match this condition (Price < $10)"); + Assert.IsTrue(rowsInserted == oldSourceTotal, "The number of rows update must match the count of rows that match +h the condtion (Price < $10)"); + Assert.IsTrue(newSourceTotal == oldSourceTotal, "The new count must match the old count since the transaction wa +as rollbacked"); + Assert.IsTrue(tableExistsBefore, string.Format("Table {0} should exist before transaction rollback", tableName)) +); + Assert.IsFalse(tableExistsAfter, string.Format("Table {0} should not exist after transaction rollback", tableNam +me)); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\DbContextExtensions\InsertFromQueryAsync.cs --- + +using System; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class InsertFromQueryAsync : DbContextExtensionsBase +{ + [TestMethod] + public async Task With_DateTime_Value() + { + var dbContext = SetupDbContext(true); + string tableName = "OrdersLast30Days"; + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + int oldTotal = dbContext.Orders.Count(); + + var orders = dbContext.Orders.Where(o => o.AddedDateTime >= dateTime); + int oldSourceTotal = orders.Count(); + int rowsInserted = await orders.InsertFromQueryAsync(tableName, + o => new { o.Id, o.ExternalId, o.Price, o.AddedDateTime, o.ModifiedDateTime, o.Active }); + int newSourceTotal = orders.Count(); + int newTargetTotal = orders.UsingTable(tableName).Count(); + + Assert.IsTrue(oldTotal > oldSourceTotal, "The total should be greater then the number of rows selected from the + source table"); + Assert.IsTrue(oldSourceTotal > 0, "There should be existing data in the source table"); + Assert.IsTrue(oldSourceTotal == newSourceTotal, "There should not be any change in the count of rows in the sour +rce table"); + Assert.IsTrue(rowsInserted == oldSourceTotal, "The number of records inserted must match the count of the sourc +ce table"); + Assert.IsTrue(rowsInserted == newTargetTotal, "The different in count in the target table before and after the i +insert must match the total row inserted"); + } + [TestMethod] + public async Task With_Decimal_Value() + { + var dbContext = SetupDbContext(true); + string tableName = "OrdersUnderTen"; + var orders = dbContext.Orders.Where(o => o.Price < 10M); + int oldSourceTotal = orders.Count(); + int rowsInserted = await dbContext.Orders.Where(o => o.Price < 10M).InsertFromQueryAsync(tableName, o => new { o +o.Id, o.Price, o.AddedDateTime, o.Active }); + int newSourceTotal = orders.Count(); + int newTargetTotal = orders.UsingTable(tableName).Count(); + + Assert.IsTrue(oldSourceTotal > 0, "There should be existing data in the source table"); + Assert.IsTrue(oldSourceTotal == newSourceTotal, "There should not be any change in the count of rows in the sour +rce table"); + Assert.IsTrue(rowsInserted == oldSourceTotal, "The number of records inserted must match the count of the sourc +ce table"); + Assert.IsTrue(rowsInserted == newTargetTotal, "The different in count in the target table before and after the i +insert must match the total row inserted"); + } + [TestMethod] + public async Task With_Schema() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Schema); + string tableName = "ProductsUnderTen"; + var products = dbContext.ProductsWithCustomSchema.Where(o => o.Price < 10M); + int oldSourceTotal = products.Count(); + int rowsInserted = await dbContext.ProductsWithCustomSchema.Where(o => o.Price < 10M).InsertFromQueryAsync(table +eName, o => new { o.Id, o.Price }); + int newSourceTotal = products.Count(); + int newTargetTotal = products.UsingTable(tableName).Count(); + + Assert.IsTrue(oldSourceTotal > 0, "There should be existing data in the source table"); + Assert.IsTrue(oldSourceTotal == newSourceTotal, "There should not be any change in the count of rows in the sour +rce table"); + Assert.IsTrue(rowsInserted == oldSourceTotal, "The number of records inserted must match the count of the sourc +ce table"); + Assert.IsTrue(rowsInserted == newTargetTotal, "The different in count in the target table before and after the i +insert must match the total row inserted"); + } + [TestMethod] + public async Task With_Transaction() + { + var dbContext = SetupDbContext(true); + string tableName = "OrdersUnderTen"; + int rowsInserted; + bool tableExistsBefore, tableExistsAfter; + int oldSourceTotal = dbContext.Orders.Where(o => o.Price < 10M).Count(); + using (var transaction = dbContext.Database.BeginTransaction()) + { + rowsInserted = await dbContext.Orders.Where(o => o.Price < 10M).InsertFromQueryAsync(tableName, o => new { o +o.Price, o.Id, o.AddedDateTime, o.Active }); + tableExistsBefore = dbContext.Database.TableExists(tableName); + transaction.Rollback(); + } + tableExistsAfter = dbContext.Database.TableExists(tableName); + int newSourceTotal = dbContext.Orders.Where(o => o.Price < 10M).Count(); + + Assert.IsTrue(oldSourceTotal > 0, "There must be orders in database that match this condition (Price < $10)"); + Assert.IsTrue(rowsInserted == oldSourceTotal, "The number of rows update must match the count of rows that match +h the condtion (Price < $10)"); + Assert.IsTrue(newSourceTotal == oldSourceTotal, "The new count must match the old count since the transaction wa +as rollbacked"); + Assert.IsTrue(tableExistsBefore, string.Format("Table {0} should exist before transaction rollback", tableName)) +); + Assert.IsFalse(tableExistsAfter, string.Format("Table {0} should not exist after transaction rollback", tableNam +me)); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\DbContextExtensions\QueryToCsvFile.cs --- + +using System.IO; +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class QueryToCsvFile : DbContextExtensionsBase +{ + [TestMethod] + public void With_Default_Options() + { + var dbContext = SetupDbContext(true); + var query = dbContext.Orders.Where(o => o.Price < 10M); + int count = query.Count(); + var queryToCsvFileResult = query.QueryToCsvFile("QueryToCsvFile-Test.csv"); + + Assert.IsTrue(count > 0, "There should be existing data in the source table"); + Assert.IsTrue(queryToCsvFileResult.DataRowCount == count, "The number of data rows written to the file should ma +atch the count from the database"); + Assert.IsTrue(queryToCsvFileResult.TotalRowCount == count + 1, "The total number of rows written to the file sho +ould match the count from the database plus the header row"); + } + [TestMethod] + public void With_Options_ColumnDelimiter_TextQualifer_HeaderRow() + { + var dbContext = SetupDbContext(true); + var query = dbContext.Orders.Where(o => o.Price < 10M); + int count = query.Count(); + var queryToCsvFileResult = query.QueryToCsvFile("QueryToCsvFile_Options_ColumnDelimiter_TextQualifer_HeaderRow-T +Test.csv", options => { options.ColumnDelimiter = "|"; options.TextQualifer = "\""; options.IncludeHeaderRow = false; }); + + Assert.IsTrue(count > 0, "There should be existing data in the source table"); + Assert.IsTrue(queryToCsvFileResult.DataRowCount == count, "The number of data rows written to the file should ma +atch the count from the database"); + Assert.IsTrue(queryToCsvFileResult.TotalRowCount == count, "The total number of rows written to the file should + match the count from the database without any header row"); + } + [TestMethod] + public void Using_FileStream() + { + var dbContext = SetupDbContext(true); + var query = dbContext.Orders.Where(o => o.Price < 10M); + int count = query.Count(); + var fileStream = File.Create("QueryToCsvFile_Stream-Test.csv"); + var queryToCsvFileResult = query.QueryToCsvFile(fileStream); + + Assert.IsTrue(count > 0, "There should be existing data in the source table"); + Assert.IsTrue(queryToCsvFileResult.DataRowCount == count, "The number of data rows written to the file should ma +atch the count from the database"); + Assert.IsTrue(queryToCsvFileResult.TotalRowCount == count + 1, "The total number of rows written to the file sho +ould match the count from the database plus the header row"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\DbContextExtensions\QueryToCsvFileAsync.cs --- + +using System.IO; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class QueryToCsvFileAsync : DbContextExtensionsBase +{ + [TestMethod] + public async Task With_Default_Options() + { + var dbContext = SetupDbContext(true); + var query = dbContext.Orders.Where(o => o.Price < 10M); + int count = query.Count(); + var queryToCsvFileResult = await query.QueryToCsvFileAsync("QueryToCsvFile-Test.csv"); + + Assert.IsTrue(count > 0, "There should be existing data in the source table"); + Assert.IsTrue(queryToCsvFileResult.DataRowCount == count, "The number of data rows written to the file should ma +atch the count from the database"); + Assert.IsTrue(queryToCsvFileResult.TotalRowCount == count + 1, "The total number of rows written to the file sho +ould match the count from the database plus the header row"); + } + [TestMethod] + public async Task With_Options_ColumnDelimiter_TextQualifer_HeaderRow() + { + var dbContext = SetupDbContext(true); + var query = dbContext.Orders.Where(o => o.Price < 10M); + int count = query.Count(); + var queryToCsvFileResult = await query.QueryToCsvFileAsync("QueryToCsvFile_Options_ColumnDelimiter_TextQualifer_ +_HeaderRow-Test.csv", options => { options.ColumnDelimiter = "|"; options.TextQualifer = "\""; options.IncludeHeaderRow = += false; }); + + Assert.IsTrue(count > 0, "There should be existing data in the source table"); + Assert.IsTrue(queryToCsvFileResult.DataRowCount == count, "The number of data rows written to the file should ma +atch the count from the database"); + Assert.IsTrue(queryToCsvFileResult.TotalRowCount == count, "The total number of rows written to the file should + match the count from the database without any header row"); + } + [TestMethod] + public async Task Using_FileStream() + { + var dbContext = SetupDbContext(true); + var query = dbContext.Orders.Where(o => o.Price < 10M); + int count = query.Count(); + var fileStream = File.Create("QueryToCsvFile_Stream-Test.csv"); + var queryToCsvFileResult = await query.QueryToCsvFileAsync(fileStream); + + Assert.IsTrue(count > 0, "There should be existing data in the source table"); + Assert.IsTrue(queryToCsvFileResult.DataRowCount == count, "The number of data rows written to the file should ma +atch the count from the database"); + Assert.IsTrue(queryToCsvFileResult.TotalRowCount == count + 1, "The total number of rows written to the file sho +ould match the count from the database plus the header row"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\DbContextExtensions\UpdateFromQuery.cs --- + +using System; +using System.Collections.Generic; +using System.Globalization; +using System.Linq; +using System.Threading; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Data; +using N.EntityFrameworkCore.Extensions.Test.Data.Enums; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class UpdateFromQuery : DbContextExtensionsBase +{ + [TestMethod] + public void With_Boolean_Value() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Products.Count(a => a.OutOfStock); + int rowUpdated = dbContext.Products.Where(a => a.OutOfStock).UpdateFromQuery(a => new Product { OutOfStock = fal +lse }); + int newTotal = dbContext.Products.Count(o => o.OutOfStock); + + Assert.IsTrue(oldTotal > 0, "There must be articles in database that match this condition (OutOfStock == true)") +); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndition (OutOfStock == false)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public void With_Concatenating_String() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.ExternalId == null); + int oldTotal = orders.Count(); + int rowUpdated = orders.UpdateFromQuery(o => new Order { ExternalId = Convert.ToString(o.Id) + "Test" }); + int newTotal = orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (ExternalId == null)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndition (ExternalId == null)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public void With_Concatenating_String_And_Number() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.ExternalId == null); + int oldTotal = orders.Count(); + int rowUpdated = orders.UpdateFromQuery(o => new Order { ExternalId = Convert.ToString(o.Id) + "Test" }); + int newTotal = orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (ExternalId == null)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndition (ExternalId == null)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public void With_DateTime_Value() + { + var dbContext = SetupDbContext(true); + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + DateTime now = DateTime.UtcNow; + + int oldTotal = dbContext.Orders.Where(o => o.AddedDateTime >= dateTime).Count(); + int rowUpdated = dbContext.Orders.Where(o => o.AddedDateTime >= dateTime).UpdateFromQuery(o => new Order { Modif +fiedDateTime = now }); + int newTotal = dbContext.Orders.Where(o => o.ModifiedDateTime == now).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Orders added in last 30 +0 days)"); + Assert.IsTrue(rowUpdated == newTotal, "The number of rows updated should equal the new count of rows that match + the condition (Orders added in last 30 days)"); + Assert.IsTrue(oldTotal == newTotal, "The count of rows matching the condition before and after the update should +d be equal. (Orders added in last 30 days)"); + } + [TestMethod] + public void With_DateTimeOffset_Value() + { + var dbContext = SetupDbContext(true); + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + DateTimeOffset now = DateTimeOffset.UtcNow; + + int oldTotal = dbContext.Orders.Where(o => o.AddedDateTime >= dateTime).Count(); + int rowUpdated = dbContext.Orders.Where(o => o.AddedDateTime >= dateTime).UpdateFromQuery(o => new Order { Modif +fiedDateTimeOffset = now }); + int newTotal = dbContext.Orders.Where(o => o.ModifiedDateTimeOffset == now).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Orders added in last 30 +0 days)"); + Assert.IsTrue(rowUpdated == newTotal, "The number of rows updated should equal the new count of rows that match + the condition (Orders added in last 30 days)"); + Assert.IsTrue(oldTotal == newTotal, "The count of rows matching the condition before and after the update should +d be equal. (Orders added in last 30 days)"); + } + [TestMethod] + public void With_DateTimeOffset_No_UTC_Value() + { + var dbContext = SetupDbContext(true); + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + DateTimeOffset now = DateTimeOffset.Parse("2020-06-17T16:00:00+05:00").ToUniversalTime(); + + int oldTotal = dbContext.Orders.Where(o => o.AddedDateTime >= dateTime).Count(); + int rowUpdated = dbContext.Orders.Where(o => o.AddedDateTime >= dateTime).UpdateFromQuery(o => new Order { Modif +fiedDateTimeOffset = now }); + int newTotal = dbContext.Orders.Where(o => o.ModifiedDateTimeOffset == now).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Orders added in last 30 +0 days)"); + Assert.IsTrue(rowUpdated == newTotal, "The number of rows updated should equal the new count of rows that match + the condition (Orders added in last 30 days)"); + Assert.IsTrue(oldTotal == newTotal, "The count of rows matching the condition before and after the update should +d be equal. (Orders added in last 30 days)"); + } + [TestMethod] + public void With_Decimal_Value() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price < 10M); + int oldTotal = orders.Count(); + int rowUpdated = orders.UpdateFromQuery(o => new Order { Price = 25.30M }); + int newTotal = orders.Count(); + int matchCount = dbContext.Orders.Where(o => o.Price == 25.30M).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < $10)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndtion (Price < $10)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + Assert.IsTrue(matchCount == rowUpdated, "The match count must be equal the number of rows updated in the databas +se."); + } + [TestMethod] + public void With_Different_Culture() + { + Thread.CurrentThread.CurrentCulture = CultureInfo.GetCultureInfo("sv-SE"); + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Where(o => o.Price < 10M).Count(); + int rowUpdated = dbContext.Orders.Where(o => o.Price < 10M).UpdateFromQuery(o => new Order { Price = 25.30M }); + int newTotal = dbContext.Orders.Where(o => o.Price < 10M).Count(); + int matchCount = dbContext.Orders.Where(o => o.Price == 25.30M).Count(); + + Assert.AreEqual("25,30", Convert.ToString(25.30M)); + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < $10)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndtion (Price < $10)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + Assert.IsTrue(matchCount == rowUpdated, "The match count must be equal the number of rows updated in the databas +se."); + } + [TestMethod] + public void With_Enum_Value() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(a => a.StatusEnum == ProductStatus.OutOfStock && a.OutOfStock); + int oldTotal = products.Count(); + int rowUpdated = products.UpdateFromQuery(a => new Product { StatusEnum = ProductStatus.InStock }); + int newTotal = products.Count(o => o.StatusEnum == ProductStatus.OutOfStock && o.OutOfStock); + int newTotal2 = dbContext.Products.Count(o => o.StatusEnum == ProductStatus.InStock && o.OutOfStock); + + Assert.IsTrue(oldTotal > 0, "There must be articles in database that match this condition (OutOfStock == true)") +); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndition (OutOfStock == false)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + Assert.IsTrue(newTotal2 == oldTotal, "All rows must have been updated"); + } + [TestMethod] + public void With_Guid_Value() + { + var dbContext = SetupDbContext(true); + var guid = Guid.NewGuid(); + var orders = dbContext.Orders.Where(o => o.Price < 10M); + int oldTotal = orders.Count(); + int rowUpdated = orders.UpdateFromQuery(o => new Order { GlobalId = guid }); + int matchCount = dbContext.Orders.Where(o => o.GlobalId == guid).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < $10)"); + Assert.IsTrue(rowUpdated == oldTotal, $"The number of rows update must match the count of rows that match the co +ondition (GlobalId = '{guid}')"); + Assert.IsTrue(matchCount == rowUpdated, "The match count must be equal the number of rows updated in the databas +se."); + } + [TestMethod] + public void With_Long_List() + { + var dbContext = SetupDbContext(true); + var ids = new List() { 1, 2, 3, 4, 5, 6, 7, 8 }; + var orders = dbContext.Orders.Where(o => ids.Contains(o.Id)); + int oldTotal = orders.Count(); + int rowUpdated = orders.UpdateFromQuery(o => new Order { Price = 25.25M }); + int newTotal = orders.Where(o => o.Price != 25.25M).Count(); + int matchCount = dbContext.Orders.Where(o => ids.Contains(o.Id) && o.Price == 25.25M).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < $10)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndtion (Price < $10)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + Assert.IsTrue(matchCount == rowUpdated, "The match count must be equal the number of rows updated in the databas +se."); + } + [TestMethod] + public void With_MethodCall() + { + var dbContext = SetupDbContext(true); + + int oldTotal = dbContext.Orders.Count(a => a.Price < 10); + int rowUpdated = dbContext.Orders.Where(a => a.Price < 10).UpdateFromQuery(o => new Order { Price = Math.Ceiling +g((o.Price + 10.5M) * 3 / 1) }); + int newTotal = dbContext.Orders.Count(o => o.Price < 10); + + Assert.IsTrue(oldTotal > 0, "There must be order in database that match this condition (Price < 10)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndition (Price < 10)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public void With_Null_Value() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.ExternalId != null); + int oldTotal = orders.Count(); + int rowUpdated = orders.UpdateFromQuery(o => new Order { ExternalId = null }); + int newTotal = orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (ExternalId != null)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndition (ExternalId != null)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public void With_Schema() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Schema); + var products = dbContext.ProductsWithCustomSchema.Where(o => o.Price < 5M); + int oldTotal = products.Count(); + int rowUpdated = products.UpdateFromQuery(o => new ProductWithCustomSchema { Price = 25.30M }); + int newTotal = products.Count(); + int matchCount = dbContext.ProductsWithCustomSchema.Where(o => o.Price == 25.30M).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be products in database that match this condition (Price < 5)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndtion (Price < 5)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + Assert.IsTrue(matchCount == rowUpdated, "The match count must be equal the number of rows updated in the databas +se."); + } + [TestMethod] + public void With_String_Containing_Apostrophe() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Where(o => o.ExternalId == null).Count(); + int rowUpdated = dbContext.Orders.Where(o => o.ExternalId == null).UpdateFromQuery(o => new Order { ExternalId = += "inv'alid" }); + int newTotal = dbContext.Orders.Where(o => o.ExternalId == null).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (ExternalId == null)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndition (ExternalId == null)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public void With_Transaction() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Where(o => o.Price < 10M).Count(); + int rowUpdated; + using (var transaction = dbContext.Database.BeginTransaction()) + { + rowUpdated = dbContext.Orders.Where(o => o.Price < 10M).UpdateFromQuery(o => new Order { Price = 25.30M }); + transaction.Rollback(); + } + int newTotal = dbContext.Orders.Where(o => o.Price < 10M).Count(); + int matchCount = dbContext.Orders.Where(o => o.Price == 25.30M).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < $10)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndtion (Price < $10)"); + Assert.IsTrue(newTotal == oldTotal, "The new count must match the old count since the transaction was rollbacked +d"); + Assert.IsTrue(matchCount == 0, "The match count must be equal to 0 since the transaction was rollbacked."); + } + [TestMethod] + public void With_Variables() + { + var dbContext = SetupDbContext(true); + decimal priceStart = 10M; + decimal priceUpdate = 0.34M; + + int oldTotal = dbContext.Orders.Count(a => a.Price < 10); + int rowUpdated = dbContext.Orders.Where(a => a.Price < 10).UpdateFromQuery(a => new Order { Price = priceStart + ++ priceUpdate }); + int newTotal = dbContext.Orders.Count(o => o.Price < 10); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < 10)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndition (Price < 10)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public void With_Variable_And_Decimal() + { + var dbContext = SetupDbContext(true); + decimal priceStart = 10M; + + int oldTotal = dbContext.Orders.Count(a => a.Price < 10); + int rowUpdated = dbContext.Orders.Where(a => a.Price < 10).UpdateFromQuery(a => new Order { Price = priceStart + ++ 7M }); + int newTotal = dbContext.Orders.Count(o => o.Price < 10); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < 10)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndition (Price < 10)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\DbContextExtensions\UpdateFromQueryAsync.cs --- + +using System; +using System.Collections.Generic; +using System.Globalization; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.EntityFrameworkCore; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Data; +using N.EntityFrameworkCore.Extensions.Test.Data.Enums; + +namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +[TestClass] +public class UpdateFromQueryAsync : DbContextExtensionsBase +{ + [TestMethod] + public async Task With_Boolean_Value() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Products.Count(a => a.OutOfStock); + int rowUpdated = await dbContext.Products.Where(a => a.OutOfStock).UpdateFromQueryAsync(a => new Product { OutOf +fStock = false }); + int newTotal = dbContext.Products.Count(o => o.OutOfStock); + + Assert.IsTrue(oldTotal > 0, "There must be articles in database that match this condition (OutOfStock == true)") +); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndition (OutOfStock == false)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public async Task With_Concatenating_String() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.ExternalId == null); + int oldTotal = orders.Count(); + int rowUpdated = await orders.UpdateFromQueryAsync(o => new Order { ExternalId = Convert.ToString(o.Id) + "Test" +" }); + int newTotal = orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (ExternalId == null)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndition (ExternalId == null)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public async Task With_Concatenating_String_And_Number() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.ExternalId == null); + int oldTotal = orders.Count(); + int rowUpdated = await orders.UpdateFromQueryAsync(o => new Order { ExternalId = Convert.ToString(o.Id) + "Test" +" }); + int newTotal = orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (ExternalId == null)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndition (ExternalId == null)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public async Task With_DateTime_Value() + { + var dbContext = SetupDbContext(true); + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + DateTime now = DateTime.UtcNow; + + int oldTotal = dbContext.Orders.Where(o => o.AddedDateTime >= dateTime).Count(); + int rowUpdated = await dbContext.Orders.Where(o => o.AddedDateTime >= dateTime).UpdateFromQueryAsync(o => new Or +rder { ModifiedDateTime = now }); + int newTotal = dbContext.Orders.Where(o => o.ModifiedDateTime == now).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Orders added in last 30 +0 days)"); + Assert.IsTrue(rowUpdated == newTotal, "The number of rows updated should equal the new count of rows that match + the condition (Orders added in last 30 days)"); + Assert.IsTrue(oldTotal == newTotal, "The count of rows matching the condition before and after the update should +d be equal. (Orders added in last 30 days)"); + } + [TestMethod] + public async Task With_DateTimeOffset_Value() + { + var dbContext = SetupDbContext(true); + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + DateTimeOffset now = DateTimeOffset.UtcNow; + + int oldTotal = dbContext.Orders.Where(o => o.AddedDateTime >= dateTime).Count(); + int rowUpdated = await dbContext.Orders.Where(o => o.AddedDateTime >= dateTime).UpdateFromQueryAsync(o => new Or +rder { ModifiedDateTimeOffset = now }); + int newTotal = dbContext.Orders.Where(o => o.ModifiedDateTimeOffset == now).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Orders added in last 30 +0 days)"); + Assert.IsTrue(rowUpdated == newTotal, "The number of rows updated should equal the new count of rows that match + the condition (Orders added in last 30 days)"); + Assert.IsTrue(oldTotal == newTotal, "The count of rows matching the condition before and after the update should +d be equal. (Orders added in last 30 days)"); + } + [TestMethod] + public async Task With_DateTimeOffset_No_UTC_Value() + { + var dbContext = SetupDbContext(true); + DateTime dateTime = dbContext.Orders.Max(o => o.AddedDateTime).AddDays(-30); + DateTimeOffset now = DateTimeOffset.Parse("2020-06-17T16:00:00+05:00").ToUniversalTime(); + + int oldTotal = dbContext.Orders.Where(o => o.AddedDateTime >= dateTime).Count(); + int rowUpdated = await dbContext.Orders.Where(o => o.AddedDateTime >= dateTime).UpdateFromQueryAsync(o => new Or +rder { ModifiedDateTimeOffset = now }); + int newTotal = dbContext.Orders.Where(o => o.ModifiedDateTimeOffset == now).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Orders added in last 30 +0 days)"); + Assert.IsTrue(rowUpdated == newTotal, "The number of rows updated should equal the new count of rows that match + the condition (Orders added in last 30 days)"); + Assert.IsTrue(oldTotal == newTotal, "The count of rows matching the condition before and after the update should +d be equal. (Orders added in last 30 days)"); + } + [TestMethod] + public async Task With_Decimal_Value() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.Price < 10M); + int oldTotal = orders.Count(); + int rowUpdated = await orders.UpdateFromQueryAsync(o => new Order { Price = 25.30M }); + int newTotal = orders.Count(); + int matchCount = dbContext.Orders.Where(o => o.Price == 25.30M).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < $10)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndtion (Price < $10)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + Assert.IsTrue(matchCount == rowUpdated, "The match count must be equal the number of rows updated in the databas +se."); + } + [TestMethod] + public async Task With_Different_Culture() + { + Thread.CurrentThread.CurrentCulture = CultureInfo.GetCultureInfo("sv-SE"); + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Where(o => o.Price < 10M).Count(); + int rowUpdated = await dbContext.Orders.Where(o => o.Price < 10M).UpdateFromQueryAsync(o => new Order { Price = + 25.30M }); + int newTotal = dbContext.Orders.Where(o => o.Price < 10M).Count(); + int matchCount = dbContext.Orders.Where(o => o.Price == 25.30M).Count(); + + Assert.AreEqual("25,30", Convert.ToString(25.30M)); + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < $10)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndtion (Price < $10)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + Assert.IsTrue(matchCount == rowUpdated, "The match count must be equal the number of rows updated in the databas +se."); + } + [TestMethod] + public async Task With_Enum_Value() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(a => a.StatusEnum == ProductStatus.OutOfStock && a.OutOfStock); + int oldTotal = products.Count(); + int rowUpdated = await products.UpdateFromQueryAsync(a => new Product { StatusEnum = ProductStatus.InStock }); + int newTotal = products.Count(o => o.StatusEnum == ProductStatus.OutOfStock && o.OutOfStock); + int newTotal2 = dbContext.Products.Count(o => o.StatusEnum == ProductStatus.InStock && o.OutOfStock); + + Assert.IsTrue(oldTotal > 0, "There must be articles in database that match this condition (OutOfStock == true)") +); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndition (OutOfStock == false)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + Assert.IsTrue(newTotal2 == oldTotal, "All rows must have been updated"); + } + [TestMethod] + public async Task With_Guid_Value() + { + var dbContext = SetupDbContext(true); + var guid = Guid.NewGuid(); + var orders = dbContext.Orders.Where(o => o.Price < 10M); + int oldTotal = await orders.CountAsync(); + int rowUpdated = await orders.UpdateFromQueryAsync(o => new Order { GlobalId = guid }); + int matchCount = await dbContext.Orders.Where(o => o.GlobalId == guid).CountAsync(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < $10)"); + Assert.IsTrue(rowUpdated == oldTotal, $"The number of rows update must match the count of rows that match the co +ondition (GlobalId = '{guid}')"); + Assert.IsTrue(matchCount == rowUpdated, "The match count must be equal the number of rows updated in the databas +se."); + } + [TestMethod] + public async Task With_Long_List() + { + var dbContext = SetupDbContext(true); + var ids = new List() { 1, 2, 3, 4, 5, 6, 7, 8 }; + var orders = dbContext.Orders.Where(o => ids.Contains(o.Id)); + int oldTotal = orders.Count(); + int rowUpdated = await orders.UpdateFromQueryAsync(o => new Order { Price = 25.25M }); + int newTotal = orders.Where(o => o.Price != 25.25M).Count(); + int matchCount = dbContext.Orders.Where(o => ids.Contains(o.Id) && o.Price == 25.25M).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < $10)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndtion (Price < $10)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + Assert.IsTrue(matchCount == rowUpdated, "The match count must be equal the number of rows updated in the databas +se."); + } + [TestMethod] + public async Task With_MethodCall() + { + var dbContext = SetupDbContext(true); + + int oldTotal = dbContext.Orders.Count(a => a.Price < 10); + int rowUpdated = await dbContext.Orders.Where(a => a.Price < 10).UpdateFromQueryAsync(o => new Order { Price = M +Math.Ceiling((o.Price + 10.5M) * 3 / 1) }); + int newTotal = dbContext.Orders.Count(o => o.Price < 10); + + Assert.IsTrue(oldTotal > 0, "There must be order in database that match this condition (Price < 10)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndition (Price < 10)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public async Task With_Null_Value() + { + var dbContext = SetupDbContext(true); + var orders = dbContext.Orders.Where(o => o.ExternalId != null); + int oldTotal = orders.Count(); + int rowUpdated = await orders.UpdateFromQueryAsync(o => new Order { ExternalId = null }); + int newTotal = orders.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (ExternalId != null)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndition (ExternalId != null)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public async Task With_Schema() + { + var dbContext = SetupDbContext(true, PopulateDataMode.Schema); + var products = dbContext.ProductsWithCustomSchema.Where(o => o.Price < 5M); + int oldTotal = products.Count(); + int rowUpdated = await products.UpdateFromQueryAsync(o => new ProductWithCustomSchema { Price = 25.30M }); + int newTotal = products.Count(); + int matchCount = dbContext.ProductsWithCustomSchema.Where(o => o.Price == 25.30M).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be products in database that match this condition (Price < 5)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndtion (Price < 5)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + Assert.IsTrue(matchCount == rowUpdated, "The match count must be equal the number of rows updated in the databas +se."); + } + [TestMethod] + public async Task With_String_Containing_Apostrophe() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Where(o => o.ExternalId == null).Count(); + int rowUpdated = await dbContext.Orders.Where(o => o.ExternalId == null).UpdateFromQueryAsync(o => new Order { E +ExternalId = "inv'alid" }); + int newTotal = dbContext.Orders.Where(o => o.ExternalId == null).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (ExternalId == null)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndition (ExternalId == null)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public async Task With_Transaction() + { + var dbContext = SetupDbContext(true); + int oldTotal = dbContext.Orders.Where(o => o.Price < 10M).Count(); + int rowUpdated; + using (var transaction = dbContext.Database.BeginTransaction()) + { + rowUpdated = await dbContext.Orders.Where(o => o.Price < 10M).UpdateFromQueryAsync(o => new Order { Price = + 25.30M }); + transaction.Rollback(); + } + int newTotal = dbContext.Orders.Where(o => o.Price < 10M).Count(); + int matchCount = dbContext.Orders.Where(o => o.Price == 25.30M).Count(); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < $10)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndtion (Price < $10)"); + Assert.IsTrue(newTotal == oldTotal, "The new count must match the old count since the transaction was rollbacked +d"); + Assert.IsTrue(matchCount == 0, "The match count must be equal to 0 since the transaction was rollbacked."); + } + [TestMethod] + public async Task With_Variables() + { + var dbContext = SetupDbContext(true); + decimal priceStart = 10M; + decimal priceUpdate = 0.34M; + + int oldTotal = dbContext.Orders.Count(a => a.Price < 10); + int rowUpdated = await dbContext.Orders.Where(a => a.Price < 10).UpdateFromQueryAsync(a => new Order { Price = p +priceStart + priceUpdate }); + int newTotal = dbContext.Orders.Count(o => o.Price < 10); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < 10)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndition (Price < 10)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } + [TestMethod] + public async Task With_Variable_And_Decimal() + { + var dbContext = SetupDbContext(true); + decimal priceStart = 10M; + + int oldTotal = dbContext.Orders.Count(a => a.Price < 10); + int rowUpdated = await dbContext.Orders.Where(a => a.Price < 10).UpdateFromQueryAsync(a => new Order { Price = p +priceStart + 7M }); + int newTotal = dbContext.Orders.Count(o => o.Price < 10); + + Assert.IsTrue(oldTotal > 0, "There must be orders in database that match this condition (Price < 10)"); + Assert.IsTrue(rowUpdated == oldTotal, "The number of rows update must match the count of rows that match the con +ndition (Price < 10)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\DbSetExtensions\Clear.cs --- + +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +namespace N.EntityFrameworkCore.Extensions.Test.DbSetExtensions; + +[TestClass] +public class Clear : DbContextExtensionsBase +{ + [TestMethod] + public void Using_Dbset() + { + var dbContext = SetupDbContext(true); + int oldOrdersCount = dbContext.Orders.Count(); + dbContext.Orders.Clear(); + int newOrdersCount = dbContext.Orders.Count(); + + Assert.IsTrue(oldOrdersCount > 0, "Orders table should have data"); + Assert.IsTrue(newOrdersCount == 0, "Order table should be empty after truncating"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\DbSetExtensions\ClearAsync.cs --- + +using System.Linq; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +namespace N.EntityFrameworkCore.Extensions.Test.DbSetExtensions; + +[TestClass] +public class ClearAsync : DbContextExtensionsBase +{ + [TestMethod] + public async Task Using_Dbset() + { + var dbContext = SetupDbContext(true); + int oldOrdersCount = dbContext.Orders.Count(); + await dbContext.Orders.ClearAsync(); + int newOrdersCount = dbContext.Orders.Count(); + + Assert.IsTrue(oldOrdersCount > 0, "Orders table should have data"); + Assert.IsTrue(newOrdersCount == 0, "Order table should be empty after truncating"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\DbSetExtensions\Truncate.cs --- + +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +namespace N.EntityFrameworkCore.Extensions.Test.DbSetExtensions; + +[TestClass] +public class Truncate : DbContextExtensionsBase +{ + [TestMethod] + public void Using_Dbset() + { + var dbContext = SetupDbContext(true); + int oldOrdersCount = dbContext.Orders.Count(); + dbContext.Orders.Truncate(); + int newOrdersCount = dbContext.Orders.Count(); + + Assert.IsTrue(oldOrdersCount > 0, "Orders table should have data"); + Assert.IsTrue(newOrdersCount == 0, "Order table should be empty after truncating"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\DbSetExtensions\TruncateAsync.cs --- + +using System.Linq; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.DbContextExtensions; + +namespace N.EntityFrameworkCore.Extensions.Test.DbSetExtensions; + +[TestClass] +public class TruncateAsync : DbContextExtensionsBase +{ + [TestMethod] + public async Task Using_Dbset() + { + var dbContext = SetupDbContext(true); + int oldOrdersCount = dbContext.Orders.Count(); + await dbContext.Orders.TruncateAsync(); + int newOrdersCount = dbContext.Orders.Count(); + + Assert.IsTrue(oldOrdersCount > 0, "Orders table should have data"); + Assert.IsTrue(newOrdersCount == 0, "Order table should be empty after truncating"); + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\LinqExtensions\ToSqlPredicateTests.cs --- + +using System; +using System.Linq.Expressions; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace N.EntityFrameworkCore.Extensions.Test.LinqExtensions; + +[TestClass] +public class ToSqlPredicateTests +{ + [TestMethod] + public void Should_handle_int() + { + Expression> expression = (s, t) => s.Id == t.Id; + + var sqlPredicate = expression.ToSqlPredicate("s", "t"); + + Assert.AreEqual("s.Id = t.Id", sqlPredicate); + } + + [TestMethod] + public void Should_handle_enum() + { + Expression> expression = (s, t) => s.Type == t.Type; + + var sqlPredicate = expression.ToSqlPredicate("s", "t"); + + Assert.AreEqual("s.Type = t.Type", sqlPredicate); + } + + [TestMethod] + public void Should_handle_complex_one() + { + Expression> expression = (s, t) => s.Type == t.Type && + (s.Id == t.Id && + s.ExternalId == t.ExternalId); + + var sqlPredicate = expression.ToSqlPredicate("s", "t"); + + Assert.AreEqual("s.Type = t.Type AND s.Id = t.Id AND s.ExternalId = t.ExternalId", sqlPredicate); + } + + [TestMethod] + public void Should_handle_prop_naming() + { + Expression> expression = (source, target) => source.Id == target.Id && + source.ExternalId == target.ExternalId; + + var sqlPredicate = expression.ToSqlPredicate("s", "t"); + + Assert.AreEqual("s.Id = t.Id AND s.ExternalId = t.ExternalId", sqlPredicate); + } + + [TestMethod] + public void Should_handle_simple_big_one() + { + Expression> expression = (s, t) => s.Type == t.Type && + s.Id == t.Id && + s.ExternalId == t.ExternalId && + s.TesterVar1 == t.TesterVar1 && + s.TesterVar2 == t.TesterVar2 && + s.TesterVar3 == t.TesterVar3 && + s.TesterVar4 == t.TesterVar4 && + s.TesterVar5 == t.TesterVar5; + + var sqlPredicate = expression.ToSqlPredicate("s", "t"); + + Assert.AreEqual("s.Type = t.Type AND s.Id = t.Id AND s.ExternalId = t.ExternalId AND s.TesterVar1 = t.TesterVar1 +1 AND s.TesterVar2 = t.TesterVar2 AND s.TesterVar3 = t.TesterVar3 AND s.TesterVar4 = t.TesterVar4 AND s.TesterVar5 = t.Te +esterVar5", sqlPredicate); + } + + [TestMethod] + public void Should_handle_complex_big_one() + { + Expression> expression = (s, t) => s.Type == t.Type && + s.Id == t.Id && + (s.ExternalId == t.ExternalId || s.TesterVar1 == t +t.TesterVar1) && + (s.TesterVar2 == t.TesterVar2 || (s.TesterVar2 == + null && t.TesterVar2 == null)) && + (s.TesterVar3 == t.TesterVar3 || (s.TesterVar3 != + null && t.TesterVar3 != null)); + + var sqlPredicate = expression.ToSqlPredicate("s", "t"); + + Assert.AreEqual("s.Type = t.Type AND s.Id = t.Id AND (s.ExternalId = t.ExternalId OR s.TesterVar1 = t.TesterVar1 +1) AND (s.TesterVar2 = t.TesterVar2 OR s.TesterVar2 IS NULL AND t.TesterVar2 IS NULL) AND (s.TesterVar3 = t.TesterVar3 OR +R s.TesterVar3 IS NOT NULL AND t.TesterVar3 IS NOT NULL)", sqlPredicate); + } + + record Entity + { + public Guid Id { get; set; } + public EntityType Type { get; set; } + public int ExternalId { get; set; } + public string TesterVar1 { get; set; } + public string TesterVar2 { get; set; } + public string TesterVar3 { get; set; } + public string TesterVar4 { get; set; } + public string TesterVar5 { get; set; } + } + + enum EntityType + { + One, + Two, + Three + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\Migrations\20250509021251_Initial.cs --- + +using System; +using Microsoft.EntityFrameworkCore.Migrations; + +#nullable disable + +namespace N.EntityFrameworkCore.Extensions.Test.Migrations +{ + public partial class Initial : Migration + { + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.Sql("CREATE TRIGGER trgProductWithTriggers\r\nON ProductsWithTrigger\r\nFOR INSERT, UPDATE, +, DELETE\r\nAS\r\nBEGIN\r\n PRINT 1 END"); + } + + protected override void Down(MigrationBuilder migrationBuilder) + { + + } + } +} + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\Migrations\20250509021251_Initial.Designer.cs --- + +// +using System; +using System.Collections.Generic; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.EntityFrameworkCore.Migrations; +using Microsoft.EntityFrameworkCore.Storage.ValueConversion; +using N.EntityFrameworkCore.Extensions.Test.Data; + +#nullable disable + +namespace N.EntityFrameworkCore.Extensions.Test.Migrations +{ + [DbContext(typeof(TestDbContext))] + [Migration("20250509021251_Initial")] + partial class Initial + { + protected override void BuildTargetModel(ModelBuilder modelBuilder) + { + + } + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\Migrations\TestDbContextModelSnapshot.cs --- + +// +using System; +using System.Collections.Generic; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.EntityFrameworkCore.Storage.ValueConversion; +using N.EntityFrameworkCore.Extensions.Test.Data; + +#nullable disable + +namespace N.EntityFrameworkCore.Extensions.Test.Migrations +{ + [DbContext(typeof(TestDbContext))] + partial class TestDbContextModelSnapshot : ModelSnapshot + { + protected override void BuildModel(ModelBuilder modelBuilder) + { + + } + } +} + + + +--- FILE: C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions\N.EntityFrameworkCore.Extensions.Test\N.EntityFramework +k.Extensions.PostgreSql.Test\N.EntityFramework.Extensions.PostgreSql.Test.csproj --- + + + + + net10.0 + + $(MSBuildThisFileDirectory)..\..\N.EntityFramework.Extensions.PostgreSql.runsettings + + + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + + + + + + + + + + + + + + + + + + + Always + + + + + + +___BEGIN___COMMAND_DONE_MARKER___0 +PS C:\Users\ttsch\source\N.EntityFrameworkCore.Extensions> \ No newline at end of file