Skip to content

Commit

Permalink
Forward code generation for function pointers from type aliases (#170)
Browse files Browse the repository at this point in the history
  • Loading branch information
lithiumtoast authored Jan 4, 2025
1 parent 3d9a40e commit 41eb0d2
Show file tree
Hide file tree
Showing 17 changed files with 193 additions and 51 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// <auto-generated>
// This code was generated by the following tool on 2025-01-03 14:36:43 GMT-05:00:
// https://github.com/bottlenoselabs/c2cs (v2025-01-03 14:36:43 GMT-05:00)
// This code was generated by the following tool on 2025-01-04 10:45:12 GMT-05:00:
// https://github.com/bottlenoselabs/c2cs (v2025-01-04 10:45:12 GMT-05:00)
//
// Changes to this file may cause incorrect behavior and will be lost if the code is regenerated.
// </auto-generated>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// <auto-generated>
// This code was generated by the following tool on 2025-01-03 14:36:43 GMT-05:00:
// This code was generated by the following tool on 2025-01-04 10:45:12 GMT-05:00:
// https://github.com/bottlenoselabs/c2cs (v0.0.0.0)
//
// Changes to this file may cause incorrect behavior and will be lost if the code is regenerated.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,16 @@ namespace helloworld
{
public static unsafe partial class my_c_library
{
[global::System.Runtime.InteropServices.DllImportAttribute("my_c_library", EntryPoint = "hw_invoke_callback", ExactSpelling = true)]
public static extern partial void hw_invoke_callback(global::helloworld.my_c_library.FnPtr_CString_Void f, global::Interop.Runtime.CString s);
[global::System.Runtime.InteropServices.DllImportAttribute("my_c_library", EntryPoint = "hw_invoke_callback1", ExactSpelling = true)]
public static extern partial void hw_invoke_callback1(global::helloworld.my_c_library.hw_callback f, global::Interop.Runtime.CString s);
}
}
namespace helloworld
{
public static unsafe partial class my_c_library
{
[global::System.Runtime.InteropServices.DllImportAttribute("my_c_library", EntryPoint = "hw_invoke_callback2", ExactSpelling = true)]
public static extern partial void hw_invoke_callback2(global::helloworld.my_c_library.FnPtr_CString_Void f, global::Interop.Runtime.CString s);
}
}
namespace helloworld
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// <auto-generated>
// This code was generated by the following tool on 2025-01-03 14:36:43 GMT-05:00:
// This code was generated by the following tool on 2025-01-04 10:45:12 GMT-05:00:
// https://github.com/bottlenoselabs/c2cs (v0.0.0.0)
//
// Changes to this file may cause incorrect behavior and will be lost if the code is regenerated.
Expand Down Expand Up @@ -29,9 +29,13 @@ public static unsafe partial class my_c_library
[UnmanagedCallConv(CallConvs = new[] { typeof(CallConvCdecl) })]
public static partial void hw_hello_world();

[LibraryImport(LibraryName, EntryPoint = "hw_invoke_callback")]
[LibraryImport(LibraryName, EntryPoint = "hw_invoke_callback1")]
[UnmanagedCallConv(CallConvs = new[] { typeof(CallConvCdecl) })]
public static partial void hw_invoke_callback(FnPtr_CString_Void f, CString s);
public static partial void hw_invoke_callback1(hw_callback f, CString s);

[LibraryImport(LibraryName, EntryPoint = "hw_invoke_callback2")]
[UnmanagedCallConv(CallConvs = new[] { typeof(CallConvCdecl) })]
public static partial void hw_invoke_callback2(FnPtr_CString_Void f, CString s);

[LibraryImport(LibraryName, EntryPoint = "hw_pass_enum")]
[UnmanagedCallConv(CallConvs = new[] { typeof(CallConvCdecl) })]
Expand Down Expand Up @@ -62,6 +66,17 @@ public enum hw_my_enum_week_day : int
_HW_MY_ENUM_WEEK_DAY_FORCE_U32 = 2147483647
}

[StructLayout(LayoutKind.Sequential)]
public partial struct hw_callback
{
public delegate* unmanaged<CString, void> Pointer;

public hw_callback(delegate* unmanaged<CString, void> pointer)
{
Pointer = pointer;
}
}

