Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds the new System.Numerics.Tensors as an input/output type when using dotnet 8.0 and up. #23261

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions csharp/src/Microsoft.ML.OnnxRuntime/FixedBufferOnnxValue.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,17 @@
using Microsoft.ML.OnnxRuntime.Tensors;
using System;

#if NET8_0_OR_GREATER
using DotnetTensors = System.Numerics.Tensors;
using TensorPrimitives = System.Numerics.Tensors.TensorPrimitives;
#endif

namespace Microsoft.ML.OnnxRuntime
{
/// <summary>
/// This is a legacy class that is kept for backward compatibility.
/// Use OrtValue based API.
///
///
/// Represents an OrtValue with its underlying buffer pinned
/// </summary>
public class FixedBufferOnnxValue : IDisposable
Expand Down Expand Up @@ -39,6 +44,22 @@ public static FixedBufferOnnxValue CreateFromTensor<T>(Tensor<T> value)
return new FixedBufferOnnxValue(ref ortValue, OnnxValueType.ONNX_TYPE_TENSOR, elementType);
}

#if NET8_0_OR_GREATER
#pragma warning disable SYSLIB5001 // System.Numerics.Tensors is only in preview so we can continue receiving API feedback
/// <summary>
/// Creates a <see cref="FixedBufferOnnxValue"/> object from the tensor and pins its underlying buffer.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="value"></param>
/// <returns>a disposable instance of FixedBufferOnnxValue</returns>
public static FixedBufferOnnxValue CreateFromDotnetTensor<T>(DotnetTensors.Tensor<T> value) where T : unmanaged
michaelgsharp marked this conversation as resolved.
Show resolved Hide resolved
{
var ortValue = OrtValue.CreateTensorValueFromDotnetTensorObject<T>(value);
return new FixedBufferOnnxValue(ref ortValue, OnnxValueType.ONNX_TYPE_TENSOR, TensorBase.GetTypeInfo(typeof(T)).ElementType);
}
#pragma warning restore SYSLIB5001 // System.Numerics.Tensors is only in preview so it can continue receiving API feedback
#endif

/// <summary>
/// This is a factory method that creates a disposable instance of FixedBufferOnnxValue
/// on top of a buffer. Internally, it will pin managed buffer and will create
Expand All @@ -62,7 +83,7 @@ public static FixedBufferOnnxValue CreateFromTensor<T>(Tensor<T> value)
/// Here is an example of using a 3rd party library class for processing float16/bfloat16.
/// Currently, to pass tensor data and create a tensor one must copy data to Float16/BFloat16 structures
/// so DenseTensor can recognize it.
///
///
/// If you are using a library that has a class Half and it is blittable, that is its managed in memory representation
/// matches native one and its size is 16-bits, you can use the following conceptual example
/// to feed/fetch data for inference using Half array. This allows you to avoid copying data from your Half[] to Float16[]
Expand All @@ -73,7 +94,7 @@ public static FixedBufferOnnxValue CreateFromTensor<T>(Tensor<T> value)
/// var input_shape = new long[] {input.Length};
/// Half[] output = new Half[40]; // Whatever the expected len/shape is must match
/// var output_shape = new long[] {output.Length};
///
///
/// var memInfo = OrtMemoryInfo.DefaultInstance; // CPU
///
/// using(var fixedBufferInput = FixedBufferOnnxvalue.CreateFromMemory{Half}(memInfo,
Expand Down
38 changes: 36 additions & 2 deletions csharp/src/Microsoft.ML.OnnxRuntime/ManagedProjections.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Reflection;


#if NET8_0_OR_GREATER
michaelgsharp marked this conversation as resolved.
Show resolved Hide resolved
using DotnetTensors = System.Numerics.Tensors;
using TensorPrimitives = System.Numerics.Tensors.TensorPrimitives;
#endif

