diff --git a/.editorconfig b/.editorconfig
new file mode 100644
index 0000000000..509e5fce93
--- /dev/null
+++ b/.editorconfig
@@ -0,0 +1,5 @@
+root = true
+
+[*.cs]
+# Sort using directives with System.* appearing first
+dotnet_sort_system_directives_first = true
\ No newline at end of file
diff --git a/.vsts-dotnet-ci.yml b/.vsts-dotnet-ci.yml
index 2699c9b7b3..ff6a0d6bec 100644
--- a/.vsts-dotnet-ci.yml
+++ b/.vsts-dotnet-ci.yml
@@ -23,6 +23,20 @@ phases:
queue:
name: Hosted VS2017
+- template: /build/ci/phase-template.yml
+ parameters:
+ name: core30
+ buildScript: build.cmd
+ customMatrixes:
+ Build_Debug_Intrinsics:
+ _configuration: Debug-Intrinsics
+ _config_short: DI
+ Build_Release_Intrinsics:
+ _configuration: Release-Intrinsics
+ _config_short: RI
+ queue:
+ name: Hosted VS2017
+
- template: /build/ci/phase-template.yml
parameters:
name: Windows_x86
diff --git a/DotnetCLIVersion.netcoreapp.latest.txt b/DotnetCLIVersion.netcoreapp.latest.txt
new file mode 100644
index 0000000000..23f4ef9a4c
--- /dev/null
+++ b/DotnetCLIVersion.netcoreapp.latest.txt
@@ -0,0 +1 @@
+3.0.100-alpha1-009622
\ No newline at end of file
diff --git a/Microsoft.ML.sln b/Microsoft.ML.sln
index 1b7d335955..b5e7e1b5e3 100644
--- a/Microsoft.ML.sln
+++ b/Microsoft.ML.sln
@@ -29,8 +29,6 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.KMeansClusteri
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.PCA", "src\Microsoft.ML.PCA\Microsoft.ML.PCA.csproj", "{58E06735-1129-4DD5-86E0-6BBFF049AAD9}"
EndProject
-Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Api", "src\Microsoft.ML.Api\Microsoft.ML.Api.csproj", "{2F636A2C-062C-49F4-85F3-60DCADAB6A43}"
-EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Tests", "test\Microsoft.ML.Tests\Microsoft.ML.Tests.csproj", "{64BC22D3-1E76-41EF-94D8-C79E471FF2DD}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.TestFramework", "test\Microsoft.ML.TestFramework\Microsoft.ML.TestFramework.csproj", "{B5989C06-4FFA-46C1-9D85-9366B34AB0A2}"
@@ -139,6 +137,18 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.DnnImageFeatur
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.DnnImageFeaturizer.ResNet101", "src\Microsoft.ML.DnnImageFeaturizer.ResNet101\Microsoft.ML.DnnImageFeaturizer.ResNet101.csproj", "{DB7CEB5E-8BE6-48A7-87BE-B91D9AE96F71}"
EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.EntryPoints", "src\Microsoft.ML.EntryPoints\Microsoft.ML.EntryPoints.csproj", "{7504D46F-E4B3-43CB-9B1C-82F3131F1C99}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.StaticPipe", "src\Microsoft.ML.StaticPipe\Microsoft.ML.StaticPipe.csproj", "{6B1B93D0-142A-4111-A20E-62B55A3E36A3}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.TensorFlow.StaticPipe", "src\Microsoft.ML.TensorFlow.StaticPipe\Microsoft.ML.TensorFlow.StaticPipe.csproj", "{F95F7AFB-03AF-4D20-BD75-1740B5FF71D3}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.HalLearners.StaticPipe", "src\Microsoft.ML.HalLearners.StaticPipe\Microsoft.ML.HalLearners.StaticPipe.csproj", "{2F25EF6A-C754-45BE-AD9E-7DDF46A1B51A}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.OnnxTransform.StaticPipe", "src\Microsoft.ML.OnnxTransform.StaticPipe\Microsoft.ML.OnnxTransform.StaticPipe.csproj", "{D1324668-9568-40F4-AA55-30A9A516C230}"
+EndProject
+Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.ML.LightGBM.StaticPipe", "src\Microsoft.ML.LightGBM.StaticPipe\Microsoft.ML.LightGBM.StaticPipe.csproj", "{22C51B08-ACAE-47B2-A312-462DC239A23B}"
+EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@@ -227,14 +237,6 @@ Global
{58E06735-1129-4DD5-86E0-6BBFF049AAD9}.Release|Any CPU.Build.0 = Release|Any CPU
{58E06735-1129-4DD5-86E0-6BBFF049AAD9}.Release-Intrinsics|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU
{58E06735-1129-4DD5-86E0-6BBFF049AAD9}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU
- {2F636A2C-062C-49F4-85F3-60DCADAB6A43}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
- {2F636A2C-062C-49F4-85F3-60DCADAB6A43}.Debug|Any CPU.Build.0 = Debug|Any CPU
- {2F636A2C-062C-49F4-85F3-60DCADAB6A43}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug-Intrinsics|Any CPU
- {2F636A2C-062C-49F4-85F3-60DCADAB6A43}.Debug-Intrinsics|Any CPU.Build.0 = Debug-Intrinsics|Any CPU
- {2F636A2C-062C-49F4-85F3-60DCADAB6A43}.Release|Any CPU.ActiveCfg = Release|Any CPU
- {2F636A2C-062C-49F4-85F3-60DCADAB6A43}.Release|Any CPU.Build.0 = Release|Any CPU
- {2F636A2C-062C-49F4-85F3-60DCADAB6A43}.Release-Intrinsics|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU
- {2F636A2C-062C-49F4-85F3-60DCADAB6A43}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU
{64BC22D3-1E76-41EF-94D8-C79E471FF2DD}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{64BC22D3-1E76-41EF-94D8-C79E471FF2DD}.Debug|Any CPU.Build.0 = Debug|Any CPU
{64BC22D3-1E76-41EF-94D8-C79E471FF2DD}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug-Intrinsics|Any CPU
@@ -531,6 +533,54 @@ Global
{DB7CEB5E-8BE6-48A7-87BE-B91D9AE96F71}.Release|Any CPU.Build.0 = Release|Any CPU
{DB7CEB5E-8BE6-48A7-87BE-B91D9AE96F71}.Release-Intrinsics|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU
{DB7CEB5E-8BE6-48A7-87BE-B91D9AE96F71}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU
+ {7504D46F-E4B3-43CB-9B1C-82F3131F1C99}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {7504D46F-E4B3-43CB-9B1C-82F3131F1C99}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {7504D46F-E4B3-43CB-9B1C-82F3131F1C99}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug-Intrinsics|Any CPU
+ {7504D46F-E4B3-43CB-9B1C-82F3131F1C99}.Debug-Intrinsics|Any CPU.Build.0 = Debug-Intrinsics|Any CPU
+ {7504D46F-E4B3-43CB-9B1C-82F3131F1C99}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {7504D46F-E4B3-43CB-9B1C-82F3131F1C99}.Release|Any CPU.Build.0 = Release|Any CPU
+ {7504D46F-E4B3-43CB-9B1C-82F3131F1C99}.Release-Intrinsics|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU
+ {7504D46F-E4B3-43CB-9B1C-82F3131F1C99}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU
+ {6B1B93D0-142A-4111-A20E-62B55A3E36A3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {6B1B93D0-142A-4111-A20E-62B55A3E36A3}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {6B1B93D0-142A-4111-A20E-62B55A3E36A3}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug-Intrinsics|Any CPU
+ {6B1B93D0-142A-4111-A20E-62B55A3E36A3}.Debug-Intrinsics|Any CPU.Build.0 = Debug-Intrinsics|Any CPU
+ {6B1B93D0-142A-4111-A20E-62B55A3E36A3}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {6B1B93D0-142A-4111-A20E-62B55A3E36A3}.Release|Any CPU.Build.0 = Release|Any CPU
+ {6B1B93D0-142A-4111-A20E-62B55A3E36A3}.Release-Intrinsics|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU
+ {6B1B93D0-142A-4111-A20E-62B55A3E36A3}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU
+ {F95F7AFB-03AF-4D20-BD75-1740B5FF71D3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {F95F7AFB-03AF-4D20-BD75-1740B5FF71D3}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {F95F7AFB-03AF-4D20-BD75-1740B5FF71D3}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug-Intrinsics|Any CPU
+ {F95F7AFB-03AF-4D20-BD75-1740B5FF71D3}.Debug-Intrinsics|Any CPU.Build.0 = Debug-Intrinsics|Any CPU
+ {F95F7AFB-03AF-4D20-BD75-1740B5FF71D3}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {F95F7AFB-03AF-4D20-BD75-1740B5FF71D3}.Release|Any CPU.Build.0 = Release|Any CPU
+ {F95F7AFB-03AF-4D20-BD75-1740B5FF71D3}.Release-Intrinsics|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU
+ {F95F7AFB-03AF-4D20-BD75-1740B5FF71D3}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU
+ {2F25EF6A-C754-45BE-AD9E-7DDF46A1B51A}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {2F25EF6A-C754-45BE-AD9E-7DDF46A1B51A}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {2F25EF6A-C754-45BE-AD9E-7DDF46A1B51A}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug-Intrinsics|Any CPU
+ {2F25EF6A-C754-45BE-AD9E-7DDF46A1B51A}.Debug-Intrinsics|Any CPU.Build.0 = Debug-Intrinsics|Any CPU
+ {2F25EF6A-C754-45BE-AD9E-7DDF46A1B51A}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {2F25EF6A-C754-45BE-AD9E-7DDF46A1B51A}.Release|Any CPU.Build.0 = Release|Any CPU
+ {2F25EF6A-C754-45BE-AD9E-7DDF46A1B51A}.Release-Intrinsics|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU
+ {2F25EF6A-C754-45BE-AD9E-7DDF46A1B51A}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU
+ {D1324668-9568-40F4-AA55-30A9A516C230}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {D1324668-9568-40F4-AA55-30A9A516C230}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {D1324668-9568-40F4-AA55-30A9A516C230}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug-Intrinsics|Any CPU
+ {D1324668-9568-40F4-AA55-30A9A516C230}.Debug-Intrinsics|Any CPU.Build.0 = Debug-Intrinsics|Any CPU
+ {D1324668-9568-40F4-AA55-30A9A516C230}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {D1324668-9568-40F4-AA55-30A9A516C230}.Release|Any CPU.Build.0 = Release|Any CPU
+ {D1324668-9568-40F4-AA55-30A9A516C230}.Release-Intrinsics|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU
+ {D1324668-9568-40F4-AA55-30A9A516C230}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU
+ {22C51B08-ACAE-47B2-A312-462DC239A23B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {22C51B08-ACAE-47B2-A312-462DC239A23B}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {22C51B08-ACAE-47B2-A312-462DC239A23B}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug-Intrinsics|Any CPU
+ {22C51B08-ACAE-47B2-A312-462DC239A23B}.Debug-Intrinsics|Any CPU.Build.0 = Debug-Intrinsics|Any CPU
+ {22C51B08-ACAE-47B2-A312-462DC239A23B}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {22C51B08-ACAE-47B2-A312-462DC239A23B}.Release|Any CPU.Build.0 = Release|Any CPU
+ {22C51B08-ACAE-47B2-A312-462DC239A23B}.Release-Intrinsics|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU
+ {22C51B08-ACAE-47B2-A312-462DC239A23B}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
@@ -546,7 +596,6 @@ Global
{7288C084-11C0-43BE-AC7F-45DCFEAEEBF6} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{F1CAE3AB-4F86-4BC0-BBA8-C4A58E7E8A4A} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{58E06735-1129-4DD5-86E0-6BBFF049AAD9} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
- {2F636A2C-062C-49F4-85F3-60DCADAB6A43} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{64BC22D3-1E76-41EF-94D8-C79E471FF2DD} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{B5989C06-4FFA-46C1-9D85-9366B34AB0A2} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{6B047E09-39C9-4583-96F3-685D84CA4117} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
@@ -589,6 +638,12 @@ Global
{6C29AA9B-054B-4762-BEA5-D305B932AA80} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{4805129D-78C8-46D4-9519-0AD9B0574D6D} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{DB7CEB5E-8BE6-48A7-87BE-B91D9AE96F71} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
+ {7504D46F-E4B3-43CB-9B1C-82F3131F1C99} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
+ {6B1B93D0-142A-4111-A20E-62B55A3E36A3} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
+ {F95F7AFB-03AF-4D20-BD75-1740B5FF71D3} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
+ {2F25EF6A-C754-45BE-AD9E-7DDF46A1B51A} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
+ {D1324668-9568-40F4-AA55-30A9A516C230} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
+ {22C51B08-ACAE-47B2-A312-462DC239A23B} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D}
diff --git a/README.md b/README.md
index 15dd419948..38829cd751 100644
--- a/README.md
+++ b/README.md
@@ -18,7 +18,7 @@ Along with these ML capabilities, this first release of ML.NET also brings the f
ML.NET runs on Windows, Linux, and macOS using [.NET Core](https://github.com/dotnet/core), or Windows using .NET Framework. 64 bit is supported on all platforms. 32 bit is supported on Windows, except for TensorFlow, LightGBM, and ONNX related functionality.
-The current release is 0.7. Check out the [release notes](docs/release-notes/0.7/release-0.7.md) and [blog post](https://blogs.msdn.microsoft.com/dotnet/2018/11/08/announcing-ml-net-0-7-machine-learning-net/) to see what's new.
+The current release is 0.8. Check out the [release notes](docs/release-notes/0.8/release-0.8.md) and [blog post](https://blogs.msdn.microsoft.com/dotnet/2018/12/02/announcing-ml-net-0-8-machine-learning-for-net/) to see what's new.
First, ensure you have installed [.NET Core 2.1](https://www.microsoft.com/net/learn/get-started) or later. ML.NET also works on the .NET Framework 4.6.1 or later, but 4.7.2 or later is recommended.
diff --git a/build.proj b/build.proj
index de8c507de7..15fea4e309 100644
--- a/build.proj
+++ b/build.proj
@@ -78,7 +78,7 @@
- https://aka.ms/tlc-resources/benchmarks/%(Identity)
+ https://aka.ms/mlnet-resources/benchmarks/%(Identity)$(MSBuildThisFileDirectory)/test/data/external/%(Identity)
diff --git a/build/BranchInfo.props b/build/BranchInfo.props
index 02e98d6b5d..1627429558 100644
--- a/build/BranchInfo.props
+++ b/build/BranchInfo.props
@@ -1,7 +1,7 @@
0
- 8
+ 90preview
diff --git a/build/Dependencies.props b/build/Dependencies.props
index 46304fc56a..8d2f7abdc1 100644
--- a/build/Dependencies.props
+++ b/build/Dependencies.props
@@ -15,7 +15,7 @@
3.5.12.2.1.1
- 1.1.0
+ 0.1.50.0.0.72.1.34.5.0
diff --git a/build/ExternalBenchmarkDataFiles.props b/build/ExternalBenchmarkDataFiles.props
index ad3d350d60..42df4ccd96 100644
--- a/build/ExternalBenchmarkDataFiles.props
+++ b/build/ExternalBenchmarkDataFiles.props
@@ -1,5 +1,6 @@
+
diff --git a/build/ci/phase-template.yml b/build/ci/phase-template.yml
index 9929ba182d..a9d2a35943 100644
--- a/build/ci/phase-template.yml
+++ b/build/ci/phase-template.yml
@@ -3,6 +3,7 @@ parameters:
architecture: x64
buildScript: ''
queue: {}
+ customMatrixes: ''
phases:
- phase: ${{ parameters.name }}
@@ -14,12 +15,15 @@ phases:
timeoutInMinutes: 45
parallel: 99
matrix:
- Build_Debug:
- _configuration: Debug
- _config_short: D
- Build_Release:
- _configuration: Release
- _config_short: R
+ ${{ if eq(parameters.customMatrixes, '') }}:
+ Build_Debug:
+ _configuration: Debug
+ _config_short: D
+ Build_Release:
+ _configuration: Release
+ _config_short: R
+ ${{ if ne(parameters.customMatrixes, '') }}:
+ ${{ insert }}: ${{ parameters.customMatrixes }}
${{ insert }}: ${{ parameters.queue }}
steps:
- script: $(_buildScript) -$(_configuration) -buildArch=$(_arch)
diff --git a/build/vsts-ci.yml b/build/vsts-ci.yml
index d0171921ac..ab77d0ed4a 100644
--- a/build/vsts-ci.yml
+++ b/build/vsts-ci.yml
@@ -24,7 +24,7 @@ phases:
container: LinuxContainer
steps:
# Only build native assets to avoid conflicts.
- - script: ./build.sh -buildNative -$(BuildConfig)
+ - script: ./build.sh -buildNative -$(BuildConfig) -skipRIDAgnosticAssets
displayName: Build
- task: PublishBuildArtifacts@1
@@ -49,7 +49,7 @@ phases:
- agent.os -equals Darwin
steps:
# Only build native assets to avoid conflicts.
- - script: ./build.sh -buildNative -$(BuildConfig)
+ - script: ./build.sh -buildNative -$(BuildConfig) -skipRIDAgnosticAssets
displayName: Build
- task: PublishBuildArtifacts@1
@@ -89,7 +89,7 @@ phases:
condition: and(succeeded(), in(variables._SignType, 'real', 'test'))
# Only build native assets to avoid conflicts.
- - script: ./build.cmd -buildNative -$(BuildConfig) -buildArch=x86
+ - script: ./build.cmd -buildNative -$(BuildConfig) -buildArch=x86 -skipRIDAgnosticAssets
displayName: Build
- task: MSBuild@1
diff --git a/config.json b/config.json
index 8436586e61..99f72f7051 100644
--- a/config.json
+++ b/config.json
@@ -3,7 +3,7 @@
"Configuration": {
"description": "Sets the optimization level for the Build Configuration you want to build.",
"valueType": "property",
- "values": [ "Debug", "Release" ],
+ "values": [ "Debug", "Release", "Debug-Intrinsics", "Release-Intrinsics" ],
"defaultValue": "Debug"
},
"TargetArchitecture": {
@@ -30,6 +30,12 @@
"values": [],
"defaultValue": ""
},
+ "SkipRIDAgnosticAssets": {
+ "description": "Prevents RID agnostic assets in redist from being built.",
+ "valueType": "property",
+ "values": [],
+ "defaultValue": ""
+ },
"MsBuildLogging": {
"description": "MsBuild logging options.",
"valueType": "passThrough",
@@ -94,6 +100,18 @@
"Configuration": "Release"
}
},
+ "debug-intrinsics": {
+ "description": "Sets optimization level to debug for managed build configuration and builds against netcoreapp3.0. (/p:Configuration=Debug-Intrinsics)",
+ "settings": {
+ "Configuration": "Debug-Intrinsics"
+ }
+ },
+ "release-intrinsics": {
+ "description": "Sets optimization level to release for managed build configuration and builds against netcoreapp3.0. (/p:Configuration=Release-Intrinsics)",
+ "settings": {
+ "Configuration": "Release-Intrinsics"
+ }
+ },
"buildArch": {
"description": "Sets the architecture for the native build. (/p:TargetArchitecture=[value])",
"settings": {
@@ -106,6 +124,12 @@
"BuildNative": "default"
}
},
+ "skipRIDAgnosticAssets": {
+ "description": "Avoid building RID agnostic assets in redist.",
+ "settings": {
+ "SkipRIDAgnosticAssets": "default"
+ }
+ },
"buildPackages": {
"description": "Builds the NuGet packages.",
"settings": {
diff --git a/docs/building/netcoreapp3.0-instructions.md b/docs/building/netcoreapp3.0-instructions.md
index 338e06ef83..d65bcba34b 100644
--- a/docs/building/netcoreapp3.0-instructions.md
+++ b/docs/building/netcoreapp3.0-instructions.md
@@ -1,10 +1,8 @@
In order to build ML.NET for .NET Core 3.0, you need to do a few manual steps.
-1. Pick a version of the .NET Core 3.0 SDK you want to use. As of this writing, I'm using `3.0.100-alpha1-009622`. You can get the latest available version from the [dotnet/core-sdk README](https://github.com/dotnet/core-sdk#installers-and-binaries) page.
-2. Change the [DotnetCLIVersion.txt](https://github.com/dotnet/machinelearning/blob/master/DotnetCLIVersion.txt) file to use that version number.
-3. Delete the local `.\Tools\` folder from the root of the repo, to ensure you download the new version.
-4. Run `.\build.cmd -- /p:Configuration=Release-Intrinsics` from the root of the repo.
-5. If you want to build the NuGet packages, `.\build.cmd -buildPackages` after step 4.
+1. Delete the local `.\Tools\` folder from the root of the repo, to ensure you download the new version of the .NET Core SDK.
+2. Run `.\build.cmd -- /p:Configuration=Release-Intrinsics` or `.\build.cmd -Release-Intrinsics` from the root of the repo.
+3. If you want to build the NuGet packages, `.\build.cmd -buildPackages` after step 2.
If you are using Visual Studio, you will need to do the following:
diff --git a/docs/building/unix-instructions.md b/docs/building/unix-instructions.md
index 7efa1c414c..ec7e27e8d4 100644
--- a/docs/building/unix-instructions.md
+++ b/docs/building/unix-instructions.md
@@ -5,6 +5,7 @@ Building ML.NET on Linux and macOS
1. Install the prerequisites ([Linux](#user-content-linux), [macOS](#user-content-macos))
2. Clone the machine learning repo `git clone --recursive https://github.com/dotnet/machinelearning.git`
3. Navigate to the `machinelearning` directory
+4. Run `git submodule update --init` if you have not previously done so
4. Run the build script `./build.sh`
Calling the script `./build.sh` builds both the native and managed code.
diff --git a/docs/code/EntryPoints.md b/docs/code/EntryPoints.md
index dbcc4e6bc9..9fb11a6e99 100644
--- a/docs/code/EntryPoints.md
+++ b/docs/code/EntryPoints.md
@@ -220,12 +220,11 @@ parameter.
## How to create an entry point for an existing ML.NET component
-The steps to take, to create an entry point for an existing ML.NET component, are:
-1. Add the `SignatureEntryPointModule` signature to the `LoadableClass` assembly attribute.
+1. Add a `LoadableClass` assembly attribute with the `SignatureEntryPointModule` signature as shown [here](https://github.com/dotnet/machinelearning/blob/9db16c85888e7163c671543faee6ba1f47015d68/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs#L27).
2. Create a public static method, that:
- a. Takes as input, among others, an object representing the arguments of the component you want to expose.
- b. Initializes and run the components, returning one of the nested classes of `Microsoft.ML.Runtime.EntryPoints.CommonOutputs`
- c. Is annotated with the `TlcModule.EntryPoint` attribute
+ 1. Takes an object representing the arguments of the component you want to expose as shown [here](https://github.com/dotnet/machinelearning/blob/9db16c85888e7163c671543faee6ba1f47015d68/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs#L414)
+ 2. Initializes and runs the component, returning one of the nested classes of [`Microsoft.ML.EntryPoints.CommonOutputs`](https://github.com/dotnet/machinelearning/blob/master/src/Microsoft.ML.Data/EntryPoints/CommonOutputs.cs)
+ 3. Is annotated with the [`TlcModule.EntryPoint`](https://github.com/dotnet/machinelearning/blob/9db16c85888e7163c671543faee6ba1f47015d68/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs#L407) attribute
-Based on the type of entry point being created, there are further conventions on the name of the method, for example, the Trainers entry points are typically called: 'TrainMultiClass', 'TrainBinary' etc, based on the task.
-Look at [OnlineGradientDescent](../../src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs) for an example of a component and its entry point.
\ No newline at end of file
+For an example of a transformer as an entrypoint, see [OneHotVectorizer](https://github.com/dotnet/machinelearning/blob/9db16c85888e7163c671543faee6ba1f47015d68/src/Microsoft.ML.Transforms/OneHotEncoding.cs#L283).
+For a trainer-estimator, see [LogisticRegression](https://github.com/dotnet/machinelearning/blob/9db16c85888e7163c671543faee6ba1f47015d68/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs#L407).
diff --git a/docs/code/MlNetCookBook.md b/docs/code/MlNetCookBook.md
index 6ce3c2c67b..908010c721 100644
--- a/docs/code/MlNetCookBook.md
+++ b/docs/code/MlNetCookBook.md
@@ -95,7 +95,7 @@ This is how you can read this data:
var mlContext = new MLContext();
// Create the reader: define the data columns and where to find them in the text file.
-var reader = mlContext.Data.TextReader(ctx => (
+var reader = mlContext.Data.CreateTextReader(ctx => (
// A boolean column depicting the 'target label'.
IsOver50K: ctx.LoadBool(0),
// Three text columns.
@@ -115,9 +115,7 @@ If the schema of the data is not known at compile time, or too cumbersome, you c
var mlContext = new MLContext();
// Create the reader: define the data columns and where to find them in the text file.
-var reader = mlContext.Data.TextReader(new TextLoader.Arguments
-{
- Column = new[] {
+var reader = mlContext.Data.CreateTextReader(new[] {
// A boolean column depicting the 'label'.
new TextLoader.Column("IsOver50K", DataKind.BL, 0),
// Three text columns.
@@ -126,13 +124,51 @@ var reader = mlContext.Data.TextReader(new TextLoader.Arguments
new TextLoader.Column("MaritalStatus", DataKind.TX, 3)
},
// First line of the file is a header, not a data row.
- HasHeader = true
-});
+ hasHeader: true
+);
// Now read the file (remember though, readers are lazy, so the actual reading will happen when the data is accessed).
var data = reader.Read(dataPath);
```
+You can also create a data model class, and read the data based on this type.
+
+```csharp
+// The data model. This type will be used through the document.
+private class InspectedRow
+{
+ [LoadColumn(0)]
+ public bool IsOver50K { get; set; }
+
+ [LoadColumn(1)]
+ public string Workclass { get; set; }
+
+ [LoadColumn(2)]
+ public string Education { get; set; }
+
+ [LoadColumn(3)]
+ public string MaritalStatus { get; set; }
+
+ public string[] AllFeatures { get; set; }
+}
+
+private class InspectedRowWithAllFeatures : InspectedRow
+{
+ public string[] AllFeatures { get; set; }
+}
+
+// Create a new context for ML.NET operations. It can be used for exception tracking and logging,
+// as a catalog of available operations and as the source of randomness.
+var mlContext = new MLContext();
+
+// Read the data into a data view.
+var data = mlContext.Data.ReadFromTextFile(dataPath,
+ // First line of the file is a header, not a data row.
+ hasHeader: true
+)
+
+```
+
## How do I load data from multiple files?
You can again use the `TextLoader`, and specify an array of files to its Read method.
@@ -155,7 +191,7 @@ This is how you can read this data:
var mlContext = new MLContext();
// Create the reader: define the data columns and where to find them in the text file.
-var reader = mlContext.Data.TextReader(ctx => (
+var reader = mlContext.Data.CreateTextReader(ctx => (
// A boolean column depicting the 'target label'.
IsOver50K: ctx.LoadBool(14),
// Three text columns.
@@ -175,19 +211,17 @@ The code is very similar using the dynamic API:
var mlContext = new MLContext();
// Create the reader: define the data columns and where to find them in the text file.
-var reader = mlContext.Data.TextReader(new TextLoader.Arguments
-{
- Column = new[] {
+var reader = mlContext.Data.CreateTextReader(new[] {
// A boolean column depicting the 'label'.
- new TextLoader.Column("IsOver50k", DataKind.BL, 0),
+ new TextLoader.Column("IsOver50K", DataKind.BL, 0),
// Three text columns.
new TextLoader.Column("Workclass", DataKind.TX, 1),
new TextLoader.Column("Education", DataKind.TX, 2),
new TextLoader.Column("MaritalStatus", DataKind.TX, 3)
},
// First line of the file is a header, not a data row.
- HasHeader = true
-});
+ hasHeader: true
+);
var data = reader.Read(exampleFile1, exampleFile2);
```
@@ -211,14 +245,14 @@ Reading this file using `TextLoader`:
var mlContext = new MLContext();
// Create the reader: define the data columns and where to find them in the text file.
-var reader = mlContext.Data.TextReader(ctx => (
+var reader = mlContext.Data.CreateTextReader(ctx => (
// We read the first 11 values as a single float vector.
FeatureVector: ctx.LoadFloat(0, 10),
// Separately, read the target variable.
Target: ctx.LoadFloat(11)
),
// Default separator is tab, but we need a comma.
- separator: ',');
+ separatorChar: ',');
// Now read the file (remember though, readers are lazy, so the actual reading will happen when the data is accessed).
@@ -233,19 +267,43 @@ If the schema of the data is not known at compile time, or too cumbersome, you c
var mlContext = new MLContext();
// Create the reader: define the data columns and where to find them in the text file.
-var reader = mlContext.Data.TextReader(new[] {
+var reader = mlContext.Data.CreateTextReader(new[] {
// We read the first 10 values as a single float vector.
- new TextLoader.Column("FeatureVector", DataKind.R4, new[] {new TextLoader.Range(0, 9)}),
+ new TextLoader.Column("FeatureVector", DataKind.R4, new[] {new TextLoader.Range(0, 10)}),
// Separately, read the target variable.
- new TextLoader.Column("Target", DataKind.R4, 10)
+ new TextLoader.Column("Target", DataKind.R4, 11)
},
// Default separator is tab, but we need a comma.
- s => s.Separator = ",");
+ separatorChar: ',');
// Now read the file (remember though, readers are lazy, so the actual reading will happen when the data is accessed).
var data = reader.Read(dataPath);
```
+Or by creating a data model for it:
+
+```csharp
+private class AdultData
+{
+ [LoadColumn("0", "10"), ColumnName("Features")]
+ public float FeatureVector { get; }
+
+ [LoadColumn(11)]
+ public float Target { get; }
+}
+
+// Create a new context for ML.NET operations. It can be used for exception tracking and logging,
+// as a catalog of available operations and as the source of randomness.
+var mlContext = new MLContext();
+
+// Read the data into a data view.
+var data = mlContext.Data.ReadFromTextFile(dataPath,
+ // First line of the file is a header, not a data row.
+ separatorChar: ','
+);
+
+```
+
## How do I debug my experiment or preview my pipeline?
Most ML.NET operations are 'lazy': they are not actually processing data, they just validate that the operation is possible, and then defer execution until the output data is actually requested. This provides good efficiency, but makes it hard to step through and debug the experiment.
@@ -302,7 +360,7 @@ Label Workclass education marital-status
var mlContext = new MLContext();
// Create the reader: define the data columns and where to find them in the text file.
-var reader = mlContext.Data.TextReader(ctx => (
+var reader = mlContext.Data.CreateTextReader(ctx => (
// A boolean column depicting the 'target label'.
IsOver50K: ctx.LoadBool(0),
// Three text columns.
@@ -329,7 +387,7 @@ var transformedData = dataPipeline.Fit(data).Transform(data);
// 'transformedData' is a 'promise' of data. Let's actually read it.
var someRows = transformedData.AsDynamic
// Convert to an enumerable of user-defined type.
- .AsEnumerable(mlContext, reuseRowObject: false)
+ .AsEnumerable(mlContext, reuseRowObject: false)
// Take a couple values as an array.
.Take(4).ToArray();
@@ -346,54 +404,29 @@ var sameFeatureColumns = dynamicData.GetColumn(mlContext, "AllFeatures
.Take(20).ToArray();
```
-The above code assumes that we defined our `InspectedRow` class as follows:
-```csharp
-private class InspectedRow
-{
- public bool IsOver50K;
- public string Workclass;
- public string Education;
- public string MaritalStatus;
- public string[] AllFeatures;
-}
-```
-
You can also use the dynamic API to create the equivalent of the previous pipeline.
```csharp
// Create a new context for ML.NET operations. It can be used for exception tracking and logging,
// as a catalog of available operations and as the source of randomness.
var mlContext = new MLContext();
-// Create the reader: define the data columns and where to find them in the text file.
-var reader = mlContext.Data.TextReader(new TextLoader.Arguments
-{
- Column = new[] {
- // A boolean column depicting the 'label'.
- new TextLoader.Column("IsOver50k", DataKind.BL, 0),
- // Three text columns.
- new TextLoader.Column("Workclass", DataKind.TX, 1),
- new TextLoader.Column("Education", DataKind.TX, 2),
- new TextLoader.Column("MaritalStatus", DataKind.TX, 3)
- },
+// Read the data into a data view.
+var data = mlContext.Data.ReadFromTextFile(dataPath,
// First line of the file is a header, not a data row.
- HasHeader = true
-});
+ hasHeader: true
+);
// Start creating our processing pipeline. For now, let's just concatenate all the text columns
// together into one.
var dynamicPipeline = mlContext.Transforms.Concatenate("AllFeatures", "Education", "MaritalStatus");
-// Let's verify that the data has been read correctly.
-// First, we read the data file.
-var data = reader.Read(dataPath);
-
// Fit our data pipeline and transform data with it.
var transformedData = dynamicPipeline.Fit(data).Transform(data);
// 'transformedData' is a 'promise' of data. Let's actually read it.
var someRows = transformedData
// Convert to an enumerable of user-defined type.
- .AsEnumerable(mlContext, reuseRowObject: false)
+ .AsEnumerable(mlContext, reuseRowObject: false)
// Take a couple values as an array.
.Take(4).ToArray();
@@ -428,7 +461,7 @@ var mlContext = new MLContext();
// Step one: read the data as an IDataView.
// First, we define the reader: specify the data columns and where to find them in the text file.
-var reader = mlContext.Data.TextReader(ctx => (
+var reader = mlContext.Data.CreateTextReader(ctx => (
// We read the first 11 values as a single float vector.
FeatureVector: ctx.LoadFloat(0, 10),
// Separately, read the target variable.
@@ -437,16 +470,30 @@ var reader = mlContext.Data.TextReader(ctx => (
// The data file has header.
hasHeader: true,
// Default separator is tab, but we need a semicolon.
- separator: ';');
+ separatorChar: ';');
// Now read the file (remember though, readers are lazy, so the actual reading will happen when the data is accessed).
var trainData = reader.Read(trainDataPath);
+// Sometime, caching data in-memory after its first access can save some loading time when the data is going to be used
+// several times somewhere. The caching mechanism is also lazy; it only caches things after being used.
+// User can replace all the subsequently uses of "trainData" with "cachedTrainData". We still use "trainData" because
+// a caching step, which provides the same caching function, will be inserted in the considered "learningPipeline."
+var cachedTrainData = trainData.Cache();
+
// Step two: define the learning pipeline.
// We 'start' the pipeline with the output of the reader.
var learningPipeline = reader.MakeNewEstimator()
+ // We add a step for caching data in memory so that the downstream iterative training
+ // algorithm can efficiently scan through the data multiple times. Otherwise, the following
+ // trainer will read data from disk multiple times. The caching mechanism uses an on-demand strategy.
+ // The data accessed in any downstream step will be cached since its first use. In general, you only
+ // need to add a caching step before trainable step, because caching is not helpful if the data is
+ // only scanned once. This step can be removed if user doesn't have enough memory to store the whole
+ // data set.
+ .AppendCacheCheckpoint()
// Now we can add any 'training steps' to it. In our case we want to 'normalize' the data (rescale to be
// between -1 and 1 for all examples)
.Append(r => (
@@ -468,23 +515,17 @@ var mlContext = new MLContext();
// Step one: read the data as an IDataView.
// First, we define the reader: specify the data columns and where to find them in the text file.
-var reader = mlContext.Data.TextReader(new TextLoader.Arguments
-{
- Column = new[] {
- // We read the first 11 values as a single float vector.
- new TextLoader.Column("FeatureVector", DataKind.R4, 0, 10),
-
- // Separately, read the target variable.
- new TextLoader.Column("Target", DataKind.R4, 11),
- },
+// Read the data into a data view. Remember though, readers are lazy, so the actual reading will happen when the data is accessed.
+var trainData = mlContext.Data.ReadFromTextFile(dataPath,
// First line of the file is a header, not a data row.
- HasHeader = true,
- // Default separator is tab, but we need a semicolon.
- Separator = ";"
-});
+ separatorChar: ','
+);
-// Now read the file (remember though, readers are lazy, so the actual reading will happen when the data is accessed).
-var trainData = reader.Read(trainDataPath);
+// Sometime, caching data in-memory after its first access can save some loading time when the data is going to be used
+// several times somewhere. The caching mechanism is also lazy; it only caches things after being used.
+// User can replace all the subsequently uses of "trainData" with "cachedTrainData". We still use "trainData" because
+// a caching step, which provides the same caching function, will be inserted in the considered "dynamicPipeline."
+var cachedTrainData = mlContext.Data.Cache(trainData);
// Step two: define the learning pipeline.
@@ -493,6 +534,15 @@ var dynamicPipeline =
// First 'normalize' the data (rescale to be
// between -1 and 1 for all examples)
mlContext.Transforms.Normalize("FeatureVector")
+ // We add a step for caching data in memory so that the downstream iterative training
+ // algorithm can efficiently scan through the data multiple times. Otherwise, the following
+ // trainer will read data from disk multiple times. The caching mechanism uses an on-demand strategy.
+ // The data accessed in any downstream step will be cached since its first use. In general, you only
+ // need to add a caching step before trainable step, because caching is not helpful if the data is
+ // only scanned once. This step can be removed if user doesn't have enough memory to store the whole
+ // data set. Notice that in the upstream Transforms.Normalize step, we only scan through the data
+ // once so adding a caching step before it is not helpful.
+ .AppendCacheCheckpoint(mlContext)
// Add the SDCA regression trainer.
.Append(mlContext.Regression.Trainers.StochasticDualCoordinateAscent(label: "Target", features: "FeatureVector"));
@@ -516,7 +566,10 @@ var metrics = mlContext.Regression.Evaluate(model.Transform(testData), label: r
Calculating the metrics with the dynamic API is as follows.
```csharp
// Read the test dataset.
-var testData = reader.Read(testDataPath);
+var testData = mlContext.Data.ReadFromTextFile(testDataPath,
+ // First line of the file is a header, not a data row.
+ separatorChar: ','
+);
// Calculate metrics of the model on the test data.
var metrics = mlContext.Regression.Evaluate(model.Transform(testData), label: "Target");
```
@@ -563,7 +616,7 @@ Since any ML.NET model is a transformer, you can of course use `model.Transform`
A more typical case, though, is when there is no 'dataset' that we want to predict on, but instead we receive one example at a time. For instance, we run the model as part of the ASP.NET website, and we need to make a prediction for an incoming HTTP request.
-For this case, ML.NET offers a convenient `PredictionFunction` component, that essentially runs one example at a time through the prediction pipeline.
+For this case, ML.NET offers a convenient `PredictionEngine` component, that essentially runs one example at a time through the prediction pipeline.
Here is the full example. Let's imagine that we have built a model for the famous Iris prediction dataset:
@@ -574,7 +627,7 @@ var mlContext = new MLContext();
// Step one: read the data as an IDataView.
// First, we define the reader: specify the data columns and where to find them in the text file.
-var reader = mlContext.Data.TextReader(ctx => (
+var reader = mlContext.Data.CreateTextReader(ctx => (
// The four features of the Iris dataset.
SepalLength: ctx.LoadFloat(0),
SepalWidth: ctx.LoadFloat(1),
@@ -584,7 +637,7 @@ var reader = mlContext.Data.TextReader(ctx => (
Label: ctx.LoadText(4)
),
// Default separator is tab, but the dataset has comma.
- separator: ',');
+ separatorChar: ',');
// Retrieve the training data.
var trainData = reader.Read(irisDataPath);
@@ -595,6 +648,13 @@ var learningPipeline = reader.MakeNewEstimator()
r.Label,
// Concatenate all the features together into one column 'Features'.
Features: r.SepalLength.ConcatWith(r.SepalWidth, r.PetalLength, r.PetalWidth)))
+ // We add a step for caching data in memory so that the downstream iterative training
+ // algorithm can efficiently scan through the data multiple times. Otherwise, the following
+ // trainer will read data from disk multiple times. The caching mechanism uses an on-demand strategy.
+ // The data accessed in any downstream step will be cached since its first use. In general, you only
+ // need to add a caching step before trainable step, because caching is not helpful if the data is
+ // only scanned once.
+ .AppendCacheCheckpoint()
.Append(r => (
r.Label,
// Train the multi-class SDCA model to predict the label using features.
@@ -616,23 +676,11 @@ You can also use the dynamic API to create the equivalent of the previous pipeli
var mlContext = new MLContext();
// Step one: read the data as an IDataView.
-// First, we define the reader: specify the data columns and where to find them in the text file.
-var reader = mlContext.Data.TextReader(new TextLoader.Arguments
-{
- Column = new[] {
- new TextLoader.Column("SepalLength", DataKind.R4, 0),
- new TextLoader.Column("SepalWidth", DataKind.R4, 1),
- new TextLoader.Column("PetalLength", DataKind.R4, 2),
- new TextLoader.Column("PetalWidth", DataKind.R4, 3),
- // Label: kind of iris.
- new TextLoader.Column("Label", DataKind.TX, 4),
- },
+ // Retrieve the training data.
+var trainData = mlContext.Data.ReadFromTextFile(irisDataPath,
// Default separator is tab, but the dataset has comma.
- Separator = ","
-});
-
-// Retrieve the training data.
-var trainData = reader.Read(irisDataPath);
+ separatorChar: ','
+);
// Build the training pipeline.
var dynamicPipeline =
@@ -640,6 +688,8 @@ var dynamicPipeline =
mlContext.Transforms.Concatenate("Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth")
// Note that the label is text, so it needs to be converted to key.
.Append(mlContext.Transforms.Categorical.MapValueToKey("Label"), TransformerScope.TrainTest)
+ // Cache data in memory for steps after the cache check point stage.
+ .AppendCacheCheckpoint(mlContext)
// Use the multi-class SDCA model to predict the label using features.
.Append(mlContext.MulticlassClassification.Trainers.StochasticDualCoordinateAscent())
// Apply the inverse conversion from 'PredictedLabel' column back to string value.
@@ -679,10 +729,10 @@ var mlContext = new MLContext();
// Make the prediction function object. Note that, on average, this call takes around 200x longer
// than one prediction, so you might want to cache and reuse the prediction function, instead of
// creating one per prediction.
-var predictionFunc = model.MakePredictionFunction(mlContext);
+var predictionFunc = model.CreatePredictionEngine(mlContext);
// Obtain the prediction. Remember that 'Predict' is not reentrant. If you want to use multiple threads
-// for simultaneous prediction, make sure each thread is using its own PredictionFunction.
+// for simultaneous prediction, make sure each thread is using its own PredictionEngine.
var prediction = predictionFunc.Predict(new IrisInput
{
SepalLength = 4.1f,
@@ -741,6 +791,7 @@ var trainData = mlContext.CreateStreamingDataView(churnData);
var dynamicLearningPipeline = mlContext.Transforms.Categorical.OneHotEncoding("DemographicCategory")
.Append(mlContext.Transforms.Concatenate("Features", "DemographicCategory", "LastVisits"))
+ .AppendCacheCheckpoint(mlContext) // FastTree will benefit from caching data in memory.
.Append(mlContext.BinaryClassification.Trainers.FastTree("HasChurned", "Features", numTrees: 20));
var dynamicModel = dynamicLearningPipeline.Fit(trainData);
@@ -757,6 +808,7 @@ var staticLearningPipeline = staticData.MakeNewEstimator()
.Append(r => (
r.HasChurned,
Features: r.DemographicCategory.OneHotEncoding().ConcatWith(r.LastVisits)))
+ .AppendCacheCheckpoint() // FastTree will benefit from caching data in memory.
.Append(r => mlContext.BinaryClassification.Trainers.FastTree(r.HasChurned, r.Features, numTrees: 20));
var staticModel = staticLearningPipeline.Fit(staticData);
@@ -781,7 +833,7 @@ var mlContext = new MLContext();
// Step one: read the data as an IDataView.
// First, we define the reader: specify the data columns and where to find them in the text file.
-var reader = mlContext.Data.TextReader(ctx => (
+var reader = mlContext.Data.CreateTextReader(ctx => (
// The four features of the Iris dataset.
SepalLength: ctx.LoadFloat(0),
SepalWidth: ctx.LoadFloat(1),
@@ -791,7 +843,7 @@ var reader = mlContext.Data.TextReader(ctx => (
Label: ctx.LoadText(4)
),
// Default separator is tab, but the dataset has comma.
- separator: ',');
+ separatorChar: ',');
// Retrieve the training data.
var trainData = reader.Read(dataPath);
@@ -813,6 +865,8 @@ var learningPipeline = reader.MakeNewEstimator()
// When the normalizer is trained, the below delegate is going to be called.
// We use it to memorize the scales.
onFit: (scales, offsets) => normScales = scales)))
+ // Cache data used in memory because the subsequently trainer needs to access the data multiple times.
+ .AppendCacheCheckpoint()
.Append(r => (
r.Label,
// Train the multi-class SDCA model to predict the label using features.
@@ -875,14 +929,14 @@ Here's a snippet of code that demonstrates normalization in learning pipelines.
var mlContext = new MLContext();
// Define the reader: specify the data columns and where to find them in the text file.
-var reader = mlContext.Data.TextReader(ctx => (
+var reader = mlContext.Data.CreateTextReader(ctx => (
// The four features of the Iris dataset will be grouped together as one Features column.
Features: ctx.LoadFloat(0, 3),
// Label: kind of iris.
Label: ctx.LoadText(4)
),
// Default separator is tab, but the dataset has comma.
- separator: ',');
+ separatorChar: ',');
// Read the training data.
var trainData = reader.Read(dataPath);
@@ -905,25 +959,26 @@ var meanVarValues = normalizedData.GetColumn(r => r.MeanVarNormalized).ToArray()
You can achieve the same results using the dynamic API.
```csharp
+//data model for the Iris class
+private class IrisInputAllFeatures
+{
+ // Unfortunately, we still need the dummy 'Label' column to be present.
+ [ColumnName("Label"), LoadColumn(4)]
+ public string IgnoredLabel { get; set; }
+
+ [LoadColumn(4, loadAllOthers:true)]
+ public float Features { get; set; }
+}
+
// Create a new context for ML.NET operations. It can be used for exception tracking and logging,
// as a catalog of available operations and as the source of randomness.
var mlContext = new MLContext();
-// Define the reader: specify the data columns and where to find them in the text file.
-var reader = mlContext.Data.TextReader(new TextLoader.Arguments
-{
- Column = new[] {
- // The four features of the Iris dataset will be grouped together as one Features column.
- new TextLoader.Column("Features", DataKind.R4, 0, 3),
- // Label: kind of iris.
- new TextLoader.Column("Label", DataKind.TX, 4),
- },
- // Default separator is tab, but the dataset has comma.
- Separator = ","
-});
-
// Read the training data.
-var trainData = reader.Read(dataPath);
+var trainData = mlContext.Data.ReadFromTextFile(dataPath,
+ // Default separator is tab, but the dataset has comma.
+ separatorChar: ','
+);
// Apply all kinds of standard ML.NET normalization to the raw features.
var pipeline =
@@ -969,7 +1024,7 @@ Label Workclass education marital-status occupation relationship ethnicity sex n
var mlContext = new MLContext();
// Define the reader: specify the data columns and where to find them in the text file.
-var reader = mlContext.Data.TextReader(ctx => (
+var reader = mlContext.Data.CreateTextReader(ctx => (
Label: ctx.LoadBool(0),
// We will load all the categorical features into one vector column of size 8.
CategoricalFeatures: ctx.LoadText(1, 8),
@@ -987,6 +1042,10 @@ var catColumns = data.GetColumn(r => r.CategoricalFeatures).Take(10).ToArray();
// Build several alternative featurization pipelines.
var learningPipeline = reader.MakeNewEstimator()
+ // Cache data in memory in an on-demand manner. Columns used in any downstream step will be
+ // cached in memory at their first uses. This step can be removed if user's machine doesn't
+ // have enough memory.
+ .AppendCacheCheckpoint()
.Append(r => (
r.Label,
r.NumericalFeatures,
@@ -1027,9 +1086,8 @@ You can achieve the same results using the dynamic API.
var mlContext = new MLContext();
// Define the reader: specify the data columns and where to find them in the text file.
-var reader = mlContext.Data.TextReader(new TextLoader.Arguments
-{
- Column = new[] {
+var reader = mlContext.Data.CreateTextReader(new[]
+ {
new TextLoader.Column("Label", DataKind.BL, 0),
// We will load all the categorical features into one vector column of size 8.
new TextLoader.Column("CategoricalFeatures", DataKind.TX, 1, 8),
@@ -1038,8 +1096,8 @@ var reader = mlContext.Data.TextReader(new TextLoader.Arguments
// Let's also separately load the 'Workclass' column.
new TextLoader.Column("Workclass", DataKind.TX, 1),
},
- HasHeader = true
-});
+ hasHeader: true
+);
// Read the data.
var data = reader.Read(dataPath);
@@ -1070,6 +1128,9 @@ var workclasses = transformedData.GetColumn(mlContext, "WorkclassOneHot
var fullLearningPipeline = dynamicPipeline
// Concatenate two of the 3 categorical pipelines, and the numeric features.
.Append(mlContext.Transforms.Concatenate("Features", "NumericalFeatures", "CategoricalBag", "WorkclassOneHotTrimmed"))
+ // Cache data in memory so that the following trainer will be able to access training examples without
+ // reading them from disk multiple times.
+ .AppendCacheCheckpoint(mlContext)
// Now we're ready to train. We chose our FastTree trainer for this classification task.
.Append(mlContext.BinaryClassification.Trainers.FastTree(numTrees: 50));
@@ -1108,7 +1169,7 @@ Sentiment SentimentText
var mlContext = new MLContext();
// Define the reader: specify the data columns and where to find them in the text file.
-var reader = mlContext.Data.TextReader(ctx => (
+var reader = mlContext.Data.CreateTextReader(ctx => (
IsToxic: ctx.LoadBool(0),
Message: ctx.LoadText(1)
), hasHeader: true);
@@ -1121,6 +1182,10 @@ var messageTexts = data.GetColumn(x => x.Message).Take(20).ToArray();
// Apply various kinds of text operations supported by ML.NET.
var learningPipeline = reader.MakeNewEstimator()
+ // Cache data in memory in an on-demand manner. Columns used in any downstream step will be
+ // cached in memory at their first uses. This step can be removed if user's machine doesn't
+ // have enough memory.
+ .AppendCacheCheckpoint()
.Append(r => (
// One-stop shop to run the full text featurization.
TextFeatures: r.Message.FeaturizeText(),
@@ -1154,14 +1219,13 @@ You can achieve the same results using the dynamic API.
var mlContext = new MLContext();
// Define the reader: specify the data columns and where to find them in the text file.
-var reader = mlContext.Data.TextReader(new TextLoader.Arguments
-{
- Column = new[] {
+var reader = mlContext.Data.CreateTextReader(new[]
+ {
new TextLoader.Column("IsToxic", DataKind.BL, 0),
new TextLoader.Column("Message", DataKind.TX, 1),
},
- HasHeader = true
-});
+ hasHeader: true
+);
// Read the data.
var data = reader.Read(dataPath);
@@ -1221,7 +1285,7 @@ var mlContext = new MLContext();
// Step one: read the data as an IDataView.
// First, we define the reader: specify the data columns and where to find them in the text file.
-var reader = mlContext.Data.TextReader(ctx => (
+var reader = mlContext.Data.CreateTextReader(ctx => (
// The four features of the Iris dataset.
SepalLength: ctx.LoadFloat(0),
SepalWidth: ctx.LoadFloat(1),
@@ -1231,7 +1295,7 @@ var reader = mlContext.Data.TextReader(ctx => (
Label: ctx.LoadText(4)
),
// Default separator is tab, but the dataset has comma.
- separator: ',');
+ separatorChar: ',');
// Read the data.
var data = reader.Read(dataPath);
@@ -1243,6 +1307,9 @@ var learningPipeline = reader.MakeNewEstimator()
Label: r.Label.ToKey(),
// Concatenate all the features together into one column 'Features'.
Features: r.SepalLength.ConcatWith(r.SepalWidth, r.PetalLength, r.PetalWidth)))
+ // Add a step for caching data in memory so that the downstream iterative training
+ // algorithm can efficiently scan through the data multiple times.
+ .AppendCacheCheckpoint()
.Append(r => (
r.Label,
// Train the multi-class SDCA model to predict the label using features.
@@ -1273,24 +1340,10 @@ You can achieve the same results using the dynamic API.
var mlContext = new MLContext();
// Step one: read the data as an IDataView.
-// First, we define the reader: specify the data columns and where to find them in the text file.
-var reader = mlContext.Data.TextReader(new TextLoader.Arguments
-{
- Column = new[] {
- // We read the first 11 values as a single float vector.
- new TextLoader.Column("SepalLength", DataKind.R4, 0),
- new TextLoader.Column("SepalWidth", DataKind.R4, 1),
- new TextLoader.Column("PetalLength", DataKind.R4, 2),
- new TextLoader.Column("PetalWidth", DataKind.R4, 3),
- // Label: kind of iris.
- new TextLoader.Column("Label", DataKind.TX, 4),
- },
+var data = mlContext.Data.ReadFromTextFile(dataPath,
// Default separator is tab, but the dataset has comma.
- Separator = ","
-});
-
-// Read the data.
-var data = reader.Read(dataPath);
+ separatorChar: ','
+);
// Build the training pipeline.
var dynamicPipeline =
@@ -1298,6 +1351,10 @@ var dynamicPipeline =
mlContext.Transforms.Concatenate("Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth")
// Note that the label is text, so it needs to be converted to key.
.Append(mlContext.Transforms.Conversions.MapValueToKey("Label"), TransformerScope.TrainTest)
+ // Cache data in memory so that SDCA trainer will be able to randomly access training examples without
+ // reading data from disk multiple times. Data will be cached at its first use in any downstream step.
+ // Notice that unused part in the data may not be cached.
+ .AppendCacheCheckpoint(mlContext)
// Use the multi-class SDCA model to predict the label using features.
.Append(mlContext.MulticlassClassification.Trainers.StochasticDualCoordinateAscent());
@@ -1335,7 +1392,7 @@ var mlContext = new MLContext();
// Read the data as an IDataView.
// First, we define the reader: specify the data columns and where to find them in the text file.
-var reader = mlContext.Data.TextReader(ctx => (
+var reader = mlContext.Data.CreateTextReader(ctx => (
// The four features of the Iris dataset.
SepalLength: ctx.LoadFloat(0),
SepalWidth: ctx.LoadFloat(1),
@@ -1345,7 +1402,7 @@ var reader = mlContext.Data.TextReader(ctx => (
Label: ctx.LoadText(4)
),
// Default separator is tab, but the dataset has comma.
- separator: ',');
+ separatorChar: ',');
// Read the data.
var data = reader.Read(dataPath);
@@ -1439,6 +1496,7 @@ public static ITransformer TrainModel(MLContext mlContext, IDataView trainData)
Action mapping = (input, output) => output.Label = input.Income > 50000;
// Construct the learning pipeline.
var estimator = mlContext.Transforms.CustomMapping(mapping, null)
+ .AppendCacheCheckpoint(mlContext)
.Append(mlContext.BinaryClassification.Trainers.FastTree(label: "Label"));
return estimator.Fit(trainData);
@@ -1480,8 +1538,12 @@ public class CustomMappings
var estimator = mlContext.Transforms.CustomMapping(CustomMappings.IncomeMapping, nameof(CustomMappings.IncomeMapping))
.Append(mlContext.BinaryClassification.Trainers.FastTree(label: "Label"));
+// If memory is enough, we can cache the data in-memory to avoid reading them from file
+// when it will be accessed multiple times.
+var cachedTrainData = mlContext.Data.Cache(trainData);
+
// Train the model.
-var model = estimator.Fit(trainData);
+var model = estimator.Fit(cachedTrainData);
// Save the model.
using (var fs = File.Create(modelPath))
diff --git a/docs/code/SchemaComprehension.md b/docs/code/SchemaComprehension.md
index 37238e0c0c..93e14a02db 100644
--- a/docs/code/SchemaComprehension.md
+++ b/docs/code/SchemaComprehension.md
@@ -65,7 +65,7 @@ static void Main(string[] args)
};
// Create the ML.NET environment.
- var env = new Microsoft.ML.Runtime.Data.TlcEnvironment();
+ var env = new Microsoft.ML.Data.TlcEnvironment();
// Create the data view.
// This method will use the definition of IrisData to understand what columns there are in the
@@ -74,7 +74,7 @@ static void Main(string[] args)
// Now let's do something to the data view. For example, concatenate all four non-label columns
// into 'Features' column.
- dv = new Microsoft.ML.Runtime.Data.ConcatTransform(env, dv, "Features",
+ dv = new Microsoft.ML.Data.ConcatTransform(env, dv, "Features",
"SepalLength", "SepalWidth", "PetalLength", "PetalWidth");
// Read the data into an another array, this time we read the 'Features' and 'Label' columns
diff --git a/docs/code/VBufferCareFeeding.md b/docs/code/VBufferCareFeeding.md
index 36c19c3184..91918c5e9b 100644
--- a/docs/code/VBufferCareFeeding.md
+++ b/docs/code/VBufferCareFeeding.md
@@ -224,17 +224,17 @@ ML.NET's runtime code has a number of utilities for operating over `VBuffer`s
that we have written to be generally useful. We will not treat on these in
detail here, but:
-* `Microsoft.ML.Runtime.Data.VBuffer` itself contains a few methods for
+* `Microsoft.ML.Data.VBuffer` itself contains a few methods for
accessing and iterating over its values.
-* `Microsoft.ML.Runtime.Internal.Utilities.VBufferUtils` contains utilities
+* `Microsoft.ML.Internal.Utilities.VBufferUtils` contains utilities
mainly for non-numeric manipulation of `VBuffer`s.
-* `Microsoft.ML.Runtime.Numeric.VectorUtils` contains math operations
+* `Microsoft.ML.Numeric.VectorUtils` contains math operations
over `VBuffer` and `float[]`, like computing norms, dot-products, and
whatnot.
-* `Microsoft.ML.Runtime.Data.BufferBuilder` is an abstract class whose
+* `Microsoft.ML.Data.BufferBuilder` is an abstract class whose
concrete implementations are used throughout ML.NET to build up `VBuffer`
instances. Note that if one *can* simply build a `VBuffer` oneself easily
and do not need the niceties provided by the buffer builder, you should
diff --git a/docs/project-docs/developer-guide.md b/docs/project-docs/developer-guide.md
index 074ea03dab..20e94d4e0c 100644
--- a/docs/project-docs/developer-guide.md
+++ b/docs/project-docs/developer-guide.md
@@ -25,6 +25,12 @@ build -?
**Examples**
+- Initialize the repo to make build possible (if the build fails because it can't find `mf.cpp` then perhaps you missed this step)
+
+```
+git submodule update --init
+```
+
- Building in release mode for platform x64
```
build.cmd -Release -TargetArchitecture:x64
@@ -58,4 +64,5 @@ One can build in Debug or Release mode from the root by doing `build.cmd -Releas
### Building other Architectures
-We only support 64-bit binaries right now.
\ No newline at end of file
+We only support 64-bit binaries right now.
+
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Calibrator.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Calibrator.cs
new file mode 100644
index 0000000000..78c62668ec
--- /dev/null
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Calibrator.cs
@@ -0,0 +1,113 @@
+using System;
+using System.Linq;
+using Microsoft.ML.Calibrator;
+using Microsoft.ML.Data;
+
+namespace Microsoft.ML.Samples.Dynamic
+{
+ ///
+ /// This example first trains a StochasticDualCoordinateAscentBinary Classifier and then convert its output to probability via training a calibrator.
+ ///
+ public class CalibratorExample
+ {
+ public static void Calibration()
+ {
+ // Downloading the dataset from github.com/dotnet/machinelearning.
+ // This will create a sentiment.tsv file in the filesystem.
+ // The string, dataFile, is the path to the downloaded file.
+ // You can open this file, if you want to see the data.
+ string dataFile = SamplesUtils.DatasetUtils.DownloadSentimentDataset();
+
+ // A preview of the data.
+ // Sentiment SentimentText
+ // 0 " :Erm, thank you. "
+ // 1 ==You're cool==
+
+ // Create a new context for ML.NET operations. It can be used for exception tracking and logging,
+ // as a catalog of available operations and as the source of randomness.
+ var mlContext = new MLContext();
+
+ // Create a text loader.
+ var reader = mlContext.Data.CreateTextReader(new TextLoader.Arguments()
+ {
+ Separator = "tab",
+ HasHeader = true,
+ Column = new[]
+ {
+ new TextLoader.Column("Sentiment", DataKind.BL, 0),
+ new TextLoader.Column("SentimentText", DataKind.Text, 1)
+ }
+ });
+
+ // Read the data
+ var data = reader.Read(dataFile);
+
+ // Split the dataset into two parts: one used for training, the other to train the calibrator
+ var (trainData, calibratorTrainingData) = mlContext.BinaryClassification.TrainTestSplit(data, testFraction: 0.1);
+
+ // Featurize the text column through the FeaturizeText API.
+ // Then append the StochasticDualCoordinateAscentBinary binary classifier, setting the "Label" column as the label of the dataset, and
+ // the "Features" column produced by FeaturizeText as the features column.
+ var pipeline = mlContext.Transforms.Text.FeaturizeText("SentimentText", "Features")
+ .Append(mlContext.BinaryClassification.Trainers.StochasticDualCoordinateAscent(
+ labelColumn: "Sentiment",
+ featureColumn: "Features",
+ l2Const: 0.001f,
+ loss: new HingeLoss())); // By specifying loss: new HingeLoss(), StochasticDualCoordinateAscent will train a support vector machine (SVM).
+
+ // Fit the pipeline, and get a transformer that knows how to score new data.
+ var transformer = pipeline.Fit(trainData);
+ IPredictor model = transformer.LastTransformer.Model;
+
+ // Let's score the new data. The score will give us a numerical estimation of the chance that the particular sample
+ // bears positive sentiment. This estimate is relative to the numbers obtained.
+ var scoredData = transformer.Transform(calibratorTrainingData);
+ var scoredDataPreview = scoredData.Preview();
+
+ PrintRowViewValues(scoredDataPreview);
+ // Preview of scoredDataPreview.RowView
+ //
+ // Score - 0.458968
+ // Score - 0.7022135
+ // Score 1.138822
+ // Score 0.4807112
+ // Score 1.112813
+
+ // Let's train a calibrator estimator on this scored dataset. The trained calibrator estimator produces a transformer
+ // that can transform the scored data by adding a new column names "Probability".
+ var calibratorEstimator = new PlattCalibratorEstimator(mlContext, model, "Sentiment", "Features");
+ var calibratorTransformer = calibratorEstimator.Fit(scoredData);
+
+ // Transform the scored data with a calibrator transfomer by adding a new column names "Probability".
+ // This column is a calibrated version of the "Score" column, meaning its values are a valid probability value in the [0, 1] interval
+ // representing the chance that the respective sample bears positive sentiment.
+ var finalData = calibratorTransformer.Transform(scoredData).Preview();
+
+ PrintRowViewValues(finalData);
+
+ //Preview of finalData.RowView
+ //
+ // Score - 0.458968 Probability 0.4670409
+ // Score - 0.7022135 Probability 0.3912723
+ // Score 1.138822 Probability 0.8703266
+ // Score 0.4807112 Probability 0.7437012
+ // Score 1.112813 Probability 0.8665403
+
+ }
+
+ private static void PrintRowViewValues(Data.DataDebuggerPreview data)
+ {
+ var firstRows = data.RowView.Take(5);
+
+ foreach(Data.DataDebuggerPreview.RowInfo row in firstRows)
+ {
+ foreach (var kvPair in row.Values)
+ {
+ if (kvPair.Key.Equals("Score") || kvPair.Key.Equals("Probability"))
+ Console.Write($" {kvPair.Key} {kvPair.Value} ");
+ }
+ Console.WriteLine();
+ }
+ }
+ }
+}
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/ConcatTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/ConcatTransform.cs
index ba945f38fa..c14439ee78 100644
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/ConcatTransform.cs
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/ConcatTransform.cs
@@ -1,8 +1,7 @@
-using Microsoft.ML.Runtime.Data;
-using Microsoft.ML.Runtime.Api;
-using System;
-using System.Linq;
+using System;
using System.Collections.Generic;
+using System.Linq;
+using Microsoft.ML.Data;
using Microsoft.ML.Transforms;
namespace Microsoft.ML.Samples.Dynamic
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/FastTreeRegression.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/FastTreeRegression.cs
new file mode 100644
index 0000000000..e747b4fc86
--- /dev/null
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/FastTreeRegression.cs
@@ -0,0 +1,50 @@
+using System;
+using System.Collections.Generic;
+using Microsoft.ML.Data;
+
+namespace Microsoft.ML.Samples.Dynamic
+{
+ public class FastTreeRegressionExample
+ {
+ public static void FastTreeRegression()
+ {
+ // Create a new ML context, for ML.NET operations. It can be used for exception tracking and logging,
+ // as well as the source of randomness.
+ var ml = new MLContext();
+
+ // Get a small dataset as an IEnumerable and convert it to an IDataView.
+ var data = SamplesUtils.DatasetUtils.GetInfertData();
+ var trainData = ml.CreateStreamingDataView(data);
+
+ // Preview of the data.
+ //
+ // Age Case Education Induced Parity PooledStratum RowNum ...
+ // 26 1 0-5yrs 1 6 3 1 ...
+ // 42 1 0-5yrs 1 1 1 2 ...
+ // 39 1 0-5yrs 2 6 4 3 ...
+ // 34 1 0-5yrs 2 4 2 4 ...
+ // 35 1 6-11yrs 1 3 32 5 ...
+
+ // A pipeline for concatenating the Parity and Induced columns together in the Features column.
+ // We will train a FastTreeRegression model with 1 tree on these two columns to predict Age.
+ string outputColumnName = "Features";
+ var pipeline = ml.Transforms.Concatenate(outputColumnName, new[] { "Parity", "Induced" })
+ .Append(ml.Regression.Trainers.FastTree(labelColumn: "Age", featureColumn: outputColumnName, numTrees: 1, numLeaves: 2, minDatapointsInLeaves: 1));
+
+ var model = pipeline.Fit(trainData);
+
+ // Get the trained model parameters.
+ var modelParams = model.LastTransformer.Model;
+
+ // Let's see where an example with Parity = 1 and Induced = 1 would end up in the single trained tree.
+ var testRow = new VBuffer(2, new[] { 1.0f, 1.0f });
+ // Use the path object to pass to GetLeaf, which will populate path with the IDs of th nodes from root to leaf.
+ List path = default;
+ // Get the ID of the leaf this example ends up in tree 0.
+ var leafID = modelParams.GetLeaf(0, in testRow, ref path);
+ // Get the leaf value for this leaf ID in tree 0.
+ var leafValue = modelParams.GetLeafValue(0, leafID);
+ Console.WriteLine("The leaf value in tree 0 is: " + leafValue);
+ }
+ }
+}
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/FeatureContributionCalculationTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/FeatureContributionCalculationTransform.cs
index 1c28df327f..96ef491bb7 100644
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/FeatureContributionCalculationTransform.cs
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/FeatureContributionCalculationTransform.cs
@@ -1,6 +1,5 @@
-using Microsoft.ML.Runtime.Api;
-using Microsoft.ML.Runtime.Data;
-using System;
+using System;
+using Microsoft.ML.Data;
namespace Microsoft.ML.Samples.Dynamic
{
@@ -19,11 +18,8 @@ public static void FeatureContributionCalculationTransform_Regression()
// Step 1: Read the data as an IDataView.
// First, we define the reader: specify the data columns and where to find them in the text file.
- var reader = mlContext.Data.TextReader(new TextLoader.Arguments()
- {
- Separator = "tab",
- HasHeader = true,
- Column = new[]
+ var reader = mlContext.Data.CreateTextReader(
+ columns: new[]
{
new TextLoader.Column("MedianHomeValue", DataKind.R4, 0),
new TextLoader.Column("CrimesPerCapita", DataKind.R4, 1),
@@ -37,8 +33,9 @@ public static void FeatureContributionCalculationTransform_Regression()
new TextLoader.Column("HighwayDistance", DataKind.R4, 9),
new TextLoader.Column("TaxRate", DataKind.R4, 10),
new TextLoader.Column("TeacherRatio", DataKind.R4, 11),
- }
- });
+ },
+ hasHeader: true
+ );
// Read the data
var data = reader.Read(dataFile);
@@ -50,22 +47,26 @@ public static void FeatureContributionCalculationTransform_Regression()
var transformPipeline = mlContext.Transforms.Concatenate("Features", "CrimesPerCapita", "PercentResidental",
"PercentNonRetail", "CharlesRiver", "NitricOxides", "RoomsPerDwelling", "PercentPre40s",
"EmploymentDistance", "HighwayDistance", "TaxRate", "TeacherRatio");
- var learner = mlContext.Regression.Trainers.StochasticDualCoordinateAscent(
+ var learner = mlContext.Regression.Trainers.OrdinaryLeastSquares(
labelColumn: "MedianHomeValue", featureColumn: "Features");
var transformedData = transformPipeline.Fit(data).Transform(data);
+ // Now we train the model and score it on the transformed data.
var model = learner.Fit(transformedData);
+ var scoredData = model.Transform(transformedData);
// Create a Feature Contribution Calculator
- // Calculate the feature contributions for all features
+ // Calculate the feature contributions for all features given trained model parameters
// And don't normalize the contribution scores
- var args = new FeatureContributionCalculationTransform.Arguments()
- {
- Top = 11,
- Normalize = false
- };
- var featureContributionCalculator = FeatureContributionCalculationTransform.Create(mlContext, args, transformedData, model.Model, model.FeatureColumn);
+ var featureContributionCalculator = mlContext.Model.Explainability.FeatureContributionCalculation(model.Model, model.FeatureColumn, top: 11, normalize: false);
+ var outputData = featureContributionCalculator.Fit(scoredData).Transform(scoredData);
+
+ // FeatureContributionCalculatingEstimator can be use as an intermediary step in a pipeline.
+ // The features retained by FeatureContributionCalculatingEstimator will be in the FeatureContribution column.
+ var pipeline = mlContext.Model.Explainability.FeatureContributionCalculation(model.Model, model.FeatureColumn, top: 11)
+ .Append(mlContext.Regression.Trainers.OrdinaryLeastSquares(featureColumn: "FeatureContributions"));
+ var outData = featureContributionCalculator.Fit(scoredData).Transform(scoredData);
// Let's extract the weights from the linear model to use as a comparison
var weights = new VBuffer();
@@ -73,9 +74,9 @@ public static void FeatureContributionCalculationTransform_Regression()
// Let's now walk through the first ten reconds and see which feature drove the values the most
// Get prediction scores and contributions
- var scoringEnumerator = featureContributionCalculator.AsEnumerable(mlContext, true).GetEnumerator();
+ var scoringEnumerator = outputData.AsEnumerable(mlContext, true).GetEnumerator();
int index = 0;
- Console.WriteLine("Label\tScore\tBiggestFeature\tValue\tWeight\tContribution\tPercent");
+ Console.WriteLine("Label\tScore\tBiggestFeature\tValue\tWeight\tContribution");
while (scoringEnumerator.MoveNext() && index < 10)
{
var row = scoringEnumerator.Current;
@@ -86,26 +87,34 @@ public static void FeatureContributionCalculationTransform_Regression()
// And the corresponding information about the feature
var value = row.Features[featureOfInterest];
var contribution = row.FeatureContributions[featureOfInterest];
- var percentContribution = 100 * contribution / row.Score;
- var name = data.Schema.GetColumnName(featureOfInterest + 1);
+ var name = data.Schema[featureOfInterest + 1].Name;
var weight = weights.GetValues()[featureOfInterest];
- Console.WriteLine("{0:0.00}\t{1:0.00}\t{2}\t{3:0.00}\t{4:0.00}\t{5:0.00}\t{6:0.00}",
+ Console.WriteLine("{0:0.00}\t{1:0.00}\t{2}\t{3:0.00}\t{4:0.00}\t{5:0.00}",
row.MedianHomeValue,
row.Score,
name,
value,
weight,
- contribution,
- percentContribution
+ contribution
);
index++;
}
-
- // For bulk scoring, the ApplyToData API can also be used
- var scoredData = featureContributionCalculator.ApplyToData(mlContext, transformedData);
- var preview = scoredData.Preview(100);
+ Console.ReadLine();
+
+ // The output of the above code is:
+ // Label Score BiggestFeature Value Weight Contribution
+ // 24.00 27.74 RoomsPerDwelling 6.58 98.55 39.95
+ // 21.60 23.85 RoomsPerDwelling 6.42 98.55 39.01
+ // 34.70 29.29 RoomsPerDwelling 7.19 98.55 43.65
+ // 33.40 27.17 RoomsPerDwelling 7.00 98.55 42.52
+ // 36.20 27.68 RoomsPerDwelling 7.15 98.55 43.42
+ // 28.70 23.13 RoomsPerDwelling 6.43 98.55 39.07
+ // 22.90 22.71 RoomsPerDwelling 6.01 98.55 36.53
+ // 27.10 21.72 RoomsPerDwelling 6.17 98.55 37.50
+ // 16.50 18.04 RoomsPerDwelling 5.63 98.55 34.21
+ // 18.90 20.14 RoomsPerDwelling 6.00 98.55 36.48
}
private static int GetMostContributingFeature(float[] featureContributions)
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/FeatureSelectionTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/FeatureSelectionTransform.cs
index 7508815dc4..aa3d4446bb 100644
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/FeatureSelectionTransform.cs
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/FeatureSelectionTransform.cs
@@ -1,7 +1,6 @@
-using Microsoft.ML.Data;
-using Microsoft.ML.Runtime.Data;
-using System;
+using System;
using System.Collections.Generic;
+using Microsoft.ML.Data;
namespace Microsoft.ML.Samples.Dynamic
{
@@ -31,16 +30,14 @@ public static void FeatureSelectionTransform()
// First, we define the reader: specify the data columns and where to find them in the text file. Notice that we combine entries from
// all the feature columns into entries of a vector of a single column named "Features".
- var reader = ml.Data.TextReader(new TextLoader.Arguments()
- {
- Separator = "tab",
- HasHeader = true,
- Column = new[]
+ var reader = ml.Data.CreateTextReader(
+ columns: new[]
{
new TextLoader.Column("Label", DataKind.BL, 0),
new TextLoader.Column("Features", DataKind.Num, new [] { new TextLoader.Range(1, 9) })
- }
- });
+ },
+ hasHeader: true
+ );
// Then, we use the reader to read the data as an IDataView.
var data = reader.Read(dataFilePath);
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/FieldAwareFactorizationMachine.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/FieldAwareFactorizationMachine.cs
new file mode 100644
index 0000000000..812de0fd27
--- /dev/null
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/FieldAwareFactorizationMachine.cs
@@ -0,0 +1,71 @@
+using System;
+using Microsoft.ML.Data;
+namespace Microsoft.ML.Samples.Dynamic
+{
+ public class FFM_BinaryClassificationExample
+ {
+ public static void FFM_BinaryClassification()
+ {
+ // Downloading the dataset from github.com/dotnet/machinelearning.
+ // This will create a sentiment.tsv file in the filesystem.
+ // You can open this file, if you want to see the data.
+ string dataFile = SamplesUtils.DatasetUtils.DownloadSentimentDataset();
+
+ // A preview of the data.
+ // Sentiment SentimentText
+ // 0 " :Erm, thank you. "
+ // 1 ==You're cool==
+
+ // Create a new context for ML.NET operations. It can be used for exception tracking and logging,
+ // as a catalog of available operations and as the source of randomness.
+ var mlContext = new MLContext();
+
+ // Step 1: Read the data as an IDataView.
+ // First, we define the reader: specify the data columns and where to find them in the text file.
+ var reader = mlContext.Data.CreateTextReader(
+ columns: new[]
+ {
+ new TextLoader.Column("Sentiment", DataKind.BL, 0),
+ new TextLoader.Column("SentimentText", DataKind.Text, 1)
+ },
+ hasHeader: true
+ );
+
+ // Read the data
+ var data = reader.Read(dataFile);
+
+ // ML.NET doesn't cache data set by default. Therefore, if one reads a data set from a file and accesses it many times, it can be slow due to
+ // expensive featurization and disk operations. When the considered data can fit into memory, a solution is to cache the data in memory. Caching is especially
+ // helpful when working with iterative algorithms which needs many data passes. Since SDCA is the case, we cache. Inserting a
+ // cache step in a pipeline is also possible, please see the construction of pipeline below.
+ data = mlContext.Data.Cache(data);
+
+ // Step 2: Pipeline
+ // Featurize the text column through the FeaturizeText API.
+ // Then append a binary classifier, setting the "Label" column as the label of the dataset, and
+ // the "Features" column produced by FeaturizeText as the features column.
+ var pipeline = mlContext.Transforms.Text.FeaturizeText("SentimentText", "Features")
+ .AppendCacheCheckpoint(mlContext) // Add a data-cache step within a pipeline.
+ .Append(mlContext.BinaryClassification.Trainers.FieldAwareFactorizationMachine(labelColumn: "Sentiment", featureColumns: new[] { "Features" }));
+
+ // Fit the model.
+ var model = pipeline.Fit(data);
+
+ // Let's get the model parameters from the model.
+ var modelParams = model.LastTransformer.Model;
+
+ // Let's inspect the model parameters.
+ var featureCount = modelParams.GetFeatureCount();
+ var fieldCount = modelParams.GetFieldCount();
+ var latentDim = modelParams.GetLatentDim();
+ var linearWeights = modelParams.GetLinearWeights();
+ var latentWeights = modelParams.GetLatentWeights();
+
+ Console.WriteLine("The feature count is: " + featureCount);
+ Console.WriteLine("The number of fields is: " + fieldCount);
+ Console.WriteLine("The latent dimension is: " + latentDim);
+ Console.WriteLine("The lineear weights of the features are: " + string.Join(", ", linearWeights));
+ Console.WriteLine("The weights of the latent features are: " + string.Join(", ", latentWeights));
+ }
+ }
+}
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/GeneralizedAdditiveModels.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/GeneralizedAdditiveModels.cs
index 827d04a586..acd08979a2 100644
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/GeneralizedAdditiveModels.cs
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/GeneralizedAdditiveModels.cs
@@ -1,6 +1,6 @@
-using Microsoft.ML.Runtime.Data;
-using System;
+using System;
using System.Linq;
+using Microsoft.ML.Data;
namespace Microsoft.ML.Samples.Dynamic
{
@@ -19,11 +19,8 @@ public static void RunExample()
// Step 1: Read the data as an IDataView.
// First, we define the reader: specify the data columns and where to find them in the text file.
- var reader = mlContext.Data.TextReader(new TextLoader.Arguments()
- {
- Separator = "tab",
- HasHeader = true,
- Column = new[]
+ var reader = mlContext.Data.CreateTextReader(
+ columns: new[]
{
new TextLoader.Column("MedianHomeValue", DataKind.R4, 0),
new TextLoader.Column("CrimesPerCapita", DataKind.R4, 1),
@@ -37,8 +34,9 @@ public static void RunExample()
new TextLoader.Column("HighwayDistance", DataKind.R4, 9),
new TextLoader.Column("TaxRate", DataKind.R4, 10),
new TextLoader.Column("TeacherRatio", DataKind.R4, 11),
- }
- });
+ },
+ hasHeader: true
+ );
// Read the data
var data = reader.Read(dataFile);
@@ -50,8 +48,8 @@ public static void RunExample()
// and use a small number of bins to make it easy to visualize in the console window.
// For real appplications, it is recommended to start with the default number of bins.
var labelName = "MedianHomeValue";
- var featureNames = data.Schema.GetColumns()
- .Select(tuple => tuple.column.Name) // Get the column names
+ var featureNames = data.Schema
+ .Select(column => column.Name) // Get the column names
.Where(name => name != labelName) // Drop the Label
.ToArray();
var pipeline = mlContext.Transforms.Concatenate("Features", featureNames)
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/IidChangePointDetectorTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/IidChangePointDetectorTransform.cs
index 52729c6c97..1a476d3f59 100644
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/IidChangePointDetectorTransform.cs
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/IidChangePointDetectorTransform.cs
@@ -1,13 +1,15 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
using System;
-using System.Linq;
using System.Collections.Generic;
-using Microsoft.ML.Runtime.Data;
-using Microsoft.ML.Runtime.Api;
-using Microsoft.ML.Runtime.TimeSeriesProcessing;
-using Microsoft.ML.Core.Data;
-using Microsoft.ML.TimeSeries;
using System.IO;
+using System.Linq;
+using Microsoft.ML.Core.Data;
using Microsoft.ML.Data;
+using Microsoft.ML.TimeSeries;
+using Microsoft.ML.TimeSeriesProcessing;
namespace Microsoft.ML.Samples.Dynamic
{
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/IidSpikeDetectorTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/IidSpikeDetectorTransform.cs
index 5dcb7e1774..572c6b140d 100644
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/IidSpikeDetectorTransform.cs
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/IidSpikeDetectorTransform.cs
@@ -1,13 +1,11 @@
using System;
+using System.Collections.Generic;
using System.IO;
using System.Linq;
-using System.Collections.Generic;
-using Microsoft.ML.Data;
-using Microsoft.ML.Runtime.Data;
-using Microsoft.ML.Runtime.Api;
-using Microsoft.ML.Runtime.TimeSeriesProcessing;
using Microsoft.ML.Core.Data;
+using Microsoft.ML.Data;
using Microsoft.ML.TimeSeries;
+using Microsoft.ML.TimeSeriesProcessing;
namespace Microsoft.ML.Samples.Dynamic
{
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/KMeans.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/KMeans.cs
new file mode 100644
index 0000000000..1a5ad75b26
--- /dev/null
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/KMeans.cs
@@ -0,0 +1,45 @@
+using System;
+using Microsoft.ML.Data;
+
+namespace Microsoft.ML.Samples.Dynamic
+{
+ public class KMeans_example
+ {
+ public static void KMeans()
+ {
+ // Create a new ML context, for ML.NET operations. It can be used for exception tracking and logging,
+ // as well as the source of randomness.
+ var ml = new MLContext();
+
+ // Get a small dataset as an IEnumerable and convert it to an IDataView.
+ var data = SamplesUtils.DatasetUtils.GetInfertData();
+ var trainData = ml.CreateStreamingDataView(data);
+
+ // Preview of the data.
+ //
+ // Age Case Education Induced Parity PooledStratum RowNum ...
+ // 26 1 0-5yrs 1 6 3 1 ...
+ // 42 1 0-5yrs 1 1 1 2 ...
+ // 39 1 0-5yrs 2 6 4 3 ...
+ // 34 1 0-5yrs 2 4 2 4 ...
+ // 35 1 6-11yrs 1 3 32 5 ...
+
+ // A pipeline for concatenating the age, parity and induced columns together in the Features column and training a KMeans model on them.
+ string outputColumnName = "Features";
+ var pipeline = ml.Transforms.Concatenate(outputColumnName, new[] { "Age", "Parity", "Induced" })
+ .Append(ml.Clustering.Trainers.KMeans(outputColumnName, clustersCount: 2));
+
+ var model = pipeline.Fit(trainData);
+
+ // Get cluster centroids and the number of clusters k from KMeansModelParameters.
+ VBuffer[] centroids = default;
+ int k;
+
+ var modelParams = model.LastTransformer.Model;
+ modelParams.GetClusterCentroids(ref centroids, out k);
+
+ var centroid = centroids[0].GetValues();
+ Console.WriteLine("The coordinates of centroid 0 are: " + string.Join(", ", centroid.ToArray()));
+ }
+ }
+}
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/KeyToValue_Term.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/KeyToValue_Term.cs
index d6e15eeb50..6d3deb9f84 100644
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/KeyToValue_Term.cs
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/KeyToValue_Term.cs
@@ -1,12 +1,8 @@
-using Microsoft.ML.Data;
-using Microsoft.ML.Runtime.Api;
-using Microsoft.ML.Runtime.Data;
-using Microsoft.ML.Transforms.Categorical;
+using System;
+using System.Collections.Generic;
+using Microsoft.ML.Data;
using Microsoft.ML.Transforms.Conversions;
using Microsoft.ML.Transforms.Text;
-using System;
-using System.Collections.Generic;
-using System.Linq;
namespace Microsoft.ML.Samples.Dynamic
{
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/LdaTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/LdaTransform.cs
index f2cbec7c28..4765fe06a4 100644
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/LdaTransform.cs
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/LdaTransform.cs
@@ -1,8 +1,6 @@
-using Microsoft.ML.Data;
-using Microsoft.ML.Runtime.Api;
-using Microsoft.ML.Runtime.Data;
-using System;
+using System;
using System.Collections.Generic;
+using Microsoft.ML.Data;
namespace Microsoft.ML.Samples.Dynamic
{
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/MatrixFactorization.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/MatrixFactorization.cs
index 063a554adc..ecd5336b03 100644
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/MatrixFactorization.cs
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/MatrixFactorization.cs
@@ -1,8 +1,6 @@
-using Microsoft.ML.Runtime.Api;
-using Microsoft.ML.Runtime.Data;
-using Microsoft.ML.Trainers;
using System;
using System.Collections.Generic;
+using Microsoft.ML.Data;
namespace Microsoft.ML.Samples.Dynamic
{
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/NgramExtraction.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/NgramExtraction.cs
index 3c9147f32a..524f04c834 100644
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/NgramExtraction.cs
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/NgramExtraction.cs
@@ -1,8 +1,6 @@
-using Microsoft.ML.Data;
-using Microsoft.ML.Runtime.Api;
-using Microsoft.ML.Runtime.Data;
-using System;
+using System;
using System.Collections.Generic;
+using Microsoft.ML.Data;
namespace Microsoft.ML.Samples.Dynamic
{
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Normalizer.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Normalizer.cs
index 121eede6e6..abeb454d51 100644
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/Normalizer.cs
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Normalizer.cs
@@ -1,9 +1,8 @@
-using Microsoft.ML.Data;
-using Microsoft.ML.Runtime.Api;
-using Microsoft.ML.Transforms.Normalizers;
-using System;
+using System;
using System.Collections.Generic;
using System.Linq;
+using Microsoft.ML.Data;
+using Microsoft.ML.Transforms.Normalizers;
namespace Microsoft.ML.Samples.Dynamic
{
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/OnnxTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/OnnxTransform.cs
new file mode 100644
index 0000000000..21d740b4ae
--- /dev/null
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/OnnxTransform.cs
@@ -0,0 +1,105 @@
+using System;
+using System.Linq;
+using Microsoft.ML.Data;
+using Microsoft.ML.OnnxRuntime;
+using Microsoft.ML.Transforms;
+
+namespace Microsoft.ML.Samples.Dynamic
+{
+ class OnnxTransformExample
+ {
+ ///
+ /// Example use of OnnxEstimator in an ML.NET pipeline
+ ///
+ public static void OnnxTransformSample()
+ {
+ // Download the squeeznet image model from ONNX model zoo, version 1.2
+ // https://github.com/onnx/models/tree/master/squeezenet
+ var modelPath = @"squeezenet\model.onnx";
+
+ // Inspect the model's inputs and outputs
+ var session = new InferenceSession(modelPath);
+ var inputInfo = session.InputMetadata.First();
+ var outputInfo = session.OutputMetadata.First();
+ Console.WriteLine($"Input Name is {String.Join(",", inputInfo.Key)}");
+ Console.WriteLine($"Input Dimensions are {String.Join(",", inputInfo.Value.Dimensions)}");
+ Console.WriteLine($"Output Name is {String.Join(",", outputInfo.Key)}");
+ Console.WriteLine($"Output Dimensions are {String.Join(",", outputInfo.Value.Dimensions)}");
+ // Results..
+ // Input Name is data_0
+ // Input Dimensions are 1,3,224,224
+ // Output Name is softmaxout_1
+ // Output Dimensions are 1,1000,1,1
+
+ // Create ML pipeline to score the data using OnnxScoringEstimator
+ var mlContext = new MLContext();
+ var data = GetTensorData();
+ var idv = mlContext.CreateStreamingDataView(data);
+ var pipeline = new OnnxScoringEstimator(mlContext, modelPath, new[] { inputInfo.Key }, new[] { outputInfo.Key });
+
+ // Run the pipeline and get the transformed values
+ var transformedValues = pipeline.Fit(idv).Transform(idv);
+
+ // Retrieve model scores into Prediction class
+ var predictions = transformedValues.AsEnumerable(mlContext, reuseRowObject: false);
+
+ // Iterate rows
+ foreach (var prediction in predictions)
+ {
+ int numClasses = 0;
+ foreach (var classScore in prediction.softmaxout_1.Take(3))
+ {
+ Console.WriteLine($"Class #{numClasses++} score = {classScore}");
+ }
+ Console.WriteLine(new string('-', 10));
+ }
+
+ // Results look like below...
+ // Class #0 score = 4.544065E-05
+ // Class #1 score = 0.003845858
+ // Class #2 score = 0.0001249467
+ // ----------
+ // Class #0 score = 4.491953E-05
+ // Class #1 score = 0.003848222
+ // Class #2 score = 0.0001245592
+ // ----------
+ }
+
+ ///
+ /// inputSize is the overall dimensions of the model input tensor.
+ ///
+ private const int inputSize = 224 * 224 * 3;
+
+ ///
+ /// A class to hold sample tensor data. Member name should match
+ /// the inputs that the model expects (in this case, data_0)
+ ///
+ public class TensorData
+ {
+ [VectorType(inputSize)]
+ public float[] data_0 { get; set; }
+ }
+
+ ///
+ /// Method to generate sample test data. Returns 2 sample rows.
+ ///
+ ///
+ public static TensorData[] GetTensorData()
+ {
+ // This can be any numerical data. Assume image pixel values.
+ var image1 = Enumerable.Range(0, inputSize).Select(x => (float)x / inputSize).ToArray();
+ var image2 = Enumerable.Range(0, inputSize).Select(x => (float)(x + 10000) / inputSize).ToArray();
+ return new TensorData[] { new TensorData() { data_0 = image1 }, new TensorData() { data_0 = image2 } };
+ }
+
+ ///
+ /// Class to contain the output values from the transformation.
+ /// This model generates a vector of 1000 floats.
+ ///
+ class Prediction
+ {
+ [VectorType(1000)]
+ public float[] softmaxout_1 { get; set; }
+ }
+ }
+}
\ No newline at end of file
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance.cs
deleted file mode 100644
index 0c95abacb8..0000000000
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance.cs
+++ /dev/null
@@ -1,128 +0,0 @@
-using Microsoft.ML.Runtime.Data;
-using Microsoft.ML.Runtime.Learners;
-using System;
-using System.Linq;
-
-namespace Microsoft.ML.Samples.Dynamic
-{
- public class PFI_RegressionExample
- {
- public static void PFI_Regression()
- {
- // Download the dataset from github.com/dotnet/machinelearning.
- // This will create a housing.txt file in the filesystem.
- // You can open this file to see the data.
- string dataFile = SamplesUtils.DatasetUtils.DownloadHousingRegressionDataset();
-
- // Create a new context for ML.NET operations. It can be used for exception tracking and logging,
- // as a catalog of available operations and as the source of randomness.
- var mlContext = new MLContext();
-
- // Step 1: Read the data as an IDataView.
- // First, we define the reader: specify the data columns and where to find them in the text file.
- // The data file is composed of rows of data, with each row having 11 numerical columns
- // separated by whitespace.
- var reader = mlContext.Data.TextReader(new TextLoader.Arguments()
- {
- Separator = "tab",
- HasHeader = true,
- Column = new[]
- {
- // Read the first column (indexed by 0) in the data file as an R4 (float)
- new TextLoader.Column("MedianHomeValue", DataKind.R4, 0),
- new TextLoader.Column("CrimesPerCapita", DataKind.R4, 1),
- new TextLoader.Column("PercentResidental", DataKind.R4, 2),
- new TextLoader.Column("PercentNonRetail", DataKind.R4, 3),
- new TextLoader.Column("CharlesRiver", DataKind.R4, 4),
- new TextLoader.Column("NitricOxides", DataKind.R4, 5),
- new TextLoader.Column("RoomsPerDwelling", DataKind.R4, 6),
- new TextLoader.Column("PercentPre40s", DataKind.R4, 7),
- new TextLoader.Column("EmploymentDistance", DataKind.R4, 8),
- new TextLoader.Column("HighwayDistance", DataKind.R4, 9),
- new TextLoader.Column("TaxRate", DataKind.R4, 10),
- new TextLoader.Column("TeacherRatio", DataKind.R4, 11),
- }
- });
-
- // Read the data
- var data = reader.Read(dataFile);
-
- // Step 2: Pipeline
- // Concatenate the features to create a Feature vector.
- // Normalize the data set so that for each feature, its maximum value is 1 while its minimum value is 0.
- // Then append a linear regression trainer, setting the "MedianHomeValue" column as the label of the dataset,
- // the "Features" column produced by concatenation as the features of the dataset.
- var labelName = "MedianHomeValue";
- var pipeline = mlContext.Transforms.Concatenate("Features", "CrimesPerCapita", "PercentResidental",
- "PercentNonRetail", "CharlesRiver", "NitricOxides", "RoomsPerDwelling", "PercentPre40s",
- "EmploymentDistance", "HighwayDistance", "TaxRate", "TeacherRatio")
- .Append(mlContext.Transforms.Normalize("Features"))
- .Append(mlContext.Regression.Trainers.StochasticDualCoordinateAscent(
- labelColumn: labelName, featureColumn: "Features"));
- var model = pipeline.Fit(data);
-
- // Extract the model from the pipeline
- var linearPredictor = model.LastTransformer;
- var weights = GetLinearModelWeights(linearPredictor.Model);
-
- // Compute the permutation metrics using the properly-featurized data.
- var transformedData = model.Transform(data);
- var permutationMetrics = mlContext.Regression.PermutationFeatureImportance(
- linearPredictor, transformedData, label: labelName, features: "Features");
-
- // Now let's look at which features are most important to the model overall
- // First, we have to prepare the data:
- // Get the feature names as an IEnumerable
- var featureNames = data.Schema.GetColumns()
- .Select(tuple => tuple.column.Name) // Get the column names
- .Where(name => name != labelName) // Drop the Label
- .ToArray();
-
- // Get the feature indices sorted by their impact on R-Squared
- var sortedIndices = permutationMetrics.Select((metrics, index) => new { index, metrics.RSquared })
- .OrderByDescending(feature => Math.Abs(feature.RSquared))
- .Select(feature => feature.index);
-
- // Print out the permutation results, with the model weights, in order of their impact:
- // Expected console output:
- // Feature Model Weight Change in R - Squared
- // RoomsPerDwelling 50.80 -0.3695
- // EmploymentDistance -17.79 -0.2238
- // TeacherRatio -19.83 -0.1228
- // TaxRate -8.60 -0.1042
- // NitricOxides -15.95 -0.1025
- // HighwayDistance 5.37 -0.09345
- // CrimesPerCapita -15.05 -0.05797
- // PercentPre40s -4.64 -0.0385
- // PercentResidental 3.98 -0.02184
- // CharlesRiver 3.38 -0.01487
- // PercentNonRetail -1.94 -0.007231
- //
- // Let's dig into these results a little bit. First, if you look at the weights of the model, they generally correlate
- // with the results of PFI, but there are some significant misorderings. For example, "Tax Rate" is weighted lower than
- // "Nitric Oxides" and "Crimes Per Capita", but the permutation analysis shows this feature to have a larger effect
- // on the accuracy of the model even though it has a relatively small weight. To understand why the weights don't
- // reflect the same feature importance as PFI, we need to go back to the basics of linear models: one of the
- // assumptions of a linear model is that the features are uncorrelated. Now, the features in this dataset are clearly
- // correlated: the tax rate for a house and the student-to-teacher ratio at the nearest school, for example, are often
- // coupled through school levies. The tax rate, presence of pollution (e.g. nitric oxides), and the crime rate would also
- // seem to be correlated with each other through social dynamics. We could draw out similar relationships for all the
- // variables in this dataset. The reason why the linear model weights don't reflect the same feature importance as PFI
- // is that the solution to the linear model redistributes weights between correlated variables in unpredictable ways, so
- // that the weights themselves are no longer a good measure of feature importance.
- Console.WriteLine("Feature\tModel Weight\tChange in R-Squared");
- var rSquared = permutationMetrics.Select(x => x.RSquared).ToArray(); // Fetch r-squared as an array
- foreach (int i in sortedIndices)
- {
- Console.WriteLine($"{featureNames[i]}\t{weights[i]:0.00}\t{rSquared[i]:G4}");
- }
- }
-
- private static float[] GetLinearModelWeights(LinearRegressionPredictor linearModel)
- {
- var weights = new VBuffer();
- linearModel.GetFeatureWeights(ref weights);
- return weights.GetValues().ToArray();
- }
- }
-}
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PFIHelper.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PFIHelper.cs
new file mode 100644
index 0000000000..ebb4def616
--- /dev/null
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PFIHelper.cs
@@ -0,0 +1,88 @@
+using System;
+using System.Linq;
+using Microsoft.ML.Data;
+using Microsoft.ML.Learners;
+using Microsoft.ML.Trainers.HalLearners;
+
+namespace Microsoft.ML.Samples.Dynamic.PermutationFeatureImportance
+{
+ public class PfiHelper
+ {
+ public static IDataView GetHousingRegressionIDataView(MLContext mlContext, out string labelName, out string[] featureNames, bool binaryPrediction = false)
+ {
+ // Download the dataset from github.com/dotnet/machinelearning.
+ // This will create a housing.txt file in the filesystem.
+ // You can open this file to see the data.
+ string dataFile = SamplesUtils.DatasetUtils.DownloadHousingRegressionDataset();
+
+ // Read the data as an IDataView.
+ // First, we define the reader: specify the data columns and where to find them in the text file.
+ // The data file is composed of rows of data, with each row having 11 numerical columns
+ // separated by whitespace.
+ var reader = mlContext.Data.CreateTextReader(
+ columns: new[]
+ {
+ // Read the first column (indexed by 0) in the data file as an R4 (float)
+ new TextLoader.Column("MedianHomeValue", DataKind.R4, 0),
+ new TextLoader.Column("CrimesPerCapita", DataKind.R4, 1),
+ new TextLoader.Column("PercentResidental", DataKind.R4, 2),
+ new TextLoader.Column("PercentNonRetail", DataKind.R4, 3),
+ new TextLoader.Column("CharlesRiver", DataKind.R4, 4),
+ new TextLoader.Column("NitricOxides", DataKind.R4, 5),
+ new TextLoader.Column("RoomsPerDwelling", DataKind.R4, 6),
+ new TextLoader.Column("PercentPre40s", DataKind.R4, 7),
+ new TextLoader.Column("EmploymentDistance", DataKind.R4, 8),
+ new TextLoader.Column("HighwayDistance", DataKind.R4, 9),
+ new TextLoader.Column("TaxRate", DataKind.R4, 10),
+ new TextLoader.Column("TeacherRatio", DataKind.R4, 11),
+ },
+ hasHeader: true
+ );
+
+ // Read the data
+ var data = reader.Read(dataFile);
+ var labelColumn = "MedianHomeValue";
+
+ if (binaryPrediction)
+ {
+ labelColumn = nameof(BinaryOutputRow.AboveAverage);
+ data = mlContext.Transforms.CustomMappingTransformer(GreaterThanAverage, null).Transform(data);
+ data = mlContext.Transforms.DropColumns("MedianHomeValue").Fit(data).Transform(data);
+ }
+
+ labelName = labelColumn;
+ featureNames = data.Schema.AsEnumerable()
+ .Select(column => column.Name) // Get the column names
+ .Where(name => name != labelColumn) // Drop the Label
+ .ToArray();
+
+ return data;
+ }
+
+ // Define a class for all the input columns that we intend to consume.
+ private class ContinuousInputRow
+ {
+ public float MedianHomeValue { get; set; }
+ }
+
+ // Define a class for all output columns that we intend to produce.
+ private class BinaryOutputRow
+ {
+ public bool AboveAverage { get; set; }
+ }
+
+ // Define an Action to apply a custom mapping from one object to the other
+ private readonly static Action GreaterThanAverage = (input, output)
+ => output.AboveAverage = input.MedianHomeValue > 22.6;
+
+ public static float[] GetLinearModelWeights(OlsLinearRegressionModelParameters linearModel)
+ {
+ return linearModel.Weights.ToArray();
+ }
+
+ public static float[] GetLinearModelWeights(LinearBinaryModelParameters linearModel)
+ {
+ return linearModel.Weights.ToArray();
+ }
+ }
+}
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PFIRegressionExample.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PFIRegressionExample.cs
new file mode 100644
index 0000000000..e552a841ed
--- /dev/null
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PFIRegressionExample.cs
@@ -0,0 +1,77 @@
+using System;
+using System.Linq;
+
+namespace Microsoft.ML.Samples.Dynamic.PermutationFeatureImportance
+{
+ public class PfiRegressionExample
+ {
+ public static void RunExample()
+ {
+ // Create a new context for ML.NET operations. It can be used for exception tracking and logging,
+ // as a catalog of available operations and as the source of randomness.
+ var mlContext = new MLContext();
+
+ // Step 1: Read the data
+ var data = PfiHelper.GetHousingRegressionIDataView(mlContext, out string labelName, out string[] featureNames);
+
+ // Step 2: Pipeline
+ // Concatenate the features to create a Feature vector.
+ // Normalize the data set so that for each feature, its maximum value is 1 while its minimum value is 0.
+ // Then append a linear regression trainer.
+ var pipeline = mlContext.Transforms.Concatenate("Features", featureNames)
+ .Append(mlContext.Transforms.Normalize("Features"))
+ .Append(mlContext.Regression.Trainers.OrdinaryLeastSquares(
+ labelColumn: labelName, featureColumn: "Features"));
+ var model = pipeline.Fit(data);
+
+ // Extract the model from the pipeline
+ var linearPredictor = model.LastTransformer;
+ var weights = PfiHelper.GetLinearModelWeights(linearPredictor.Model);
+
+ // Compute the permutation metrics using the properly normalized data.
+ var transformedData = model.Transform(data);
+ var permutationMetrics = mlContext.Regression.PermutationFeatureImportance(
+ linearPredictor, transformedData, label: labelName, features: "Features", permutationCount: 3);
+
+ // Now let's look at which features are most important to the model overall
+ // Get the feature indices sorted by their impact on R-Squared
+ var sortedIndices = permutationMetrics.Select((metrics, index) => new { index, metrics.RSquared })
+ .OrderByDescending(feature => Math.Abs(feature.RSquared.Mean))
+ .Select(feature => feature.index);
+
+ // Print out the permutation results, with the model weights, in order of their impact:
+ // Expected console output for 100 permutations:
+ // Feature Model Weight Change in R-Squared 95% Confidence Interval of the Mean
+ // RoomsPerDwelling 53.35 -0.4298 0.005705
+ // EmploymentDistance -19.21 -0.2609 0.004591
+ // NitricOxides -19.32 -0.1569 0.003701
+ // HighwayDistance 6.11 -0.1173 0.0025
+ // TeacherRatio -21.92 -0.1106 0.002207
+ // TaxRate -8.68 -0.1008 0.002083
+ // CrimesPerCapita -16.37 -0.05988 0.00178
+ // PercentPre40s -4.52 -0.03836 0.001432
+ // PercentResidental 3.91 -0.02006 0.001079
+ // CharlesRiver 3.49 -0.01839 0.000841
+ // PercentNonRetail -1.17 -0.002111 0.0003176
+ //
+ // Let's dig into these results a little bit. First, if you look at the weights of the model, they generally correlate
+ // with the results of PFI, but there are some significant misorderings. For example, "Tax Rate" and "Highway Distance"
+ // have relatively small model weights, but the permutation analysis shows these feature to have a larger effect
+ // on the accuracy of the model than higher-weighted features. To understand why the weights don't reflect the same
+ // feature importance as PFI, we need to go back to the basics of linear models: one of the assumptions of a linear
+ // model is that the features are uncorrelated. Now, the features in this dataset are clearly correlated: the tax rate
+ // for a house and the student-to-teacher ratio at the nearest school, for example, are often coupled through school
+ // levies. The tax rate, distance to a highway, and the crime rate would also seem to be correlated through social
+ // dynamics. We could draw out similar relationships for all variables in this dataset. The reason why the linear
+ // model weights don't reflect the same feature importance as PFI is that the solution to the linear model redistributes
+ // weights between correlated variables in unpredictable ways, so that the weights themselves are no longer a good
+ // measure of feature importance.
+ Console.WriteLine("Feature\tModel Weight\tChange in R-Squared\t95% Confidence Interval of the Mean");
+ var rSquared = permutationMetrics.Select(x => x.RSquared).ToArray(); // Fetch r-squared as an array
+ foreach (int i in sortedIndices)
+ {
+ Console.WriteLine($"{featureNames[i]}\t{weights[i]:0.00}\t{rSquared[i].Mean:G4}\t{1.96 * rSquared[i].StandardError:G4}");
+ }
+ }
+ }
+}
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PfiBinaryClassificationExample.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PfiBinaryClassificationExample.cs
new file mode 100644
index 0000000000..e75fa170c9
--- /dev/null
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PfiBinaryClassificationExample.cs
@@ -0,0 +1,76 @@
+using System;
+using System.Linq;
+using Microsoft.ML.Learners;
+
+namespace Microsoft.ML.Samples.Dynamic.PermutationFeatureImportance
+{
+ public class PfiBinaryClassificationExample
+ {
+ public static void RunExample()
+ {
+ // Create a new context for ML.NET operations. It can be used for exception tracking and logging,
+ // as a catalog of available operations and as the source of randomness.
+ var mlContext = new MLContext(seed:999123);
+
+ // Step 1: Read the data
+ var data = PfiHelper.GetHousingRegressionIDataView(mlContext,
+ out string labelName, out string[] featureNames, binaryPrediction: true);
+
+ // Step 2: Pipeline
+ // Concatenate the features to create a Feature vector.
+ // Normalize the data set so that for each feature, its maximum value is 1 while its minimum value is 0.
+ // Then append a logistic regression trainer.
+ var pipeline = mlContext.Transforms.Concatenate("Features", featureNames)
+ .Append(mlContext.Transforms.Normalize("Features"))
+ .Append(mlContext.BinaryClassification.Trainers.LogisticRegression(
+ labelColumn: labelName, featureColumn: "Features"));
+ var model = pipeline.Fit(data);
+
+ // Extract the model from the pipeline
+ var linearPredictor = model.LastTransformer;
+ // Linear models for binary classification are wrapped by a calibrator as a generic predictor
+ // To access it directly, we must extract it out and cast it to the proper class
+ var weights = PfiHelper.GetLinearModelWeights(linearPredictor.Model.SubPredictor as LinearBinaryModelParameters);
+
+ // Compute the permutation metrics using the properly normalized data.
+ var transformedData = model.Transform(data);
+ var permutationMetrics = mlContext.BinaryClassification.PermutationFeatureImportance(
+ linearPredictor, transformedData, label: labelName, features: "Features", permutationCount: 3);
+
+ // Now let's look at which features are most important to the model overall
+ // Get the feature indices sorted by their impact on AUC
+ var sortedIndices = permutationMetrics.Select((metrics, index) => new { index, metrics.Auc })
+ .OrderByDescending(feature => Math.Abs(feature.Auc.Mean))
+ .Select(feature => feature.index);
+
+ // Print out the permutation results, with the model weights, in order of their impact:
+ // Expected console output (for 100 permutations):
+ // Feature Model Weight Change in AUC 95% Confidence in the Mean Change in AUC
+ // PercentPre40s -1.96 -0.06316 0.002377
+ // RoomsPerDwelling 3.71 -0.04385 0.001245
+ // EmploymentDistance -1.31 -0.02139 0.0006867
+ // TeacherRatio -2.46 -0.0203 0.0009566
+ // PercentNonRetail -1.58 -0.01846 0.001586
+ // CharlesRiver 0.66 -0.008605 0.0005136
+ // PercentResidental 0.60 0.002483 0.0004818
+ // TaxRate -0.95 -0.00221 0.0007394
+ // NitricOxides -0.32 0.00101 0.0001428
+ // CrimesPerCapita -0.04 -3.029E-05 1.678E-05
+ // HighwayDistance 0.00 0 0
+ // Let's look at these results.
+ // First, if you look at the weights of the model, they generally correlate with the results of PFI,
+ // but there are some significant misorderings. See the discussion in the Regression example for an
+ // explanation of why this happens and how to interpret it.
+ // Second, the logistic regression learner uses L1 regularization by default. Here, it causes the "HighWay Distance"
+ // feature to be zeroed out from the model. PFI assigns zero importance to this variable, as expected.
+ // Third, some features show an *increase* in AUC. This means that the model actually improved
+ // when these features were shuffled. This is a sign to investigate these features further.
+ Console.WriteLine("Feature\tModel Weight\tChange in AUC\t95% Confidence in the Mean Change in AUC");
+ var auc = permutationMetrics.Select(x => x.Auc).ToArray(); // Fetch AUC as an array
+ foreach (int i in sortedIndices)
+ {
+ Console.WriteLine($"{featureNames[i]}\t{weights[i]:0.00}\t{auc[i].Mean:G4}\t{1.96 * auc[i].StandardError:G4}");
+ }
+ }
+ }
+}
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/ProjectionTransforms.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/ProjectionTransforms.cs
index 0f7761cfba..15ffbf6653 100644
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/ProjectionTransforms.cs
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/ProjectionTransforms.cs
@@ -1,9 +1,7 @@
-using Microsoft.ML.Runtime.Api;
-using Microsoft.ML.Runtime.Data;
-using Microsoft.ML.Data;
-using System;
+using System;
using System.Collections.Generic;
using System.Linq;
+using Microsoft.ML.Data;
namespace Microsoft.ML.Samples.Dynamic
{
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/SDCA.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/SDCA.cs
index 32a33cdc43..5f08dee906 100644
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/SDCA.cs
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/SDCA.cs
@@ -1,6 +1,6 @@
-using Microsoft.ML.Runtime.Data;
-using System;
+using System;
using System.Linq;
+using Microsoft.ML.Data;
namespace Microsoft.ML.Samples.Dynamic
{
@@ -24,25 +24,30 @@ public static void SDCA_BinaryClassification()
// Step 1: Read the data as an IDataView.
// First, we define the reader: specify the data columns and where to find them in the text file.
- var reader = mlContext.Data.TextReader(new TextLoader.Arguments()
- {
- Separator = "tab",
- HasHeader = true,
- Column = new[]
+ var reader = mlContext.Data.CreateTextReader(
+ columns: new[]
{
new TextLoader.Column("Sentiment", DataKind.BL, 0),
new TextLoader.Column("SentimentText", DataKind.Text, 1)
- }
- });
+ },
+ hasHeader: true
+ );
// Read the data
var data = reader.Read(dataFile);
+ // ML.NET doesn't cache data set by default. Therefore, if one reads a data set from a file and accesses it many times, it can be slow due to
+ // expensive featurization and disk operations. When the considered data can fit into memory, a solution is to cache the data in memory. Caching is especially
+ // helpful when working with iterative algorithms which needs many data passes. Since SDCA is the case, we cache. Inserting a
+ // cache step in a pipeline is also possible, please see the construction of pipeline below.
+ data = mlContext.Data.Cache(data);
+
// Step 2: Pipeline
// Featurize the text column through the FeaturizeText API.
// Then append a binary classifier, setting the "Label" column as the label of the dataset, and
- // the "Features" column produced by FeaturizeText as the features column.
+ // the "Features" column produced by FeaturizeText as the features column.
var pipeline = mlContext.Transforms.Text.FeaturizeText("SentimentText", "Features")
+ .AppendCacheCheckpoint(mlContext) // Add a data-cache step within a pipeline.
.Append(mlContext.BinaryClassification.Trainers.StochasticDualCoordinateAscent(labelColumn: "Sentiment", featureColumn: "Features", l2Const: 0.001f));
// Step 3: Run Cross-Validation on this pipeline.
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/SDCARegression.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/SDCARegression.cs
new file mode 100644
index 0000000000..c75872011c
--- /dev/null
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/SDCARegression.cs
@@ -0,0 +1,43 @@
+using System;
+using System.Linq;
+using Microsoft.ML.Data;
+
+namespace Microsoft.ML.Samples.Dynamic
+{
+ public class SDCARegressionExample
+ {
+ public static void SDCARegression()
+ {
+ // Create a new ML context, for ML.NET operations. It can be used for exception tracking and logging,
+ // as well as the source of randomness.
+ var ml = new MLContext();
+
+ // Get a small dataset as an IEnumerable and convert it to an IDataView.
+ var data = SamplesUtils.DatasetUtils.GetInfertData();
+ var trainData = ml.CreateStreamingDataView(data);
+
+ // Preview of the data.
+ //
+ // Age Case Education Induced Parity PooledStratum RowNum ...
+ // 26 1 0-5yrs 1 6 3 1 ...
+ // 42 1 0-5yrs 1 1 1 2 ...
+ // 39 1 0-5yrs 2 6 4 3 ...
+ // 34 1 0-5yrs 2 4 2 4 ...
+ // 35 1 6-11yrs 1 3 32 5 ...
+
+ // A pipeline for concatenating the Parity and Induced columns together in the Features column.
+ // We will train a FastTreeRegression model with 1 tree on these two columns to predict Age.
+ string outputColumnName = "Features";
+ var pipeline = ml.Transforms.Concatenate(outputColumnName, new[] { "Parity", "Induced" })
+ .Append(ml.Regression.Trainers.StochasticDualCoordinateAscent(labelColumn: "Age", featureColumn: outputColumnName, maxIterations:2));
+
+ var model = pipeline.Fit(trainData);
+
+ // Get the trained model parameters.
+ var modelParams = model.LastTransformer.Model;
+ // Inspect the bias and model weights.
+ Console.WriteLine("The bias term is: " + modelParams.Bias);
+ Console.WriteLine("The feature weights are: " + string.Join(", ", modelParams.Weights.ToArray()));
+ }
+ }
+}
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/SsaChangePointDetectorTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/SsaChangePointDetectorTransform.cs
index e03ba68606..7eba634243 100644
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/SsaChangePointDetectorTransform.cs
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/SsaChangePointDetectorTransform.cs
@@ -1,12 +1,11 @@
-using Microsoft.ML.Core.Data;
-using Microsoft.ML.Data;
-using Microsoft.ML.Runtime.Api;
-using Microsoft.ML.Runtime.TimeSeriesProcessing;
-using Microsoft.ML.TimeSeries;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
+using Microsoft.ML.Core.Data;
+using Microsoft.ML.Data;
+using Microsoft.ML.TimeSeries;
+using Microsoft.ML.TimeSeriesProcessing;
namespace Microsoft.ML.Samples.Dynamic
{
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/SsaSpikeDetectorTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/SsaSpikeDetectorTransform.cs
index fe9e981ecd..ff915dac1f 100644
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/SsaSpikeDetectorTransform.cs
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/SsaSpikeDetectorTransform.cs
@@ -1,13 +1,11 @@
-using Microsoft.ML.Core.Data;
-using Microsoft.ML.Data;
-using Microsoft.ML.Runtime.Api;
-using Microsoft.ML.Runtime.Data;
-using Microsoft.ML.Runtime.TimeSeriesProcessing;
-using Microsoft.ML.TimeSeries;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
+using Microsoft.ML.Core.Data;
+using Microsoft.ML.Data;
+using Microsoft.ML.TimeSeries;
+using Microsoft.ML.TimeSeriesProcessing;
namespace Microsoft.ML.Samples.Dynamic
{
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/TensorFlowTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/TensorFlowTransform.cs
new file mode 100644
index 0000000000..41d1eac9fe
--- /dev/null
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/TensorFlowTransform.cs
@@ -0,0 +1,91 @@
+using System;
+using System.Linq;
+using Microsoft.ML.Data;
+
+namespace Microsoft.ML.Samples.Dynamic
+{
+ class TensorFlowTransformExample
+ {
+ ///
+ /// Example use of the TensorFlowEstimator in a ML.NET pipeline.
+ ///
+ public static void TensorFlowScoringSample()
+ {
+ // Download the ResNet 101 model from the location below.
+ // https://storage.googleapis.com/download.tensorflow.org/models/tflite_11_05_08/resnet_v2_101.tgz
+ var modelLocation = @"resnet_v2_101/resnet_v2_101_299_frozen.pb";
+
+ var mlContext = new MLContext();
+ var data = GetTensorData();
+ var idv = mlContext.CreateStreamingDataView(data);
+
+ // Create a ML pipeline.
+ var pipeline = mlContext.Transforms.ScoreTensorFlowModel(
+ modelLocation,
+ new[] { nameof(TensorData.input) },
+ new[] { nameof(OutputScores.output) });
+
+ // Run the pipeline and get the transformed values.
+ var estimator = pipeline.Fit(idv);
+ var transformedValues = estimator.Transform(idv);
+
+ // Retrieve model scores.
+ var outScores = transformedValues.AsEnumerable(mlContext, reuseRowObject: false);
+
+ // Display scores. (for the sake of brevity we display scores of the first 3 classes)
+ foreach (var prediction in outScores)
+ {
+ int numClasses = 0;
+ foreach (var classScore in prediction.output.Take(3))
+ {
+ Console.WriteLine($"Class #{numClasses++} score = {classScore}");
+ }
+ Console.WriteLine(new string('-', 10));
+ }
+
+ // Results look like below...
+ //Class #0 score = -0.8092947
+ //Class #1 score = -0.3310375
+ //Class #2 score = 0.1119193
+ //----------
+ //Class #0 score = -0.7807726
+ //Class #1 score = -0.2158062
+ //Class #2 score = 0.1153686
+ //----------
+ }
+
+ private const int imageHeight = 224;
+ private const int imageWidth = 224;
+ private const int numChannels = 3;
+ private const int inputSize = imageHeight * imageWidth * numChannels;
+
+ ///
+ /// A class to hold sample tensor data.
+ /// Member name should match the inputs that the model expects (in this case, input).
+ ///
+ public class TensorData
+ {
+ [VectorType(imageHeight, imageWidth, numChannels)]
+ public float[] input { get; set; }
+ }
+
+ ///
+ /// Method to generate sample test data. Returns 2 sample rows.
+ ///
+ public static TensorData[] GetTensorData()
+ {
+ // This can be any numerical data. Assume image pixel values.
+ var image1 = Enumerable.Range(0, inputSize).Select(x => (float)x / inputSize).ToArray();
+ var image2 = Enumerable.Range(0, inputSize).Select(x => (float)(x + 10000) / inputSize).ToArray();
+ return new TensorData[] { new TensorData() { input = image1 }, new TensorData() { input = image2 } };
+ }
+
+ ///
+ /// Class to contain the output values from the transformation.
+ ///
+ class OutputScores
+ {
+ public float[] output { get; set; }
+ }
+ }
+}
\ No newline at end of file
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/TextTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/TextTransform.cs
index 57fbc7e493..3c8964b6f7 100644
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/TextTransform.cs
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/TextTransform.cs
@@ -1,9 +1,7 @@
-using Microsoft.ML.Runtime.Data;
-using Microsoft.ML.Runtime.Api;
+using System;
+using System.Collections.Generic;
using Microsoft.ML.Data;
using Microsoft.ML.Transforms.Text;
-using System;
-using System.Collections.Generic;
namespace Microsoft.ML.Samples.Dynamic
{
diff --git a/docs/samples/Microsoft.ML.Samples/Microsoft.ML.Samples.csproj b/docs/samples/Microsoft.ML.Samples/Microsoft.ML.Samples.csproj
index abea9f3e8d..44b673640f 100644
--- a/docs/samples/Microsoft.ML.Samples/Microsoft.ML.Samples.csproj
+++ b/docs/samples/Microsoft.ML.Samples/Microsoft.ML.Samples.csproj
@@ -7,17 +7,23 @@
+
+
+
+
+
+ false
diff --git a/docs/samples/Microsoft.ML.Samples/Program.cs b/docs/samples/Microsoft.ML.Samples/Program.cs
index 83a7e8c2e8..41f465ef5e 100644
--- a/docs/samples/Microsoft.ML.Samples/Program.cs
+++ b/docs/samples/Microsoft.ML.Samples/Program.cs
@@ -6,7 +6,7 @@ internal static class Program
{
static void Main(string[] args)
{
- LdaTransformExample.LdaTransform();
+ FeatureContributionCalculationTransform_RegressionExample.FeatureContributionCalculationTransform_Regression();
}
}
}
diff --git a/docs/samples/Microsoft.ML.Samples/Static/AveragedPerceptronBinaryClassification.cs b/docs/samples/Microsoft.ML.Samples/Static/AveragedPerceptronBinaryClassification.cs
index d2359526e9..176f55eff6 100644
--- a/docs/samples/Microsoft.ML.Samples/Static/AveragedPerceptronBinaryClassification.cs
+++ b/docs/samples/Microsoft.ML.Samples/Static/AveragedPerceptronBinaryClassification.cs
@@ -1,8 +1,6 @@
-using Microsoft.ML.Runtime.Data;
+using System;
+using Microsoft.ML.Data;
using Microsoft.ML.StaticPipe;
-using Microsoft.ML.Transforms;
-using Microsoft.ML.Transforms.Categorical;
-using System;
namespace Microsoft.ML.Samples.Static
{
diff --git a/docs/samples/Microsoft.ML.Samples/Static/FastTreeBinaryClassification.cs b/docs/samples/Microsoft.ML.Samples/Static/FastTreeBinaryClassification.cs
index d28b7c79de..55835aa0b3 100644
--- a/docs/samples/Microsoft.ML.Samples/Static/FastTreeBinaryClassification.cs
+++ b/docs/samples/Microsoft.ML.Samples/Static/FastTreeBinaryClassification.cs
@@ -1,9 +1,6 @@
-using Microsoft.ML.Runtime.Data;
+using System;
+using Microsoft.ML.Data;
using Microsoft.ML.StaticPipe;
-using Microsoft.ML.Transforms;
-using Microsoft.ML.Transforms.Categorical;
-using Microsoft.ML.Transforms.FeatureSelection;
-using System;
namespace Microsoft.ML.Samples.Static
{
diff --git a/docs/samples/Microsoft.ML.Samples/Static/FastTreeRegression.cs b/docs/samples/Microsoft.ML.Samples/Static/FastTreeRegression.cs
index 66ddc6772c..10e3edb7ea 100644
--- a/docs/samples/Microsoft.ML.Samples/Static/FastTreeRegression.cs
+++ b/docs/samples/Microsoft.ML.Samples/Static/FastTreeRegression.cs
@@ -1,8 +1,8 @@
-using Microsoft.ML.Runtime.Data;
-using Microsoft.ML.Trainers.FastTree;
-using Microsoft.ML.StaticPipe;
-using System;
+using System;
using System.Linq;
+using Microsoft.ML.Data;
+using Microsoft.ML.StaticPipe;
+using Microsoft.ML.Trainers.FastTree;
namespace Microsoft.ML.Samples.Static
{
@@ -30,7 +30,7 @@ public static void FastTreeRegression()
var data = reader.Read(dataFile);
// The predictor that gets produced out of training
- FastTreeRegressionPredictor pred = null;
+ FastTreeRegressionModelParameters pred = null;
// Create the estimator
var learningPipeline = reader.MakeNewEstimator()
diff --git a/docs/samples/Microsoft.ML.Samples/Static/FeatureSelectionTransform.cs b/docs/samples/Microsoft.ML.Samples/Static/FeatureSelectionTransform.cs
index d38f428eea..6c130dfc4b 100644
--- a/docs/samples/Microsoft.ML.Samples/Static/FeatureSelectionTransform.cs
+++ b/docs/samples/Microsoft.ML.Samples/Static/FeatureSelectionTransform.cs
@@ -1,8 +1,7 @@
-using Microsoft.ML.Data;
-using Microsoft.ML.Runtime.Data;
-using Microsoft.ML.StaticPipe;
-using System;
+using System;
using System.Collections.Generic;
+using Microsoft.ML.Data;
+using Microsoft.ML.StaticPipe;
namespace Microsoft.ML.Samples.Dynamic
{
diff --git a/docs/samples/Microsoft.ML.Samples/Static/LightGBMBinaryClassification.cs b/docs/samples/Microsoft.ML.Samples/Static/LightGBMBinaryClassification.cs
index 53899d95fe..bee5be6c21 100644
--- a/docs/samples/Microsoft.ML.Samples/Static/LightGBMBinaryClassification.cs
+++ b/docs/samples/Microsoft.ML.Samples/Static/LightGBMBinaryClassification.cs
@@ -1,9 +1,7 @@
-using Microsoft.ML.Runtime.Data;
-using Microsoft.ML.StaticPipe;
-using Microsoft.ML.Transforms;
-using Microsoft.ML.Transforms.Categorical;
-using Microsoft.ML.Transforms.FeatureSelection;
using System;
+using Microsoft.ML.Data;
+using Microsoft.ML.LightGBM.StaticPipe;
+using Microsoft.ML.StaticPipe;
namespace Microsoft.ML.Samples.Static
{
diff --git a/docs/samples/Microsoft.ML.Samples/Static/LightGBMRegression.cs b/docs/samples/Microsoft.ML.Samples/Static/LightGBMRegression.cs
index ca257d864f..45bf277157 100644
--- a/docs/samples/Microsoft.ML.Samples/Static/LightGBMRegression.cs
+++ b/docs/samples/Microsoft.ML.Samples/Static/LightGBMRegression.cs
@@ -1,7 +1,7 @@
-using Microsoft.ML.Runtime.Data;
-using Microsoft.ML.Runtime.LightGBM;
-using Microsoft.ML.StaticPipe;
using System;
+using Microsoft.ML.Data;
+using Microsoft.ML.LightGBM;
+using Microsoft.ML.LightGBM.StaticPipe;
namespace Microsoft.ML.Samples.Static
{
@@ -30,7 +30,7 @@ public static void LightGbmRegression()
var (trainData, testData) = mlContext.Regression.TrainTestSplit(data, testFraction: 0.1);
// The predictor that gets produced out of training
- LightGbmRegressionPredictor pred = null;
+ LightGbmRegressionModelParameters pred = null;
// Create the estimator
var learningPipeline = reader.MakeNewEstimator()
diff --git a/docs/samples/Microsoft.ML.Samples/Static/SDCABinaryClassification.cs b/docs/samples/Microsoft.ML.Samples/Static/SDCABinaryClassification.cs
index 886342c416..122057a96e 100644
--- a/docs/samples/Microsoft.ML.Samples/Static/SDCABinaryClassification.cs
+++ b/docs/samples/Microsoft.ML.Samples/Static/SDCABinaryClassification.cs
@@ -1,9 +1,6 @@
-using Microsoft.ML.Runtime.Data;
+using System;
+using Microsoft.ML.Data;
using Microsoft.ML.StaticPipe;
-using Microsoft.ML.Transforms;
-using Microsoft.ML.Transforms.Categorical;
-using Microsoft.ML.Transforms.FeatureSelection;
-using System;
namespace Microsoft.ML.Samples.Static
{
diff --git a/docs/samples/Microsoft.ML.Samples/Static/SDCARegression.cs b/docs/samples/Microsoft.ML.Samples/Static/SDCARegression.cs
index 4d35a28257..93c680d950 100644
--- a/docs/samples/Microsoft.ML.Samples/Static/SDCARegression.cs
+++ b/docs/samples/Microsoft.ML.Samples/Static/SDCARegression.cs
@@ -1,7 +1,7 @@
-using Microsoft.ML.Runtime.Data;
-using Microsoft.ML.Runtime.Learners;
-using Microsoft.ML.StaticPipe;
using System;
+using Microsoft.ML.Data;
+using Microsoft.ML.Learners;
+using Microsoft.ML.StaticPipe;
namespace Microsoft.ML.Samples.Static
{
@@ -29,7 +29,7 @@ public static void SdcaRegression()
var (trainData, testData) = mlContext.Regression.TrainTestSplit(data, testFraction: 0.1);
// The predictor that gets produced out of training
- LinearRegressionPredictor pred = null;
+ LinearRegressionModelParameters pred = null;
// Create the estimator
var learningPipeline = reader.MakeNewEstimator()
diff --git a/init-tools.cmd b/init-tools.cmd
index 3743cb413d..349f7b1461 100644
--- a/init-tools.cmd
+++ b/init-tools.cmd
@@ -46,15 +46,17 @@ if exist "%DotNetBuildToolsDir%" (
echo Running %0 > "%INIT_TOOLS_LOG%"
set /p DOTNET_VERSION=< "%~dp0DotnetCLIVersion.txt"
-if exist "%DOTNET_CMD%" goto :afterdotnetrestore
:Arg_Loop
if [%1] == [] goto :ArchSet
-if /i [%1] == [x86] ( set ARCH=x86&&goto ArchSet)
+if /i [%1] == [x86] ( set ARCH=x86)
+if /i [%1] == [-Debug-Intrinsics] ( set /p DOTNET_VERSION=< "%~dp0DotnetCLIVersion.netcoreapp.latest.txt")
+if /i [%1] == [-Release-Intrinsics] ( set /p DOTNET_VERSION=< "%~dp0DotnetCLIVersion.netcoreapp.latest.txt")
shift
goto :Arg_Loop
:ArchSet
+if exist "%DOTNET_CMD%" goto :afterdotnetrestore
echo Installing dotnet cli...
if NOT exist "%DOTNET_PATH%" mkdir "%DOTNET_PATH%"
diff --git a/pkg/Microsoft.ML.EntryPoints/Microsoft.ML.EntryPoints.nupkgproj b/pkg/Microsoft.ML.EntryPoints/Microsoft.ML.EntryPoints.nupkgproj
new file mode 100644
index 0000000000..b845fdeb45
--- /dev/null
+++ b/pkg/Microsoft.ML.EntryPoints/Microsoft.ML.EntryPoints.nupkgproj
@@ -0,0 +1,12 @@
+
+
+
+ netstandard2.0
+ Microsoft.ML.EntryPoints contains the ML.NET entry point API catalog.
+
+
+
+
+
+
+
diff --git a/pkg/Microsoft.ML.EntryPoints/Microsoft.ML.EntryPoints.symbols.nupkgproj b/pkg/Microsoft.ML.EntryPoints/Microsoft.ML.EntryPoints.symbols.nupkgproj
new file mode 100644
index 0000000000..3fa0255960
--- /dev/null
+++ b/pkg/Microsoft.ML.EntryPoints/Microsoft.ML.EntryPoints.symbols.nupkgproj
@@ -0,0 +1,5 @@
+
+
+
+
+
diff --git a/pkg/Microsoft.ML.OnnxTransform/Microsoft.ML.OnnxTransform.nupkgproj b/pkg/Microsoft.ML.OnnxTransform/Microsoft.ML.OnnxTransform.nupkgproj
index 4d64d756fe..b817e809d1 100644
--- a/pkg/Microsoft.ML.OnnxTransform/Microsoft.ML.OnnxTransform.nupkgproj
+++ b/pkg/Microsoft.ML.OnnxTransform/Microsoft.ML.OnnxTransform.nupkgproj
@@ -7,7 +7,7 @@
-
+
diff --git a/pkg/Microsoft.ML.TensorFlow.Redist/Microsoft.ML.TensorFlow.Redist.nupkgproj b/pkg/Microsoft.ML.TensorFlow.Redist/Microsoft.ML.TensorFlow.Redist.nupkgproj
index 595b6c1b7d..42d3387355 100644
--- a/pkg/Microsoft.ML.TensorFlow.Redist/Microsoft.ML.TensorFlow.Redist.nupkgproj
+++ b/pkg/Microsoft.ML.TensorFlow.Redist/Microsoft.ML.TensorFlow.Redist.nupkgproj
@@ -6,7 +6,7 @@
$(MSBuildProjectName) contains the TensorFlow C library version $(TensorFlowVersion) redistributed as a NuGet package.https://github.com/tensorflow/tensorflow/blob/master/LICENSEtrue
- Copyright 2018 The TensorFlow Authors. All rights reserved.
+ Copyright 2018 The TensorFlow Authors. All rights reserved.https://www.tensorflow.orghttps://github.com/tensorflow/tensorflow/releases/tag/v$(TensorFlowVersion)$(PackageTags) TensorFlow
diff --git a/src/Directory.Build.props b/src/Directory.Build.props
index 1ebabb9213..be6d1d53bc 100644
--- a/src/Directory.Build.props
+++ b/src/Directory.Build.props
@@ -32,4 +32,10 @@
Include="StyleCop.Analyzers" Version="1.1.0-beta008" PrivateAssets="All" />
+
+
+ stylecop.json
+
+
+
diff --git a/src/Microsoft.ML.Api/CodeGenerationUtils.cs b/src/Microsoft.ML.Api/CodeGenerationUtils.cs
deleted file mode 100644
index 1c39f1b9dd..0000000000
--- a/src/Microsoft.ML.Api/CodeGenerationUtils.cs
+++ /dev/null
@@ -1,142 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-// See the LICENSE file in the project root for more information.
-
-using System.Collections.Generic;
-using System.Text;
-using System.Text.RegularExpressions;
-using System.CodeDom;
-using Microsoft.ML.Runtime.Data;
-using Microsoft.CSharp;
-using System.IO;
-
-namespace Microsoft.ML.Runtime.Api
-{
- ///
- /// Utility methods for code generation.
- ///
- internal static class CodeGenerationUtils
- {
- ///
- /// Replace placeholders with provided values. Assert that every placeholder is found.
- ///
- public static string MultiReplace(string text, Dictionary replacementMap)
- {
- Contracts.AssertValue(text);
- Contracts.AssertValue(replacementMap);
- var pattern = @"\/\*#(.*)#\*\/.*\/\*#\/\1#\*\/[\n\r]{0,2}";
-
- int seenTags = 0;
- var result = Regex.Replace(text, pattern,
- match =>
- {
- var tag = match.Groups[1].Value;
- string replacement;
- bool found = replacementMap.TryGetValue(tag, out replacement);
- Contracts.Assert(found);
- seenTags++;
- return replacement;
- }, RegexOptions.Singleline);
-
- Contracts.Assert(seenTags == replacementMap.Count);
- return result;
- }
-
- ///
- /// Append a field declaration to the provided .
- ///
- public static void AppendFieldDeclaration(CSharpCodeProvider codeProvider, StringBuilder target, int columnIndex,
- string fieldName, ColumnType colType, bool appendInitializer, bool useVBuffer)
- {
- Contracts.AssertValueOrNull(codeProvider);
- Contracts.AssertValue(target);
- Contracts.Assert(columnIndex >= 0);
- Contracts.AssertNonEmpty(fieldName);
- Contracts.AssertValue(colType);
-
- var attributes = new List();
- string generatedCsTypeName = GetBackingTypeName(colType, useVBuffer, attributes);
-
- if (codeProvider != null && !codeProvider.IsValidIdentifier(fieldName))
- {
- attributes.Add(string.Format("[ColumnName({0})]", GetCSharpString(codeProvider, fieldName)));
- fieldName = string.Format("Column{0}", columnIndex);
- }
-
- const string indent = " ";
- if (attributes.Count > 0)
- {
- foreach (var attr in attributes)
- {
- target.Append(indent);
- target.AppendLine(attr);
- }
- }
- target.Append(indent);
- target.AppendFormat("public {0} {1}", generatedCsTypeName, fieldName);
-
- if (appendInitializer && colType is VectorType vecColType && vecColType.Size > 0 && !useVBuffer)
- {
- Contracts.Assert(generatedCsTypeName.EndsWith("[]"));
- var csItemType = generatedCsTypeName.Substring(0, generatedCsTypeName.Length - 2);
- target.AppendFormat(" = new {0}[{1}]", csItemType, vecColType.Size);
- }
- target.AppendLine(";");
- }
-
- ///
- /// Generates a C# string for a given input (with proper escaping).
- ///
- public static string GetCSharpString(CSharpCodeProvider codeProvider, string value)
- {
- using (var writer = new StringWriter())
- {
- codeProvider.GenerateCodeFromExpression(new CodePrimitiveExpression(value), writer, null);
- return writer.ToString();
- }
- }
-
- ///
- /// Gets the C# strings representing the type name for a variable corresponding to
- /// the column type.
- ///
- /// If the type is a vector, then controls whether the array field is
- /// generated or .
- ///
- /// If additional attributes are required, they are appended to the list.
- ///
- private static string GetBackingTypeName(ColumnType colType, bool useVBuffer, List attributes)
- {
- Contracts.AssertValue(colType);
- Contracts.AssertValue(attributes);
- if (colType is VectorType vecColType)
- {
- if (vecColType.Size > 0)
- {
- // By default, arrays are assumed variable length, unless a [VectorType(dim1, dim2, ...)]
- // attribute is applied to the fields.
- attributes.Add(string.Format("[VectorType({0})]", string.Join(", ", vecColType.Dimensions)));
- }
-
- var itemType = GetBackingTypeName(colType.ItemType, false, attributes);
- return useVBuffer ? string.Format("VBuffer<{0}>", itemType) : string.Format("{0}[]", itemType);
- }
-
- if (colType.IsText)
- return "string";
-
- if (colType.IsKey)
- {
- // The way to define a key type in C# is by specifying the [KeyType] attribute and
- // making the field type equal to the underlying raw type.
- var key = colType.AsKey;
- attributes.Add(string.Format("[KeyType(Count={0}, Min={1}, Contiguous={2})]",
- key.Count,
- key.Min,
- key.Contiguous ? "true" : "false"));
- }
-
- return colType.AsPrimitive.RawType.Name;
- }
- }
-}
diff --git a/src/Microsoft.ML.Api/GenerateCodeCommand.cs b/src/Microsoft.ML.Api/GenerateCodeCommand.cs
deleted file mode 100644
index 4855a94219..0000000000
--- a/src/Microsoft.ML.Api/GenerateCodeCommand.cs
+++ /dev/null
@@ -1,152 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-// See the LICENSE file in the project root for more information.
-
-using System.Collections.Generic;
-using System.IO;
-using System.Reflection;
-using System.Text;
-using Microsoft.CSharp;
-using Microsoft.ML.Runtime;
-using Microsoft.ML.Runtime.Api;
-using Microsoft.ML.Runtime.Command;
-using Microsoft.ML.Runtime.CommandLine;
-using Microsoft.ML.Runtime.Data;
-using Microsoft.ML.Runtime.Model;
-
-[assembly: LoadableClass(typeof(GenerateCodeCommand), typeof(GenerateCodeCommand.Arguments), typeof(SignatureCommand),
- "Generate Sample Prediction Code", GenerateCodeCommand.LoadName, "codegen")]
-
-namespace Microsoft.ML.Runtime.Api
-{
- ///
- /// Generates the sample prediction code for a given model file, with correct input and output classes.
- ///
- /// REVIEW: Consider adding support for generating VBuffers instead of arrays, maybe for high dimensionality vectors.
- ///
- internal sealed class GenerateCodeCommand : ICommand
- {
- public const string LoadName = "GenerateSamplePredictionCode";
- private const string CodeTemplatePath = "Microsoft.ML.Api.GeneratedCodeTemplate.csresource";
-
-#pragma warning disable 0649 // The command is internal, suppress a warning about fields never assigned to.
- public sealed class Arguments
- {
- [Argument(ArgumentType.Required, HelpText = "Input model file", ShortName = "in", IsInputFileName = true)]
- public string InputModelFile;
-
- [Argument(ArgumentType.Required, HelpText = "File to output generated C# code", ShortName = "cs")]
- public string CSharpOutput;
-
- ///
- /// Whether to use the to represent vector columns (supports sparse vectors).
- ///
- [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to use the VBuffer to represent vector columns (supports sparse vectors)",
- ShortName = "sparse", SortOrder = 102)]
- public bool SparseVectorDeclaration;
-
- // REVIEW: currently, it's only used in unit testing to not generate the paths into the test output folder.
- // However, it might be handy for automation scenarios, so I've added this as a hidden option.
- [Argument(ArgumentType.AtMostOnce, HelpText = "A location of the model file to put into generated file", Hide = true)]
- public string ModelNameOverride;
- }
-#pragma warning restore 0649
-
- private readonly IHost _host;
- private readonly Arguments _args;
-
- public GenerateCodeCommand(IHostEnvironment env, Arguments args)
- {
- Contracts.CheckValue(env, nameof(env));
- _host = env.Register("GenerateCodeCommand");
- _host.CheckValue(args, nameof(args));
- _host.CheckUserArg(!string.IsNullOrWhiteSpace(args.InputModelFile),
- nameof(args.InputModelFile), "input model file is required");
- _host.CheckUserArg(!string.IsNullOrWhiteSpace(args.CSharpOutput),
- nameof(args.CSharpOutput), "Output file is required");
- _args = args;
- }
-
- public void Run()
- {
- string template;
- using (var stream = Assembly.GetExecutingAssembly().GetManifestResourceStream(CodeTemplatePath))
- using (var reader = new StreamReader(stream))
- template = reader.ReadToEnd();
-
- var codeProvider = new CSharpCodeProvider();
- using (var fs = File.OpenRead(_args.InputModelFile))
- {
- var transformPipe = ModelFileUtils.LoadPipeline(_host, fs, new MultiFileSource(null), true);
- var pred = _host.LoadPredictorOrNull(fs);
-
- IDataView root;
- for (root = transformPipe; root is IDataTransform; root = ((IDataTransform)root).Source)
- ;
-
- // root is now the loader.
- _host.Assert(root is IDataLoader);
-
- // Loader columns.
- var loaderSb = new StringBuilder();
- for (int i = 0; i < root.Schema.ColumnCount; i++)
- {
- if (root.Schema.IsHidden(i))
- continue;
- if (loaderSb.Length > 0)
- loaderSb.AppendLine();
-
- ColumnType colType = root.Schema.GetColumnType(i);
- CodeGenerationUtils.AppendFieldDeclaration(codeProvider, loaderSb, i, root.Schema.GetColumnName(i), colType, true, _args.SparseVectorDeclaration);
- }
-
- // Scored example columns.
- IDataView scorer;
- if (pred == null)
- scorer = transformPipe;
- else
- {
- var roles = ModelFileUtils.LoadRoleMappingsOrNull(_host, fs);
- scorer = roles != null
- ? _host.CreateDefaultScorer(new RoleMappedData(transformPipe, roles, opt: true), pred)
- : _host.CreateDefaultScorer(new RoleMappedData(transformPipe, label: null, "Features"), pred);
- }
-
- var nonScoreSb = new StringBuilder();
- var scoreSb = new StringBuilder();
- for (int i = 0; i < scorer.Schema.ColumnCount; i++)
- {
- if (scorer.Schema.IsHidden(i))
- continue;
- bool isScoreColumn = scorer.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.ScoreColumnSetId, i) != null;
-
- var sb = isScoreColumn ? scoreSb : nonScoreSb;
-
- if (sb.Length > 0)
- sb.AppendLine();
-
- ColumnType colType = scorer.Schema.GetColumnType(i);
- CodeGenerationUtils.AppendFieldDeclaration(codeProvider, sb, i, scorer.Schema.GetColumnName(i), colType, false, _args.SparseVectorDeclaration);
- }
-
- // Turn model path into a C# identifier and insert it.
- var modelPath = !string.IsNullOrWhiteSpace(_args.ModelNameOverride) ? _args.ModelNameOverride : _args.InputModelFile;
- modelPath = CodeGenerationUtils.GetCSharpString(codeProvider, modelPath);
- modelPath = string.Format("modelPath = {0};", modelPath);
-
- // Replace values inside the template.
- var replacementMap =
- new Dictionary
- {
- { "EXAMPLE_CLASS_DECL", loaderSb.ToString() },
- { "SCORED_EXAMPLE_CLASS_DECL", nonScoreSb.ToString() },
- { "SCORE_CLASS_DECL", scoreSb.ToString() },
- { "MODEL_PATH", modelPath }
- };
-
- var classSource = CodeGenerationUtils.MultiReplace(template, replacementMap);
- File.WriteAllText(_args.CSharpOutput, classSource);
- }
- }
- }
-}
diff --git a/src/Microsoft.ML.Api/Microsoft.ML.Api.csproj b/src/Microsoft.ML.Api/Microsoft.ML.Api.csproj
deleted file mode 100644
index 88324dca5e..0000000000
--- a/src/Microsoft.ML.Api/Microsoft.ML.Api.csproj
+++ /dev/null
@@ -1,26 +0,0 @@
-
-
-
- netstandard2.0
- Microsoft.ML
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
diff --git a/src/Microsoft.ML.Api/PredictionFunction.cs b/src/Microsoft.ML.Api/PredictionFunction.cs
deleted file mode 100644
index 7243301f8d..0000000000
--- a/src/Microsoft.ML.Api/PredictionFunction.cs
+++ /dev/null
@@ -1,64 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-// See the LICENSE file in the project root for more information.
-
-using Microsoft.ML.Core.Data;
-using Microsoft.ML.Runtime.Api;
-
-namespace Microsoft.ML.Runtime.Data
-{
- ///
- /// A prediction engine class, that takes instances of through
- /// the transformer pipeline and produces instances of as outputs.
- ///
- public sealed class PredictionFunction
- where TSrc : class
- where TDst : class, new()
- {
- private readonly PredictionEngine _engine;
-
- ///
- /// Create an instance of .
- ///
- /// The host environment.
- /// The model (transformer) to use for prediction.
- public PredictionFunction(IHostEnvironment env, ITransformer transformer)
- {
- Contracts.CheckValue(env, nameof(env));
- env.CheckValue(transformer, nameof(transformer));
-
- IDataView dv = env.CreateDataView(new TSrc[0]);
- _engine = env.CreatePredictionEngine(transformer);
- }
-
- ///
- /// Perform one prediction using the model.
- ///
- /// The object that holds values to predict from.
- /// The object populated with prediction results.
- public TDst Predict(TSrc example) => _engine.Predict(example);
-
- ///
- /// Perform one prediction using the model.
- /// Reuses the provided prediction object, which is more efficient in high-load scenarios.
- ///
- /// The object that holds values to predict from.
- /// The object to store the predictions in. If it's null, a new object is created,
- /// otherwise the provided object is used.
- public void Predict(TSrc example, ref TDst prediction) => _engine.Predict(example, ref prediction);
- }
-
- public static class PredictionFunctionExtensions
- {
- ///
- /// Create an instance of the 'prediction function', or 'prediction machine', from a model
- /// denoted by .
- /// It will be accepting instances of as input, and produce
- /// instances of as output.
- ///
- public static PredictionFunction MakePredictionFunction(this ITransformer transformer, IHostEnvironment env)
- where TSrc : class
- where TDst : class, new()
- => new PredictionFunction(env, transformer);
- }
-}
diff --git a/src/Microsoft.ML.Api/Predictor.cs b/src/Microsoft.ML.Api/Predictor.cs
deleted file mode 100644
index 137c58bfb8..0000000000
--- a/src/Microsoft.ML.Api/Predictor.cs
+++ /dev/null
@@ -1,40 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-// See the LICENSE file in the project root for more information.
-
-using System;
-
-namespace Microsoft.ML.Runtime.Api
-{
- ///
- /// An opaque 'holder' of the predictor, meant to insulate the user from the internal TLC predictor structure,
- /// which is subject to change.
- ///
- public sealed class Predictor
- {
- ///
- /// The actual predictor.
- ///
- internal readonly IPredictor Pred;
-
- internal Predictor(IPredictor pred)
- {
- Contracts.AssertValue(pred);
- Pred = pred;
- }
-
- ///
- /// A way for the user to extract the predictor object and 'delve into the underworld' of unsupported non-API methods.
- /// This is needed, for instance, to inspect the weights of a predictor programmatically.
- /// The intention is to expose most methods through the API and make usage of this method increasingly unnecessary.
- ///
- [Obsolete("Welcome adventurous stranger, to the Underdark! By calling the mysterious GetPredictorObject method,"+
- " you have entered a world of shifting realities, where nothing is as it seems. Your code may work today, but"+
- " the churning impermanence of the Underdark means the strong foothold today may be nothing but empty air"+
- " tomorrow. Brace yourself!")]
- public object GetPredictorObject()
- {
- return Pred;
- }
- }
-}
diff --git a/src/Microsoft.ML.Console/Console.cs b/src/Microsoft.ML.Console/Console.cs
index 152d65951a..549f222de7 100644
--- a/src/Microsoft.ML.Console/Console.cs
+++ b/src/Microsoft.ML.Console/Console.cs
@@ -2,7 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-namespace Microsoft.ML.Runtime.Tools.Console
+namespace Microsoft.ML.Tools.Console
{
public static class Console
{
diff --git a/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj b/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj
index f7c87c0abd..7fadb9ea75 100644
--- a/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj
+++ b/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj
@@ -4,7 +4,7 @@
netcoreapp2.1ExeMML
- Microsoft.ML.Runtime.Tools.Console.Console
+ Microsoft.ML.Tools.Console.Console
diff --git a/src/Microsoft.ML.Core/BestFriendAttribute.cs b/src/Microsoft.ML.Core/BestFriendAttribute.cs
index 1470f95c34..19c70922e7 100644
--- a/src/Microsoft.ML.Core/BestFriendAttribute.cs
+++ b/src/Microsoft.ML.Core/BestFriendAttribute.cs
@@ -6,7 +6,7 @@
#if CPUMATH_INFRASTRUCTURE
// CpuMath has its own BestFriend and WantsToBeBestFriends attributes for making itself a standalone module
-namespace Microsoft.ML.Runtime.Internal.CpuMath.Core
+namespace Microsoft.ML.Internal.CpuMath.Core
#else
// This namespace contains the BestFriend and WantsToBeBestFriends attributes generally used in ML.NET project settings
namespace Microsoft.ML
diff --git a/src/Microsoft.ML.Core/CommandLine/ArgumentAttribute.cs b/src/Microsoft.ML.Core/CommandLine/ArgumentAttribute.cs
index 70f9ec8d98..fdc296d7f8 100644
--- a/src/Microsoft.ML.Core/CommandLine/ArgumentAttribute.cs
+++ b/src/Microsoft.ML.Core/CommandLine/ArgumentAttribute.cs
@@ -5,7 +5,7 @@
using System;
using System.Linq;
-namespace Microsoft.ML.Runtime.CommandLine
+namespace Microsoft.ML.CommandLine
{
///
/// Allows control of command line parsing.
diff --git a/src/Microsoft.ML.Core/CommandLine/ArgumentType.cs b/src/Microsoft.ML.Core/CommandLine/ArgumentType.cs
index 5840615fd5..d27acf9387 100644
--- a/src/Microsoft.ML.Core/CommandLine/ArgumentType.cs
+++ b/src/Microsoft.ML.Core/CommandLine/ArgumentType.cs
@@ -4,7 +4,7 @@
using System;
-namespace Microsoft.ML.Runtime.CommandLine
+namespace Microsoft.ML.CommandLine
{
///
/// Used to control parsing of command line arguments.
diff --git a/src/Microsoft.ML.Core/CommandLine/CharCursor.cs b/src/Microsoft.ML.Core/CommandLine/CharCursor.cs
index d8a591331c..85a60a7f36 100644
--- a/src/Microsoft.ML.Core/CommandLine/CharCursor.cs
+++ b/src/Microsoft.ML.Core/CommandLine/CharCursor.cs
@@ -2,9 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using Microsoft.ML.Runtime.Internal.Utilities;
+using Microsoft.ML.Internal.Utilities;
-namespace Microsoft.ML.Runtime.CommandLine
+namespace Microsoft.ML.CommandLine
{
internal sealed class CharCursor
{
diff --git a/src/Microsoft.ML.Core/CommandLine/CmdLexer.cs b/src/Microsoft.ML.Core/CommandLine/CmdLexer.cs
index a5c14259a8..ce4d6ef4d3 100644
--- a/src/Microsoft.ML.Core/CommandLine/CmdLexer.cs
+++ b/src/Microsoft.ML.Core/CommandLine/CmdLexer.cs
@@ -4,7 +4,7 @@
using System.Text;
-namespace Microsoft.ML.Runtime.CommandLine
+namespace Microsoft.ML.CommandLine
{
[BestFriend]
internal sealed class CmdLexer
@@ -302,7 +302,8 @@ public static string UnquoteValue(string str)
}
}
- public sealed class CmdQuoter
+ [BestFriend]
+ internal sealed class CmdQuoter
{
private readonly string _str;
private StringBuilder _sb;
diff --git a/src/Microsoft.ML.Core/CommandLine/CmdParser.cs b/src/Microsoft.ML.Core/CommandLine/CmdParser.cs
index 81018dad3e..cecb73fb4b 100644
--- a/src/Microsoft.ML.Core/CommandLine/CmdParser.cs
+++ b/src/Microsoft.ML.Core/CommandLine/CmdParser.cs
@@ -9,12 +9,10 @@
using System.IO;
using System.Linq;
using System.Reflection;
-using System.Runtime.InteropServices;
using System.Text;
-using Microsoft.ML.Runtime.EntryPoints;
-using Microsoft.ML.Runtime.Internal.Utilities;
+using Microsoft.ML.Internal.Utilities;
-namespace Microsoft.ML.Runtime.CommandLine
+namespace Microsoft.ML.CommandLine
{
///
@@ -537,16 +535,10 @@ private static ArgumentInfo GetArgumentInfo(Type type, object defaults)
private static ArgumentAttribute GetAttribute(FieldInfo field)
{
- var attrs = field.GetCustomAttributes(typeof(ArgumentAttribute), false).ToArray();
- if (attrs.Length == 1)
- {
- var argumentAttribute = (ArgumentAttribute)attrs[0];
- if (argumentAttribute.Visibility == ArgumentAttribute.VisibilityType.EntryPointsOnly)
- return null;
- return argumentAttribute;
- }
- Contracts.Assert(attrs.Length == 0);
- return null;
+ var argumentAttribute = field.GetCustomAttribute(false);
+ if (argumentAttribute?.Visibility == ArgumentAttribute.VisibilityType.EntryPointsOnly)
+ return null;
+ return argumentAttribute;
}
private void ReportUnrecognizedArgument(string argument)
diff --git a/src/Microsoft.ML.Core/CommandLine/DefaultArgumentAttribute.cs b/src/Microsoft.ML.Core/CommandLine/DefaultArgumentAttribute.cs
index 2d676f1ece..12121df9ea 100644
--- a/src/Microsoft.ML.Core/CommandLine/DefaultArgumentAttribute.cs
+++ b/src/Microsoft.ML.Core/CommandLine/DefaultArgumentAttribute.cs
@@ -4,7 +4,7 @@
using System;
-namespace Microsoft.ML.Runtime.CommandLine
+namespace Microsoft.ML.CommandLine
{
///
/// Indicates that this argument is the default argument.
diff --git a/src/Microsoft.ML.Core/CommandLine/EnumValueDisplayAttribute.cs b/src/Microsoft.ML.Core/CommandLine/EnumValueDisplayAttribute.cs
index b6cf4254ca..9b3652d9b4 100644
--- a/src/Microsoft.ML.Core/CommandLine/EnumValueDisplayAttribute.cs
+++ b/src/Microsoft.ML.Core/CommandLine/EnumValueDisplayAttribute.cs
@@ -4,7 +4,7 @@
using System;
-namespace Microsoft.ML.Runtime.CommandLine
+namespace Microsoft.ML.CommandLine
{
///
/// On an enum value - specifies the display name.
diff --git a/src/Microsoft.ML.Core/CommandLine/HideEnumValueAttribute.cs b/src/Microsoft.ML.Core/CommandLine/HideEnumValueAttribute.cs
index 964a5cc3f3..078a8abfa8 100644
--- a/src/Microsoft.ML.Core/CommandLine/HideEnumValueAttribute.cs
+++ b/src/Microsoft.ML.Core/CommandLine/HideEnumValueAttribute.cs
@@ -4,7 +4,7 @@
using System;
-namespace Microsoft.ML.Runtime.CommandLine
+namespace Microsoft.ML.CommandLine
{
///
/// On an enum value - indicates that the value should not be shown in help or UI.
diff --git a/src/Microsoft.ML.Core/CommandLine/SpecialPurpose.cs b/src/Microsoft.ML.Core/CommandLine/SpecialPurpose.cs
index 46423d43d9..491ef4b21c 100644
--- a/src/Microsoft.ML.Core/CommandLine/SpecialPurpose.cs
+++ b/src/Microsoft.ML.Core/CommandLine/SpecialPurpose.cs
@@ -2,7 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-namespace Microsoft.ML.Runtime.CommandLine
+namespace Microsoft.ML.CommandLine
{
[BestFriend]
internal static class SpecialPurpose
diff --git a/src/Microsoft.ML.Core/ComponentModel/AssemblyLoadingUtils.cs b/src/Microsoft.ML.Core/ComponentModel/AssemblyLoadingUtils.cs
index e947776bc9..a194088d3b 100644
--- a/src/Microsoft.ML.Core/ComponentModel/AssemblyLoadingUtils.cs
+++ b/src/Microsoft.ML.Core/ComponentModel/AssemblyLoadingUtils.cs
@@ -2,13 +2,13 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using Microsoft.ML.Runtime.Internal.Utilities;
using System;
using System.IO;
using System.IO.Compression;
using System.Reflection;
+using Microsoft.ML.Internal.Utilities;
-namespace Microsoft.ML.Runtime
+namespace Microsoft.ML
{
[Obsolete("The usage for this is intended for the internal command line utilities and is not intended for anything related to the API. " +
"Please consider another way of doing whatever it is you're attempting to accomplish.")]
@@ -159,7 +159,7 @@ private static bool ShouldSkipPath(string path)
case "neuraltreeevaluator.dll":
case "optimizationbuilderdotnet.dll":
case "parallelcommunicator.dll":
- case "microsoft.ml.runtime.runtests.dll":
+ case "Microsoft.ML.runtests.dll":
case "scopecompiler.dll":
case "symsgdnative.dll":
case "tbb.dll":
diff --git a/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs b/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs
index d0ab610a60..0b50c559a7 100644
--- a/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs
+++ b/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs
@@ -2,17 +2,17 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using Microsoft.ML.Runtime.CommandLine;
-using Microsoft.ML.Runtime.EntryPoints;
-using Microsoft.ML.Runtime.Internal.Utilities;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Text.RegularExpressions;
+using Microsoft.ML.CommandLine;
+using Microsoft.ML.EntryPoints;
+using Microsoft.ML.Internal.Utilities;
// REVIEW: Determine ideal namespace.
-namespace Microsoft.ML.Runtime
+namespace Microsoft.ML
{
///
/// This catalogs instantiatable components (aka, loadable classes). Components are registered via
@@ -40,7 +40,8 @@ internal ComponentCatalog()
///
/// Provides information on an instantiatable component, aka, loadable class.
///
- public sealed class LoadableClassInfo
+ [BestFriend]
+ internal sealed class LoadableClassInfo
{
///
/// Used for dictionary lookup based on signature and name.
@@ -264,7 +265,8 @@ public object CreateArguments()
///
/// A description of a single entry point.
///
- public sealed class EntryPointInfo
+ [BestFriend]
+ internal sealed class EntryPointInfo
{
public readonly string Name;
public readonly string Description;
@@ -333,7 +335,8 @@ private Type[] FindEntryPointKinds(Type type)
/// The 'component' is a non-standalone building block that is used to parametrize entry points or other ML.NET components.
/// For example, 'Loss function', or 'similarity calculator' could be components.
///
- public sealed class ComponentInfo
+ [BestFriend]
+ internal sealed class ComponentInfo
{
public readonly string Name;
public readonly string Description;
@@ -616,7 +619,8 @@ public void RegisterAssembly(Assembly assembly, bool throwOnError = true)
/// Return an array containing information for all instantiatable components.
/// If provided, the given set of assemblies is loaded first.
///
- public LoadableClassInfo[] GetAllClasses()
+ [BestFriend]
+ internal LoadableClassInfo[] GetAllClasses()
{
return _classes.ToArray();
}
@@ -625,7 +629,8 @@ public LoadableClassInfo[] GetAllClasses()
/// Return an array containing information for instantiatable components with the given
/// signature and base type. If provided, the given set of assemblies is loaded first.
///
- public LoadableClassInfo[] GetAllDerivedClasses(Type typeBase, Type typeSig)
+ [BestFriend]
+ internal LoadableClassInfo[] GetAllDerivedClasses(Type typeBase, Type typeSig)
{
Contracts.CheckValue(typeBase, nameof(typeBase));
Contracts.CheckValueOrNull(typeSig);
@@ -643,7 +648,8 @@ public LoadableClassInfo[] GetAllDerivedClasses(Type typeBase, Type typeSig)
/// Return an array containing all the known signature types. If provided, the given set of assemblies
/// is loaded first.
///
- public Type[] GetAllSignatureTypes()
+ [BestFriend]
+ internal Type[] GetAllSignatureTypes()
{
return _signatures.Select(kvp => kvp.Key).ToArray();
}
@@ -651,7 +657,8 @@ public Type[] GetAllSignatureTypes()
///
/// Returns a string name for a given signature type.
///
- public static string SignatureToString(Type sig)
+ [BestFriend]
+ internal static string SignatureToString(Type sig)
{
Contracts.CheckValue(sig, nameof(sig));
Contracts.CheckParam(sig.BaseType == typeof(MulticastDelegate), nameof(sig), "Must be a delegate type");
@@ -670,7 +677,8 @@ private LoadableClassInfo FindClassCore(LoadableClassInfo.Key key)
return null;
}
- public LoadableClassInfo[] FindLoadableClasses(string name)
+ [BestFriend]
+ internal LoadableClassInfo[] FindLoadableClasses(string name)
{
name = name.ToLowerInvariant().Trim();
@@ -680,14 +688,16 @@ public LoadableClassInfo[] FindLoadableClasses(string name)
return res;
}
- public LoadableClassInfo[] FindLoadableClasses()
+ [BestFriend]
+ internal LoadableClassInfo[] FindLoadableClasses()
{
return _classes
.Where(ci => ci.SignatureTypes.Contains(typeof(TSig)))
.ToArray();
}
- public LoadableClassInfo[] FindLoadableClasses()
+ [BestFriend]
+ internal LoadableClassInfo[] FindLoadableClasses()
{
// REVIEW: this and above methods perform a linear search over all the loadable classes.
// On 6/15/2015, TLC release build contained 431 of them, so adding extra lookups looks unnecessary at this time.
@@ -696,12 +706,14 @@ public LoadableClassInfo[] FindLoadableClasses()
.ToArray();
}
- public LoadableClassInfo GetLoadableClassInfo(string loadName)
+ [BestFriend]
+ internal LoadableClassInfo GetLoadableClassInfo(string loadName)
{
return GetLoadableClassInfo(loadName, typeof(TSig));
}
- public LoadableClassInfo GetLoadableClassInfo(string loadName, Type signatureType)
+ [BestFriend]
+ internal LoadableClassInfo GetLoadableClassInfo(string loadName, Type signatureType)
{
Contracts.CheckParam(signatureType.BaseType == typeof(MulticastDelegate), nameof(signatureType), "signatureType must be a delegate type");
Contracts.CheckValueOrNull(loadName);
@@ -712,18 +724,21 @@ public LoadableClassInfo GetLoadableClassInfo(string loadName, Type signatureTyp
///
/// Get all registered entry points.
///
- public IEnumerable AllEntryPoints()
+ [BestFriend]
+ internal IEnumerable AllEntryPoints()
{
return _entryPoints.AsEnumerable();
}
- public bool TryFindEntryPoint(string name, out EntryPointInfo entryPoint)
+ [BestFriend]
+ internal bool TryFindEntryPoint(string name, out EntryPointInfo entryPoint)
{
Contracts.CheckNonEmpty(name, nameof(name));
return _entryPointMap.TryGetValue(name, out entryPoint);
}
- public bool TryFindComponent(string kind, string alias, out ComponentInfo component)
+ [BestFriend]
+ internal bool TryFindComponent(string kind, string alias, out ComponentInfo component)
{
Contracts.CheckNonEmpty(kind, nameof(kind));
Contracts.CheckNonEmpty(alias, nameof(alias));
@@ -733,7 +748,8 @@ public bool TryFindComponent(string kind, string alias, out ComponentInfo compon
return _componentMap.TryGetValue($"{kind}:{alias}", out component);
}
- public bool TryFindComponent(Type argumentType, out ComponentInfo component)
+ [BestFriend]
+ internal bool TryFindComponent(Type argumentType, out ComponentInfo component)
{
Contracts.CheckValue(argumentType, nameof(argumentType));
@@ -741,7 +757,8 @@ public bool TryFindComponent(Type argumentType, out ComponentInfo component)
return component != null;
}
- public bool TryFindComponent(Type interfaceType, Type argumentType, out ComponentInfo component)
+ [BestFriend]
+ internal bool TryFindComponent(Type interfaceType, Type argumentType, out ComponentInfo component)
{
Contracts.CheckValue(interfaceType, nameof(interfaceType));
Contracts.CheckParam(interfaceType.IsInterface, nameof(interfaceType), "Must be interface");
@@ -751,7 +768,8 @@ public bool TryFindComponent(Type interfaceType, Type argumentType, out Componen
return component != null;
}
- public bool TryFindComponent(Type interfaceType, string alias, out ComponentInfo component)
+ [BestFriend]
+ internal bool TryFindComponent(Type interfaceType, string alias, out ComponentInfo component)
{
Contracts.CheckValue(interfaceType, nameof(interfaceType));
Contracts.CheckParam(interfaceType.IsInterface, nameof(interfaceType), "Must be interface");
@@ -764,7 +782,8 @@ public bool TryFindComponent(Type interfaceType, string alias, out ComponentInfo
/// Akin to , except if the regular (case sensitive) comparison fails, it will
/// attempt to back off to a case-insensitive comparison.
///
- public bool TryFindComponentCaseInsensitive(Type interfaceType, string alias, out ComponentInfo component)
+ [BestFriend]
+ internal bool TryFindComponentCaseInsensitive(Type interfaceType, string alias, out ComponentInfo component)
{
Contracts.CheckValue(interfaceType, nameof(interfaceType));
Contracts.CheckParam(interfaceType.IsInterface, nameof(interfaceType), "Must be interface");
@@ -786,7 +805,8 @@ private static bool AnyMatch(string name, string[] aliases)
///
/// Returns all valid component kinds.
///
- public IEnumerable GetAllComponentKinds()
+ [BestFriend]
+ internal IEnumerable GetAllComponentKinds()
{
return _components.Select(x => x.Kind).Distinct().OrderBy(x => x);
}
@@ -794,7 +814,8 @@ public IEnumerable GetAllComponentKinds()
///
/// Returns all components of the specified kind.
///
- public IEnumerable GetAllComponents(string kind)
+ [BestFriend]
+ internal IEnumerable GetAllComponents(string kind)
{
Contracts.CheckNonEmpty(kind, nameof(kind));
Contracts.CheckParam(IsValidName(kind), nameof(kind), "Invalid component kind");
@@ -804,13 +825,15 @@ public IEnumerable GetAllComponents(string kind)
///
/// Returns all components that implement the specified interface.
///
- public IEnumerable GetAllComponents(Type interfaceType)
+ [BestFriend]
+ internal IEnumerable GetAllComponents(Type interfaceType)
{
Contracts.CheckValue(interfaceType, nameof(interfaceType));
return _components.Where(x => x.InterfaceType == interfaceType).OrderBy(x => x.Name);
}
- public bool TryGetComponentKind(Type signatureType, out string kind)
+ [BestFriend]
+ internal bool TryGetComponentKind(Type signatureType, out string kind)
{
Contracts.CheckValue(signatureType, nameof(signatureType));
// REVIEW: replace with a dictionary lookup.
@@ -821,7 +844,8 @@ public bool TryGetComponentKind(Type signatureType, out string kind)
return faceAttr != null;
}
- public bool TryGetComponentShortName(Type type, out string name)
+ [BestFriend]
+ internal bool TryGetComponentShortName(Type type, out string name)
{
ComponentInfo component;
if (!TryFindComponent(type, out component))
@@ -850,7 +874,8 @@ private static bool IsValidName(string name)
///
/// Create an instance of the indicated component with the given extra parameters.
///
- public static TRes CreateInstance(IHostEnvironment env, Type signatureType, string name, string options, params object[] extra)
+ [BestFriend]
+ internal static TRes CreateInstance(IHostEnvironment env, Type signatureType, string name, string options, params object[] extra)
where TRes : class
{
TRes result;
@@ -863,13 +888,15 @@ public static TRes CreateInstance(IHostEnvironment env, Type signatureType
/// Try to create an instance of the indicated component and settings with the given extra parameters.
/// If there is no such component in the catalog, returns false. Any other error results in an exception.
///
- public static bool TryCreateInstance(IHostEnvironment env, out TRes result, string name, string options, params object[] extra)
+ [BestFriend]
+ internal static bool TryCreateInstance(IHostEnvironment env, out TRes result, string name, string options, params object[] extra)
where TRes : class
{
return TryCreateInstance(env, typeof(TSig), out result, name, options, extra);
}
- private static bool TryCreateInstance(IHostEnvironment env, Type signatureType, out TRes result, string name, string options, params object[] extra)
+ [BestFriend]
+ internal static bool TryCreateInstance(IHostEnvironment env, Type signatureType, out TRes result, string name, string options, params object[] extra)
where TRes : class
{
Contracts.CheckValue(env, nameof(env));
diff --git a/src/Microsoft.ML.Core/ComponentModel/ComponentFactory.cs b/src/Microsoft.ML.Core/ComponentModel/ComponentFactory.cs
index 58ea3c6baf..93ebdf6397 100644
--- a/src/Microsoft.ML.Core/ComponentModel/ComponentFactory.cs
+++ b/src/Microsoft.ML.Core/ComponentModel/ComponentFactory.cs
@@ -4,7 +4,7 @@
using System;
-namespace Microsoft.ML.Runtime
+namespace Microsoft.ML
{
///
/// This is a token interface that all component factories must implement.
@@ -48,7 +48,8 @@ public interface IComponentFactory
///
/// A utility class for creating instances.
///
- public static class ComponentFactoryUtils
+ [BestFriend]
+ internal static class ComponentFactoryUtils
{
///
/// Creates a component factory with no extra parameters (other than an )
diff --git a/src/Microsoft.ML.Core/ComponentModel/LoadableClassAttribute.cs b/src/Microsoft.ML.Core/ComponentModel/LoadableClassAttribute.cs
index 328a48be2b..7e8ea83e73 100644
--- a/src/Microsoft.ML.Core/ComponentModel/LoadableClassAttribute.cs
+++ b/src/Microsoft.ML.Core/ComponentModel/LoadableClassAttribute.cs
@@ -5,17 +5,19 @@
using System;
using System.Linq;
using System.Reflection;
-using Microsoft.ML.Runtime.Internal.Utilities;
+using Microsoft.ML.Internal.Utilities;
-namespace Microsoft.ML.Runtime
+namespace Microsoft.ML
{
///
/// Common signature type with no extra parameters.
///
- public delegate void SignatureDefault();
+ [BestFriend]
+ internal delegate void SignatureDefault();
[AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)]
- public sealed class LoadableClassAttribute : LoadableClassAttributeBase
+ [BestFriend]
+ internal sealed class LoadableClassAttribute : LoadableClassAttributeBase
{
///
/// Assembly attribute used to specify that a class is loadable by a machine learning
@@ -98,7 +100,7 @@ public LoadableClassAttribute(string summary, Type instType, Type loaderType, Ty
}
}
- public abstract class LoadableClassAttributeBase : Attribute
+ internal abstract class LoadableClassAttributeBase : Attribute
{
// Note: these properties have private setters to make attribute parsing easier - the values
// are all guaranteed to be in the ConstructorArguments of the CustomAttributeData
diff --git a/src/Microsoft.ML.Core/Data/ColumnType.cs b/src/Microsoft.ML.Core/Data/ColumnType.cs
index 83b5caa27d..eb3d080ecd 100644
--- a/src/Microsoft.ML.Core/Data/ColumnType.cs
+++ b/src/Microsoft.ML.Core/Data/ColumnType.cs
@@ -7,12 +7,11 @@
using System;
using System.Collections.Immutable;
using System.Linq;
-using System.Reflection;
using System.Text;
using System.Threading;
-using Microsoft.ML.Runtime.Internal.Utilities;
+using Microsoft.ML.Internal.Utilities;
-namespace Microsoft.ML.Runtime.Data
+namespace Microsoft.ML.Data
{
///
/// This is the abstract base class for all types in the type system.
@@ -80,12 +79,6 @@ private protected ColumnType(Type rawType, DataKind rawKind)
[BestFriend]
internal bool IsPrimitive { get; }
- ///
- /// Equivalent to as .
- ///
- [BestFriend]
- internal PrimitiveType AsPrimitive => IsPrimitive ? (PrimitiveType)this : null;
-
///
/// Whether this type is a standard numeric type. External code should use is .
///
@@ -140,12 +133,6 @@ internal bool IsBool
[BestFriend]
internal bool IsKey { get; }
- ///
- /// Equivalent to as .
- ///
- [BestFriend]
- internal KeyType AsKey => IsKey ? (KeyType)this : null;
-
///
/// Zero return means either it's not a key type or the cardinality is unknown. External code should first
/// test whether this is of type , then if so get the property
@@ -166,12 +153,6 @@ internal bool IsBool
[BestFriend]
internal bool IsVector { get; }
- ///
- /// Equivalent to as .
- ///
- [BestFriend]
- internal VectorType AsVector => IsVector ? (VectorType)this : null;
-
///
/// For non-vector types, this returns the column type itself (i.e., return this).
///
@@ -766,8 +747,7 @@ public override bool Equals(ColumnType other)
if (other == this)
return true;
- var tmp = other as KeyType;
- if (tmp == null)
+ if (!(other is KeyType tmp))
return false;
if (RawKind != tmp.RawKind)
return false;
diff --git a/src/Microsoft.ML.Core/Data/DataKind.cs b/src/Microsoft.ML.Core/Data/DataKind.cs
index ad8d8fbfe0..db5a75326e 100644
--- a/src/Microsoft.ML.Core/Data/DataKind.cs
+++ b/src/Microsoft.ML.Core/Data/DataKind.cs
@@ -3,9 +3,8 @@
// See the LICENSE file in the project root for more information.
using System;
-using System.Text;
-namespace Microsoft.ML.Runtime.Data
+namespace Microsoft.ML.Data
{
///
/// Data type specifier.
@@ -52,7 +51,8 @@ public enum DataKind : byte
///
/// Extension methods related to the DataKind enum.
///
- public static class DataKindExtensions
+ [BestFriend]
+ internal static class DataKindExtensions
{
public const DataKind KindMin = DataKind.I1;
public const DataKind KindLim = DataKind.U16 + 1;
@@ -171,7 +171,7 @@ public static Type ToType(this DataKind kind)
case DataKind.DZ:
return typeof(DateTimeOffset);
case DataKind.UG:
- return typeof(UInt128);
+ return typeof(RowId);
}
return null;
@@ -215,7 +215,7 @@ public static bool TryGetDataKind(this Type type, out DataKind kind)
kind = DataKind.DT;
else if (type == typeof(DateTimeOffset))
kind = DataKind.DZ;
- else if (type == typeof(UInt128))
+ else if (type == typeof(RowId))
kind = DataKind.UG;
else
{
diff --git a/src/Microsoft.ML.Core/Data/ICommand.cs b/src/Microsoft.ML.Core/Data/ICommand.cs
index 44d4c7340b..4c68218436 100644
--- a/src/Microsoft.ML.Core/Data/ICommand.cs
+++ b/src/Microsoft.ML.Core/Data/ICommand.cs
@@ -2,11 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using System;
-using Microsoft.ML.Runtime.Internal.Utilities;
-using Microsoft.ML.Runtime.Data;
-
-namespace Microsoft.ML.Runtime.Command
+namespace Microsoft.ML.Command
{
///
/// The signature for commands.
diff --git a/src/Microsoft.ML.Core/Data/ICursor.cs b/src/Microsoft.ML.Core/Data/ICursor.cs
deleted file mode 100644
index 476f07245c..0000000000
--- a/src/Microsoft.ML.Core/Data/ICursor.cs
+++ /dev/null
@@ -1,106 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-// See the LICENSE file in the project root for more information.
-
-using Float = System.Single;
-
-using System;
-
-namespace Microsoft.ML.Runtime.Data
-{
- ///
- /// This is a base interface for an and . It contains only the
- /// positional properties, no behavioral methods, and no data.
- ///
- public interface ICounted
- {
- ///
- /// This is incremented for ICursor when the underlying contents changes, giving clients a way to detect change.
- /// Generally it's -1 when the object is in an invalid state. In particular, for an , this is -1
- /// when the is or .
- ///
- /// Note that this position is not position within the underlying data, but position of this cursor only.
- /// If one, for example, opened a set of parallel streaming cursors, or a shuffled cursor, each such cursor's
- /// first valid entry would always have position 0.
- ///
- long Position { get; }
-
- ///
- /// This provides a means for reconciling multiple streams of counted things. Generally, in each stream,
- /// batch numbers should be non-decreasing. Furthermore, any given batch number should only appear in one
- /// of the streams. Order is determined by batch number. The reconciler ensures that each stream (that is
- /// still active) has at least one item available, then takes the item with the smallest batch number.
- ///
- /// Note that there is no suggestion that the batches for a particular entry will be consistent from
- /// cursoring to cursoring, except for the consistency in resulting in the same overall ordering. The same
- /// entry could have different batch numbers from one cursoring to another. There is also no requirement
- /// that any given batch number must appear, at all.
- ///
- long Batch { get; }
-
- ///
- /// A getter for a 128-bit ID value. It is common for objects to serve multiple
- /// instances to iterate over what is supposed to be the same data, for example, in a
- /// a cursor set will produce the same data as a serial cursor, just partitioned, and a shuffled cursor
- /// will produce the same data as a serial cursor or any other shuffled cursor, only shuffled. The ID
- /// exists for applications that need to reconcile which entry is actually which. Ideally this ID should
- /// be unique, but for practical reasons, it suffices if collisions are simply extremely improbable.
- ///
- /// Note that this ID, while it must be consistent for multiple streams according to the semantics
- /// above, is not considered part of the data per se. So, to take the example of a data view specifically,
- /// a single data view must render consistent IDs across all cursorings, but there is no suggestion at
- /// all that if the "same" data were presented in a different data view (as by, say, being transformed,
- /// cached, saved, or whatever), that the IDs between the two different data views would have any
- /// discernable relationship.
- ValueGetter GetIdGetter();
- }
-
- ///
- /// Defines the possible states of a cursor.
- ///
- public enum CursorState
- {
- NotStarted,
- Good,
- Done
- }
-
- ///
- /// The basic cursor interface. is incremented by
- /// and . When the cursor state is or
- /// , is -1. Otherwise,
- /// >= 0.
- ///
- public interface ICursor : ICounted, IDisposable
- {
- ///
- /// Returns the state of the cursor. Before the first call to or
- /// this should be . After
- /// any call those move functions that returns true, this should return
- /// ,
- ///
- CursorState State { get; }
-
- ///
- /// Advance to the next row. When the cursor is first created, this method should be called to
- /// move to the first row. Returns false if there are no more rows.
- ///
- bool MoveNext();
-
- ///
- /// Logically equivalent to calling the given number of times. The
- /// parameter must be positive. Note that cursor implementations may be
- /// able to optimize this.
- ///
- bool MoveMany(long count);
-
- ///
- /// Returns a cursor that can be used for invoking , ,
- /// , and , with results identical to calling those
- /// on this cursor. Generally, if the root cursor is not the same as this cursor, using the
- /// root cursor will be faster. As an aside, note that this is not necessarily the case of
- /// values from .
- ///
- ICursor GetRootCursor();
- }
-}
\ No newline at end of file
diff --git a/src/Microsoft.ML.Core/Data/IDataView.cs b/src/Microsoft.ML.Core/Data/IDataView.cs
index cb0da9a194..893b17a26a 100644
--- a/src/Microsoft.ML.Core/Data/IDataView.cs
+++ b/src/Microsoft.ML.Core/Data/IDataView.cs
@@ -2,17 +2,17 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using Microsoft.ML.Data;
using System;
using System.Collections.Generic;
-namespace Microsoft.ML.Runtime.Data
+namespace Microsoft.ML.Data
{
///
/// Legacy interface for schema information.
/// Please avoid implementing this interface, use .
///
- public interface ISchema
+ [BestFriend]
+ internal interface ISchema
{
///
/// Number of columns.
@@ -60,22 +60,11 @@ public interface ISchema
void GetMetadata(string kind, int col, ref TValue value);
}
- ///
- /// Base interface for schematized information. IDataView and IRowCursor both derive from this.
- ///
- public interface ISchematized
- {
- ///
- /// Gets an instance of Schema.
- ///
- Schema Schema { get; }
- }
-
///
/// The input and output of Query Operators (Transforms). This is the fundamental data pipeline
- /// type, comparable to IEnumerable for LINQ.
+ /// type, comparable to for LINQ.
///
- public interface IDataView : ISchematized
+ public interface IDataView
{
///
/// Whether this IDataView supports shuffling of rows, to any degree.
@@ -99,75 +88,187 @@ public interface IDataView : ISchematized
/// a getter for an inactive columns will throw. The predicate must be
/// non-null. To activate all columns, pass "col => true".
///
- IRowCursor GetRowCursor(Func needCol, IRandom rand = null);
+ RowCursor GetRowCursor(Func needCol, Random rand = null);
///
- /// This constructs a set of parallel batch cursors. The value n is a recommended limit
- /// on cardinality. If is non-positive, this indicates that the caller
- /// has no recommendation, and the implementation should have some default behavior to cover
- /// this case. Note that this is strictly a recommendation: it is entirely possible that
- /// an implementation can return a different number of cursors.
+ /// This constructs a set of parallel batch cursors. The value is a recommended limit on
+ /// cardinality. If is non-positive, this indicates that the caller has no recommendation,
+ /// and the implementation should have some default behavior to cover this case. Note that this is strictly a
+ /// recommendation: it is entirely possible that an implementation can return a different number of cursors.
///
/// The cursors should return the same data as returned through
- /// , except partitioned: no two cursors
- /// should return the "same" row as would have been returned through the regular serial cursor,
- /// but all rows should be returned by exactly one of the cursors returned from this cursor.
- /// The cursors can have their values reconciled downstream through the use of the
- /// property.
- ///
- /// This is an object that can be used to reconcile the
- /// returned array of cursors. When the array of cursors is of length 1, it is legal,
- /// indeed expected, that this parameter should be null.
+ /// , except partitioned: no two cursors should return the
+ /// "same" row as would have been returned through the regular serial cursor, but all rows should be returned by
+ /// exactly one of the cursors returned from this cursor. The cursors can have their values reconciled
+ /// downstream through the use of the property.
+ ///
+ /// The typical usage pattern is that a set of cursors is requested, each of them is then given to a set of
+ /// working threads that consume from them independently while, ultimately, the results are finally collated in
+ /// the end by exploiting the ordering of the property described above. More typical
+ /// scenarios will be content with pulling from the single serial cursor of
+ /// .
+ ///
/// The predicate, where a column is active if this returns true.
/// The suggested degree of parallelism.
/// An instance
///
- IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator,
- Func needCol, int n, IRandom rand = null);
- }
+ RowCursor[] GetRowCursorSet(Func needCol, int n, Random rand = null);
- ///
- /// This is used to consolidate parallel cursors into a single cursor. The object that determines
- /// the number of cursors and splits the row "stream" provides the consolidator object.
- ///
- public interface IRowCursorConsolidator
- {
///
- /// Create a consolidated cursor from the given parallel cursor set.
+ /// Gets an instance of Schema.
///
- IRowCursor CreateCursor(IChannelProvider provider, IRowCursor[] inputs);
+ Schema Schema { get; }
}
///
- /// Delegate type to get a value. This can used for efficient access to data in an IRow
- /// or IRowCursor.
+ /// Delegate type to get a value. This can be used for efficient access to data in a
+ /// or .
///
public delegate void ValueGetter(ref TValue value);
///
- /// A logical row. May be a row of an IDataView or a stand-alone row. If/when its contents
- /// change, its ICounted.Counter value is incremented.
+ /// A logical row. May be a row of an or a stand-alone row. If/when its contents
+ /// change, its value is changed.
///
- public interface IRow : ISchematized, ICounted
+ public abstract class Row : IDisposable
{
+ ///
+ /// This is incremented when the underlying contents changes, giving clients a way to detect change. Generally
+ /// it's -1 when the object is in an invalid state. In particular, for an , this is -1
+ /// when the is or .
+ ///
+ /// Note that this position is not position within the underlying data, but position of this cursor only. If
+ /// one, for example, opened a set of parallel streaming cursors, or a shuffled cursor, each such cursor's first
+ /// valid entry would always have position 0.
+ ///
+ public abstract long Position { get; }
+
+ ///
+ /// This provides a means for reconciling multiple rows that have been produced generally from
+ /// . When getting a set, there is a need
+ /// to, while allowing parallel processing to proceed, always have an aim that the original order should be
+ /// reconverable. Note, whether or not a user cares about that original order in ones specific application is
+ /// another story altogether (most callers of this as a practical matter do not, otherwise they would not call
+ /// it), but at least in principle it should be possible to reconstruct the original order one would get from an
+ /// identically configured . So: for any cursor
+ /// implementation, batch numbers should be non-decreasing. Furthermore, any given batch number should only
+ /// appear in one of the cursors as returned by
+ /// . In this way, order is determined by
+ /// batch number. An operation that reconciles these cursors to produce a consistent single cursoring, could do
+ /// so by drawing from the single cursor, among all cursors in the set, that has the smallest batch number
+ /// available.
+ ///
+ /// Note that there is no suggestion that the batches for a particular entry will be consistent from cursoring
+ /// to cursoring, except for the consistency in resulting in the same overall ordering. The same entry could
+ /// have different batch numbers from one cursoring to another. There is also no requirement that any given
+ /// batch number must appear, at all. It is merely a mechanism for recovering ordering from a possibly arbitrary
+ /// partitioning of the data. It also follows from this, of course, that considering the batch to be a property
+ /// of the data is completely invalid.
+ ///
+ public abstract long Batch { get; }
+
+ ///
+ /// A getter for a 128-bit ID value. It is common for objects to serve multiple
+ /// instances to iterate over what is supposed to be the same data, for example, in a
+ /// a cursor set will produce the same data as a serial cursor, just partitioned, and a shuffled cursor will
+ /// produce the same data as a serial cursor or any other shuffled cursor, only shuffled. The ID exists for
+ /// applications that need to reconcile which entry is actually which. Ideally this ID should be unique, but for
+ /// practical reasons, it suffices if collisions are simply extremely improbable.
+ ///
+ /// Note that this ID, while it must be consistent for multiple streams according to the semantics above, is not
+ /// considered part of the data per se. So, to take the example of a data view specifically, a single data view
+ /// must render consistent IDs across all cursorings, but there is no suggestion at all that if the "same" data
+ /// were presented in a different data view (as by, say, being transformed, cached, saved, or whatever), that
+ /// the IDs between the two different data views would have any discernable relationship.
+ public abstract ValueGetter GetIdGetter();
+
///
/// Returns whether the given column is active in this row.
///
- bool IsColumnActive(int col);
+ public abstract bool IsColumnActive(int col);
///
/// Returns a value getter delegate to fetch the given column value from the row.
/// This throws if the column is not active in this row, or if the type
/// differs from this column's type.
///
- ValueGetter GetGetter(int col);
+ public abstract ValueGetter GetGetter(int col);
+
+ ///
+ /// Gets a , which provides name and type information for variables
+ /// (i.e., columns in ML.NET's type system) stored in this row.
+ ///
+ public abstract Schema Schema { get; }
+
+ ///
+ /// Implementation of dispose. Calls with .
+ ///
+ public void Dispose()
+ {
+ Dispose(true);
+ GC.SuppressFinalize(this);
+ }
+
+ ///
+ /// The disposable method for the disposable pattern. This default implementation does nothing.
+ ///
+ /// Whether this was called from .
+ /// Subclasses that implement should call this method with
+ /// , but I hasten to add that implementing finalizers should be
+ /// avoided if at all possible..
+ protected virtual void Dispose(bool disposing)
+ {
+ }
+ }
+
+ ///
+ /// Defines the possible states of a cursor.
+ ///
+ public enum CursorState
+ {
+ NotStarted,
+ Good,
+ Done
}
///
- /// A cursor through rows of an . Note that this includes/is an
- /// , as well as an .
+ /// The basic cursor base class to cursor through rows of an . Note that
+ /// this is also an . The is incremented by
+ /// and . When the cursor state is or
+ /// , is -1. Otherwise,
+ /// >= 0.
///
- public interface IRowCursor : ICursor, IRow
+ public abstract class RowCursor : Row
{
+ ///
+ /// Returns the state of the cursor. Before the first call to or
+ /// this should be . After
+ /// any call those move functions that returns , this should return
+ /// ,
+ ///
+ public abstract CursorState State { get; }
+
+ ///
+ /// Advance to the next row. When the cursor is first created, this method should be called to
+ /// move to the first row. Returns false if there are no more rows.
+ ///
+ public abstract bool MoveNext();
+
+ ///
+ /// Logically equivalent to calling the given number of times. The
+ /// parameter must be positive. Note that cursor implementations may be
+ /// able to optimize this.
+ ///
+ public abstract bool MoveMany(long count);
+
+ ///
+ /// Returns a cursor that can be used for invoking , ,
+ /// , and , with results identical to calling those
+ /// on this cursor. Generally, if the root cursor is not the same as this cursor, using the
+ /// root cursor will be faster. As an aside, note that this is not necessarily the case of
+ /// values from .
+ ///
+ public abstract RowCursor GetRootCursor();
}
}
\ No newline at end of file
diff --git a/src/Microsoft.ML.Core/Data/IEstimator.cs b/src/Microsoft.ML.Core/Data/IEstimator.cs
index b82ad4f5a7..b0300795ec 100644
--- a/src/Microsoft.ML.Core/Data/IEstimator.cs
+++ b/src/Microsoft.ML.Core/Data/IEstimator.cs
@@ -2,27 +2,29 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using Microsoft.ML.Data;
-using Microsoft.ML.Runtime;
-using Microsoft.ML.Runtime.Data;
-using System;
+using System.Collections;
using System.Collections.Generic;
using System.Linq;
+using Microsoft.ML.Data;
namespace Microsoft.ML.Core.Data
{
///
/// A set of 'requirements' to the incoming schema, as well as a set of 'promises' of the outgoing schema.
- /// This is more relaxed than the proper , since it's only a subset of the columns,
+ /// This is more relaxed than the proper , since it's only a subset of the columns,
/// and also since it doesn't specify exact 's for vectors and keys.
///
- public sealed class SchemaShape
+ public sealed class SchemaShape : IReadOnlyList
{
- public readonly Column[] Columns;
+ private readonly Column[] _columns;
private static readonly SchemaShape _empty = new SchemaShape(Enumerable.Empty());
- public sealed class Column
+ public int Count => _columns.Count();
+
+ public Column this[int index] => _columns[index];
+
+ public struct Column
{
public enum VectorKind
{
@@ -55,13 +57,13 @@ public enum VectorKind
///
public readonly SchemaShape Metadata;
- public Column(string name, VectorKind vecKind, ColumnType itemType, bool isKey, SchemaShape metadata = null)
+ [BestFriend]
+ internal Column(string name, VectorKind vecKind, ColumnType itemType, bool isKey, SchemaShape metadata = null)
{
Contracts.CheckNonEmpty(name, nameof(name));
Contracts.CheckValueOrNull(metadata);
Contracts.CheckParam(!itemType.IsKey, nameof(itemType), "Item type cannot be a key");
Contracts.CheckParam(!itemType.IsVector, nameof(itemType), "Item type cannot be a vector");
-
Contracts.CheckParam(!isKey || KeyType.IsValidDataKind(itemType.RawKind), nameof(itemType), "The item type must be valid for a key");
Name = name;
@@ -80,9 +82,10 @@ public Column(string name, VectorKind vecKind, ColumnType itemType, bool isKey,
/// - The columns of of is a superset of our columns.
/// - Each such metadata column is itself compatible with the input metadata column.
///
- public bool IsCompatibleWith(Column inputColumn)
+ [BestFriend]
+ internal bool IsCompatibleWith(Column inputColumn)
{
- Contracts.CheckValue(inputColumn, nameof(inputColumn));
+ Contracts.Check(inputColumn.IsValid, nameof(inputColumn));
if (Name != inputColumn.Name)
return false;
if (Kind != inputColumn.Kind)
@@ -91,7 +94,7 @@ public bool IsCompatibleWith(Column inputColumn)
return false;
if (IsKey != inputColumn.IsKey)
return false;
- foreach (var metaCol in Metadata.Columns)
+ foreach (var metaCol in Metadata)
{
if (!inputColumn.Metadata.TryFindColumn(metaCol.Name, out var inputMetaCol))
return false;
@@ -101,7 +104,8 @@ public bool IsCompatibleWith(Column inputColumn)
return true;
}
- public string GetTypeString()
+ [BestFriend]
+ internal string GetTypeString()
{
string result = ItemType.ToString();
if (IsKey)
@@ -112,13 +116,20 @@ public string GetTypeString()
result = $"VarVector<{result}>";
return result;
}
+
+ ///
+ /// Return if this structure is not identical to the default value of . If true,
+ /// it means this structure is initialized properly and therefore considered as valid.
+ ///
+ [BestFriend]
+ internal bool IsValid => Name != null;
}
public SchemaShape(IEnumerable columns)
{
Contracts.CheckValue(columns, nameof(columns));
- Columns = columns.ToArray();
- Contracts.CheckParam(columns.All(c => c != null), nameof(columns), "No items should be null.");
+ _columns = columns.ToArray();
+ Contracts.CheckParam(columns.All(c => c.IsValid), nameof(columns), "Some items are not initialized properly.");
}
///
@@ -151,26 +162,27 @@ internal static void GetColumnTypeShape(ColumnType type,
///
/// Create a schema shape out of the fully defined schema.
///
- public static SchemaShape Create(Schema schema)
+ [BestFriend]
+ internal static SchemaShape Create(Schema schema)
{
Contracts.CheckValue(schema, nameof(schema));
var cols = new List();
for (int iCol = 0; iCol < schema.Count; iCol++)
{
- if (!schema.IsHidden(iCol))
+ if (!schema[iCol].IsHidden)
{
// First create the metadata.
var mCols = new List();
- foreach (var metaNameType in schema.GetMetadataTypes(iCol))
+ foreach (var metaColumn in schema[iCol].Metadata.Schema)
{
- GetColumnTypeShape(metaNameType.Value, out var mVecKind, out var mItemType, out var mIsKey);
- mCols.Add(new Column(metaNameType.Key, mVecKind, mItemType, mIsKey));
+ GetColumnTypeShape(metaColumn.Type, out var mVecKind, out var mItemType, out var mIsKey);
+ mCols.Add(new Column(metaColumn.Name, mVecKind, mItemType, mIsKey));
}
var metadata = mCols.Count > 0 ? new SchemaShape(mCols) : _empty;
// Next create the single column.
- GetColumnTypeShape(schema.GetColumnType(iCol), out var vecKind, out var itemType, out var isKey);
- cols.Add(new Column(schema.GetColumnName(iCol), vecKind, itemType, isKey, metadata));
+ GetColumnTypeShape(schema[iCol].Type, out var vecKind, out var itemType, out var isKey);
+ cols.Add(new Column(schema[iCol].Name, vecKind, itemType, isKey, metadata));
}
}
return new SchemaShape(cols);
@@ -179,25 +191,23 @@ public static SchemaShape Create(Schema schema)
///
/// Returns if there is a column with a specified and if so stores it in .
///
- public bool TryFindColumn(string name, out Column column)
+ [BestFriend]
+ internal bool TryFindColumn(string name, out Column column)
{
Contracts.CheckValue(name, nameof(name));
- column = Columns.FirstOrDefault(x => x.Name == name);
- return column != null;
+ column = _columns.FirstOrDefault(x => x.Name == name);
+ return column.IsValid;
}
+ public IEnumerator GetEnumerator() => ((IEnumerable)_columns).GetEnumerator();
+
+ IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
+
// REVIEW: I think we should have an IsCompatible method to check if it's OK to use one schema shape
// as an input to another schema shape. I started writing, but realized that there's more than one way to check for
// the 'compatibility': as in, 'CAN be compatible' vs. 'WILL be compatible'.
}
- ///
- /// Exception class for schema validation errors.
- ///
- public class SchemaException : Exception
- {
- }
-
///
/// The 'data reader' takes a certain kind of input and turns it into an .
///
@@ -246,7 +256,6 @@ public interface ITransformer
///
/// Schema propagation for transformers.
/// Returns the output schema of the data, if the input schema is like the one provided.
- /// Throws if the input schema is not valid for the transformer.
///
Schema GetOutputSchema(Schema inputSchema);
@@ -275,7 +284,7 @@ public interface ITransformer
///
/// The estimator (in Spark terminology) is an 'untrained transformer'. It needs to 'fit' on the data to manufacture
/// a transformer.
- /// It also provides the 'schema propagation' like transformers do, but over instead of .
+ /// It also provides the 'schema propagation' like transformers do, but over instead of .
///
public interface IEstimator
where TTransformer : ITransformer
@@ -288,7 +297,6 @@ public interface IEstimator
///
/// Schema propagation for estimators.
/// Returns the output schema shape of the estimator, if the input schema shape is like the one provided.
- /// Throws iff the input schema is not valid for the estimator.
///
SchemaShape GetOutputSchema(SchemaShape inputSchema);
}
diff --git a/src/Microsoft.ML.Core/Data/IFileHandle.cs b/src/Microsoft.ML.Core/Data/IFileHandle.cs
index 37b871b7b6..b5b2ae8183 100644
--- a/src/Microsoft.ML.Core/Data/IFileHandle.cs
+++ b/src/Microsoft.ML.Core/Data/IFileHandle.cs
@@ -5,9 +5,9 @@
using System;
using System.Collections.Generic;
using System.IO;
-using Microsoft.ML.Runtime.Internal.Utilities;
+using Microsoft.ML.Internal.Utilities;
-namespace Microsoft.ML.Runtime
+namespace Microsoft.ML
{
///
/// A file handle.
diff --git a/src/Microsoft.ML.Core/Data/IHostEnvironment.cs b/src/Microsoft.ML.Core/Data/IHostEnvironment.cs
index eb0d57845c..72639c6ac2 100644
--- a/src/Microsoft.ML.Core/Data/IHostEnvironment.cs
+++ b/src/Microsoft.ML.Core/Data/IHostEnvironment.cs
@@ -5,7 +5,7 @@
using System;
using System.ComponentModel.Composition.Hosting;
-namespace Microsoft.ML.Runtime
+namespace Microsoft.ML
{
///
/// A channel provider can create new channels and generic information pipes.
@@ -94,7 +94,7 @@ public interface IHost : IHostEnvironment
/// The random number generator issued to this component. Note that random number
/// generators are NOT thread safe.
///
- IRandom Rand { get; }
+ Random Rand { get; }
///
/// Signal to stop exection in this host and all its children.
@@ -233,7 +233,8 @@ public interface IChannel : IPipe
/// that do not belong in more specific areas, for example, or
/// component creation.
///
- public static class HostExtensions
+ [BestFriend]
+ internal static class HostExtensions
{
public static T Apply(this IHost host, string channelName, Func func)
{
diff --git a/src/Microsoft.ML.Core/Data/IProgressChannel.cs b/src/Microsoft.ML.Core/Data/IProgressChannel.cs
index 924e7806f9..26dd7e1831 100644
--- a/src/Microsoft.ML.Core/Data/IProgressChannel.cs
+++ b/src/Microsoft.ML.Core/Data/IProgressChannel.cs
@@ -4,7 +4,7 @@
using System;
-namespace Microsoft.ML.Runtime
+namespace Microsoft.ML
{
///
/// This is a factory interface for .
diff --git a/src/Microsoft.ML.Core/Data/IRowToRowMapper.cs b/src/Microsoft.ML.Core/Data/IRowToRowMapper.cs
new file mode 100644
index 0000000000..d9c55227f3
--- /dev/null
+++ b/src/Microsoft.ML.Core/Data/IRowToRowMapper.cs
@@ -0,0 +1,53 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+
+namespace Microsoft.ML.Data
+{
+ ///
+ /// This interface maps an input to an output . Typically, the output contains
+ /// both the input columns and new columns added by the implementing class, although some implementations may
+ /// return a subset of the input columns.
+ /// This interface is similar to , except it does not have any input role mappings,
+ /// so to rebind, the same input column names must be used.
+ /// Implementations of this interface are typically created over defined input .
+ ///
+ public interface IRowToRowMapper
+ {
+ ///
+ /// Mappers are defined as accepting inputs with this very specific schema.
+ ///
+ Schema InputSchema { get; }
+
+ ///
+ /// Gets an instance of which describes the columns' names and types in the output generated by this mapper.
+ ///
+ Schema OutputSchema { get; }
+
+ ///
+ /// Given a predicate specifying which columns are needed, return a predicate indicating which input columns are
+ /// needed. The domain of the function is defined over the indices of the columns of
+ /// for .
+ ///
+ Func GetDependencies(Func predicate);
+
+ ///
+ /// Get an with the indicated active columns, based on the input .
+ /// The active columns are those for which returns true. Getting values on inactive
+ /// columns of the returned row will throw. Null predicates are disallowed.
+ ///
+ /// The of should be the same object as
+ /// . Implementors of this method should throw if that is not the case. Conversely,
+ /// the returned value must have the same schema as .
+ ///
+ /// This method creates a live connection between the input and the output . In particular, when the getters of the output are invoked, they invoke the
+ /// getters of the input row and base the output values on the current values of the input .
+ /// The output values are re-computed when requested through the getters. Also, the returned
+ /// will dispose when it is disposed.
+ ///
+ Row GetRow(Row input, Func active);
+ }
+}
diff --git a/src/Microsoft.ML.Core/Data/ISchemaBindableMapper.cs b/src/Microsoft.ML.Core/Data/ISchemaBindableMapper.cs
index 525c7926a0..34b09fcf44 100644
--- a/src/Microsoft.ML.Core/Data/ISchemaBindableMapper.cs
+++ b/src/Microsoft.ML.Core/Data/ISchemaBindableMapper.cs
@@ -2,11 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using Microsoft.ML.Data;
-using System;
using System.Collections.Generic;
-namespace Microsoft.ML.Runtime.Data
+namespace Microsoft.ML.Data
{
///
/// A mapper that can be bound to a (which is an ISchema, with mappings from column kinds
@@ -21,7 +19,8 @@ namespace Microsoft.ML.Runtime.Data
/// for the output schema of the . In case the interface is implemented,
/// the SimpleRow class can be used in the method.
///
- public interface ISchemaBindableMapper
+ [BestFriend]
+ internal interface ISchemaBindableMapper
{
ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema);
}
@@ -30,13 +29,19 @@ public interface ISchemaBindableMapper
/// This interface is used to map a schema from input columns to output columns. The should keep track
/// of the input columns that are needed for the mapping.
///
- public interface ISchemaBoundMapper : ISchematized
+ [BestFriend]
+ internal interface ISchemaBoundMapper
{
///
/// The that was passed to the in the binding process.
///
RoleMappedSchema InputRoleMappedSchema { get; }
+ ///
+ /// Gets schema of this mapper's output.
+ ///
+ Schema OutputSchema { get; }
+
///
/// A property to get back the that produced this .
///
@@ -51,53 +56,13 @@ public interface ISchemaBoundMapper : ISchematized
///
/// This interface combines with .
///
- public interface ISchemaBoundRowMapper : ISchemaBoundMapper, IRowToRowMapper
- {
- }
-
- ///
- /// This interface maps an input to an output . Typically, the output contains
- /// both the input columns and new columns added by the implementing class, although some implementations may
- /// return a subset of the input columns.
- /// This interface is similar to , except it does not have any input role mappings,
- /// so to rebind, the same input column names must be used.
- /// Implementing of this object are typically created using a definie input .
- ///
- public interface IRowToRowMapper : ISchematized
+ [BestFriend]
+ internal interface ISchemaBoundRowMapper : ISchemaBoundMapper, IRowToRowMapper
{
///
- /// Mappers are defined as accepting inputs with this very specific schema.
- ///
- Schema InputSchema { get; }
-
- ///
- /// Given a predicate specifying which columns are needed, return a predicate indicating which input columns are
- /// needed. The domain of the function is defined over the indices of the columns of
- /// for .
- ///
- Func GetDependencies(Func predicate);
-
- ///
- /// Get an with the indicated active columns, based on the input .
- /// The active columns are those for which returns true. Getting values on inactive
- /// columns of the returned row will throw. Null predicates are disallowed.
- ///
- /// The of should be the same object as
- /// . Implementors of this method should throw if that is not the case. Conversely,
- /// the returned value must have the same schema as .
- ///
- /// This method creates a live connection between the input and the output . In particular, when the getters of the output are invoked, they invoke the
- /// getters of the input row and base the output values on the current values of the input .
- /// The output values are re-computed when requested through the getters.
- ///
- /// The optional should be invoked by any user of this row mapping, once it no
- /// longer needs the . If no action is needed when the cursor is Disposed, the implementation
- /// should set to null, otherwise it should be set to a delegate to be
- /// invoked by the code calling this object. (For example, a wrapping cursor's
- /// method. It's best for this action to be idempotent - calling it multiple times should be equivalent to
- /// calling it once.
+ /// There are two schemas from and .
+ /// Since the two parent schema's are identical in all derived classes, we merge them into .
///
- IRow GetRow(IRow input, Func active, out Action disposer);
+ new Schema OutputSchema { get; }
}
}
diff --git a/src/Microsoft.ML.Core/Data/IValueMapper.cs b/src/Microsoft.ML.Core/Data/IValueMapper.cs
index c6abfbc02d..dcdf6706c7 100644
--- a/src/Microsoft.ML.Core/Data/IValueMapper.cs
+++ b/src/Microsoft.ML.Core/Data/IValueMapper.cs
@@ -2,20 +2,20 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using System;
-
-namespace Microsoft.ML.Runtime.Data
+namespace Microsoft.ML.Data
{
///
/// Delegate type to map/convert a value.
///
- public delegate void ValueMapper(in TSrc src, ref TDst dst);
+ [BestFriend]
+ internal delegate void ValueMapper(in TSrc src, ref TDst dst);
///
/// Delegate type to map/convert among three values, for example, one input with two
/// outputs, or two inputs with one output.
///
- public delegate void ValueMapper(in TVal1 val1, ref TVal2 val2, ref TVal3 val3);
+ [BestFriend]
+ internal delegate void ValueMapper(in TVal1 val1, ref TVal2 val2, ref TVal3 val3);
///
/// Interface for mapping a single input value (of an indicated ColumnType) to
@@ -24,7 +24,8 @@ namespace Microsoft.ML.Runtime.Data
/// type arguments for GetMapper, but typically contain additional information like
/// vector lengths.
///
- public interface IValueMapper
+ [BestFriend]
+ internal interface IValueMapper
{
ColumnType InputType { get; }
ColumnType OutputType { get; }
@@ -43,7 +44,8 @@ public interface IValueMapper
/// type arguments for GetMapper, but typically contain additional information like
/// vector lengths.
///
- public interface IValueMapperDist : IValueMapper
+ [BestFriend]
+ internal interface IValueMapperDist : IValueMapper
{
ColumnType DistType { get; }
diff --git a/src/Microsoft.ML.Core/Data/InPredicate.cs b/src/Microsoft.ML.Core/Data/InPredicate.cs
index 74d16e906d..3c35bf600a 100644
--- a/src/Microsoft.ML.Core/Data/InPredicate.cs
+++ b/src/Microsoft.ML.Core/Data/InPredicate.cs
@@ -2,7 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-namespace Microsoft.ML.Runtime.Data
+namespace Microsoft.ML.Data
{
public delegate bool InPredicate(in T value);
}
diff --git a/src/Microsoft.ML.Core/Data/LinkedRootCursorBase.cs b/src/Microsoft.ML.Core/Data/LinkedRootCursorBase.cs
index ecec0b1a0d..efb785a835 100644
--- a/src/Microsoft.ML.Core/Data/LinkedRootCursorBase.cs
+++ b/src/Microsoft.ML.Core/Data/LinkedRootCursorBase.cs
@@ -2,43 +2,45 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-namespace Microsoft.ML.Runtime.Data
+namespace Microsoft.ML.Data
{
///
/// Base class for a cursor has an input cursor, but still needs to do work on
- /// MoveNext/MoveMany.
+ /// / .
///
[BestFriend]
- internal abstract class LinkedRootCursorBase : RootCursorBase
- where TInput : class, ICursor
+ internal abstract class LinkedRootCursorBase : RootCursorBase
{
- private readonly ICursor _root;
/// Gets the input cursor.
- protected TInput Input { get; }
+ protected RowCursor Input { get; }
///
/// Returns the root cursor of the input. It should be used to perform MoveNext or MoveMany operations.
- /// Note that GetRootCursor() returns "this", NOT Root. Root is used to advance our input, not for
- /// clients of this cursor. That's why it is protected, not public.
+ /// Note that returns , not .
+ /// is used to advance our input, not for clients of this cursor. That is why it is
+ /// protected, not public.
///
- protected ICursor Root { get { return _root; } }
+ protected RowCursor Root { get; }
- protected LinkedRootCursorBase(IChannelProvider provider, TInput input)
+ protected LinkedRootCursorBase(IChannelProvider provider, RowCursor input)
: base(provider)
{
Ch.AssertValue(input, nameof(input));
Input = input;
- _root = Input.GetRootCursor();
+ Root = Input.GetRootCursor();
}
- public override void Dispose()
+ protected override void Dispose(bool disposing)
{
- if (State != CursorState.Done)
+ if (State == CursorState.Done)
+ return;
+ if (disposing)
{
Input.Dispose();
- base.Dispose();
+ // The base class should set the state to done under these circumstances.
+ base.Dispose(true);
}
}
}
diff --git a/src/Microsoft.ML.Core/Data/LinkedRowFilterCursorBase.cs b/src/Microsoft.ML.Core/Data/LinkedRowFilterCursorBase.cs
index 0ed4dd19f9..fa35240ad1 100644
--- a/src/Microsoft.ML.Core/Data/LinkedRowFilterCursorBase.cs
+++ b/src/Microsoft.ML.Core/Data/LinkedRowFilterCursorBase.cs
@@ -2,9 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using Microsoft.ML.Data;
-
-namespace Microsoft.ML.Runtime.Data
+namespace Microsoft.ML.Data
{
///
/// Base class for creating a cursor of rows that filters out some input rows.
@@ -14,12 +12,12 @@ internal abstract class LinkedRowFilterCursorBase : LinkedRowRootCursorBase
{
public override long Batch => Input.Batch;
- protected LinkedRowFilterCursorBase(IChannelProvider provider, IRowCursor input, Schema schema, bool[] active)
+ protected LinkedRowFilterCursorBase(IChannelProvider provider, RowCursor input, Schema schema, bool[] active)
: base(provider, input, schema, active)
{
}
- public override ValueGetter GetIdGetter()
+ public override ValueGetter GetIdGetter()
{
return Input.GetIdGetter();
}
diff --git a/src/Microsoft.ML.Core/Data/LinkedRowRootCursorBase.cs b/src/Microsoft.ML.Core/Data/LinkedRowRootCursorBase.cs
index 188522ad3d..1f427ddeb9 100644
--- a/src/Microsoft.ML.Core/Data/LinkedRowRootCursorBase.cs
+++ b/src/Microsoft.ML.Core/Data/LinkedRowRootCursorBase.cs
@@ -2,25 +2,23 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using Microsoft.ML.Data;
-
-namespace Microsoft.ML.Runtime.Data
+namespace Microsoft.ML.Data
{
///
- /// A base class for a that has an input cursor, but still needs
- /// to do work on /. Note
+ /// A base class for a that has an input cursor, but still needs
+ /// to do work on /. Note
/// that the default assumes
/// that each input column is exposed as an output column with the same column index.
///
[BestFriend]
- internal abstract class LinkedRowRootCursorBase : LinkedRootCursorBase, IRowCursor
+ internal abstract class LinkedRowRootCursorBase : LinkedRootCursorBase
{
private readonly bool[] _active;
/// Gets row's schema.
- public Schema Schema { get; }
+ public sealed override Schema Schema { get; }
- protected LinkedRowRootCursorBase(IChannelProvider provider, IRowCursor input, Schema schema, bool[] active)
+ protected LinkedRowRootCursorBase(IChannelProvider provider, RowCursor input, Schema schema, bool[] active)
: base(provider, input)
{
Ch.CheckValue(schema, nameof(schema));
@@ -29,13 +27,13 @@ protected LinkedRowRootCursorBase(IChannelProvider provider, IRowCursor input, S
Schema = schema;
}
- public bool IsColumnActive(int col)
+ public sealed override bool IsColumnActive(int col)
{
Ch.Check(0 <= col && col < Schema.Count);
return _active == null || _active[col];
}
- public virtual ValueGetter GetGetter(int col)
+ public override ValueGetter GetGetter(int col)
{
return Input.GetGetter(col);
}
diff --git a/src/Microsoft.ML.Core/Data/MetadataBuilder.cs b/src/Microsoft.ML.Core/Data/MetadataBuilder.cs
index 0b72f53a25..65cafd2627 100644
--- a/src/Microsoft.ML.Core/Data/MetadataBuilder.cs
+++ b/src/Microsoft.ML.Core/Data/MetadataBuilder.cs
@@ -2,12 +2,10 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using Microsoft.ML.Runtime;
-using Microsoft.ML.Runtime.Data;
-using Microsoft.ML.Runtime.Internal.Utilities;
using System;
using System.Collections.Generic;
using System.Linq;
+using Microsoft.ML.Internal.Utilities;
namespace Microsoft.ML.Data
{
@@ -16,11 +14,11 @@ namespace Microsoft.ML.Data
///
public sealed class MetadataBuilder
{
- private readonly List<(string Name, ColumnType Type, Delegate Getter)> _items;
+ private readonly List<(string Name, ColumnType Type, Delegate Getter, Schema.Metadata Metadata)> _items;
public MetadataBuilder()
{
- _items = new List<(string Name, ColumnType Type, Delegate Getter)>();
+ _items = new List<(string Name, ColumnType Type, Delegate Getter, Schema.Metadata Metadata)>();
}
///
@@ -40,7 +38,7 @@ public void Add(Schema.Metadata metadata, Func selector)
foreach (var column in metadata.Schema)
{
if (selector(column.Name))
- _items.Add((column.Name, column.Type, metadata.Getters[column.Index]));
+ _items.Add((column.Name, column.Type, metadata.Getters[column.Index], column.Metadata));
}
}
@@ -51,13 +49,17 @@ public void Add(Schema.Metadata metadata, Func selector)
/// The metadata name.
/// The metadata type.
/// The getter delegate.
- public void Add(string name, ColumnType type, ValueGetter getter)
+ /// Metadata of the input column. Note that metadata on a metadata column is somewhat rare
+ /// except for certain types (for example, slot names for a vector, key values for something of key type).
+ public void Add(string name, ColumnType type, ValueGetter getter, Schema.Metadata metadata = null)
{
Contracts.CheckNonEmpty(name, nameof(name));
Contracts.CheckValue(type, nameof(type));
Contracts.CheckValue(getter, nameof(getter));
- Contracts.CheckParam(type.RawType == typeof(TValue), nameof(getter));
- _items.Add((name, type, getter));
+ Contracts.CheckParam(type.RawType == typeof(TValue), nameof(type));
+ Contracts.CheckValueOrNull(metadata);
+
+ _items.Add((name, type, getter, metadata));
}
///
@@ -67,11 +69,31 @@ public void Add(string name, ColumnType type, ValueGetter getter
/// The metadata type.
/// The getter delegate that provides the value. Note that the type of the getter is still checked
/// inside this method.
- public void Add(string name, ColumnType type, Delegate getter)
+ /// Metadata of the input column. Note that metadata on a metadata column is somewhat rare
+ /// except for certain types (for example, slot names for a vector, key values for something of key type).
+ public void Add(string name, ColumnType type, Delegate getter, Schema.Metadata metadata = null)
{
Contracts.CheckNonEmpty(name, nameof(name));
Contracts.CheckValue(type, nameof(type));
- Utils.MarshalActionInvoke(AddDelegate, type.RawType, name, type, getter);
+ Contracts.CheckValueOrNull(metadata);
+ Utils.MarshalActionInvoke(AddDelegate, type.RawType, name, type, getter, metadata);
+ }
+
+ ///
+ /// Add one metadata column for a primitive value type.
+ ///
+ /// The metadata name.
+ /// The metadata type.
+ /// The value of the metadata.
+ /// Metadata of the input column. Note that metadata on a metadata column is somewhat rare
+ /// except for certain types (for example, slot names for a vector, key values for something of key type).
+ public void AddPrimitiveValue(string name, PrimitiveType type, TValue value, Schema.Metadata metadata = null)
+ {
+ Contracts.CheckNonEmpty(name, nameof(name));
+ Contracts.CheckValue(type, nameof(type));
+ Contracts.CheckParam(type.RawType == typeof(TValue), nameof(type));
+ Contracts.CheckValueOrNull(metadata);
+ Add(name, type, (ref TValue dst) => dst = value, metadata);
}
///
@@ -100,11 +122,11 @@ public Schema.Metadata GetMetadata()
{
var builder = new SchemaBuilder();
foreach (var item in _items)
- builder.AddColumn(item.Name, item.Type, null);
+ builder.AddColumn(item.Name, item.Type, item.Metadata);
return new Schema.Metadata(builder.GetSchema(), _items.Select(x => x.Getter).ToArray());
}
- private void AddDelegate(string name, ColumnType type, Delegate getter)
+ private void AddDelegate(string name, ColumnType type, Delegate getter, Schema.Metadata metadata)
{
Contracts.AssertNonEmpty(name);
Contracts.AssertValue(type);
@@ -112,7 +134,7 @@ private void AddDelegate(string name, ColumnType type, Delegate getter)
var typedGetter = getter as ValueGetter;
Contracts.CheckParam(typedGetter != null, nameof(getter));
- _items.Add((name, type, typedGetter));
+ _items.Add((name, type, typedGetter, metadata));
}
}
}
diff --git a/src/Microsoft.ML.Core/Data/MetadataUtils.cs b/src/Microsoft.ML.Core/Data/MetadataUtils.cs
index 9178c1c129..efa22d134a 100644
--- a/src/Microsoft.ML.Core/Data/MetadataUtils.cs
+++ b/src/Microsoft.ML.Core/Data/MetadataUtils.cs
@@ -9,10 +9,9 @@
using System.Linq;
using System.Threading;
using Microsoft.ML.Core.Data;
-using Microsoft.ML.Data;
-using Microsoft.ML.Runtime.Internal.Utilities;
+using Microsoft.ML.Internal.Utilities;
-namespace Microsoft.ML.Runtime.Data
+namespace Microsoft.ML.Data
{
///
/// Utilities for implementing and using the metadata API of .
@@ -117,23 +116,20 @@ public static class ScoreValueKind
///
/// Returns a standard exception for responding to an invalid call to GetMetadata.
///
- public static Exception ExceptGetMetadata()
- {
- return Contracts.Except("Invalid call to GetMetadata");
- }
+ [BestFriend]
+ internal static Exception ExceptGetMetadata() => Contracts.Except("Invalid call to GetMetadata");
///
/// Returns a standard exception for responding to an invalid call to GetMetadata.
///
- public static Exception ExceptGetMetadata(this IExceptionContext ctx)
- {
- return ctx.Except("Invalid call to GetMetadata");
- }
+ [BestFriend]
+ internal static Exception ExceptGetMetadata(this IExceptionContext ctx) => ctx.Except("Invalid call to GetMetadata");
///
/// Helper to marshal a call to GetMetadata{TValue} to a specific type.
///
- public static void Marshal(this MetadataGetter getter, int col, ref TNeed dst)
+ [BestFriend]
+ internal static void Marshal(this MetadataGetter getter, int col, ref TNeed dst)
{
Contracts.CheckValue(getter, nameof(getter));
@@ -147,7 +143,8 @@ public static void Marshal(this MetadataGetter getter, int
/// Returns a vector type with item type text and the given size. The size must be positive.
/// This is a standard type for metadata consisting of multiple text values, eg SlotNames.
///
- public static VectorType GetNamesType(int size)
+ [BestFriend]
+ internal static VectorType GetNamesType(int size)
{
Contracts.CheckParam(size > 0, nameof(size), "must be known size");
return new VectorType(TextType.Instance, size);
@@ -159,18 +156,20 @@ public static VectorType GetNamesType(int size)
/// This is a standard type for metadata consisting of multiple int values that represent
/// categorical slot ranges with in a column.
///
- public static VectorType GetCategoricalType(int rangeCount)
+ [BestFriend]
+ internal static VectorType GetCategoricalType(int rangeCount)
{
Contracts.CheckParam(rangeCount > 0, nameof(rangeCount), "must be known size");
return new VectorType(NumberType.I4, rangeCount, 2);
}
- private static volatile ColumnType _scoreColumnSetIdType;
+ private static volatile KeyType _scoreColumnSetIdType;
///
/// The type of the ScoreColumnSetId metadata.
///
- public static ColumnType ScoreColumnSetIdType
+ [BestFriend]
+ internal static KeyType ScoreColumnSetIdType
{
get
{
@@ -186,7 +185,8 @@ public static ColumnType ScoreColumnSetIdType
///
/// Returns a key-value pair useful when implementing GetMetadataTypes(col).
///
- public static KeyValuePair GetSlotNamesPair(int size)
+ [BestFriend]
+ internal static KeyValuePair GetSlotNamesPair(int size)
{
return GetNamesType(size).GetPair(Kinds.SlotNames);
}
@@ -195,7 +195,8 @@ public static KeyValuePair GetSlotNamesPair(int size)
/// Returns a key-value pair useful when implementing GetMetadataTypes(col). This assumes
/// that the values of the key type are Text.
///
- public static KeyValuePair GetKeyNamesPair(int size)
+ [BestFriend]
+ internal static KeyValuePair GetKeyNamesPair(int size)
{
return GetNamesType(size).GetPair(Kinds.KeyValues);
}
@@ -204,7 +205,8 @@ public static KeyValuePair GetKeyNamesPair(int size)
/// Given a type and metadata kind string, returns a key-value pair. This is useful when
/// implementing GetMetadataTypes(col).
///
- public static KeyValuePair GetPair(this ColumnType type, string kind)
+ [BestFriend]
+ internal static KeyValuePair GetPair(this ColumnType type, string kind)
{
Contracts.CheckValue(type, nameof(type));
return new KeyValuePair(kind, type);
@@ -215,7 +217,8 @@ public static KeyValuePair GetPair(this ColumnType type, str
///
/// Prepends a params array to an enumerable. Useful when implementing GetMetadataTypes.
///
- public static IEnumerable Prepend(this IEnumerable tail, params T[] head)
+ [BestFriend]
+ internal static IEnumerable Prepend(this IEnumerable tail, params T[] head)
{
return head.Concat(tail);
}
@@ -234,13 +237,13 @@ public static uint GetMaxMetadataKind(this Schema schema, out int colMax, string
colMax = -1;
for (int col = 0; col < schema.Count; col++)
{
- var columnType = schema.GetMetadataTypeOrNull(metadataKind, col);
+ var columnType = schema[col].Metadata.Schema.GetColumnOrNull(metadataKind)?.Type;
if (columnType == null || !columnType.IsKey || columnType.RawKind != DataKind.U4)
continue;
if (filterFunc != null && !filterFunc(schema, col))
continue;
uint value = 0;
- schema.GetMetadata(metadataKind, col, ref value);
+ schema[col].Metadata.GetValue(metadataKind, ref value);
if (max < value)
{
max = value;
@@ -254,15 +257,16 @@ public static uint GetMaxMetadataKind(this Schema schema, out int colMax, string
/// Returns the set of column ids which match the value of specified metadata kind.
/// The metadata type should be a KeyType with raw type U4.
///
- public static IEnumerable GetColumnSet(this Schema schema, string metadataKind, uint value)
+ [BestFriend]
+ internal static IEnumerable GetColumnSet(this Schema schema, string metadataKind, uint value)
{
for (int col = 0; col < schema.Count; col++)
{
- var columnType = schema.GetMetadataTypeOrNull(metadataKind, col);
+ var columnType = schema[col].Metadata.Schema.GetColumnOrNull(metadataKind)?.Type;
if (columnType != null && columnType.IsKey && columnType.RawKind == DataKind.U4)
{
uint val = 0;
- schema.GetMetadata(metadataKind, col, ref val);
+ schema[col].Metadata.GetValue(metadataKind, ref val);
if (val == value)
yield return col;
}
@@ -273,15 +277,16 @@ public static IEnumerable GetColumnSet(this Schema schema, string metadataK
/// Returns the set of column ids which match the value of specified metadata kind.
/// The metadata type should be of type text.
///
- public static IEnumerable GetColumnSet(this Schema schema, string metadataKind, string value)
+ [BestFriend]
+ internal static IEnumerable GetColumnSet(this Schema schema, string metadataKind, string value)
{
for (int col = 0; col < schema.Count; col++)
{
- var columnType = schema.GetMetadataTypeOrNull(metadataKind, col);
+ var columnType = schema[col].Metadata.Schema.GetColumnOrNull(metadataKind)?.Type;
if (columnType != null && columnType.IsText)
{
ReadOnlyMemory val = default;
- schema.GetMetadata(metadataKind, col, ref val);
+ schema[col].Metadata.GetValue(metadataKind, ref val);
if (ReadOnlyMemoryUtils.EqualsStr(value, val))
yield return col;
}
@@ -290,48 +295,60 @@ public static IEnumerable GetColumnSet(this Schema schema, string metadataK
///
/// Returns true if the specified column:
- /// * is a vector of length N (including 0)
+ /// * is a vector of length N
/// * has a SlotNames metadata
/// * metadata type is VBuffer<ReadOnlyMemory<char>> of length N
///
- public static bool HasSlotNames(this Schema schema, int col, int vectorSize)
+ public static bool HasSlotNames(this Schema.Column column)
+ => column.Type.IsKnownSizeVector && column.HasSlotNames(column.Type.VectorSize);
+
+ ///
+ /// Returns true if the specified column:
+ /// * has a SlotNames metadata
+ /// * metadata type is VBuffer<ReadOnlyMemory<char>> of length .
+ ///
+ [BestFriend]
+ internal static bool HasSlotNames(this Schema.Column column, int vectorSize)
{
if (vectorSize == 0)
return false;
- var type = schema.GetMetadataTypeOrNull(Kinds.SlotNames, col);
+ var metaColumn = column.Metadata.Schema.GetColumnOrNull(Kinds.SlotNames);
return
- type != null
- && type.IsVector
- && type.VectorSize == vectorSize
- && type.ItemType.IsText;
+ metaColumn != null
+ && metaColumn.Value.Type.IsVector
+ && metaColumn.Value.Type.VectorSize == vectorSize
+ && metaColumn.Value.Type.ItemType.IsText;
}
- public static void GetSlotNames(RoleMappedSchema schema, RoleMappedSchema.ColumnRole role, int vectorSize, ref VBuffer> slotNames)
+ public static void GetSlotNames(this Schema.Column column, ref VBuffer> slotNames)
+ => column.Metadata.GetValue(Kinds.SlotNames, ref slotNames);
+
+ [BestFriend]
+ internal static void GetSlotNames(RoleMappedSchema schema, RoleMappedSchema.ColumnRole role, int vectorSize, ref VBuffer> slotNames)
{
Contracts.CheckValueOrNull(schema);
Contracts.CheckParam(vectorSize >= 0, nameof(vectorSize));
- IReadOnlyList list;
- if ((list = schema?.GetColumns(role)) == null || list.Count != 1 || !schema.Schema.HasSlotNames(list[0].Index, vectorSize))
- {
+ IReadOnlyList list = schema?.GetColumns(role);
+ if (list?.Count != 1 || !schema.Schema[list[0].Index].HasSlotNames(vectorSize))
VBufferUtils.Resize(ref slotNames, vectorSize, 0);
- }
else
- schema.Schema.GetMetadata(Kinds.SlotNames, list[0].Index, ref slotNames);
+ schema.Schema[list[0].Index].Metadata.GetValue(Kinds.SlotNames, ref slotNames);
}
- public static bool HasKeyValues(this Schema schema, int col, int keyCount)
+ [BestFriend]
+ internal static bool HasKeyValues(this Schema.Column column, int keyCount)
{
if (keyCount == 0)
return false;
- var type = schema.GetMetadataTypeOrNull(Kinds.KeyValues, col);
+ var metaColumn = column.Metadata.Schema.GetColumnOrNull(Kinds.KeyValues);
return
- type != null
- && type.IsVector
- && type.VectorSize == keyCount
- && type.ItemType.IsText;
+ metaColumn != null
+ && metaColumn.Value.Type.IsVector
+ && metaColumn.Value.Type.VectorSize == keyCount
+ && metaColumn.Value.Type.ItemType.IsText;
}
[BestFriend]
@@ -343,19 +360,17 @@ internal static bool HasKeyValues(this SchemaShape.Column col)
}
///
- /// Returns whether a column has the metadata set to true.
- /// That metadata should be set when the data has undergone transforms that would render it
- /// "normalized."
+ /// Returns true iff has IsNormalized metadata set to true.
///
- /// The schema to query
- /// Which column in the schema to query
- /// True if and only if the column has the metadata
- /// set to the scalar value true
- public static bool IsNormalized(this Schema schema, int col)
+ public static bool IsNormalized(this Schema.Column column)
{
- Contracts.CheckValue(schema, nameof(schema));
- var value = default(bool);
- return schema.TryGetMetadata(BoolType.Instance, Kinds.IsNormalized, col, ref value) && value;
+ var metaColumn = column.Metadata.Schema.GetColumnOrNull((Kinds.IsNormalized));
+ if (metaColumn == null || !metaColumn.Value.Type.IsBool)
+ return false;
+
+ bool value = default;
+ column.Metadata.GetValue(Kinds.IsNormalized, ref value);
+ return value;
}
///
@@ -367,7 +382,7 @@ public static bool IsNormalized(this Schema schema, int col)
/// of a scalar type, which we assume, if set, should be true.
public static bool IsNormalized(this SchemaShape.Column col)
{
- Contracts.CheckValue(col, nameof(col));
+ Contracts.CheckParam(col.IsValid, nameof(col), "struct not initialized properly");
return col.Metadata.TryFindColumn(Kinds.IsNormalized, out var metaCol)
&& metaCol.Kind == SchemaShape.Column.VectorKind.Scalar && !metaCol.IsKey
&& metaCol.ItemType == BoolType.Instance;
@@ -382,7 +397,7 @@ public static bool IsNormalized(this SchemaShape.Column col)
/// metadata of definite sized vectors of text.
public static bool HasSlotNames(this SchemaShape.Column col)
{
- Contracts.CheckValue(col, nameof(col));
+ Contracts.CheckParam(col.IsValid, nameof(col), "struct not initialized properly");
return col.Kind == SchemaShape.Column.VectorKind.Vector
&& col.Metadata.TryFindColumn(Kinds.SlotNames, out var metaCol)
&& metaCol.Kind == SchemaShape.Column.VectorKind.Vector && !metaCol.IsKey
@@ -399,23 +414,19 @@ public static bool HasSlotNames(this SchemaShape.Column col)
/// The column
/// The value to return, if successful
/// True if the metadata of the right type exists, false otherwise
- public static bool TryGetMetadata(this Schema schema, PrimitiveType type, string kind, int col, ref T value)
+ [BestFriend]
+ internal static bool TryGetMetadata(this Schema schema, PrimitiveType type, string kind, int col, ref T value)
{
Contracts.CheckValue(schema, nameof(schema));
Contracts.CheckValue(type, nameof(type));
- var metadataType = schema.GetMetadataTypeOrNull(kind, col);
+ var metadataType = schema[col].Metadata.Schema.GetColumnOrNull(kind)?.Type;
if (!type.Equals(metadataType))
return false;
- schema.GetMetadata(kind, col, ref value);
+ schema[col].Metadata.GetValue(kind, ref value);
return true;
}
- ///
- /// Return whether the given column index is hidden in the given schema.
- ///
- public static bool IsHidden(this Schema schema, int col) => schema[col].IsHidden;
-
///
/// The categoricalFeatures is a vector of the indices of categorical features slots.
/// This vector should always have an even number of elements, and the elements should be parsed in groups of two consecutive numbers.
@@ -424,21 +435,22 @@ public static bool TryGetMetadata(this Schema schema, PrimitiveType type, str
/// The way to interpret that is: feature with indices 0, 1, and 2 are one categorical
/// Features with indices 3 and 4 are another categorical. Features 5 and 6 don't appear there, so they are not categoricals.
///
- public static bool TryGetCategoricalFeatureIndices(Schema schema, int colIndex, out int[] categoricalFeatures)
+ [BestFriend]
+ internal static bool TryGetCategoricalFeatureIndices(Schema schema, int colIndex, out int[] categoricalFeatures)
{
Contracts.CheckValue(schema, nameof(schema));
Contracts.Check(colIndex >= 0, nameof(colIndex));
bool isValid = false;
categoricalFeatures = null;
- if (!(schema.GetColumnType(colIndex) is VectorType vecType && vecType.Size > 0))
+ if (!(schema[colIndex].Type is VectorType vecType && vecType.Size > 0))
return isValid;
- var type = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.CategoricalSlotRanges, colIndex);
+ var type = schema[colIndex].Metadata.Schema.GetColumnOrNull(Kinds.CategoricalSlotRanges)?.Type;
if (type?.RawType == typeof(VBuffer))
{
VBuffer catIndices = default(VBuffer);
- schema.GetMetadata(MetadataUtils.Kinds.CategoricalSlotRanges, colIndex, ref catIndices);
+ schema[colIndex].Metadata.GetValue(Kinds.CategoricalSlotRanges, ref catIndices);
VBufferUtils.Densify(ref catIndices);
int columnSlotsCount = vecType.Size;
if (catIndices.Length > 0 && catIndices.Length % 2 == 0 && catIndices.Length <= columnSlotsCount * 2)
@@ -471,7 +483,8 @@ public static bool TryGetCategoricalFeatureIndices(Schema schema, int colIndex,
/// Produces sequence of columns that are generated by trainer estimators.
///
/// whether we should also append 'IsNormalized' (typically for probability column)
- public static IEnumerable GetTrainerOutputMetadata(bool isNormalized = false)
+ [BestFriend]
+ internal static IEnumerable GetTrainerOutputMetadata(bool isNormalized = false)
{
var cols = new List();
cols.Add(new SchemaShape.Column(Kinds.ScoreColumnSetId, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true));
@@ -483,16 +496,48 @@ public static bool TryGetCategoricalFeatureIndices(Schema schema, int colIndex,
}
///
- /// Produces sequence of columns that are generated by multiclass trainer estimators.
+ /// Produces metadata for the score column generated by trainer estimators for multiclass classification.
+ /// If input LabelColumn is not available it produces slotnames metadata by default.
///
/// Label column.
- public static IEnumerable MetadataForMulticlassScoreColumn(SchemaShape.Column labelColumn)
+ [BestFriend]
+ internal static IEnumerable MetadataForMulticlassScoreColumn(SchemaShape.Column? labelColumn = null)
{
var cols = new List();
- if (labelColumn.IsKey && HasKeyValues(labelColumn))
+ if (labelColumn != null && labelColumn.Value.IsKey && HasKeyValues(labelColumn.Value))
cols.Add(new SchemaShape.Column(Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false));
cols.AddRange(GetTrainerOutputMetadata());
return cols;
}
+
+ private sealed class MetadataRow : Row
+ {
+ private readonly Schema.Metadata _metadata;
+
+ public MetadataRow(Schema.Metadata metadata)
+ {
+ Contracts.AssertValue(metadata);
+ _metadata = metadata;
+ }
+
+ public override Schema Schema => _metadata.Schema;
+ public override long Position => 0;
+ public override long Batch => 0;
+ public override ValueGetter GetGetter(int col) => _metadata.GetGetter(col);
+ public override ValueGetter GetIdGetter() => (ref RowId dst) => dst = default;
+ public override bool IsColumnActive(int col) => true;
+ }
+
+ ///
+ /// Presents a as a an .
+ ///
+ /// The metadata to wrap.
+ /// A row that wraps an input metadata.
+ [BestFriend]
+ internal static Row MetadataAsRow(Schema.Metadata metadata)
+ {
+ Contracts.CheckValue(metadata, nameof(metadata));
+ return new MetadataRow(metadata);
+ }
}
}
\ No newline at end of file
diff --git a/src/Microsoft.ML.Core/Data/ProgressReporter.cs b/src/Microsoft.ML.Core/Data/ProgressReporter.cs
index 191364e2a3..e49c559571 100644
--- a/src/Microsoft.ML.Core/Data/ProgressReporter.cs
+++ b/src/Microsoft.ML.Core/Data/ProgressReporter.cs
@@ -7,9 +7,9 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading;
-using Microsoft.ML.Runtime.Internal.Utilities;
+using Microsoft.ML.Internal.Utilities;
-namespace Microsoft.ML.Runtime.Data
+namespace Microsoft.ML.Data
{
///
/// The progress reporting classes used by descendants.
diff --git a/src/Microsoft.ML.Core/Data/ReadOnlyMemoryUtils.cs b/src/Microsoft.ML.Core/Data/ReadOnlyMemoryUtils.cs
index 20ebb85b04..e776701030 100644
--- a/src/Microsoft.ML.Core/Data/ReadOnlyMemoryUtils.cs
+++ b/src/Microsoft.ML.Core/Data/ReadOnlyMemoryUtils.cs
@@ -2,13 +2,13 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using Microsoft.ML.Runtime.Internal.Utilities;
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;
+using Microsoft.ML.Internal.Utilities;
-namespace Microsoft.ML.Runtime.Data
+namespace Microsoft.ML.Data
{
[BestFriend]
internal static class ReadOnlyMemoryUtils
diff --git a/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs b/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs
index a8d22319b3..acd5c957e2 100644
--- a/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs
+++ b/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs
@@ -3,76 +3,12 @@
// See the LICENSE file in the project root for more information.
using System.Collections.Generic;
-using Microsoft.ML.Data;
-using Microsoft.ML.Runtime.Internal.Utilities;
+using Microsoft.ML.Internal.Utilities;
-namespace Microsoft.ML.Runtime.Data
+namespace Microsoft.ML.Data
{
///
- /// This contains information about a column in an . It is essentially a convenience cache
- /// containing the name, column index, and column type for the column. The intended usage is that users of
- /// will have a convenient method of getting the index and type without having to separately query it through the ,
- /// since practically the first thing a consumer of a will want to do once they get a mappping is get
- /// the type and index of the corresponding column.
- ///
- public sealed class ColumnInfo
- {
- public readonly string Name;
- public readonly int Index;
- public readonly ColumnType Type;
-
- private ColumnInfo(string name, int index, ColumnType type)
- {
- Name = name;
- Index = index;
- Type = type;
- }
-
- ///
- /// Create a ColumnInfo for the column with the given name in the given schema. Throws if the name
- /// doesn't map to a column.
- ///
- public static ColumnInfo CreateFromName(ISchema schema, string name, string descName)
- {
- if (!TryCreateFromName(schema, name, out var colInfo))
- throw Contracts.ExceptParam(nameof(name), $"{descName} column '{name}' not found");
-
- return colInfo;
- }
-
- ///
- /// Tries to create a ColumnInfo for the column with the given name in the given schema. Returns
- /// false if the name doesn't map to a column.
- ///
- public static bool TryCreateFromName(ISchema schema, string name, out ColumnInfo colInfo)
- {
- Contracts.CheckValue(schema, nameof(schema));
- Contracts.CheckNonEmpty(name, nameof(name));
-
- colInfo = null;
- if (!schema.TryGetColumnIndex(name, out int index))
- return false;
-
- colInfo = new ColumnInfo(name, index, schema.GetColumnType(index));
- return true;
- }
-
- ///
- /// Creates a ColumnInfo for the column with the given column index. Note that the name
- /// of the column might actually map to a different column, so this should be used with care
- /// and rarely.
- ///
- public static ColumnInfo CreateFromIndex(ISchema schema, int index)
- {
- Contracts.CheckValue(schema, nameof(schema));
- Contracts.CheckParam(0 <= index && index < schema.ColumnCount, nameof(index));
-
- return new ColumnInfo(schema.GetColumnName(index), index, schema.GetColumnType(index));
- }
- }
-
- ///
- /// Encapsulates an plus column role mapping information. The purpose of role mappings is to
+ /// Encapsulates an plus column role mapping information. The purpose of role mappings is to
/// provide information on what the intended usage is for. That is: while a given data view may have a column named
/// "Features", by itself that is insufficient: the trainer must be fed a role mapping that says that the role
/// mapping for features is filled by that "Features" column. This allows things like columns not named "Features"
@@ -88,7 +24,7 @@ public static ColumnInfo CreateFromIndex(ISchema schema, int index)
/// in this schema.
///
///
- /// Note that instances of this class are, like instances of , immutable.
+ /// Note that instances of this class are, like instances of , immutable.
///
/// It is often the case that one wishes to bundle the actual data with the role mappings, not just the schema. For
/// that case, please use the class.
@@ -100,7 +36,8 @@ public static ColumnInfo CreateFromIndex(ISchema schema, int index)
///
///
///
- public sealed class RoleMappedSchema
+ [BestFriend]
+ internal sealed class RoleMappedSchema
{
private const string FeatureString = "Feature";
private const string LabelString = "Label";
@@ -184,39 +121,39 @@ public static KeyValuePair CreatePair(ColumnRole role, strin
=> new KeyValuePair(role, name);
///
- /// The source .
+ /// The source .
///
public Schema Schema { get; }
///
/// The column, when there is exactly one (null otherwise).
///
- public ColumnInfo Feature { get; }
+ public Schema.Column? Feature { get; }
///
/// The column, when there is exactly one (null otherwise).
///
- public ColumnInfo Label { get; }
+ public Schema.Column? Label { get; }
///
/// The column, when there is exactly one (null otherwise).
///
- public ColumnInfo Group { get; }
+ public Schema.Column? Group { get; }
///
/// The column, when there is exactly one (null otherwise).
///
- public ColumnInfo Weight { get; }
+ public Schema.Column? Weight { get; }
///
/// The column, when there is exactly one (null otherwise).
///
- public ColumnInfo Name { get; }
+ public Schema.Column? Name { get; }
// Maps from role to the associated column infos.
- private readonly Dictionary> _map;
+ private readonly Dictionary> _map;
- private RoleMappedSchema(Schema schema, Dictionary> map)
+ private RoleMappedSchema(Schema schema, Dictionary> map)
{
Contracts.AssertValue(schema);
Contracts.AssertValue(map);
@@ -229,7 +166,7 @@ private RoleMappedSchema(Schema schema, Dictionary> map)
+ private RoleMappedSchema(Schema schema, Dictionary> map)
: this(schema, Copy(map))
{
}
- private static void Add(Dictionary> map, ColumnRole role, ColumnInfo info)
+ private static void Add(Dictionary> map, ColumnRole role, Schema.Column column)
{
Contracts.AssertValue(map);
Contracts.AssertNonEmpty(role.Value);
- Contracts.AssertValue(info);
if (!map.TryGetValue(role.Value, out var list))
{
- list = new List();
+ list = new List();
map.Add(role.Value, list);
}
- list.Add(info);
+ list.Add(column);
}
- private static Dictionary> MapFromNames(ISchema schema, IEnumerable> roles, bool opt = false)
+ private static Dictionary> MapFromNames(Schema schema, IEnumerable> roles, bool opt = false)
{
Contracts.AssertValue(schema);
Contracts.AssertValue(roles);
- var map = new Dictionary>();
+ var map = new Dictionary>();
foreach (var kvp in roles)
{
Contracts.AssertNonEmpty(kvp.Key.Value);
if (string.IsNullOrEmpty(kvp.Value))
continue;
- ColumnInfo info;
- if (!opt)
- info = ColumnInfo.CreateFromName(schema, kvp.Value, kvp.Key.Value);
- else if (!ColumnInfo.TryCreateFromName(schema, kvp.Value, out info))
- continue;
- Add(map, kvp.Key.Value, info);
+ var info = schema.GetColumnOrNull(kvp.Value);
+ if (info.HasValue)
+ Add(map, kvp.Key.Value, info.Value);
+ else if (!opt)
+ throw Contracts.ExceptParam(nameof(schema), $"{kvp.Value} column '{kvp.Key.Value}' not found");
}
return map;
}
@@ -317,18 +252,18 @@ public bool HasMultiple(ColumnRole role)
/// If there are columns of the given role, this returns the infos as a readonly list. Otherwise,
/// it returns null.
///
- public IReadOnlyList GetColumns(ColumnRole role)
+ public IReadOnlyList GetColumns(ColumnRole role)
=> _map.TryGetValue(role.Value, out var list) ? list : null;
///
/// An enumerable over all role-column associations within this object.
///
- public IEnumerable> GetColumnRoles()
+ public IEnumerable> GetColumnRoles()
{
foreach (var roleAndList in _map)
{
foreach (var info in roleAndList.Value)
- yield return new KeyValuePair(roleAndList.Key, info);
+ yield return new KeyValuePair(roleAndList.Key, info);
}
}
@@ -358,13 +293,13 @@ public IEnumerable> GetColumnRoleNames(ColumnRo
}
///
- /// Returns the corresponding to if there is
+ /// Returns the corresponding to if there is
/// exactly one such mapping, and otherwise throws an exception.
///
/// The role to look up
- /// The info corresponding to that role, assuming there was only one column
+ /// The column corresponding to that role, assuming there was only one column
/// mapped to that
- public ColumnInfo GetUniqueColumn(ColumnRole role)
+ public Schema.Column GetUniqueColumn(ColumnRole role)
{
var infos = GetColumns(role);
if (Utils.Size(infos) != 1)
@@ -372,9 +307,9 @@ public ColumnInfo GetUniqueColumn(ColumnRole role)
return infos[0];
}
- private static Dictionary> Copy(Dictionary> map)
+ private static Dictionary> Copy(Dictionary> map)
{
- var copy = new Dictionary>(map.Count);
+ var copy = new Dictionary>(map.Count);
foreach (var kvp in map)
{
Contracts.Assert(Utils.Size(kvp.Value) > 0);
@@ -467,9 +402,10 @@ public RoleMappedSchema(Schema schema, string label, string feature,
///
/// Encapsulates an plus a corresponding .
/// Note that the schema of of is
- /// guaranteed to equal the the of .
+ /// guaranteed to equal the the of .
///
- public sealed class RoleMappedData
+ [BestFriend]
+ internal sealed class RoleMappedData
{
///
/// The data.
@@ -478,7 +414,7 @@ public sealed class RoleMappedData
///
/// The role mapped schema. Note that 's is
- /// guaranteed to be the same as 's .
+ /// guaranteed to be the same as 's .
///
public RoleMappedSchema Schema { get; }
diff --git a/src/Microsoft.ML.Core/Data/RootCursorBase.cs b/src/Microsoft.ML.Core/Data/RootCursorBase.cs
index 5b64a40d6a..dae6c32888 100644
--- a/src/Microsoft.ML.Core/Data/RootCursorBase.cs
+++ b/src/Microsoft.ML.Core/Data/RootCursorBase.cs
@@ -2,9 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using System;
-
-namespace Microsoft.ML.Runtime.Data
+namespace Microsoft.ML.Data
{
// REVIEW: Since each cursor will create a channel, it would be great that the RootCursorBase takes
// ownership of the channel so the derived classes don't have to.
@@ -15,23 +13,21 @@ namespace Microsoft.ML.Runtime.Data
/// This cursor base class returns "this" from . That is, all
/// / calls will be seen by this cursor. For a cursor
/// that has an input cursor and does NOT need notification on /,
- /// use .
+ /// use .
///
[BestFriend]
- internal abstract class RootCursorBase : ICursor
+ internal abstract class RootCursorBase : RowCursor
{
protected readonly IChannel Ch;
+ private CursorState _state;
+ private long _position;
///
/// Zero-based position of the cursor.
///
- public long Position { get; private set; }
-
- public abstract long Batch { get; }
-
- public abstract ValueGetter GetIdGetter();
+ public sealed override long Position => _position;
- public CursorState State { get; private set; }
+ public sealed override CursorState State => _state;
///
/// Convenience property for checking whether the current state of the cursor is .
@@ -39,7 +35,7 @@ internal abstract class RootCursorBase : ICursor
protected bool IsGood => State == CursorState.Good;
///
- /// Creates an instance of the RootCursorBase class
+ /// Creates an instance of the class
///
/// Channel provider
protected RootCursorBase(IChannelProvider provider)
@@ -47,21 +43,21 @@ protected RootCursorBase(IChannelProvider provider)
Contracts.CheckValue(provider, nameof(provider));
Ch = provider.Start("Cursor");
- Position = -1;
- State = CursorState.NotStarted;
+ _position = -1;
+ _state = CursorState.NotStarted;
}
- public virtual void Dispose()
+ protected override void Dispose(bool disposing)
{
- if (State != CursorState.Done)
- {
+ if (State == CursorState.Done)
+ return;
+ if (disposing)
Ch.Dispose();
- Position = -1;
- State = CursorState.Done;
- }
+ _position = -1;
+ _state = CursorState.Done;
}
- public bool MoveNext()
+ public sealed override bool MoveNext()
{
if (State == CursorState.Done)
return false;
@@ -71,8 +67,8 @@ public bool MoveNext()
{
Ch.Assert(State == CursorState.NotStarted || State == CursorState.Good);
- Position++;
- State = CursorState.Good;
+ _position++;
+ _state = CursorState.Good;
return true;
}
@@ -80,7 +76,7 @@ public bool MoveNext()
return false;
}
- public bool MoveMany(long count)
+ public sealed override bool MoveMany(long count)
{
// Note: If we decide to allow count == 0, then we need to special case
// that MoveNext() has never been called. It's not entirely clear what the return
@@ -95,8 +91,8 @@ public bool MoveMany(long count)
{
Ch.Assert(State == CursorState.NotStarted || State == CursorState.Good);
- Position += count;
- State = CursorState.Good;
+ _position += count;
+ _state = CursorState.Good;
return true;
}
@@ -137,6 +133,6 @@ protected virtual bool MoveManyCore(long count)
/// those on this cursor. Generally, if the root cursor is not the same as this cursor, using
/// the root cursor will be faster.
///
- public ICursor GetRootCursor() => this;
+ public override RowCursor GetRootCursor() => this;
}
}
\ No newline at end of file
diff --git a/src/Microsoft.ML.Core/Data/UInt128.cs b/src/Microsoft.ML.Core/Data/RowId.cs
similarity index 58%
rename from src/Microsoft.ML.Core/Data/UInt128.cs
rename to src/Microsoft.ML.Core/Data/RowId.cs
index 017068a65d..315b1939b1 100644
--- a/src/Microsoft.ML.Core/Data/UInt128.cs
+++ b/src/Microsoft.ML.Core/Data/RowId.cs
@@ -4,114 +4,123 @@
using System;
using System.Runtime.CompilerServices;
-using Microsoft.ML.Runtime.Internal.Utilities;
-namespace Microsoft.ML.Runtime.Data
+namespace Microsoft.ML.Data
{
///
- /// A sixteen-byte unsigned integer.
+ /// A structure serving as a sixteen-byte unsigned integer. It is used as the row id of .
+ /// For datasets with millions of records, those IDs need to be unique, therefore the need for such a large structure to hold the values.
+ /// Those Ids are derived from other Ids of the previous components of the pipelines, and dividing the structure in two: high order and low order of bits,
+ /// and reduces the changes of those collisions even further.
///
- public readonly struct UInt128 : IComparable, IEquatable
+ ///
+ public readonly struct RowId : IComparable, IEquatable
{
- // The low order bits. Corresponds to H1 in the Murmur algorithms.
- public readonly ulong Lo;
- // The high order bits. Corresponds to H2 in the Murmur algorithms.
- public readonly ulong Hi;
+ ///The low order bits. Corresponds to H1 in the Murmur algorithms.
+ public readonly ulong Low;
- public UInt128(ulong lo, ulong hi)
+ /// The high order bits. Corresponds to H2 in the Murmur algorithms.
+ public readonly ulong High;
+
+ ///
+ /// Initializes a new instance of
+ ///
+ /// The low order ulong.
+ /// The high order ulong.
+ public RowId(ulong low, ulong high)
{
- Lo = lo;
- Hi = hi;
+ Low = low;
+ High = high;
}
public override string ToString()
{
// Since H1 are the low order bits, they are printed second.
- return string.Format("{0:x16}{1:x16}", Hi, Lo);
+ return string.Format("{0:x16}{1:x16}", High, Low);
}
- public int CompareTo(UInt128 other)
+ public int CompareTo(RowId other)
{
- int result = Hi.CompareTo(other.Hi);
- return result == 0 ? Lo.CompareTo(other.Lo) : result;
+ int result = High.CompareTo(other.High);
+ return result == 0 ? Low.CompareTo(other.Low) : result;
}
- public bool Equals(UInt128 other)
+ public bool Equals(RowId other)
{
- return Lo == other.Lo && Hi == other.Hi;
+ return Low == other.Low && High == other.High;
}
public override bool Equals(object obj)
{
- if (obj != null && obj is UInt128)
+ if (obj != null && obj is RowId)
{
- var item = (UInt128)obj;
+ var item = (RowId)obj;
return Equals(item);
}
return false;
}
- public static UInt128 operator +(UInt128 first, ulong second)
+ public static RowId operator +(RowId first, ulong second)
{
- ulong resHi = first.Hi;
- ulong resLo = first.Lo + second;
+ ulong resHi = first.High;
+ ulong resLo = first.Low + second;
if (resLo < second)
resHi++;
- return new UInt128(resLo, resHi);
+ return new RowId(resLo, resHi);
}
- public static UInt128 operator -(UInt128 first, ulong second)
+ public static RowId operator -(RowId first, ulong second)
{
- ulong resHi = first.Hi;
- ulong resLo = first.Lo - second;
- if (resLo > first.Lo)
+ ulong resHi = first.High;
+ ulong resLo = first.Low - second;
+ if (resLo > first.Low)
resHi--;
- return new UInt128(resLo, resHi);
+ return new RowId(resLo, resHi);
}
- public static bool operator ==(UInt128 first, ulong second)
+ public static bool operator ==(RowId first, ulong second)
{
- return first.Hi == 0 && first.Lo == second;
+ return first.High == 0 && first.Low == second;
}
- public static bool operator !=(UInt128 first, ulong second)
+ public static bool operator !=(RowId first, ulong second)
{
return !(first == second);
}
- public static bool operator <(UInt128 first, ulong second)
+ public static bool operator <(RowId first, ulong second)
{
- return first.Hi == 0 && first.Lo < second;
+ return first.High == 0 && first.Low < second;
}
- public static bool operator >(UInt128 first, ulong second)
+ public static bool operator >(RowId first, ulong second)
{
- return first.Hi > 0 || first.Lo > second;
+ return first.High > 0 || first.Low > second;
}
- public static bool operator <=(UInt128 first, ulong second)
+ public static bool operator <=(RowId first, ulong second)
{
- return first.Hi == 0 && first.Lo <= second;
+ return first.High == 0 && first.Low <= second;
}
- public static bool operator >=(UInt128 first, ulong second)
+ public static bool operator >=(RowId first, ulong second)
{
- return first.Hi > 0 || first.Lo >= second;
+ return first.High > 0 || first.Low >= second;
}
- public static explicit operator double(UInt128 x)
+ public static explicit operator double(RowId x)
{
// REVIEW: The 64-bit JIT has a bug where rounding might be not quite
// correct when converting a ulong to double with the high bit set. Should we
// care and compensate? See the DoubleParser code for a work-around.
- return x.Hi * ((double)(1UL << 32) * (1UL << 32)) + x.Lo;
+ return x.High * ((double)(1UL << 32) * (1UL << 32)) + x.Low;
}
public override int GetHashCode()
{
return (int)(
- (uint)Lo ^ (uint)(Lo >> 32) ^
- (uint)(Hi << 7) ^ (uint)(Hi >> 57) ^ (uint)(Hi >> (57 - 32)));
+ (uint)Low ^ (uint)(Low >> 32) ^
+ (uint)(High << 7) ^ (uint)(High >> 57) ^ (uint)(High >> (57 - 32)));
}
#region Hashing style
@@ -167,10 +176,10 @@ private static void FinalMix(ref ulong h1, ref ulong h2, int len)
/// that were all zeros, except for the last bit which is one.
///
[MethodImpl(MethodImplOptions.AggressiveInlining)]
- public UInt128 Fork()
+ public RowId Fork()
{
- ulong h1 = Lo;
- ulong h2 = Hi;
+ ulong h1 = Low;
+ ulong h2 = High;
// Here it's as if k1=1, k2=0.
h1 = RotL(h1, 27);
h1 += h2;
@@ -179,7 +188,7 @@ public UInt128 Fork()
h2 += h1;
h2 = h2 * 5 + 0x38495ab5;
h1 ^= RotL(_c1, 31) * _c2;
- return new UInt128(h1, h2);
+ return new RowId(h1, h2);
}
///
@@ -188,10 +197,10 @@ public UInt128 Fork()
/// that were all zeros.
///
[MethodImpl(MethodImplOptions.AggressiveInlining)]
- public UInt128 Next()
+ public RowId Next()
{
- ulong h1 = Lo;
- ulong h2 = Hi;
+ ulong h1 = Low;
+ ulong h2 = High;
// Here it's as if k1=0, k2=0.
h1 = RotL(h1, 27);
h1 += h2;
@@ -199,7 +208,7 @@ public UInt128 Next()
h2 = RotL(h2, 31);
h2 += h1;
h2 = h2 * 5 + 0x38495ab5;
- return new UInt128(h1, h2);
+ return new RowId(h1, h2);
}
///
@@ -210,14 +219,14 @@ public UInt128 Next()
///
///
[MethodImpl(MethodImplOptions.AggressiveInlining)]
- public UInt128 Combine(UInt128 other)
+ public RowId Combine(RowId other)
{
- var h1 = Lo;
- var h2 = Hi;
+ var h1 = Low;
+ var h2 = High;
other = other.Fork();
- ulong k1 = other.Lo; // First 8 bytes.
- ulong k2 = other.Hi; // Second 8 bytes.
+ ulong k1 = other.Low; // First 8 bytes.
+ ulong k2 = other.High; // Second 8 bytes.
k1 *= _c1;
k1 = RotL(k1, 31);
@@ -235,7 +244,7 @@ public UInt128 Combine(UInt128 other)
h2 += h1;
h2 = h2 * 5 + 0x38495ab5;
- return new UInt128(h1, h2);
+ return new RowId(h1, h2);
}
#endregion
}
diff --git a/src/Microsoft.ML.Core/Data/Schema.cs b/src/Microsoft.ML.Core/Data/Schema.cs
index 9be55ca48f..0977865544 100644
--- a/src/Microsoft.ML.Core/Data/Schema.cs
+++ b/src/Microsoft.ML.Core/Data/Schema.cs
@@ -2,19 +2,17 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using Microsoft.ML.Runtime;
-using Microsoft.ML.Runtime.Data;
-using Microsoft.ML.Runtime.Internal.Utilities;
using System;
using System.Collections;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
+using Microsoft.ML.Internal.Utilities;
namespace Microsoft.ML.Data
{
///
- /// This class represents the schema of an object (like an or an ).
+ /// This class represents the of an object like, for interstance, an or an .
/// On the high level, the schema is a collection of 'columns'. Each column has the following properties:
/// - Column name.
/// - Column type.
@@ -22,16 +20,11 @@ namespace Microsoft.ML.Data
/// and values.
///
[System.Diagnostics.DebuggerTypeProxy(typeof(SchemaDebuggerProxy))]
- public sealed class Schema : ISchema, IReadOnlyList
+ public sealed class Schema : IReadOnlyList
{
private readonly Column[] _columns;
private readonly Dictionary _nameMap;
- ///
- /// Number of columns in the schema.
- ///
- public int ColumnCount => _columns.Length;
-
///
/// Number of columns in the schema.
///
@@ -194,7 +187,7 @@ public sealed class Metadata
///
public Schema Schema { get; }
- public static Metadata Empty { get; } = new Metadata(new Schema(Enumerable.Empty()), new Delegate[0]);
+ public static Metadata Empty { get; } = new Metadata(new Schema(new Column[0]), new Delegate[0]);
///
/// Create a metadata row by supplying the schema columns and the getter delegates for all the values.
@@ -249,18 +242,19 @@ public void GetValue(string kind, ref TValue value)
GetGetter(column.Value.Index)(ref value);
}
- public override string ToString() => string.Join(", ", Schema.GetColumns().Select(x => x.column.Name));
+ public override string ToString() => string.Join(", ", Schema.Select(x => x.Name));
}
///
/// This constructor should only be called by .
///
- internal Schema(IEnumerable columns)
+ /// The input columns. The constructed instance takes ownership of the array.
+ internal Schema(Column[] columns)
{
Contracts.CheckValue(columns, nameof(columns));
- _columns = columns.ToArray();
+ _columns = columns;
_nameMap = new Dictionary();
for (int i = 0; i < _columns.Length; i++)
{
@@ -269,11 +263,6 @@ internal Schema(IEnumerable columns)
}
}
- ///
- /// Get all non-hidden columns as pairs of (index, ).
- ///
- public IEnumerable<(int index, Column column)> GetColumns() => _nameMap.Values.Select(idx => (idx, _columns[idx]));
-
///
/// Manufacture an instance of out of any .
///
@@ -282,9 +271,6 @@ internal static Schema Create(ISchema inputSchema)
{
Contracts.CheckValue(inputSchema, nameof(inputSchema));
- if (inputSchema is Schema s)
- return s;
-
var builder = new SchemaBuilder();
for (int i = 0; i < inputSchema.ColumnCount; i++)
{
@@ -309,42 +295,15 @@ private static Delegate GetMetadataGetterDelegate(ISchema schema, int co
return getter;
}
- #region Legacy schema API to be removed
- public string GetColumnName(int col) => this[col].Name;
-
- public ColumnType GetColumnType(int col) => this[col].Type;
-
- public IEnumerable> GetMetadataTypes(int col)
- {
- var meta = this[col].Metadata;
- if (meta == null)
- return Enumerable.Empty>();
- return meta.Schema.GetColumns().Select(c => new KeyValuePair(c.column.Name, c.column.Type));
- }
-
- public ColumnType GetMetadataTypeOrNull(string kind, int col)
- {
- var meta = this[col].Metadata;
- if (meta == null)
- return null;
- if (((ISchema)meta.Schema).TryGetColumnIndex(kind, out int metaCol))
- return meta.Schema[metaCol].Type;
- return null;
- }
-
- public void GetMetadata(string kind, int col, ref TValue value)
- {
- var meta = this[col].Metadata;
- if (meta == null)
- throw MetadataUtils.ExceptGetMetadata();
- meta.GetValue(kind, ref value);
- }
-
- public bool TryGetColumnIndex(string name, out int col)
+ ///
+ /// Legacy method to get the column index.
+ /// DO NOT USE: use instead.
+ ///
+ [BestFriend]
+ internal bool TryGetColumnIndex(string name, out int col)
{
col = GetColumnOrNull(name)?.Index ?? -1;
return col >= 0;
}
- #endregion
}
}
diff --git a/src/Microsoft.ML.Core/Data/SchemaBuilder.cs b/src/Microsoft.ML.Core/Data/SchemaBuilder.cs
index 9fb22a7026..41f109530f 100644
--- a/src/Microsoft.ML.Core/Data/SchemaBuilder.cs
+++ b/src/Microsoft.ML.Core/Data/SchemaBuilder.cs
@@ -2,10 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using Microsoft.ML.Runtime.Data;
-using System;
using System.Collections.Generic;
-using System.Text;
namespace Microsoft.ML.Data
{
@@ -30,8 +27,11 @@ public SchemaBuilder()
/// The column name.
/// The column type.
/// The column metadata.
- public void AddColumn(string name, ColumnType type, Schema.Metadata metadata)
+ public void AddColumn(string name, ColumnType type, Schema.Metadata metadata = null)
{
+ Contracts.CheckNonEmpty(name, nameof(name));
+ Contracts.CheckValue(type, nameof(type));
+ Contracts.CheckValueOrNull(metadata);
_items.Add((name, type, metadata));
}
diff --git a/src/Microsoft.ML.Core/Data/SchemaDebuggerProxy.cs b/src/Microsoft.ML.Core/Data/SchemaDebuggerProxy.cs
index 7e7c40fe03..00fff2d6a5 100644
--- a/src/Microsoft.ML.Core/Data/SchemaDebuggerProxy.cs
+++ b/src/Microsoft.ML.Core/Data/SchemaDebuggerProxy.cs
@@ -2,10 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using Microsoft.ML.Runtime.Data;
-using Microsoft.ML.Runtime.Internal.Utilities;
using System.Collections.Generic;
using System.Linq;
+using Microsoft.ML.Internal.Utilities;
namespace Microsoft.ML.Data
{
@@ -39,10 +38,10 @@ public MetadataDebuggerProxy(Schema.Metadata metadata)
private static List> BuildValues(Schema.Metadata metadata)
{
var result = new List>();
- foreach ((var index, var column) in metadata.Schema.GetColumns())
+ foreach (var column in metadata.Schema)
{
var name = column.Name;
- var value = Utils.MarshalInvoke(GetValue, column.Type.RawType, metadata, index);
+ var value = Utils.MarshalInvoke(GetValue, column.Type.RawType, metadata, column.Index);
result.Add(new KeyValuePair(name, value));
}
return result;
diff --git a/src/Microsoft.ML.Core/Data/ServerChannel.cs b/src/Microsoft.ML.Core/Data/ServerChannel.cs
index a9b33d1986..c85be7eb82 100644
--- a/src/Microsoft.ML.Core/Data/ServerChannel.cs
+++ b/src/Microsoft.ML.Core/Data/ServerChannel.cs
@@ -4,11 +4,9 @@
using System;
using System.Collections.Generic;
-using System.Reflection;
-using Microsoft.ML.Runtime.CommandLine;
-using Microsoft.ML.Runtime.EntryPoints;
+using Microsoft.ML.EntryPoints;
-namespace Microsoft.ML.Runtime
+namespace Microsoft.ML
{
///
/// Instances of this class are used to set up a bundle of named delegates. These
diff --git a/src/Microsoft.ML.Core/Data/SynchronizedCursorBase.cs b/src/Microsoft.ML.Core/Data/SynchronizedCursorBase.cs
index da60c84ccf..e8b94e7dbd 100644
--- a/src/Microsoft.ML.Core/Data/SynchronizedCursorBase.cs
+++ b/src/Microsoft.ML.Core/Data/SynchronizedCursorBase.cs
@@ -2,7 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-namespace Microsoft.ML.Runtime.Data
+namespace Microsoft.ML.Data
{
///
/// Base class for creating a cursor on top of another cursor that does not add or remove rows.
@@ -11,28 +11,27 @@ namespace Microsoft.ML.Runtime.Data
/// Dispose is virtual with the default implementation delegating to the input cursor.
///
[BestFriend]
- internal abstract class SynchronizedCursorBase : ICursor
- where TBase : class, ICursor
+ internal abstract class SynchronizedCursorBase : RowCursor
{
protected readonly IChannel Ch;
- private readonly ICursor _root;
+ private readonly RowCursor _root;
private bool _disposed;
- protected TBase Input { get; }
+ protected RowCursor Input { get; }
- public long Position => _root.Position;
+ public sealed override long Position => _root.Position;
- public long Batch => _root.Batch;
+ public sealed override long Batch => _root.Batch;
- public CursorState State => _root.State;
+ public sealed override CursorState State => _root.State;
///
/// Convenience property for checking whether the current state is CursorState.Good.
///
protected bool IsGood => _root.State == CursorState.Good;
- protected SynchronizedCursorBase(IChannelProvider provider, TBase input)
+ protected SynchronizedCursorBase(IChannelProvider provider, RowCursor input)
{
Contracts.AssertValue(provider, "provider");
Ch = provider.Start("Cursor");
@@ -42,34 +41,25 @@ protected SynchronizedCursorBase(IChannelProvider provider, TBase input)
_root = Input.GetRootCursor();
}
- public virtual void Dispose()
+ protected override void Dispose(bool disposing)
{
- if (!_disposed)
+ if (_disposed)
+ return;
+ if (disposing)
{
Input.Dispose();
Ch.Dispose();
- _disposed = true;
}
+ base.Dispose(disposing);
+ _disposed = true;
}
- public bool MoveNext()
- {
- return _root.MoveNext();
- }
+ public sealed override bool MoveNext() => _root.MoveNext();
- public bool MoveMany(long count)
- {
- return _root.MoveMany(count);
- }
+ public sealed override bool MoveMany(long count) => _root.MoveMany(count);
- public ICursor GetRootCursor()
- {
- return _root;
- }
+ public sealed override RowCursor GetRootCursor() => _root;
- public ValueGetter GetIdGetter()
- {
- return Input.GetIdGetter();
- }
+ public sealed override ValueGetter GetIdGetter() => Input.GetIdGetter();
}
}
diff --git a/src/Microsoft.ML.Core/Data/VBuffer.cs b/src/Microsoft.ML.Core/Data/VBuffer.cs
index a86f0bdae4..949857bc2b 100644
--- a/src/Microsoft.ML.Core/Data/VBuffer.cs
+++ b/src/Microsoft.ML.Core/Data/VBuffer.cs
@@ -4,9 +4,9 @@
using System;
using System.Collections.Generic;
-using Microsoft.ML.Runtime.Internal.Utilities;
+using Microsoft.ML.Internal.Utilities;
-namespace Microsoft.ML.Runtime.Data
+namespace Microsoft.ML.Data
{
///
/// A buffer that supports both dense and sparse representations. This is the
diff --git a/src/Microsoft.ML.Core/Data/VBufferEditor.cs b/src/Microsoft.ML.Core/Data/VBufferEditor.cs
index 8da19b641f..1dafac0fa5 100644
--- a/src/Microsoft.ML.Core/Data/VBufferEditor.cs
+++ b/src/Microsoft.ML.Core/Data/VBufferEditor.cs
@@ -4,7 +4,7 @@
using System;
-namespace Microsoft.ML.Runtime.Data
+namespace Microsoft.ML.Data
{
///
/// Various methods for creating instances.
diff --git a/src/Microsoft.ML.Core/Data/WrappingRow.cs b/src/Microsoft.ML.Core/Data/WrappingRow.cs
new file mode 100644
index 0000000000..f6568cf205
--- /dev/null
+++ b/src/Microsoft.ML.Core/Data/WrappingRow.cs
@@ -0,0 +1,64 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+
+namespace Microsoft.ML.Data
+{
+ ///
+ /// Convenient base class for implementors that wrap a single
+ /// as their input. The , , and
+ /// are taken from this .
+ ///
+ [BestFriend]
+ internal abstract class WrappingRow : Row
+ {
+ private bool _disposed;
+
+ ///
+ /// The wrapped input row.
+ ///
+ protected Row Input { get; }
+
+ public sealed override long Batch => Input.Batch;
+ public sealed override long Position => Input.Position;
+ public override ValueGetter GetIdGetter() => Input.GetIdGetter();
+
+ [BestFriend]
+ private protected WrappingRow(Row input)
+ {
+ Contracts.AssertValue(input);
+ Input = input;
+ }
+
+ ///
+ /// This override of the dispose method by default only calls 's
+ /// method, but subclasses can enable additional functionality
+ /// via the functionality.
+ ///
+ ///
+ protected sealed override void Dispose(bool disposing)
+ {
+ if (_disposed)
+ return;
+ // Since the input was created first, and this instance may depend on it, we should
+ // dispose local resources first before potentially disposing the input row resources.
+ DisposeCore(disposing);
+ if (disposing)
+ Input.Dispose();
+ _disposed = true;
+ }
+
+ ///
+ /// Called from with in the case where
+ /// that method has never been called before, and right after has been
+ /// disposed. The default implementation does nothing.
+ ///
+ /// Whether this was called through the dispose path, as opposed
+ /// to the finalizer path.
+ protected virtual void DisposeCore(bool disposing)
+ {
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Core/EntryPoints/EntryPointModuleAttribute.cs b/src/Microsoft.ML.Core/EntryPoints/EntryPointModuleAttribute.cs
index 0163222fc1..175e7cb09a 100644
--- a/src/Microsoft.ML.Core/EntryPoints/EntryPointModuleAttribute.cs
+++ b/src/Microsoft.ML.Core/EntryPoints/EntryPointModuleAttribute.cs
@@ -4,7 +4,7 @@
using System;
-namespace Microsoft.ML.Runtime.EntryPoints
+namespace Microsoft.ML.EntryPoints
{
///
/// This is a signature for classes that are 'holders' of entry points and components.
diff --git a/src/Microsoft.ML.Core/EntryPoints/EntryPointUtils.cs b/src/Microsoft.ML.Core/EntryPoints/EntryPointUtils.cs
index a7c3ddd298..82edc2bfaa 100644
--- a/src/Microsoft.ML.Core/EntryPoints/EntryPointUtils.cs
+++ b/src/Microsoft.ML.Core/EntryPoints/EntryPointUtils.cs
@@ -2,13 +2,13 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using Microsoft.ML.Data;
-using Microsoft.ML.Runtime.CommandLine;
-using Microsoft.ML.Runtime.Internal.Utilities;
using System;
using System.Linq;
+using Microsoft.ML.CommandLine;
+using Microsoft.ML.Data;
+using Microsoft.ML.Internal.Utilities;
-namespace Microsoft.ML.Runtime.EntryPoints
+namespace Microsoft.ML.EntryPoints
{
[BestFriend]
internal static class EntryPointUtils
diff --git a/src/Microsoft.ML.Core/EntryPoints/IMlState.cs b/src/Microsoft.ML.Core/EntryPoints/IMlState.cs
deleted file mode 100644
index 41ea062861..0000000000
--- a/src/Microsoft.ML.Core/EntryPoints/IMlState.cs
+++ /dev/null
@@ -1,14 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-// See the LICENSE file in the project root for more information.
-
-namespace Microsoft.ML.Runtime.EntryPoints
-{
- ///
- /// Dummy interface to allow reference to the AutoMlState object in the C# API (since AutoMlState
- /// has things that reference C# API, leading to circular dependency). Makes state object an opaque
- /// black box to the graph. The macro itself will then case to the concrete type.
- ///
- public interface IMlState
- { }
-}
\ No newline at end of file
diff --git a/src/Microsoft.ML.Core/EntryPoints/ModuleArgs.cs b/src/Microsoft.ML.Core/EntryPoints/ModuleArgs.cs
index d538f636ce..1308008c6e 100644
--- a/src/Microsoft.ML.Core/EntryPoints/ModuleArgs.cs
+++ b/src/Microsoft.ML.Core/EntryPoints/ModuleArgs.cs
@@ -4,15 +4,11 @@
using System;
using System.Collections.Generic;
-using System.Diagnostics;
using System.Linq;
-using System.Net.Sockets;
-using System.Reflection;
using System.Text;
-using Microsoft.ML.Runtime.Data;
-using Microsoft.ML.Runtime.Internal.Utilities;
+using Microsoft.ML.Data;
-namespace Microsoft.ML.Runtime.EntryPoints
+namespace Microsoft.ML.EntryPoints
{
///
/// This class defines attributes to annotate module inputs, outputs, entry points etc. when defining
@@ -577,11 +573,11 @@ public enum DataKind
///
FileHandle,
///
- /// A transform model, represented by an .
+ /// A transform model, represented by an .
///
TransformModel,
///
- /// A predictor model, represented by an .
+ /// A predictor model, represented by an .
///
PredictorModel,
///
@@ -602,11 +598,7 @@ public enum DataKind
/// optionally, a set of parameters, unique to each component. Example: "BinaryClassifierEvaluator{threshold=0.5}".
/// The C# representation is .
///
- Component,
- ///
- /// An C# object that represents state, such as .
- ///
- State
+ Component
}
public static DataKind GetDataType(Type type)
@@ -632,9 +624,9 @@ public static DataKind GetDataType(Type type)
return DataKind.Float;
if (typeof(IDataView).IsAssignableFrom(type))
return DataKind.DataView;
- if (typeof(ITransformModel).IsAssignableFrom(type))
+ if (typeof(TransformModel).IsAssignableFrom(type))
return DataKind.TransformModel;
- if (typeof(IPredictorModel).IsAssignableFrom(type))
+ if (typeof(PredictorModel).IsAssignableFrom(type))
return DataKind.PredictorModel;
if (typeof(IFileHandle).IsAssignableFrom(type))
return DataKind.FileHandle;
@@ -649,8 +641,6 @@ public static DataKind GetDataType(Type type)
}
if (typeof(IComponentFactory).IsAssignableFrom(type))
return DataKind.Component;
- if (typeof(IMlState).IsAssignableFrom(type))
- return DataKind.State;
return DataKind.Unknown;
}
@@ -673,7 +663,7 @@ public abstract class Optional
public abstract object GetValue();
- protected Optional(bool isExplicit)
+ private protected Optional(bool isExplicit)
{
IsExplicit = isExplicit;
}
diff --git a/src/Microsoft.ML.Core/EntryPoints/IPredictorModel.cs b/src/Microsoft.ML.Core/EntryPoints/PredictorModel.cs
similarity index 58%
rename from src/Microsoft.ML.Core/EntryPoints/IPredictorModel.cs
rename to src/Microsoft.ML.Core/EntryPoints/PredictorModel.cs
index 0ec1151134..30872a5faa 100644
--- a/src/Microsoft.ML.Core/EntryPoints/IPredictorModel.cs
+++ b/src/Microsoft.ML.Core/EntryPoints/PredictorModel.cs
@@ -2,42 +2,51 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using System;
using System.IO;
-using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Data;
-namespace Microsoft.ML.Runtime.EntryPoints
+namespace Microsoft.ML.EntryPoints
{
///
- /// Interface for standard predictor model port type.
+ /// Base type for standard predictor model port type.
///
- public interface IPredictorModel
+ public abstract class PredictorModel
{
+ [BestFriend]
+ private protected PredictorModel()
+ {
+ }
+
///
/// Save the model to the given stream.
///
- void Save(IHostEnvironment env, Stream stream);
+ [BestFriend]
+ internal abstract void Save(IHostEnvironment env, Stream stream);
///
/// Extract only the transform portion of the predictor model.
///
- ITransformModel TransformModel { get; }
+ [BestFriend]
+ internal abstract TransformModel TransformModel { get; }
///
/// Extract the predictor object out of the predictor model.
///
- IPredictor Predictor { get; }
+ [BestFriend]
+ internal abstract IPredictor Predictor { get; }
///
/// Apply the predictor model to the transform model and return the resulting predictor model.
///
- IPredictorModel Apply(IHostEnvironment env, ITransformModel transformModel);
+ [BestFriend]
+ internal abstract PredictorModel Apply(IHostEnvironment env, TransformModel transformModel);
///
/// For a given input data, return role mapped data and the predictor object.
/// The scoring entry point will hopefully know how to construct a scorer out of them.
///
- void PrepareData(IHostEnvironment env, IDataView input, out RoleMappedData roleMappedData, out IPredictor predictor);
+ [BestFriend]
+ internal abstract void PrepareData(IHostEnvironment env, IDataView input, out RoleMappedData roleMappedData, out IPredictor predictor);
///
/// Returns a string array containing the label names of the label column type predictor was trained on.
@@ -46,13 +55,13 @@ public interface IPredictorModel
///
///
/// The column type of the label the predictor was trained on.
- string[] GetLabelInfo(IHostEnvironment env, out ColumnType labelType);
+ [BestFriend]
+ internal abstract string[] GetLabelInfo(IHostEnvironment env, out ColumnType labelType);
///
- /// Returns the RoleMappedSchema that was used in training.
+ /// Returns the that was used in training.
///
- ///
- ///
- RoleMappedSchema GetTrainingSchema(IHostEnvironment env);
+ [BestFriend]
+ internal abstract RoleMappedSchema GetTrainingSchema(IHostEnvironment env);
}
}
\ No newline at end of file
diff --git a/src/Microsoft.ML.Core/Data/ITransformModel.cs b/src/Microsoft.ML.Core/EntryPoints/TransformModel.cs
similarity index 66%
rename from src/Microsoft.ML.Core/Data/ITransformModel.cs
rename to src/Microsoft.ML.Core/EntryPoints/TransformModel.cs
index 37ea9c1a2d..110c75c7aa 100644
--- a/src/Microsoft.ML.Core/Data/ITransformModel.cs
+++ b/src/Microsoft.ML.Core/EntryPoints/TransformModel.cs
@@ -2,58 +2,67 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using System;
using System.IO;
using Microsoft.ML.Data;
-using Microsoft.ML.Runtime.Data;
-namespace Microsoft.ML.Runtime.EntryPoints
+namespace Microsoft.ML.EntryPoints
{
///
/// Interface for standard transform model port type.
///
- public interface ITransformModel
+ public abstract class TransformModel
{
+ [BestFriend]
+ private protected TransformModel()
+ {
+ }
+
///
/// The input schema that this transform model was originally instantiated on.
/// Note that the schema may have columns that aren't needed by this transform model.
- /// If an IDataView exists with this schema, then applying this transform model to it
+ /// If an exists with this schema, then applying this transform model to it
/// shouldn't fail because of column type issues.
///
// REVIEW: Would be nice to be able to trim this to the minimum needed somehow. Note
// however that doing so may cause issues for composing transform models. For example,
// if transform model A needs column X and model B needs Y, that is NOT produced by A,
// then trimming A's input schema would cause composition to fail.
- Schema InputSchema { get; }
+ [BestFriend]
+ internal abstract Schema InputSchema { get; }
///
/// The output schema that this transform model was originally instantiated on. The schema resulting
- /// from may differ from this, similarly to how
+ /// from may differ from this, similarly to how
/// may differ from the schema of dataviews we apply this transform model to.
///
- Schema OutputSchema { get; }
+ [BestFriend]
+ internal abstract Schema OutputSchema { get; }
///
/// Apply the transform(s) in the model to the given input data.
///
- IDataView Apply(IHostEnvironment env, IDataView input);
+ [BestFriend]
+ internal abstract IDataView Apply(IHostEnvironment env, IDataView input);
///
/// Apply the transform(s) in the model to the given transform model.
///
- ITransformModel Apply(IHostEnvironment env, ITransformModel input);
+ [BestFriend]
+ internal abstract TransformModel Apply(IHostEnvironment env, TransformModel input);
///
/// Save the model to the given stream.
///
- void Save(IHostEnvironment env, Stream stream);
+ [BestFriend]
+ internal abstract void Save(IHostEnvironment env, Stream stream);
///
/// Returns the transform model as an that can output a row
/// given a row with the same schema as .
///
/// The transform model as an . If not all transforms
- /// in the pipeline are then it returns null.
- IRowToRowMapper AsRowToRowMapper(IExceptionContext ectx);
+ /// in the pipeline are then it returns .
+ [BestFriend]
+ internal abstract IRowToRowMapper AsRowToRowMapper(IExceptionContext ectx);
}
}
diff --git a/src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs b/src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs
index 3e27ce2516..057969c81f 100644
--- a/src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs
+++ b/src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs
@@ -9,7 +9,7 @@
using System.Linq;
using System.Threading;
-namespace Microsoft.ML.Runtime.Data
+namespace Microsoft.ML.Data
{
using Stopwatch = System.Diagnostics.Stopwatch;
@@ -368,7 +368,7 @@ public ConsoleEnvironment(int? seed = null, bool verbose = false,
/// Concurrency level. Set to 1 to run single-threaded. Set to 0 to pick automatically.
/// Text writer to print normal messages to.
/// Text writer to print error messages to.
- private ConsoleEnvironment(IRandom rand, bool verbose = false,
+ private ConsoleEnvironment(Random rand, bool verbose = false,
MessageSensitivity sensitivity = MessageSensitivity.All, int conc = 0,
TextWriter outWriter = null, TextWriter errWriter = null)
: base(rand, verbose, conc, nameof(ConsoleEnvironment))
@@ -401,7 +401,7 @@ protected override IFileHandle CreateTempFileCore(IHostEnvironment env, string s
return base.CreateTempFileCore(env, suffix, "TLC_" + prefix);
}
- protected override IHost RegisterCore(HostEnvironmentBase source, string shortName, string parentFullName, IRandom rand, bool verbose, int? conc)
+ protected override IHost RegisterCore(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, bool verbose, int? conc)
{
Contracts.AssertValue(rand);
Contracts.AssertValueOrNull(parentFullName);
@@ -472,7 +472,7 @@ public void Dispose()
private sealed class Host : HostBase
{
- public Host(HostEnvironmentBase source, string shortName, string parentFullName, IRandom rand, bool verbose, int? conc)
+ public Host(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, bool verbose, int? conc)
: base(source, shortName, parentFullName, rand, verbose, conc)
{
IsCancelled = source.IsCancelled;
@@ -494,7 +494,7 @@ protected override IPipe CreatePipe(ChannelProviderBase pare
return new Pipe(parent, name, GetDispatchDelegate());
}
- protected override IHost RegisterCore(HostEnvironmentBase source, string shortName, string parentFullName, IRandom rand, bool verbose, int? conc)
+ protected override IHost RegisterCore(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, bool verbose, int? conc)
{
return new Host(source, shortName, parentFullName, rand, verbose, conc);
}
diff --git a/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs b/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs
index 31ab23e28f..deaf372cb9 100644
--- a/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs
+++ b/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs
@@ -5,11 +5,10 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
-using System.ComponentModel.Composition;
using System.ComponentModel.Composition.Hosting;
using System.IO;
-namespace Microsoft.ML.Runtime.Data
+namespace Microsoft.ML.Data
{
///
/// Base class for channel providers. This is a common base class for.
@@ -107,12 +106,12 @@ public abstract class HostBase : HostEnvironmentBase, IHost
{
public override int Depth { get; }
- public IRandom Rand => _rand;
+ public Random Rand => _rand;
// We don't have dispose mechanism for hosts, so to let GC collect children hosts we make them WeakReference.
private readonly List> _children;
- public HostBase(HostEnvironmentBase source, string shortName, string parentFullName, IRandom rand, bool verbose, int? conc)
+ public HostBase(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, bool verbose, int? conc)
: base(source, rand, verbose, conc, shortName, parentFullName)
{
Depth = source.Depth + 1;
@@ -139,7 +138,7 @@ public void StopExecution()
IHost host;
lock (_cancelLock)
{
- IRandom rand = (seed.HasValue) ? RandomUtils.Create(seed.Value) : RandomUtils.Create(_rand);
+ Random rand = (seed.HasValue) ? RandomUtils.Create(seed.Value) : RandomUtils.Create(_rand);
host = RegisterCore(this, name, Master?.FullName, rand, verbose ?? Verbose, conc ?? _conc);
if (!IsCancelled)
_children.Add(new WeakReference(host));
@@ -342,7 +341,7 @@ public void RemoveListener(Action listenerFunc)
private readonly object _cancelLock;
// The random number generator for this host.
- private readonly IRandom _rand;
+ private readonly Random _rand;
// A dictionary mapping the type of message to the Dispatcher that gets the strongly typed dispatch delegate.
protected readonly ConcurrentDictionary ListenerDict;
@@ -367,7 +366,7 @@ public void RemoveListener(Action listenerFunc)
///
/// The main constructor.
///
- protected HostEnvironmentBase(IRandom rand, bool verbose, int conc,
+ protected HostEnvironmentBase(Random rand, bool verbose, int conc,
string shortName = null, string parentFullName = null)
: base(shortName, parentFullName, verbose)
{
@@ -386,7 +385,7 @@ protected HostEnvironmentBase(IRandom rand, bool verbose, int conc,
///
/// This constructor is for forking.
///
- protected HostEnvironmentBase(HostEnvironmentBase source, IRandom rand, bool verbose,
+ protected HostEnvironmentBase(HostEnvironmentBase source, Random rand, bool verbose,
int? conc, string shortName = null, string parentFullName = null)
: base(shortName, parentFullName, verbose)
{
@@ -433,12 +432,12 @@ public virtual void Dispose()
public IHost Register(string name, int? seed = null, bool? verbose = null, int? conc = null)
{
Contracts.CheckNonEmpty(name, nameof(name));
- IRandom rand = (seed.HasValue) ? RandomUtils.Create(seed.Value) : RandomUtils.Create(_rand);
+ Random rand = (seed.HasValue) ? RandomUtils.Create(seed.Value) : RandomUtils.Create(_rand);
return RegisterCore(this, name, Master?.FullName, rand, verbose ?? Verbose, conc);
}
protected abstract IHost RegisterCore(HostEnvironmentBase source, string shortName,
- string parentFullName, IRandom rand, bool verbose, int? conc);
+ string parentFullName, Random rand, bool verbose, int? conc);
public IFileHandle OpenInputFile(string path)
{
diff --git a/src/Microsoft.ML.Core/Environment/TelemetryMessage.cs b/src/Microsoft.ML.Core/Environment/TelemetryMessage.cs
index 72b08e2715..5501a99090 100644
--- a/src/Microsoft.ML.Core/Environment/TelemetryMessage.cs
+++ b/src/Microsoft.ML.Core/Environment/TelemetryMessage.cs
@@ -4,11 +4,8 @@
using System;
using System.Collections.Generic;
-using System.Linq;
-using System.Text;
-using System.Threading.Tasks;
-namespace Microsoft.ML.Runtime
+namespace Microsoft.ML
{
///
/// A telemetry message.
diff --git a/src/Microsoft.ML.Core/Prediction/IPredictor.cs b/src/Microsoft.ML.Core/Prediction/IPredictor.cs
index 6bd1ac2056..682cec417d 100644
--- a/src/Microsoft.ML.Core/Prediction/IPredictor.cs
+++ b/src/Microsoft.ML.Core/Prediction/IPredictor.cs
@@ -2,9 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using System;
-
-namespace Microsoft.ML.Runtime
+namespace Microsoft.ML
{
///
/// Type of prediction task
diff --git a/src/Microsoft.ML.Core/Prediction/ITrainer.cs b/src/Microsoft.ML.Core/Prediction/ITrainer.cs
index 5e796aa602..ad5a0b2539 100644
--- a/src/Microsoft.ML.Core/Prediction/ITrainer.cs
+++ b/src/Microsoft.ML.Core/Prediction/ITrainer.cs
@@ -2,10 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using Microsoft.ML.Runtime.Data;
-using System;
+using Microsoft.ML.Data;
-namespace Microsoft.ML.Runtime
+namespace Microsoft.ML
{
// REVIEW: Would be nice if the registration under SignatureTrainer were automatic
// given registration for one of the "sub-class" signatures.
diff --git a/src/Microsoft.ML.Core/Prediction/ITree.cs b/src/Microsoft.ML.Core/Prediction/ITree.cs
index 67642ecfc5..9b07acdef7 100644
--- a/src/Microsoft.ML.Core/Prediction/ITree.cs
+++ b/src/Microsoft.ML.Core/Prediction/ITree.cs
@@ -2,10 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using System;
using System.Collections.Generic;
-namespace Microsoft.ML.Runtime.TreePredictor
+namespace Microsoft.ML.TreePredictor
{
// The interfaces contained herein are meant to allow tree visualizer to run without an explicit dependency
// on FastTree, so as to allow it greater generality. These should probably be moved somewhere else, but where?
diff --git a/src/Microsoft.ML.Core/Prediction/TrainContext.cs b/src/Microsoft.ML.Core/Prediction/TrainContext.cs
index e5e4bbad1c..ff37caf598 100644
--- a/src/Microsoft.ML.Core/Prediction/TrainContext.cs
+++ b/src/Microsoft.ML.Core/Prediction/TrainContext.cs
@@ -2,9 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Data;
-namespace Microsoft.ML.Runtime
+namespace Microsoft.ML
{
///
/// Holds information relevant to trainers. Instances of this class are meant to be constructed and passed
diff --git a/src/Microsoft.ML.Core/Prediction/TrainerInfo.cs b/src/Microsoft.ML.Core/Prediction/TrainerInfo.cs
index 4f97c0c893..53398bc651 100644
--- a/src/Microsoft.ML.Core/Prediction/TrainerInfo.cs
+++ b/src/Microsoft.ML.Core/Prediction/TrainerInfo.cs
@@ -2,7 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-namespace Microsoft.ML.Runtime
+namespace Microsoft.ML
{
///
/// Instances of this class posses information about trainers, in terms of their requirements and capabilities.
diff --git a/src/Microsoft.ML.Core/Properties/AssemblyInfo.cs b/src/Microsoft.ML.Core/Properties/AssemblyInfo.cs
index a61e96ce7b..58bda21829 100644
--- a/src/Microsoft.ML.Core/Properties/AssemblyInfo.cs
+++ b/src/Microsoft.ML.Core/Properties/AssemblyInfo.cs
@@ -9,10 +9,12 @@
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Tests" + PublicKey.TestValue)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Core.Tests" + PublicKey.TestValue)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Predictor.Tests" + PublicKey.TestValue)]
+[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Sweeper.Tests" + PublicKey.TestValue)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.InferenceTesting" + PublicKey.TestValue)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.StaticPipelineTesting" + PublicKey.TestValue)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.OnnxTransformTest" + PublicKey.TestValue)]
+[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.EntryPoints" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Legacy" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Maml" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.ResultProcessor" + PublicKey.Value)]
@@ -31,7 +33,7 @@
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.PCA" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.PipelineInference" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Recommender" + PublicKey.Value)]
-[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Runtime.ImageAnalytics" + PublicKey.Value)]
+[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.ImageAnalytics" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Scoring" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.StandardLearners" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Sweeper" + PublicKey.Value)]
@@ -39,4 +41,10 @@
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TimeSeries" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Transforms" + PublicKey.Value)]
+[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.StaticPipe" + PublicKey.Value)]
+[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TensorFlow.StaticPipe" + PublicKey.Value)]
+[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.HalLearners.StaticPipe" + PublicKey.Value)]
+[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.OnnxTransform.StaticPipe" + PublicKey.Value)]
+[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.LightGBM.StaticPipe" + PublicKey.Value)]
+
[assembly: WantsToBeBestFriends]
diff --git a/src/Microsoft.ML.Core/PublicKey.cs b/src/Microsoft.ML.Core/PublicKey.cs
index 9a944c3d18..63718c3f8e 100644
--- a/src/Microsoft.ML.Core/PublicKey.cs
+++ b/src/Microsoft.ML.Core/PublicKey.cs
@@ -7,8 +7,8 @@
namespace Microsoft.ML
#else
// CpuMath module has its own PublicKey for isolating itself from Microsoft.ML.Core
-// Note that CpuMath uses its own BestFriend defined in Microsoft.ML.Runtime.Internal.CpuMath.Core.
-namespace Microsoft.ML.Runtime.Internal.CpuMath.Core
+// Note that CpuMath uses its own BestFriend defined in Microsoft.ML.Internal.CpuMath.Core.
+namespace Microsoft.ML.Internal.CpuMath.Core
#endif
{
[BestFriend]
diff --git a/src/Microsoft.ML.Core/Utilities/BigArray.cs b/src/Microsoft.ML.Core/Utilities/BigArray.cs
index 3bfb4f688f..bda630cbc0 100644
--- a/src/Microsoft.ML.Core/Utilities/BigArray.cs
+++ b/src/Microsoft.ML.Core/Utilities/BigArray.cs
@@ -6,7 +6,7 @@
using System.Collections;
using System.Collections.Generic;
-namespace Microsoft.ML.Runtime.Internal.Utilities
+namespace Microsoft.ML.Internal.Utilities
{
///
/// An array-like data structure that supports storing more than
diff --git a/src/Microsoft.ML.Core/Utilities/BinFinder.cs b/src/Microsoft.ML.Core/Utilities/BinFinder.cs
index cdfd0ad08b..d5456e2956 100644
--- a/src/Microsoft.ML.Core/Utilities/BinFinder.cs
+++ b/src/Microsoft.ML.Core/Utilities/BinFinder.cs
@@ -2,12 +2,11 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using Float = System.Single;
-
using System;
using System.Collections.Generic;
+using Float = System.Single;
-namespace Microsoft.ML.Runtime.Internal.Utilities
+namespace Microsoft.ML.Internal.Utilities
{
[BestFriend]
internal abstract class BinFinderBase
@@ -273,7 +272,7 @@ public static Double GetSplitValue(Double a, Double b)
}
}
-namespace Microsoft.ML.Runtime.Internal.Utilities
+namespace Microsoft.ML.Internal.Utilities
{
// This needs to be large enough to represent a product of 2 ints without losing precision
using EnergyType = System.Int64;
@@ -525,7 +524,7 @@ private void UpdatePeg(Peg peg)
}
}
-namespace Microsoft.ML.Runtime.Internal.Utilities
+namespace Microsoft.ML.Internal.Utilities
{
// Reasonable choices are Double and System.Int64.
using EnergyType = System.Double;
diff --git a/src/Microsoft.ML.Core/Utilities/BitUtils.cs b/src/Microsoft.ML.Core/Utilities/BitUtils.cs
index 376b0ac8ef..b38b4ad01a 100644
--- a/src/Microsoft.ML.Core/Utilities/BitUtils.cs
+++ b/src/Microsoft.ML.Core/Utilities/BitUtils.cs
@@ -2,10 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using System;
using System.Runtime.CompilerServices;
-namespace Microsoft.ML.Runtime.Internal.Utilities
+namespace Microsoft.ML.Internal.Utilities
{
internal static partial class Utils
{
diff --git a/src/Microsoft.ML.Core/Utilities/CharUtils.cs b/src/Microsoft.ML.Core/Utilities/CharUtils.cs
index d88197c8e7..398b97f851 100644
--- a/src/Microsoft.ML.Core/Utilities/CharUtils.cs
+++ b/src/Microsoft.ML.Core/Utilities/CharUtils.cs
@@ -4,11 +4,10 @@
#pragma warning disable 420 // volatile with Interlocked.CompareExchange
-using System;
using System.Runtime.CompilerServices;
using System.Threading;
-namespace Microsoft.ML.Runtime.Internal.Utilities
+namespace Microsoft.ML.Internal.Utilities
{
[BestFriend]
internal static class CharUtils
diff --git a/src/Microsoft.ML.Core/Utilities/CmdIndenter.cs b/src/Microsoft.ML.Core/Utilities/CmdIndenter.cs
index 6b4dbc48db..da6e8b821f 100644
--- a/src/Microsoft.ML.Core/Utilities/CmdIndenter.cs
+++ b/src/Microsoft.ML.Core/Utilities/CmdIndenter.cs
@@ -2,15 +2,10 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using System;
using System.CodeDom.Compiler;
-using System.Collections.Generic;
-using System.Linq;
-using System.Text;
-using System.Threading.Tasks;
-using Microsoft.ML.Runtime.CommandLine;
+using Microsoft.ML.CommandLine;
-namespace Microsoft.ML.Runtime.Internal.Utilities
+namespace Microsoft.ML.Internal.Utilities
{
[BestFriend]
internal static class CmdIndenter
diff --git a/src/Microsoft.ML.Core/Utilities/Contracts.cs b/src/Microsoft.ML.Core/Utilities/Contracts.cs
index d567d6e883..6fc94f6304 100644
--- a/src/Microsoft.ML.Core/Utilities/Contracts.cs
+++ b/src/Microsoft.ML.Core/Utilities/Contracts.cs
@@ -17,9 +17,9 @@
using System.Threading;
#if CPUMATH_INFRASTRUCTURE
-namespace Microsoft.ML.Runtime.Internal.CpuMath.Core
+namespace Microsoft.ML.Internal.CpuMath.Core
#else
-namespace Microsoft.ML.Runtime
+namespace Microsoft.ML
#endif
{
using Conditional = System.Diagnostics.ConditionalAttribute;
@@ -758,6 +758,10 @@ public static void CheckAlive(this IHostEnvironment env)
public static void CheckValueOrNull(T val) where T : class
{
}
+
+ ///
+ /// This documents that the parameter can legally be null.
+ ///
[Conditional("INVARIANT_CHECKS")]
public static void CheckValueOrNull(this IExceptionContext ctx, T val) where T : class
{
diff --git a/src/Microsoft.ML.Core/Utilities/DoubleParser.cs b/src/Microsoft.ML.Core/Utilities/DoubleParser.cs
index 4f2ef9ad80..695c1981be 100644
--- a/src/Microsoft.ML.Core/Utilities/DoubleParser.cs
+++ b/src/Microsoft.ML.Core/Utilities/DoubleParser.cs
@@ -4,14 +4,9 @@
#undef COMPARE_BCL
-using Float = System.Single;
-
using System;
-using System.Collections.Generic;
-using System.Linq;
-using System.Text;
-namespace Microsoft.ML.Runtime.Internal.Utilities
+namespace Microsoft.ML.Internal.Utilities
{
[BestFriend]
internal static class DoubleParser
diff --git a/src/Microsoft.ML.Core/Utilities/FixedSizeQueue.cs b/src/Microsoft.ML.Core/Utilities/FixedSizeQueue.cs
index 7263304046..ab2d938d22 100644
--- a/src/Microsoft.ML.Core/Utilities/FixedSizeQueue.cs
+++ b/src/Microsoft.ML.Core/Utilities/FixedSizeQueue.cs
@@ -2,9 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using System;
-
-namespace Microsoft.ML.Runtime.Internal.Utilities
+namespace Microsoft.ML.Internal.Utilities
{
using Conditional = System.Diagnostics.ConditionalAttribute;
diff --git a/src/Microsoft.ML.Core/Utilities/FloatUtils.cs b/src/Microsoft.ML.Core/Utilities/FloatUtils.cs
index 06d403da9a..b93541eb3e 100644
--- a/src/Microsoft.ML.Core/Utilities/FloatUtils.cs
+++ b/src/Microsoft.ML.Core/Utilities/FloatUtils.cs
@@ -6,7 +6,7 @@
using System.Globalization;
using System.Runtime.InteropServices;
-namespace Microsoft.ML.Runtime.Internal.Utilities
+namespace Microsoft.ML.Internal.Utilities
{
[BestFriend]
internal static class FloatUtils
diff --git a/src/Microsoft.ML.Core/Utilities/HashArray.cs b/src/Microsoft.ML.Core/Utilities/HashArray.cs
index 64dced9792..c9f31b5361 100644
--- a/src/Microsoft.ML.Core/Utilities/HashArray.cs
+++ b/src/Microsoft.ML.Core/Utilities/HashArray.cs
@@ -6,7 +6,7 @@
using System;
-namespace Microsoft.ML.Runtime.Internal.Utilities
+namespace Microsoft.ML.Internal.Utilities
{
// REVIEW: May want to add an IEnumerable>.
diff --git a/src/Microsoft.ML.Core/Utilities/Hashing.cs b/src/Microsoft.ML.Core/Utilities/Hashing.cs
index ae36fae95d..8345a319ac 100644
--- a/src/Microsoft.ML.Core/Utilities/Hashing.cs
+++ b/src/Microsoft.ML.Core/Utilities/Hashing.cs
@@ -3,11 +3,10 @@
// See the LICENSE file in the project root for more information.
using System;
-using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Text;
-namespace Microsoft.ML.Runtime.Internal.Utilities
+namespace Microsoft.ML.Internal.Utilities
{
[BestFriend]
internal static class Hashing
diff --git a/src/Microsoft.ML.Core/Utilities/Heap.cs b/src/Microsoft.ML.Core/Utilities/Heap.cs
index 7652163f10..fe1178abd3 100644
--- a/src/Microsoft.ML.Core/Utilities/Heap.cs
+++ b/src/Microsoft.ML.Core/Utilities/Heap.cs
@@ -5,7 +5,7 @@
using System;
using System.Collections.Generic;
-namespace Microsoft.ML.Runtime.Internal.Utilities
+namespace Microsoft.ML.Internal.Utilities
{
using Conditional = System.Diagnostics.ConditionalAttribute;
diff --git a/src/Microsoft.ML.Core/Utilities/HybridMemoryStream.cs b/src/Microsoft.ML.Core/Utilities/HybridMemoryStream.cs
index 02f713dd6e..c1db619510 100644
--- a/src/Microsoft.ML.Core/Utilities/HybridMemoryStream.cs
+++ b/src/Microsoft.ML.Core/Utilities/HybridMemoryStream.cs
@@ -5,7 +5,7 @@
using System;
using System.IO;
-namespace Microsoft.ML.Runtime.Internal.Utilities
+namespace Microsoft.ML.Internal.Utilities
{
using Conditional = System.Diagnostics.ConditionalAttribute;
diff --git a/src/Microsoft.ML.Core/Utilities/IndentedTextWriterExtensions.cs b/src/Microsoft.ML.Core/Utilities/IndentedTextWriterExtensions.cs
index fe8fd12e96..fd733dc60c 100644
--- a/src/Microsoft.ML.Core/Utilities/IndentedTextWriterExtensions.cs
+++ b/src/Microsoft.ML.Core/Utilities/IndentedTextWriterExtensions.cs
@@ -4,10 +4,8 @@
using System;
using System.CodeDom.Compiler;
-using System.IO;
-using System.Text;
-namespace Microsoft.ML.Runtime.Internal.Utilities
+namespace Microsoft.ML.Internal.Utilities
{
[BestFriend]
internal static class IndentedTextWriterExtensions
diff --git a/src/Microsoft.ML.Core/Utilities/LineParser.cs b/src/Microsoft.ML.Core/Utilities/LineParser.cs
index 73d9bb6158..bc8eabcdda 100644
--- a/src/Microsoft.ML.Core/Utilities/LineParser.cs
+++ b/src/Microsoft.ML.Core/Utilities/LineParser.cs
@@ -5,7 +5,7 @@
using System;
using System.Runtime.CompilerServices;
-namespace Microsoft.ML.Runtime.Internal.Utilities
+namespace Microsoft.ML.Internal.Utilities
{
[BestFriend]
internal static class LineParser
diff --git a/src/Microsoft.ML.Core/Utilities/LruCache.cs b/src/Microsoft.ML.Core/Utilities/LruCache.cs
index e041efde02..ad874b60c1 100644
--- a/src/Microsoft.ML.Core/Utilities/LruCache.cs
+++ b/src/Microsoft.ML.Core/Utilities/LruCache.cs
@@ -5,7 +5,7 @@
using System.Collections.Generic;
using System.Linq;
-namespace Microsoft.ML.Runtime.Internal.Utilities
+namespace Microsoft.ML.Internal.Utilities
{
///
/// Implements a least recently used cache.
diff --git a/src/Microsoft.ML.Core/Utilities/MathUtils.cs b/src/Microsoft.ML.Core/Utilities/MathUtils.cs
index 4ce1bb4cf0..7ab349da12 100644
--- a/src/Microsoft.ML.Core/Utilities/MathUtils.cs
+++ b/src/Microsoft.ML.Core/Utilities/MathUtils.cs
@@ -2,13 +2,11 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using Float = System.Single;
-
using System;
using System.Collections.Generic;
-using System.Linq;
+using Float = System.Single;
-namespace Microsoft.ML.Runtime.Internal.Utilities
+namespace Microsoft.ML.Internal.Utilities
{
///
/// Some useful math methods.
@@ -507,7 +505,15 @@ public static Float Tanh(Float x)
///
public static Float SigmoidSlow(Float x)
{
- return 1 / (1 + ExpSlow(-x));
+ // The following two expressions are mathematically equivalent. Due to the potential of getting overflow we should
+ // not call exp(x) for large positive x: instead, we modify the expression to compute exp(-x).
+ if (x > 0)
+ return 1 / (1 + ExpSlow(-x));
+ else
+ {
+ var ex = ExpSlow(x);
+ return ex / (1 + ex);
+ }
}
///
@@ -726,7 +732,7 @@ public static Float GetMedianInPlace(Float[] src, int count)
return (src[iv - 1] + src[iv]) / 2;
}
- public static Double CosineSimilarity(Float[] a, Float[] b, int aIdx, int bIdx, int len)
+ public static Double CosineSimilarity(ReadOnlySpan a, ReadOnlySpan b, int aIdx, int bIdx, int len)
{
const Double epsilon = 1e-12f;
Contracts.Assert(len > 0);
diff --git a/src/Microsoft.ML.Core/Utilities/MatrixTransposeOps.cs b/src/Microsoft.ML.Core/Utilities/MatrixTransposeOps.cs
index cb945e56a2..2e25599c1d 100644
--- a/src/Microsoft.ML.Core/Utilities/MatrixTransposeOps.cs
+++ b/src/Microsoft.ML.Core/Utilities/MatrixTransposeOps.cs
@@ -7,7 +7,7 @@
using System.Linq;
using System.Threading.Tasks;
-namespace Microsoft.ML.Runtime.Internal.Utilities
+namespace Microsoft.ML.Internal.Utilities
{
[BestFriend]
internal static class MatrixTransposeOps
diff --git a/src/Microsoft.ML.Core/Utilities/MinWaiter.cs b/src/Microsoft.ML.Core/Utilities/MinWaiter.cs
index fbaf8fb6d6..8a671ab817 100644
--- a/src/Microsoft.ML.Core/Utilities/MinWaiter.cs
+++ b/src/Microsoft.ML.Core/Utilities/MinWaiter.cs
@@ -2,10 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using System;
using System.Threading;
-namespace Microsoft.ML.Runtime.Internal.Utilities
+namespace Microsoft.ML.Internal.Utilities
{
///
/// A synchronization primitive meant to address situations where you have a set of
diff --git a/src/Microsoft.ML.Core/Utilities/NormStr.cs b/src/Microsoft.ML.Core/Utilities/NormStr.cs
index c79e2425d1..9859ea2724 100644
--- a/src/Microsoft.ML.Core/Utilities/NormStr.cs
+++ b/src/Microsoft.ML.Core/Utilities/NormStr.cs
@@ -3,14 +3,12 @@
// See the LICENSE file in the project root for more information.
using System;
-using System.Collections.Generic;
using System.Collections;
+using System.Collections.Generic;
using System.Linq;
-using System.Threading;
using System.Text;
-using Microsoft.ML.Runtime.Data;
-namespace Microsoft.ML.Runtime.Internal.Utilities
+namespace Microsoft.ML.Internal.Utilities
{
using Conditional = System.Diagnostics.ConditionalAttribute;
diff --git a/src/Microsoft.ML.Core/Utilities/ObjectPool.cs b/src/Microsoft.ML.Core/Utilities/ObjectPool.cs
index a06202af76..a5a9591fae 100644
--- a/src/Microsoft.ML.Core/Utilities/ObjectPool.cs
+++ b/src/Microsoft.ML.Core/Utilities/ObjectPool.cs
@@ -6,7 +6,7 @@
using System.Collections.Concurrent;
using System.Threading;
-namespace Microsoft.ML.Runtime.Internal.Utilities
+namespace Microsoft.ML.Internal.Utilities
{
[BestFriend]
internal sealed class ObjectPool : ObjectPoolBase where T : class, new()
diff --git a/src/Microsoft.ML.Core/Utilities/OrderedWaiter.cs b/src/Microsoft.ML.Core/Utilities/OrderedWaiter.cs
index ca0ef23445..5a71c6205d 100644
--- a/src/Microsoft.ML.Core/Utilities/OrderedWaiter.cs
+++ b/src/Microsoft.ML.Core/Utilities/OrderedWaiter.cs
@@ -5,7 +5,7 @@
using System;
using System.Threading;
-namespace Microsoft.ML.Runtime.Internal.Utilities
+namespace Microsoft.ML.Internal.Utilities
{
///
/// The primary use case for this structure is to impose ordering among
diff --git a/src/Microsoft.ML.Core/Utilities/PathUtils.cs b/src/Microsoft.ML.Core/Utilities/PathUtils.cs
index 98407e24a1..c10e36b1eb 100644
--- a/src/Microsoft.ML.Core/Utilities/PathUtils.cs
+++ b/src/Microsoft.ML.Core/Utilities/PathUtils.cs
@@ -6,7 +6,7 @@
using System.IO;
using System.Threading;
-namespace Microsoft.ML.Runtime.Internal.Utilities
+namespace Microsoft.ML.Internal.Utilities
{
internal static partial class Utils
{
diff --git a/src/Microsoft.ML.Core/Utilities/PlatformUtils.cs b/src/Microsoft.ML.Core/Utilities/PlatformUtils.cs
index 2bd8acab3e..16cadbdb5b 100644
--- a/src/Microsoft.ML.Core/Utilities/PlatformUtils.cs
+++ b/src/Microsoft.ML.Core/Utilities/PlatformUtils.cs
@@ -6,7 +6,7 @@
using System.Collections.ObjectModel;
using System.Reflection;
-namespace Microsoft.ML.Runtime.Internal.Utilities
+namespace Microsoft.ML.Internal.Utilities
{
///
/// Contains extension methods that aid in building cross platform.
diff --git a/src/Microsoft.ML.Core/Utilities/Random.cs b/src/Microsoft.ML.Core/Utilities/Random.cs
index d5bbf2d4ec..c79a22426b 100644
--- a/src/Microsoft.ML.Core/Utilities/Random.cs
+++ b/src/Microsoft.ML.Core/Utilities/Random.cs
@@ -4,48 +4,46 @@
using System;
using System.IO;
-using Microsoft.ML.Runtime.Internal.Utilities;
+using Microsoft.ML.Internal.Utilities;
-namespace Microsoft.ML.Runtime
+namespace Microsoft.ML
{
- public interface IRandom
+ [BestFriend]
+ internal static class RandomUtils
{
- ///
- /// Generates a Single in the range [0, 1).
- ///
- Single NextSingle();
-
- ///
- /// Generates a Double in the range [0, 1).
- ///
- Double NextDouble();
-
- ///
- /// Generates an int in the range [0, int.MaxValue]. Note that this differs
- /// from the contract for System.Random.Next, which claims to never return
- /// int.MaxValue.
- ///
- int Next();
-
- ///
- /// Generates an int in the range [int.MinValue, int.MaxValue].
- ///
- int NextSigned();
-
- ///
- /// Generates an int in the range [0, limit), unless limit == 0, in which case this advances the generator
- /// and returns 0.
- /// Throws if limit is less than 0.
- ///
- int Next(int limit);
- }
+ public static float NextSingle(this Random random)
+ {
+ if (random is TauswortheHybrid tauswortheHybrd)
+ {
+ return tauswortheHybrd.NextSingle();
+ }
+
+ for (; ; )
+ {
+ // Since the largest value that NextDouble() can return rounds to 1 when cast to float,
+ // we need to protect against returning 1.
+ var res = (float)random.NextDouble();
+ if (res < 1.0f)
+ return res;
+ }
+ }
+
+ public static int NextSigned(this Random random)
+ {
+ if (random is TauswortheHybrid tauswortheHybrd)
+ {
+ return tauswortheHybrd.NextSigned();
+ }
+
+ // Note that, according to the documentation for System.Random,
+ // this won't ever achieve int.MaxValue, but oh well.
+ return random.Next(int.MinValue, int.MaxValue);
+ }
- public static class RandomUtils
- {
public static TauswortheHybrid Create()
{
// Seed from a system random.
- return new TauswortheHybrid(new SysRandom());
+ return new TauswortheHybrid(new Random());
}
public static TauswortheHybrid Create(int? seed)
@@ -74,81 +72,17 @@ public static TauswortheHybrid Create(uint seed)
return new TauswortheHybrid(state);
}
- public static TauswortheHybrid Create(IRandom seed)
+ public static TauswortheHybrid Create(Random seed)
{
return new TauswortheHybrid(seed);
}
}
- public sealed class SysRandom : IRandom
- {
- private readonly Random _rnd;
-
- public SysRandom()
- {
- _rnd = new Random();
- }
-
- public SysRandom(int seed)
- {
- _rnd = new Random(seed);
- }
-
- public static SysRandom Wrap(Random rnd)
- {
- if (rnd != null)
- return new SysRandom(rnd);
- return null;
- }
-
- private SysRandom(Random rnd)
- {
- Contracts.AssertValue(rnd);
- _rnd = rnd;
- }
-
- public Single NextSingle()
- {
- // Since the largest value that NextDouble() can return rounds to 1 when cast to Single,
- // we need to protect against returning 1.
- for (;;)
- {
- var res = (Single)_rnd.NextDouble();
- if (res < 1.0f)
- return res;
- }
- }
-
- public Double NextDouble()
- {
- return _rnd.NextDouble();
- }
-
- public int Next()
- {
- // Note that, according to the documentation for System.Random,
- // this won't ever achieve int.MaxValue, but oh well.
- return _rnd.Next();
- }
-
- public int Next(int limit)
- {
- Contracts.CheckParam(limit >= 0, nameof(limit), "limit must be non-negative");
- return _rnd.Next(limit);
- }
-
- public int NextSigned()
- {
- // Note that, according to the documentation for System.Random,
- // this won't ever achieve int.MaxValue, but oh well.
- return _rnd.Next(int.MinValue, int.MaxValue);
- }
- }
-
///
/// Tausworthe hybrid random number generator.
///
- public sealed class TauswortheHybrid : IRandom
+ [BestFriend]
+ internal sealed class TauswortheHybrid : Random
{
public readonly struct State
{
@@ -204,7 +138,7 @@ public TauswortheHybrid(State state)
_z4 = state.U4;
}
- public TauswortheHybrid(IRandom rng)
+ public TauswortheHybrid(Random rng)
{
_z1 = GetSeed(rng);
_z2 = GetSeed(rng);
@@ -212,14 +146,14 @@ public TauswortheHybrid(IRandom rng)
_z4 = GetU(rng);
}
- private static uint GetU(IRandom rng)
+ private static uint GetU(Random rng)
{
return ((uint)rng.Next(0x00010000) << 16) | ((uint)rng.Next(0x00010000));
}
- private static uint GetSeed(IRandom rng)
+ private static uint GetSeed(Random rng)
{
- for (;;)
+ for (; ; )
{
uint u = GetU(rng);
if (u >= 128)
@@ -227,19 +161,19 @@ private static uint GetSeed(IRandom rng)
}
}
- public Single NextSingle()
+ public float NextSingle()
{
NextState();
return GetSingle();
}
- public Double NextDouble()
+ public override double NextDouble()
{
NextState();
return GetDouble();
}
- public int Next()
+ public override int Next()
{
NextState();
uint u = GetUint();
@@ -248,17 +182,35 @@ public int Next()
return n;
}
- public int Next(int limit)
+ public override int Next(int maxValue)
{
- Contracts.CheckParam(limit >= 0, nameof(limit), "limit must be non-negative");
+ Contracts.CheckParam(maxValue >= 0, nameof(maxValue), "maxValue must be non-negative");
NextState();
uint u = GetUint();
- ulong uu = (ulong)u * (ulong)limit;
+ ulong uu = (ulong)u * (ulong)maxValue;
int res = (int)(uu >> 32);
- Contracts.Assert(0 <= res && (res < limit || res == 0));
+ Contracts.Assert(0 <= res && (res < maxValue || res == 0));
return res;
}
+ public override int Next(int minValue, int maxValue)
+ {
+ Contracts.CheckParam(minValue <= maxValue, nameof(minValue), "minValue must be less than or equal to maxValue.");
+
+ long range = (long)maxValue - minValue;
+ return (int)((long)(NextDouble() * range) + minValue);
+ }
+
+ public override void NextBytes(byte[] buffer)
+ {
+ Contracts.CheckValue(buffer, nameof(buffer));
+
+ for (int i = 0; i < buffer.Length; i++)
+ {
+ buffer[i] = (byte)Next();
+ }
+ }
+
public int NextSigned()
{
NextState();
@@ -270,23 +222,23 @@ private uint GetUint()
return _z1 ^ _z2 ^ _z3 ^ _z4;
}
- private Single GetSingle()
+ private float GetSingle()
{
- const Single scale = (Single)1 / (1 << 23);
+ const float scale = (float)1 / (1 << 23);
// Drop the low 9 bits so the conversion to Single is exact. Allowing rounding would cause
// issues with biasing values and, worse, the possibility of returning exactly 1.
uint u = GetUint() >> 9;
- Contracts.Assert((uint)(Single)u == u);
+ Contracts.Assert((uint)(float)u == u);
- return (Single)u * scale;
+ return (float)u * scale;
}
- private Double GetDouble()
+ private double GetDouble()
{
- const Double scale = (Double)1 / (1 << 16) / (1 << 16);
+ const double scale = (double)1 / (1 << 16) / (1 << 16);
uint u = GetUint();
- return (Double)u * scale;
+ return (double)u * scale;
}
private void NextState()
@@ -315,97 +267,4 @@ public State GetState()
return new State(_z1, _z2, _z3, _z4);
}
}
-
-#if false // REVIEW: This was written for NN drop out but turned out to be too slow, so I inlined it instead.
- public sealed class BooleanSampler
- {
- public const int CbitRand = 25;
-
- private readonly IRandom _rand;
- private readonly uint _k; // probability of "true" is _k / (1U << _qlog).
- private readonly int _qlog; // Number of bits consumed by each call to Sample().
- private readonly int _cv; // Number of calls to Sample() covered by a call to _rand.Next(...).
- private readonly uint _mask; // (1U << _qlog) - 1
-
- // Mutable state.
- private int _c;
- private uint _v;
-
- ///
- /// Create a boolean sampler using the given random number generator, quantizing the true rate
- /// to cbitQuant bits, assuming that sampling the random number generator is capable of producing
- /// cbitRand good bits.
- ///
- /// For example, new BooleanSampler(0.5f, 1, 25, new Random()) will produce a reasonable fair coin flipper.
- /// Note that this reduces the parameters, so new BooleanSampler(0.5f, 6, 25, new Random()) will produce
- /// the same flipper. In other words, since 0.5 quantized to 6 bits can be reduced to only needing one
- /// bit, it reduces cbitQuant to 1.
- ///
- public static BooleanSampler Create(Single rate, int cbitQuant, IRandom rand)
- {
- Contracts.Assert(0 < rate && rate < 1);
- Contracts.Assert(0 < cbitQuant && cbitQuant <= CbitRand / 2);
-
- int qlog = cbitQuant;
- uint k = (uint)(rate * (1 << qlog));
- if (k == 0)
- k = 1;
- Contracts.Assert(0 <= k && k < (1U << qlog));
-
- while ((k & 1) == 0 && k > 0)
- {
- qlog--;
- k >>= 1;
- }
- Contracts.Assert(qlog > 0);
- uint q = 1U << qlog;
- Contracts.Assert(0 < k && k < q);
-
- int cv = CbitRand / qlog;
- Contracts.Assert(cv > 1);
- return new BooleanSampler(qlog, k, rand);
- }
-
- private BooleanSampler(int qlog, uint k, IRandom rand)
- {
- _qlog = qlog;
- _k = k;
- _rand = rand;
- _qlog = qlog;
- _cv = CbitRand / _qlog;
- _mask = (1U << _qlog) - 1;
- }
-
- public bool Sample()
- {
- _v >>= _qlog;
- if (--_c <= 0)
- {
- _v = (uint)_rand.Next(1 << (_cv * _qlog));
- _c = _cv;
- }
- return (_v & _mask) < _k;
- }
-
- public void SampleMany(out uint bits, out int count)
- {
- uint u = (uint)_rand.Next(1 << (_cv * _qlog));
- count = _cv;
- if (_qlog == 1)
- {
- bits = u;
- return;
- }
-
- bits = 0;
- for (int i = 0; i < count; i++)
- {
- bits <<= 1;
- if ((u & _mask) < _k)
- bits |= 1;
- u >>= _qlog;
- }
- }
- }
-#endif
}
\ No newline at end of file
diff --git a/src/Microsoft.ML.Core/Utilities/ReservoirSampler.cs b/src/Microsoft.ML.Core/Utilities/ReservoirSampler.cs
index 8438a2e742..f442fd8e59 100644
--- a/src/Microsoft.ML.Core/Utilities/ReservoirSampler.cs
+++ b/src/Microsoft.ML.Core/Utilities/ReservoirSampler.cs
@@ -2,11 +2,12 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
+using System;
using System.Collections.Generic;
using System.Linq;
-using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Data;
-namespace Microsoft.ML.Runtime.Internal.Utilities
+namespace Microsoft.ML.Internal.Utilities
{
///
/// This is an interface for creating samples of a requested size from a stream of data of type .
@@ -56,7 +57,7 @@ internal sealed class ReservoirSamplerWithoutReplacement : IReservoirSampler<
// This array contains a cache of the elements composing the reservoir.
private readonly T[] _cache;
- private readonly IRandom _rnd;
+ private readonly Random _rnd;
private long _numSampled;
private readonly ValueGetter _getter;
@@ -67,7 +68,7 @@ internal sealed class ReservoirSamplerWithoutReplacement : IReservoirSampler<
public long NumSampled { get { return _numSampled; } }
- public ReservoirSamplerWithoutReplacement(IRandom rnd, int size, ValueGetter getter)
+ public ReservoirSamplerWithoutReplacement(Random rnd, int size, ValueGetter getter)
{
Contracts.CheckValue(rnd, nameof(rnd));
Contracts.CheckParam(size > 0, nameof(size), "Reservoir size must be positive");
@@ -135,7 +136,7 @@ internal sealed class ReservoirSamplerWithReplacement : IReservoirSampler
private readonly T[] _cache;
private readonly int[] _counts;
- private readonly IRandom _rnd;
+ private readonly Random _rnd;
private long _numSampled;
private readonly ValueGetter _getter;
@@ -146,7 +147,7 @@ internal sealed class ReservoirSamplerWithReplacement : IReservoirSampler
public long NumSampled { get { return _numSampled; } }
- public ReservoirSamplerWithReplacement(IRandom rnd, int size, ValueGetter getter)
+ public ReservoirSamplerWithReplacement(Random rnd, int size, ValueGetter getter)
{
Contracts.CheckValue(rnd, nameof(rnd));
Contracts.CheckParam(size > 0, nameof(size), "Reservoir size must be positive");
diff --git a/src/Microsoft.ML.Core/Utilities/ResourceManagerUtils.cs b/src/Microsoft.ML.Core/Utilities/ResourceManagerUtils.cs
index 9e9c6f80bb..0813dac0fe 100644
--- a/src/Microsoft.ML.Core/Utilities/ResourceManagerUtils.cs
+++ b/src/Microsoft.ML.Core/Utilities/ResourceManagerUtils.cs
@@ -10,7 +10,7 @@
using System.Threading;
using System.Threading.Tasks;
-namespace Microsoft.ML.Runtime.Internal.Utilities
+namespace Microsoft.ML.Internal.Utilities
{
///
/// This class takes care of downloading resources needed by ML.NET components. Resources are located in
diff --git a/src/Microsoft.ML.Core/Utilities/Stats.cs b/src/Microsoft.ML.Core/Utilities/Stats.cs
index 182239adde..fad719e80d 100644
--- a/src/Microsoft.ML.Core/Utilities/Stats.cs
+++ b/src/Microsoft.ML.Core/Utilities/Stats.cs
@@ -2,11 +2,10 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using Float = System.Single;
-
using System;
+using Float = System.Single;
-namespace Microsoft.ML.Runtime.Internal.Utilities
+namespace Microsoft.ML.Internal.Utilities
{
///
/// A class containing common statistical functions
@@ -20,7 +19,7 @@ internal static class Stats
/// Size of range to sample from, between 0 and int.MaxValue^2
/// Random number generator
/// Sampled value
- public static long SampleLong(long rangeSize, IRandom rand)
+ public static long SampleLong(long rangeSize, Random rand)
{
Contracts.CheckParam(rangeSize > 0, nameof(rangeSize), "rangeSize must be positive.");
@@ -51,7 +50,7 @@ public static long SampleLong(long rangeSize, IRandom rand)
/// A Random to use for the sampling
/// a sample
/// uses Joseph L. Leva's algorithm from "A fast normal random number generator", 1992
- public static double SampleFromGaussian(IRandom rand)
+ public static double SampleFromGaussian(Random rand)
{
double u;
double v;
@@ -75,7 +74,7 @@ public static double SampleFromGaussian(IRandom rand)
/// The random number generator to use
/// Sample from gamma distribution
/// Uses Marsaglia and Tsang's fast algorithm
- public static double SampleFromGamma(IRandom r, double alpha)
+ public static double SampleFromGamma(Random r, double alpha)
{
Contracts.CheckParam(alpha > 0, nameof(alpha), "alpha must be positive");
@@ -111,7 +110,7 @@ public static double SampleFromGamma(IRandom r, double alpha)
/// first parameter
/// second parameter
/// Sample from distribution
- public static double SampleFromBeta(IRandom rand, double alpha1, double alpha2)
+ public static double SampleFromBeta(Random rand, double alpha1, double alpha2)
{
double gamma1 = SampleFromGamma(rand, alpha1);
double gamma2 = SampleFromGamma(rand, alpha2);
@@ -124,7 +123,7 @@ public static double SampleFromBeta(IRandom rand, double alpha1, double alpha2)
/// Random generator to use
/// array of parameters
/// array in which to store resulting sample
- public static void SampleFromDirichlet(IRandom rand, double[] alphas, double[] result)
+ public static void SampleFromDirichlet(Random rand, double[] alphas, double[] result)
{
Contracts.Check(alphas.Length == result.Length,
"Dirichlet parameters must have the same dimensionality as sample space.");
@@ -141,7 +140,7 @@ public static void SampleFromDirichlet(IRandom rand, double[] alphas, double[] r
}
}
- public static int SampleFromPoisson(IRandom rand, double lambda)
+ public static int SampleFromPoisson(Random rand, double lambda)
{
if (lambda < 5)
{
@@ -201,7 +200,7 @@ public static int SampleFromPoisson(IRandom rand, double lambda)
// Mean refers to the mu parameter. Scale refers to the b parameter.
// https://en.wikipedia.org/wiki/Laplace_distribution
- public static Float SampleFromLaplacian(IRandom rand, Float mean, Float scale)
+ public static Float SampleFromLaplacian(Random rand, Float mean, Float scale)
{
Float u = rand.NextSingle();
u = u - 0.5f;
@@ -220,7 +219,7 @@ public static Float SampleFromLaplacian(IRandom rand, Float mean, Float scale)
///
///
///
- public static Float SampleFromCauchy(IRandom rand)
+ public static Float SampleFromCauchy(Random rand)
{
return (Float)Math.Tan(Math.PI * (rand.NextSingle() - 0.5));
}
@@ -233,7 +232,7 @@ public static Float SampleFromCauchy(IRandom rand)
/// Parameter p of binomial
///
/// Should be robust for all values of n, p
- public static int SampleFromBinomial(IRandom r, int n, double p)
+ public static int SampleFromBinomial(Random r, int n, double p)
{
return BinoRand.Next(r, n, p);
}
@@ -266,7 +265,7 @@ private static double Fc(int k)
}
}
- public static int Next(IRandom rand, int n, double p)
+ public static int Next(Random rand, int n, double p)
{
int x;
double pin = Math.Min(p, 1 - p);
@@ -284,7 +283,7 @@ public static int Next(IRandom rand, int n, double p)
// For small n
// Inverse transformation algorithm
// Described in Kachitvichyanukul and Schmeiser: "Binomial Random Variate Generation"
- private static int InvTransform(int n, double p, IRandom rn)
+ private static int InvTransform(int n, double p, Random rn)
{
int x = 0;
double u = rn.NextDouble();
@@ -308,7 +307,7 @@ private static int InvTransform(int n, double p, IRandom rn)
// For large n
// Algorithm from W. Hormann: "The Generation of Binomial Random Variables"
// This is algorithm BTRD
- private static int GenerateLarge(int n, double p, IRandom rn)
+ private static int GenerateLarge(int n, double p, Random rn)
{
double np = n * p;
double q = 1 - p;
diff --git a/src/Microsoft.ML.Core/Utilities/Stream.cs b/src/Microsoft.ML.Core/Utilities/Stream.cs
index 4fe21e7df7..633ee672f7 100644
--- a/src/Microsoft.ML.Core/Utilities/Stream.cs
+++ b/src/Microsoft.ML.Core/Utilities/Stream.cs
@@ -9,7 +9,7 @@
using System.Text;
using System.Threading;
-namespace Microsoft.ML.Runtime.Internal.Utilities
+namespace Microsoft.ML.Internal.Utilities
{
internal static partial class Utils
{
diff --git a/src/Microsoft.ML.Core/Utilities/SubsetStream.cs b/src/Microsoft.ML.Core/Utilities/SubsetStream.cs
index 3bf9ad2c9e..022ca28cb1 100644
--- a/src/Microsoft.ML.Core/Utilities/SubsetStream.cs
+++ b/src/Microsoft.ML.Core/Utilities/SubsetStream.cs
@@ -5,7 +5,7 @@
using System;
using System.IO;
-namespace Microsoft.ML.Runtime.Internal.Utilities
+namespace Microsoft.ML.Internal.Utilities
{
///
/// Returns a "view" stream, which appears to be a possibly truncated
diff --git a/src/Microsoft.ML.Core/Utilities/SummaryStatistics.cs b/src/Microsoft.ML.Core/Utilities/SummaryStatistics.cs
index 3e36191564..3382885e9c 100644
--- a/src/Microsoft.ML.Core/Utilities/SummaryStatistics.cs
+++ b/src/Microsoft.ML.Core/Utilities/SummaryStatistics.cs
@@ -4,7 +4,7 @@
using System;
-namespace Microsoft.ML.Runtime.Internal.Utilities
+namespace Microsoft.ML.Internal.Utilities
{
internal abstract class SummaryStatisticsBase
{
diff --git a/src/Microsoft.ML.Core/Utilities/SupervisedBinFinder.cs b/src/Microsoft.ML.Core/Utilities/SupervisedBinFinder.cs
index 63257823a2..f640043009 100644
--- a/src/Microsoft.ML.Core/Utilities/SupervisedBinFinder.cs
+++ b/src/Microsoft.ML.Core/Utilities/SupervisedBinFinder.cs
@@ -2,13 +2,11 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using Float = System.Single;
-
using System;
using System.Collections.Generic;
using System.Diagnostics;
-namespace Microsoft.ML.Runtime.Internal.Utilities
+namespace Microsoft.ML.Internal.Utilities
{
///
/// This class performs discretization of (value, label) pairs into bins in a way that minimizes
diff --git a/src/Microsoft.ML.Core/Utilities/TextReaderStream.cs b/src/Microsoft.ML.Core/Utilities/TextReaderStream.cs
index 682a0336cb..691876c0c7 100644
--- a/src/Microsoft.ML.Core/Utilities/TextReaderStream.cs
+++ b/src/Microsoft.ML.Core/Utilities/TextReaderStream.cs
@@ -6,7 +6,7 @@
using System.IO;
using System.Text;
-namespace Microsoft.ML.Runtime.Internal.Utilities
+namespace Microsoft.ML.Internal.Utilities
{
///
/// A readable that is backed by a .
diff --git a/src/Microsoft.ML.Core/Utilities/ThreadUtils.cs b/src/Microsoft.ML.Core/Utilities/ThreadUtils.cs
index 46a82a4e7c..3585b36f23 100644
--- a/src/Microsoft.ML.Core/Utilities/ThreadUtils.cs
+++ b/src/Microsoft.ML.Core/Utilities/ThreadUtils.cs
@@ -8,7 +8,7 @@
using System.Threading;
using System.Threading.Tasks;
-namespace Microsoft.ML.Runtime.Internal.Utilities
+namespace Microsoft.ML.Internal.Utilities
{
internal static partial class Utils
{
diff --git a/src/Microsoft.ML.Core/Utilities/Tree.cs b/src/Microsoft.ML.Core/Utilities/Tree.cs
index 8d4c0f7585..cbec20fe8e 100644
--- a/src/Microsoft.ML.Core/Utilities/Tree.cs
+++ b/src/Microsoft.ML.Core/Utilities/Tree.cs
@@ -5,7 +5,7 @@
using System.Collections;
using System.Collections.Generic;
-namespace Microsoft.ML.Runtime.Internal.Utilities
+namespace Microsoft.ML.Internal.Utilities
{
///
/// The tree structure is simultaneously a tree, and a node in a tree. The interface to
diff --git a/src/Microsoft.ML.Core/Utilities/Utils.cs b/src/Microsoft.ML.Core/Utilities/Utils.cs
index d0f6f6256e..c6e21486ca 100644
--- a/src/Microsoft.ML.Core/Utilities/Utils.cs
+++ b/src/Microsoft.ML.Core/Utilities/Utils.cs
@@ -2,8 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using Float = System.Single;
-
using System;
using System.Collections;
using System.Collections.Generic;
@@ -13,7 +11,7 @@
using System.Text.RegularExpressions;
using System.Threading;
-namespace Microsoft.ML.Runtime.Internal.Utilities
+namespace Microsoft.ML.Internal.Utilities
{
[BestFriend]
@@ -218,10 +216,10 @@ public static int FindIndexSorted(this IList input, int value)
/// In case of duplicates it returns the index of the first one.
/// It guarantees that items before the returned index are < value, while those at and after the returned index are >= value.
///
- public static int FindIndexSorted(this Single[] input, Single value)
+ public static int FindIndexSorted(this IList input, float value)
{
Contracts.AssertValue(input);
- return FindIndexSorted(input, 0, input.Length, value);
+ return FindIndexSorted(input, 0, input.Count, value);
}
///
@@ -342,11 +340,11 @@ public static int FindIndexSorted(this IList input, int min, int lim, int v
/// In case of duplicates it returns the index of the first one.
/// It guarantees that items before the returned index are < value, while those at and after the returned index are >= value.
///
- public static int FindIndexSorted(this Single[] input, int min, int lim, Single value)
+ public static int FindIndexSorted(this IList input, int min, int lim, float value)
{
Contracts.AssertValue(input);
- Contracts.Assert(0 <= min & min <= lim & lim <= input.Length);
- Contracts.Assert(!Single.IsNaN(value));
+ Contracts.Assert(0 <= min & min <= lim & lim <= input.Count);
+ Contracts.Assert(!float.IsNaN(value));
int minCur = min;
int limCur = lim;
@@ -354,7 +352,7 @@ public static int FindIndexSorted(this Single[] input, int min, int lim, Single
{
int mid = (int)(((uint)minCur + (uint)limCur) / 2);
Contracts.Assert(minCur <= mid & mid < limCur);
- Contracts.Assert(!Single.IsNaN(input[mid]));
+ Contracts.Assert(!float.IsNaN(input[mid]));
if (input[mid] >= value)
limCur = mid;
@@ -530,56 +528,10 @@ public static int[] GetRandomPermutation(Random rand, int size)
Contracts.Assert(size >= 0);
var res = GetIdentityPermutation(size);
- Shuffle(rand, res);
+ Shuffle(rand, res);
return res;
}
- public static int[] GetRandomPermutation(IRandom rand, int size)
- {
- Contracts.AssertValue(rand);
- Contracts.Assert(size >= 0);
-
- var res = GetIdentityPermutation(size);
- Shuffle(rand, res);
- return res;
- }
-
- public static void Shuffle(IRandom rand, T[] rgv)
- {
- Contracts.AssertValue(rand);
- Contracts.AssertValue(rgv);
-
- Shuffle(rand, rgv, 0, rgv.Length);
- }
-
- public static void Shuffle(Random rand, T[] rgv)
- {
- Contracts.AssertValue(rand);
- Contracts.AssertValue(rgv);
-
- Shuffle(rand, rgv, 0, rgv.Length);
- }
-
- public static void Shuffle(IRandom rand, T[] rgv, int min, int lim)
- {
- Contracts.AssertValue(rand);
- Contracts.AssertValue(rgv);
- Contracts.Check(0 <= min & min <= lim & lim <= rgv.Length);
-
- for (int iv = min; iv < lim; iv++)
- Swap(ref rgv[iv], ref rgv[iv + rand.Next(lim - iv)]);
- }
-
- public static void Shuffle(Random rand, T[] rgv, int min, int lim)
- {
- Contracts.AssertValue(rand);
- Contracts.AssertValue(rgv);
- Contracts.Check(0 <= min & min <= lim & lim <= rgv.Length);
-
- for (int iv = min; iv < lim; iv++)
- Swap(ref rgv[iv], ref rgv[iv + rand.Next(lim - iv)]);
- }
-
public static bool AreEqual(Single[] arr1, Single[] arr2)
{
if (arr1 == arr2)
@@ -614,6 +566,14 @@ public static bool AreEqual(Double[] arr1, Double[] arr2)
return true;
}
+ public static void Shuffle(Random rand, Span rgv)
+ {
+ Contracts.AssertValue(rand);
+
+ for (int iv = 0; iv < rgv.Length; iv++)
+ Swap(ref rgv[iv], ref rgv[iv + rand.Next(rgv.Length - iv)]);
+ }
+
public static bool AreEqual(int[] arr1, int[] arr2)
{
if (arr1 == arr2)
@@ -653,14 +613,14 @@ public static string ExtractLettersAndNumbers(string value)
return Regex.Replace(value, "[^A-Za-z0-9]", "");
}
- public static bool IsSorted(Float[] values)
+ public static bool IsSorted(IList values)
{
if (Utils.Size(values) <= 1)
return true;
var prev = values[0];
- for (int i = 1; i < values.Length; i++)
+ for (int i = 1; i < values.Count; i++)
{
if (!(values[i] >= prev))
return false;
@@ -1104,6 +1064,15 @@ public static void MarshalActionInvoke(Action
+ /// A four-argument version of .
+ ///
+ public static void MarshalActionInvoke(Action act, Type genArg, TArg1 arg1, TArg2 arg2, TArg3 arg3, TArg4 arg4)
+ {
+ var meth = MarshalActionInvokeCheckAndCreate(genArg, act);
+ meth.Invoke(act.Target, new object[] { arg1, arg2, arg3, arg4 });
+ }
+
public static string GetDescription(this Enum value)
{
Type type = value.GetType();
diff --git a/src/Microsoft.ML.Core/Utilities/VBufferUtils.cs b/src/Microsoft.ML.Core/Utilities/VBufferUtils.cs
index f730d61724..9120ccbedd 100644
--- a/src/Microsoft.ML.Core/Utilities/VBufferUtils.cs
+++ b/src/Microsoft.ML.Core/Utilities/VBufferUtils.cs
@@ -4,9 +4,9 @@
using System;
using System.Collections.Generic;
-using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Data;
-namespace Microsoft.ML.Runtime.Internal.Utilities
+namespace Microsoft.ML.Internal.Utilities
{
// REVIEW: Consider automatic densification in some of the operations, where appropriate.
// REVIEW: Once we do the conversions from Vector/WritableVector, review names of methods,
@@ -405,7 +405,7 @@ public static void ApplyAt(ref VBuffer dst, int slot, SlotValueManipulator
// we are modifying in the sparse vector, in which case the vector becomes
// dense. Then there is no need to do anything with indices.
bool needIndices = dstValuesCount + 1 < dst.Length;
- editor = VBufferEditor.Create(ref dst, dst.Length, dstValuesCount + 1);
+ editor = VBufferEditor.Create(ref dst, dst.Length, dstValuesCount + 1, keepOldOnResize: true);
if (idx != dstValuesCount)
{
// We have to do some sort of shift copy.
@@ -1322,7 +1322,7 @@ public static void ApplyInto(in VBuffer a, in VBuffer
// REVIEW: Worth optimizing the newCount == a.Length case?
// Probably not...
- editor = VBufferEditor.Create(ref dst, a.Length, newCount);
+ editor = VBufferEditor.Create(ref dst, a.Length, newCount, requireIndicesOnDense: true);
Span indices = editor.Indices;
if (newCount == bValues.Length)
diff --git a/src/Microsoft.ML.CpuMath/AlignedArray.cs b/src/Microsoft.ML.CpuMath/AlignedArray.cs
index 4602e884e4..6687022e8f 100644
--- a/src/Microsoft.ML.CpuMath/AlignedArray.cs
+++ b/src/Microsoft.ML.CpuMath/AlignedArray.cs
@@ -2,20 +2,18 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using Microsoft.ML.Runtime.Internal.CpuMath.Core;
using System;
+using Microsoft.ML.Internal.CpuMath.Core;
-namespace Microsoft.ML.Runtime.Internal.CpuMath
+namespace Microsoft.ML.Internal.CpuMath
{
- using Float = System.Single;
-
///
- /// This implements a logical array of Floats that is automatically aligned for SSE/AVX operations.
+ /// This implements a logical array of floats that is automatically aligned for SSE/AVX operations.
/// To pin and force alignment, call the GetPin method, typically wrapped in a using (since it
/// returns a Pin struct that is IDisposable). From the pin, you can get the IntPtr to pass to
/// native code.
///
- /// The ctor takes an alignment value, which must be a power of two at least sizeof(Float).
+ /// The ctor takes an alignment value, which must be a power of two at least sizeof(float).
///
[BestFriend]
internal sealed class AlignedArray
@@ -24,7 +22,7 @@ internal sealed class AlignedArray
// items, also filled with NaN. Note that _size * sizeof(Float) is divisible by _cbAlign.
// It is illegal to access any slot outsize [_base, _base + _size). This is internal so clients
// can easily pin it.
- public Float[] Items;
+ public float[] Items;
private readonly int _size; // Must be divisible by (_cbAlign / sizeof(Float)).
private readonly int _cbAlign; // The alignment in bytes, a power of two, divisible by sizeof(Float).
@@ -40,12 +38,12 @@ public AlignedArray(int size, int cbAlign)
{
Contracts.Assert(0 < size);
// cbAlign should be a power of two.
- Contracts.Assert(sizeof(Float) <= cbAlign);
+ Contracts.Assert(sizeof(float) <= cbAlign);
Contracts.Assert((cbAlign & (cbAlign - 1)) == 0);
// cbAlign / sizeof(Float) should divide size.
- Contracts.Assert((size * sizeof(Float)) % cbAlign == 0);
+ Contracts.Assert((size * sizeof(float)) % cbAlign == 0);
- Items = new Float[size + cbAlign / sizeof(Float)];
+ Items = new float[size + cbAlign / sizeof(float)];
_size = size;
_cbAlign = cbAlign;
_lock = new object();
@@ -54,15 +52,15 @@ public AlignedArray(int size, int cbAlign)
public unsafe int GetBase(long addr)
{
#if DEBUG
- fixed (Float* pv = Items)
- Contracts.Assert((Float*)addr == pv);
+ fixed (float* pv = Items)
+ Contracts.Assert((float*)addr == pv);
#endif
int cbLow = (int)(addr & (_cbAlign - 1));
int ibMin = cbLow == 0 ? 0 : _cbAlign - cbLow;
- Contracts.Assert(ibMin % sizeof(Float) == 0);
+ Contracts.Assert(ibMin % sizeof(float) == 0);
- int ifltMin = ibMin / sizeof(Float);
+ int ifltMin = ibMin / sizeof(float);
if (ifltMin == _base)
return _base;
@@ -71,9 +69,9 @@ public unsafe int GetBase(long addr)
// Anything outsize [_base, _base + _size) should not be accessed, so
// set them to NaN, for debug validation.
for (int i = 0; i < _base; i++)
- Items[i] = Float.NaN;
+ Items[i] = float.NaN;
for (int i = _base + _size; i < Items.Length; i++)
- Items[i] = Float.NaN;
+ Items[i] = float.NaN;
#endif
return _base;
}
@@ -96,7 +94,7 @@ private void MoveData(int newBase)
public int CbAlign { get { return _cbAlign; } }
- public Float this[int index]
+ public float this[int index]
{
get
{
@@ -110,30 +108,30 @@ public Float this[int index]
}
}
- public void CopyTo(Float[] dst, int index, int count)
+ public void CopyTo(Span dst, int index, int count)
{
Contracts.Assert(0 <= count && count <= _size);
Contracts.Assert(dst != null);
Contracts.Assert(0 <= index && index <= dst.Length - count);
- Array.Copy(Items, _base, dst, index, count);
+ Items.AsSpan(_base, count).CopyTo(dst.Slice(index));
}
- public void CopyTo(int start, Float[] dst, int index, int count)
+ public void CopyTo(int start, Span dst, int index, int count)
{
Contracts.Assert(0 <= count);
Contracts.Assert(0 <= start && start <= _size - count);
Contracts.Assert(dst != null);
Contracts.Assert(0 <= index && index <= dst.Length - count);
- Array.Copy(Items, start + _base, dst, index, count);
+ Items.AsSpan(start + _base, count).CopyTo(dst.Slice(index));
}
- public void CopyFrom(ReadOnlySpan src)
+ public void CopyFrom(ReadOnlySpan src)
{
Contracts.Assert(src.Length <= _size);
src.CopyTo(Items.AsSpan(_base));
}
- public void CopyFrom(int start, ReadOnlySpan src)
+ public void CopyFrom(int start, ReadOnlySpan src)
{
Contracts.Assert(0 <= start && start <= _size - src.Length);
src.CopyTo(Items.AsSpan(start + _base));
@@ -143,7 +141,7 @@ public void CopyFrom(int start, ReadOnlySpan src)
// valuesSrc contains only the non-zero entries. Those are copied into their logical positions in the dense array.
// rgposSrc contains the logical positions + offset of the non-zero entries in the dense array.
// rgposSrc runs parallel to the valuesSrc array.
- public void CopyFrom(ReadOnlySpan rgposSrc, ReadOnlySpan valuesSrc, int posMin, int iposMin, int iposLim, bool zeroItems)
+ public void CopyFrom(ReadOnlySpan rgposSrc, ReadOnlySpan valuesSrc, int posMin, int iposMin, int iposLim, bool zeroItems)
{
Contracts.Assert(rgposSrc != null);
Contracts.Assert(valuesSrc != null);
@@ -202,7 +200,7 @@ public void ZeroItems(int[] rgposSrc, int posMin, int iposMin, int iposLim)
// REVIEW: This is hackish and slightly dangerous. Perhaps we should wrap this in an
// IDisposable that "locks" this, prohibiting GetBase from being called, while the buffer
// is "checked out".
- public void GetRawBuffer(out Float[] items, out int offset)
+ public void GetRawBuffer(out float[] items, out int offset)
{
items = Items;
offset = _base;
diff --git a/src/Microsoft.ML.CpuMath/AlignedMatrix.cs b/src/Microsoft.ML.CpuMath/AlignedMatrix.cs
index 6d550fc3fc..d9dee9a868 100644
--- a/src/Microsoft.ML.CpuMath/AlignedMatrix.cs
+++ b/src/Microsoft.ML.CpuMath/AlignedMatrix.cs
@@ -2,14 +2,13 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using Float = System.Single;
-
-using Microsoft.ML.Runtime.Internal.CpuMath.Core;
using System;
using System.Collections;
using System.Collections.Generic;
+using Microsoft.ML.Internal.CpuMath.Core;
+using Float = System.Single;
-namespace Microsoft.ML.Runtime.Internal.CpuMath
+namespace Microsoft.ML.Internal.CpuMath
{
using Conditional = System.Diagnostics.ConditionalAttribute;
diff --git a/src/Microsoft.ML.CpuMath/AssemblyInfo.cs b/src/Microsoft.ML.CpuMath/AssemblyInfo.cs
index 7710703c29..65557d90f9 100644
--- a/src/Microsoft.ML.CpuMath/AssemblyInfo.cs
+++ b/src/Microsoft.ML.CpuMath/AssemblyInfo.cs
@@ -2,8 +2,8 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using Microsoft.ML.Runtime.Internal.CpuMath.Core;
using System.Runtime.CompilerServices;
+using Microsoft.ML.Internal.CpuMath.Core;
[assembly: InternalsVisibleTo("Microsoft.ML.CpuMath.UnitTests.netstandard" + PublicKey.TestValue)]
[assembly: InternalsVisibleTo("Microsoft.ML.CpuMath.UnitTests.netcoreapp" + PublicKey.TestValue)]
diff --git a/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs b/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs
index 55e692eb63..ba86b9158f 100644
--- a/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs
+++ b/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs
@@ -14,9 +14,10 @@
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86;
+using Microsoft.ML.Internal.CpuMath.Core;
using nuint = System.UInt64;
-namespace Microsoft.ML.Runtime.Internal.CpuMath
+namespace Microsoft.ML.Internal.CpuMath
{
internal static class AvxIntrinsics
{
@@ -46,6 +47,25 @@ internal static class AvxIntrinsics
private static readonly Vector256 _absMask256 = Avx.StaticCast(Avx.SetAllVector256(0x7FFFFFFF));
+ private const int Vector256Alignment = 32;
+
+ [MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
+ private static bool HasCompatibleAlignment(AlignedArray alignedArray)
+ {
+ Contracts.AssertValue(alignedArray);
+ Contracts.Assert(alignedArray.Size > 0);
+ return (alignedArray.CbAlign % Vector256Alignment) == 0;
+ }
+
+ [MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
+ private static unsafe float* GetAlignedBase(AlignedArray alignedArray, float* unalignedBase)
+ {
+ Contracts.AssertValue(alignedArray);
+ float* alignedBase = unalignedBase + alignedArray.GetBase((long)unalignedBase);
+ Contracts.Assert(((long)alignedBase % Vector256Alignment) == 0);
+ return alignedBase;
+ }
+
[MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
private static Vector128 GetHigh(in Vector256 x)
=> Avx.ExtractVector128(x, 1);
@@ -153,20 +173,18 @@ private static Vector256 MultiplyAdd(Vector256 src1, Vector256 mat, ReadOnlySpan src, Span dst, int crow, int ccol)
- {
- Contracts.Assert(crow % 4 == 0);
- Contracts.Assert(ccol % 4 == 0);
+ Contracts.Assert(HasCompatibleAlignment(mat));
+ Contracts.Assert(HasCompatibleAlignment(src));
+ Contracts.Assert(HasCompatibleAlignment(dst));
- fixed (float* psrc = &MemoryMarshal.GetReference(src))
- fixed (float* pdst = &MemoryMarshal.GetReference(dst))
- fixed (float* pmat = &MemoryMarshal.GetReference(mat))
- fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0])
- fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0])
+ fixed (float* pSrcStart = &src.Items[0])
+ fixed (float* pDstStart = &dst.Items[0])
+ fixed (float* pMatStart = &mat.Items[0])
{
+ float* psrc = GetAlignedBase(src, pSrcStart);
+ float* pdst = GetAlignedBase(dst, pDstStart);
+ float* pmat = GetAlignedBase(mat, pMatStart);
+
float* pSrcEnd = psrc + ccol;
float* pDstEnd = pdst + crow;
float* pDstCurrent = pdst;
@@ -175,118 +193,36 @@ public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan sr
while (pDstCurrent < pDstEnd)
{
Vector256 res0 = Avx.SetZeroVector256();
- Vector256 res1 = Avx.SetZeroVector256();
- Vector256 res2 = Avx.SetZeroVector256();
- Vector256 res3 = Avx.SetZeroVector256();
+ Vector256 res1 = res0;
+ Vector256 res2 = res0;
+ Vector256 res3 = res0;
- int length = ccol;
float* pSrcCurrent = psrc;
- nuint address = (nuint)(pMatCurrent);
- int misalignment = (int)(address % 32);
- int remainder = 0;
-
- if ((misalignment & 3) != 0)
+ while (pSrcCurrent < pSrcEnd)
{
- // Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations
- while (pSrcCurrent < pSrcEnd)
- {
- Vector256 vector = Avx.LoadVector256(pSrcCurrent);
-
- float* pMatTemp = pMatCurrent;
- res0 = MultiplyAdd(pMatTemp, vector, res0);
- res1 = MultiplyAdd(pMatTemp += ccol, vector, res1);
- res2 = MultiplyAdd(pMatTemp += ccol, vector, res2);
- res3 = MultiplyAdd(pMatTemp += ccol, vector, res3);
-
- pSrcCurrent += 8;
- pMatCurrent += 8;
- }
- }
- else
- {
- if (misalignment != 0)
- {
- // Handle cases where the data is not 256-bit aligned by doing an unaligned read and then
- // masking any elements that will be included in the first aligned read
- misalignment >>= 2;
- misalignment = 8 - misalignment;
-
- Vector256 mask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + (misalignment * 8));
-
- // We only align pMat since it has significantly more reads.
- float* pMatTemp = pMatCurrent;
- Vector256 x01 = Avx.And(mask, Avx.LoadVector256(pMatTemp));
- Vector256 x11 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol));
- Vector256 x21 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol));
- Vector256 x31 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol));
- Vector256 vector = Avx.And(mask, Avx.LoadVector256(pSrcCurrent));
-
- res0 = Avx.Multiply(x01, vector);
- res1 = Avx.Multiply(x11, vector);
- res2 = Avx.Multiply(x21, vector);
- res3 = Avx.Multiply(x31, vector);
-
- pMatCurrent += misalignment;
- pSrcCurrent += misalignment;
- length -= misalignment;
- }
-
- if (length > 7)
- {
- // Handle all the 256-bit blocks that we can now that we have offset to an aligned address
- remainder = length % 8;
-
- while (pSrcCurrent + 8 <= pSrcEnd)
- {
- // If we aren't using the VEX-encoding, the JIT will only fold away aligned loads
- // (due to semantics of the legacy encoding).
- // We don't need an assert, since the instruction will throw for unaligned inputs.
- Vector256 vector = Avx.LoadVector256(pSrcCurrent);
-
- float* pMatTemp = pMatCurrent;
- res0 = MultiplyAdd(pMatTemp, vector, res0);
- res1 = MultiplyAdd(pMatTemp += ccol, vector, res1);
- res2 = MultiplyAdd(pMatTemp += ccol, vector, res2);
- res3 = MultiplyAdd(pMatTemp += ccol, vector, res3);
-
- pSrcCurrent += 8;
- pMatCurrent += 8;
- }
- }
- else
- {
- // Handle the "worst-case" scenario, which is when we have 8-16 elements and the input is not
- // 256-bit aligned. This means we can't do any aligned loads and will just end up doing two
- // unaligned loads where we mask the input each time.
- remainder = length;
- }
-
- if (remainder != 0)
- {
- // Handle any trailing elements that don't fit into a 256-bit block by moving back so that the next
- // unaligned load will read to the end of the array and then mask out any elements already processed
-
- pMatCurrent -= (8 - remainder);
- pSrcCurrent -= (8 - remainder);
-
- Vector256 mask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8));
-
- float* pMatTemp = pMatCurrent;
- Vector256 x01 = Avx.And(mask, Avx.LoadVector256(pMatTemp));
- Vector256 x11 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol));
- Vector256 x21 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol));
- Vector256 x31 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol));
- Vector256 vector = Avx.And(mask, Avx.LoadVector256(pSrcCurrent));
-
- res0 = MultiplyAdd(x01, vector, res0);
- res1 = MultiplyAdd(x11, vector, res1);
- res2 = MultiplyAdd(x21, vector, res2);
- res3 = MultiplyAdd(x31, vector, res3);
-
- pMatCurrent += 8;
- pSrcCurrent += 8;
- }
+ float* pMatTemp = pMatCurrent;
+ Contracts.Assert(((nuint)(pMatTemp) % 32) == 0);
+ Contracts.Assert(((nuint)(pSrcCurrent) % 32) == 0);
+
+ // The JIT will only fold away unaligned loads due to the semantics behind
+ // the VEX-encoding of the memory operand for `ins xmm, xmm, [mem]`. Since
+ // modern hardware has unaligned loads that are as fast as aligned loads,
+ // when it doesn't cross a cache-line/page boundary, we will just assert
+ // that the alignment is correct and allow for the more-efficient codegen.
+ Vector256 x01 = Avx.LoadVector256(pMatTemp);
+ Vector256 x11 = Avx.LoadVector256(pMatTemp += ccol);
+ Vector256 x21 = Avx.LoadVector256(pMatTemp += ccol);
+ Vector256 x31 = Avx.LoadVector256(pMatTemp += ccol);
+ Vector256 x02 = Avx.LoadVector256(pSrcCurrent);
+
+ res0 = MultiplyAdd(x01, x02, res0);
+ res1 = MultiplyAdd(x11, x02, res1);
+ res2 = MultiplyAdd(x21, x02, res2);
+ res3 = MultiplyAdd(x31, x02, res3);
+
+ pSrcCurrent += 8;
+ pMatCurrent += 8;
}
// Add up the entries of each, with the 4 results in res0
@@ -295,7 +231,7 @@ public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan sr
res0 = Avx.HorizontalAdd(res0, res2);
Vector128 sum = Sse.Add(Avx.GetLowerHalf(res0), GetHigh(in res0));
- Sse.Store(pDstCurrent, sum);
+ Sse.StoreAligned(pDstCurrent, sum);
pDstCurrent += 4;
pMatCurrent += 3 * ccol;
@@ -307,24 +243,21 @@ public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan sr
public static unsafe void MatMulP(AlignedArray mat, ReadOnlySpan rgposSrc, AlignedArray src,
int posMin, int iposMin, int iposEnd, AlignedArray dst, int crow, int ccol)
{
- MatMulP(mat.Items, rgposSrc, src.Items, posMin, iposMin, iposEnd, dst.Items, crow, ccol);
- }
-
- public static unsafe void MatMulP(ReadOnlySpan mat, ReadOnlySpan rgposSrc, ReadOnlySpan src,
- int posMin, int iposMin, int iposEnd, Span dst, int crow, int ccol)
- {
- Contracts.Assert(crow % 8 == 0);
- Contracts.Assert(ccol % 8 == 0);
-
// REVIEW: For extremely sparse inputs, interchanging the loops would
// likely be more efficient.
- fixed (float* psrc = &MemoryMarshal.GetReference(src))
- fixed (float* pdst = &MemoryMarshal.GetReference(dst))
- fixed (float* pmat = &MemoryMarshal.GetReference(mat))
- fixed (int* pposSrc = &MemoryMarshal.GetReference(rgposSrc))
- fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0])
- fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0])
+ Contracts.Assert(HasCompatibleAlignment(mat));
+ Contracts.Assert(HasCompatibleAlignment(src));
+ Contracts.Assert(HasCompatibleAlignment(dst));
+
+ fixed (float* pSrcStart = &src.Items[0])
+ fixed (float* pDstStart = &dst.Items[0])
+ fixed (float* pMatStart = &mat.Items[0])
+ fixed (int* pposSrc = &rgposSrc[0])
{
+ float* psrc = GetAlignedBase(src, pSrcStart);
+ float* pdst = GetAlignedBase(dst, pDstStart);
+ float* pmat = GetAlignedBase(mat, pMatStart);
+
int* pposMin = pposSrc + iposMin;
int* pposEnd = pposSrc + iposEnd;
float* pDstEnd = pdst + crow;
@@ -332,116 +265,7 @@ public static unsafe void MatMulP(ReadOnlySpan mat, ReadOnlySpan rgp
float* pSrcCurrent = psrc - posMin;
float* pDstCurrent = pdst;
- nuint address = (nuint)(pDstCurrent);
- int misalignment = (int)(address % 32);
- int length = crow;
- int remainder = 0;
-
- if ((misalignment & 3) != 0)
- {
- // Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations
- while (pDstCurrent < pDstEnd)
- {
- Avx.Store(pDstCurrent, SparseMultiplicationAcrossRow());
- pDstCurrent += 8;
- pm0 += 8 * ccol;
- }
- }
- else
- {
- if (misalignment != 0)
- {
- // Handle cases where the data is not 256-bit aligned by doing an unaligned read and then
- // masking any elements that will be included in the first aligned read
- misalignment >>= 2;
- misalignment = 8 - misalignment;
-
- Vector256 mask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + (misalignment * 8));
-
- float* pm1 = pm0 + ccol;
- float* pm2 = pm1 + ccol;
- float* pm3 = pm2 + ccol;
- Vector256 result = Avx.SetZeroVector256();
-
- int* ppos = pposMin;
-
- while (ppos < pposEnd)
- {
- int col1 = *ppos;
- int col2 = col1 + 4 * ccol;
- Vector256 x1 = Avx.SetVector256(pm3[col2], pm2[col2], pm1[col2], pm0[col2],
- pm3[col1], pm2[col1], pm1[col1], pm0[col1]);
-
- x1 = Avx.And(mask, x1);
- Vector256 x2 = Avx.SetAllVector256(pSrcCurrent[col1]);
- result = MultiplyAdd(x2, x1, result);
- ppos++;
- }
-
- Avx.Store(pDstCurrent, result);
- pDstCurrent += misalignment;
- pm0 += misalignment * ccol;
- length -= misalignment;
- }
-
- if (length > 7)
- {
- // Handle all the 256-bit blocks that we can now that we have offset to an aligned address
- remainder = length % 8;
- while (pDstCurrent < pDstEnd)
- {
- Avx.Store(pDstCurrent, SparseMultiplicationAcrossRow());
- pDstCurrent += 8;
- pm0 += 8 * ccol;
- }
- }
- else
- {
- // Handle the "worst-case" scenario, which is when we have 8-16 elements and the input is not
- // 256-bit aligned. This means we can't do any aligned loads and will just end up doing two
- // unaligned loads where we mask the input each time.
- remainder = length;
- }
-
- if (remainder != 0)
- {
- // Handle any trailing elements that don't fit into a 256-bit block by moving back so that the next
- // unaligned load will read to the end of the array and then mask out any elements already processed
-
- pDstCurrent -= (8 - remainder);
- pm0 -= (8 - remainder) * ccol;
- Vector256 trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8));
- Vector256 leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + ((8 - remainder) * 8));
-
- float* pm1 = pm0 + ccol;
- float* pm2 = pm1 + ccol;
- float* pm3 = pm2 + ccol;
- Vector256 result = Avx.SetZeroVector256();
-
- int* ppos = pposMin;
-
- while (ppos < pposEnd)
- {
- int col1 = *ppos;
- int col2 = col1 + 4 * ccol;
- Vector256 x1 = Avx.SetVector256(pm3[col2], pm2[col2], pm1[col2], pm0[col2],
- pm3[col1], pm2[col1], pm1[col1], pm0[col1]);
- x1 = Avx.And(x1, trailingMask);
-
- Vector256 x2 = Avx.SetAllVector256(pSrcCurrent[col1]);
- result = MultiplyAdd(x2, x1, result);
- ppos++;
- }
-
- result = Avx.Add(result, Avx.And(leadingMask, Avx.LoadVector256(pDstCurrent)));
-
- Avx.Store(pDstCurrent, result);
- pDstCurrent += 8;
- pm0 += 8 * ccol;
- }
- }
-
- Vector256 SparseMultiplicationAcrossRow()
+ while (pDstCurrent < pDstEnd)
{
float* pm1 = pm0 + ccol;
float* pm2 = pm1 + ccol;
@@ -455,329 +279,133 @@ Vector256 SparseMultiplicationAcrossRow()
int col1 = *ppos;
int col2 = col1 + 4 * ccol;
Vector256 x1 = Avx.SetVector256(pm3[col2], pm2[col2], pm1[col2], pm0[col2],
- pm3[col1], pm2[col1], pm1[col1], pm0[col1]);
+ pm3[col1], pm2[col1], pm1[col1], pm0[col1]);
Vector256 x2 = Avx.SetAllVector256(pSrcCurrent[col1]);
- result = MultiplyAdd(x2, x1, result);
+ x2 = Avx.Multiply(x2, x1);
+ result = Avx.Add(result, x2);
+
ppos++;
}
- return result;
+ Avx.StoreAligned(pDstCurrent, result);
+ pDstCurrent += 8;
+ pm0 += 8 * ccol;
}
}
}
public static unsafe void MatMulTran(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol)
{
- MatMulTran(mat.Items, src.Items, dst.Items, crow, ccol);
- }
-
- public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan src, Span dst, int crow, int ccol)
- {
- Contracts.Assert(crow % 4 == 0);
- Contracts.Assert(ccol % 4 == 0);
+ Contracts.Assert(HasCompatibleAlignment(mat));
+ Contracts.Assert(HasCompatibleAlignment(src));
+ Contracts.Assert(HasCompatibleAlignment(dst));
- fixed (float* psrc = &MemoryMarshal.GetReference(src))
- fixed (float* pdst = &MemoryMarshal.GetReference(dst))
- fixed (float* pmat = &MemoryMarshal.GetReference(mat))
- fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0])
- fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0])
+ fixed (float* pSrcStart = &src.Items[0])
+ fixed (float* pDstStart = &dst.Items[0])
+ fixed (float* pMatStart = &mat.Items[0])
{
+ float* psrc = GetAlignedBase(src, pSrcStart);
+ float* pdst = GetAlignedBase(dst, pDstStart);
+ float* pmat = GetAlignedBase(mat, pMatStart);
+
float* pSrcEnd = psrc + ccol;
float* pDstEnd = pdst + crow;
float* pSrcCurrent = psrc;
float* pMatCurrent = pmat;
- // The reason behind adding the if condtion instead of boolean flag
- // is to avoid branching in codegen.
- if (pSrcCurrent < pSrcEnd)
- {
- Vector128 h01 = Sse.LoadVector128(pSrcCurrent);
- // Replicate each slot of h01 (ABCD) into its own register.
- Vector128 h11 = Avx.Permute(h01, 0x55); // B
- Vector128 h21 = Avx.Permute(h01, 0xAA); // C
- Vector128 h31 = Avx.Permute(h01, 0xFF); // D
- h01 = Avx.Permute(h01, 0x00); // A
+ // We do 4-way unrolling
+ Vector128 h01 = Sse.LoadAlignedVector128(pSrcCurrent);
+ // Replicate each slot of h01 (ABCD) into its own register.
+ Vector128 h11 = Sse.Shuffle(h01, h01, 0x55); // B
+ Vector128 h21 = Sse.Shuffle(h01, h01, 0xAA); // C
+ Vector128 h31 = Sse.Shuffle(h01, h01, 0xFF); // D
+ h01 = Sse.Shuffle(h01, h01, 0x00); // A
- Vector256 x01 = Avx.SetHighLow(h01, h01);
- Vector256 x11 = Avx.SetHighLow(h11, h11);
- Vector256 x21 = Avx.SetHighLow(h21, h21);
- Vector256 x31 = Avx.SetHighLow(h31, h31);
+ Vector256 x01 = Avx.SetHighLow(h01, h01);
+ Vector256 x11 = Avx.SetHighLow(h11, h11);
+ Vector256 x21 = Avx.SetHighLow(h21, h21);
+ Vector256 x31 = Avx.SetHighLow(h31, h31);
- int length = crow;
- float* pDstCurrent = pdst;
+ pSrcCurrent += 4;
- nuint address = (nuint)(pMatCurrent);
- int misalignment = (int)(address % 32);
+ float* pDstCurrent = pdst;
- if ((misalignment & 3) != 0)
- {
- // Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations
- while (pDstCurrent < pDstEnd)
- {
- float* pMatTemp = pMatCurrent;
- Vector256 x02 = Avx.Multiply(x01, Avx.LoadVector256(pMatTemp));
- Vector256 x12 = Avx.Multiply(x11, Avx.LoadVector256(pMatTemp += crow));
- Vector256 x22 = Avx.Multiply(x21, Avx.LoadVector256(pMatTemp += crow));
- Vector256 x32 = Avx.Multiply(x31, Avx.LoadVector256(pMatTemp += crow));
-
- x02 = Avx.Add(x02, x12);
- x22 = Avx.Add(x22, x32);
- x02 = Avx.Add(x02, x22);
-
- Avx.Store(pDstCurrent, x02);
- pDstCurrent += 8;
- pMatCurrent += 8;
- }
- }
- else
- {
- int remainder = 0;
- if (misalignment != 0)
- {
- // Handle cases where the data is not 256-bit aligned by doing an unaligned read and then
- // masking any elements that will be included in the first aligned read
- misalignment >>= 2;
- misalignment = 8 - misalignment;
-
- Vector256 leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + (misalignment * 8));
-
- // We only align pMat since it has significantly more reads.
- float* pMatTemp = pMatCurrent;
- Vector256 x02 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp));
- Vector256 x12 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow));
- Vector256 x22 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow));
- Vector256 x32 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow));
-
- x02 = Avx.Multiply(x01, x02);
- x12 = Avx.Multiply(x11, x12);
- x22 = Avx.Multiply(x21, x22);
- x32 = Avx.Multiply(x31, x32);
-
- x02 = Avx.Add(x02, x12);
- x22 = Avx.Add(x22, x32);
- x02 = Avx.Add(x02, x22);
-
- Vector256 trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + ((8 - misalignment) * 8));
- Vector256 x3 = Avx.LoadVector256(pDstCurrent);
- x02 = Avx.Or(x02, Avx.And(x3, trailingMask));
-
- Avx.Store(pDstCurrent, x02);
- pMatCurrent += misalignment;
- pDstCurrent += misalignment;
- length -= misalignment;
- }
- if (length > 7)
- {
- // Handle all the 256-bit blocks that we can now that we have offset to an aligned address
- remainder = length % 8;
-
- while (pDstCurrent + 8 <= pDstEnd)
- {
- // If we aren't using the VEX-encoding, the JIT will only fold away aligned loads
- // (due to semantics of the legacy encoding).
- // We don't need an assert, since the instruction will throw for unaligned inputs.
- float* pMatTemp = pMatCurrent;
-
- Vector256 x02 = Avx.Multiply(x01, Avx.LoadVector256(pMatTemp));
- Vector256 x12 = Avx.Multiply(x11, Avx.LoadVector256(pMatTemp += crow));
- Vector256 x22 = Avx.Multiply(x21, Avx.LoadVector256(pMatTemp += crow));
- Vector256 x32 = Avx.Multiply(x31, Avx.LoadVector256(pMatTemp += crow));
-
- x02 = Avx.Add(x02, x12);
- x22 = Avx.Add(x22, x32);
- x02 = Avx.Add(x02, x22);
-
- Avx.Store(pDstCurrent, x02);
- pDstCurrent += 8;
- pMatCurrent += 8;
- }
- }
- else
- {
- // Handle the "worst-case" scenario, which is when we have 8-16 elements and the input is not
- // 256-bit aligned. This means we can't do any aligned loads and will just end up doing two
- // unaligned loads where we mask the input each time.
- remainder = length;
- }
+ while (pDstCurrent < pDstEnd)
+ {
+ float* pMatTemp = pMatCurrent;
+ Contracts.Assert(((nuint)(pMatTemp) % 32) == 0);
- if (remainder != 0)
- {
- // Handle any trailing elements that don't fit into a 256-bit block by moving back so that the next
- // unaligned load will read to the end of the array and then mask out any elements already processed
-
- pMatCurrent -= (8 - remainder);
- pDstCurrent -= (8 - remainder);
- Vector256 trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8));
-
- float* pMatTemp = pMatCurrent;
- Vector256 x02 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp));
- Vector256 x12 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp += crow));
- Vector256 x22 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp += crow));
- Vector256 x32 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp += crow));
-
- x02 = Avx.Multiply(x01, x02);
- x12 = Avx.Multiply(x11, x12);
- x22 = Avx.Multiply(x21, x22);
- x32 = Avx.Multiply(x31, x32);
-
- x02 = Avx.Add(x02, x12);
- x22 = Avx.Add(x22, x32);
- x02 = Avx.Add(x02, x22);
-
- Vector256 leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + ((8 - remainder) * 8));
- Vector256 x3 = Avx.LoadVector256(pDstCurrent);
- x02 = Avx.Or(x02, Avx.And(x3, leadingMask));
-
- Avx.Store(pDstCurrent, x02);
- pDstCurrent += 8;
- pMatCurrent += 8;
- }
- }
+ // The JIT will only fold away unaligned loads due to the semantics behind
+ // the VEX-encoding of the memory operand for `ins xmm, xmm, [mem]`. Since
+ // modern hardware has unaligned loads that are as fast as aligned loads,
+ // when it doesn't cross a cache-line/page boundary, we will just assert
+ // that the alignment is correct and allow for the more-efficient codegen.
+ Vector256 x02 = Avx.LoadVector256(pMatTemp);
+ Vector256 x12 = Avx.LoadVector256(pMatTemp += crow);
+ Vector256 x22 = Avx.LoadVector256(pMatTemp += crow);
+ Vector256 x32 = Avx.LoadVector256(pMatTemp += crow);
- pMatCurrent += 3 * crow;
- pSrcCurrent += 4;
+ x02 = Avx.Multiply(x01, x02);
+ x02 = MultiplyAdd(x11, x12, x02);
+
+ x22 = Avx.Multiply(x21, x22);
+ x22 = MultiplyAdd(x31, x32, x22);
+
+ x02 = Avx.Add(x02, x22);
+ Avx.StoreAligned(pDstCurrent, x02);
+
+ pDstCurrent += 8;
+ pMatCurrent += 8;
}
- // We do 4-way unrolling
+ pMatCurrent += 3 * crow;
+
while (pSrcCurrent < pSrcEnd)
{
- Vector128 h01 = Sse.LoadVector128(pSrcCurrent);
+ h01 = Sse.LoadAlignedVector128(pSrcCurrent);
// Replicate each slot of h01 (ABCD) into its own register.
- Vector128 h11 = Avx.Permute(h01, 0x55); // B
- Vector128 h21 = Avx.Permute(h01, 0xAA); // C
- Vector128 h31 = Avx.Permute(h01, 0xFF); // D
- h01 = Avx.Permute(h01, 0x00); // A
-
- Vector256 x01 = Avx.SetHighLow(h01, h01);
- Vector256 x11 = Avx.SetHighLow(h11, h11);
- Vector256 x21 = Avx.SetHighLow(h21, h21);
- Vector256 x31 = Avx.SetHighLow(h31, h31);
+ h11 = Sse.Shuffle(h01, h01, 0x55); // B
+ h21 = Sse.Shuffle(h01, h01, 0xAA); // C
+ h31 = Sse.Shuffle(h01, h01, 0xFF); // D
+ h01 = Sse.Shuffle(h01, h01, 0x00); // A
- int length = crow;
- float* pDstCurrent = pdst;
+ x01 = Avx.SetHighLow(h01, h01);
+ x11 = Avx.SetHighLow(h11, h11);
+ x21 = Avx.SetHighLow(h21, h21);
+ x31 = Avx.SetHighLow(h31, h31);
- nuint address = (nuint)(pMatCurrent);
- int misalignment = (int)(address % 32);
+ pDstCurrent = pdst;
- if ((misalignment & 3) != 0)
+ while (pDstCurrent < pDstEnd)
{
- while (pDstCurrent < pDstEnd)
- {
- float* pMatTemp = pMatCurrent;
- Vector256 x02 = Avx.Multiply(x01, Avx.LoadVector256(pMatTemp));
- Vector256 x12 = Avx.Multiply(x11, Avx.LoadVector256(pMatTemp += crow));
- Vector256 x22 = Avx.Multiply(x21, Avx.LoadVector256(pMatTemp += crow));
- Vector256 x32 = Avx.Multiply(x31, Avx.LoadVector256(pMatTemp += crow));
+ float* pMatTemp = pMatCurrent;
- x02 = Avx.Add(x02, x12);
- x22 = Avx.Add(x22, x32);
- x02 = Avx.Add(x02, x22);
+ Contracts.Assert(((nuint)(pMatTemp) % 32) == 0);
+ Contracts.Assert(((nuint)(pDstCurrent) % 32) == 0);
- x02 = Avx.Add(x02, Avx.LoadVector256(pDstCurrent));
+ // The JIT will only fold away unaligned loads due to the semantics behind
+ // the VEX-encoding of the memory operand for `ins xmm, xmm, [mem]`. Since
+ // modern hardware has unaligned loads that are as fast as aligned loads,
+ // when it doesn't cross a cache-line/page boundary, we will just assert
+ // that the alignment is correct and allow for the more-efficient codegen.
+ Vector256 x02 = Avx.LoadVector256(pMatTemp);
+ Vector256 x12 = Avx.LoadVector256(pMatTemp += crow);
+ Vector256 x22 = Avx.LoadVector256(pMatTemp += crow);
+ Vector256 x32 = Avx.LoadVector256(pMatTemp += crow);
+ Vector256 x3 = Avx.LoadVector256(pDstCurrent);
- Avx.Store(pDstCurrent, x02);
- pDstCurrent += 8;
- pMatCurrent += 8;
- }
- }
- else
- {
- int remainder = 0;
- if (misalignment != 0)
- {
- // Handle cases where the data is not 256-bit aligned by doing an unaligned read and then
- // masking any elements that will be included in the first aligned read
- misalignment >>= 2;
- misalignment = 8 - misalignment;
-
- Vector256 leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + (misalignment * 8));
-
- // We only align pMat since it has significantly more reads.
- float* pMatTemp = pMatCurrent;
- Vector256 x02 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp));
- Vector256 x12 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow));
- Vector256 x22 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow));
- Vector256