Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ref/out parameters which return list #19

Merged
merged 1 commit into from
Jan 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions SqlMarshal.CompilationTests/ConnectionManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ public ConnectionManager(SqlConnection connection)
[SqlMarshal("persons_list")]
public partial IList<PersonInformation> GetResult();

[SqlMarshal("persons_list")]
public partial void GetResultWithOut(out IList<PersonInformation> result);

[SqlMarshal("persons_list")]
public partial IList<(int Id, string Name)> GetTupleResult();

Expand Down
7 changes: 7 additions & 0 deletions SqlMarshal.CompilationTests/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ private static void TestConnectionManager(ConnectionManager connectionManager)
WritePerson(personInfo);
}

connectionManager.GetResultWithOut(out var samePersons);
WriteLine("Same 10 rows from persons_list SP using out parameter");
foreach (var personInfo in persons.Take(10))
{
WritePerson(personInfo);
}

var persons2 = connectionManager.GetResultByPage(2, out var totalCount);
WriteLine("Print results of persons_by_page SP");
foreach (var personInfo in persons2)
Expand Down
142 changes: 142 additions & 0 deletions SqlMarshal.Tests/SqlConnectionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,148 @@ partial class C
Assert.AreEqual(expectedOutput, output);
}

[TestMethod]
public void MapResultSetToProcedureWithOutParameter()
{
string source = @"
namespace Foo
{
public class Item
{
public string StringValue { get; set; }
public int Int32Value { get; set; }
public int? NullableInt32Value { get; set; }
}

class C
{
private DbConnection connection;

[SqlMarshal(""sp_TestSP"")]
public partial void M(out IList<Item> result)
}
}";
string output = this.GetGeneratedOutput(source, NullableContextOptions.Disable);

Assert.IsNotNull(output);

var expectedOutput = @"// <auto-generated>
// Code generated by Stored Procedures Code Generator.
// Changes may cause incorrect behavior and will be lost if the code is
// regenerated.
// </auto-generated>
#nullable enable
#pragma warning disable 1591

namespace Foo
{
using System;
using System.Collections.Generic;
using System.Data.Common;
using System.Linq;

partial class C
{
public partial void M(out IList<Foo.Item> result)
{
var connection = this.connection;
using var command = connection.CreateCommand();

var sqlQuery = @""sp_TestSP"";
command.CommandText = sqlQuery;
using var reader = command.ExecuteReader();
var __result = new List<Item>();
while (reader.Read())
{
var item = new Item();
var value_0 = reader.GetValue(0);
item.StringValue = value_0 == DBNull.Value ? (string?)null : (string)value_0;
var value_1 = reader.GetValue(1);
item.Int32Value = (int)value_1;
var value_2 = reader.GetValue(2);
item.NullableInt32Value = value_2 == DBNull.Value ? (int?)null : (int)value_2;
__result.Add(item);
}

reader.Close();
result = __result;
}
}
}";
Assert.AreEqual(expectedOutput, output);
}

[TestMethod]
public void MapResultSetToProcedureWithRefParameter()
{
string source = @"
namespace Foo
{
public class Item
{
public string StringValue { get; set; }
public int Int32Value { get; set; }
public int? NullableInt32Value { get; set; }
}

class C
{
private DbConnection connection;

[SqlMarshal(""sp_TestSP"")]
public partial void M(ref IList<Item> result)
}
}";
string output = this.GetGeneratedOutput(source, NullableContextOptions.Disable);

Assert.IsNotNull(output);

var expectedOutput = @"// <auto-generated>
// Code generated by Stored Procedures Code Generator.
// Changes may cause incorrect behavior and will be lost if the code is
// regenerated.
// </auto-generated>
#nullable enable
#pragma warning disable 1591

namespace Foo
{
using System;
using System.Collections.Generic;
using System.Data.Common;
using System.Linq;

partial class C
{
public partial void M(ref IList<Foo.Item> result)
{
var connection = this.connection;
using var command = connection.CreateCommand();

var sqlQuery = @""sp_TestSP"";
command.CommandText = sqlQuery;
using var reader = command.ExecuteReader();
var __result = result ?? new List<Item>();
while (reader.Read())
{
var item = new Item();
var value_0 = reader.GetValue(0);
item.StringValue = value_0 == DBNull.Value ? (string?)null : (string)value_0;
var value_1 = reader.GetValue(1);
item.Int32Value = (int)value_1;
var value_2 = reader.GetValue(2);
item.NullableInt32Value = value_2 == DBNull.Value ? (int?)null : (int)value_2;
__result.Add(item);
}

reader.Close();
result = __result;
}
}
}";
Assert.AreEqual(expectedOutput, output);
}

[TestMethod]
public void MapSingleObjectToProcedureConnection()
{
Expand Down
62 changes: 62 additions & 0 deletions SqlMarshal.Tests/StoredProcedureGenerationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1624,6 +1624,68 @@ partial class C
}
}
}
}";
Assert.AreEqual(expectedOutput, output);
}