namespace Microsoft.ML.OnnxRuntime
{
Expand Down Expand Up @@ -166,13 +173,41 @@ private static OrtValue CreateMapProjection(NamedOnnxValue node, NodeMetadata el
/// <exception cref="OnnxRuntimeException"></exception>
private static OrtValue CreateTensorProjection(NamedOnnxValue node, NodeMetadata elementMeta)
{
if (node.Value is not TensorBase)
#if NET8_0_OR_GREATER
#pragma warning disable SYSLIB5001 // System.Numerics.Tensors is only in preview so we can continue receiving API feedback
if (node.Value is not TensorBase && node.Value.GetType().GetGenericTypeDefinition() != typeof(DotnetTensors.Tensor<>))
{
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
$"NamedOnnxValue contains: {node.Value.GetType()}, expecting a Tensor<T>");
}

OrtValue ortValue;
TensorElementType elementType;

if (node.Value is TensorBase)
{
ortValue = OrtValue.CreateFromTensorObject(node.Value as TensorBase, out elementType);
}
else
{
MethodInfo method = typeof(OrtValue).GetMethod(nameof(OrtValue.CreateTensorValueFromDotnetTensorObject), BindingFlags.Static | BindingFlags.Public);
Type tensorType = node.Value.GetType().GetGenericArguments()[0];
MethodInfo generic = method.MakeGenericMethod(tensorType);
ortValue = (OrtValue)generic.Invoke(null, [node.Value]);
elementType = TensorBase.GetTypeInfo(tensorType).ElementType;
}


#pragma warning restore SYSLIB5001 // System.Numerics.Tensors is only in preview so we can continue receiving API feedback
#else
if (node.Value is not TensorBase)
{
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
$"NamedOnnxValue contains: {node.Value.GetType()}, expecting a Tensor<T>");
}
OrtValue ortValue = OrtValue.CreateFromTensorObject(node.Value as TensorBase, out TensorElementType elementType);

#endif
try
{
if (elementType != elementMeta.ElementDataType)
Expand All @@ -191,4 +226,3 @@ private static OrtValue CreateTensorProjection(NamedOnnxValue node, NodeMetadata
}
}
}

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<Project Sdk="MSBuild.Sdk.Extras/3.0.22">
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<!--- packaging properties -->
<OrtPackageId Condition="'$(OrtPackageId)' == ''">Microsoft.ML.OnnxRuntime</OrtPackageId>
Expand Down Expand Up @@ -184,6 +184,10 @@
<PackageReference Include="Microsoft.SourceLink.GitHub" Version="8.0.0" PrivateAssets="All" />
</ItemGroup>

<ItemGroup Condition="$([MSBuild]::IsTargetFrameworkCompatible('$(TargetFramework)', 'net8.0'))">
<PackageReference Include="System.Numerics.Tensors" Version="9.0.0" />
</ItemGroup>

<!-- debug output - makes finding/fixing any issues with the the conditions easy. -->
<Target Name="DumpValues" BeforeTargets="PreBuildEvent">
<Message Text="SolutionName='$(SolutionName)'" />
Expand Down
61 changes: 49 additions & 12 deletions csharp/src/Microsoft.ML.OnnxRuntime/NamedOnnxValue.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
using System.Diagnostics;
using System.Linq;

michaelgsharp marked this conversation as resolved.
Show resolved Hide resolved
#if NET8_0_OR_GREATER
using DotnetTensors = System.Numerics.Tensors;
using TensorPrimitives = System.Numerics.Tensors.TensorPrimitives;
#endif

