diff --git a/src/WTG.Analyzers.Test/TestData/LinqEnumerableAnalyzer/PrependTypeMismatch/Diagnostics.xml b/src/WTG.Analyzers.Test/TestData/LinqEnumerableAnalyzer/PrependTypeMismatch/Diagnostics.xml new file mode 100644 index 0000000..efb060e --- /dev/null +++ b/src/WTG.Analyzers.Test/TestData/LinqEnumerableAnalyzer/PrependTypeMismatch/Diagnostics.xml @@ -0,0 +1,19 @@ + + + + + Test0.cs: (10, 3-36) + + + Test0.cs: (11, 3-20) + + + + + Test0.cs: (12, 3-25) + + + Test0.cs: (13, 3-20) + + + diff --git a/src/WTG.Analyzers.Test/TestData/LinqEnumerableAnalyzer/PrependTypeMismatch/Result.cs b/src/WTG.Analyzers.Test/TestData/LinqEnumerableAnalyzer/PrependTypeMismatch/Result.cs new file mode 100644 index 0000000..de82624 --- /dev/null +++ b/src/WTG.Analyzers.Test/TestData/LinqEnumerableAnalyzer/PrependTypeMismatch/Result.cs @@ -0,0 +1,20 @@ +using System.Collections.Generic; +using System.Linq; + +public class Bob +{ + public void Method() + { + var viewModel = new ViewModel(); + + viewModel.Items.Prepend(viewModel); + Enumerable.Prepend(viewModel.Items, viewModel); + viewModel.Items.Append(viewModel); + Enumerable.Append(viewModel.Items, viewModel); + } +} + +public class ViewModel +{ + public IEnumerable Items { get; set; } +} diff --git a/src/WTG.Analyzers.Test/TestData/LinqEnumerableAnalyzer/PrependTypeMismatch/Source.cs b/src/WTG.Analyzers.Test/TestData/LinqEnumerableAnalyzer/PrependTypeMismatch/Source.cs new file mode 100644 index 0000000..bdc2318 --- /dev/null +++ b/src/WTG.Analyzers.Test/TestData/LinqEnumerableAnalyzer/PrependTypeMismatch/Source.cs @@ -0,0 +1,20 @@ +using System.Collections.Generic; +using System.Linq; + +public class Bob +{ + public void Method() + { + var viewModel = new ViewModel(); + + new object[] { viewModel }.Concat(viewModel.Items); + Enumerable.Concat(new object[] { viewModel }, viewModel.Items); + viewModel.Items.Concat(new object[] { viewModel }); + Enumerable.Concat(viewModel.Items, new object[] { viewModel }); + } +} + +public class ViewModel +{ + public IEnumerable Items { get; set; } +} diff --git a/src/WTG.Analyzers/Analyzers/LinqEnumerable/LinqEnumerableCodeFixProvider.cs b/src/WTG.Analyzers/Analyzers/LinqEnumerable/LinqEnumerableCodeFixProvider.cs index 18ea77f..ef7e656 100644 --- a/src/WTG.Analyzers/Analyzers/LinqEnumerable/LinqEnumerableCodeFixProvider.cs +++ b/src/WTG.Analyzers/Analyzers/LinqEnumerable/LinqEnumerableCodeFixProvider.cs @@ -104,15 +104,19 @@ public static SyntaxNode FixConcatWithAppendMethod(MemberAccessExpressionSyntax var listOfArgumentsAndSeparators = new List(); + ExpressionSyntax singleElementCollection; + switch (invocation.ArgumentList.Arguments.Count) { case 1: - listOfArgumentsAndSeparators.Add(Argument(LinqEnumerableUtils.GetFirstValue(invocation.ArgumentList.Arguments[0].Expression)!)); + singleElementCollection = invocation.ArgumentList.Arguments[0].Expression; + listOfArgumentsAndSeparators.Add(Argument(LinqEnumerableUtils.GetFirstValue(singleElementCollection)!)); break; case 2: + singleElementCollection = invocation.ArgumentList.Arguments[1].Expression; listOfArgumentsAndSeparators.Add(invocation.ArgumentList.Arguments[0]); listOfArgumentsAndSeparators.Add(Token(SyntaxKind.CommaToken)); - listOfArgumentsAndSeparators.Add(Argument(LinqEnumerableUtils.GetFirstValue(invocation.ArgumentList.Arguments[1].Expression)!)); + listOfArgumentsAndSeparators.Add(Argument(LinqEnumerableUtils.GetFirstValue(singleElementCollection)!)); break; default: throw new InvalidOperationException("Unreachable - Code fix should never trigger for >2 arguments."); @@ -125,7 +129,7 @@ public static SyntaxNode FixConcatWithAppendMethod(MemberAccessExpressionSyntax .WithTriviaFrom(m.Expression) .WithAdditionalAnnotations(Simplifier.Annotation), m.OperatorToken, - IdentifierName(nameof(Enumerable.Append)) + GetMethodName(nameof(Enumerable.Append), singleElementCollection) .WithTriviaFrom(m.Name))) .WithArgumentList( ArgumentList( @@ -141,19 +145,22 @@ public static SyntaxNode FixConcatWithAppendMethod(MemberAccessExpressionSyntax var listOfArgumentsAndSeparators = new List(); ExpressionSyntax member; + ExpressionSyntax singleElementCollection; switch (invocation.ArgumentList.Arguments.Count) { case 1: - listOfArgumentsAndSeparators.Add(Argument(LinqEnumerableUtils.GetFirstValue(m.Expression.TryGetExpressionFromParenthesizedExpression())!)); + singleElementCollection = m.Expression.TryGetExpressionFromParenthesizedExpression(); + listOfArgumentsAndSeparators.Add(Argument(LinqEnumerableUtils.GetFirstValue(singleElementCollection)!)); member = ParenthesizedExpression(invocation.ArgumentList.Arguments[0].Expression.WithoutTrivia()) .WithTriviaFrom(m.Expression) .WithAdditionalAnnotations(Simplifier.Annotation); break; case 2: + singleElementCollection = invocation.ArgumentList.Arguments[0].Expression; listOfArgumentsAndSeparators.Add(invocation.ArgumentList.Arguments[1]); listOfArgumentsAndSeparators.Add(Token(SyntaxKind.CommaToken)); - listOfArgumentsAndSeparators.Add(Argument(LinqEnumerableUtils.GetFirstValue(invocation.ArgumentList.Arguments[0].Expression)!)); + listOfArgumentsAndSeparators.Add(Argument(LinqEnumerableUtils.GetFirstValue(singleElementCollection)!)); member = m.Expression; break; @@ -166,7 +173,7 @@ public static SyntaxNode FixConcatWithAppendMethod(MemberAccessExpressionSyntax SyntaxKind.SimpleMemberAccessExpression, member, m.OperatorToken, - IdentifierName(nameof(Enumerable.Prepend)) + GetMethodName(nameof(Enumerable.Prepend), singleElementCollection) .WithTriviaFrom(m.Name))) .WithArgumentList( ArgumentList( @@ -213,5 +220,42 @@ public static SyntaxNode FixConcatWithNewCollection(MemberAccessExpressionSyntax .WithTriviaFrom(invocation) .WithAdditionalAnnotations(Simplifier.Annotation); } + + static SimpleNameSyntax GetMethodName(string methodName, ExpressionSyntax singleElementCollection) + { + var elementType = GetCollectionElementType(singleElementCollection.TryGetExpressionFromParenthesizedExpression()); + + if (elementType != null) + { + return GenericName(Identifier(methodName)) + .WithTypeArgumentList( + TypeArgumentList( + SingletonSeparatedList( + elementType.WithoutTrivia()))) + .WithAdditionalAnnotations(Simplifier.Annotation); + } + + return IdentifierName(methodName); + } + + static TypeSyntax? GetCollectionElementType(ExpressionSyntax expression) + { + switch (expression.Kind()) + { + case SyntaxKind.ArrayCreationExpression: + return ((ArrayCreationExpressionSyntax)expression).Type.ElementType; + + case SyntaxKind.ObjectCreationExpression: + var objectCreationType = ((ObjectCreationExpressionSyntax)expression).Type; + if (objectCreationType is GenericNameSyntax genericName && genericName.TypeArgumentList.Arguments.Count == 1) + { + return genericName.TypeArgumentList.Arguments[0]; + } + + break; + } + + return null; + } } }