Skip to content

Commit

Permalink
changes from PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelgsharp committed Jan 8, 2025
1 parent 645b8b6 commit 6724b3b
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 303 deletions.
27 changes: 3 additions & 24 deletions csharp/src/Microsoft.ML.OnnxRuntime/FixedBufferOnnxValue.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,12 @@
using Microsoft.ML.OnnxRuntime.Tensors;
using System;

#if NET8_0_OR_GREATER
using DotnetTensors = System.Numerics.Tensors;
using TensorPrimitives = System.Numerics.Tensors.TensorPrimitives;
#endif

namespace Microsoft.ML.OnnxRuntime
{
/// <summary>
/// This is a legacy class that is kept for backward compatibility.
/// Use OrtValue based API.
///
///
/// Represents an OrtValue with its underlying buffer pinned
/// </summary>
public class FixedBufferOnnxValue : IDisposable
Expand Down Expand Up @@ -44,22 +39,6 @@ public static FixedBufferOnnxValue CreateFromTensor<T>(Tensor<T> value)
return new FixedBufferOnnxValue(ref ortValue, OnnxValueType.ONNX_TYPE_TENSOR, elementType);
}

#if NET8_0_OR_GREATER
#pragma warning disable SYSLIB5001 // System.Numerics.Tensors is only in preview so we can continue receiving API feedback
/// <summary>
/// Creates a <see cref="FixedBufferOnnxValue"/> object from the tensor and pins its underlying buffer.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="value"></param>
/// <returns>a disposable instance of FixedBufferOnnxValue</returns>
public static FixedBufferOnnxValue CreateFromDotnetTensor<T>(DotnetTensors.Tensor<T> value) where T : unmanaged
{
var ortValue = OrtValue.CreateTensorValueFromDotnetTensorObject<T>(value);
return new FixedBufferOnnxValue(ref ortValue, OnnxValueType.ONNX_TYPE_TENSOR, TensorBase.GetTypeInfo(typeof(T)).ElementType);
}
#pragma warning restore SYSLIB5001 // System.Numerics.Tensors is only in preview so it can continue receiving API feedback
#endif

