Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable CancellationToken for non-remoting actor implementations #1202

Merged
merged 4 commits into from
Jan 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@

try
{
var (header, body) = await runtime.DispatchWithRemotingAsync(actorTypeName, actorId, methodName, daprActorheader, context.Request.Body);
var (header, body) = await runtime.DispatchWithRemotingAsync(actorTypeName, actorId, methodName, daprActorheader, context.Request.Body, context.RequestAborted);

Check warning on line 106 in src/Dapr.Actors.AspNetCore/ActorsEndpointRouteBuilderExtensions.cs

View check run for this annotation

Codecov / codecov/patch

src/Dapr.Actors.AspNetCore/ActorsEndpointRouteBuilderExtensions.cs#L106

Added line #L106 was not covered by tests

// Item 1 is header , Item 2 is body
if (header != string.Empty)
Expand All @@ -112,14 +112,14 @@
context.Response.Headers[Constants.ErrorResponseHeaderName] = header; // add error header
}

await context.Response.Body.WriteAsync(body, 0, body.Length); // add response message body
await context.Response.Body.WriteAsync(body, 0, body.Length, context.RequestAborted); // add response message body

Check warning on line 115 in src/Dapr.Actors.AspNetCore/ActorsEndpointRouteBuilderExtensions.cs

View check run for this annotation

Codecov / codecov/patch

src/Dapr.Actors.AspNetCore/ActorsEndpointRouteBuilderExtensions.cs#L115

Added line #L115 was not covered by tests
}
catch (Exception ex)
{
var (header, body) = CreateExceptionResponseMessage(ex);

context.Response.Headers[Constants.ErrorResponseHeaderName] = header;
await context.Response.Body.WriteAsync(body, 0, body.Length);
await context.Response.Body.WriteAsync(body, 0, body.Length, context.RequestAborted);

Check warning on line 122 in src/Dapr.Actors.AspNetCore/ActorsEndpointRouteBuilderExtensions.cs

View check run for this annotation

Codecov / codecov/patch

src/Dapr.Actors.AspNetCore/ActorsEndpointRouteBuilderExtensions.cs#L122

Added line #L122 was not covered by tests
}
finally
{
Expand All @@ -130,7 +130,7 @@
{
try
{
await runtime.DispatchWithoutRemotingAsync(actorTypeName, actorId, methodName, context.Request.Body, context.Response.Body);
await runtime.DispatchWithoutRemotingAsync(actorTypeName, actorId, methodName, context.Request.Body, context.Response.Body, context.RequestAborted);
}
finally
{
Expand Down
8 changes: 4 additions & 4 deletions src/Dapr.Actors/Runtime/ActorManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -148,16 +148,16 @@ async Task<object> RequestFunc(Actor actor, CancellationToken ct)
var parameters = methodInfo.GetParameters();
dynamic awaitable;

if (parameters.Length == 0)
if (parameters.Length == 0 || (parameters.Length == 1 && parameters[0].ParameterType == typeof(CancellationToken)))
{
awaitable = methodInfo.Invoke(actor, null);
awaitable = methodInfo.Invoke(actor, parameters.Length == 0 ? null : new object[] { ct });
}
else if (parameters.Length == 1)
else if (parameters.Length == 1 || (parameters.Length == 2 && parameters[1].ParameterType == typeof(CancellationToken)))
{
// deserialize using stream.
var type = parameters[0].ParameterType;
var deserializedType = await JsonSerializer.DeserializeAsync(requestBodyStream, type, jsonSerializerOptions);
awaitable = methodInfo.Invoke(actor, new object[] { deserializedType });
awaitable = methodInfo.Invoke(actor, parameters.Length == 1 ? new object[] { deserializedType } : new object[] { deserializedType, ct });
}
else
{
Expand Down
106 changes: 106 additions & 0 deletions test/Dapr.Actors.Test/Runtime/ActorRuntimeTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ namespace Dapr.Actors.Test
using Xunit;
using Dapr.Actors.Client;
using System.Reflection;
using System.Threading;

public sealed class ActorRuntimeTests
{
Expand Down Expand Up @@ -109,6 +110,111 @@ public async Task NoActivateMessageFromRuntime()
Assert.Contains(actorType.Name, runtime.RegisteredActors.Select(a => a.Type.ActorTypeName), StringComparer.InvariantCulture);
}

public interface INotRemotedActor : IActor
{
Task<string> NoArgumentsAsync();

Task<string> NoArgumentsWithCancellationAsync(CancellationToken cancellationToken = default);

Task<string> SingleArgumentAsync(bool arg);

Task<string> SingleArgumentWithCancellationAsync(bool arg, CancellationToken cancellationToken = default);
}

public sealed class NotRemotedActor : Actor, INotRemotedActor
{
public NotRemotedActor(ActorHost host)
: base(host)
{
}

public Task<string> NoArgumentsAsync()
{
return Task.FromResult(nameof(NoArgumentsAsync));
}

public Task<string> NoArgumentsWithCancellationAsync(CancellationToken cancellationToken = default)
{
return Task.FromResult(nameof(NoArgumentsWithCancellationAsync));
}

public Task<string> SingleArgumentAsync(bool arg)
{
return Task.FromResult(nameof(SingleArgumentAsync));
}

public Task<string> SingleArgumentWithCancellationAsync(bool arg, CancellationToken cancellationToken = default)
{
return Task.FromResult(nameof(SingleArgumentWithCancellationAsync));
}
}

public async Task<string> InvokeMethod<T>(string methodName, object arg = null) where T : Actor
{
var options = new ActorRuntimeOptions();

options.Actors.RegisterActor<T>();

var runtime = new ActorRuntime(options, loggerFactory, activatorFactory, proxyFactory);

using var input = new MemoryStream();

if (arg is not null)
{
JsonSerializer.Serialize(input, arg);

input.Seek(0, SeekOrigin.Begin);
}

using var output = new MemoryStream();

await runtime.DispatchWithoutRemotingAsync(typeof(T).Name, ActorId.CreateRandom().ToString(), methodName, input, output);

output.Seek(0, SeekOrigin.Begin);

return JsonSerializer.Deserialize<string>(output);
}

[Fact]
public async Task NoRemotingMethodWithNoArguments()
{
string methodName = nameof(INotRemotedActor.NoArgumentsAsync);

string result = await InvokeMethod<NotRemotedActor>(methodName);

Assert.Equal(methodName, result);
}

[Fact]
public async Task NoRemotingMethodWithNoArgumentsWithCancellation()
{
string methodName = nameof(INotRemotedActor.NoArgumentsWithCancellationAsync);

string result = await InvokeMethod<NotRemotedActor>(methodName);

Assert.Equal(methodName, result);
}

[Fact]
public async Task NoRemotingMethodWithSingleArgument()
{
string methodName = nameof(INotRemotedActor.SingleArgumentAsync);

string result = await InvokeMethod<NotRemotedActor>(methodName, true);

Assert.Equal(methodName, result);
}

[Fact]
public async Task NoRemotingMethodWithSingleArgumentWithCancellation()
{
string methodName = nameof(INotRemotedActor.SingleArgumentWithCancellationAsync);

string result = await InvokeMethod<NotRemotedActor>(methodName, true);

Assert.Equal(methodName, result);
}

[Fact]
public async Task Actor_UsesCustomActivator()
{
Expand Down
Loading