[StructLayout(LayoutKind.Sequential)]
public partial struct FnPtr_CString_Void
{
Expand Down
9 changes: 6 additions & 3 deletions src/cs/examples/helloworld/helloworld-app/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,16 @@ private static unsafe void Main()
// - It uses the same naming as `System.Func<>`. The last type on the name is always the return type. In this case 'void`.
// Only available in C# 9 (.NET 5+). See https://learn.microsoft.com/en-us/dotnet/csharp/language-reference/unsafe-code#function-pointers
// Additionally function pointers need to use the `address-of` operator (&) to a C# static function marked with the UnmanagedCallersOnly attribute. See https://learn.microsoft.com/en-us/dotnet/api/system.runtime.interopservices.unmanagedcallersonlyattribute?view=net-9.0
var functionPointer = new FnPtr_CString_Void(&Callback);
var functionPointer = new hw_callback(&Callback);
#else
var functionPointer = new FnPtr_CString_Void(Callback);
#endif

using var cStringCallback = (CString)"Hello from callback!";
hw_invoke_callback(functionPointer, cStringCallback);
using var cStringCallback1 = (CString)"Hello from callback!";
hw_invoke_callback1(functionPointer, cStringCallback1);

// using var cStringCallback2 = (CString)"Hello again from callback!";
// hw_invoke_callback2(functionPointer, cStringCallback2);
}

#if NET5_0_OR_GREATER
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@ typedef enum hw_my_enum_week_day {
_HW_MY_ENUM_WEEK_DAY_FORCE_U32 = 0x7FFFFFFF
} hw_my_enum_week_day;

typedef void (*hw_callback)(const char* s);

MY_C_LIBRARY_API_DECL void hw_hello_world(void);
MY_C_LIBRARY_API_DECL void hw_invoke_callback(void f(const char*), const char* s);
MY_C_LIBRARY_API_DECL void hw_invoke_callback1(hw_callback f, const char* s);
MY_C_LIBRARY_API_DECL void hw_invoke_callback2(void f(const char*), const char* s);
MY_C_LIBRARY_API_DECL void hw_pass_string(const char* s);
MY_C_LIBRARY_API_DECL void hw_pass_integers_by_value(uint16_t a, int32_t b, uint64_t c);
MY_C_LIBRARY_API_DECL void hw_pass_integers_by_reference(const uint16_t* a, const int32_t* b, const uint64_t* c);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,17 @@ void hw_hello_world(void)
printf("Hello world from C!\n");
}

void hw_invoke_callback(void f(const char*), const char* s)
void hw_invoke_callback1(hw_callback f, const char* s)
{
if (f == 0)
{
return;
}

f(s);
}

