From de36f6dae852b973bd54aee3d7125127719f1666 Mon Sep 17 00:00:00 2001 From: Tanner Gooding Date: Fri, 30 Aug 2019 10:35:39 -0700 Subject: [PATCH] Adding basic support for handling vectors in the database loader. (#4138) * Adding basic support for handling vectors in the database loader. * Updating the DatabaseLoaderTests to run against an actual database. * Fixing a variable name to avoid a conflict * Fixing the DatabaseLoaderTests.IrisSdcaMaximumEntropy test to look in the TestModel folder * Adding back a type check, returning NaN as the default for float/double, and removing some dead code. * Fixing up the DatabaseLoader tests to use the Iris TestDatabase * Remove a call to pipeline.Preview * Fix the name of loaderColumns1 to loaderColumns * Fixing up doc comments and removing allocations from the database loader cursor. * Responding to PR feedback and removing dead code. * Mark the DatabaseLoaderTests as Windows specific for now. * Fixing the context writer to not nullref for a null segment * Ensure the Bindings segments is null if the count is zero. * Fix tests to build on netfx, and to be skipped on non-Windows. * Adding a connection timeout to give localdb a chance to initialize --- build/Dependencies.props | 2 + .../DataLoadSave/Database/DatabaseLoader.cs | 271 ++++++++++---- .../Database/DatabaseLoaderCursor.cs | 352 +++++++++++++++++- test/Microsoft.ML.TestFramework/Datasets.cs | 7 + .../Microsoft.ML.Tests/DatabaseLoaderTests.cs | 333 +++++------------ .../Microsoft.ML.Tests.csproj | 2 + 6 files changed, 658 insertions(+), 309 deletions(-) diff --git a/build/Dependencies.props b/build/Dependencies.props index ea5a6a8736..e1338a54e8 100644 --- a/build/Dependencies.props +++ b/build/Dependencies.props @@ -45,9 +45,11 @@ 0.11.3 1.0.0-beta1-63812-02 + 0.0.5-test 0.0.5-test 0.0.11-test 0.0.5-test + 4.6.1 diff --git a/src/Microsoft.ML.Experimental/DataLoadSave/Database/DatabaseLoader.cs b/src/Microsoft.ML.Experimental/DataLoadSave/Database/DatabaseLoader.cs index 9f6cbc3c3b..4bff0606be 100644 --- a/src/Microsoft.ML.Experimental/DataLoadSave/Database/DatabaseLoader.cs +++ b/src/Microsoft.ML.Experimental/DataLoadSave/Database/DatabaseLoader.cs @@ -5,11 +5,8 @@ using System; using System.Collections.Generic; using System.Data; -using System.Data.Common; using System.Linq; using System.Reflection; -using System.Runtime.CompilerServices; -using System.Text; using Microsoft.ML; using Microsoft.ML.CommandLine; using Microsoft.ML.Data; @@ -133,7 +130,7 @@ internal static DatabaseLoader CreateDatabaseLoader(IHostEnvironment hos if (mappingAttr is object) { var sources = mappingAttr.Sources.Select((source) => Range.FromTextLoaderRange(source)).ToArray(); - column.Source = sources.Single().Min; + column.Source = sources; } InternalDataKind dk; @@ -172,6 +169,52 @@ internal static DatabaseLoader CreateDatabaseLoader(IHostEnvironment hos /// public sealed class Column { + /// + /// Initializes a new instance of the class. + /// + public Column() { } + + /// + /// Initializes a new instance of the class. + /// + /// Name of the column. + /// of the items in the column. + /// Index of the column. + public Column(string name, DbType dbType, int index) + : this(name, dbType, new[] { new Range(index) }) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// Name of the column. + /// of the items in the column. + /// The minimum inclusive index of the column. + /// The maximum-inclusive index of the column. + public Column(string name, DbType dbType, int minIndex, int maxIndex) + : this(name, dbType, new[] { new Range(minIndex, maxIndex) }) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// Name of the column. + /// of the items in the column. + /// Source index range(s) of the column. + /// For a key column, this defines the range of values. + public Column(string name, DbType dbType, Range[] source, KeyCount keyCount = null) + { + Contracts.CheckValue(name, nameof(name)); + Contracts.CheckValue(source, nameof(source)); + + Name = name; + Type = dbType; + Source = source; + KeyCount = keyCount; + } + /// /// Name of the column. /// @@ -185,10 +228,10 @@ public sealed class Column public DbType Type = DbType.Single; /// - /// Source index of the column. + /// Source index range(s) of the column. /// - [Argument(ArgumentType.Multiple, HelpText = "Source index of the column", ShortName = "src")] - public int? Source; + [Argument(ArgumentType.Multiple, HelpText = "Source index range(s) of the column", ShortName = "src")] + public Range[] Source; /// /// For a key column, this defines the range of values. @@ -207,7 +250,7 @@ public Range() { } /// /// A range representing a single value. Will result in a scalar column. /// - /// The index of the field of the text file to read. + /// The index of the field of the table to read. public Range(int index) { Contracts.CheckParam(index >= 0, nameof(index), "Must be non-negative"); @@ -219,20 +262,17 @@ public Range(int index) /// A range representing a set of values. Will result in a vector column. /// /// The minimum inclusive index of the column. - /// The maximum-inclusive index of the column. If null - /// indicates that the should auto-detect the legnth - /// of the lines, and read untill the end. - public Range(int min, int? max) + /// The maximum-inclusive index of the column. + public Range(int min, int max) { Contracts.CheckParam(min >= 0, nameof(min), "Must be non-negative"); - Contracts.CheckParam(!(max < min), nameof(max), "If specified, must be greater than or equal to " + nameof(min)); + Contracts.CheckParam(max >= min, nameof(max), "Must be greater than or equal to " + nameof(min)); Min = min; Max = max; // Note that without the following being set, in the case where there is a single range // where Min == Max, the result will not be a vector valued but a scalar column. ForceVector = true; - AutoEnd = max == null; } /// @@ -242,28 +282,10 @@ public Range(int min, int? max) public int Min; /// - /// The maximum index of the column, inclusive. If - /// indicates that the should auto-detect the legnth - /// of the lines, and read untill the end. - /// If is specified, the field is ignored. + /// The maximum index of the column, inclusive. /// [Argument(ArgumentType.AtMostOnce, HelpText = "Last index in the range")] - public int? Max; - - /// - /// Whether this range extends to the end of the line, but should be a fixed number of items. - /// If is specified, the field is ignored. - /// - [Argument(ArgumentType.AtMostOnce, - HelpText = "This range extends to the end of the line, but should be a fixed number of items", - ShortName = "auto")] - public bool AutoEnd; - - /// - /// Whether this range includes only other indices not specified. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "This range includes only other indices not specified", ShortName = "other")] - public bool AllOther; + public int Max; /// /// Force scalar columns to be treated as vectors of length one. @@ -273,7 +295,8 @@ public Range(int min, int? max) internal static Range FromTextLoaderRange(TextLoader.Range range) { - return new Range(range.Min, range.Max); + Contracts.Assert(range.Max.HasValue); + return new Range(range.Min, range.Max.Value); } } @@ -290,31 +313,102 @@ public sealed class Options public Column[] Columns; } + /// + /// Used as an input column range. + /// + internal readonly struct Segment + { + public readonly int Min; + public readonly int Lim; + public readonly bool ForceVector; + + public Segment(int min, int lim, bool forceVector) + { + Contracts.Assert(0 <= min & min < lim); + Min = min; + Lim = lim; + ForceVector = forceVector; + } + } + /// /// Information for an output column. /// private sealed class ColInfo { public readonly string Name; - public readonly int? SourceIndex; public readonly DataViewType ColType; + public readonly Segment[] Segments; - public ColInfo(string name, int? sourceIndex, DataViewType colType) + // BaseSize is the sum of the sizes of segments. + public readonly int SizeBase; + + private ColInfo(string name, DataViewType colType, Segment[] segs, int sizeBase) { Contracts.AssertNonEmpty(name); - Contracts.Assert(!sourceIndex.HasValue || sourceIndex >= 0); - Contracts.AssertValue(colType); + Contracts.AssertValueOrNull(segs); + Contracts.Assert(sizeBase > 0); Name = name; - SourceIndex = sourceIndex; + Contracts.Assert(colType.GetItemType().GetRawKind() != 0); ColType = colType; + Segments = segs; + SizeBase = sizeBase; + } + + public static ColInfo Create(string name, PrimitiveDataViewType itemType, Segment[] segs, bool user) + { + Contracts.AssertNonEmpty(name); + Contracts.AssertValue(itemType); + Contracts.AssertValueOrNull(segs); + + int size = 0; + DataViewType type = itemType; + + if (segs != null) + { + var order = Utils.GetIdentityPermutation(segs.Length); + Array.Sort(order, (x, y) => segs[x].Min.CompareTo(segs[y].Min)); + + // Check that the segments are disjoint. + for (int i = 1; i < order.Length; i++) + { + int a = order[i - 1]; + int b = order[i]; + Contracts.Assert(segs[a].Min <= segs[b].Min); + if (segs[a].Lim > segs[b].Min) + { + throw user ? + Contracts.ExceptUserArg(nameof(Column.Source), "Intervals specified for column '{0}' overlap", name) : + Contracts.ExceptDecode("Intervals specified for column '{0}' overlap", name); + } + } + + // Note: since we know that the segments don't overlap, we're guaranteed that + // the sum of their sizes doesn't overflow. + for (int i = 0; i < segs.Length; i++) + { + var seg = segs[i]; + size += seg.Lim - seg.Min; + } + Contracts.Assert(size >= segs.Length); + + if (size > 1 || segs[0].ForceVector) + type = new VectorDataViewType(itemType, size); + } + else + { + size++; + } + + return new ColInfo(name, type, segs, size); } } private sealed class Bindings { /// - /// [i] stores the i-th column's name and type. Columns are loaded from the input text file. + /// [i] stores the i-th column's name and type. Columns are loaded from the input database. /// public readonly ColInfo[] Infos; @@ -326,13 +420,6 @@ public Bindings(DatabaseLoader parent, Column[] cols) using (var ch = parent._host.Start("Binding")) { - // Make sure all columns have at least one source range. - foreach (var col in cols) - { - if (col.Source < 0) - throw ch.ExceptUserArg(nameof(Column.Source), "Source column index must be non-negative"); - } - Infos = new ColInfo[cols.Length]; // This dictionary is used only for detecting duplicated column names specified by user. @@ -354,10 +441,34 @@ public Bindings(DatabaseLoader parent, Column[] cols) } else { + ch.CheckUserArg(Enum.IsDefined(typeof(DbType), col.Type), nameof(Column.Type), "Bad item type"); itemType = ColumnTypeExtensions.PrimitiveTypeFromType(col.Type.ToType()); } - Infos[iinfo] = new ColInfo(name, col.Source, itemType); + Segment[] segs = null; + + if (col.Source != null) + { + segs = new Segment[col.Source.Length]; + + for (int i = 0; i < segs.Length; i++) + { + var range = col.Source[i]; + + int min = range.Min; + ch.CheckUserArg(0 <= min, nameof(range.Min)); + + Segment seg; + + int max = range.Max; + ch.CheckUserArg(min <= max, nameof(range.Max)); + seg = new Segment(min, max + 1, range.ForceVector); + + segs[i] = seg; + } + } + + Infos[iinfo] = ColInfo.Create(name, itemType, segs, true); nameToInfoIndex[name] = iinfo; } @@ -377,9 +488,11 @@ public Bindings(ModelLoadContext ctx, DatabaseLoader parent) // byte: bool of whether this is a key type // for a key type: // ulong: count for key range - // byte: bool of whether the source index is valid - // for a valid source index: - // int: source index + // int: number of segments + // foreach segment: + // int: min + // int: lim + // byte: force vector (verWrittenCur: verIsVectorSupported) int cinfo = ctx.Reader.ReadInt32(); Contracts.CheckDecode(cinfo > 0); Infos = new ColInfo[cinfo]; @@ -405,14 +518,32 @@ public Bindings(ModelLoadContext ctx, DatabaseLoader parent) else itemType = ColumnTypeExtensions.PrimitiveTypeFromKind(kind); - int? sourceIndex = null; - bool hasSourceIndex = ctx.Reader.ReadBoolByte(); - if (hasSourceIndex) + int cseg = ctx.Reader.ReadInt32(); + + Segment[] segs; + + if (cseg == 0) { - sourceIndex = ctx.Reader.ReadInt32(); + segs = null; + } + else + { + Contracts.CheckDecode(cseg > 0); + segs = new Segment[cseg]; + for (int iseg = 0; iseg < cseg; iseg++) + { + int min = ctx.Reader.ReadInt32(); + int lim = ctx.Reader.ReadInt32(); + Contracts.CheckDecode(0 <= min && min < lim); + bool forceVector = ctx.Reader.ReadBoolByte(); + segs[iseg] = new Segment(min, lim, forceVector); + } } - Infos[iinfo] = new ColInfo(name, sourceIndex, itemType); + // Note that this will throw if the segments are ill-structured, including the case + // of multiple variable segments (since those segments will overlap and overlapping + // segments are illegal). + Infos[iinfo] = ColInfo.Create(name, itemType, segs, false); } OutputSchema = ComputeOutputSchema(); @@ -430,9 +561,11 @@ internal void Save(ModelSaveContext ctx) // byte: bool of whether this is a key type // for a key type: // ulong: count for key range - // byte: bool of whether the source index is valid - // for a valid source index: - // int: source index + // int: number of segments + // foreach segment: + // int: min + // int: lim + // byte: force vector (verWrittenCur: verIsVectorSupported) ctx.Writer.Write(Infos.Length); for (int iinfo = 0; iinfo < Infos.Length; iinfo++) { @@ -445,9 +578,21 @@ internal void Save(ModelSaveContext ctx) ctx.Writer.WriteBoolByte(type is KeyDataViewType); if (type is KeyDataViewType key) ctx.Writer.Write(key.Count); - ctx.Writer.WriteBoolByte(info.SourceIndex.HasValue); - if (info.SourceIndex.HasValue) - ctx.Writer.Write(info.SourceIndex.GetValueOrDefault()); + + if (info.Segments is null) + { + ctx.Writer.Write(0); + } + else + { + ctx.Writer.Write(info.Segments.Length); + foreach (var seg in info.Segments) + { + ctx.Writer.Write(seg.Min); + ctx.Writer.Write(seg.Lim); + ctx.Writer.WriteBoolByte(seg.ForceVector); + } + } } } diff --git a/src/Microsoft.ML.Experimental/DataLoadSave/Database/DatabaseLoaderCursor.cs b/src/Microsoft.ML.Experimental/DataLoadSave/Database/DatabaseLoaderCursor.cs index 4ab48e5dfb..e6b87dfffa 100644 --- a/src/Microsoft.ML.Experimental/DataLoadSave/Database/DatabaseLoaderCursor.cs +++ b/src/Microsoft.ML.Experimental/DataLoadSave/Database/DatabaseLoaderCursor.cs @@ -225,6 +225,58 @@ private Delegate CreateGetterDelegate(int col) { getterDelegate = CreateUInt64GetterDelegate(colInfo); } + else if (typeof(TValue) == typeof(VBuffer)) + { + getterDelegate = CreateVBufferBooleanGetterDelegate(colInfo); + } + else if (typeof(TValue) == typeof(VBuffer)) + { + getterDelegate = CreateVBufferByteGetterDelegate(colInfo); + } + else if (typeof(TValue) == typeof(VBuffer)) + { + getterDelegate = CreateVBufferDateTimeGetterDelegate(colInfo); + } + else if (typeof(TValue) == typeof(VBuffer)) + { + getterDelegate = CreateVBufferDoubleGetterDelegate(colInfo); + } + else if (typeof(TValue) == typeof(VBuffer)) + { + getterDelegate = CreateVBufferInt16GetterDelegate(colInfo); + } + else if (typeof(TValue) == typeof(VBuffer)) + { + getterDelegate = CreateVBufferInt32GetterDelegate(colInfo); + } + else if (typeof(TValue) == typeof(VBuffer)) + { + getterDelegate = CreateVBufferInt64GetterDelegate(colInfo); + } + else if (typeof(TValue) == typeof(VBuffer)) + { + getterDelegate = CreateVBufferSByteGetterDelegate(colInfo); + } + else if (typeof(TValue) == typeof(VBuffer)) + { + getterDelegate = CreateVBufferSingleGetterDelegate(colInfo); + } + else if (typeof(TValue) == typeof(VBuffer>)) + { + getterDelegate = CreateVBufferStringGetterDelegate(colInfo); + } + else if (typeof(TValue) == typeof(VBuffer)) + { + getterDelegate = CreateVBufferUInt16GetterDelegate(colInfo); + } + else if (typeof(TValue) == typeof(VBuffer)) + { + getterDelegate = CreateVBufferUInt32GetterDelegate(colInfo); + } + else if (typeof(TValue) == typeof(VBuffer)) + { + getterDelegate = CreateVBufferUInt64GetterDelegate(colInfo); + } else { throw new NotSupportedException(); @@ -311,9 +363,307 @@ private ValueGetter CreateUInt64GetterDelegate(ColInfo colInfo) return (ref ulong value) => value = DataReader.IsDBNull(columnIndex) ? default : (ulong)DataReader.GetInt64(columnIndex); } + private ValueGetter> CreateVBufferBooleanGetterDelegate(ColInfo colInfo) + { + return (ref VBuffer value) => + { + int length = colInfo.SizeBase; + var editor = VBufferEditor.Create(ref value, length); + + int i = 0; + var segs = colInfo.Segments; + + foreach (var seg in segs) + { + for (int columnIndex = seg.Min; columnIndex < seg.Lim; columnIndex++, i++) + { + editor.Values[i] = DataReader.IsDBNull(columnIndex) ? default : DataReader.GetBoolean(columnIndex); + } + } + + value = editor.Commit(); + }; + } + + private ValueGetter> CreateVBufferByteGetterDelegate(ColInfo colInfo) + { + return (ref VBuffer value) => + { + int length = colInfo.SizeBase; + var editor = VBufferEditor.Create(ref value, length); + + int i = 0; + var segs = colInfo.Segments; + + foreach (var seg in segs) + { + for (int columnIndex = seg.Min; columnIndex < seg.Lim; columnIndex++, i++) + { + editor.Values[i] = DataReader.IsDBNull(columnIndex) ? default : DataReader.GetByte(columnIndex); + } + } + + value = editor.Commit(); + }; + } + + private ValueGetter> CreateVBufferDateTimeGetterDelegate(ColInfo colInfo) + { + return (ref VBuffer value) => + { + int length = colInfo.SizeBase; + var editor = VBufferEditor.Create(ref value, length); + + int i = 0; + var segs = colInfo.Segments; + + foreach (var seg in segs) + { + for (int columnIndex = seg.Min; columnIndex < seg.Lim; columnIndex++, i++) + { + editor.Values[i] = DataReader.IsDBNull(columnIndex) ? default : DataReader.GetDateTime(columnIndex); + } + } + + value = editor.Commit(); + }; + } + + private ValueGetter> CreateVBufferDoubleGetterDelegate(ColInfo colInfo) + { + return (ref VBuffer value) => + { + int length = colInfo.SizeBase; + var editor = VBufferEditor.Create(ref value, length); + + int i = 0; + var segs = colInfo.Segments; + + foreach (var seg in segs) + { + for (int columnIndex = seg.Min; columnIndex < seg.Lim; columnIndex++, i++) + { + editor.Values[i] = DataReader.IsDBNull(columnIndex) ? double.NaN : DataReader.GetDouble(columnIndex); + } + } + + value = editor.Commit(); + }; + } + + private ValueGetter> CreateVBufferInt16GetterDelegate(ColInfo colInfo) + { + return (ref VBuffer value) => + { + int length = colInfo.SizeBase; + var editor = VBufferEditor.Create(ref value, length); + + int i = 0; + var segs = colInfo.Segments; + + foreach (var seg in segs) + { + for (int columnIndex = seg.Min; columnIndex < seg.Lim; columnIndex++, i++) + { + editor.Values[i] = DataReader.IsDBNull(columnIndex) ? default : DataReader.GetInt16(columnIndex); + } + } + + value = editor.Commit(); + }; + } + + private ValueGetter> CreateVBufferInt32GetterDelegate(ColInfo colInfo) + { + return (ref VBuffer value) => + { + int length = colInfo.SizeBase; + var editor = VBufferEditor.Create(ref value, length); + + int i = 0; + var segs = colInfo.Segments; + + foreach (var seg in segs) + { + for (int columnIndex = seg.Min; columnIndex < seg.Lim; columnIndex++, i++) + { + editor.Values[i] = DataReader.IsDBNull(columnIndex) ? default : DataReader.GetInt32(columnIndex); + } + } + + value = editor.Commit(); + }; + } + + private ValueGetter> CreateVBufferInt64GetterDelegate(ColInfo colInfo) + { + return (ref VBuffer value) => + { + int length = colInfo.SizeBase; + var editor = VBufferEditor.Create(ref value, length); + + int i = 0; + var segs = colInfo.Segments; + + foreach (var seg in segs) + { + for (int columnIndex = seg.Min; columnIndex < seg.Lim; columnIndex++, i++) + { + editor.Values[i] = DataReader.IsDBNull(columnIndex) ? default : DataReader.GetInt64(columnIndex); + } + } + + value = editor.Commit(); + }; + } + + private ValueGetter> CreateVBufferSByteGetterDelegate(ColInfo colInfo) + { + return (ref VBuffer value) => + { + int length = colInfo.SizeBase; + var editor = VBufferEditor.Create(ref value, length); + + int i = 0; + var segs = colInfo.Segments; + + foreach (var seg in segs) + { + for (int columnIndex = seg.Min; columnIndex < seg.Lim; columnIndex++, i++) + { + editor.Values[i] = DataReader.IsDBNull(columnIndex) ? default : (sbyte)DataReader.GetByte(columnIndex); + } + } + + value = editor.Commit(); + }; + } + + private ValueGetter> CreateVBufferSingleGetterDelegate(ColInfo colInfo) + { + return (ref VBuffer value) => + { + int length = colInfo.SizeBase; + var editor = VBufferEditor.Create(ref value, length); + + int i = 0; + var segs = colInfo.Segments; + + foreach (var seg in segs) + { + for (int columnIndex = seg.Min; columnIndex < seg.Lim; columnIndex++, i++) + { + editor.Values[i] = DataReader.IsDBNull(columnIndex) ? float.NaN : DataReader.GetFloat(columnIndex); + } + } + + value = editor.Commit(); + }; + } + + private ValueGetter>> CreateVBufferStringGetterDelegate(ColInfo colInfo) + { + return (ref VBuffer> value) => + { + int length = colInfo.SizeBase; + var editor = VBufferEditor.Create(ref value, length); + + int i = 0; + var segs = colInfo.Segments; + + foreach (var seg in segs) + { + for (int columnIndex = seg.Min; columnIndex < seg.Lim; columnIndex++, i++) + { + editor.Values[i] = DataReader.IsDBNull(columnIndex) ? default : DataReader.GetString(columnIndex).AsMemory(); + } + } + + value = editor.Commit(); + }; + } + + private ValueGetter> CreateVBufferUInt16GetterDelegate(ColInfo colInfo) + { + return (ref VBuffer value) => + { + int length = colInfo.SizeBase; + var editor = VBufferEditor.Create(ref value, length); + + int i = 0; + var segs = colInfo.Segments; + + foreach (var seg in segs) + { + for (int columnIndex = seg.Min; columnIndex < seg.Lim; columnIndex++, i++) + { + editor.Values[i] = DataReader.IsDBNull(columnIndex) ? default : (ushort)DataReader.GetInt16(columnIndex); + } + } + + value = editor.Commit(); + }; + } + + private ValueGetter> CreateVBufferUInt32GetterDelegate(ColInfo colInfo) + { + return (ref VBuffer value) => + { + int length = colInfo.SizeBase; + var editor = VBufferEditor.Create(ref value, length); + + int i = 0; + var segs = colInfo.Segments; + + foreach (var seg in segs) + { + for (int columnIndex = seg.Min; columnIndex < seg.Lim; columnIndex++, i++) + { + editor.Values[i] = DataReader.IsDBNull(columnIndex) ? default : (uint)DataReader.GetInt32(columnIndex); + } + } + + value = editor.Commit(); + }; + } + + private ValueGetter> CreateVBufferUInt64GetterDelegate(ColInfo colInfo) + { + return (ref VBuffer value) => + { + int length = colInfo.SizeBase; + var editor = VBufferEditor.Create(ref value, length); + + int i = 0; + var segs = colInfo.Segments; + + foreach (var seg in segs) + { + for (int columnIndex = seg.Min; columnIndex < seg.Lim; columnIndex++, i++) + { + editor.Values[i] = DataReader.IsDBNull(columnIndex) ? default : (ulong)DataReader.GetInt64(columnIndex); + } + } + + value = editor.Commit(); + }; + } + private int GetColumnIndex(ColInfo colInfo) { - return colInfo.SourceIndex ?? DataReader.GetOrdinal(colInfo.Name); + var segs = colInfo.Segments; + + if (segs is null) + { + return DataReader.GetOrdinal(colInfo.Name); + } + + Contracts.Check(segs.Length == 1); + + var seg = segs[0]; + Contracts.Check(seg.Min == seg.Lim); + + return seg.Min; } } } diff --git a/test/Microsoft.ML.TestFramework/Datasets.cs b/test/Microsoft.ML.TestFramework/Datasets.cs index a0aa1d7dac..2a0dc6e6cf 100644 --- a/test/Microsoft.ML.TestFramework/Datasets.cs +++ b/test/Microsoft.ML.TestFramework/Datasets.cs @@ -462,6 +462,13 @@ public static class TestDatasets mamlExtraSettings = new[] { "xf=Term{col=Label}" } }; + public static TestDataset irisDb = new TestDataset() + { + name = "iris", + trainFilename = @"iris-train", + testFilename = @"iris-test", + }; + public static TestDataset irisMissing = new TestDataset() { name = "irisMissing", diff --git a/test/Microsoft.ML.Tests/DatabaseLoaderTests.cs b/test/Microsoft.ML.Tests/DatabaseLoaderTests.cs index eb6dc45329..e28ebd9813 100644 --- a/test/Microsoft.ML.Tests/DatabaseLoaderTests.cs +++ b/test/Microsoft.ML.Tests/DatabaseLoaderTests.cs @@ -2,11 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Collections; using System.Data; -using System.Data.Common; -using System.Linq; +using System.Data.SqlClient; +using System.IO; +using System.Runtime.InteropServices; using Microsoft.ML.Data; using Microsoft.ML.RunTests; using Microsoft.ML.TestFramework; @@ -26,9 +25,16 @@ public DatabaseLoaderTests(ITestOutputHelper output) [LightGBMFact] public void IrisLightGbm() { + if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + // https://github.com/dotnet/machinelearning/issues/4156 + return; + } + var mlContext = new MLContext(seed: 1); - var connectionString = GetDataPath(TestDatasets.iris.trainFilename); - var commandText = "Label;SepalLength;SepalWidth;PetalLength;PetalWidth"; + + var connectionString = GetConnectionString(TestDatasets.irisDb.name); + var commandText = $@"SELECT * FROM ""{TestDatasets.irisDb.trainFilename}"""; var loaderColumns = new DatabaseLoader.Column[] { @@ -41,12 +47,11 @@ public void IrisLightGbm() var loader = mlContext.Data.CreateDatabaseLoader(loaderColumns); - var mockProviderFactory = new MockProviderFactory(mlContext, loader); - var databaseSource = new DatabaseSource(mockProviderFactory, connectionString, commandText); + var databaseSource = new DatabaseSource(SqlClientFactory.Instance, connectionString, commandText); var trainingData = loader.Load(databaseSource); - var pipeline = mlContext.Transforms.Conversion.MapValueToKey("Label") + IEstimator pipeline = mlContext.Transforms.Conversion.MapValueToKey("Label") .Append(mlContext.Transforms.Concatenate("Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth")) .Append(mlContext.MulticlassClassification.Trainers.LightGbm()) .Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel")); @@ -72,17 +77,65 @@ public void IrisLightGbm() }).PredictedLabel); } + [LightGBMFact] + public void IrisVectorLightGbm() + { + if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + // https://github.com/dotnet/machinelearning/issues/4156 + return; + } + + var mlContext = new MLContext(seed: 1); + + var connectionString = GetConnectionString(TestDatasets.irisDb.name); + var commandText = $@"SELECT * FROM ""{TestDatasets.irisDb.trainFilename}"""; + + var loader = mlContext.Data.CreateDatabaseLoader(); + + var databaseSource = new DatabaseSource(SqlClientFactory.Instance, connectionString, commandText); + + var trainingData = loader.Load(databaseSource); + + IEstimator pipeline = mlContext.Transforms.Conversion.MapValueToKey("Label") + .Append(mlContext.Transforms.Concatenate("Features", "SepalInfo", "PetalInfo")) + .Append(mlContext.MulticlassClassification.Trainers.LightGbm()) + .Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel")); + + var model = pipeline.Fit(trainingData); + + var engine = mlContext.Model.CreatePredictionEngine(model); + + Assert.Equal(0, engine.Predict(new IrisVectorData() + { + SepalInfo = new float[] { 4.5f, 5.6f }, + PetalInfo = new float[] { 0.5f, 0.5f }, + }).PredictedLabel); + + Assert.Equal(1, engine.Predict(new IrisVectorData() + { + SepalInfo = new float[] { 4.9f, 2.4f }, + PetalInfo = new float[] { 3.3f, 1.0f }, + }).PredictedLabel); + } + [Fact] public void IrisSdcaMaximumEntropy() { + if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + // https://github.com/dotnet/machinelearning/issues/4156 + return; + } + var mlContext = new MLContext(seed: 1); - var connectionString = GetDataPath(TestDatasets.iris.trainFilename); - var commandText = "Label;SepalLength;SepalWidth;PetalLength;PetalWidth"; + + var connectionString = GetConnectionString(TestDatasets.irisDb.name); + var commandText = $@"SELECT * FROM ""{TestDatasets.irisDb.trainFilename}"""; var loader = mlContext.Data.CreateDatabaseLoader(); - var mockProviderFactory = new MockProviderFactory(mlContext, loader); - var databaseSource = new DatabaseSource(mockProviderFactory, connectionString, commandText); + var databaseSource = new DatabaseSource(SqlClientFactory.Instance, connectionString, commandText); var trainingData = loader.Load(databaseSource); @@ -112,258 +165,48 @@ public void IrisSdcaMaximumEntropy() }).PredictedLabel); } - public class IrisData - { - public int Label; - - public float SepalLength; - - public float SepalWidth; - - public float PetalLength; - - public float PetalWidth; - } - - public class IrisPrediction - { - public int PredictedLabel; - public float[] Score; - } - } - - internal sealed class MockProviderFactory : DbProviderFactory - { - private MLContext _context; - private DatabaseLoader _databaseLoader; - - public MockProviderFactory(MLContext context, DatabaseLoader databaseLoader) - { - _context = context; - _databaseLoader = databaseLoader; - } - - public override DbConnection CreateConnection() => new MockConnection(_context, _databaseLoader); - } - - internal sealed class MockConnection : DbConnection - { - private string _dataPath; - private TextLoader _textLoader; - - public MockConnection(MLContext context, DatabaseLoader databaseLoader) + private string GetTestDatabasePath(string databaseName) { - var outputSchema = databaseLoader.GetOutputSchema(); - var readerColumns = new TextLoader.Column[outputSchema.Count]; - - for (int i = 0; i < outputSchema.Count; i++) - { - var column = outputSchema[i]; - var columnType = column.Type.RawType; - - Assert.True(columnType.TryGetDataKind(out var internalDataKind)); - readerColumns[i] = new TextLoader.Column(column.Name, internalDataKind.ToDataKind(), i); - } - - _textLoader = context.Data.CreateTextLoader(readerColumns); + return Path.GetFullPath(Path.Combine("TestDatabases", $"{databaseName}.mdf")); } - public override string ConnectionString + private string GetConnectionString(string databaseName) { - get - { - return _dataPath; - } - - set - { - _dataPath = value; - } + var databaseFile = GetTestDatabasePath(databaseName); + return $@"Data Source=(LocalDB)\MSSQLLocalDB;AttachDbFilename={databaseFile};Database={databaseName};Integrated Security=True;Connect Timeout=120"; } - public override string Database => throw new NotImplementedException(); - - public override string DataSource => throw new NotImplementedException(); - - public IDataView DataView { get; private set; } - - public override string ServerVersion => throw new NotImplementedException(); - - public override ConnectionState State => throw new NotImplementedException(); - - public override void ChangeDatabase(string databaseName) => throw new NotImplementedException(); - - public override void Close() => throw new NotImplementedException(); - - public override void Open() - { - DataView = _textLoader.Load(_dataPath); - } - - protected override DbTransaction BeginDbTransaction(IsolationLevel isolationLevel) => throw new NotImplementedException(); - - protected override DbCommand CreateDbCommand() => new MockCommand(this); - } - - internal sealed class MockCommand : DbCommand - { - public MockCommand(MockConnection connection) - { - CommandText = string.Empty; - Connection = connection; - } - - public override string CommandText { get; set; } - - public override int CommandTimeout + public class IrisData { - get => throw new NotImplementedException(); - set => throw new NotImplementedException(); - } + public int Label; - public override CommandType CommandType - { - get - { - throw new NotImplementedException(); - } + public float SepalLength; - set - { - throw new NotImplementedException(); - } - } + public float SepalWidth; - public override bool DesignTimeVisible - { - get => throw new NotImplementedException(); - set => throw new NotImplementedException(); - } + public float PetalLength; - public override UpdateRowSource UpdatedRowSource - { - get => throw new NotImplementedException(); - set => throw new NotImplementedException(); + public float PetalWidth; } - protected override DbConnection DbConnection { get; set; } - - protected override DbParameterCollection DbParameterCollection => throw new NotImplementedException(); - - protected override DbTransaction DbTransaction + public class IrisVectorData { - get => throw new NotImplementedException(); - set => throw new NotImplementedException(); - } - - public override void Cancel() => throw new NotImplementedException(); - - public override int ExecuteNonQuery() => throw new NotImplementedException(); - - public override object ExecuteScalar() => throw new NotImplementedException(); - - public override void Prepare() => throw new NotImplementedException(); - - protected override DbParameter CreateDbParameter() => throw new NotImplementedException(); - - protected override DbDataReader ExecuteDbDataReader(CommandBehavior behavior) => new MockDbDataReader(this); - } - - internal sealed class MockDbDataReader : DbDataReader - { - private MockCommand _command; - private DataViewRowCursor _rowCursor; - private IDataView _dataView; - - public MockDbDataReader(MockCommand command) - { - _command = command; - - var connection = (MockConnection)_command.Connection; - _dataView = connection.DataView; - - var inputColumns = _dataView.Schema.Where((column) => - { - var inputColumnNames = command.CommandText.Split(';'); - return inputColumnNames.Any((columnName) => column.Name.Equals(column.Name)); - }); - _rowCursor = _dataView.GetRowCursor(inputColumns); - } - - public override object this[int ordinal] => throw new NotImplementedException(); - - public override object this[string name] => throw new NotImplementedException(); - - public override int Depth => throw new NotImplementedException(); - - public override int FieldCount => throw new NotImplementedException(); - - public override bool HasRows => throw new NotImplementedException(); - - public override bool IsClosed => throw new NotImplementedException(); - - public override int RecordsAffected => throw new NotImplementedException(); - - public override bool GetBoolean(int ordinal) => throw new NotImplementedException(); - - public override byte GetByte(int ordinal) => throw new NotImplementedException(); - - public override long GetBytes(int ordinal, long dataOffset, byte[] buffer, int bufferOffset, int length) => throw new NotImplementedException(); - - public override char GetChar(int ordinal) => throw new NotImplementedException(); - - public override long GetChars(int ordinal, long dataOffset, char[] buffer, int bufferOffset, int length) => throw new NotImplementedException(); - - public override string GetDataTypeName(int ordinal) => throw new NotImplementedException(); - - public override DateTime GetDateTime(int ordinal) => throw new NotImplementedException(); - - public override decimal GetDecimal(int ordinal) => throw new NotImplementedException(); - - public override double GetDouble(int ordinal) => throw new NotImplementedException(); - - public override IEnumerator GetEnumerator() => throw new NotImplementedException(); + public int Label; - public override Type GetFieldType(int ordinal) => throw new NotImplementedException(); + [LoadColumn(1, 2)] + [VectorType(2)] + public float[] SepalInfo; - public override float GetFloat(int ordinal) - { - float result = 0; - _rowCursor.GetGetter(_dataView.Schema[ordinal])(ref result); - return result; + [LoadColumn(3, 4)] + [VectorType(2)] + public float[] PetalInfo; } - public override Guid GetGuid(int ordinal) => throw new NotImplementedException(); - - public override short GetInt16(int ordinal) => throw new NotImplementedException(); - - public override int GetInt32(int ordinal) + public class IrisPrediction { - int result = 0; - _rowCursor.GetGetter(_dataView.Schema[ordinal])(ref result); - return result; - } - - public override long GetInt64(int ordinal) => throw new NotImplementedException(); - - public override string GetName(int ordinal) => throw new NotImplementedException(); + public int PredictedLabel; - public override int GetOrdinal(string name) - { - var connection = (MockConnection)_command.Connection; - return connection.DataView.Schema.TryGetColumnIndex(name, out int ordinal) ? ordinal : -1; + public float[] Score; } - - public override string GetString(int ordinal) => throw new NotImplementedException(); - - public override object GetValue(int ordinal) => throw new NotImplementedException(); - - public override int GetValues(object[] values) => throw new NotImplementedException(); - - public override bool IsDBNull(int ordinal) => false; - - public override bool NextResult() => throw new NotImplementedException(); - - public override bool Read() => _rowCursor.MoveNext(); } } diff --git a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj index 5712c755e4..fd4db7b3e4 100644 --- a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj +++ b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj @@ -50,6 +50,8 @@ + +