Skip to content

Commit

Permalink
Enable CancellationToken for non-remoting actor implementations (#1202
Browse files Browse the repository at this point in the history
)

* Sketch no arguments with cancellation.

Signed-off-by: Phillip Hoff <[email protected]>

* Sketch the other argument permutations.

Signed-off-by: Phillip Hoff <[email protected]>

* Refactor tests.

Signed-off-by: Phillip Hoff <[email protected]>

* Push HTTP request cancellation token down into handlers.

Signed-off-by: Phillip Hoff <[email protected]>

---------

Signed-off-by: Phillip Hoff <[email protected]>
  • Loading branch information
philliphoff authored Jan 6, 2024
1 parent 10ef818 commit 8d06a1f
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ private static IEndpointConventionBuilder MapActorMethodEndpoint(this IEndpointR

try
{
var (header, body) = await runtime.DispatchWithRemotingAsync(actorTypeName, actorId, methodName, daprActorheader, context.Request.Body);
var (header, body) = await runtime.DispatchWithRemotingAsync(actorTypeName, actorId, methodName, daprActorheader, context.Request.Body, context.RequestAborted);

// Item 1 is header , Item 2 is body
if (header != string.Empty)
Expand All @@ -112,14 +112,14 @@ private static IEndpointConventionBuilder MapActorMethodEndpoint(this IEndpointR
context.Response.Headers[Constants.ErrorResponseHeaderName] = header; // add error header
}

await context.Response.Body.WriteAsync(body, 0, body.Length); // add response message body
await context.Response.Body.WriteAsync(body, 0, body.Length, context.RequestAborted); // add response message body
}
catch (Exception ex)
{
var (header, body) = CreateExceptionResponseMessage(ex);

context.Response.Headers[Constants.ErrorResponseHeaderName] = header;
await context.Response.Body.WriteAsync(body, 0, body.Length);
await context.Response.Body.WriteAsync(body, 0, body.Length, context.RequestAborted);
}
finally
{
Expand All @@ -130,7 +130,7 @@ private static IEndpointConventionBuilder MapActorMethodEndpoint(this IEndpointR
{
try
{
await runtime.DispatchWithoutRemotingAsync(actorTypeName, actorId, methodName, context.Request.Body, context.Response.Body);
await runtime.DispatchWithoutRemotingAsync(actorTypeName, actorId, methodName, context.Request.Body, context.Response.Body, context.RequestAborted);
}
finally
{
Expand Down
8 changes: 4 additions & 4 deletions src/Dapr.Actors/Runtime/ActorManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -148,16 +148,16 @@ async Task<object> RequestFunc(Actor actor, CancellationToken ct)
var parameters = methodInfo.GetParameters();
dynamic awaitable;

if (parameters.Length == 0)
if (parameters.Length == 0 || (parameters.Length == 1 && parameters[0].ParameterType == typeof(CancellationToken)))
{
awaitable = methodInfo.Invoke(actor, null);
awaitable = methodInfo.Invoke(actor, parameters.Length == 0 ? null : new object[] { ct });
}
else if (parameters.Length == 1)
else if (parameters.Length == 1 || (parameters.Length == 2 && parameters[1].ParameterType == typeof(CancellationToken)))
{
// deserialize using stream.
var type = parameters[0].ParameterType;
var deserializedType = await JsonSerializer.DeserializeAsync(requestBodyStream, type, jsonSerializerOptions);
awaitable = methodInfo.Invoke(actor, new object[] { deserializedType });
awaitable = methodInfo.Invoke(actor, parameters.Length == 1 ? new object[] { deserializedType } : new object[] { deserializedType, ct });
}
else
{
Expand Down
106 changes: 106 additions & 0 deletions test/Dapr.Actors.Test/Runtime/ActorRuntimeTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ namespace Dapr.Actors.Test
using Xunit;
using Dapr.Actors.Client;
using System.Reflection;
using System.Threading;

public sealed class ActorRuntimeTests
{
Expand Down Expand Up @@ -109,6 +110,111 @@ public async Task NoActivateMessageFromRuntime()
Assert.Contains(actorType.Name, runtime.RegisteredActors.Select(a => a.Type.ActorTypeName), StringComparer.InvariantCulture);
}

public interface INotRemotedActor : IActor
{
Task<string> NoArgumentsAsync();

Task<string> NoArgumentsWithCancellationAsync(CancellationToken cancellationToken = default);

Task<string> SingleArgumentAsync(bool arg);

Task<string> SingleArgumentWithCancellationAsync(bool arg, CancellationToken cancellationToken = default);
}

public sealed class NotRemotedActor : Actor, INotRemotedActor
{
public NotRemotedActor(ActorHost host)
: base(host)
{
}

public Task<string> NoArgumentsAsync()
{
return Task.FromResult(nameof(NoArgumentsAsync));
}

public Task<string> NoArgumentsWithCancellationAsync(CancellationToken cancellationToken = default)
{
return Task.FromResult(nameof(NoArgumentsWithCancellationAsync));
}

public Task<string> SingleArgumentAsync(bool arg)
{
return Task.FromResult(nameof(SingleArgumentAsync));
}

public Task<string> SingleArgumentWithCancellationAsync(bool arg, CancellationToken cancellationToken = default)
{
return Task.FromResult(nameof(SingleArgumentWithCancellationAsync));
}
}

public async Task<string> InvokeMethod<T>(string methodName, object arg = null) where T : Actor
{
var options = new ActorRuntimeOptions();

options.Actors.RegisterActor<T>();

var runtime = new ActorRuntime(options, loggerFactory, activatorFactory, proxyFactory);

using var input = new MemoryStream();

if (arg is not null)
{
JsonSerializer.Serialize(input, arg);

input.Seek(0, SeekOrigin.Begin);
}

using var output = new MemoryStream();

await runtime.DispatchWithoutRemotingAsync(typeof(T).Name, ActorId.CreateRandom().ToString(), methodName, input, output);

output.Seek(0, SeekOrigin.Begin);

return JsonSerializer.Deserialize<string>(output);
}

[Fact]
public async Task NoRemotingMethodWithNoArguments()
{
string methodName = nameof(INotRemotedActor.NoArgumentsAsync);

string result = await InvokeMethod<NotRemotedActor>(methodName);

Assert.Equal(methodName, result);
}

[Fact]
public async Task NoRemotingMethodWithNoArgumentsWithCancellation()
{
string methodName = nameof(INotRemotedActor.NoArgumentsWithCancellationAsync);

string result = await InvokeMethod<NotRemotedActor>(methodName);

Assert.Equal(methodName, result);
}

[Fact]
public async Task NoRemotingMethodWithSingleArgument()
{
string methodName = nameof(INotRemotedActor.SingleArgumentAsync);

string result = await InvokeMethod<NotRemotedActor>(methodName, true);

Assert.Equal(methodName, result);
}

[Fact]
public async Task NoRemotingMethodWithSingleArgumentWithCancellation()
{
string methodName = nameof(INotRemotedActor.SingleArgumentWithCancellationAsync);

string result = await InvokeMethod<NotRemotedActor>(methodName, true);

Assert.Equal(methodName, result);
}

[Fact]
public async Task Actor_UsesCustomActivator()
{
Expand Down

0 comments on commit 8d06a1f

Please sign in to comment.