[TestMethod]
public void MapListWithOtputParameter()
{
string source = @"
namespace Foo
{
class C
{
[SqlMarshal(""sp_TestSP"")]
public partial void M(int clientId, int? personId, out IList<Item> result)
}
}";
string output = this.GetGeneratedOutput(source, NullableContextOptions.Disable);

Assert.IsNotNull(output);

var expectedOutput = @"// <auto-generated>
// Code generated by Stored Procedures Code Generator.
// Changes may cause incorrect behavior and will be lost if the code is
// regenerated.
// </auto-generated>
#nullable enable
#pragma warning disable 1591

namespace Foo
{
using System;
using System.Data.Common;
using System.Linq;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Storage;

partial class C
{
public partial void M(int clientId, int? personId, out IList<Item> result)
{
var connection = this.dbContext.Database.GetDbConnection();
using var command = connection.CreateCommand();

var clientIdParameter = command.CreateParameter();
clientIdParameter.ParameterName = ""@client_id"";
clientIdParameter.Value = clientId;

var personIdParameter = command.CreateParameter();
personIdParameter.ParameterName = ""@person_id"";
personIdParameter.Value = personId == null ? (object)DBNull.Value : personId;

var parameters = new DbParameter[]
{
clientIdParameter,
personIdParameter,
};

var sqlQuery = @""sp_TestSP @client_id, @person_id"";
var __result = this.dbContext.Items.FromSqlRaw(sqlQuery, parameters).ToList();
result = __result;
}
}
}";
Assert.AreEqual(expectedOutput, output);
}
Expand Down
2 changes: 1 addition & 1 deletion SqlMarshal/Extensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ internal static ITypeSymbol GetUnderlyingType(ITypeSymbol returnType)
return returnType;
}

internal static bool IsList(ITypeSymbol returnType) => returnType.Name == "IList" || returnType.Name == "List";
internal static bool IsList(this ITypeSymbol returnType) => returnType.Name == "IList" || returnType.Name == "List";

internal static bool IsEnumerable(ITypeSymbol returnType) => returnType.Name == "IEnumerable";

Expand Down
11 changes: 9 additions & 2 deletions SqlMarshal/Generator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ private void MapResults(

if (isList)
{
source.AppendLine($@"var __result = new List<{(IsTuple(itemType) ? itemType.ToDisplayString() : itemType.Name)}>();");
source.AppendLine($@"var __result = {(methodGenerationContext.OutputResultsetParameter?.RefKind == RefKind.Ref ? $"{methodGenerationContext.OutputResultsetParameter.Name} ?? " : string.Empty)}new List<{(IsTuple(itemType) ? itemType.ToDisplayString() : itemType.Name)}>();");
if (isTask)
{
source.AppendLine($"while (await reader.ReadAsync({cancellationToken}).ConfigureAwait(false))");
Expand Down Expand Up @@ -1106,7 +1106,14 @@ private void ProcessMethod(
{
this.MapResults(source, methodGenerationContext, methodSymbol, parameters, itemType, hasNullableAnnotations, isList, isTask);
MarshalOutputParameters(source, parameters, hasNullableAnnotations);
source.AppendLine(ReturnStatement(IdentifierName("__result")).NormalizeWhitespace().ToFullString());
if (methodGenerationContext.OutputResultsetParameter != null)
{
source.AppendLine($"{methodGenerationContext.OutputResultsetParameter.Name} = __result;");
}
else
{
source.AppendLine(ReturnStatement(IdentifierName("__result")).NormalizeWhitespace().ToFullString());
}
}
}

Expand Down
25 changes: 23 additions & 2 deletions SqlMarshal/MethodGenerationContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,18 @@ internal MethodGenerationContext(ClassGenerationContext classGenerationContext,
this.DbContextParameter = GetDbContextParameter(methodSymbol);
this.CustomSqlParameter = GetCustomSqlParameter(methodSymbol);
this.CancellationTokenParameter = GetCancellationTokenParameter(methodSymbol);
this.OutputResultsetParameter = GetOutputResultsetParameter(methodSymbol);
var parameters = methodSymbol.Parameters;
if (this.ConnectionParameter != null)
{
parameters = parameters.Remove(this.ConnectionParameter);
}

if (this.OutputResultsetParameter != null)
{
parameters = parameters.Remove(this.OutputResultsetParameter);
}

if (this.DbContextParameter != null)
{
parameters = parameters.Remove(this.DbContextParameter);
Expand Down Expand Up @@ -60,6 +66,8 @@ internal MethodGenerationContext(ClassGenerationContext classGenerationContext,

internal IParameterSymbol? ConnectionParameter { get; }

internal IParameterSymbol? OutputResultsetParameter { get; }

internal IParameterSymbol? TransactionParameter { get; }

internal IParameterSymbol? DbContextParameter { get; }
Expand All @@ -74,9 +82,9 @@ internal MethodGenerationContext(ClassGenerationContext classGenerationContext,

internal bool IsDataReader => this.MethodSymbol.ReturnType.Name == "DbDataReader";

internal ITypeSymbol ReturnType => this.MethodSymbol.ReturnType.UnwrapTaskType();
internal ITypeSymbol ReturnType => this.OutputResultsetParameter?.Type ?? this.MethodSymbol.ReturnType.UnwrapTaskType();

internal bool IsList => IsList(this.ReturnType);
internal bool IsList => this.ReturnType.IsList();

internal bool IsEnumerable => IsEnumerable(this.ReturnType);

Expand All @@ -95,6 +103,19 @@ internal MethodGenerationContext(ClassGenerationContext classGenerationContext,
return null;
}

private static IParameterSymbol? GetOutputResultsetParameter(IMethodSymbol methodSymbol)
{
foreach (var parameterSymbol in methodSymbol.Parameters)
{
if (parameterSymbol.Type.IsList() && (parameterSymbol.RefKind == RefKind.Out || parameterSymbol.RefKind == RefKind.Ref))
{
return parameterSymbol;
}
}

return null;
}

private static IParameterSymbol? GetTransactionParameter(IMethodSymbol methodSymbol)
{
foreach (var parameterSymbol in methodSymbol.Parameters)
Expand Down
Loading