void hw_invoke_callback2(void f(const char*), const char* s)
{
if (f == 0)
{
Expand Down
76 changes: 74 additions & 2 deletions src/cs/production/c2cs.Tool/GenerateCSharpCode/CodeGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
using C2CS.GenerateCSharpCode.Generators;
using c2ffi.Data;
using c2ffi.Data.Nodes;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.FindSymbols;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;

Expand Down Expand Up @@ -58,9 +62,11 @@ public CodeProject GenerateCodeProject(CFfiCrossPlatform ffi, DiagnosticsSink di
AddDocumentInteropRuntime(options, documents);
AddDocumentAssemblyAttributes(options, documents);

var newDocuments = PostProcessDocuments(documents.ToImmutableArray());

return new CodeProject
{
Documents = documents.ToImmutable()
Documents = [..newDocuments]
};
}
#pragma warning disable CA1031
Expand All @@ -72,12 +78,78 @@ public CodeProject GenerateCodeProject(CFfiCrossPlatform ffi, DiagnosticsSink di
}
}

private static ImmutableArray<CodeProjectDocument> PostProcessDocuments(
ImmutableArray<CodeProjectDocument> documents)
{
using var workspace = new AdhocWorkspace();
var project = workspace.CurrentSolution.AddProject(
"TemporaryProject", "TemporaryAssembly", LanguageNames.CSharp)
.AddMetadataReferences([
MetadataReference.CreateFromFile(typeof(object).Assembly.Location)
]);
foreach (var document in documents)
{
project = project.AddDocument(document.FileName, document.Code).Project;
}

var newDocuments = new List<CodeProjectDocument>();
var compilation = project.GetCompilationAsync().Result!;
foreach (var syntaxTree in compilation.SyntaxTrees)
{
var code = PostProcessDocumentSyntaxTree(syntaxTree, compilation, project.Solution);
var newDocument = new CodeProjectDocument
{
Code = code,
FileName = syntaxTree.FilePath
};
newDocuments.Add(newDocument);
}

return [..newDocuments];
}

private static string PostProcessDocumentSyntaxTree(
SyntaxTree syntaxTree,
Compilation compilation,
Solution solution)
{
var root = syntaxTree.GetRoot();
var semanticModel = compilation.GetSemanticModel(syntaxTree);

var functionPointerStructWrappers = root.DescendantNodes()
.OfType<StructDeclarationSyntax>()
.Where(x =>
x.Identifier.Text.StartsWith("FnPtr_", StringComparison.InvariantCultureIgnoreCase))
.ToImmutableArray();
if (functionPointerStructWrappers.IsDefaultOrEmpty)
{
return root.SyntaxTree.ToString();
}

var newRoot = root;
foreach (var functionPointerStructWrapper in functionPointerStructWrappers)
{
var symbol = semanticModel.GetDeclaredSymbol(functionPointerStructWrapper)!;
var references = SymbolFinder
.FindReferencesAsync(symbol, solution).Result
.ToImmutableArray();

var referencesCount = references.Sum(x => x.Locations.Count());
if (referencesCount == 0)
{
newRoot = newRoot.RemoveNode(functionPointerStructWrapper, SyntaxRemoveOptions.KeepNoTrivia)!;
}
}

return newRoot.SyntaxTree.ToString();
}

private void AddDocumentPInvoke(
CodeGeneratorDocumentOptions options,
ImmutableArray<CodeProjectDocument>.Builder documents,
CFfiCrossPlatform ffi)
{
var context = new CodeGeneratorContext(_input, _nodeCodeGenerators);
var context = new CodeGeneratorContext(_input, ffi, _nodeCodeGenerators);
var document = _codeGeneratorDocumentPInvoke.Generate(options, context, ffi);
documents.Add(document);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Collections.Immutable;
using bottlenoselabs.Common.Tools;
using C2CS.GenerateCSharpCode.Generators;
using c2ffi.Data;
using c2ffi.Data.Nodes;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
Expand All @@ -23,26 +24,26 @@ public sealed class CodeGeneratorContext

public InputSanitized Input { get; }

public CFfiCrossPlatform Ffi { get; }

public NameMapper NameMapper => _nameMapper;

public CodeGeneratorContext(
InputSanitized input,
CFfiCrossPlatform ffi,
ImmutableDictionary<Type, BaseGenerator> nodeCodeGenerators)
{
Input = input;
Ffi = ffi;
_nameMapper = new NameMapper(this);
_nodeCodeGenerators = nodeCodeGenerators;
}

internal TMemberDeclarationSyntax? ProcessCNode<TNode, TMemberDeclarationSyntax>(TNode node)
public TMemberDeclarationSyntax? ProcessCNode<TNode, TMemberDeclarationSyntax>(TNode node)
where TNode : CNode
where TMemberDeclarationSyntax : MemberDeclarationSyntax
{
var type = typeof(TNode);
if (!_nodeCodeGenerators.TryGetValue(type, out var codeGenerator))
{
throw new ToolException($"A code generator '{nameof(BaseGenerator)}' does not exist for the type '{type.FullName ?? type.Name}'.");
}
var codeGenerator = GetCodeGenerator<TNode>();

var nameCSharp = _nameMapper.GetNodeNameCSharp(node);
var isAlreadyAdded = !_existingNamesCSharp.Add(nameCSharp);
Expand All @@ -51,7 +52,12 @@ public CodeGeneratorContext(
return null;
}

var code = codeGenerator.GenerateCode(nameCSharp, this, node);
var code = codeGenerator.GenerateCode(this, nameCSharp, node);
if (string.IsNullOrEmpty(code))
{
return null;
}

var memberDeclarationSyntax = SyntaxFactory.ParseMemberDeclaration(code.Trim())!;
if (memberDeclarationSyntax is not TMemberDeclarationSyntax typedMemberDeclarationSyntax)
{
Expand All @@ -64,4 +70,17 @@ public CodeGeneratorContext(

return memberDeclarationSyntaxWithTrivia;
}

public BaseGenerator<TNode> GetCodeGenerator<TNode>()
where TNode : CNode
{
var type = typeof(TNode);
if (!_nodeCodeGenerators.TryGetValue(type, out var codeGenerator))
{
throw new ToolException(
$"A code generator does not exist for the C node '{type.Name}'.");
}

return (BaseGenerator<TNode>)codeGenerator;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ public abstract class BaseGenerator(ILogger<BaseGenerator> logger)
{
protected readonly ILogger<BaseGenerator> Logger = logger;

protected internal abstract string GenerateCode(
protected abstract string? GenerateCode(
string nameCSharp,
CodeGeneratorContext context,
object obj);
Expand All @@ -20,15 +20,15 @@ public abstract class BaseGenerator<TNode>(ILogger<BaseGenerator<TNode>> logger)
: BaseGenerator(logger)
where TNode : CNode
{
protected internal override string GenerateCode(
public abstract string? GenerateCode(CodeGeneratorContext context, string nameCSharp, TNode node);

protected override string? GenerateCode(
string nameCSharp,
CodeGeneratorContext context,
object obj)
{
var node = (TNode)obj;
var code = GenerateCode(nameCSharp, context, node);
var code = GenerateCode(context, nameCSharp, node);
return code;
}

protected abstract string GenerateCode(string nameCSharp, CodeGeneratorContext context, TNode node);
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Bottlenose Labs Inc. (https://github.com/bottlenoselabs). All rights reserved.
// Licensed under the MIT license. See LICENSE file in the Git repository root directory for full license information.

using c2ffi.Data;
using c2ffi.Data.Nodes;
using JetBrains.Annotations;
using Microsoft.Extensions.Logging;
Expand All @@ -11,13 +12,21 @@ namespace C2CS.GenerateCSharpCode.Generators;
public class GeneratorAliasType(ILogger<GeneratorAliasType> logger)
: BaseGenerator<CTypeAlias>(logger)
{
protected override string GenerateCode(
string nameCSharp, CodeGeneratorContext context, CTypeAlias node)
public override string? GenerateCode(CodeGeneratorContext context, string nameCSharp, CTypeAlias node)
{
var underlyingTypeNameCSharp = context.NameMapper.GetTypeNameCSharp(node.UnderlyingType);
var sizeOf = node.UnderlyingType.SizeOf;
var alignOf = node.UnderlyingType.AlignOf;

if (node.UnderlyingType.NodeKind == CNodeKind.FunctionPointer)
{
var functionPointer = context.Ffi.FunctionPointers[node.UnderlyingType.Name];
var functionPointerCodeGenerator = context.GetCodeGenerator<CFunctionPointer>();
var functionPointerCode = functionPointerCodeGenerator.GenerateCode(
context, nameCSharp, functionPointer);
return functionPointerCode;
}

var code = $$"""
[StructLayout(LayoutKind.Explicit, Size = {{sizeOf}}, Pack = {{alignOf}})]
public partial struct {{nameCSharp}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,7 @@ namespace C2CS.GenerateCSharpCode.Generators;
public class GeneratorEnum(ILogger<GeneratorEnum> logger)
: BaseGenerator<CEnum>(logger)
{
protected override string GenerateCode(
string nameCSharp,
CodeGeneratorContext context,
CEnum node)
public override string? GenerateCode(CodeGeneratorContext context, string nameCSharp, CEnum node)
{
var integerTypeNameCSharp = node.SizeOf switch
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@ namespace C2CS.GenerateCSharpCode.Generators;
public class GeneratorFunction(ILogger<GeneratorFunction> logger)
: BaseGenerator<CFunction>(logger)
{
protected override string GenerateCode(
string nameCSharp,
CodeGeneratorContext context,
CFunction function)
public override string? GenerateCode(CodeGeneratorContext context, string nameCSharp, CFunction function)
{
var returnTypeNameCSharp = context.NameMapper.GetTypeNameCSharp(function.ReturnType);
var parametersStringCSharp = string.Join(',', function.Parameters.Select(
Expand Down
Loading

0 comments on commit 41eb0d2

Please sign in to comment.