namespace Microsoft.ML.OnnxRuntime
{
/// <summary>
Expand All @@ -30,37 +35,37 @@ internal MapHelper(TensorBase keys, TensorBase values)
/// <summary>
/// This is a legacy class that is kept for backward compatibility.
/// Use OrtValue based API.
///
/// The class associates a name with an Object.
///
/// The class associates a name with an Object.
/// The name of the class is a misnomer, it does not hold any Onnx values,
/// just managed representation of them.
///
///
/// The class is currently used as both inputs and outputs. Because it is non-
/// disposable, it can not hold on to any native objects.
///
///
/// When used as input, we temporarily create OrtValues that map managed inputs
/// directly. Thus we are able to avoid copying of contiguous data.
///
///
/// For outputs, tensor buffers works the same as input, providing it matches
/// the expected output shape. For other types (maps and sequences) we create a copy of the data.
/// This is because, the class is not Disposable and it is a public interface, thus it can not own
/// the underlying OrtValues that must be destroyed before Run() returns.
///
///
/// To avoid data copying on output, use DisposableNamedOnnxValue class that is returned from Run() methods.
/// This provides access to the native memory tensors and avoids copying.
///
///
/// It is a recursive structure that may contain Tensors (base case)
/// Other sequences and maps. Although the OnnxValueType is exposed,
/// the caller is supposed to know the actual data type contained.
///
///
/// The convention is that for tensors, it would contain a DenseTensor{T} instance or
/// anything derived from Tensor{T}.
///
///
/// For sequences, it would contain a IList{T} where T is an instance of NamedOnnxValue that
/// would contain a tensor or another type.
///
///
/// For Maps, it would contain a IDictionary{K, V} where K,V are primitive types or strings.
///
///
/// </summary>
public class NamedOnnxValue
{
Expand Down Expand Up @@ -140,6 +145,23 @@ public static NamedOnnxValue CreateFromTensor<T>(string name, Tensor<T> value)
return new NamedOnnxValue(name, value, OnnxValueType.ONNX_TYPE_TENSOR);
}

#if NET8_0_OR_GREATER
#pragma warning disable SYSLIB5001 // System.Numerics.Tensors is only in preview so we can continue receiving API feedback
/// <summary>
/// This is a factory method that instantiates NamedOnnxValue
/// and associated name with an instance of a Tensor<typeparamref name="T"/>
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="name">name</param>
/// <param name="value">Tensor<typeparamref name="T"/></param>
/// <returns></returns>
public static NamedOnnxValue CreateFromDotnetTensor<T>(string name, DotnetTensors.Tensor<T> value)
{
return new NamedOnnxValue(name, value, OnnxValueType.ONNX_TYPE_TENSOR);
}
#pragma warning restore SYSLIB5001 // System.Numerics.Tensors is only in preview so it can continue receiving API feedback
#endif

/// <summary>
/// This is a factory method that instantiates NamedOnnxValue.
/// It would contain a sequence of elements
Expand Down Expand Up @@ -196,6 +218,21 @@ public Tensor<T> AsTensor<T>()
return _value as Tensor<T>; // will return null if not castable
}


#if NET8_0_OR_GREATER
#pragma warning disable SYSLIB5001 // System.Numerics.Tensors is only in preview so we can continue receiving API feedback
/// <summary>
/// Try-get value as a Tensor&lt;T&gt;.
/// </summary>
/// <typeparam name="T">Type</typeparam>
/// <returns>Tensor object if contained value is a Tensor. Null otherwise</returns>
public DotnetTensors.Tensor<T> AsDotnetTensor<T>()
{
return _value as DotnetTensors.Tensor<T>; // will return null if not castable
}
#pragma warning restore SYSLIB5001 // System.Numerics.Tensors is only in preview so it can continue receiving API feedback
#endif

/// <summary>
/// Try-get value as an Enumerable&lt;T&gt;.
/// T is usually a NamedOnnxValue instance that may contain
Expand Down Expand Up @@ -266,7 +303,7 @@ internal virtual IntPtr OutputToOrtValueHandle(NodeMetadata metadata, out IDispo
}
}

throw new OnnxRuntimeException(ErrorCode.NotImplemented,
throw new OnnxRuntimeException(ErrorCode.NotImplemented,
$"Can not create output OrtValue for NamedOnnxValue '{metadata.OnnxValueType}' type." +
$" Only tensors can be pre-allocated for outputs " +
$" Use Run() overloads that return DisposableNamedOnnxValue to get access to all Onnx value types that may be returned as output.");
Expand Down
Loading
Loading