/// <summary>
/// This is a factory method that creates a disposable instance of FixedBufferOnnxValue
/// on top of a buffer. Internally, it will pin managed buffer and will create
Expand All @@ -83,7 +62,7 @@ public static FixedBufferOnnxValue CreateFromDotnetTensor<T>(DotnetTensors.Tenso
/// Here is an example of using a 3rd party library class for processing float16/bfloat16.
/// Currently, to pass tensor data and create a tensor one must copy data to Float16/BFloat16 structures
/// so DenseTensor can recognize it.
///
///
/// If you are using a library that has a class Half and it is blittable, that is its managed in memory representation
/// matches native one and its size is 16-bits, you can use the following conceptual example
/// to feed/fetch data for inference using Half array. This allows you to avoid copying data from your Half[] to Float16[]
Expand All @@ -94,7 +73,7 @@ public static FixedBufferOnnxValue CreateFromDotnetTensor<T>(DotnetTensors.Tenso
/// var input_shape = new long[] {input.Length};
/// Half[] output = new Half[40]; // Whatever the expected len/shape is must match
/// var output_shape = new long[] {output.Length};
///
///
/// var memInfo = OrtMemoryInfo.DefaultInstance; // CPU
///
/// using(var fixedBufferInput = FixedBufferOnnxvalue.CreateFromMemory{Half}(memInfo,
Expand Down
38 changes: 2 additions & 36 deletions csharp/src/Microsoft.ML.OnnxRuntime/ManagedProjections.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,6 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Reflection;


#if NET8_0_OR_GREATER
using DotnetTensors = System.Numerics.Tensors;
using TensorPrimitives = System.Numerics.Tensors.TensorPrimitives;
#endif

namespace Microsoft.ML.OnnxRuntime
{
Expand Down Expand Up @@ -173,41 +166,13 @@ private static OrtValue CreateMapProjection(NamedOnnxValue node, NodeMetadata el
/// <exception cref="OnnxRuntimeException"></exception>
private static OrtValue CreateTensorProjection(NamedOnnxValue node, NodeMetadata elementMeta)
{
#if NET8_0_OR_GREATER
#pragma warning disable SYSLIB5001 // System.Numerics.Tensors is only in preview so we can continue receiving API feedback
if (node.Value is not TensorBase && node.Value.GetType().GetGenericTypeDefinition() != typeof(DotnetTensors.Tensor<>))
{
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
$"NamedOnnxValue contains: {node.Value.GetType()}, expecting a Tensor<T>");
}

OrtValue ortValue;
TensorElementType elementType;

if (node.Value is TensorBase)
{
ortValue = OrtValue.CreateFromTensorObject(node.Value as TensorBase, out elementType);
}
else
{
MethodInfo method = typeof(OrtValue).GetMethod(nameof(OrtValue.CreateTensorValueFromDotnetTensorObject), BindingFlags.Static | BindingFlags.Public);
Type tensorType = node.Value.GetType().GetGenericArguments()[0];
MethodInfo generic = method.MakeGenericMethod(tensorType);
ortValue = (OrtValue)generic.Invoke(null, [node.Value]);
elementType = TensorBase.GetTypeInfo(tensorType).ElementType;
}


#pragma warning restore SYSLIB5001 // System.Numerics.Tensors is only in preview so we can continue receiving API feedback
#else
if (node.Value is not TensorBase)
{
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
$"NamedOnnxValue contains: {node.Value.GetType()}, expecting a Tensor<T>");
}
OrtValue ortValue = OrtValue.CreateFromTensorObject(node.Value as TensorBase, out TensorElementType elementType);

#endif
OrtValue ortValue = OrtValue.CreateFromTensorObject(node.Value as TensorBase, out TensorElementType elementType);
try
{
if (elementType != elementMeta.ElementDataType)
Expand All @@ -226,3 +191,4 @@ private static OrtValue CreateTensorProjection(NamedOnnxValue node, NodeMetadata
}
}
}

61 changes: 12 additions & 49 deletions csharp/src/Microsoft.ML.OnnxRuntime/NamedOnnxValue.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,6 @@
using System.Diagnostics;
using System.Linq;

#if NET8_0_OR_GREATER
using DotnetTensors = System.Numerics.Tensors;
using TensorPrimitives = System.Numerics.Tensors.TensorPrimitives;
#endif

namespace Microsoft.ML.OnnxRuntime
{
/// <summary>
Expand All @@ -35,37 +30,37 @@ internal MapHelper(TensorBase keys, TensorBase values)
/// <summary>
/// This is a legacy class that is kept for backward compatibility.
/// Use OrtValue based API.
///
/// The class associates a name with an Object.
///
/// The class associates a name with an Object.
/// The name of the class is a misnomer, it does not hold any Onnx values,
/// just managed representation of them.
///
///
/// The class is currently used as both inputs and outputs. Because it is non-
/// disposable, it can not hold on to any native objects.
///
///
/// When used as input, we temporarily create OrtValues that map managed inputs
/// directly. Thus we are able to avoid copying of contiguous data.
///
///
/// For outputs, tensor buffers works the same as input, providing it matches
/// the expected output shape. For other types (maps and sequences) we create a copy of the data.
/// This is because, the class is not Disposable and it is a public interface, thus it can not own
/// the underlying OrtValues that must be destroyed before Run() returns.
///
///
/// To avoid data copying on output, use DisposableNamedOnnxValue class that is returned from Run() methods.
/// This provides access to the native memory tensors and avoids copying.
///
///
/// It is a recursive structure that may contain Tensors (base case)
/// Other sequences and maps. Although the OnnxValueType is exposed,
/// the caller is supposed to know the actual data type contained.
///
///
/// The convention is that for tensors, it would contain a DenseTensor{T} instance or
/// anything derived from Tensor{T}.
///
///
/// For sequences, it would contain a IList{T} where T is an instance of NamedOnnxValue that
/// would contain a tensor or another type.
///
///
/// For Maps, it would contain a IDictionary{K, V} where K,V are primitive types or strings.
///
///
/// </summary>
public class NamedOnnxValue
{
Expand Down Expand Up @@ -145,23 +140,6 @@ public static NamedOnnxValue CreateFromTensor<T>(string name, Tensor<T> value)
return new NamedOnnxValue(name, value, OnnxValueType.ONNX_TYPE_TENSOR);
}

#if NET8_0_OR_GREATER
#pragma warning disable SYSLIB5001 // System.Numerics.Tensors is only in preview so we can continue receiving API feedback
/// <summary>
/// This is a factory method that instantiates NamedOnnxValue
/// and associated name with an instance of a Tensor<typeparamref name="T"/>
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="name">name</param>
/// <param name="value">Tensor<typeparamref name="T"/></param>
/// <returns></returns>
public static NamedOnnxValue CreateFromDotnetTensor<T>(string name, DotnetTensors.Tensor<T> value)
{
return new NamedOnnxValue(name, value, OnnxValueType.ONNX_TYPE_TENSOR);
}
#pragma warning restore SYSLIB5001 // System.Numerics.Tensors is only in preview so it can continue receiving API feedback
#endif

/// <summary>
/// This is a factory method that instantiates NamedOnnxValue.
/// It would contain a sequence of elements
Expand Down Expand Up @@ -218,21 +196,6 @@ public Tensor<T> AsTensor<T>()
return _value as Tensor<T>; // will return null if not castable
}


#if NET8_0_OR_GREATER
#pragma warning disable SYSLIB5001 // System.Numerics.Tensors is only in preview so we can continue receiving API feedback
/// <summary>
/// Try-get value as a Tensor&lt;T&gt;.
/// </summary>
/// <typeparam name="T">Type</typeparam>
/// <returns>Tensor object if contained value is a Tensor. Null otherwise</returns>
public DotnetTensors.Tensor<T> AsDotnetTensor<T>()
{
return _value as DotnetTensors.Tensor<T>; // will return null if not castable
}
#pragma warning restore SYSLIB5001 // System.Numerics.Tensors is only in preview so it can continue receiving API feedback
#endif

/// <summary>
/// Try-get value as an Enumerable&lt;T&gt;.
/// T is usually a NamedOnnxValue instance that may contain
Expand Down Expand Up @@ -303,7 +266,7 @@ internal virtual IntPtr OutputToOrtValueHandle(NodeMetadata metadata, out IDispo
}
}

throw new OnnxRuntimeException(ErrorCode.NotImplemented,
throw new OnnxRuntimeException(ErrorCode.NotImplemented,
$"Can not create output OrtValue for NamedOnnxValue '{metadata.OnnxValueType}' type." +
$" Only tensors can be pre-allocated for outputs " +
$" Use Run() overloads that return DisposableNamedOnnxValue to get access to all Onnx value types that may be returned as output.");
Expand Down
9 changes: 4 additions & 5 deletions csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ public DotnetTensors.ReadOnlyTensorSpan<T> GetTensorDataAsTensorSpan<T>() where

var typeSpan = MemoryMarshal.Cast<byte, T>(byteSpan);
var shape = GetTypeInfo().TensorTypeAndShapeInfo.Shape;
var nArray = shape.Select(x => (nint)x).ToArray();
nint[] nArray = Array.ConvertAll(shape, new Converter<long, nint>(x => (nint)x));

return new DotnetTensors.ReadOnlyTensorSpan<T>(typeSpan, nArray, []);
}
Expand Down Expand Up @@ -281,7 +281,7 @@ public DotnetTensors.TensorSpan<T> GetTensorMutableDataAsTensorSpan<T>() where T

var typeSpan = MemoryMarshal.Cast<byte, T>(byteSpan);
var shape = GetTypeInfo().TensorTypeAndShapeInfo.Shape;
var nArray = shape.Select(x => (nint)x).ToArray();
nint[] nArray = Array.ConvertAll(shape, new Converter<long, nint>(x => (nint)x));

return new DotnetTensors.TensorSpan<T>(typeSpan, nArray, []);
}
Expand All @@ -308,7 +308,7 @@ public DotnetTensors.TensorSpan<byte> GetTensorSpanMutableRawData<T>() where T :
var byteSpan = GetTensorBufferRawData(typeof(T));

var shape = GetTypeInfo().TensorTypeAndShapeInfo.Shape;
var nArray = shape.Select(x => (nint)x).ToArray();
nint[] nArray = Array.ConvertAll(shape, new Converter<long, nint>(x => (nint)x));

return new DotnetTensors.TensorSpan<byte>(byteSpan, nArray, []);
}
Expand Down Expand Up @@ -720,8 +720,7 @@ public static OrtValue CreateTensorValueFromDotnetTensorObject<T>(DotnetTensors.
}

var bufferLengthInBytes = tensor.FlattenedLength * sizeof(T);

var shape = tensor.Lengths.ToArray().Select(x => (long)x).ToArray();
long[] shape = Array.ConvertAll(tensor.Lengths.ToArray(), new Converter<nint, long>(x => (long)x));

var typeInfo = TensorBase.GetTypeInfo(typeof(T)) ??
throw new OnnxRuntimeException(ErrorCode.InvalidArgument, $"Tensor of type: {typeof(T)} is not supported");
Expand Down
Loading

0 comments on commit 6724b3b

Please sign in to comment.