diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000000..509e5fce93 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,5 @@ +root = true + +[*.cs] +# Sort using directives with System.* appearing first +dotnet_sort_system_directives_first = true \ No newline at end of file diff --git a/.vsts-dotnet-ci.yml b/.vsts-dotnet-ci.yml index 2699c9b7b3..ff6a0d6bec 100644 --- a/.vsts-dotnet-ci.yml +++ b/.vsts-dotnet-ci.yml @@ -23,6 +23,20 @@ phases: queue: name: Hosted VS2017 +- template: /build/ci/phase-template.yml + parameters: + name: core30 + buildScript: build.cmd + customMatrixes: + Build_Debug_Intrinsics: + _configuration: Debug-Intrinsics + _config_short: DI + Build_Release_Intrinsics: + _configuration: Release-Intrinsics + _config_short: RI + queue: + name: Hosted VS2017 + - template: /build/ci/phase-template.yml parameters: name: Windows_x86 diff --git a/DotnetCLIVersion.netcoreapp.latest.txt b/DotnetCLIVersion.netcoreapp.latest.txt new file mode 100644 index 0000000000..23f4ef9a4c --- /dev/null +++ b/DotnetCLIVersion.netcoreapp.latest.txt @@ -0,0 +1 @@ +3.0.100-alpha1-009622 \ No newline at end of file diff --git a/Microsoft.ML.sln b/Microsoft.ML.sln index 1b7d335955..b5e7e1b5e3 100644 --- a/Microsoft.ML.sln +++ b/Microsoft.ML.sln @@ -29,8 +29,6 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.KMeansClusteri EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.PCA", "src\Microsoft.ML.PCA\Microsoft.ML.PCA.csproj", "{58E06735-1129-4DD5-86E0-6BBFF049AAD9}" EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Api", "src\Microsoft.ML.Api\Microsoft.ML.Api.csproj", "{2F636A2C-062C-49F4-85F3-60DCADAB6A43}" -EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Tests", "test\Microsoft.ML.Tests\Microsoft.ML.Tests.csproj", "{64BC22D3-1E76-41EF-94D8-C79E471FF2DD}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.TestFramework", "test\Microsoft.ML.TestFramework\Microsoft.ML.TestFramework.csproj", "{B5989C06-4FFA-46C1-9D85-9366B34AB0A2}" @@ -139,6 +137,18 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.DnnImageFeatur EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.DnnImageFeaturizer.ResNet101", "src\Microsoft.ML.DnnImageFeaturizer.ResNet101\Microsoft.ML.DnnImageFeaturizer.ResNet101.csproj", "{DB7CEB5E-8BE6-48A7-87BE-B91D9AE96F71}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.EntryPoints", "src\Microsoft.ML.EntryPoints\Microsoft.ML.EntryPoints.csproj", "{7504D46F-E4B3-43CB-9B1C-82F3131F1C99}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.StaticPipe", "src\Microsoft.ML.StaticPipe\Microsoft.ML.StaticPipe.csproj", "{6B1B93D0-142A-4111-A20E-62B55A3E36A3}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.TensorFlow.StaticPipe", "src\Microsoft.ML.TensorFlow.StaticPipe\Microsoft.ML.TensorFlow.StaticPipe.csproj", "{F95F7AFB-03AF-4D20-BD75-1740B5FF71D3}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.HalLearners.StaticPipe", "src\Microsoft.ML.HalLearners.StaticPipe\Microsoft.ML.HalLearners.StaticPipe.csproj", "{2F25EF6A-C754-45BE-AD9E-7DDF46A1B51A}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.OnnxTransform.StaticPipe", "src\Microsoft.ML.OnnxTransform.StaticPipe\Microsoft.ML.OnnxTransform.StaticPipe.csproj", "{D1324668-9568-40F4-AA55-30A9A516C230}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.ML.LightGBM.StaticPipe", "src\Microsoft.ML.LightGBM.StaticPipe\Microsoft.ML.LightGBM.StaticPipe.csproj", "{22C51B08-ACAE-47B2-A312-462DC239A23B}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -227,14 +237,6 @@ Global {58E06735-1129-4DD5-86E0-6BBFF049AAD9}.Release|Any CPU.Build.0 = Release|Any CPU {58E06735-1129-4DD5-86E0-6BBFF049AAD9}.Release-Intrinsics|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU {58E06735-1129-4DD5-86E0-6BBFF049AAD9}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU - {2F636A2C-062C-49F4-85F3-60DCADAB6A43}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {2F636A2C-062C-49F4-85F3-60DCADAB6A43}.Debug|Any CPU.Build.0 = Debug|Any CPU - {2F636A2C-062C-49F4-85F3-60DCADAB6A43}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug-Intrinsics|Any CPU - {2F636A2C-062C-49F4-85F3-60DCADAB6A43}.Debug-Intrinsics|Any CPU.Build.0 = Debug-Intrinsics|Any CPU - {2F636A2C-062C-49F4-85F3-60DCADAB6A43}.Release|Any CPU.ActiveCfg = Release|Any CPU - {2F636A2C-062C-49F4-85F3-60DCADAB6A43}.Release|Any CPU.Build.0 = Release|Any CPU - {2F636A2C-062C-49F4-85F3-60DCADAB6A43}.Release-Intrinsics|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU - {2F636A2C-062C-49F4-85F3-60DCADAB6A43}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU {64BC22D3-1E76-41EF-94D8-C79E471FF2DD}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {64BC22D3-1E76-41EF-94D8-C79E471FF2DD}.Debug|Any CPU.Build.0 = Debug|Any CPU {64BC22D3-1E76-41EF-94D8-C79E471FF2DD}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug-Intrinsics|Any CPU @@ -531,6 +533,54 @@ Global {DB7CEB5E-8BE6-48A7-87BE-B91D9AE96F71}.Release|Any CPU.Build.0 = Release|Any CPU {DB7CEB5E-8BE6-48A7-87BE-B91D9AE96F71}.Release-Intrinsics|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU {DB7CEB5E-8BE6-48A7-87BE-B91D9AE96F71}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU + {7504D46F-E4B3-43CB-9B1C-82F3131F1C99}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {7504D46F-E4B3-43CB-9B1C-82F3131F1C99}.Debug|Any CPU.Build.0 = Debug|Any CPU + {7504D46F-E4B3-43CB-9B1C-82F3131F1C99}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug-Intrinsics|Any CPU + {7504D46F-E4B3-43CB-9B1C-82F3131F1C99}.Debug-Intrinsics|Any CPU.Build.0 = Debug-Intrinsics|Any CPU + {7504D46F-E4B3-43CB-9B1C-82F3131F1C99}.Release|Any CPU.ActiveCfg = Release|Any CPU + {7504D46F-E4B3-43CB-9B1C-82F3131F1C99}.Release|Any CPU.Build.0 = Release|Any CPU + {7504D46F-E4B3-43CB-9B1C-82F3131F1C99}.Release-Intrinsics|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU + {7504D46F-E4B3-43CB-9B1C-82F3131F1C99}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU + {6B1B93D0-142A-4111-A20E-62B55A3E36A3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {6B1B93D0-142A-4111-A20E-62B55A3E36A3}.Debug|Any CPU.Build.0 = Debug|Any CPU + {6B1B93D0-142A-4111-A20E-62B55A3E36A3}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug-Intrinsics|Any CPU + {6B1B93D0-142A-4111-A20E-62B55A3E36A3}.Debug-Intrinsics|Any CPU.Build.0 = Debug-Intrinsics|Any CPU + {6B1B93D0-142A-4111-A20E-62B55A3E36A3}.Release|Any CPU.ActiveCfg = Release|Any CPU + {6B1B93D0-142A-4111-A20E-62B55A3E36A3}.Release|Any CPU.Build.0 = Release|Any CPU + {6B1B93D0-142A-4111-A20E-62B55A3E36A3}.Release-Intrinsics|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU + {6B1B93D0-142A-4111-A20E-62B55A3E36A3}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU + {F95F7AFB-03AF-4D20-BD75-1740B5FF71D3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {F95F7AFB-03AF-4D20-BD75-1740B5FF71D3}.Debug|Any CPU.Build.0 = Debug|Any CPU + {F95F7AFB-03AF-4D20-BD75-1740B5FF71D3}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug-Intrinsics|Any CPU + {F95F7AFB-03AF-4D20-BD75-1740B5FF71D3}.Debug-Intrinsics|Any CPU.Build.0 = Debug-Intrinsics|Any CPU + {F95F7AFB-03AF-4D20-BD75-1740B5FF71D3}.Release|Any CPU.ActiveCfg = Release|Any CPU + {F95F7AFB-03AF-4D20-BD75-1740B5FF71D3}.Release|Any CPU.Build.0 = Release|Any CPU + {F95F7AFB-03AF-4D20-BD75-1740B5FF71D3}.Release-Intrinsics|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU + {F95F7AFB-03AF-4D20-BD75-1740B5FF71D3}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU + {2F25EF6A-C754-45BE-AD9E-7DDF46A1B51A}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {2F25EF6A-C754-45BE-AD9E-7DDF46A1B51A}.Debug|Any CPU.Build.0 = Debug|Any CPU + {2F25EF6A-C754-45BE-AD9E-7DDF46A1B51A}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug-Intrinsics|Any CPU + {2F25EF6A-C754-45BE-AD9E-7DDF46A1B51A}.Debug-Intrinsics|Any CPU.Build.0 = Debug-Intrinsics|Any CPU + {2F25EF6A-C754-45BE-AD9E-7DDF46A1B51A}.Release|Any CPU.ActiveCfg = Release|Any CPU + {2F25EF6A-C754-45BE-AD9E-7DDF46A1B51A}.Release|Any CPU.Build.0 = Release|Any CPU + {2F25EF6A-C754-45BE-AD9E-7DDF46A1B51A}.Release-Intrinsics|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU + {2F25EF6A-C754-45BE-AD9E-7DDF46A1B51A}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU + {D1324668-9568-40F4-AA55-30A9A516C230}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {D1324668-9568-40F4-AA55-30A9A516C230}.Debug|Any CPU.Build.0 = Debug|Any CPU + {D1324668-9568-40F4-AA55-30A9A516C230}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug-Intrinsics|Any CPU + {D1324668-9568-40F4-AA55-30A9A516C230}.Debug-Intrinsics|Any CPU.Build.0 = Debug-Intrinsics|Any CPU + {D1324668-9568-40F4-AA55-30A9A516C230}.Release|Any CPU.ActiveCfg = Release|Any CPU + {D1324668-9568-40F4-AA55-30A9A516C230}.Release|Any CPU.Build.0 = Release|Any CPU + {D1324668-9568-40F4-AA55-30A9A516C230}.Release-Intrinsics|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU + {D1324668-9568-40F4-AA55-30A9A516C230}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU + {22C51B08-ACAE-47B2-A312-462DC239A23B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {22C51B08-ACAE-47B2-A312-462DC239A23B}.Debug|Any CPU.Build.0 = Debug|Any CPU + {22C51B08-ACAE-47B2-A312-462DC239A23B}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug-Intrinsics|Any CPU + {22C51B08-ACAE-47B2-A312-462DC239A23B}.Debug-Intrinsics|Any CPU.Build.0 = Debug-Intrinsics|Any CPU + {22C51B08-ACAE-47B2-A312-462DC239A23B}.Release|Any CPU.ActiveCfg = Release|Any CPU + {22C51B08-ACAE-47B2-A312-462DC239A23B}.Release|Any CPU.Build.0 = Release|Any CPU + {22C51B08-ACAE-47B2-A312-462DC239A23B}.Release-Intrinsics|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU + {22C51B08-ACAE-47B2-A312-462DC239A23B}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -546,7 +596,6 @@ Global {7288C084-11C0-43BE-AC7F-45DCFEAEEBF6} = {09EADF06-BE25-4228-AB53-95AE3E15B530} {F1CAE3AB-4F86-4BC0-BBA8-C4A58E7E8A4A} = {09EADF06-BE25-4228-AB53-95AE3E15B530} {58E06735-1129-4DD5-86E0-6BBFF049AAD9} = {09EADF06-BE25-4228-AB53-95AE3E15B530} - {2F636A2C-062C-49F4-85F3-60DCADAB6A43} = {09EADF06-BE25-4228-AB53-95AE3E15B530} {64BC22D3-1E76-41EF-94D8-C79E471FF2DD} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4} {B5989C06-4FFA-46C1-9D85-9366B34AB0A2} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4} {6B047E09-39C9-4583-96F3-685D84CA4117} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4} @@ -589,6 +638,12 @@ Global {6C29AA9B-054B-4762-BEA5-D305B932AA80} = {09EADF06-BE25-4228-AB53-95AE3E15B530} {4805129D-78C8-46D4-9519-0AD9B0574D6D} = {09EADF06-BE25-4228-AB53-95AE3E15B530} {DB7CEB5E-8BE6-48A7-87BE-B91D9AE96F71} = {09EADF06-BE25-4228-AB53-95AE3E15B530} + {7504D46F-E4B3-43CB-9B1C-82F3131F1C99} = {09EADF06-BE25-4228-AB53-95AE3E15B530} + {6B1B93D0-142A-4111-A20E-62B55A3E36A3} = {09EADF06-BE25-4228-AB53-95AE3E15B530} + {F95F7AFB-03AF-4D20-BD75-1740B5FF71D3} = {09EADF06-BE25-4228-AB53-95AE3E15B530} + {2F25EF6A-C754-45BE-AD9E-7DDF46A1B51A} = {09EADF06-BE25-4228-AB53-95AE3E15B530} + {D1324668-9568-40F4-AA55-30A9A516C230} = {09EADF06-BE25-4228-AB53-95AE3E15B530} + {22C51B08-ACAE-47B2-A312-462DC239A23B} = {09EADF06-BE25-4228-AB53-95AE3E15B530} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D} diff --git a/README.md b/README.md index 15dd419948..38829cd751 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ Along with these ML capabilities, this first release of ML.NET also brings the f ML.NET runs on Windows, Linux, and macOS using [.NET Core](https://github.com/dotnet/core), or Windows using .NET Framework. 64 bit is supported on all platforms. 32 bit is supported on Windows, except for TensorFlow, LightGBM, and ONNX related functionality. -The current release is 0.7. Check out the [release notes](docs/release-notes/0.7/release-0.7.md) and [blog post](https://blogs.msdn.microsoft.com/dotnet/2018/11/08/announcing-ml-net-0-7-machine-learning-net/) to see what's new. +The current release is 0.8. Check out the [release notes](docs/release-notes/0.8/release-0.8.md) and [blog post](https://blogs.msdn.microsoft.com/dotnet/2018/12/02/announcing-ml-net-0-8-machine-learning-for-net/) to see what's new. First, ensure you have installed [.NET Core 2.1](https://www.microsoft.com/net/learn/get-started) or later. ML.NET also works on the .NET Framework 4.6.1 or later, but 4.7.2 or later is recommended. diff --git a/build.proj b/build.proj index de8c507de7..15fea4e309 100644 --- a/build.proj +++ b/build.proj @@ -78,7 +78,7 @@ - https://aka.ms/tlc-resources/benchmarks/%(Identity) + https://aka.ms/mlnet-resources/benchmarks/%(Identity) $(MSBuildThisFileDirectory)/test/data/external/%(Identity) diff --git a/build/BranchInfo.props b/build/BranchInfo.props index 02e98d6b5d..1627429558 100644 --- a/build/BranchInfo.props +++ b/build/BranchInfo.props @@ -1,7 +1,7 @@ 0 - 8 + 9 0 preview diff --git a/build/Dependencies.props b/build/Dependencies.props index 46304fc56a..8d2f7abdc1 100644 --- a/build/Dependencies.props +++ b/build/Dependencies.props @@ -15,7 +15,7 @@ 3.5.1 2.2.1.1 - 1.1.0 + 0.1.5 0.0.0.7 2.1.3 4.5.0 diff --git a/build/ExternalBenchmarkDataFiles.props b/build/ExternalBenchmarkDataFiles.props index ad3d350d60..42df4ccd96 100644 --- a/build/ExternalBenchmarkDataFiles.props +++ b/build/ExternalBenchmarkDataFiles.props @@ -1,5 +1,6 @@ + diff --git a/build/ci/phase-template.yml b/build/ci/phase-template.yml index 9929ba182d..a9d2a35943 100644 --- a/build/ci/phase-template.yml +++ b/build/ci/phase-template.yml @@ -3,6 +3,7 @@ parameters: architecture: x64 buildScript: '' queue: {} + customMatrixes: '' phases: - phase: ${{ parameters.name }} @@ -14,12 +15,15 @@ phases: timeoutInMinutes: 45 parallel: 99 matrix: - Build_Debug: - _configuration: Debug - _config_short: D - Build_Release: - _configuration: Release - _config_short: R + ${{ if eq(parameters.customMatrixes, '') }}: + Build_Debug: + _configuration: Debug + _config_short: D + Build_Release: + _configuration: Release + _config_short: R + ${{ if ne(parameters.customMatrixes, '') }}: + ${{ insert }}: ${{ parameters.customMatrixes }} ${{ insert }}: ${{ parameters.queue }} steps: - script: $(_buildScript) -$(_configuration) -buildArch=$(_arch) diff --git a/build/vsts-ci.yml b/build/vsts-ci.yml index d0171921ac..ab77d0ed4a 100644 --- a/build/vsts-ci.yml +++ b/build/vsts-ci.yml @@ -24,7 +24,7 @@ phases: container: LinuxContainer steps: # Only build native assets to avoid conflicts. - - script: ./build.sh -buildNative -$(BuildConfig) + - script: ./build.sh -buildNative -$(BuildConfig) -skipRIDAgnosticAssets displayName: Build - task: PublishBuildArtifacts@1 @@ -49,7 +49,7 @@ phases: - agent.os -equals Darwin steps: # Only build native assets to avoid conflicts. - - script: ./build.sh -buildNative -$(BuildConfig) + - script: ./build.sh -buildNative -$(BuildConfig) -skipRIDAgnosticAssets displayName: Build - task: PublishBuildArtifacts@1 @@ -89,7 +89,7 @@ phases: condition: and(succeeded(), in(variables._SignType, 'real', 'test')) # Only build native assets to avoid conflicts. - - script: ./build.cmd -buildNative -$(BuildConfig) -buildArch=x86 + - script: ./build.cmd -buildNative -$(BuildConfig) -buildArch=x86 -skipRIDAgnosticAssets displayName: Build - task: MSBuild@1 diff --git a/config.json b/config.json index 8436586e61..99f72f7051 100644 --- a/config.json +++ b/config.json @@ -3,7 +3,7 @@ "Configuration": { "description": "Sets the optimization level for the Build Configuration you want to build.", "valueType": "property", - "values": [ "Debug", "Release" ], + "values": [ "Debug", "Release", "Debug-Intrinsics", "Release-Intrinsics" ], "defaultValue": "Debug" }, "TargetArchitecture": { @@ -30,6 +30,12 @@ "values": [], "defaultValue": "" }, + "SkipRIDAgnosticAssets": { + "description": "Prevents RID agnostic assets in redist from being built.", + "valueType": "property", + "values": [], + "defaultValue": "" + }, "MsBuildLogging": { "description": "MsBuild logging options.", "valueType": "passThrough", @@ -94,6 +100,18 @@ "Configuration": "Release" } }, + "debug-intrinsics": { + "description": "Sets optimization level to debug for managed build configuration and builds against netcoreapp3.0. (/p:Configuration=Debug-Intrinsics)", + "settings": { + "Configuration": "Debug-Intrinsics" + } + }, + "release-intrinsics": { + "description": "Sets optimization level to release for managed build configuration and builds against netcoreapp3.0. (/p:Configuration=Release-Intrinsics)", + "settings": { + "Configuration": "Release-Intrinsics" + } + }, "buildArch": { "description": "Sets the architecture for the native build. (/p:TargetArchitecture=[value])", "settings": { @@ -106,6 +124,12 @@ "BuildNative": "default" } }, + "skipRIDAgnosticAssets": { + "description": "Avoid building RID agnostic assets in redist.", + "settings": { + "SkipRIDAgnosticAssets": "default" + } + }, "buildPackages": { "description": "Builds the NuGet packages.", "settings": { diff --git a/docs/building/netcoreapp3.0-instructions.md b/docs/building/netcoreapp3.0-instructions.md index 338e06ef83..d65bcba34b 100644 --- a/docs/building/netcoreapp3.0-instructions.md +++ b/docs/building/netcoreapp3.0-instructions.md @@ -1,10 +1,8 @@ In order to build ML.NET for .NET Core 3.0, you need to do a few manual steps. -1. Pick a version of the .NET Core 3.0 SDK you want to use. As of this writing, I'm using `3.0.100-alpha1-009622`. You can get the latest available version from the [dotnet/core-sdk README](https://github.com/dotnet/core-sdk#installers-and-binaries) page. -2. Change the [DotnetCLIVersion.txt](https://github.com/dotnet/machinelearning/blob/master/DotnetCLIVersion.txt) file to use that version number. -3. Delete the local `.\Tools\` folder from the root of the repo, to ensure you download the new version. -4. Run `.\build.cmd -- /p:Configuration=Release-Intrinsics` from the root of the repo. -5. If you want to build the NuGet packages, `.\build.cmd -buildPackages` after step 4. +1. Delete the local `.\Tools\` folder from the root of the repo, to ensure you download the new version of the .NET Core SDK. +2. Run `.\build.cmd -- /p:Configuration=Release-Intrinsics` or `.\build.cmd -Release-Intrinsics` from the root of the repo. +3. If you want to build the NuGet packages, `.\build.cmd -buildPackages` after step 2. If you are using Visual Studio, you will need to do the following: diff --git a/docs/building/unix-instructions.md b/docs/building/unix-instructions.md index 7efa1c414c..ec7e27e8d4 100644 --- a/docs/building/unix-instructions.md +++ b/docs/building/unix-instructions.md @@ -5,6 +5,7 @@ Building ML.NET on Linux and macOS 1. Install the prerequisites ([Linux](#user-content-linux), [macOS](#user-content-macos)) 2. Clone the machine learning repo `git clone --recursive https://github.com/dotnet/machinelearning.git` 3. Navigate to the `machinelearning` directory +4. Run `git submodule update --init` if you have not previously done so 4. Run the build script `./build.sh` Calling the script `./build.sh` builds both the native and managed code. diff --git a/docs/code/EntryPoints.md b/docs/code/EntryPoints.md index dbcc4e6bc9..9fb11a6e99 100644 --- a/docs/code/EntryPoints.md +++ b/docs/code/EntryPoints.md @@ -220,12 +220,11 @@ parameter. ## How to create an entry point for an existing ML.NET component -The steps to take, to create an entry point for an existing ML.NET component, are: -1. Add the `SignatureEntryPointModule` signature to the `LoadableClass` assembly attribute. +1. Add a `LoadableClass` assembly attribute with the `SignatureEntryPointModule` signature as shown [here](https://github.com/dotnet/machinelearning/blob/9db16c85888e7163c671543faee6ba1f47015d68/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs#L27). 2. Create a public static method, that: - a. Takes as input, among others, an object representing the arguments of the component you want to expose. - b. Initializes and run the components, returning one of the nested classes of `Microsoft.ML.Runtime.EntryPoints.CommonOutputs` - c. Is annotated with the `TlcModule.EntryPoint` attribute + 1. Takes an object representing the arguments of the component you want to expose as shown [here](https://github.com/dotnet/machinelearning/blob/9db16c85888e7163c671543faee6ba1f47015d68/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs#L414) + 2. Initializes and runs the component, returning one of the nested classes of [`Microsoft.ML.EntryPoints.CommonOutputs`](https://github.com/dotnet/machinelearning/blob/master/src/Microsoft.ML.Data/EntryPoints/CommonOutputs.cs) + 3. Is annotated with the [`TlcModule.EntryPoint`](https://github.com/dotnet/machinelearning/blob/9db16c85888e7163c671543faee6ba1f47015d68/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs#L407) attribute -Based on the type of entry point being created, there are further conventions on the name of the method, for example, the Trainers entry points are typically called: 'TrainMultiClass', 'TrainBinary' etc, based on the task. -Look at [OnlineGradientDescent](../../src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs) for an example of a component and its entry point. \ No newline at end of file +For an example of a transformer as an entrypoint, see [OneHotVectorizer](https://github.com/dotnet/machinelearning/blob/9db16c85888e7163c671543faee6ba1f47015d68/src/Microsoft.ML.Transforms/OneHotEncoding.cs#L283). +For a trainer-estimator, see [LogisticRegression](https://github.com/dotnet/machinelearning/blob/9db16c85888e7163c671543faee6ba1f47015d68/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs#L407). diff --git a/docs/code/MlNetCookBook.md b/docs/code/MlNetCookBook.md index 6ce3c2c67b..908010c721 100644 --- a/docs/code/MlNetCookBook.md +++ b/docs/code/MlNetCookBook.md @@ -95,7 +95,7 @@ This is how you can read this data: var mlContext = new MLContext(); // Create the reader: define the data columns and where to find them in the text file. -var reader = mlContext.Data.TextReader(ctx => ( +var reader = mlContext.Data.CreateTextReader(ctx => ( // A boolean column depicting the 'target label'. IsOver50K: ctx.LoadBool(0), // Three text columns. @@ -115,9 +115,7 @@ If the schema of the data is not known at compile time, or too cumbersome, you c var mlContext = new MLContext(); // Create the reader: define the data columns and where to find them in the text file. -var reader = mlContext.Data.TextReader(new TextLoader.Arguments -{ - Column = new[] { +var reader = mlContext.Data.CreateTextReader(new[] { // A boolean column depicting the 'label'. new TextLoader.Column("IsOver50K", DataKind.BL, 0), // Three text columns. @@ -126,13 +124,51 @@ var reader = mlContext.Data.TextReader(new TextLoader.Arguments new TextLoader.Column("MaritalStatus", DataKind.TX, 3) }, // First line of the file is a header, not a data row. - HasHeader = true -}); + hasHeader: true +); // Now read the file (remember though, readers are lazy, so the actual reading will happen when the data is accessed). var data = reader.Read(dataPath); ``` +You can also create a data model class, and read the data based on this type. + +```csharp +// The data model. This type will be used through the document. +private class InspectedRow +{ + [LoadColumn(0)] + public bool IsOver50K { get; set; } + + [LoadColumn(1)] + public string Workclass { get; set; } + + [LoadColumn(2)] + public string Education { get; set; } + + [LoadColumn(3)] + public string MaritalStatus { get; set; } + + public string[] AllFeatures { get; set; } +} + +private class InspectedRowWithAllFeatures : InspectedRow +{ + public string[] AllFeatures { get; set; } +} + +// Create a new context for ML.NET operations. It can be used for exception tracking and logging, +// as a catalog of available operations and as the source of randomness. +var mlContext = new MLContext(); + +// Read the data into a data view. +var data = mlContext.Data.ReadFromTextFile(dataPath, + // First line of the file is a header, not a data row. + hasHeader: true +) + +``` + ## How do I load data from multiple files? You can again use the `TextLoader`, and specify an array of files to its Read method. @@ -155,7 +191,7 @@ This is how you can read this data: var mlContext = new MLContext(); // Create the reader: define the data columns and where to find them in the text file. -var reader = mlContext.Data.TextReader(ctx => ( +var reader = mlContext.Data.CreateTextReader(ctx => ( // A boolean column depicting the 'target label'. IsOver50K: ctx.LoadBool(14), // Three text columns. @@ -175,19 +211,17 @@ The code is very similar using the dynamic API: var mlContext = new MLContext(); // Create the reader: define the data columns and where to find them in the text file. -var reader = mlContext.Data.TextReader(new TextLoader.Arguments -{ - Column = new[] { +var reader = mlContext.Data.CreateTextReader(new[] { // A boolean column depicting the 'label'. - new TextLoader.Column("IsOver50k", DataKind.BL, 0), + new TextLoader.Column("IsOver50K", DataKind.BL, 0), // Three text columns. new TextLoader.Column("Workclass", DataKind.TX, 1), new TextLoader.Column("Education", DataKind.TX, 2), new TextLoader.Column("MaritalStatus", DataKind.TX, 3) }, // First line of the file is a header, not a data row. - HasHeader = true -}); + hasHeader: true +); var data = reader.Read(exampleFile1, exampleFile2); ``` @@ -211,14 +245,14 @@ Reading this file using `TextLoader`: var mlContext = new MLContext(); // Create the reader: define the data columns and where to find them in the text file. -var reader = mlContext.Data.TextReader(ctx => ( +var reader = mlContext.Data.CreateTextReader(ctx => ( // We read the first 11 values as a single float vector. FeatureVector: ctx.LoadFloat(0, 10), // Separately, read the target variable. Target: ctx.LoadFloat(11) ), // Default separator is tab, but we need a comma. - separator: ','); + separatorChar: ','); // Now read the file (remember though, readers are lazy, so the actual reading will happen when the data is accessed). @@ -233,19 +267,43 @@ If the schema of the data is not known at compile time, or too cumbersome, you c var mlContext = new MLContext(); // Create the reader: define the data columns and where to find them in the text file. -var reader = mlContext.Data.TextReader(new[] { +var reader = mlContext.Data.CreateTextReader(new[] { // We read the first 10 values as a single float vector. - new TextLoader.Column("FeatureVector", DataKind.R4, new[] {new TextLoader.Range(0, 9)}), + new TextLoader.Column("FeatureVector", DataKind.R4, new[] {new TextLoader.Range(0, 10)}), // Separately, read the target variable. - new TextLoader.Column("Target", DataKind.R4, 10) + new TextLoader.Column("Target", DataKind.R4, 11) }, // Default separator is tab, but we need a comma. - s => s.Separator = ","); + separatorChar: ','); // Now read the file (remember though, readers are lazy, so the actual reading will happen when the data is accessed). var data = reader.Read(dataPath); ``` +Or by creating a data model for it: + +```csharp +private class AdultData +{ + [LoadColumn("0", "10"), ColumnName("Features")] + public float FeatureVector { get; } + + [LoadColumn(11)] + public float Target { get; } +} + +// Create a new context for ML.NET operations. It can be used for exception tracking and logging, +// as a catalog of available operations and as the source of randomness. +var mlContext = new MLContext(); + +// Read the data into a data view. +var data = mlContext.Data.ReadFromTextFile(dataPath, + // First line of the file is a header, not a data row. + separatorChar: ',' +); + +``` + ## How do I debug my experiment or preview my pipeline? Most ML.NET operations are 'lazy': they are not actually processing data, they just validate that the operation is possible, and then defer execution until the output data is actually requested. This provides good efficiency, but makes it hard to step through and debug the experiment. @@ -302,7 +360,7 @@ Label Workclass education marital-status var mlContext = new MLContext(); // Create the reader: define the data columns and where to find them in the text file. -var reader = mlContext.Data.TextReader(ctx => ( +var reader = mlContext.Data.CreateTextReader(ctx => ( // A boolean column depicting the 'target label'. IsOver50K: ctx.LoadBool(0), // Three text columns. @@ -329,7 +387,7 @@ var transformedData = dataPipeline.Fit(data).Transform(data); // 'transformedData' is a 'promise' of data. Let's actually read it. var someRows = transformedData.AsDynamic // Convert to an enumerable of user-defined type. - .AsEnumerable(mlContext, reuseRowObject: false) + .AsEnumerable(mlContext, reuseRowObject: false) // Take a couple values as an array. .Take(4).ToArray(); @@ -346,54 +404,29 @@ var sameFeatureColumns = dynamicData.GetColumn(mlContext, "AllFeatures .Take(20).ToArray(); ``` -The above code assumes that we defined our `InspectedRow` class as follows: -```csharp -private class InspectedRow -{ - public bool IsOver50K; - public string Workclass; - public string Education; - public string MaritalStatus; - public string[] AllFeatures; -} -``` - You can also use the dynamic API to create the equivalent of the previous pipeline. ```csharp // Create a new context for ML.NET operations. It can be used for exception tracking and logging, // as a catalog of available operations and as the source of randomness. var mlContext = new MLContext(); -// Create the reader: define the data columns and where to find them in the text file. -var reader = mlContext.Data.TextReader(new TextLoader.Arguments -{ - Column = new[] { - // A boolean column depicting the 'label'. - new TextLoader.Column("IsOver50k", DataKind.BL, 0), - // Three text columns. - new TextLoader.Column("Workclass", DataKind.TX, 1), - new TextLoader.Column("Education", DataKind.TX, 2), - new TextLoader.Column("MaritalStatus", DataKind.TX, 3) - }, +// Read the data into a data view. +var data = mlContext.Data.ReadFromTextFile(dataPath, // First line of the file is a header, not a data row. - HasHeader = true -}); + hasHeader: true +); // Start creating our processing pipeline. For now, let's just concatenate all the text columns // together into one. var dynamicPipeline = mlContext.Transforms.Concatenate("AllFeatures", "Education", "MaritalStatus"); -// Let's verify that the data has been read correctly. -// First, we read the data file. -var data = reader.Read(dataPath); - // Fit our data pipeline and transform data with it. var transformedData = dynamicPipeline.Fit(data).Transform(data); // 'transformedData' is a 'promise' of data. Let's actually read it. var someRows = transformedData // Convert to an enumerable of user-defined type. - .AsEnumerable(mlContext, reuseRowObject: false) + .AsEnumerable(mlContext, reuseRowObject: false) // Take a couple values as an array. .Take(4).ToArray(); @@ -428,7 +461,7 @@ var mlContext = new MLContext(); // Step one: read the data as an IDataView. // First, we define the reader: specify the data columns and where to find them in the text file. -var reader = mlContext.Data.TextReader(ctx => ( +var reader = mlContext.Data.CreateTextReader(ctx => ( // We read the first 11 values as a single float vector. FeatureVector: ctx.LoadFloat(0, 10), // Separately, read the target variable. @@ -437,16 +470,30 @@ var reader = mlContext.Data.TextReader(ctx => ( // The data file has header. hasHeader: true, // Default separator is tab, but we need a semicolon. - separator: ';'); + separatorChar: ';'); // Now read the file (remember though, readers are lazy, so the actual reading will happen when the data is accessed). var trainData = reader.Read(trainDataPath); +// Sometime, caching data in-memory after its first access can save some loading time when the data is going to be used +// several times somewhere. The caching mechanism is also lazy; it only caches things after being used. +// User can replace all the subsequently uses of "trainData" with "cachedTrainData". We still use "trainData" because +// a caching step, which provides the same caching function, will be inserted in the considered "learningPipeline." +var cachedTrainData = trainData.Cache(); + // Step two: define the learning pipeline. // We 'start' the pipeline with the output of the reader. var learningPipeline = reader.MakeNewEstimator() + // We add a step for caching data in memory so that the downstream iterative training + // algorithm can efficiently scan through the data multiple times. Otherwise, the following + // trainer will read data from disk multiple times. The caching mechanism uses an on-demand strategy. + // The data accessed in any downstream step will be cached since its first use. In general, you only + // need to add a caching step before trainable step, because caching is not helpful if the data is + // only scanned once. This step can be removed if user doesn't have enough memory to store the whole + // data set. + .AppendCacheCheckpoint() // Now we can add any 'training steps' to it. In our case we want to 'normalize' the data (rescale to be // between -1 and 1 for all examples) .Append(r => ( @@ -468,23 +515,17 @@ var mlContext = new MLContext(); // Step one: read the data as an IDataView. // First, we define the reader: specify the data columns and where to find them in the text file. -var reader = mlContext.Data.TextReader(new TextLoader.Arguments -{ - Column = new[] { - // We read the first 11 values as a single float vector. - new TextLoader.Column("FeatureVector", DataKind.R4, 0, 10), - - // Separately, read the target variable. - new TextLoader.Column("Target", DataKind.R4, 11), - }, +// Read the data into a data view. Remember though, readers are lazy, so the actual reading will happen when the data is accessed. +var trainData = mlContext.Data.ReadFromTextFile(dataPath, // First line of the file is a header, not a data row. - HasHeader = true, - // Default separator is tab, but we need a semicolon. - Separator = ";" -}); + separatorChar: ',' +); -// Now read the file (remember though, readers are lazy, so the actual reading will happen when the data is accessed). -var trainData = reader.Read(trainDataPath); +// Sometime, caching data in-memory after its first access can save some loading time when the data is going to be used +// several times somewhere. The caching mechanism is also lazy; it only caches things after being used. +// User can replace all the subsequently uses of "trainData" with "cachedTrainData". We still use "trainData" because +// a caching step, which provides the same caching function, will be inserted in the considered "dynamicPipeline." +var cachedTrainData = mlContext.Data.Cache(trainData); // Step two: define the learning pipeline. @@ -493,6 +534,15 @@ var dynamicPipeline = // First 'normalize' the data (rescale to be // between -1 and 1 for all examples) mlContext.Transforms.Normalize("FeatureVector") + // We add a step for caching data in memory so that the downstream iterative training + // algorithm can efficiently scan through the data multiple times. Otherwise, the following + // trainer will read data from disk multiple times. The caching mechanism uses an on-demand strategy. + // The data accessed in any downstream step will be cached since its first use. In general, you only + // need to add a caching step before trainable step, because caching is not helpful if the data is + // only scanned once. This step can be removed if user doesn't have enough memory to store the whole + // data set. Notice that in the upstream Transforms.Normalize step, we only scan through the data + // once so adding a caching step before it is not helpful. + .AppendCacheCheckpoint(mlContext) // Add the SDCA regression trainer. .Append(mlContext.Regression.Trainers.StochasticDualCoordinateAscent(label: "Target", features: "FeatureVector")); @@ -516,7 +566,10 @@ var metrics = mlContext.Regression.Evaluate(model.Transform(testData), label: r Calculating the metrics with the dynamic API is as follows. ```csharp // Read the test dataset. -var testData = reader.Read(testDataPath); +var testData = mlContext.Data.ReadFromTextFile(testDataPath, + // First line of the file is a header, not a data row. + separatorChar: ',' +); // Calculate metrics of the model on the test data. var metrics = mlContext.Regression.Evaluate(model.Transform(testData), label: "Target"); ``` @@ -563,7 +616,7 @@ Since any ML.NET model is a transformer, you can of course use `model.Transform` A more typical case, though, is when there is no 'dataset' that we want to predict on, but instead we receive one example at a time. For instance, we run the model as part of the ASP.NET website, and we need to make a prediction for an incoming HTTP request. -For this case, ML.NET offers a convenient `PredictionFunction` component, that essentially runs one example at a time through the prediction pipeline. +For this case, ML.NET offers a convenient `PredictionEngine` component, that essentially runs one example at a time through the prediction pipeline. Here is the full example. Let's imagine that we have built a model for the famous Iris prediction dataset: @@ -574,7 +627,7 @@ var mlContext = new MLContext(); // Step one: read the data as an IDataView. // First, we define the reader: specify the data columns and where to find them in the text file. -var reader = mlContext.Data.TextReader(ctx => ( +var reader = mlContext.Data.CreateTextReader(ctx => ( // The four features of the Iris dataset. SepalLength: ctx.LoadFloat(0), SepalWidth: ctx.LoadFloat(1), @@ -584,7 +637,7 @@ var reader = mlContext.Data.TextReader(ctx => ( Label: ctx.LoadText(4) ), // Default separator is tab, but the dataset has comma. - separator: ','); + separatorChar: ','); // Retrieve the training data. var trainData = reader.Read(irisDataPath); @@ -595,6 +648,13 @@ var learningPipeline = reader.MakeNewEstimator() r.Label, // Concatenate all the features together into one column 'Features'. Features: r.SepalLength.ConcatWith(r.SepalWidth, r.PetalLength, r.PetalWidth))) + // We add a step for caching data in memory so that the downstream iterative training + // algorithm can efficiently scan through the data multiple times. Otherwise, the following + // trainer will read data from disk multiple times. The caching mechanism uses an on-demand strategy. + // The data accessed in any downstream step will be cached since its first use. In general, you only + // need to add a caching step before trainable step, because caching is not helpful if the data is + // only scanned once. + .AppendCacheCheckpoint() .Append(r => ( r.Label, // Train the multi-class SDCA model to predict the label using features. @@ -616,23 +676,11 @@ You can also use the dynamic API to create the equivalent of the previous pipeli var mlContext = new MLContext(); // Step one: read the data as an IDataView. -// First, we define the reader: specify the data columns and where to find them in the text file. -var reader = mlContext.Data.TextReader(new TextLoader.Arguments -{ - Column = new[] { - new TextLoader.Column("SepalLength", DataKind.R4, 0), - new TextLoader.Column("SepalWidth", DataKind.R4, 1), - new TextLoader.Column("PetalLength", DataKind.R4, 2), - new TextLoader.Column("PetalWidth", DataKind.R4, 3), - // Label: kind of iris. - new TextLoader.Column("Label", DataKind.TX, 4), - }, + // Retrieve the training data. +var trainData = mlContext.Data.ReadFromTextFile(irisDataPath, // Default separator is tab, but the dataset has comma. - Separator = "," -}); - -// Retrieve the training data. -var trainData = reader.Read(irisDataPath); + separatorChar: ',' +); // Build the training pipeline. var dynamicPipeline = @@ -640,6 +688,8 @@ var dynamicPipeline = mlContext.Transforms.Concatenate("Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") // Note that the label is text, so it needs to be converted to key. .Append(mlContext.Transforms.Categorical.MapValueToKey("Label"), TransformerScope.TrainTest) + // Cache data in memory for steps after the cache check point stage. + .AppendCacheCheckpoint(mlContext) // Use the multi-class SDCA model to predict the label using features. .Append(mlContext.MulticlassClassification.Trainers.StochasticDualCoordinateAscent()) // Apply the inverse conversion from 'PredictedLabel' column back to string value. @@ -679,10 +729,10 @@ var mlContext = new MLContext(); // Make the prediction function object. Note that, on average, this call takes around 200x longer // than one prediction, so you might want to cache and reuse the prediction function, instead of // creating one per prediction. -var predictionFunc = model.MakePredictionFunction(mlContext); +var predictionFunc = model.CreatePredictionEngine(mlContext); // Obtain the prediction. Remember that 'Predict' is not reentrant. If you want to use multiple threads -// for simultaneous prediction, make sure each thread is using its own PredictionFunction. +// for simultaneous prediction, make sure each thread is using its own PredictionEngine. var prediction = predictionFunc.Predict(new IrisInput { SepalLength = 4.1f, @@ -741,6 +791,7 @@ var trainData = mlContext.CreateStreamingDataView(churnData); var dynamicLearningPipeline = mlContext.Transforms.Categorical.OneHotEncoding("DemographicCategory") .Append(mlContext.Transforms.Concatenate("Features", "DemographicCategory", "LastVisits")) + .AppendCacheCheckpoint(mlContext) // FastTree will benefit from caching data in memory. .Append(mlContext.BinaryClassification.Trainers.FastTree("HasChurned", "Features", numTrees: 20)); var dynamicModel = dynamicLearningPipeline.Fit(trainData); @@ -757,6 +808,7 @@ var staticLearningPipeline = staticData.MakeNewEstimator() .Append(r => ( r.HasChurned, Features: r.DemographicCategory.OneHotEncoding().ConcatWith(r.LastVisits))) + .AppendCacheCheckpoint() // FastTree will benefit from caching data in memory. .Append(r => mlContext.BinaryClassification.Trainers.FastTree(r.HasChurned, r.Features, numTrees: 20)); var staticModel = staticLearningPipeline.Fit(staticData); @@ -781,7 +833,7 @@ var mlContext = new MLContext(); // Step one: read the data as an IDataView. // First, we define the reader: specify the data columns and where to find them in the text file. -var reader = mlContext.Data.TextReader(ctx => ( +var reader = mlContext.Data.CreateTextReader(ctx => ( // The four features of the Iris dataset. SepalLength: ctx.LoadFloat(0), SepalWidth: ctx.LoadFloat(1), @@ -791,7 +843,7 @@ var reader = mlContext.Data.TextReader(ctx => ( Label: ctx.LoadText(4) ), // Default separator is tab, but the dataset has comma. - separator: ','); + separatorChar: ','); // Retrieve the training data. var trainData = reader.Read(dataPath); @@ -813,6 +865,8 @@ var learningPipeline = reader.MakeNewEstimator() // When the normalizer is trained, the below delegate is going to be called. // We use it to memorize the scales. onFit: (scales, offsets) => normScales = scales))) + // Cache data used in memory because the subsequently trainer needs to access the data multiple times. + .AppendCacheCheckpoint() .Append(r => ( r.Label, // Train the multi-class SDCA model to predict the label using features. @@ -875,14 +929,14 @@ Here's a snippet of code that demonstrates normalization in learning pipelines. var mlContext = new MLContext(); // Define the reader: specify the data columns and where to find them in the text file. -var reader = mlContext.Data.TextReader(ctx => ( +var reader = mlContext.Data.CreateTextReader(ctx => ( // The four features of the Iris dataset will be grouped together as one Features column. Features: ctx.LoadFloat(0, 3), // Label: kind of iris. Label: ctx.LoadText(4) ), // Default separator is tab, but the dataset has comma. - separator: ','); + separatorChar: ','); // Read the training data. var trainData = reader.Read(dataPath); @@ -905,25 +959,26 @@ var meanVarValues = normalizedData.GetColumn(r => r.MeanVarNormalized).ToArray() You can achieve the same results using the dynamic API. ```csharp +//data model for the Iris class +private class IrisInputAllFeatures +{ + // Unfortunately, we still need the dummy 'Label' column to be present. + [ColumnName("Label"), LoadColumn(4)] + public string IgnoredLabel { get; set; } + + [LoadColumn(4, loadAllOthers:true)] + public float Features { get; set; } +} + // Create a new context for ML.NET operations. It can be used for exception tracking and logging, // as a catalog of available operations and as the source of randomness. var mlContext = new MLContext(); -// Define the reader: specify the data columns and where to find them in the text file. -var reader = mlContext.Data.TextReader(new TextLoader.Arguments -{ - Column = new[] { - // The four features of the Iris dataset will be grouped together as one Features column. - new TextLoader.Column("Features", DataKind.R4, 0, 3), - // Label: kind of iris. - new TextLoader.Column("Label", DataKind.TX, 4), - }, - // Default separator is tab, but the dataset has comma. - Separator = "," -}); - // Read the training data. -var trainData = reader.Read(dataPath); +var trainData = mlContext.Data.ReadFromTextFile(dataPath, + // Default separator is tab, but the dataset has comma. + separatorChar: ',' +); // Apply all kinds of standard ML.NET normalization to the raw features. var pipeline = @@ -969,7 +1024,7 @@ Label Workclass education marital-status occupation relationship ethnicity sex n var mlContext = new MLContext(); // Define the reader: specify the data columns and where to find them in the text file. -var reader = mlContext.Data.TextReader(ctx => ( +var reader = mlContext.Data.CreateTextReader(ctx => ( Label: ctx.LoadBool(0), // We will load all the categorical features into one vector column of size 8. CategoricalFeatures: ctx.LoadText(1, 8), @@ -987,6 +1042,10 @@ var catColumns = data.GetColumn(r => r.CategoricalFeatures).Take(10).ToArray(); // Build several alternative featurization pipelines. var learningPipeline = reader.MakeNewEstimator() + // Cache data in memory in an on-demand manner. Columns used in any downstream step will be + // cached in memory at their first uses. This step can be removed if user's machine doesn't + // have enough memory. + .AppendCacheCheckpoint() .Append(r => ( r.Label, r.NumericalFeatures, @@ -1027,9 +1086,8 @@ You can achieve the same results using the dynamic API. var mlContext = new MLContext(); // Define the reader: specify the data columns and where to find them in the text file. -var reader = mlContext.Data.TextReader(new TextLoader.Arguments -{ - Column = new[] { +var reader = mlContext.Data.CreateTextReader(new[] + { new TextLoader.Column("Label", DataKind.BL, 0), // We will load all the categorical features into one vector column of size 8. new TextLoader.Column("CategoricalFeatures", DataKind.TX, 1, 8), @@ -1038,8 +1096,8 @@ var reader = mlContext.Data.TextReader(new TextLoader.Arguments // Let's also separately load the 'Workclass' column. new TextLoader.Column("Workclass", DataKind.TX, 1), }, - HasHeader = true -}); + hasHeader: true +); // Read the data. var data = reader.Read(dataPath); @@ -1070,6 +1128,9 @@ var workclasses = transformedData.GetColumn(mlContext, "WorkclassOneHot var fullLearningPipeline = dynamicPipeline // Concatenate two of the 3 categorical pipelines, and the numeric features. .Append(mlContext.Transforms.Concatenate("Features", "NumericalFeatures", "CategoricalBag", "WorkclassOneHotTrimmed")) + // Cache data in memory so that the following trainer will be able to access training examples without + // reading them from disk multiple times. + .AppendCacheCheckpoint(mlContext) // Now we're ready to train. We chose our FastTree trainer for this classification task. .Append(mlContext.BinaryClassification.Trainers.FastTree(numTrees: 50)); @@ -1108,7 +1169,7 @@ Sentiment SentimentText var mlContext = new MLContext(); // Define the reader: specify the data columns and where to find them in the text file. -var reader = mlContext.Data.TextReader(ctx => ( +var reader = mlContext.Data.CreateTextReader(ctx => ( IsToxic: ctx.LoadBool(0), Message: ctx.LoadText(1) ), hasHeader: true); @@ -1121,6 +1182,10 @@ var messageTexts = data.GetColumn(x => x.Message).Take(20).ToArray(); // Apply various kinds of text operations supported by ML.NET. var learningPipeline = reader.MakeNewEstimator() + // Cache data in memory in an on-demand manner. Columns used in any downstream step will be + // cached in memory at their first uses. This step can be removed if user's machine doesn't + // have enough memory. + .AppendCacheCheckpoint() .Append(r => ( // One-stop shop to run the full text featurization. TextFeatures: r.Message.FeaturizeText(), @@ -1154,14 +1219,13 @@ You can achieve the same results using the dynamic API. var mlContext = new MLContext(); // Define the reader: specify the data columns and where to find them in the text file. -var reader = mlContext.Data.TextReader(new TextLoader.Arguments -{ - Column = new[] { +var reader = mlContext.Data.CreateTextReader(new[] + { new TextLoader.Column("IsToxic", DataKind.BL, 0), new TextLoader.Column("Message", DataKind.TX, 1), }, - HasHeader = true -}); + hasHeader: true +); // Read the data. var data = reader.Read(dataPath); @@ -1221,7 +1285,7 @@ var mlContext = new MLContext(); // Step one: read the data as an IDataView. // First, we define the reader: specify the data columns and where to find them in the text file. -var reader = mlContext.Data.TextReader(ctx => ( +var reader = mlContext.Data.CreateTextReader(ctx => ( // The four features of the Iris dataset. SepalLength: ctx.LoadFloat(0), SepalWidth: ctx.LoadFloat(1), @@ -1231,7 +1295,7 @@ var reader = mlContext.Data.TextReader(ctx => ( Label: ctx.LoadText(4) ), // Default separator is tab, but the dataset has comma. - separator: ','); + separatorChar: ','); // Read the data. var data = reader.Read(dataPath); @@ -1243,6 +1307,9 @@ var learningPipeline = reader.MakeNewEstimator() Label: r.Label.ToKey(), // Concatenate all the features together into one column 'Features'. Features: r.SepalLength.ConcatWith(r.SepalWidth, r.PetalLength, r.PetalWidth))) + // Add a step for caching data in memory so that the downstream iterative training + // algorithm can efficiently scan through the data multiple times. + .AppendCacheCheckpoint() .Append(r => ( r.Label, // Train the multi-class SDCA model to predict the label using features. @@ -1273,24 +1340,10 @@ You can achieve the same results using the dynamic API. var mlContext = new MLContext(); // Step one: read the data as an IDataView. -// First, we define the reader: specify the data columns and where to find them in the text file. -var reader = mlContext.Data.TextReader(new TextLoader.Arguments -{ - Column = new[] { - // We read the first 11 values as a single float vector. - new TextLoader.Column("SepalLength", DataKind.R4, 0), - new TextLoader.Column("SepalWidth", DataKind.R4, 1), - new TextLoader.Column("PetalLength", DataKind.R4, 2), - new TextLoader.Column("PetalWidth", DataKind.R4, 3), - // Label: kind of iris. - new TextLoader.Column("Label", DataKind.TX, 4), - }, +var data = mlContext.Data.ReadFromTextFile(dataPath, // Default separator is tab, but the dataset has comma. - Separator = "," -}); - -// Read the data. -var data = reader.Read(dataPath); + separatorChar: ',' +); // Build the training pipeline. var dynamicPipeline = @@ -1298,6 +1351,10 @@ var dynamicPipeline = mlContext.Transforms.Concatenate("Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") // Note that the label is text, so it needs to be converted to key. .Append(mlContext.Transforms.Conversions.MapValueToKey("Label"), TransformerScope.TrainTest) + // Cache data in memory so that SDCA trainer will be able to randomly access training examples without + // reading data from disk multiple times. Data will be cached at its first use in any downstream step. + // Notice that unused part in the data may not be cached. + .AppendCacheCheckpoint(mlContext) // Use the multi-class SDCA model to predict the label using features. .Append(mlContext.MulticlassClassification.Trainers.StochasticDualCoordinateAscent()); @@ -1335,7 +1392,7 @@ var mlContext = new MLContext(); // Read the data as an IDataView. // First, we define the reader: specify the data columns and where to find them in the text file. -var reader = mlContext.Data.TextReader(ctx => ( +var reader = mlContext.Data.CreateTextReader(ctx => ( // The four features of the Iris dataset. SepalLength: ctx.LoadFloat(0), SepalWidth: ctx.LoadFloat(1), @@ -1345,7 +1402,7 @@ var reader = mlContext.Data.TextReader(ctx => ( Label: ctx.LoadText(4) ), // Default separator is tab, but the dataset has comma. - separator: ','); + separatorChar: ','); // Read the data. var data = reader.Read(dataPath); @@ -1439,6 +1496,7 @@ public static ITransformer TrainModel(MLContext mlContext, IDataView trainData) Action mapping = (input, output) => output.Label = input.Income > 50000; // Construct the learning pipeline. var estimator = mlContext.Transforms.CustomMapping(mapping, null) + .AppendCacheCheckpoint(mlContext) .Append(mlContext.BinaryClassification.Trainers.FastTree(label: "Label")); return estimator.Fit(trainData); @@ -1480,8 +1538,12 @@ public class CustomMappings var estimator = mlContext.Transforms.CustomMapping(CustomMappings.IncomeMapping, nameof(CustomMappings.IncomeMapping)) .Append(mlContext.BinaryClassification.Trainers.FastTree(label: "Label")); +// If memory is enough, we can cache the data in-memory to avoid reading them from file +// when it will be accessed multiple times. +var cachedTrainData = mlContext.Data.Cache(trainData); + // Train the model. -var model = estimator.Fit(trainData); +var model = estimator.Fit(cachedTrainData); // Save the model. using (var fs = File.Create(modelPath)) diff --git a/docs/code/SchemaComprehension.md b/docs/code/SchemaComprehension.md index 37238e0c0c..93e14a02db 100644 --- a/docs/code/SchemaComprehension.md +++ b/docs/code/SchemaComprehension.md @@ -65,7 +65,7 @@ static void Main(string[] args) }; // Create the ML.NET environment. - var env = new Microsoft.ML.Runtime.Data.TlcEnvironment(); + var env = new Microsoft.ML.Data.TlcEnvironment(); // Create the data view. // This method will use the definition of IrisData to understand what columns there are in the @@ -74,7 +74,7 @@ static void Main(string[] args) // Now let's do something to the data view. For example, concatenate all four non-label columns // into 'Features' column. - dv = new Microsoft.ML.Runtime.Data.ConcatTransform(env, dv, "Features", + dv = new Microsoft.ML.Data.ConcatTransform(env, dv, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth"); // Read the data into an another array, this time we read the 'Features' and 'Label' columns diff --git a/docs/code/VBufferCareFeeding.md b/docs/code/VBufferCareFeeding.md index 36c19c3184..91918c5e9b 100644 --- a/docs/code/VBufferCareFeeding.md +++ b/docs/code/VBufferCareFeeding.md @@ -224,17 +224,17 @@ ML.NET's runtime code has a number of utilities for operating over `VBuffer`s that we have written to be generally useful. We will not treat on these in detail here, but: -* `Microsoft.ML.Runtime.Data.VBuffer` itself contains a few methods for +* `Microsoft.ML.Data.VBuffer` itself contains a few methods for accessing and iterating over its values. -* `Microsoft.ML.Runtime.Internal.Utilities.VBufferUtils` contains utilities +* `Microsoft.ML.Internal.Utilities.VBufferUtils` contains utilities mainly for non-numeric manipulation of `VBuffer`s. -* `Microsoft.ML.Runtime.Numeric.VectorUtils` contains math operations +* `Microsoft.ML.Numeric.VectorUtils` contains math operations over `VBuffer` and `float[]`, like computing norms, dot-products, and whatnot. -* `Microsoft.ML.Runtime.Data.BufferBuilder` is an abstract class whose +* `Microsoft.ML.Data.BufferBuilder` is an abstract class whose concrete implementations are used throughout ML.NET to build up `VBuffer` instances. Note that if one *can* simply build a `VBuffer` oneself easily and do not need the niceties provided by the buffer builder, you should diff --git a/docs/project-docs/developer-guide.md b/docs/project-docs/developer-guide.md index 074ea03dab..20e94d4e0c 100644 --- a/docs/project-docs/developer-guide.md +++ b/docs/project-docs/developer-guide.md @@ -25,6 +25,12 @@ build -? **Examples** +- Initialize the repo to make build possible (if the build fails because it can't find `mf.cpp` then perhaps you missed this step) + +``` +git submodule update --init +``` + - Building in release mode for platform x64 ``` build.cmd -Release -TargetArchitecture:x64 @@ -58,4 +64,5 @@ One can build in Debug or Release mode from the root by doing `build.cmd -Releas ### Building other Architectures -We only support 64-bit binaries right now. \ No newline at end of file +We only support 64-bit binaries right now. + diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Calibrator.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Calibrator.cs new file mode 100644 index 0000000000..78c62668ec --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Calibrator.cs @@ -0,0 +1,113 @@ +using System; +using System.Linq; +using Microsoft.ML.Calibrator; +using Microsoft.ML.Data; + +namespace Microsoft.ML.Samples.Dynamic +{ + /// + /// This example first trains a StochasticDualCoordinateAscentBinary Classifier and then convert its output to probability via training a calibrator. + /// + public class CalibratorExample + { + public static void Calibration() + { + // Downloading the dataset from github.com/dotnet/machinelearning. + // This will create a sentiment.tsv file in the filesystem. + // The string, dataFile, is the path to the downloaded file. + // You can open this file, if you want to see the data. + string dataFile = SamplesUtils.DatasetUtils.DownloadSentimentDataset(); + + // A preview of the data. + // Sentiment SentimentText + // 0 " :Erm, thank you. " + // 1 ==You're cool== + + // Create a new context for ML.NET operations. It can be used for exception tracking and logging, + // as a catalog of available operations and as the source of randomness. + var mlContext = new MLContext(); + + // Create a text loader. + var reader = mlContext.Data.CreateTextReader(new TextLoader.Arguments() + { + Separator = "tab", + HasHeader = true, + Column = new[] + { + new TextLoader.Column("Sentiment", DataKind.BL, 0), + new TextLoader.Column("SentimentText", DataKind.Text, 1) + } + }); + + // Read the data + var data = reader.Read(dataFile); + + // Split the dataset into two parts: one used for training, the other to train the calibrator + var (trainData, calibratorTrainingData) = mlContext.BinaryClassification.TrainTestSplit(data, testFraction: 0.1); + + // Featurize the text column through the FeaturizeText API. + // Then append the StochasticDualCoordinateAscentBinary binary classifier, setting the "Label" column as the label of the dataset, and + // the "Features" column produced by FeaturizeText as the features column. + var pipeline = mlContext.Transforms.Text.FeaturizeText("SentimentText", "Features") + .Append(mlContext.BinaryClassification.Trainers.StochasticDualCoordinateAscent( + labelColumn: "Sentiment", + featureColumn: "Features", + l2Const: 0.001f, + loss: new HingeLoss())); // By specifying loss: new HingeLoss(), StochasticDualCoordinateAscent will train a support vector machine (SVM). + + // Fit the pipeline, and get a transformer that knows how to score new data. + var transformer = pipeline.Fit(trainData); + IPredictor model = transformer.LastTransformer.Model; + + // Let's score the new data. The score will give us a numerical estimation of the chance that the particular sample + // bears positive sentiment. This estimate is relative to the numbers obtained. + var scoredData = transformer.Transform(calibratorTrainingData); + var scoredDataPreview = scoredData.Preview(); + + PrintRowViewValues(scoredDataPreview); + // Preview of scoredDataPreview.RowView + // + // Score - 0.458968 + // Score - 0.7022135 + // Score 1.138822 + // Score 0.4807112 + // Score 1.112813 + + // Let's train a calibrator estimator on this scored dataset. The trained calibrator estimator produces a transformer + // that can transform the scored data by adding a new column names "Probability". + var calibratorEstimator = new PlattCalibratorEstimator(mlContext, model, "Sentiment", "Features"); + var calibratorTransformer = calibratorEstimator.Fit(scoredData); + + // Transform the scored data with a calibrator transfomer by adding a new column names "Probability". + // This column is a calibrated version of the "Score" column, meaning its values are a valid probability value in the [0, 1] interval + // representing the chance that the respective sample bears positive sentiment. + var finalData = calibratorTransformer.Transform(scoredData).Preview(); + + PrintRowViewValues(finalData); + + //Preview of finalData.RowView + // + // Score - 0.458968 Probability 0.4670409 + // Score - 0.7022135 Probability 0.3912723 + // Score 1.138822 Probability 0.8703266 + // Score 0.4807112 Probability 0.7437012 + // Score 1.112813 Probability 0.8665403 + + } + + private static void PrintRowViewValues(Data.DataDebuggerPreview data) + { + var firstRows = data.RowView.Take(5); + + foreach(Data.DataDebuggerPreview.RowInfo row in firstRows) + { + foreach (var kvPair in row.Values) + { + if (kvPair.Key.Equals("Score") || kvPair.Key.Equals("Probability")) + Console.Write($" {kvPair.Key} {kvPair.Value} "); + } + Console.WriteLine(); + } + } + } +} diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/ConcatTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/ConcatTransform.cs index ba945f38fa..c14439ee78 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/ConcatTransform.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/ConcatTransform.cs @@ -1,8 +1,7 @@ -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Api; -using System; -using System.Linq; +using System; using System.Collections.Generic; +using System.Linq; +using Microsoft.ML.Data; using Microsoft.ML.Transforms; namespace Microsoft.ML.Samples.Dynamic diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/FastTreeRegression.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/FastTreeRegression.cs new file mode 100644 index 0000000000..e747b4fc86 --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/FastTreeRegression.cs @@ -0,0 +1,50 @@ +using System; +using System.Collections.Generic; +using Microsoft.ML.Data; + +namespace Microsoft.ML.Samples.Dynamic +{ + public class FastTreeRegressionExample + { + public static void FastTreeRegression() + { + // Create a new ML context, for ML.NET operations. It can be used for exception tracking and logging, + // as well as the source of randomness. + var ml = new MLContext(); + + // Get a small dataset as an IEnumerable and convert it to an IDataView. + var data = SamplesUtils.DatasetUtils.GetInfertData(); + var trainData = ml.CreateStreamingDataView(data); + + // Preview of the data. + // + // Age Case Education Induced Parity PooledStratum RowNum ... + // 26 1 0-5yrs 1 6 3 1 ... + // 42 1 0-5yrs 1 1 1 2 ... + // 39 1 0-5yrs 2 6 4 3 ... + // 34 1 0-5yrs 2 4 2 4 ... + // 35 1 6-11yrs 1 3 32 5 ... + + // A pipeline for concatenating the Parity and Induced columns together in the Features column. + // We will train a FastTreeRegression model with 1 tree on these two columns to predict Age. + string outputColumnName = "Features"; + var pipeline = ml.Transforms.Concatenate(outputColumnName, new[] { "Parity", "Induced" }) + .Append(ml.Regression.Trainers.FastTree(labelColumn: "Age", featureColumn: outputColumnName, numTrees: 1, numLeaves: 2, minDatapointsInLeaves: 1)); + + var model = pipeline.Fit(trainData); + + // Get the trained model parameters. + var modelParams = model.LastTransformer.Model; + + // Let's see where an example with Parity = 1 and Induced = 1 would end up in the single trained tree. + var testRow = new VBuffer(2, new[] { 1.0f, 1.0f }); + // Use the path object to pass to GetLeaf, which will populate path with the IDs of th nodes from root to leaf. + List path = default; + // Get the ID of the leaf this example ends up in tree 0. + var leafID = modelParams.GetLeaf(0, in testRow, ref path); + // Get the leaf value for this leaf ID in tree 0. + var leafValue = modelParams.GetLeafValue(0, leafID); + Console.WriteLine("The leaf value in tree 0 is: " + leafValue); + } + } +} diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/FeatureContributionCalculationTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/FeatureContributionCalculationTransform.cs index 1c28df327f..96ef491bb7 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/FeatureContributionCalculationTransform.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/FeatureContributionCalculationTransform.cs @@ -1,6 +1,5 @@ -using Microsoft.ML.Runtime.Api; -using Microsoft.ML.Runtime.Data; -using System; +using System; +using Microsoft.ML.Data; namespace Microsoft.ML.Samples.Dynamic { @@ -19,11 +18,8 @@ public static void FeatureContributionCalculationTransform_Regression() // Step 1: Read the data as an IDataView. // First, we define the reader: specify the data columns and where to find them in the text file. - var reader = mlContext.Data.TextReader(new TextLoader.Arguments() - { - Separator = "tab", - HasHeader = true, - Column = new[] + var reader = mlContext.Data.CreateTextReader( + columns: new[] { new TextLoader.Column("MedianHomeValue", DataKind.R4, 0), new TextLoader.Column("CrimesPerCapita", DataKind.R4, 1), @@ -37,8 +33,9 @@ public static void FeatureContributionCalculationTransform_Regression() new TextLoader.Column("HighwayDistance", DataKind.R4, 9), new TextLoader.Column("TaxRate", DataKind.R4, 10), new TextLoader.Column("TeacherRatio", DataKind.R4, 11), - } - }); + }, + hasHeader: true + ); // Read the data var data = reader.Read(dataFile); @@ -50,22 +47,26 @@ public static void FeatureContributionCalculationTransform_Regression() var transformPipeline = mlContext.Transforms.Concatenate("Features", "CrimesPerCapita", "PercentResidental", "PercentNonRetail", "CharlesRiver", "NitricOxides", "RoomsPerDwelling", "PercentPre40s", "EmploymentDistance", "HighwayDistance", "TaxRate", "TeacherRatio"); - var learner = mlContext.Regression.Trainers.StochasticDualCoordinateAscent( + var learner = mlContext.Regression.Trainers.OrdinaryLeastSquares( labelColumn: "MedianHomeValue", featureColumn: "Features"); var transformedData = transformPipeline.Fit(data).Transform(data); + // Now we train the model and score it on the transformed data. var model = learner.Fit(transformedData); + var scoredData = model.Transform(transformedData); // Create a Feature Contribution Calculator - // Calculate the feature contributions for all features + // Calculate the feature contributions for all features given trained model parameters // And don't normalize the contribution scores - var args = new FeatureContributionCalculationTransform.Arguments() - { - Top = 11, - Normalize = false - }; - var featureContributionCalculator = FeatureContributionCalculationTransform.Create(mlContext, args, transformedData, model.Model, model.FeatureColumn); + var featureContributionCalculator = mlContext.Model.Explainability.FeatureContributionCalculation(model.Model, model.FeatureColumn, top: 11, normalize: false); + var outputData = featureContributionCalculator.Fit(scoredData).Transform(scoredData); + + // FeatureContributionCalculatingEstimator can be use as an intermediary step in a pipeline. + // The features retained by FeatureContributionCalculatingEstimator will be in the FeatureContribution column. + var pipeline = mlContext.Model.Explainability.FeatureContributionCalculation(model.Model, model.FeatureColumn, top: 11) + .Append(mlContext.Regression.Trainers.OrdinaryLeastSquares(featureColumn: "FeatureContributions")); + var outData = featureContributionCalculator.Fit(scoredData).Transform(scoredData); // Let's extract the weights from the linear model to use as a comparison var weights = new VBuffer(); @@ -73,9 +74,9 @@ public static void FeatureContributionCalculationTransform_Regression() // Let's now walk through the first ten reconds and see which feature drove the values the most // Get prediction scores and contributions - var scoringEnumerator = featureContributionCalculator.AsEnumerable(mlContext, true).GetEnumerator(); + var scoringEnumerator = outputData.AsEnumerable(mlContext, true).GetEnumerator(); int index = 0; - Console.WriteLine("Label\tScore\tBiggestFeature\tValue\tWeight\tContribution\tPercent"); + Console.WriteLine("Label\tScore\tBiggestFeature\tValue\tWeight\tContribution"); while (scoringEnumerator.MoveNext() && index < 10) { var row = scoringEnumerator.Current; @@ -86,26 +87,34 @@ public static void FeatureContributionCalculationTransform_Regression() // And the corresponding information about the feature var value = row.Features[featureOfInterest]; var contribution = row.FeatureContributions[featureOfInterest]; - var percentContribution = 100 * contribution / row.Score; - var name = data.Schema.GetColumnName(featureOfInterest + 1); + var name = data.Schema[featureOfInterest + 1].Name; var weight = weights.GetValues()[featureOfInterest]; - Console.WriteLine("{0:0.00}\t{1:0.00}\t{2}\t{3:0.00}\t{4:0.00}\t{5:0.00}\t{6:0.00}", + Console.WriteLine("{0:0.00}\t{1:0.00}\t{2}\t{3:0.00}\t{4:0.00}\t{5:0.00}", row.MedianHomeValue, row.Score, name, value, weight, - contribution, - percentContribution + contribution ); index++; } - - // For bulk scoring, the ApplyToData API can also be used - var scoredData = featureContributionCalculator.ApplyToData(mlContext, transformedData); - var preview = scoredData.Preview(100); + Console.ReadLine(); + + // The output of the above code is: + // Label Score BiggestFeature Value Weight Contribution + // 24.00 27.74 RoomsPerDwelling 6.58 98.55 39.95 + // 21.60 23.85 RoomsPerDwelling 6.42 98.55 39.01 + // 34.70 29.29 RoomsPerDwelling 7.19 98.55 43.65 + // 33.40 27.17 RoomsPerDwelling 7.00 98.55 42.52 + // 36.20 27.68 RoomsPerDwelling 7.15 98.55 43.42 + // 28.70 23.13 RoomsPerDwelling 6.43 98.55 39.07 + // 22.90 22.71 RoomsPerDwelling 6.01 98.55 36.53 + // 27.10 21.72 RoomsPerDwelling 6.17 98.55 37.50 + // 16.50 18.04 RoomsPerDwelling 5.63 98.55 34.21 + // 18.90 20.14 RoomsPerDwelling 6.00 98.55 36.48 } private static int GetMostContributingFeature(float[] featureContributions) diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/FeatureSelectionTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/FeatureSelectionTransform.cs index 7508815dc4..aa3d4446bb 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/FeatureSelectionTransform.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/FeatureSelectionTransform.cs @@ -1,7 +1,6 @@ -using Microsoft.ML.Data; -using Microsoft.ML.Runtime.Data; -using System; +using System; using System.Collections.Generic; +using Microsoft.ML.Data; namespace Microsoft.ML.Samples.Dynamic { @@ -31,16 +30,14 @@ public static void FeatureSelectionTransform() // First, we define the reader: specify the data columns and where to find them in the text file. Notice that we combine entries from // all the feature columns into entries of a vector of a single column named "Features". - var reader = ml.Data.TextReader(new TextLoader.Arguments() - { - Separator = "tab", - HasHeader = true, - Column = new[] + var reader = ml.Data.CreateTextReader( + columns: new[] { new TextLoader.Column("Label", DataKind.BL, 0), new TextLoader.Column("Features", DataKind.Num, new [] { new TextLoader.Range(1, 9) }) - } - }); + }, + hasHeader: true + ); // Then, we use the reader to read the data as an IDataView. var data = reader.Read(dataFilePath); diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/FieldAwareFactorizationMachine.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/FieldAwareFactorizationMachine.cs new file mode 100644 index 0000000000..812de0fd27 --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/FieldAwareFactorizationMachine.cs @@ -0,0 +1,71 @@ +using System; +using Microsoft.ML.Data; +namespace Microsoft.ML.Samples.Dynamic +{ + public class FFM_BinaryClassificationExample + { + public static void FFM_BinaryClassification() + { + // Downloading the dataset from github.com/dotnet/machinelearning. + // This will create a sentiment.tsv file in the filesystem. + // You can open this file, if you want to see the data. + string dataFile = SamplesUtils.DatasetUtils.DownloadSentimentDataset(); + + // A preview of the data. + // Sentiment SentimentText + // 0 " :Erm, thank you. " + // 1 ==You're cool== + + // Create a new context for ML.NET operations. It can be used for exception tracking and logging, + // as a catalog of available operations and as the source of randomness. + var mlContext = new MLContext(); + + // Step 1: Read the data as an IDataView. + // First, we define the reader: specify the data columns and where to find them in the text file. + var reader = mlContext.Data.CreateTextReader( + columns: new[] + { + new TextLoader.Column("Sentiment", DataKind.BL, 0), + new TextLoader.Column("SentimentText", DataKind.Text, 1) + }, + hasHeader: true + ); + + // Read the data + var data = reader.Read(dataFile); + + // ML.NET doesn't cache data set by default. Therefore, if one reads a data set from a file and accesses it many times, it can be slow due to + // expensive featurization and disk operations. When the considered data can fit into memory, a solution is to cache the data in memory. Caching is especially + // helpful when working with iterative algorithms which needs many data passes. Since SDCA is the case, we cache. Inserting a + // cache step in a pipeline is also possible, please see the construction of pipeline below. + data = mlContext.Data.Cache(data); + + // Step 2: Pipeline + // Featurize the text column through the FeaturizeText API. + // Then append a binary classifier, setting the "Label" column as the label of the dataset, and + // the "Features" column produced by FeaturizeText as the features column. + var pipeline = mlContext.Transforms.Text.FeaturizeText("SentimentText", "Features") + .AppendCacheCheckpoint(mlContext) // Add a data-cache step within a pipeline. + .Append(mlContext.BinaryClassification.Trainers.FieldAwareFactorizationMachine(labelColumn: "Sentiment", featureColumns: new[] { "Features" })); + + // Fit the model. + var model = pipeline.Fit(data); + + // Let's get the model parameters from the model. + var modelParams = model.LastTransformer.Model; + + // Let's inspect the model parameters. + var featureCount = modelParams.GetFeatureCount(); + var fieldCount = modelParams.GetFieldCount(); + var latentDim = modelParams.GetLatentDim(); + var linearWeights = modelParams.GetLinearWeights(); + var latentWeights = modelParams.GetLatentWeights(); + + Console.WriteLine("The feature count is: " + featureCount); + Console.WriteLine("The number of fields is: " + fieldCount); + Console.WriteLine("The latent dimension is: " + latentDim); + Console.WriteLine("The lineear weights of the features are: " + string.Join(", ", linearWeights)); + Console.WriteLine("The weights of the latent features are: " + string.Join(", ", latentWeights)); + } + } +} diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/GeneralizedAdditiveModels.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/GeneralizedAdditiveModels.cs index 827d04a586..acd08979a2 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/GeneralizedAdditiveModels.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/GeneralizedAdditiveModels.cs @@ -1,6 +1,6 @@ -using Microsoft.ML.Runtime.Data; -using System; +using System; using System.Linq; +using Microsoft.ML.Data; namespace Microsoft.ML.Samples.Dynamic { @@ -19,11 +19,8 @@ public static void RunExample() // Step 1: Read the data as an IDataView. // First, we define the reader: specify the data columns and where to find them in the text file. - var reader = mlContext.Data.TextReader(new TextLoader.Arguments() - { - Separator = "tab", - HasHeader = true, - Column = new[] + var reader = mlContext.Data.CreateTextReader( + columns: new[] { new TextLoader.Column("MedianHomeValue", DataKind.R4, 0), new TextLoader.Column("CrimesPerCapita", DataKind.R4, 1), @@ -37,8 +34,9 @@ public static void RunExample() new TextLoader.Column("HighwayDistance", DataKind.R4, 9), new TextLoader.Column("TaxRate", DataKind.R4, 10), new TextLoader.Column("TeacherRatio", DataKind.R4, 11), - } - }); + }, + hasHeader: true + ); // Read the data var data = reader.Read(dataFile); @@ -50,8 +48,8 @@ public static void RunExample() // and use a small number of bins to make it easy to visualize in the console window. // For real appplications, it is recommended to start with the default number of bins. var labelName = "MedianHomeValue"; - var featureNames = data.Schema.GetColumns() - .Select(tuple => tuple.column.Name) // Get the column names + var featureNames = data.Schema + .Select(column => column.Name) // Get the column names .Where(name => name != labelName) // Drop the Label .ToArray(); var pipeline = mlContext.Transforms.Concatenate("Features", featureNames) diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/IidChangePointDetectorTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/IidChangePointDetectorTransform.cs index 52729c6c97..1a476d3f59 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/IidChangePointDetectorTransform.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/IidChangePointDetectorTransform.cs @@ -1,13 +1,15 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + using System; -using System.Linq; using System.Collections.Generic; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Api; -using Microsoft.ML.Runtime.TimeSeriesProcessing; -using Microsoft.ML.Core.Data; -using Microsoft.ML.TimeSeries; using System.IO; +using System.Linq; +using Microsoft.ML.Core.Data; using Microsoft.ML.Data; +using Microsoft.ML.TimeSeries; +using Microsoft.ML.TimeSeriesProcessing; namespace Microsoft.ML.Samples.Dynamic { diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/IidSpikeDetectorTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/IidSpikeDetectorTransform.cs index 5dcb7e1774..572c6b140d 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/IidSpikeDetectorTransform.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/IidSpikeDetectorTransform.cs @@ -1,13 +1,11 @@ using System; +using System.Collections.Generic; using System.IO; using System.Linq; -using System.Collections.Generic; -using Microsoft.ML.Data; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Api; -using Microsoft.ML.Runtime.TimeSeriesProcessing; using Microsoft.ML.Core.Data; +using Microsoft.ML.Data; using Microsoft.ML.TimeSeries; +using Microsoft.ML.TimeSeriesProcessing; namespace Microsoft.ML.Samples.Dynamic { diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/KMeans.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/KMeans.cs new file mode 100644 index 0000000000..1a5ad75b26 --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/KMeans.cs @@ -0,0 +1,45 @@ +using System; +using Microsoft.ML.Data; + +namespace Microsoft.ML.Samples.Dynamic +{ + public class KMeans_example + { + public static void KMeans() + { + // Create a new ML context, for ML.NET operations. It can be used for exception tracking and logging, + // as well as the source of randomness. + var ml = new MLContext(); + + // Get a small dataset as an IEnumerable and convert it to an IDataView. + var data = SamplesUtils.DatasetUtils.GetInfertData(); + var trainData = ml.CreateStreamingDataView(data); + + // Preview of the data. + // + // Age Case Education Induced Parity PooledStratum RowNum ... + // 26 1 0-5yrs 1 6 3 1 ... + // 42 1 0-5yrs 1 1 1 2 ... + // 39 1 0-5yrs 2 6 4 3 ... + // 34 1 0-5yrs 2 4 2 4 ... + // 35 1 6-11yrs 1 3 32 5 ... + + // A pipeline for concatenating the age, parity and induced columns together in the Features column and training a KMeans model on them. + string outputColumnName = "Features"; + var pipeline = ml.Transforms.Concatenate(outputColumnName, new[] { "Age", "Parity", "Induced" }) + .Append(ml.Clustering.Trainers.KMeans(outputColumnName, clustersCount: 2)); + + var model = pipeline.Fit(trainData); + + // Get cluster centroids and the number of clusters k from KMeansModelParameters. + VBuffer[] centroids = default; + int k; + + var modelParams = model.LastTransformer.Model; + modelParams.GetClusterCentroids(ref centroids, out k); + + var centroid = centroids[0].GetValues(); + Console.WriteLine("The coordinates of centroid 0 are: " + string.Join(", ", centroid.ToArray())); + } + } +} diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/KeyToValue_Term.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/KeyToValue_Term.cs index d6e15eeb50..6d3deb9f84 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/KeyToValue_Term.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/KeyToValue_Term.cs @@ -1,12 +1,8 @@ -using Microsoft.ML.Data; -using Microsoft.ML.Runtime.Api; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Transforms.Categorical; +using System; +using System.Collections.Generic; +using Microsoft.ML.Data; using Microsoft.ML.Transforms.Conversions; using Microsoft.ML.Transforms.Text; -using System; -using System.Collections.Generic; -using System.Linq; namespace Microsoft.ML.Samples.Dynamic { diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/LdaTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/LdaTransform.cs index f2cbec7c28..4765fe06a4 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/LdaTransform.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/LdaTransform.cs @@ -1,8 +1,6 @@ -using Microsoft.ML.Data; -using Microsoft.ML.Runtime.Api; -using Microsoft.ML.Runtime.Data; -using System; +using System; using System.Collections.Generic; +using Microsoft.ML.Data; namespace Microsoft.ML.Samples.Dynamic { diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/MatrixFactorization.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/MatrixFactorization.cs index 063a554adc..ecd5336b03 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/MatrixFactorization.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/MatrixFactorization.cs @@ -1,8 +1,6 @@ -using Microsoft.ML.Runtime.Api; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Trainers; using System; using System.Collections.Generic; +using Microsoft.ML.Data; namespace Microsoft.ML.Samples.Dynamic { diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/NgramExtraction.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/NgramExtraction.cs index 3c9147f32a..524f04c834 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/NgramExtraction.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/NgramExtraction.cs @@ -1,8 +1,6 @@ -using Microsoft.ML.Data; -using Microsoft.ML.Runtime.Api; -using Microsoft.ML.Runtime.Data; -using System; +using System; using System.Collections.Generic; +using Microsoft.ML.Data; namespace Microsoft.ML.Samples.Dynamic { diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Normalizer.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Normalizer.cs index 121eede6e6..abeb454d51 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Normalizer.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Normalizer.cs @@ -1,9 +1,8 @@ -using Microsoft.ML.Data; -using Microsoft.ML.Runtime.Api; -using Microsoft.ML.Transforms.Normalizers; -using System; +using System; using System.Collections.Generic; using System.Linq; +using Microsoft.ML.Data; +using Microsoft.ML.Transforms.Normalizers; namespace Microsoft.ML.Samples.Dynamic { diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/OnnxTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/OnnxTransform.cs new file mode 100644 index 0000000000..21d740b4ae --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/OnnxTransform.cs @@ -0,0 +1,105 @@ +using System; +using System.Linq; +using Microsoft.ML.Data; +using Microsoft.ML.OnnxRuntime; +using Microsoft.ML.Transforms; + +namespace Microsoft.ML.Samples.Dynamic +{ + class OnnxTransformExample + { + /// + /// Example use of OnnxEstimator in an ML.NET pipeline + /// + public static void OnnxTransformSample() + { + // Download the squeeznet image model from ONNX model zoo, version 1.2 + // https://github.com/onnx/models/tree/master/squeezenet + var modelPath = @"squeezenet\model.onnx"; + + // Inspect the model's inputs and outputs + var session = new InferenceSession(modelPath); + var inputInfo = session.InputMetadata.First(); + var outputInfo = session.OutputMetadata.First(); + Console.WriteLine($"Input Name is {String.Join(",", inputInfo.Key)}"); + Console.WriteLine($"Input Dimensions are {String.Join(",", inputInfo.Value.Dimensions)}"); + Console.WriteLine($"Output Name is {String.Join(",", outputInfo.Key)}"); + Console.WriteLine($"Output Dimensions are {String.Join(",", outputInfo.Value.Dimensions)}"); + // Results.. + // Input Name is data_0 + // Input Dimensions are 1,3,224,224 + // Output Name is softmaxout_1 + // Output Dimensions are 1,1000,1,1 + + // Create ML pipeline to score the data using OnnxScoringEstimator + var mlContext = new MLContext(); + var data = GetTensorData(); + var idv = mlContext.CreateStreamingDataView(data); + var pipeline = new OnnxScoringEstimator(mlContext, modelPath, new[] { inputInfo.Key }, new[] { outputInfo.Key }); + + // Run the pipeline and get the transformed values + var transformedValues = pipeline.Fit(idv).Transform(idv); + + // Retrieve model scores into Prediction class + var predictions = transformedValues.AsEnumerable(mlContext, reuseRowObject: false); + + // Iterate rows + foreach (var prediction in predictions) + { + int numClasses = 0; + foreach (var classScore in prediction.softmaxout_1.Take(3)) + { + Console.WriteLine($"Class #{numClasses++} score = {classScore}"); + } + Console.WriteLine(new string('-', 10)); + } + + // Results look like below... + // Class #0 score = 4.544065E-05 + // Class #1 score = 0.003845858 + // Class #2 score = 0.0001249467 + // ---------- + // Class #0 score = 4.491953E-05 + // Class #1 score = 0.003848222 + // Class #2 score = 0.0001245592 + // ---------- + } + + /// + /// inputSize is the overall dimensions of the model input tensor. + /// + private const int inputSize = 224 * 224 * 3; + + /// + /// A class to hold sample tensor data. Member name should match + /// the inputs that the model expects (in this case, data_0) + /// + public class TensorData + { + [VectorType(inputSize)] + public float[] data_0 { get; set; } + } + + /// + /// Method to generate sample test data. Returns 2 sample rows. + /// + /// + public static TensorData[] GetTensorData() + { + // This can be any numerical data. Assume image pixel values. + var image1 = Enumerable.Range(0, inputSize).Select(x => (float)x / inputSize).ToArray(); + var image2 = Enumerable.Range(0, inputSize).Select(x => (float)(x + 10000) / inputSize).ToArray(); + return new TensorData[] { new TensorData() { data_0 = image1 }, new TensorData() { data_0 = image2 } }; + } + + /// + /// Class to contain the output values from the transformation. + /// This model generates a vector of 1000 floats. + /// + class Prediction + { + [VectorType(1000)] + public float[] softmaxout_1 { get; set; } + } + } +} \ No newline at end of file diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance.cs deleted file mode 100644 index 0c95abacb8..0000000000 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance.cs +++ /dev/null @@ -1,128 +0,0 @@ -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Learners; -using System; -using System.Linq; - -namespace Microsoft.ML.Samples.Dynamic -{ - public class PFI_RegressionExample - { - public static void PFI_Regression() - { - // Download the dataset from github.com/dotnet/machinelearning. - // This will create a housing.txt file in the filesystem. - // You can open this file to see the data. - string dataFile = SamplesUtils.DatasetUtils.DownloadHousingRegressionDataset(); - - // Create a new context for ML.NET operations. It can be used for exception tracking and logging, - // as a catalog of available operations and as the source of randomness. - var mlContext = new MLContext(); - - // Step 1: Read the data as an IDataView. - // First, we define the reader: specify the data columns and where to find them in the text file. - // The data file is composed of rows of data, with each row having 11 numerical columns - // separated by whitespace. - var reader = mlContext.Data.TextReader(new TextLoader.Arguments() - { - Separator = "tab", - HasHeader = true, - Column = new[] - { - // Read the first column (indexed by 0) in the data file as an R4 (float) - new TextLoader.Column("MedianHomeValue", DataKind.R4, 0), - new TextLoader.Column("CrimesPerCapita", DataKind.R4, 1), - new TextLoader.Column("PercentResidental", DataKind.R4, 2), - new TextLoader.Column("PercentNonRetail", DataKind.R4, 3), - new TextLoader.Column("CharlesRiver", DataKind.R4, 4), - new TextLoader.Column("NitricOxides", DataKind.R4, 5), - new TextLoader.Column("RoomsPerDwelling", DataKind.R4, 6), - new TextLoader.Column("PercentPre40s", DataKind.R4, 7), - new TextLoader.Column("EmploymentDistance", DataKind.R4, 8), - new TextLoader.Column("HighwayDistance", DataKind.R4, 9), - new TextLoader.Column("TaxRate", DataKind.R4, 10), - new TextLoader.Column("TeacherRatio", DataKind.R4, 11), - } - }); - - // Read the data - var data = reader.Read(dataFile); - - // Step 2: Pipeline - // Concatenate the features to create a Feature vector. - // Normalize the data set so that for each feature, its maximum value is 1 while its minimum value is 0. - // Then append a linear regression trainer, setting the "MedianHomeValue" column as the label of the dataset, - // the "Features" column produced by concatenation as the features of the dataset. - var labelName = "MedianHomeValue"; - var pipeline = mlContext.Transforms.Concatenate("Features", "CrimesPerCapita", "PercentResidental", - "PercentNonRetail", "CharlesRiver", "NitricOxides", "RoomsPerDwelling", "PercentPre40s", - "EmploymentDistance", "HighwayDistance", "TaxRate", "TeacherRatio") - .Append(mlContext.Transforms.Normalize("Features")) - .Append(mlContext.Regression.Trainers.StochasticDualCoordinateAscent( - labelColumn: labelName, featureColumn: "Features")); - var model = pipeline.Fit(data); - - // Extract the model from the pipeline - var linearPredictor = model.LastTransformer; - var weights = GetLinearModelWeights(linearPredictor.Model); - - // Compute the permutation metrics using the properly-featurized data. - var transformedData = model.Transform(data); - var permutationMetrics = mlContext.Regression.PermutationFeatureImportance( - linearPredictor, transformedData, label: labelName, features: "Features"); - - // Now let's look at which features are most important to the model overall - // First, we have to prepare the data: - // Get the feature names as an IEnumerable - var featureNames = data.Schema.GetColumns() - .Select(tuple => tuple.column.Name) // Get the column names - .Where(name => name != labelName) // Drop the Label - .ToArray(); - - // Get the feature indices sorted by their impact on R-Squared - var sortedIndices = permutationMetrics.Select((metrics, index) => new { index, metrics.RSquared }) - .OrderByDescending(feature => Math.Abs(feature.RSquared)) - .Select(feature => feature.index); - - // Print out the permutation results, with the model weights, in order of their impact: - // Expected console output: - // Feature Model Weight Change in R - Squared - // RoomsPerDwelling 50.80 -0.3695 - // EmploymentDistance -17.79 -0.2238 - // TeacherRatio -19.83 -0.1228 - // TaxRate -8.60 -0.1042 - // NitricOxides -15.95 -0.1025 - // HighwayDistance 5.37 -0.09345 - // CrimesPerCapita -15.05 -0.05797 - // PercentPre40s -4.64 -0.0385 - // PercentResidental 3.98 -0.02184 - // CharlesRiver 3.38 -0.01487 - // PercentNonRetail -1.94 -0.007231 - // - // Let's dig into these results a little bit. First, if you look at the weights of the model, they generally correlate - // with the results of PFI, but there are some significant misorderings. For example, "Tax Rate" is weighted lower than - // "Nitric Oxides" and "Crimes Per Capita", but the permutation analysis shows this feature to have a larger effect - // on the accuracy of the model even though it has a relatively small weight. To understand why the weights don't - // reflect the same feature importance as PFI, we need to go back to the basics of linear models: one of the - // assumptions of a linear model is that the features are uncorrelated. Now, the features in this dataset are clearly - // correlated: the tax rate for a house and the student-to-teacher ratio at the nearest school, for example, are often - // coupled through school levies. The tax rate, presence of pollution (e.g. nitric oxides), and the crime rate would also - // seem to be correlated with each other through social dynamics. We could draw out similar relationships for all the - // variables in this dataset. The reason why the linear model weights don't reflect the same feature importance as PFI - // is that the solution to the linear model redistributes weights between correlated variables in unpredictable ways, so - // that the weights themselves are no longer a good measure of feature importance. - Console.WriteLine("Feature\tModel Weight\tChange in R-Squared"); - var rSquared = permutationMetrics.Select(x => x.RSquared).ToArray(); // Fetch r-squared as an array - foreach (int i in sortedIndices) - { - Console.WriteLine($"{featureNames[i]}\t{weights[i]:0.00}\t{rSquared[i]:G4}"); - } - } - - private static float[] GetLinearModelWeights(LinearRegressionPredictor linearModel) - { - var weights = new VBuffer(); - linearModel.GetFeatureWeights(ref weights); - return weights.GetValues().ToArray(); - } - } -} diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PFIHelper.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PFIHelper.cs new file mode 100644 index 0000000000..ebb4def616 --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PFIHelper.cs @@ -0,0 +1,88 @@ +using System; +using System.Linq; +using Microsoft.ML.Data; +using Microsoft.ML.Learners; +using Microsoft.ML.Trainers.HalLearners; + +namespace Microsoft.ML.Samples.Dynamic.PermutationFeatureImportance +{ + public class PfiHelper + { + public static IDataView GetHousingRegressionIDataView(MLContext mlContext, out string labelName, out string[] featureNames, bool binaryPrediction = false) + { + // Download the dataset from github.com/dotnet/machinelearning. + // This will create a housing.txt file in the filesystem. + // You can open this file to see the data. + string dataFile = SamplesUtils.DatasetUtils.DownloadHousingRegressionDataset(); + + // Read the data as an IDataView. + // First, we define the reader: specify the data columns and where to find them in the text file. + // The data file is composed of rows of data, with each row having 11 numerical columns + // separated by whitespace. + var reader = mlContext.Data.CreateTextReader( + columns: new[] + { + // Read the first column (indexed by 0) in the data file as an R4 (float) + new TextLoader.Column("MedianHomeValue", DataKind.R4, 0), + new TextLoader.Column("CrimesPerCapita", DataKind.R4, 1), + new TextLoader.Column("PercentResidental", DataKind.R4, 2), + new TextLoader.Column("PercentNonRetail", DataKind.R4, 3), + new TextLoader.Column("CharlesRiver", DataKind.R4, 4), + new TextLoader.Column("NitricOxides", DataKind.R4, 5), + new TextLoader.Column("RoomsPerDwelling", DataKind.R4, 6), + new TextLoader.Column("PercentPre40s", DataKind.R4, 7), + new TextLoader.Column("EmploymentDistance", DataKind.R4, 8), + new TextLoader.Column("HighwayDistance", DataKind.R4, 9), + new TextLoader.Column("TaxRate", DataKind.R4, 10), + new TextLoader.Column("TeacherRatio", DataKind.R4, 11), + }, + hasHeader: true + ); + + // Read the data + var data = reader.Read(dataFile); + var labelColumn = "MedianHomeValue"; + + if (binaryPrediction) + { + labelColumn = nameof(BinaryOutputRow.AboveAverage); + data = mlContext.Transforms.CustomMappingTransformer(GreaterThanAverage, null).Transform(data); + data = mlContext.Transforms.DropColumns("MedianHomeValue").Fit(data).Transform(data); + } + + labelName = labelColumn; + featureNames = data.Schema.AsEnumerable() + .Select(column => column.Name) // Get the column names + .Where(name => name != labelColumn) // Drop the Label + .ToArray(); + + return data; + } + + // Define a class for all the input columns that we intend to consume. + private class ContinuousInputRow + { + public float MedianHomeValue { get; set; } + } + + // Define a class for all output columns that we intend to produce. + private class BinaryOutputRow + { + public bool AboveAverage { get; set; } + } + + // Define an Action to apply a custom mapping from one object to the other + private readonly static Action GreaterThanAverage = (input, output) + => output.AboveAverage = input.MedianHomeValue > 22.6; + + public static float[] GetLinearModelWeights(OlsLinearRegressionModelParameters linearModel) + { + return linearModel.Weights.ToArray(); + } + + public static float[] GetLinearModelWeights(LinearBinaryModelParameters linearModel) + { + return linearModel.Weights.ToArray(); + } + } +} diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PFIRegressionExample.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PFIRegressionExample.cs new file mode 100644 index 0000000000..e552a841ed --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PFIRegressionExample.cs @@ -0,0 +1,77 @@ +using System; +using System.Linq; + +namespace Microsoft.ML.Samples.Dynamic.PermutationFeatureImportance +{ + public class PfiRegressionExample + { + public static void RunExample() + { + // Create a new context for ML.NET operations. It can be used for exception tracking and logging, + // as a catalog of available operations and as the source of randomness. + var mlContext = new MLContext(); + + // Step 1: Read the data + var data = PfiHelper.GetHousingRegressionIDataView(mlContext, out string labelName, out string[] featureNames); + + // Step 2: Pipeline + // Concatenate the features to create a Feature vector. + // Normalize the data set so that for each feature, its maximum value is 1 while its minimum value is 0. + // Then append a linear regression trainer. + var pipeline = mlContext.Transforms.Concatenate("Features", featureNames) + .Append(mlContext.Transforms.Normalize("Features")) + .Append(mlContext.Regression.Trainers.OrdinaryLeastSquares( + labelColumn: labelName, featureColumn: "Features")); + var model = pipeline.Fit(data); + + // Extract the model from the pipeline + var linearPredictor = model.LastTransformer; + var weights = PfiHelper.GetLinearModelWeights(linearPredictor.Model); + + // Compute the permutation metrics using the properly normalized data. + var transformedData = model.Transform(data); + var permutationMetrics = mlContext.Regression.PermutationFeatureImportance( + linearPredictor, transformedData, label: labelName, features: "Features", permutationCount: 3); + + // Now let's look at which features are most important to the model overall + // Get the feature indices sorted by their impact on R-Squared + var sortedIndices = permutationMetrics.Select((metrics, index) => new { index, metrics.RSquared }) + .OrderByDescending(feature => Math.Abs(feature.RSquared.Mean)) + .Select(feature => feature.index); + + // Print out the permutation results, with the model weights, in order of their impact: + // Expected console output for 100 permutations: + // Feature Model Weight Change in R-Squared 95% Confidence Interval of the Mean + // RoomsPerDwelling 53.35 -0.4298 0.005705 + // EmploymentDistance -19.21 -0.2609 0.004591 + // NitricOxides -19.32 -0.1569 0.003701 + // HighwayDistance 6.11 -0.1173 0.0025 + // TeacherRatio -21.92 -0.1106 0.002207 + // TaxRate -8.68 -0.1008 0.002083 + // CrimesPerCapita -16.37 -0.05988 0.00178 + // PercentPre40s -4.52 -0.03836 0.001432 + // PercentResidental 3.91 -0.02006 0.001079 + // CharlesRiver 3.49 -0.01839 0.000841 + // PercentNonRetail -1.17 -0.002111 0.0003176 + // + // Let's dig into these results a little bit. First, if you look at the weights of the model, they generally correlate + // with the results of PFI, but there are some significant misorderings. For example, "Tax Rate" and "Highway Distance" + // have relatively small model weights, but the permutation analysis shows these feature to have a larger effect + // on the accuracy of the model than higher-weighted features. To understand why the weights don't reflect the same + // feature importance as PFI, we need to go back to the basics of linear models: one of the assumptions of a linear + // model is that the features are uncorrelated. Now, the features in this dataset are clearly correlated: the tax rate + // for a house and the student-to-teacher ratio at the nearest school, for example, are often coupled through school + // levies. The tax rate, distance to a highway, and the crime rate would also seem to be correlated through social + // dynamics. We could draw out similar relationships for all variables in this dataset. The reason why the linear + // model weights don't reflect the same feature importance as PFI is that the solution to the linear model redistributes + // weights between correlated variables in unpredictable ways, so that the weights themselves are no longer a good + // measure of feature importance. + Console.WriteLine("Feature\tModel Weight\tChange in R-Squared\t95% Confidence Interval of the Mean"); + var rSquared = permutationMetrics.Select(x => x.RSquared).ToArray(); // Fetch r-squared as an array + foreach (int i in sortedIndices) + { + Console.WriteLine($"{featureNames[i]}\t{weights[i]:0.00}\t{rSquared[i].Mean:G4}\t{1.96 * rSquared[i].StandardError:G4}"); + } + } + } +} diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PfiBinaryClassificationExample.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PfiBinaryClassificationExample.cs new file mode 100644 index 0000000000..e75fa170c9 --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PfiBinaryClassificationExample.cs @@ -0,0 +1,76 @@ +using System; +using System.Linq; +using Microsoft.ML.Learners; + +namespace Microsoft.ML.Samples.Dynamic.PermutationFeatureImportance +{ + public class PfiBinaryClassificationExample + { + public static void RunExample() + { + // Create a new context for ML.NET operations. It can be used for exception tracking and logging, + // as a catalog of available operations and as the source of randomness. + var mlContext = new MLContext(seed:999123); + + // Step 1: Read the data + var data = PfiHelper.GetHousingRegressionIDataView(mlContext, + out string labelName, out string[] featureNames, binaryPrediction: true); + + // Step 2: Pipeline + // Concatenate the features to create a Feature vector. + // Normalize the data set so that for each feature, its maximum value is 1 while its minimum value is 0. + // Then append a logistic regression trainer. + var pipeline = mlContext.Transforms.Concatenate("Features", featureNames) + .Append(mlContext.Transforms.Normalize("Features")) + .Append(mlContext.BinaryClassification.Trainers.LogisticRegression( + labelColumn: labelName, featureColumn: "Features")); + var model = pipeline.Fit(data); + + // Extract the model from the pipeline + var linearPredictor = model.LastTransformer; + // Linear models for binary classification are wrapped by a calibrator as a generic predictor + // To access it directly, we must extract it out and cast it to the proper class + var weights = PfiHelper.GetLinearModelWeights(linearPredictor.Model.SubPredictor as LinearBinaryModelParameters); + + // Compute the permutation metrics using the properly normalized data. + var transformedData = model.Transform(data); + var permutationMetrics = mlContext.BinaryClassification.PermutationFeatureImportance( + linearPredictor, transformedData, label: labelName, features: "Features", permutationCount: 3); + + // Now let's look at which features are most important to the model overall + // Get the feature indices sorted by their impact on AUC + var sortedIndices = permutationMetrics.Select((metrics, index) => new { index, metrics.Auc }) + .OrderByDescending(feature => Math.Abs(feature.Auc.Mean)) + .Select(feature => feature.index); + + // Print out the permutation results, with the model weights, in order of their impact: + // Expected console output (for 100 permutations): + // Feature Model Weight Change in AUC 95% Confidence in the Mean Change in AUC + // PercentPre40s -1.96 -0.06316 0.002377 + // RoomsPerDwelling 3.71 -0.04385 0.001245 + // EmploymentDistance -1.31 -0.02139 0.0006867 + // TeacherRatio -2.46 -0.0203 0.0009566 + // PercentNonRetail -1.58 -0.01846 0.001586 + // CharlesRiver 0.66 -0.008605 0.0005136 + // PercentResidental 0.60 0.002483 0.0004818 + // TaxRate -0.95 -0.00221 0.0007394 + // NitricOxides -0.32 0.00101 0.0001428 + // CrimesPerCapita -0.04 -3.029E-05 1.678E-05 + // HighwayDistance 0.00 0 0 + // Let's look at these results. + // First, if you look at the weights of the model, they generally correlate with the results of PFI, + // but there are some significant misorderings. See the discussion in the Regression example for an + // explanation of why this happens and how to interpret it. + // Second, the logistic regression learner uses L1 regularization by default. Here, it causes the "HighWay Distance" + // feature to be zeroed out from the model. PFI assigns zero importance to this variable, as expected. + // Third, some features show an *increase* in AUC. This means that the model actually improved + // when these features were shuffled. This is a sign to investigate these features further. + Console.WriteLine("Feature\tModel Weight\tChange in AUC\t95% Confidence in the Mean Change in AUC"); + var auc = permutationMetrics.Select(x => x.Auc).ToArray(); // Fetch AUC as an array + foreach (int i in sortedIndices) + { + Console.WriteLine($"{featureNames[i]}\t{weights[i]:0.00}\t{auc[i].Mean:G4}\t{1.96 * auc[i].StandardError:G4}"); + } + } + } +} diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/ProjectionTransforms.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/ProjectionTransforms.cs index 0f7761cfba..15ffbf6653 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/ProjectionTransforms.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/ProjectionTransforms.cs @@ -1,9 +1,7 @@ -using Microsoft.ML.Runtime.Api; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Data; -using System; +using System; using System.Collections.Generic; using System.Linq; +using Microsoft.ML.Data; namespace Microsoft.ML.Samples.Dynamic { diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/SDCA.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/SDCA.cs index 32a33cdc43..5f08dee906 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/SDCA.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/SDCA.cs @@ -1,6 +1,6 @@ -using Microsoft.ML.Runtime.Data; -using System; +using System; using System.Linq; +using Microsoft.ML.Data; namespace Microsoft.ML.Samples.Dynamic { @@ -24,25 +24,30 @@ public static void SDCA_BinaryClassification() // Step 1: Read the data as an IDataView. // First, we define the reader: specify the data columns and where to find them in the text file. - var reader = mlContext.Data.TextReader(new TextLoader.Arguments() - { - Separator = "tab", - HasHeader = true, - Column = new[] + var reader = mlContext.Data.CreateTextReader( + columns: new[] { new TextLoader.Column("Sentiment", DataKind.BL, 0), new TextLoader.Column("SentimentText", DataKind.Text, 1) - } - }); + }, + hasHeader: true + ); // Read the data var data = reader.Read(dataFile); + // ML.NET doesn't cache data set by default. Therefore, if one reads a data set from a file and accesses it many times, it can be slow due to + // expensive featurization and disk operations. When the considered data can fit into memory, a solution is to cache the data in memory. Caching is especially + // helpful when working with iterative algorithms which needs many data passes. Since SDCA is the case, we cache. Inserting a + // cache step in a pipeline is also possible, please see the construction of pipeline below. + data = mlContext.Data.Cache(data); + // Step 2: Pipeline // Featurize the text column through the FeaturizeText API. // Then append a binary classifier, setting the "Label" column as the label of the dataset, and - // the "Features" column produced by FeaturizeText as the features column. + // the "Features" column produced by FeaturizeText as the features column. var pipeline = mlContext.Transforms.Text.FeaturizeText("SentimentText", "Features") + .AppendCacheCheckpoint(mlContext) // Add a data-cache step within a pipeline. .Append(mlContext.BinaryClassification.Trainers.StochasticDualCoordinateAscent(labelColumn: "Sentiment", featureColumn: "Features", l2Const: 0.001f)); // Step 3: Run Cross-Validation on this pipeline. diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/SDCARegression.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/SDCARegression.cs new file mode 100644 index 0000000000..c75872011c --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/SDCARegression.cs @@ -0,0 +1,43 @@ +using System; +using System.Linq; +using Microsoft.ML.Data; + +namespace Microsoft.ML.Samples.Dynamic +{ + public class SDCARegressionExample + { + public static void SDCARegression() + { + // Create a new ML context, for ML.NET operations. It can be used for exception tracking and logging, + // as well as the source of randomness. + var ml = new MLContext(); + + // Get a small dataset as an IEnumerable and convert it to an IDataView. + var data = SamplesUtils.DatasetUtils.GetInfertData(); + var trainData = ml.CreateStreamingDataView(data); + + // Preview of the data. + // + // Age Case Education Induced Parity PooledStratum RowNum ... + // 26 1 0-5yrs 1 6 3 1 ... + // 42 1 0-5yrs 1 1 1 2 ... + // 39 1 0-5yrs 2 6 4 3 ... + // 34 1 0-5yrs 2 4 2 4 ... + // 35 1 6-11yrs 1 3 32 5 ... + + // A pipeline for concatenating the Parity and Induced columns together in the Features column. + // We will train a FastTreeRegression model with 1 tree on these two columns to predict Age. + string outputColumnName = "Features"; + var pipeline = ml.Transforms.Concatenate(outputColumnName, new[] { "Parity", "Induced" }) + .Append(ml.Regression.Trainers.StochasticDualCoordinateAscent(labelColumn: "Age", featureColumn: outputColumnName, maxIterations:2)); + + var model = pipeline.Fit(trainData); + + // Get the trained model parameters. + var modelParams = model.LastTransformer.Model; + // Inspect the bias and model weights. + Console.WriteLine("The bias term is: " + modelParams.Bias); + Console.WriteLine("The feature weights are: " + string.Join(", ", modelParams.Weights.ToArray())); + } + } +} diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/SsaChangePointDetectorTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/SsaChangePointDetectorTransform.cs index e03ba68606..7eba634243 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/SsaChangePointDetectorTransform.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/SsaChangePointDetectorTransform.cs @@ -1,12 +1,11 @@ -using Microsoft.ML.Core.Data; -using Microsoft.ML.Data; -using Microsoft.ML.Runtime.Api; -using Microsoft.ML.Runtime.TimeSeriesProcessing; -using Microsoft.ML.TimeSeries; using System; using System.Collections.Generic; using System.IO; using System.Linq; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data; +using Microsoft.ML.TimeSeries; +using Microsoft.ML.TimeSeriesProcessing; namespace Microsoft.ML.Samples.Dynamic { diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/SsaSpikeDetectorTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/SsaSpikeDetectorTransform.cs index fe9e981ecd..ff915dac1f 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/SsaSpikeDetectorTransform.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/SsaSpikeDetectorTransform.cs @@ -1,13 +1,11 @@ -using Microsoft.ML.Core.Data; -using Microsoft.ML.Data; -using Microsoft.ML.Runtime.Api; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.TimeSeriesProcessing; -using Microsoft.ML.TimeSeries; using System; using System.Collections.Generic; using System.IO; using System.Linq; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data; +using Microsoft.ML.TimeSeries; +using Microsoft.ML.TimeSeriesProcessing; namespace Microsoft.ML.Samples.Dynamic { diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/TensorFlowTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/TensorFlowTransform.cs new file mode 100644 index 0000000000..41d1eac9fe --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/TensorFlowTransform.cs @@ -0,0 +1,91 @@ +using System; +using System.Linq; +using Microsoft.ML.Data; + +namespace Microsoft.ML.Samples.Dynamic +{ + class TensorFlowTransformExample + { + /// + /// Example use of the TensorFlowEstimator in a ML.NET pipeline. + /// + public static void TensorFlowScoringSample() + { + // Download the ResNet 101 model from the location below. + // https://storage.googleapis.com/download.tensorflow.org/models/tflite_11_05_08/resnet_v2_101.tgz + var modelLocation = @"resnet_v2_101/resnet_v2_101_299_frozen.pb"; + + var mlContext = new MLContext(); + var data = GetTensorData(); + var idv = mlContext.CreateStreamingDataView(data); + + // Create a ML pipeline. + var pipeline = mlContext.Transforms.ScoreTensorFlowModel( + modelLocation, + new[] { nameof(TensorData.input) }, + new[] { nameof(OutputScores.output) }); + + // Run the pipeline and get the transformed values. + var estimator = pipeline.Fit(idv); + var transformedValues = estimator.Transform(idv); + + // Retrieve model scores. + var outScores = transformedValues.AsEnumerable(mlContext, reuseRowObject: false); + + // Display scores. (for the sake of brevity we display scores of the first 3 classes) + foreach (var prediction in outScores) + { + int numClasses = 0; + foreach (var classScore in prediction.output.Take(3)) + { + Console.WriteLine($"Class #{numClasses++} score = {classScore}"); + } + Console.WriteLine(new string('-', 10)); + } + + // Results look like below... + //Class #0 score = -0.8092947 + //Class #1 score = -0.3310375 + //Class #2 score = 0.1119193 + //---------- + //Class #0 score = -0.7807726 + //Class #1 score = -0.2158062 + //Class #2 score = 0.1153686 + //---------- + } + + private const int imageHeight = 224; + private const int imageWidth = 224; + private const int numChannels = 3; + private const int inputSize = imageHeight * imageWidth * numChannels; + + /// + /// A class to hold sample tensor data. + /// Member name should match the inputs that the model expects (in this case, input). + /// + public class TensorData + { + [VectorType(imageHeight, imageWidth, numChannels)] + public float[] input { get; set; } + } + + /// + /// Method to generate sample test data. Returns 2 sample rows. + /// + public static TensorData[] GetTensorData() + { + // This can be any numerical data. Assume image pixel values. + var image1 = Enumerable.Range(0, inputSize).Select(x => (float)x / inputSize).ToArray(); + var image2 = Enumerable.Range(0, inputSize).Select(x => (float)(x + 10000) / inputSize).ToArray(); + return new TensorData[] { new TensorData() { input = image1 }, new TensorData() { input = image2 } }; + } + + /// + /// Class to contain the output values from the transformation. + /// + class OutputScores + { + public float[] output { get; set; } + } + } +} \ No newline at end of file diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/TextTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/TextTransform.cs index 57fbc7e493..3c8964b6f7 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/TextTransform.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/TextTransform.cs @@ -1,9 +1,7 @@ -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Api; +using System; +using System.Collections.Generic; using Microsoft.ML.Data; using Microsoft.ML.Transforms.Text; -using System; -using System.Collections.Generic; namespace Microsoft.ML.Samples.Dynamic { diff --git a/docs/samples/Microsoft.ML.Samples/Microsoft.ML.Samples.csproj b/docs/samples/Microsoft.ML.Samples/Microsoft.ML.Samples.csproj index abea9f3e8d..44b673640f 100644 --- a/docs/samples/Microsoft.ML.Samples/Microsoft.ML.Samples.csproj +++ b/docs/samples/Microsoft.ML.Samples/Microsoft.ML.Samples.csproj @@ -7,17 +7,23 @@ + + + + + + false diff --git a/docs/samples/Microsoft.ML.Samples/Program.cs b/docs/samples/Microsoft.ML.Samples/Program.cs index 83a7e8c2e8..41f465ef5e 100644 --- a/docs/samples/Microsoft.ML.Samples/Program.cs +++ b/docs/samples/Microsoft.ML.Samples/Program.cs @@ -6,7 +6,7 @@ internal static class Program { static void Main(string[] args) { - LdaTransformExample.LdaTransform(); + FeatureContributionCalculationTransform_RegressionExample.FeatureContributionCalculationTransform_Regression(); } } } diff --git a/docs/samples/Microsoft.ML.Samples/Static/AveragedPerceptronBinaryClassification.cs b/docs/samples/Microsoft.ML.Samples/Static/AveragedPerceptronBinaryClassification.cs index d2359526e9..176f55eff6 100644 --- a/docs/samples/Microsoft.ML.Samples/Static/AveragedPerceptronBinaryClassification.cs +++ b/docs/samples/Microsoft.ML.Samples/Static/AveragedPerceptronBinaryClassification.cs @@ -1,8 +1,6 @@ -using Microsoft.ML.Runtime.Data; +using System; +using Microsoft.ML.Data; using Microsoft.ML.StaticPipe; -using Microsoft.ML.Transforms; -using Microsoft.ML.Transforms.Categorical; -using System; namespace Microsoft.ML.Samples.Static { diff --git a/docs/samples/Microsoft.ML.Samples/Static/FastTreeBinaryClassification.cs b/docs/samples/Microsoft.ML.Samples/Static/FastTreeBinaryClassification.cs index d28b7c79de..55835aa0b3 100644 --- a/docs/samples/Microsoft.ML.Samples/Static/FastTreeBinaryClassification.cs +++ b/docs/samples/Microsoft.ML.Samples/Static/FastTreeBinaryClassification.cs @@ -1,9 +1,6 @@ -using Microsoft.ML.Runtime.Data; +using System; +using Microsoft.ML.Data; using Microsoft.ML.StaticPipe; -using Microsoft.ML.Transforms; -using Microsoft.ML.Transforms.Categorical; -using Microsoft.ML.Transforms.FeatureSelection; -using System; namespace Microsoft.ML.Samples.Static { diff --git a/docs/samples/Microsoft.ML.Samples/Static/FastTreeRegression.cs b/docs/samples/Microsoft.ML.Samples/Static/FastTreeRegression.cs index 66ddc6772c..10e3edb7ea 100644 --- a/docs/samples/Microsoft.ML.Samples/Static/FastTreeRegression.cs +++ b/docs/samples/Microsoft.ML.Samples/Static/FastTreeRegression.cs @@ -1,8 +1,8 @@ -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Trainers.FastTree; -using Microsoft.ML.StaticPipe; -using System; +using System; using System.Linq; +using Microsoft.ML.Data; +using Microsoft.ML.StaticPipe; +using Microsoft.ML.Trainers.FastTree; namespace Microsoft.ML.Samples.Static { @@ -30,7 +30,7 @@ public static void FastTreeRegression() var data = reader.Read(dataFile); // The predictor that gets produced out of training - FastTreeRegressionPredictor pred = null; + FastTreeRegressionModelParameters pred = null; // Create the estimator var learningPipeline = reader.MakeNewEstimator() diff --git a/docs/samples/Microsoft.ML.Samples/Static/FeatureSelectionTransform.cs b/docs/samples/Microsoft.ML.Samples/Static/FeatureSelectionTransform.cs index d38f428eea..6c130dfc4b 100644 --- a/docs/samples/Microsoft.ML.Samples/Static/FeatureSelectionTransform.cs +++ b/docs/samples/Microsoft.ML.Samples/Static/FeatureSelectionTransform.cs @@ -1,8 +1,7 @@ -using Microsoft.ML.Data; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.StaticPipe; -using System; +using System; using System.Collections.Generic; +using Microsoft.ML.Data; +using Microsoft.ML.StaticPipe; namespace Microsoft.ML.Samples.Dynamic { diff --git a/docs/samples/Microsoft.ML.Samples/Static/LightGBMBinaryClassification.cs b/docs/samples/Microsoft.ML.Samples/Static/LightGBMBinaryClassification.cs index 53899d95fe..bee5be6c21 100644 --- a/docs/samples/Microsoft.ML.Samples/Static/LightGBMBinaryClassification.cs +++ b/docs/samples/Microsoft.ML.Samples/Static/LightGBMBinaryClassification.cs @@ -1,9 +1,7 @@ -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.StaticPipe; -using Microsoft.ML.Transforms; -using Microsoft.ML.Transforms.Categorical; -using Microsoft.ML.Transforms.FeatureSelection; using System; +using Microsoft.ML.Data; +using Microsoft.ML.LightGBM.StaticPipe; +using Microsoft.ML.StaticPipe; namespace Microsoft.ML.Samples.Static { diff --git a/docs/samples/Microsoft.ML.Samples/Static/LightGBMRegression.cs b/docs/samples/Microsoft.ML.Samples/Static/LightGBMRegression.cs index ca257d864f..45bf277157 100644 --- a/docs/samples/Microsoft.ML.Samples/Static/LightGBMRegression.cs +++ b/docs/samples/Microsoft.ML.Samples/Static/LightGBMRegression.cs @@ -1,7 +1,7 @@ -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.LightGBM; -using Microsoft.ML.StaticPipe; using System; +using Microsoft.ML.Data; +using Microsoft.ML.LightGBM; +using Microsoft.ML.LightGBM.StaticPipe; namespace Microsoft.ML.Samples.Static { @@ -30,7 +30,7 @@ public static void LightGbmRegression() var (trainData, testData) = mlContext.Regression.TrainTestSplit(data, testFraction: 0.1); // The predictor that gets produced out of training - LightGbmRegressionPredictor pred = null; + LightGbmRegressionModelParameters pred = null; // Create the estimator var learningPipeline = reader.MakeNewEstimator() diff --git a/docs/samples/Microsoft.ML.Samples/Static/SDCABinaryClassification.cs b/docs/samples/Microsoft.ML.Samples/Static/SDCABinaryClassification.cs index 886342c416..122057a96e 100644 --- a/docs/samples/Microsoft.ML.Samples/Static/SDCABinaryClassification.cs +++ b/docs/samples/Microsoft.ML.Samples/Static/SDCABinaryClassification.cs @@ -1,9 +1,6 @@ -using Microsoft.ML.Runtime.Data; +using System; +using Microsoft.ML.Data; using Microsoft.ML.StaticPipe; -using Microsoft.ML.Transforms; -using Microsoft.ML.Transforms.Categorical; -using Microsoft.ML.Transforms.FeatureSelection; -using System; namespace Microsoft.ML.Samples.Static { diff --git a/docs/samples/Microsoft.ML.Samples/Static/SDCARegression.cs b/docs/samples/Microsoft.ML.Samples/Static/SDCARegression.cs index 4d35a28257..93c680d950 100644 --- a/docs/samples/Microsoft.ML.Samples/Static/SDCARegression.cs +++ b/docs/samples/Microsoft.ML.Samples/Static/SDCARegression.cs @@ -1,7 +1,7 @@ -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Learners; -using Microsoft.ML.StaticPipe; using System; +using Microsoft.ML.Data; +using Microsoft.ML.Learners; +using Microsoft.ML.StaticPipe; namespace Microsoft.ML.Samples.Static { @@ -29,7 +29,7 @@ public static void SdcaRegression() var (trainData, testData) = mlContext.Regression.TrainTestSplit(data, testFraction: 0.1); // The predictor that gets produced out of training - LinearRegressionPredictor pred = null; + LinearRegressionModelParameters pred = null; // Create the estimator var learningPipeline = reader.MakeNewEstimator() diff --git a/init-tools.cmd b/init-tools.cmd index 3743cb413d..349f7b1461 100644 --- a/init-tools.cmd +++ b/init-tools.cmd @@ -46,15 +46,17 @@ if exist "%DotNetBuildToolsDir%" ( echo Running %0 > "%INIT_TOOLS_LOG%" set /p DOTNET_VERSION=< "%~dp0DotnetCLIVersion.txt" -if exist "%DOTNET_CMD%" goto :afterdotnetrestore :Arg_Loop if [%1] == [] goto :ArchSet -if /i [%1] == [x86] ( set ARCH=x86&&goto ArchSet) +if /i [%1] == [x86] ( set ARCH=x86) +if /i [%1] == [-Debug-Intrinsics] ( set /p DOTNET_VERSION=< "%~dp0DotnetCLIVersion.netcoreapp.latest.txt") +if /i [%1] == [-Release-Intrinsics] ( set /p DOTNET_VERSION=< "%~dp0DotnetCLIVersion.netcoreapp.latest.txt") shift goto :Arg_Loop :ArchSet +if exist "%DOTNET_CMD%" goto :afterdotnetrestore echo Installing dotnet cli... if NOT exist "%DOTNET_PATH%" mkdir "%DOTNET_PATH%" diff --git a/pkg/Microsoft.ML.EntryPoints/Microsoft.ML.EntryPoints.nupkgproj b/pkg/Microsoft.ML.EntryPoints/Microsoft.ML.EntryPoints.nupkgproj new file mode 100644 index 0000000000..b845fdeb45 --- /dev/null +++ b/pkg/Microsoft.ML.EntryPoints/Microsoft.ML.EntryPoints.nupkgproj @@ -0,0 +1,12 @@ + + + + netstandard2.0 + Microsoft.ML.EntryPoints contains the ML.NET entry point API catalog. + + + + + + + diff --git a/pkg/Microsoft.ML.EntryPoints/Microsoft.ML.EntryPoints.symbols.nupkgproj b/pkg/Microsoft.ML.EntryPoints/Microsoft.ML.EntryPoints.symbols.nupkgproj new file mode 100644 index 0000000000..3fa0255960 --- /dev/null +++ b/pkg/Microsoft.ML.EntryPoints/Microsoft.ML.EntryPoints.symbols.nupkgproj @@ -0,0 +1,5 @@ + + + + + diff --git a/pkg/Microsoft.ML.OnnxTransform/Microsoft.ML.OnnxTransform.nupkgproj b/pkg/Microsoft.ML.OnnxTransform/Microsoft.ML.OnnxTransform.nupkgproj index 4d64d756fe..b817e809d1 100644 --- a/pkg/Microsoft.ML.OnnxTransform/Microsoft.ML.OnnxTransform.nupkgproj +++ b/pkg/Microsoft.ML.OnnxTransform/Microsoft.ML.OnnxTransform.nupkgproj @@ -7,7 +7,7 @@ - + diff --git a/pkg/Microsoft.ML.TensorFlow.Redist/Microsoft.ML.TensorFlow.Redist.nupkgproj b/pkg/Microsoft.ML.TensorFlow.Redist/Microsoft.ML.TensorFlow.Redist.nupkgproj index 595b6c1b7d..42d3387355 100644 --- a/pkg/Microsoft.ML.TensorFlow.Redist/Microsoft.ML.TensorFlow.Redist.nupkgproj +++ b/pkg/Microsoft.ML.TensorFlow.Redist/Microsoft.ML.TensorFlow.Redist.nupkgproj @@ -6,7 +6,7 @@ $(MSBuildProjectName) contains the TensorFlow C library version $(TensorFlowVersion) redistributed as a NuGet package. https://github.com/tensorflow/tensorflow/blob/master/LICENSE true - Copyright 2018 The TensorFlow Authors. All rights reserved. + Copyright 2018 The TensorFlow Authors. All rights reserved. https://www.tensorflow.org https://github.com/tensorflow/tensorflow/releases/tag/v$(TensorFlowVersion) $(PackageTags) TensorFlow diff --git a/src/Directory.Build.props b/src/Directory.Build.props index 1ebabb9213..be6d1d53bc 100644 --- a/src/Directory.Build.props +++ b/src/Directory.Build.props @@ -32,4 +32,10 @@ Include="StyleCop.Analyzers" Version="1.1.0-beta008" PrivateAssets="All" /> + + + stylecop.json + + + diff --git a/src/Microsoft.ML.Api/CodeGenerationUtils.cs b/src/Microsoft.ML.Api/CodeGenerationUtils.cs deleted file mode 100644 index 1c39f1b9dd..0000000000 --- a/src/Microsoft.ML.Api/CodeGenerationUtils.cs +++ /dev/null @@ -1,142 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System.Collections.Generic; -using System.Text; -using System.Text.RegularExpressions; -using System.CodeDom; -using Microsoft.ML.Runtime.Data; -using Microsoft.CSharp; -using System.IO; - -namespace Microsoft.ML.Runtime.Api -{ - /// - /// Utility methods for code generation. - /// - internal static class CodeGenerationUtils - { - /// - /// Replace placeholders with provided values. Assert that every placeholder is found. - /// - public static string MultiReplace(string text, Dictionary replacementMap) - { - Contracts.AssertValue(text); - Contracts.AssertValue(replacementMap); - var pattern = @"\/\*#(.*)#\*\/.*\/\*#\/\1#\*\/[\n\r]{0,2}"; - - int seenTags = 0; - var result = Regex.Replace(text, pattern, - match => - { - var tag = match.Groups[1].Value; - string replacement; - bool found = replacementMap.TryGetValue(tag, out replacement); - Contracts.Assert(found); - seenTags++; - return replacement; - }, RegexOptions.Singleline); - - Contracts.Assert(seenTags == replacementMap.Count); - return result; - } - - /// - /// Append a field declaration to the provided . - /// - public static void AppendFieldDeclaration(CSharpCodeProvider codeProvider, StringBuilder target, int columnIndex, - string fieldName, ColumnType colType, bool appendInitializer, bool useVBuffer) - { - Contracts.AssertValueOrNull(codeProvider); - Contracts.AssertValue(target); - Contracts.Assert(columnIndex >= 0); - Contracts.AssertNonEmpty(fieldName); - Contracts.AssertValue(colType); - - var attributes = new List(); - string generatedCsTypeName = GetBackingTypeName(colType, useVBuffer, attributes); - - if (codeProvider != null && !codeProvider.IsValidIdentifier(fieldName)) - { - attributes.Add(string.Format("[ColumnName({0})]", GetCSharpString(codeProvider, fieldName))); - fieldName = string.Format("Column{0}", columnIndex); - } - - const string indent = " "; - if (attributes.Count > 0) - { - foreach (var attr in attributes) - { - target.Append(indent); - target.AppendLine(attr); - } - } - target.Append(indent); - target.AppendFormat("public {0} {1}", generatedCsTypeName, fieldName); - - if (appendInitializer && colType is VectorType vecColType && vecColType.Size > 0 && !useVBuffer) - { - Contracts.Assert(generatedCsTypeName.EndsWith("[]")); - var csItemType = generatedCsTypeName.Substring(0, generatedCsTypeName.Length - 2); - target.AppendFormat(" = new {0}[{1}]", csItemType, vecColType.Size); - } - target.AppendLine(";"); - } - - /// - /// Generates a C# string for a given input (with proper escaping). - /// - public static string GetCSharpString(CSharpCodeProvider codeProvider, string value) - { - using (var writer = new StringWriter()) - { - codeProvider.GenerateCodeFromExpression(new CodePrimitiveExpression(value), writer, null); - return writer.ToString(); - } - } - - /// - /// Gets the C# strings representing the type name for a variable corresponding to - /// the column type. - /// - /// If the type is a vector, then controls whether the array field is - /// generated or . - /// - /// If additional attributes are required, they are appended to the list. - /// - private static string GetBackingTypeName(ColumnType colType, bool useVBuffer, List attributes) - { - Contracts.AssertValue(colType); - Contracts.AssertValue(attributes); - if (colType is VectorType vecColType) - { - if (vecColType.Size > 0) - { - // By default, arrays are assumed variable length, unless a [VectorType(dim1, dim2, ...)] - // attribute is applied to the fields. - attributes.Add(string.Format("[VectorType({0})]", string.Join(", ", vecColType.Dimensions))); - } - - var itemType = GetBackingTypeName(colType.ItemType, false, attributes); - return useVBuffer ? string.Format("VBuffer<{0}>", itemType) : string.Format("{0}[]", itemType); - } - - if (colType.IsText) - return "string"; - - if (colType.IsKey) - { - // The way to define a key type in C# is by specifying the [KeyType] attribute and - // making the field type equal to the underlying raw type. - var key = colType.AsKey; - attributes.Add(string.Format("[KeyType(Count={0}, Min={1}, Contiguous={2})]", - key.Count, - key.Min, - key.Contiguous ? "true" : "false")); - } - - return colType.AsPrimitive.RawType.Name; - } - } -} diff --git a/src/Microsoft.ML.Api/GenerateCodeCommand.cs b/src/Microsoft.ML.Api/GenerateCodeCommand.cs deleted file mode 100644 index 4855a94219..0000000000 --- a/src/Microsoft.ML.Api/GenerateCodeCommand.cs +++ /dev/null @@ -1,152 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System.Collections.Generic; -using System.IO; -using System.Reflection; -using System.Text; -using Microsoft.CSharp; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Api; -using Microsoft.ML.Runtime.Command; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Model; - -[assembly: LoadableClass(typeof(GenerateCodeCommand), typeof(GenerateCodeCommand.Arguments), typeof(SignatureCommand), - "Generate Sample Prediction Code", GenerateCodeCommand.LoadName, "codegen")] - -namespace Microsoft.ML.Runtime.Api -{ - /// - /// Generates the sample prediction code for a given model file, with correct input and output classes. - /// - /// REVIEW: Consider adding support for generating VBuffers instead of arrays, maybe for high dimensionality vectors. - /// - internal sealed class GenerateCodeCommand : ICommand - { - public const string LoadName = "GenerateSamplePredictionCode"; - private const string CodeTemplatePath = "Microsoft.ML.Api.GeneratedCodeTemplate.csresource"; - -#pragma warning disable 0649 // The command is internal, suppress a warning about fields never assigned to. - public sealed class Arguments - { - [Argument(ArgumentType.Required, HelpText = "Input model file", ShortName = "in", IsInputFileName = true)] - public string InputModelFile; - - [Argument(ArgumentType.Required, HelpText = "File to output generated C# code", ShortName = "cs")] - public string CSharpOutput; - - /// - /// Whether to use the to represent vector columns (supports sparse vectors). - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to use the VBuffer to represent vector columns (supports sparse vectors)", - ShortName = "sparse", SortOrder = 102)] - public bool SparseVectorDeclaration; - - // REVIEW: currently, it's only used in unit testing to not generate the paths into the test output folder. - // However, it might be handy for automation scenarios, so I've added this as a hidden option. - [Argument(ArgumentType.AtMostOnce, HelpText = "A location of the model file to put into generated file", Hide = true)] - public string ModelNameOverride; - } -#pragma warning restore 0649 - - private readonly IHost _host; - private readonly Arguments _args; - - public GenerateCodeCommand(IHostEnvironment env, Arguments args) - { - Contracts.CheckValue(env, nameof(env)); - _host = env.Register("GenerateCodeCommand"); - _host.CheckValue(args, nameof(args)); - _host.CheckUserArg(!string.IsNullOrWhiteSpace(args.InputModelFile), - nameof(args.InputModelFile), "input model file is required"); - _host.CheckUserArg(!string.IsNullOrWhiteSpace(args.CSharpOutput), - nameof(args.CSharpOutput), "Output file is required"); - _args = args; - } - - public void Run() - { - string template; - using (var stream = Assembly.GetExecutingAssembly().GetManifestResourceStream(CodeTemplatePath)) - using (var reader = new StreamReader(stream)) - template = reader.ReadToEnd(); - - var codeProvider = new CSharpCodeProvider(); - using (var fs = File.OpenRead(_args.InputModelFile)) - { - var transformPipe = ModelFileUtils.LoadPipeline(_host, fs, new MultiFileSource(null), true); - var pred = _host.LoadPredictorOrNull(fs); - - IDataView root; - for (root = transformPipe; root is IDataTransform; root = ((IDataTransform)root).Source) - ; - - // root is now the loader. - _host.Assert(root is IDataLoader); - - // Loader columns. - var loaderSb = new StringBuilder(); - for (int i = 0; i < root.Schema.ColumnCount; i++) - { - if (root.Schema.IsHidden(i)) - continue; - if (loaderSb.Length > 0) - loaderSb.AppendLine(); - - ColumnType colType = root.Schema.GetColumnType(i); - CodeGenerationUtils.AppendFieldDeclaration(codeProvider, loaderSb, i, root.Schema.GetColumnName(i), colType, true, _args.SparseVectorDeclaration); - } - - // Scored example columns. - IDataView scorer; - if (pred == null) - scorer = transformPipe; - else - { - var roles = ModelFileUtils.LoadRoleMappingsOrNull(_host, fs); - scorer = roles != null - ? _host.CreateDefaultScorer(new RoleMappedData(transformPipe, roles, opt: true), pred) - : _host.CreateDefaultScorer(new RoleMappedData(transformPipe, label: null, "Features"), pred); - } - - var nonScoreSb = new StringBuilder(); - var scoreSb = new StringBuilder(); - for (int i = 0; i < scorer.Schema.ColumnCount; i++) - { - if (scorer.Schema.IsHidden(i)) - continue; - bool isScoreColumn = scorer.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.ScoreColumnSetId, i) != null; - - var sb = isScoreColumn ? scoreSb : nonScoreSb; - - if (sb.Length > 0) - sb.AppendLine(); - - ColumnType colType = scorer.Schema.GetColumnType(i); - CodeGenerationUtils.AppendFieldDeclaration(codeProvider, sb, i, scorer.Schema.GetColumnName(i), colType, false, _args.SparseVectorDeclaration); - } - - // Turn model path into a C# identifier and insert it. - var modelPath = !string.IsNullOrWhiteSpace(_args.ModelNameOverride) ? _args.ModelNameOverride : _args.InputModelFile; - modelPath = CodeGenerationUtils.GetCSharpString(codeProvider, modelPath); - modelPath = string.Format("modelPath = {0};", modelPath); - - // Replace values inside the template. - var replacementMap = - new Dictionary - { - { "EXAMPLE_CLASS_DECL", loaderSb.ToString() }, - { "SCORED_EXAMPLE_CLASS_DECL", nonScoreSb.ToString() }, - { "SCORE_CLASS_DECL", scoreSb.ToString() }, - { "MODEL_PATH", modelPath } - }; - - var classSource = CodeGenerationUtils.MultiReplace(template, replacementMap); - File.WriteAllText(_args.CSharpOutput, classSource); - } - } - } -} diff --git a/src/Microsoft.ML.Api/Microsoft.ML.Api.csproj b/src/Microsoft.ML.Api/Microsoft.ML.Api.csproj deleted file mode 100644 index 88324dca5e..0000000000 --- a/src/Microsoft.ML.Api/Microsoft.ML.Api.csproj +++ /dev/null @@ -1,26 +0,0 @@ - - - - netstandard2.0 - Microsoft.ML - - - - - - - - - - - - - - - - - - - - - diff --git a/src/Microsoft.ML.Api/PredictionFunction.cs b/src/Microsoft.ML.Api/PredictionFunction.cs deleted file mode 100644 index 7243301f8d..0000000000 --- a/src/Microsoft.ML.Api/PredictionFunction.cs +++ /dev/null @@ -1,64 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using Microsoft.ML.Core.Data; -using Microsoft.ML.Runtime.Api; - -namespace Microsoft.ML.Runtime.Data -{ - /// - /// A prediction engine class, that takes instances of through - /// the transformer pipeline and produces instances of as outputs. - /// - public sealed class PredictionFunction - where TSrc : class - where TDst : class, new() - { - private readonly PredictionEngine _engine; - - /// - /// Create an instance of . - /// - /// The host environment. - /// The model (transformer) to use for prediction. - public PredictionFunction(IHostEnvironment env, ITransformer transformer) - { - Contracts.CheckValue(env, nameof(env)); - env.CheckValue(transformer, nameof(transformer)); - - IDataView dv = env.CreateDataView(new TSrc[0]); - _engine = env.CreatePredictionEngine(transformer); - } - - /// - /// Perform one prediction using the model. - /// - /// The object that holds values to predict from. - /// The object populated with prediction results. - public TDst Predict(TSrc example) => _engine.Predict(example); - - /// - /// Perform one prediction using the model. - /// Reuses the provided prediction object, which is more efficient in high-load scenarios. - /// - /// The object that holds values to predict from. - /// The object to store the predictions in. If it's null, a new object is created, - /// otherwise the provided object is used. - public void Predict(TSrc example, ref TDst prediction) => _engine.Predict(example, ref prediction); - } - - public static class PredictionFunctionExtensions - { - /// - /// Create an instance of the 'prediction function', or 'prediction machine', from a model - /// denoted by . - /// It will be accepting instances of as input, and produce - /// instances of as output. - /// - public static PredictionFunction MakePredictionFunction(this ITransformer transformer, IHostEnvironment env) - where TSrc : class - where TDst : class, new() - => new PredictionFunction(env, transformer); - } -} diff --git a/src/Microsoft.ML.Api/Predictor.cs b/src/Microsoft.ML.Api/Predictor.cs deleted file mode 100644 index 137c58bfb8..0000000000 --- a/src/Microsoft.ML.Api/Predictor.cs +++ /dev/null @@ -1,40 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; - -namespace Microsoft.ML.Runtime.Api -{ - /// - /// An opaque 'holder' of the predictor, meant to insulate the user from the internal TLC predictor structure, - /// which is subject to change. - /// - public sealed class Predictor - { - /// - /// The actual predictor. - /// - internal readonly IPredictor Pred; - - internal Predictor(IPredictor pred) - { - Contracts.AssertValue(pred); - Pred = pred; - } - - /// - /// A way for the user to extract the predictor object and 'delve into the underworld' of unsupported non-API methods. - /// This is needed, for instance, to inspect the weights of a predictor programmatically. - /// The intention is to expose most methods through the API and make usage of this method increasingly unnecessary. - /// - [Obsolete("Welcome adventurous stranger, to the Underdark! By calling the mysterious GetPredictorObject method,"+ - " you have entered a world of shifting realities, where nothing is as it seems. Your code may work today, but"+ - " the churning impermanence of the Underdark means the strong foothold today may be nothing but empty air"+ - " tomorrow. Brace yourself!")] - public object GetPredictorObject() - { - return Pred; - } - } -} diff --git a/src/Microsoft.ML.Console/Console.cs b/src/Microsoft.ML.Console/Console.cs index 152d65951a..549f222de7 100644 --- a/src/Microsoft.ML.Console/Console.cs +++ b/src/Microsoft.ML.Console/Console.cs @@ -2,7 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -namespace Microsoft.ML.Runtime.Tools.Console +namespace Microsoft.ML.Tools.Console { public static class Console { diff --git a/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj b/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj index f7c87c0abd..7fadb9ea75 100644 --- a/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj +++ b/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj @@ -4,7 +4,7 @@ netcoreapp2.1 Exe MML - Microsoft.ML.Runtime.Tools.Console.Console + Microsoft.ML.Tools.Console.Console diff --git a/src/Microsoft.ML.Core/BestFriendAttribute.cs b/src/Microsoft.ML.Core/BestFriendAttribute.cs index 1470f95c34..19c70922e7 100644 --- a/src/Microsoft.ML.Core/BestFriendAttribute.cs +++ b/src/Microsoft.ML.Core/BestFriendAttribute.cs @@ -6,7 +6,7 @@ #if CPUMATH_INFRASTRUCTURE // CpuMath has its own BestFriend and WantsToBeBestFriends attributes for making itself a standalone module -namespace Microsoft.ML.Runtime.Internal.CpuMath.Core +namespace Microsoft.ML.Internal.CpuMath.Core #else // This namespace contains the BestFriend and WantsToBeBestFriends attributes generally used in ML.NET project settings namespace Microsoft.ML diff --git a/src/Microsoft.ML.Core/CommandLine/ArgumentAttribute.cs b/src/Microsoft.ML.Core/CommandLine/ArgumentAttribute.cs index 70f9ec8d98..fdc296d7f8 100644 --- a/src/Microsoft.ML.Core/CommandLine/ArgumentAttribute.cs +++ b/src/Microsoft.ML.Core/CommandLine/ArgumentAttribute.cs @@ -5,7 +5,7 @@ using System; using System.Linq; -namespace Microsoft.ML.Runtime.CommandLine +namespace Microsoft.ML.CommandLine { /// /// Allows control of command line parsing. diff --git a/src/Microsoft.ML.Core/CommandLine/ArgumentType.cs b/src/Microsoft.ML.Core/CommandLine/ArgumentType.cs index 5840615fd5..d27acf9387 100644 --- a/src/Microsoft.ML.Core/CommandLine/ArgumentType.cs +++ b/src/Microsoft.ML.Core/CommandLine/ArgumentType.cs @@ -4,7 +4,7 @@ using System; -namespace Microsoft.ML.Runtime.CommandLine +namespace Microsoft.ML.CommandLine { /// /// Used to control parsing of command line arguments. diff --git a/src/Microsoft.ML.Core/CommandLine/CharCursor.cs b/src/Microsoft.ML.Core/CommandLine/CharCursor.cs index d8a591331c..85a60a7f36 100644 --- a/src/Microsoft.ML.Core/CommandLine/CharCursor.cs +++ b/src/Microsoft.ML.Core/CommandLine/CharCursor.cs @@ -2,9 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.CommandLine +namespace Microsoft.ML.CommandLine { internal sealed class CharCursor { diff --git a/src/Microsoft.ML.Core/CommandLine/CmdLexer.cs b/src/Microsoft.ML.Core/CommandLine/CmdLexer.cs index a5c14259a8..ce4d6ef4d3 100644 --- a/src/Microsoft.ML.Core/CommandLine/CmdLexer.cs +++ b/src/Microsoft.ML.Core/CommandLine/CmdLexer.cs @@ -4,7 +4,7 @@ using System.Text; -namespace Microsoft.ML.Runtime.CommandLine +namespace Microsoft.ML.CommandLine { [BestFriend] internal sealed class CmdLexer @@ -302,7 +302,8 @@ public static string UnquoteValue(string str) } } - public sealed class CmdQuoter + [BestFriend] + internal sealed class CmdQuoter { private readonly string _str; private StringBuilder _sb; diff --git a/src/Microsoft.ML.Core/CommandLine/CmdParser.cs b/src/Microsoft.ML.Core/CommandLine/CmdParser.cs index 81018dad3e..cecb73fb4b 100644 --- a/src/Microsoft.ML.Core/CommandLine/CmdParser.cs +++ b/src/Microsoft.ML.Core/CommandLine/CmdParser.cs @@ -9,12 +9,10 @@ using System.IO; using System.Linq; using System.Reflection; -using System.Runtime.InteropServices; using System.Text; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.CommandLine +namespace Microsoft.ML.CommandLine { /// @@ -537,16 +535,10 @@ private static ArgumentInfo GetArgumentInfo(Type type, object defaults) private static ArgumentAttribute GetAttribute(FieldInfo field) { - var attrs = field.GetCustomAttributes(typeof(ArgumentAttribute), false).ToArray(); - if (attrs.Length == 1) - { - var argumentAttribute = (ArgumentAttribute)attrs[0]; - if (argumentAttribute.Visibility == ArgumentAttribute.VisibilityType.EntryPointsOnly) - return null; - return argumentAttribute; - } - Contracts.Assert(attrs.Length == 0); - return null; + var argumentAttribute = field.GetCustomAttribute(false); + if (argumentAttribute?.Visibility == ArgumentAttribute.VisibilityType.EntryPointsOnly) + return null; + return argumentAttribute; } private void ReportUnrecognizedArgument(string argument) diff --git a/src/Microsoft.ML.Core/CommandLine/DefaultArgumentAttribute.cs b/src/Microsoft.ML.Core/CommandLine/DefaultArgumentAttribute.cs index 2d676f1ece..12121df9ea 100644 --- a/src/Microsoft.ML.Core/CommandLine/DefaultArgumentAttribute.cs +++ b/src/Microsoft.ML.Core/CommandLine/DefaultArgumentAttribute.cs @@ -4,7 +4,7 @@ using System; -namespace Microsoft.ML.Runtime.CommandLine +namespace Microsoft.ML.CommandLine { /// /// Indicates that this argument is the default argument. diff --git a/src/Microsoft.ML.Core/CommandLine/EnumValueDisplayAttribute.cs b/src/Microsoft.ML.Core/CommandLine/EnumValueDisplayAttribute.cs index b6cf4254ca..9b3652d9b4 100644 --- a/src/Microsoft.ML.Core/CommandLine/EnumValueDisplayAttribute.cs +++ b/src/Microsoft.ML.Core/CommandLine/EnumValueDisplayAttribute.cs @@ -4,7 +4,7 @@ using System; -namespace Microsoft.ML.Runtime.CommandLine +namespace Microsoft.ML.CommandLine { /// /// On an enum value - specifies the display name. diff --git a/src/Microsoft.ML.Core/CommandLine/HideEnumValueAttribute.cs b/src/Microsoft.ML.Core/CommandLine/HideEnumValueAttribute.cs index 964a5cc3f3..078a8abfa8 100644 --- a/src/Microsoft.ML.Core/CommandLine/HideEnumValueAttribute.cs +++ b/src/Microsoft.ML.Core/CommandLine/HideEnumValueAttribute.cs @@ -4,7 +4,7 @@ using System; -namespace Microsoft.ML.Runtime.CommandLine +namespace Microsoft.ML.CommandLine { /// /// On an enum value - indicates that the value should not be shown in help or UI. diff --git a/src/Microsoft.ML.Core/CommandLine/SpecialPurpose.cs b/src/Microsoft.ML.Core/CommandLine/SpecialPurpose.cs index 46423d43d9..491ef4b21c 100644 --- a/src/Microsoft.ML.Core/CommandLine/SpecialPurpose.cs +++ b/src/Microsoft.ML.Core/CommandLine/SpecialPurpose.cs @@ -2,7 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -namespace Microsoft.ML.Runtime.CommandLine +namespace Microsoft.ML.CommandLine { [BestFriend] internal static class SpecialPurpose diff --git a/src/Microsoft.ML.Core/ComponentModel/AssemblyLoadingUtils.cs b/src/Microsoft.ML.Core/ComponentModel/AssemblyLoadingUtils.cs index e947776bc9..a194088d3b 100644 --- a/src/Microsoft.ML.Core/ComponentModel/AssemblyLoadingUtils.cs +++ b/src/Microsoft.ML.Core/ComponentModel/AssemblyLoadingUtils.cs @@ -2,13 +2,13 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime.Internal.Utilities; using System; using System.IO; using System.IO.Compression; using System.Reflection; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime +namespace Microsoft.ML { [Obsolete("The usage for this is intended for the internal command line utilities and is not intended for anything related to the API. " + "Please consider another way of doing whatever it is you're attempting to accomplish.")] @@ -159,7 +159,7 @@ private static bool ShouldSkipPath(string path) case "neuraltreeevaluator.dll": case "optimizationbuilderdotnet.dll": case "parallelcommunicator.dll": - case "microsoft.ml.runtime.runtests.dll": + case "Microsoft.ML.runtests.dll": case "scopecompiler.dll": case "symsgdnative.dll": case "tbb.dll": diff --git a/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs b/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs index d0ab610a60..0b50c559a7 100644 --- a/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs +++ b/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs @@ -2,17 +2,17 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Utilities; using System; using System.Collections.Generic; using System.Linq; using System.Reflection; using System.Text.RegularExpressions; +using Microsoft.ML.CommandLine; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Utilities; // REVIEW: Determine ideal namespace. -namespace Microsoft.ML.Runtime +namespace Microsoft.ML { /// /// This catalogs instantiatable components (aka, loadable classes). Components are registered via @@ -40,7 +40,8 @@ internal ComponentCatalog() /// /// Provides information on an instantiatable component, aka, loadable class. /// - public sealed class LoadableClassInfo + [BestFriend] + internal sealed class LoadableClassInfo { /// /// Used for dictionary lookup based on signature and name. @@ -264,7 +265,8 @@ public object CreateArguments() /// /// A description of a single entry point. /// - public sealed class EntryPointInfo + [BestFriend] + internal sealed class EntryPointInfo { public readonly string Name; public readonly string Description; @@ -333,7 +335,8 @@ private Type[] FindEntryPointKinds(Type type) /// The 'component' is a non-standalone building block that is used to parametrize entry points or other ML.NET components. /// For example, 'Loss function', or 'similarity calculator' could be components. /// - public sealed class ComponentInfo + [BestFriend] + internal sealed class ComponentInfo { public readonly string Name; public readonly string Description; @@ -616,7 +619,8 @@ public void RegisterAssembly(Assembly assembly, bool throwOnError = true) /// Return an array containing information for all instantiatable components. /// If provided, the given set of assemblies is loaded first. /// - public LoadableClassInfo[] GetAllClasses() + [BestFriend] + internal LoadableClassInfo[] GetAllClasses() { return _classes.ToArray(); } @@ -625,7 +629,8 @@ public LoadableClassInfo[] GetAllClasses() /// Return an array containing information for instantiatable components with the given /// signature and base type. If provided, the given set of assemblies is loaded first. /// - public LoadableClassInfo[] GetAllDerivedClasses(Type typeBase, Type typeSig) + [BestFriend] + internal LoadableClassInfo[] GetAllDerivedClasses(Type typeBase, Type typeSig) { Contracts.CheckValue(typeBase, nameof(typeBase)); Contracts.CheckValueOrNull(typeSig); @@ -643,7 +648,8 @@ public LoadableClassInfo[] GetAllDerivedClasses(Type typeBase, Type typeSig) /// Return an array containing all the known signature types. If provided, the given set of assemblies /// is loaded first. /// - public Type[] GetAllSignatureTypes() + [BestFriend] + internal Type[] GetAllSignatureTypes() { return _signatures.Select(kvp => kvp.Key).ToArray(); } @@ -651,7 +657,8 @@ public Type[] GetAllSignatureTypes() /// /// Returns a string name for a given signature type. /// - public static string SignatureToString(Type sig) + [BestFriend] + internal static string SignatureToString(Type sig) { Contracts.CheckValue(sig, nameof(sig)); Contracts.CheckParam(sig.BaseType == typeof(MulticastDelegate), nameof(sig), "Must be a delegate type"); @@ -670,7 +677,8 @@ private LoadableClassInfo FindClassCore(LoadableClassInfo.Key key) return null; } - public LoadableClassInfo[] FindLoadableClasses(string name) + [BestFriend] + internal LoadableClassInfo[] FindLoadableClasses(string name) { name = name.ToLowerInvariant().Trim(); @@ -680,14 +688,16 @@ public LoadableClassInfo[] FindLoadableClasses(string name) return res; } - public LoadableClassInfo[] FindLoadableClasses() + [BestFriend] + internal LoadableClassInfo[] FindLoadableClasses() { return _classes .Where(ci => ci.SignatureTypes.Contains(typeof(TSig))) .ToArray(); } - public LoadableClassInfo[] FindLoadableClasses() + [BestFriend] + internal LoadableClassInfo[] FindLoadableClasses() { // REVIEW: this and above methods perform a linear search over all the loadable classes. // On 6/15/2015, TLC release build contained 431 of them, so adding extra lookups looks unnecessary at this time. @@ -696,12 +706,14 @@ public LoadableClassInfo[] FindLoadableClasses() .ToArray(); } - public LoadableClassInfo GetLoadableClassInfo(string loadName) + [BestFriend] + internal LoadableClassInfo GetLoadableClassInfo(string loadName) { return GetLoadableClassInfo(loadName, typeof(TSig)); } - public LoadableClassInfo GetLoadableClassInfo(string loadName, Type signatureType) + [BestFriend] + internal LoadableClassInfo GetLoadableClassInfo(string loadName, Type signatureType) { Contracts.CheckParam(signatureType.BaseType == typeof(MulticastDelegate), nameof(signatureType), "signatureType must be a delegate type"); Contracts.CheckValueOrNull(loadName); @@ -712,18 +724,21 @@ public LoadableClassInfo GetLoadableClassInfo(string loadName, Type signatureTyp /// /// Get all registered entry points. /// - public IEnumerable AllEntryPoints() + [BestFriend] + internal IEnumerable AllEntryPoints() { return _entryPoints.AsEnumerable(); } - public bool TryFindEntryPoint(string name, out EntryPointInfo entryPoint) + [BestFriend] + internal bool TryFindEntryPoint(string name, out EntryPointInfo entryPoint) { Contracts.CheckNonEmpty(name, nameof(name)); return _entryPointMap.TryGetValue(name, out entryPoint); } - public bool TryFindComponent(string kind, string alias, out ComponentInfo component) + [BestFriend] + internal bool TryFindComponent(string kind, string alias, out ComponentInfo component) { Contracts.CheckNonEmpty(kind, nameof(kind)); Contracts.CheckNonEmpty(alias, nameof(alias)); @@ -733,7 +748,8 @@ public bool TryFindComponent(string kind, string alias, out ComponentInfo compon return _componentMap.TryGetValue($"{kind}:{alias}", out component); } - public bool TryFindComponent(Type argumentType, out ComponentInfo component) + [BestFriend] + internal bool TryFindComponent(Type argumentType, out ComponentInfo component) { Contracts.CheckValue(argumentType, nameof(argumentType)); @@ -741,7 +757,8 @@ public bool TryFindComponent(Type argumentType, out ComponentInfo component) return component != null; } - public bool TryFindComponent(Type interfaceType, Type argumentType, out ComponentInfo component) + [BestFriend] + internal bool TryFindComponent(Type interfaceType, Type argumentType, out ComponentInfo component) { Contracts.CheckValue(interfaceType, nameof(interfaceType)); Contracts.CheckParam(interfaceType.IsInterface, nameof(interfaceType), "Must be interface"); @@ -751,7 +768,8 @@ public bool TryFindComponent(Type interfaceType, Type argumentType, out Componen return component != null; } - public bool TryFindComponent(Type interfaceType, string alias, out ComponentInfo component) + [BestFriend] + internal bool TryFindComponent(Type interfaceType, string alias, out ComponentInfo component) { Contracts.CheckValue(interfaceType, nameof(interfaceType)); Contracts.CheckParam(interfaceType.IsInterface, nameof(interfaceType), "Must be interface"); @@ -764,7 +782,8 @@ public bool TryFindComponent(Type interfaceType, string alias, out ComponentInfo /// Akin to , except if the regular (case sensitive) comparison fails, it will /// attempt to back off to a case-insensitive comparison. /// - public bool TryFindComponentCaseInsensitive(Type interfaceType, string alias, out ComponentInfo component) + [BestFriend] + internal bool TryFindComponentCaseInsensitive(Type interfaceType, string alias, out ComponentInfo component) { Contracts.CheckValue(interfaceType, nameof(interfaceType)); Contracts.CheckParam(interfaceType.IsInterface, nameof(interfaceType), "Must be interface"); @@ -786,7 +805,8 @@ private static bool AnyMatch(string name, string[] aliases) /// /// Returns all valid component kinds. /// - public IEnumerable GetAllComponentKinds() + [BestFriend] + internal IEnumerable GetAllComponentKinds() { return _components.Select(x => x.Kind).Distinct().OrderBy(x => x); } @@ -794,7 +814,8 @@ public IEnumerable GetAllComponentKinds() /// /// Returns all components of the specified kind. /// - public IEnumerable GetAllComponents(string kind) + [BestFriend] + internal IEnumerable GetAllComponents(string kind) { Contracts.CheckNonEmpty(kind, nameof(kind)); Contracts.CheckParam(IsValidName(kind), nameof(kind), "Invalid component kind"); @@ -804,13 +825,15 @@ public IEnumerable GetAllComponents(string kind) /// /// Returns all components that implement the specified interface. /// - public IEnumerable GetAllComponents(Type interfaceType) + [BestFriend] + internal IEnumerable GetAllComponents(Type interfaceType) { Contracts.CheckValue(interfaceType, nameof(interfaceType)); return _components.Where(x => x.InterfaceType == interfaceType).OrderBy(x => x.Name); } - public bool TryGetComponentKind(Type signatureType, out string kind) + [BestFriend] + internal bool TryGetComponentKind(Type signatureType, out string kind) { Contracts.CheckValue(signatureType, nameof(signatureType)); // REVIEW: replace with a dictionary lookup. @@ -821,7 +844,8 @@ public bool TryGetComponentKind(Type signatureType, out string kind) return faceAttr != null; } - public bool TryGetComponentShortName(Type type, out string name) + [BestFriend] + internal bool TryGetComponentShortName(Type type, out string name) { ComponentInfo component; if (!TryFindComponent(type, out component)) @@ -850,7 +874,8 @@ private static bool IsValidName(string name) /// /// Create an instance of the indicated component with the given extra parameters. /// - public static TRes CreateInstance(IHostEnvironment env, Type signatureType, string name, string options, params object[] extra) + [BestFriend] + internal static TRes CreateInstance(IHostEnvironment env, Type signatureType, string name, string options, params object[] extra) where TRes : class { TRes result; @@ -863,13 +888,15 @@ public static TRes CreateInstance(IHostEnvironment env, Type signatureType /// Try to create an instance of the indicated component and settings with the given extra parameters. /// If there is no such component in the catalog, returns false. Any other error results in an exception. /// - public static bool TryCreateInstance(IHostEnvironment env, out TRes result, string name, string options, params object[] extra) + [BestFriend] + internal static bool TryCreateInstance(IHostEnvironment env, out TRes result, string name, string options, params object[] extra) where TRes : class { return TryCreateInstance(env, typeof(TSig), out result, name, options, extra); } - private static bool TryCreateInstance(IHostEnvironment env, Type signatureType, out TRes result, string name, string options, params object[] extra) + [BestFriend] + internal static bool TryCreateInstance(IHostEnvironment env, Type signatureType, out TRes result, string name, string options, params object[] extra) where TRes : class { Contracts.CheckValue(env, nameof(env)); diff --git a/src/Microsoft.ML.Core/ComponentModel/ComponentFactory.cs b/src/Microsoft.ML.Core/ComponentModel/ComponentFactory.cs index 58ea3c6baf..93ebdf6397 100644 --- a/src/Microsoft.ML.Core/ComponentModel/ComponentFactory.cs +++ b/src/Microsoft.ML.Core/ComponentModel/ComponentFactory.cs @@ -4,7 +4,7 @@ using System; -namespace Microsoft.ML.Runtime +namespace Microsoft.ML { /// /// This is a token interface that all component factories must implement. @@ -48,7 +48,8 @@ public interface IComponentFactory /// /// A utility class for creating instances. /// - public static class ComponentFactoryUtils + [BestFriend] + internal static class ComponentFactoryUtils { /// /// Creates a component factory with no extra parameters (other than an ) diff --git a/src/Microsoft.ML.Core/ComponentModel/LoadableClassAttribute.cs b/src/Microsoft.ML.Core/ComponentModel/LoadableClassAttribute.cs index 328a48be2b..7e8ea83e73 100644 --- a/src/Microsoft.ML.Core/ComponentModel/LoadableClassAttribute.cs +++ b/src/Microsoft.ML.Core/ComponentModel/LoadableClassAttribute.cs @@ -5,17 +5,19 @@ using System; using System.Linq; using System.Reflection; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime +namespace Microsoft.ML { /// /// Common signature type with no extra parameters. /// - public delegate void SignatureDefault(); + [BestFriend] + internal delegate void SignatureDefault(); [AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)] - public sealed class LoadableClassAttribute : LoadableClassAttributeBase + [BestFriend] + internal sealed class LoadableClassAttribute : LoadableClassAttributeBase { /// /// Assembly attribute used to specify that a class is loadable by a machine learning @@ -98,7 +100,7 @@ public LoadableClassAttribute(string summary, Type instType, Type loaderType, Ty } } - public abstract class LoadableClassAttributeBase : Attribute + internal abstract class LoadableClassAttributeBase : Attribute { // Note: these properties have private setters to make attribute parsing easier - the values // are all guaranteed to be in the ConstructorArguments of the CustomAttributeData diff --git a/src/Microsoft.ML.Core/Data/ColumnType.cs b/src/Microsoft.ML.Core/Data/ColumnType.cs index 83b5caa27d..eb3d080ecd 100644 --- a/src/Microsoft.ML.Core/Data/ColumnType.cs +++ b/src/Microsoft.ML.Core/Data/ColumnType.cs @@ -7,12 +7,11 @@ using System; using System.Collections.Immutable; using System.Linq; -using System.Reflection; using System.Text; using System.Threading; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// This is the abstract base class for all types in the type system. @@ -80,12 +79,6 @@ private protected ColumnType(Type rawType, DataKind rawKind) [BestFriend] internal bool IsPrimitive { get; } - /// - /// Equivalent to as . - /// - [BestFriend] - internal PrimitiveType AsPrimitive => IsPrimitive ? (PrimitiveType)this : null; - /// /// Whether this type is a standard numeric type. External code should use is . /// @@ -140,12 +133,6 @@ internal bool IsBool [BestFriend] internal bool IsKey { get; } - /// - /// Equivalent to as . - /// - [BestFriend] - internal KeyType AsKey => IsKey ? (KeyType)this : null; - /// /// Zero return means either it's not a key type or the cardinality is unknown. External code should first /// test whether this is of type , then if so get the property @@ -166,12 +153,6 @@ internal bool IsBool [BestFriend] internal bool IsVector { get; } - /// - /// Equivalent to as . - /// - [BestFriend] - internal VectorType AsVector => IsVector ? (VectorType)this : null; - /// /// For non-vector types, this returns the column type itself (i.e., return this). /// @@ -766,8 +747,7 @@ public override bool Equals(ColumnType other) if (other == this) return true; - var tmp = other as KeyType; - if (tmp == null) + if (!(other is KeyType tmp)) return false; if (RawKind != tmp.RawKind) return false; diff --git a/src/Microsoft.ML.Core/Data/DataKind.cs b/src/Microsoft.ML.Core/Data/DataKind.cs index ad8d8fbfe0..db5a75326e 100644 --- a/src/Microsoft.ML.Core/Data/DataKind.cs +++ b/src/Microsoft.ML.Core/Data/DataKind.cs @@ -3,9 +3,8 @@ // See the LICENSE file in the project root for more information. using System; -using System.Text; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// Data type specifier. @@ -52,7 +51,8 @@ public enum DataKind : byte /// /// Extension methods related to the DataKind enum. /// - public static class DataKindExtensions + [BestFriend] + internal static class DataKindExtensions { public const DataKind KindMin = DataKind.I1; public const DataKind KindLim = DataKind.U16 + 1; @@ -171,7 +171,7 @@ public static Type ToType(this DataKind kind) case DataKind.DZ: return typeof(DateTimeOffset); case DataKind.UG: - return typeof(UInt128); + return typeof(RowId); } return null; @@ -215,7 +215,7 @@ public static bool TryGetDataKind(this Type type, out DataKind kind) kind = DataKind.DT; else if (type == typeof(DateTimeOffset)) kind = DataKind.DZ; - else if (type == typeof(UInt128)) + else if (type == typeof(RowId)) kind = DataKind.UG; else { diff --git a/src/Microsoft.ML.Core/Data/ICommand.cs b/src/Microsoft.ML.Core/Data/ICommand.cs index 44d4c7340b..4c68218436 100644 --- a/src/Microsoft.ML.Core/Data/ICommand.cs +++ b/src/Microsoft.ML.Core/Data/ICommand.cs @@ -2,11 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Data; - -namespace Microsoft.ML.Runtime.Command +namespace Microsoft.ML.Command { /// /// The signature for commands. diff --git a/src/Microsoft.ML.Core/Data/ICursor.cs b/src/Microsoft.ML.Core/Data/ICursor.cs deleted file mode 100644 index 476f07245c..0000000000 --- a/src/Microsoft.ML.Core/Data/ICursor.cs +++ /dev/null @@ -1,106 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using Float = System.Single; - -using System; - -namespace Microsoft.ML.Runtime.Data -{ - /// - /// This is a base interface for an and . It contains only the - /// positional properties, no behavioral methods, and no data. - /// - public interface ICounted - { - /// - /// This is incremented for ICursor when the underlying contents changes, giving clients a way to detect change. - /// Generally it's -1 when the object is in an invalid state. In particular, for an , this is -1 - /// when the is or . - /// - /// Note that this position is not position within the underlying data, but position of this cursor only. - /// If one, for example, opened a set of parallel streaming cursors, or a shuffled cursor, each such cursor's - /// first valid entry would always have position 0. - /// - long Position { get; } - - /// - /// This provides a means for reconciling multiple streams of counted things. Generally, in each stream, - /// batch numbers should be non-decreasing. Furthermore, any given batch number should only appear in one - /// of the streams. Order is determined by batch number. The reconciler ensures that each stream (that is - /// still active) has at least one item available, then takes the item with the smallest batch number. - /// - /// Note that there is no suggestion that the batches for a particular entry will be consistent from - /// cursoring to cursoring, except for the consistency in resulting in the same overall ordering. The same - /// entry could have different batch numbers from one cursoring to another. There is also no requirement - /// that any given batch number must appear, at all. - /// - long Batch { get; } - - /// - /// A getter for a 128-bit ID value. It is common for objects to serve multiple - /// instances to iterate over what is supposed to be the same data, for example, in a - /// a cursor set will produce the same data as a serial cursor, just partitioned, and a shuffled cursor - /// will produce the same data as a serial cursor or any other shuffled cursor, only shuffled. The ID - /// exists for applications that need to reconcile which entry is actually which. Ideally this ID should - /// be unique, but for practical reasons, it suffices if collisions are simply extremely improbable. - /// - /// Note that this ID, while it must be consistent for multiple streams according to the semantics - /// above, is not considered part of the data per se. So, to take the example of a data view specifically, - /// a single data view must render consistent IDs across all cursorings, but there is no suggestion at - /// all that if the "same" data were presented in a different data view (as by, say, being transformed, - /// cached, saved, or whatever), that the IDs between the two different data views would have any - /// discernable relationship. - ValueGetter GetIdGetter(); - } - - /// - /// Defines the possible states of a cursor. - /// - public enum CursorState - { - NotStarted, - Good, - Done - } - - /// - /// The basic cursor interface. is incremented by - /// and . When the cursor state is or - /// , is -1. Otherwise, - /// >= 0. - /// - public interface ICursor : ICounted, IDisposable - { - /// - /// Returns the state of the cursor. Before the first call to or - /// this should be . After - /// any call those move functions that returns true, this should return - /// , - /// - CursorState State { get; } - - /// - /// Advance to the next row. When the cursor is first created, this method should be called to - /// move to the first row. Returns false if there are no more rows. - /// - bool MoveNext(); - - /// - /// Logically equivalent to calling the given number of times. The - /// parameter must be positive. Note that cursor implementations may be - /// able to optimize this. - /// - bool MoveMany(long count); - - /// - /// Returns a cursor that can be used for invoking , , - /// , and , with results identical to calling those - /// on this cursor. Generally, if the root cursor is not the same as this cursor, using the - /// root cursor will be faster. As an aside, note that this is not necessarily the case of - /// values from . - /// - ICursor GetRootCursor(); - } -} \ No newline at end of file diff --git a/src/Microsoft.ML.Core/Data/IDataView.cs b/src/Microsoft.ML.Core/Data/IDataView.cs index cb0da9a194..893b17a26a 100644 --- a/src/Microsoft.ML.Core/Data/IDataView.cs +++ b/src/Microsoft.ML.Core/Data/IDataView.cs @@ -2,17 +2,17 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Data; using System; using System.Collections.Generic; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// Legacy interface for schema information. /// Please avoid implementing this interface, use . /// - public interface ISchema + [BestFriend] + internal interface ISchema { /// /// Number of columns. @@ -60,22 +60,11 @@ public interface ISchema void GetMetadata(string kind, int col, ref TValue value); } - /// - /// Base interface for schematized information. IDataView and IRowCursor both derive from this. - /// - public interface ISchematized - { - /// - /// Gets an instance of Schema. - /// - Schema Schema { get; } - } - /// /// The input and output of Query Operators (Transforms). This is the fundamental data pipeline - /// type, comparable to IEnumerable for LINQ. + /// type, comparable to for LINQ. /// - public interface IDataView : ISchematized + public interface IDataView { /// /// Whether this IDataView supports shuffling of rows, to any degree. @@ -99,75 +88,187 @@ public interface IDataView : ISchematized /// a getter for an inactive columns will throw. The predicate must be /// non-null. To activate all columns, pass "col => true". /// - IRowCursor GetRowCursor(Func needCol, IRandom rand = null); + RowCursor GetRowCursor(Func needCol, Random rand = null); /// - /// This constructs a set of parallel batch cursors. The value n is a recommended limit - /// on cardinality. If is non-positive, this indicates that the caller - /// has no recommendation, and the implementation should have some default behavior to cover - /// this case. Note that this is strictly a recommendation: it is entirely possible that - /// an implementation can return a different number of cursors. + /// This constructs a set of parallel batch cursors. The value is a recommended limit on + /// cardinality. If is non-positive, this indicates that the caller has no recommendation, + /// and the implementation should have some default behavior to cover this case. Note that this is strictly a + /// recommendation: it is entirely possible that an implementation can return a different number of cursors. /// /// The cursors should return the same data as returned through - /// , except partitioned: no two cursors - /// should return the "same" row as would have been returned through the regular serial cursor, - /// but all rows should be returned by exactly one of the cursors returned from this cursor. - /// The cursors can have their values reconciled downstream through the use of the - /// property. - /// - /// This is an object that can be used to reconcile the - /// returned array of cursors. When the array of cursors is of length 1, it is legal, - /// indeed expected, that this parameter should be null. + /// , except partitioned: no two cursors should return the + /// "same" row as would have been returned through the regular serial cursor, but all rows should be returned by + /// exactly one of the cursors returned from this cursor. The cursors can have their values reconciled + /// downstream through the use of the property. + /// + /// The typical usage pattern is that a set of cursors is requested, each of them is then given to a set of + /// working threads that consume from them independently while, ultimately, the results are finally collated in + /// the end by exploiting the ordering of the property described above. More typical + /// scenarios will be content with pulling from the single serial cursor of + /// . + /// /// The predicate, where a column is active if this returns true. /// The suggested degree of parallelism. /// An instance /// - IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, - Func needCol, int n, IRandom rand = null); - } + RowCursor[] GetRowCursorSet(Func needCol, int n, Random rand = null); - /// - /// This is used to consolidate parallel cursors into a single cursor. The object that determines - /// the number of cursors and splits the row "stream" provides the consolidator object. - /// - public interface IRowCursorConsolidator - { /// - /// Create a consolidated cursor from the given parallel cursor set. + /// Gets an instance of Schema. /// - IRowCursor CreateCursor(IChannelProvider provider, IRowCursor[] inputs); + Schema Schema { get; } } /// - /// Delegate type to get a value. This can used for efficient access to data in an IRow - /// or IRowCursor. + /// Delegate type to get a value. This can be used for efficient access to data in a + /// or . /// public delegate void ValueGetter(ref TValue value); /// - /// A logical row. May be a row of an IDataView or a stand-alone row. If/when its contents - /// change, its ICounted.Counter value is incremented. + /// A logical row. May be a row of an or a stand-alone row. If/when its contents + /// change, its value is changed. /// - public interface IRow : ISchematized, ICounted + public abstract class Row : IDisposable { + /// + /// This is incremented when the underlying contents changes, giving clients a way to detect change. Generally + /// it's -1 when the object is in an invalid state. In particular, for an , this is -1 + /// when the is or . + /// + /// Note that this position is not position within the underlying data, but position of this cursor only. If + /// one, for example, opened a set of parallel streaming cursors, or a shuffled cursor, each such cursor's first + /// valid entry would always have position 0. + /// + public abstract long Position { get; } + + /// + /// This provides a means for reconciling multiple rows that have been produced generally from + /// . When getting a set, there is a need + /// to, while allowing parallel processing to proceed, always have an aim that the original order should be + /// reconverable. Note, whether or not a user cares about that original order in ones specific application is + /// another story altogether (most callers of this as a practical matter do not, otherwise they would not call + /// it), but at least in principle it should be possible to reconstruct the original order one would get from an + /// identically configured . So: for any cursor + /// implementation, batch numbers should be non-decreasing. Furthermore, any given batch number should only + /// appear in one of the cursors as returned by + /// . In this way, order is determined by + /// batch number. An operation that reconciles these cursors to produce a consistent single cursoring, could do + /// so by drawing from the single cursor, among all cursors in the set, that has the smallest batch number + /// available. + /// + /// Note that there is no suggestion that the batches for a particular entry will be consistent from cursoring + /// to cursoring, except for the consistency in resulting in the same overall ordering. The same entry could + /// have different batch numbers from one cursoring to another. There is also no requirement that any given + /// batch number must appear, at all. It is merely a mechanism for recovering ordering from a possibly arbitrary + /// partitioning of the data. It also follows from this, of course, that considering the batch to be a property + /// of the data is completely invalid. + /// + public abstract long Batch { get; } + + /// + /// A getter for a 128-bit ID value. It is common for objects to serve multiple + /// instances to iterate over what is supposed to be the same data, for example, in a + /// a cursor set will produce the same data as a serial cursor, just partitioned, and a shuffled cursor will + /// produce the same data as a serial cursor or any other shuffled cursor, only shuffled. The ID exists for + /// applications that need to reconcile which entry is actually which. Ideally this ID should be unique, but for + /// practical reasons, it suffices if collisions are simply extremely improbable. + /// + /// Note that this ID, while it must be consistent for multiple streams according to the semantics above, is not + /// considered part of the data per se. So, to take the example of a data view specifically, a single data view + /// must render consistent IDs across all cursorings, but there is no suggestion at all that if the "same" data + /// were presented in a different data view (as by, say, being transformed, cached, saved, or whatever), that + /// the IDs between the two different data views would have any discernable relationship. + public abstract ValueGetter GetIdGetter(); + /// /// Returns whether the given column is active in this row. /// - bool IsColumnActive(int col); + public abstract bool IsColumnActive(int col); /// /// Returns a value getter delegate to fetch the given column value from the row. /// This throws if the column is not active in this row, or if the type /// differs from this column's type. /// - ValueGetter GetGetter(int col); + public abstract ValueGetter GetGetter(int col); + + /// + /// Gets a , which provides name and type information for variables + /// (i.e., columns in ML.NET's type system) stored in this row. + /// + public abstract Schema Schema { get; } + + /// + /// Implementation of dispose. Calls with . + /// + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + /// + /// The disposable method for the disposable pattern. This default implementation does nothing. + /// + /// Whether this was called from . + /// Subclasses that implement should call this method with + /// , but I hasten to add that implementing finalizers should be + /// avoided if at all possible.. + protected virtual void Dispose(bool disposing) + { + } + } + + /// + /// Defines the possible states of a cursor. + /// + public enum CursorState + { + NotStarted, + Good, + Done } /// - /// A cursor through rows of an . Note that this includes/is an - /// , as well as an . + /// The basic cursor base class to cursor through rows of an . Note that + /// this is also an . The is incremented by + /// and . When the cursor state is or + /// , is -1. Otherwise, + /// >= 0. /// - public interface IRowCursor : ICursor, IRow + public abstract class RowCursor : Row { + /// + /// Returns the state of the cursor. Before the first call to or + /// this should be . After + /// any call those move functions that returns , this should return + /// , + /// + public abstract CursorState State { get; } + + /// + /// Advance to the next row. When the cursor is first created, this method should be called to + /// move to the first row. Returns false if there are no more rows. + /// + public abstract bool MoveNext(); + + /// + /// Logically equivalent to calling the given number of times. The + /// parameter must be positive. Note that cursor implementations may be + /// able to optimize this. + /// + public abstract bool MoveMany(long count); + + /// + /// Returns a cursor that can be used for invoking , , + /// , and , with results identical to calling those + /// on this cursor. Generally, if the root cursor is not the same as this cursor, using the + /// root cursor will be faster. As an aside, note that this is not necessarily the case of + /// values from . + /// + public abstract RowCursor GetRootCursor(); } } \ No newline at end of file diff --git a/src/Microsoft.ML.Core/Data/IEstimator.cs b/src/Microsoft.ML.Core/Data/IEstimator.cs index b82ad4f5a7..b0300795ec 100644 --- a/src/Microsoft.ML.Core/Data/IEstimator.cs +++ b/src/Microsoft.ML.Core/Data/IEstimator.cs @@ -2,27 +2,29 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using System; +using System.Collections; using System.Collections.Generic; using System.Linq; +using Microsoft.ML.Data; namespace Microsoft.ML.Core.Data { /// /// A set of 'requirements' to the incoming schema, as well as a set of 'promises' of the outgoing schema. - /// This is more relaxed than the proper , since it's only a subset of the columns, + /// This is more relaxed than the proper , since it's only a subset of the columns, /// and also since it doesn't specify exact 's for vectors and keys. /// - public sealed class SchemaShape + public sealed class SchemaShape : IReadOnlyList { - public readonly Column[] Columns; + private readonly Column[] _columns; private static readonly SchemaShape _empty = new SchemaShape(Enumerable.Empty()); - public sealed class Column + public int Count => _columns.Count(); + + public Column this[int index] => _columns[index]; + + public struct Column { public enum VectorKind { @@ -55,13 +57,13 @@ public enum VectorKind /// public readonly SchemaShape Metadata; - public Column(string name, VectorKind vecKind, ColumnType itemType, bool isKey, SchemaShape metadata = null) + [BestFriend] + internal Column(string name, VectorKind vecKind, ColumnType itemType, bool isKey, SchemaShape metadata = null) { Contracts.CheckNonEmpty(name, nameof(name)); Contracts.CheckValueOrNull(metadata); Contracts.CheckParam(!itemType.IsKey, nameof(itemType), "Item type cannot be a key"); Contracts.CheckParam(!itemType.IsVector, nameof(itemType), "Item type cannot be a vector"); - Contracts.CheckParam(!isKey || KeyType.IsValidDataKind(itemType.RawKind), nameof(itemType), "The item type must be valid for a key"); Name = name; @@ -80,9 +82,10 @@ public Column(string name, VectorKind vecKind, ColumnType itemType, bool isKey, /// - The columns of of is a superset of our columns. /// - Each such metadata column is itself compatible with the input metadata column. /// - public bool IsCompatibleWith(Column inputColumn) + [BestFriend] + internal bool IsCompatibleWith(Column inputColumn) { - Contracts.CheckValue(inputColumn, nameof(inputColumn)); + Contracts.Check(inputColumn.IsValid, nameof(inputColumn)); if (Name != inputColumn.Name) return false; if (Kind != inputColumn.Kind) @@ -91,7 +94,7 @@ public bool IsCompatibleWith(Column inputColumn) return false; if (IsKey != inputColumn.IsKey) return false; - foreach (var metaCol in Metadata.Columns) + foreach (var metaCol in Metadata) { if (!inputColumn.Metadata.TryFindColumn(metaCol.Name, out var inputMetaCol)) return false; @@ -101,7 +104,8 @@ public bool IsCompatibleWith(Column inputColumn) return true; } - public string GetTypeString() + [BestFriend] + internal string GetTypeString() { string result = ItemType.ToString(); if (IsKey) @@ -112,13 +116,20 @@ public string GetTypeString() result = $"VarVector<{result}>"; return result; } + + /// + /// Return if this structure is not identical to the default value of . If true, + /// it means this structure is initialized properly and therefore considered as valid. + /// + [BestFriend] + internal bool IsValid => Name != null; } public SchemaShape(IEnumerable columns) { Contracts.CheckValue(columns, nameof(columns)); - Columns = columns.ToArray(); - Contracts.CheckParam(columns.All(c => c != null), nameof(columns), "No items should be null."); + _columns = columns.ToArray(); + Contracts.CheckParam(columns.All(c => c.IsValid), nameof(columns), "Some items are not initialized properly."); } /// @@ -151,26 +162,27 @@ internal static void GetColumnTypeShape(ColumnType type, /// /// Create a schema shape out of the fully defined schema. /// - public static SchemaShape Create(Schema schema) + [BestFriend] + internal static SchemaShape Create(Schema schema) { Contracts.CheckValue(schema, nameof(schema)); var cols = new List(); for (int iCol = 0; iCol < schema.Count; iCol++) { - if (!schema.IsHidden(iCol)) + if (!schema[iCol].IsHidden) { // First create the metadata. var mCols = new List(); - foreach (var metaNameType in schema.GetMetadataTypes(iCol)) + foreach (var metaColumn in schema[iCol].Metadata.Schema) { - GetColumnTypeShape(metaNameType.Value, out var mVecKind, out var mItemType, out var mIsKey); - mCols.Add(new Column(metaNameType.Key, mVecKind, mItemType, mIsKey)); + GetColumnTypeShape(metaColumn.Type, out var mVecKind, out var mItemType, out var mIsKey); + mCols.Add(new Column(metaColumn.Name, mVecKind, mItemType, mIsKey)); } var metadata = mCols.Count > 0 ? new SchemaShape(mCols) : _empty; // Next create the single column. - GetColumnTypeShape(schema.GetColumnType(iCol), out var vecKind, out var itemType, out var isKey); - cols.Add(new Column(schema.GetColumnName(iCol), vecKind, itemType, isKey, metadata)); + GetColumnTypeShape(schema[iCol].Type, out var vecKind, out var itemType, out var isKey); + cols.Add(new Column(schema[iCol].Name, vecKind, itemType, isKey, metadata)); } } return new SchemaShape(cols); @@ -179,25 +191,23 @@ public static SchemaShape Create(Schema schema) /// /// Returns if there is a column with a specified and if so stores it in . /// - public bool TryFindColumn(string name, out Column column) + [BestFriend] + internal bool TryFindColumn(string name, out Column column) { Contracts.CheckValue(name, nameof(name)); - column = Columns.FirstOrDefault(x => x.Name == name); - return column != null; + column = _columns.FirstOrDefault(x => x.Name == name); + return column.IsValid; } + public IEnumerator GetEnumerator() => ((IEnumerable)_columns).GetEnumerator(); + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + // REVIEW: I think we should have an IsCompatible method to check if it's OK to use one schema shape // as an input to another schema shape. I started writing, but realized that there's more than one way to check for // the 'compatibility': as in, 'CAN be compatible' vs. 'WILL be compatible'. } - /// - /// Exception class for schema validation errors. - /// - public class SchemaException : Exception - { - } - /// /// The 'data reader' takes a certain kind of input and turns it into an . /// @@ -246,7 +256,6 @@ public interface ITransformer /// /// Schema propagation for transformers. /// Returns the output schema of the data, if the input schema is like the one provided. - /// Throws if the input schema is not valid for the transformer. /// Schema GetOutputSchema(Schema inputSchema); @@ -275,7 +284,7 @@ public interface ITransformer /// /// The estimator (in Spark terminology) is an 'untrained transformer'. It needs to 'fit' on the data to manufacture /// a transformer. - /// It also provides the 'schema propagation' like transformers do, but over instead of . + /// It also provides the 'schema propagation' like transformers do, but over instead of . /// public interface IEstimator where TTransformer : ITransformer @@ -288,7 +297,6 @@ public interface IEstimator /// /// Schema propagation for estimators. /// Returns the output schema shape of the estimator, if the input schema shape is like the one provided. - /// Throws iff the input schema is not valid for the estimator. /// SchemaShape GetOutputSchema(SchemaShape inputSchema); } diff --git a/src/Microsoft.ML.Core/Data/IFileHandle.cs b/src/Microsoft.ML.Core/Data/IFileHandle.cs index 37b871b7b6..b5b2ae8183 100644 --- a/src/Microsoft.ML.Core/Data/IFileHandle.cs +++ b/src/Microsoft.ML.Core/Data/IFileHandle.cs @@ -5,9 +5,9 @@ using System; using System.Collections.Generic; using System.IO; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime +namespace Microsoft.ML { /// /// A file handle. diff --git a/src/Microsoft.ML.Core/Data/IHostEnvironment.cs b/src/Microsoft.ML.Core/Data/IHostEnvironment.cs index eb0d57845c..72639c6ac2 100644 --- a/src/Microsoft.ML.Core/Data/IHostEnvironment.cs +++ b/src/Microsoft.ML.Core/Data/IHostEnvironment.cs @@ -5,7 +5,7 @@ using System; using System.ComponentModel.Composition.Hosting; -namespace Microsoft.ML.Runtime +namespace Microsoft.ML { /// /// A channel provider can create new channels and generic information pipes. @@ -94,7 +94,7 @@ public interface IHost : IHostEnvironment /// The random number generator issued to this component. Note that random number /// generators are NOT thread safe. /// - IRandom Rand { get; } + Random Rand { get; } /// /// Signal to stop exection in this host and all its children. @@ -233,7 +233,8 @@ public interface IChannel : IPipe /// that do not belong in more specific areas, for example, or /// component creation. /// - public static class HostExtensions + [BestFriend] + internal static class HostExtensions { public static T Apply(this IHost host, string channelName, Func func) { diff --git a/src/Microsoft.ML.Core/Data/IProgressChannel.cs b/src/Microsoft.ML.Core/Data/IProgressChannel.cs index 924e7806f9..26dd7e1831 100644 --- a/src/Microsoft.ML.Core/Data/IProgressChannel.cs +++ b/src/Microsoft.ML.Core/Data/IProgressChannel.cs @@ -4,7 +4,7 @@ using System; -namespace Microsoft.ML.Runtime +namespace Microsoft.ML { /// /// This is a factory interface for . diff --git a/src/Microsoft.ML.Core/Data/IRowToRowMapper.cs b/src/Microsoft.ML.Core/Data/IRowToRowMapper.cs new file mode 100644 index 0000000000..d9c55227f3 --- /dev/null +++ b/src/Microsoft.ML.Core/Data/IRowToRowMapper.cs @@ -0,0 +1,53 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.ML.Data +{ + /// + /// This interface maps an input to an output . Typically, the output contains + /// both the input columns and new columns added by the implementing class, although some implementations may + /// return a subset of the input columns. + /// This interface is similar to , except it does not have any input role mappings, + /// so to rebind, the same input column names must be used. + /// Implementations of this interface are typically created over defined input . + /// + public interface IRowToRowMapper + { + /// + /// Mappers are defined as accepting inputs with this very specific schema. + /// + Schema InputSchema { get; } + + /// + /// Gets an instance of which describes the columns' names and types in the output generated by this mapper. + /// + Schema OutputSchema { get; } + + /// + /// Given a predicate specifying which columns are needed, return a predicate indicating which input columns are + /// needed. The domain of the function is defined over the indices of the columns of + /// for . + /// + Func GetDependencies(Func predicate); + + /// + /// Get an with the indicated active columns, based on the input . + /// The active columns are those for which returns true. Getting values on inactive + /// columns of the returned row will throw. Null predicates are disallowed. + /// + /// The of should be the same object as + /// . Implementors of this method should throw if that is not the case. Conversely, + /// the returned value must have the same schema as . + /// + /// This method creates a live connection between the input and the output . In particular, when the getters of the output are invoked, they invoke the + /// getters of the input row and base the output values on the current values of the input . + /// The output values are re-computed when requested through the getters. Also, the returned + /// will dispose when it is disposed. + /// + Row GetRow(Row input, Func active); + } +} diff --git a/src/Microsoft.ML.Core/Data/ISchemaBindableMapper.cs b/src/Microsoft.ML.Core/Data/ISchemaBindableMapper.cs index 525c7926a0..34b09fcf44 100644 --- a/src/Microsoft.ML.Core/Data/ISchemaBindableMapper.cs +++ b/src/Microsoft.ML.Core/Data/ISchemaBindableMapper.cs @@ -2,11 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Data; -using System; using System.Collections.Generic; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// A mapper that can be bound to a (which is an ISchema, with mappings from column kinds @@ -21,7 +19,8 @@ namespace Microsoft.ML.Runtime.Data /// for the output schema of the . In case the interface is implemented, /// the SimpleRow class can be used in the method. /// - public interface ISchemaBindableMapper + [BestFriend] + internal interface ISchemaBindableMapper { ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema); } @@ -30,13 +29,19 @@ public interface ISchemaBindableMapper /// This interface is used to map a schema from input columns to output columns. The should keep track /// of the input columns that are needed for the mapping. /// - public interface ISchemaBoundMapper : ISchematized + [BestFriend] + internal interface ISchemaBoundMapper { /// /// The that was passed to the in the binding process. /// RoleMappedSchema InputRoleMappedSchema { get; } + /// + /// Gets schema of this mapper's output. + /// + Schema OutputSchema { get; } + /// /// A property to get back the that produced this . /// @@ -51,53 +56,13 @@ public interface ISchemaBoundMapper : ISchematized /// /// This interface combines with . /// - public interface ISchemaBoundRowMapper : ISchemaBoundMapper, IRowToRowMapper - { - } - - /// - /// This interface maps an input to an output . Typically, the output contains - /// both the input columns and new columns added by the implementing class, although some implementations may - /// return a subset of the input columns. - /// This interface is similar to , except it does not have any input role mappings, - /// so to rebind, the same input column names must be used. - /// Implementing of this object are typically created using a definie input . - /// - public interface IRowToRowMapper : ISchematized + [BestFriend] + internal interface ISchemaBoundRowMapper : ISchemaBoundMapper, IRowToRowMapper { /// - /// Mappers are defined as accepting inputs with this very specific schema. - /// - Schema InputSchema { get; } - - /// - /// Given a predicate specifying which columns are needed, return a predicate indicating which input columns are - /// needed. The domain of the function is defined over the indices of the columns of - /// for . - /// - Func GetDependencies(Func predicate); - - /// - /// Get an with the indicated active columns, based on the input . - /// The active columns are those for which returns true. Getting values on inactive - /// columns of the returned row will throw. Null predicates are disallowed. - /// - /// The of should be the same object as - /// . Implementors of this method should throw if that is not the case. Conversely, - /// the returned value must have the same schema as . - /// - /// This method creates a live connection between the input and the output . In particular, when the getters of the output are invoked, they invoke the - /// getters of the input row and base the output values on the current values of the input . - /// The output values are re-computed when requested through the getters. - /// - /// The optional should be invoked by any user of this row mapping, once it no - /// longer needs the . If no action is needed when the cursor is Disposed, the implementation - /// should set to null, otherwise it should be set to a delegate to be - /// invoked by the code calling this object. (For example, a wrapping cursor's - /// method. It's best for this action to be idempotent - calling it multiple times should be equivalent to - /// calling it once. + /// There are two schemas from and . + /// Since the two parent schema's are identical in all derived classes, we merge them into . /// - IRow GetRow(IRow input, Func active, out Action disposer); + new Schema OutputSchema { get; } } } diff --git a/src/Microsoft.ML.Core/Data/IValueMapper.cs b/src/Microsoft.ML.Core/Data/IValueMapper.cs index c6abfbc02d..dcdf6706c7 100644 --- a/src/Microsoft.ML.Core/Data/IValueMapper.cs +++ b/src/Microsoft.ML.Core/Data/IValueMapper.cs @@ -2,20 +2,20 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; - -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// Delegate type to map/convert a value. /// - public delegate void ValueMapper(in TSrc src, ref TDst dst); + [BestFriend] + internal delegate void ValueMapper(in TSrc src, ref TDst dst); /// /// Delegate type to map/convert among three values, for example, one input with two /// outputs, or two inputs with one output. /// - public delegate void ValueMapper(in TVal1 val1, ref TVal2 val2, ref TVal3 val3); + [BestFriend] + internal delegate void ValueMapper(in TVal1 val1, ref TVal2 val2, ref TVal3 val3); /// /// Interface for mapping a single input value (of an indicated ColumnType) to @@ -24,7 +24,8 @@ namespace Microsoft.ML.Runtime.Data /// type arguments for GetMapper, but typically contain additional information like /// vector lengths. /// - public interface IValueMapper + [BestFriend] + internal interface IValueMapper { ColumnType InputType { get; } ColumnType OutputType { get; } @@ -43,7 +44,8 @@ public interface IValueMapper /// type arguments for GetMapper, but typically contain additional information like /// vector lengths. /// - public interface IValueMapperDist : IValueMapper + [BestFriend] + internal interface IValueMapperDist : IValueMapper { ColumnType DistType { get; } diff --git a/src/Microsoft.ML.Core/Data/InPredicate.cs b/src/Microsoft.ML.Core/Data/InPredicate.cs index 74d16e906d..3c35bf600a 100644 --- a/src/Microsoft.ML.Core/Data/InPredicate.cs +++ b/src/Microsoft.ML.Core/Data/InPredicate.cs @@ -2,7 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { public delegate bool InPredicate(in T value); } diff --git a/src/Microsoft.ML.Core/Data/LinkedRootCursorBase.cs b/src/Microsoft.ML.Core/Data/LinkedRootCursorBase.cs index ecec0b1a0d..efb785a835 100644 --- a/src/Microsoft.ML.Core/Data/LinkedRootCursorBase.cs +++ b/src/Microsoft.ML.Core/Data/LinkedRootCursorBase.cs @@ -2,43 +2,45 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// Base class for a cursor has an input cursor, but still needs to do work on - /// MoveNext/MoveMany. + /// / . /// [BestFriend] - internal abstract class LinkedRootCursorBase : RootCursorBase - where TInput : class, ICursor + internal abstract class LinkedRootCursorBase : RootCursorBase { - private readonly ICursor _root; /// Gets the input cursor. - protected TInput Input { get; } + protected RowCursor Input { get; } /// /// Returns the root cursor of the input. It should be used to perform MoveNext or MoveMany operations. - /// Note that GetRootCursor() returns "this", NOT Root. Root is used to advance our input, not for - /// clients of this cursor. That's why it is protected, not public. + /// Note that returns , not . + /// is used to advance our input, not for clients of this cursor. That is why it is + /// protected, not public. /// - protected ICursor Root { get { return _root; } } + protected RowCursor Root { get; } - protected LinkedRootCursorBase(IChannelProvider provider, TInput input) + protected LinkedRootCursorBase(IChannelProvider provider, RowCursor input) : base(provider) { Ch.AssertValue(input, nameof(input)); Input = input; - _root = Input.GetRootCursor(); + Root = Input.GetRootCursor(); } - public override void Dispose() + protected override void Dispose(bool disposing) { - if (State != CursorState.Done) + if (State == CursorState.Done) + return; + if (disposing) { Input.Dispose(); - base.Dispose(); + // The base class should set the state to done under these circumstances. + base.Dispose(true); } } } diff --git a/src/Microsoft.ML.Core/Data/LinkedRowFilterCursorBase.cs b/src/Microsoft.ML.Core/Data/LinkedRowFilterCursorBase.cs index 0ed4dd19f9..fa35240ad1 100644 --- a/src/Microsoft.ML.Core/Data/LinkedRowFilterCursorBase.cs +++ b/src/Microsoft.ML.Core/Data/LinkedRowFilterCursorBase.cs @@ -2,9 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Data; - -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// Base class for creating a cursor of rows that filters out some input rows. @@ -14,12 +12,12 @@ internal abstract class LinkedRowFilterCursorBase : LinkedRowRootCursorBase { public override long Batch => Input.Batch; - protected LinkedRowFilterCursorBase(IChannelProvider provider, IRowCursor input, Schema schema, bool[] active) + protected LinkedRowFilterCursorBase(IChannelProvider provider, RowCursor input, Schema schema, bool[] active) : base(provider, input, schema, active) { } - public override ValueGetter GetIdGetter() + public override ValueGetter GetIdGetter() { return Input.GetIdGetter(); } diff --git a/src/Microsoft.ML.Core/Data/LinkedRowRootCursorBase.cs b/src/Microsoft.ML.Core/Data/LinkedRowRootCursorBase.cs index 188522ad3d..1f427ddeb9 100644 --- a/src/Microsoft.ML.Core/Data/LinkedRowRootCursorBase.cs +++ b/src/Microsoft.ML.Core/Data/LinkedRowRootCursorBase.cs @@ -2,25 +2,23 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Data; - -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// - /// A base class for a that has an input cursor, but still needs - /// to do work on /. Note + /// A base class for a that has an input cursor, but still needs + /// to do work on /. Note /// that the default assumes /// that each input column is exposed as an output column with the same column index. /// [BestFriend] - internal abstract class LinkedRowRootCursorBase : LinkedRootCursorBase, IRowCursor + internal abstract class LinkedRowRootCursorBase : LinkedRootCursorBase { private readonly bool[] _active; /// Gets row's schema. - public Schema Schema { get; } + public sealed override Schema Schema { get; } - protected LinkedRowRootCursorBase(IChannelProvider provider, IRowCursor input, Schema schema, bool[] active) + protected LinkedRowRootCursorBase(IChannelProvider provider, RowCursor input, Schema schema, bool[] active) : base(provider, input) { Ch.CheckValue(schema, nameof(schema)); @@ -29,13 +27,13 @@ protected LinkedRowRootCursorBase(IChannelProvider provider, IRowCursor input, S Schema = schema; } - public bool IsColumnActive(int col) + public sealed override bool IsColumnActive(int col) { Ch.Check(0 <= col && col < Schema.Count); return _active == null || _active[col]; } - public virtual ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { return Input.GetGetter(col); } diff --git a/src/Microsoft.ML.Core/Data/MetadataBuilder.cs b/src/Microsoft.ML.Core/Data/MetadataBuilder.cs index 0b72f53a25..65cafd2627 100644 --- a/src/Microsoft.ML.Core/Data/MetadataBuilder.cs +++ b/src/Microsoft.ML.Core/Data/MetadataBuilder.cs @@ -2,12 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; using System; using System.Collections.Generic; using System.Linq; +using Microsoft.ML.Internal.Utilities; namespace Microsoft.ML.Data { @@ -16,11 +14,11 @@ namespace Microsoft.ML.Data /// public sealed class MetadataBuilder { - private readonly List<(string Name, ColumnType Type, Delegate Getter)> _items; + private readonly List<(string Name, ColumnType Type, Delegate Getter, Schema.Metadata Metadata)> _items; public MetadataBuilder() { - _items = new List<(string Name, ColumnType Type, Delegate Getter)>(); + _items = new List<(string Name, ColumnType Type, Delegate Getter, Schema.Metadata Metadata)>(); } /// @@ -40,7 +38,7 @@ public void Add(Schema.Metadata metadata, Func selector) foreach (var column in metadata.Schema) { if (selector(column.Name)) - _items.Add((column.Name, column.Type, metadata.Getters[column.Index])); + _items.Add((column.Name, column.Type, metadata.Getters[column.Index], column.Metadata)); } } @@ -51,13 +49,17 @@ public void Add(Schema.Metadata metadata, Func selector) /// The metadata name. /// The metadata type. /// The getter delegate. - public void Add(string name, ColumnType type, ValueGetter getter) + /// Metadata of the input column. Note that metadata on a metadata column is somewhat rare + /// except for certain types (for example, slot names for a vector, key values for something of key type). + public void Add(string name, ColumnType type, ValueGetter getter, Schema.Metadata metadata = null) { Contracts.CheckNonEmpty(name, nameof(name)); Contracts.CheckValue(type, nameof(type)); Contracts.CheckValue(getter, nameof(getter)); - Contracts.CheckParam(type.RawType == typeof(TValue), nameof(getter)); - _items.Add((name, type, getter)); + Contracts.CheckParam(type.RawType == typeof(TValue), nameof(type)); + Contracts.CheckValueOrNull(metadata); + + _items.Add((name, type, getter, metadata)); } /// @@ -67,11 +69,31 @@ public void Add(string name, ColumnType type, ValueGetter getter /// The metadata type. /// The getter delegate that provides the value. Note that the type of the getter is still checked /// inside this method. - public void Add(string name, ColumnType type, Delegate getter) + /// Metadata of the input column. Note that metadata on a metadata column is somewhat rare + /// except for certain types (for example, slot names for a vector, key values for something of key type). + public void Add(string name, ColumnType type, Delegate getter, Schema.Metadata metadata = null) { Contracts.CheckNonEmpty(name, nameof(name)); Contracts.CheckValue(type, nameof(type)); - Utils.MarshalActionInvoke(AddDelegate, type.RawType, name, type, getter); + Contracts.CheckValueOrNull(metadata); + Utils.MarshalActionInvoke(AddDelegate, type.RawType, name, type, getter, metadata); + } + + /// + /// Add one metadata column for a primitive value type. + /// + /// The metadata name. + /// The metadata type. + /// The value of the metadata. + /// Metadata of the input column. Note that metadata on a metadata column is somewhat rare + /// except for certain types (for example, slot names for a vector, key values for something of key type). + public void AddPrimitiveValue(string name, PrimitiveType type, TValue value, Schema.Metadata metadata = null) + { + Contracts.CheckNonEmpty(name, nameof(name)); + Contracts.CheckValue(type, nameof(type)); + Contracts.CheckParam(type.RawType == typeof(TValue), nameof(type)); + Contracts.CheckValueOrNull(metadata); + Add(name, type, (ref TValue dst) => dst = value, metadata); } /// @@ -100,11 +122,11 @@ public Schema.Metadata GetMetadata() { var builder = new SchemaBuilder(); foreach (var item in _items) - builder.AddColumn(item.Name, item.Type, null); + builder.AddColumn(item.Name, item.Type, item.Metadata); return new Schema.Metadata(builder.GetSchema(), _items.Select(x => x.Getter).ToArray()); } - private void AddDelegate(string name, ColumnType type, Delegate getter) + private void AddDelegate(string name, ColumnType type, Delegate getter, Schema.Metadata metadata) { Contracts.AssertNonEmpty(name); Contracts.AssertValue(type); @@ -112,7 +134,7 @@ private void AddDelegate(string name, ColumnType type, Delegate getter) var typedGetter = getter as ValueGetter; Contracts.CheckParam(typedGetter != null, nameof(getter)); - _items.Add((name, type, typedGetter)); + _items.Add((name, type, typedGetter, metadata)); } } } diff --git a/src/Microsoft.ML.Core/Data/MetadataUtils.cs b/src/Microsoft.ML.Core/Data/MetadataUtils.cs index 9178c1c129..efa22d134a 100644 --- a/src/Microsoft.ML.Core/Data/MetadataUtils.cs +++ b/src/Microsoft.ML.Core/Data/MetadataUtils.cs @@ -9,10 +9,9 @@ using System.Linq; using System.Threading; using Microsoft.ML.Core.Data; -using Microsoft.ML.Data; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// Utilities for implementing and using the metadata API of . @@ -117,23 +116,20 @@ public static class ScoreValueKind /// /// Returns a standard exception for responding to an invalid call to GetMetadata. /// - public static Exception ExceptGetMetadata() - { - return Contracts.Except("Invalid call to GetMetadata"); - } + [BestFriend] + internal static Exception ExceptGetMetadata() => Contracts.Except("Invalid call to GetMetadata"); /// /// Returns a standard exception for responding to an invalid call to GetMetadata. /// - public static Exception ExceptGetMetadata(this IExceptionContext ctx) - { - return ctx.Except("Invalid call to GetMetadata"); - } + [BestFriend] + internal static Exception ExceptGetMetadata(this IExceptionContext ctx) => ctx.Except("Invalid call to GetMetadata"); /// /// Helper to marshal a call to GetMetadata{TValue} to a specific type. /// - public static void Marshal(this MetadataGetter getter, int col, ref TNeed dst) + [BestFriend] + internal static void Marshal(this MetadataGetter getter, int col, ref TNeed dst) { Contracts.CheckValue(getter, nameof(getter)); @@ -147,7 +143,8 @@ public static void Marshal(this MetadataGetter getter, int /// Returns a vector type with item type text and the given size. The size must be positive. /// This is a standard type for metadata consisting of multiple text values, eg SlotNames. /// - public static VectorType GetNamesType(int size) + [BestFriend] + internal static VectorType GetNamesType(int size) { Contracts.CheckParam(size > 0, nameof(size), "must be known size"); return new VectorType(TextType.Instance, size); @@ -159,18 +156,20 @@ public static VectorType GetNamesType(int size) /// This is a standard type for metadata consisting of multiple int values that represent /// categorical slot ranges with in a column. /// - public static VectorType GetCategoricalType(int rangeCount) + [BestFriend] + internal static VectorType GetCategoricalType(int rangeCount) { Contracts.CheckParam(rangeCount > 0, nameof(rangeCount), "must be known size"); return new VectorType(NumberType.I4, rangeCount, 2); } - private static volatile ColumnType _scoreColumnSetIdType; + private static volatile KeyType _scoreColumnSetIdType; /// /// The type of the ScoreColumnSetId metadata. /// - public static ColumnType ScoreColumnSetIdType + [BestFriend] + internal static KeyType ScoreColumnSetIdType { get { @@ -186,7 +185,8 @@ public static ColumnType ScoreColumnSetIdType /// /// Returns a key-value pair useful when implementing GetMetadataTypes(col). /// - public static KeyValuePair GetSlotNamesPair(int size) + [BestFriend] + internal static KeyValuePair GetSlotNamesPair(int size) { return GetNamesType(size).GetPair(Kinds.SlotNames); } @@ -195,7 +195,8 @@ public static KeyValuePair GetSlotNamesPair(int size) /// Returns a key-value pair useful when implementing GetMetadataTypes(col). This assumes /// that the values of the key type are Text. /// - public static KeyValuePair GetKeyNamesPair(int size) + [BestFriend] + internal static KeyValuePair GetKeyNamesPair(int size) { return GetNamesType(size).GetPair(Kinds.KeyValues); } @@ -204,7 +205,8 @@ public static KeyValuePair GetKeyNamesPair(int size) /// Given a type and metadata kind string, returns a key-value pair. This is useful when /// implementing GetMetadataTypes(col). /// - public static KeyValuePair GetPair(this ColumnType type, string kind) + [BestFriend] + internal static KeyValuePair GetPair(this ColumnType type, string kind) { Contracts.CheckValue(type, nameof(type)); return new KeyValuePair(kind, type); @@ -215,7 +217,8 @@ public static KeyValuePair GetPair(this ColumnType type, str /// /// Prepends a params array to an enumerable. Useful when implementing GetMetadataTypes. /// - public static IEnumerable Prepend(this IEnumerable tail, params T[] head) + [BestFriend] + internal static IEnumerable Prepend(this IEnumerable tail, params T[] head) { return head.Concat(tail); } @@ -234,13 +237,13 @@ public static uint GetMaxMetadataKind(this Schema schema, out int colMax, string colMax = -1; for (int col = 0; col < schema.Count; col++) { - var columnType = schema.GetMetadataTypeOrNull(metadataKind, col); + var columnType = schema[col].Metadata.Schema.GetColumnOrNull(metadataKind)?.Type; if (columnType == null || !columnType.IsKey || columnType.RawKind != DataKind.U4) continue; if (filterFunc != null && !filterFunc(schema, col)) continue; uint value = 0; - schema.GetMetadata(metadataKind, col, ref value); + schema[col].Metadata.GetValue(metadataKind, ref value); if (max < value) { max = value; @@ -254,15 +257,16 @@ public static uint GetMaxMetadataKind(this Schema schema, out int colMax, string /// Returns the set of column ids which match the value of specified metadata kind. /// The metadata type should be a KeyType with raw type U4. /// - public static IEnumerable GetColumnSet(this Schema schema, string metadataKind, uint value) + [BestFriend] + internal static IEnumerable GetColumnSet(this Schema schema, string metadataKind, uint value) { for (int col = 0; col < schema.Count; col++) { - var columnType = schema.GetMetadataTypeOrNull(metadataKind, col); + var columnType = schema[col].Metadata.Schema.GetColumnOrNull(metadataKind)?.Type; if (columnType != null && columnType.IsKey && columnType.RawKind == DataKind.U4) { uint val = 0; - schema.GetMetadata(metadataKind, col, ref val); + schema[col].Metadata.GetValue(metadataKind, ref val); if (val == value) yield return col; } @@ -273,15 +277,16 @@ public static IEnumerable GetColumnSet(this Schema schema, string metadataK /// Returns the set of column ids which match the value of specified metadata kind. /// The metadata type should be of type text. /// - public static IEnumerable GetColumnSet(this Schema schema, string metadataKind, string value) + [BestFriend] + internal static IEnumerable GetColumnSet(this Schema schema, string metadataKind, string value) { for (int col = 0; col < schema.Count; col++) { - var columnType = schema.GetMetadataTypeOrNull(metadataKind, col); + var columnType = schema[col].Metadata.Schema.GetColumnOrNull(metadataKind)?.Type; if (columnType != null && columnType.IsText) { ReadOnlyMemory val = default; - schema.GetMetadata(metadataKind, col, ref val); + schema[col].Metadata.GetValue(metadataKind, ref val); if (ReadOnlyMemoryUtils.EqualsStr(value, val)) yield return col; } @@ -290,48 +295,60 @@ public static IEnumerable GetColumnSet(this Schema schema, string metadataK /// /// Returns true if the specified column: - /// * is a vector of length N (including 0) + /// * is a vector of length N /// * has a SlotNames metadata /// * metadata type is VBuffer<ReadOnlyMemory<char>> of length N /// - public static bool HasSlotNames(this Schema schema, int col, int vectorSize) + public static bool HasSlotNames(this Schema.Column column) + => column.Type.IsKnownSizeVector && column.HasSlotNames(column.Type.VectorSize); + + /// + /// Returns true if the specified column: + /// * has a SlotNames metadata + /// * metadata type is VBuffer<ReadOnlyMemory<char>> of length . + /// + [BestFriend] + internal static bool HasSlotNames(this Schema.Column column, int vectorSize) { if (vectorSize == 0) return false; - var type = schema.GetMetadataTypeOrNull(Kinds.SlotNames, col); + var metaColumn = column.Metadata.Schema.GetColumnOrNull(Kinds.SlotNames); return - type != null - && type.IsVector - && type.VectorSize == vectorSize - && type.ItemType.IsText; + metaColumn != null + && metaColumn.Value.Type.IsVector + && metaColumn.Value.Type.VectorSize == vectorSize + && metaColumn.Value.Type.ItemType.IsText; } - public static void GetSlotNames(RoleMappedSchema schema, RoleMappedSchema.ColumnRole role, int vectorSize, ref VBuffer> slotNames) + public static void GetSlotNames(this Schema.Column column, ref VBuffer> slotNames) + => column.Metadata.GetValue(Kinds.SlotNames, ref slotNames); + + [BestFriend] + internal static void GetSlotNames(RoleMappedSchema schema, RoleMappedSchema.ColumnRole role, int vectorSize, ref VBuffer> slotNames) { Contracts.CheckValueOrNull(schema); Contracts.CheckParam(vectorSize >= 0, nameof(vectorSize)); - IReadOnlyList list; - if ((list = schema?.GetColumns(role)) == null || list.Count != 1 || !schema.Schema.HasSlotNames(list[0].Index, vectorSize)) - { + IReadOnlyList list = schema?.GetColumns(role); + if (list?.Count != 1 || !schema.Schema[list[0].Index].HasSlotNames(vectorSize)) VBufferUtils.Resize(ref slotNames, vectorSize, 0); - } else - schema.Schema.GetMetadata(Kinds.SlotNames, list[0].Index, ref slotNames); + schema.Schema[list[0].Index].Metadata.GetValue(Kinds.SlotNames, ref slotNames); } - public static bool HasKeyValues(this Schema schema, int col, int keyCount) + [BestFriend] + internal static bool HasKeyValues(this Schema.Column column, int keyCount) { if (keyCount == 0) return false; - var type = schema.GetMetadataTypeOrNull(Kinds.KeyValues, col); + var metaColumn = column.Metadata.Schema.GetColumnOrNull(Kinds.KeyValues); return - type != null - && type.IsVector - && type.VectorSize == keyCount - && type.ItemType.IsText; + metaColumn != null + && metaColumn.Value.Type.IsVector + && metaColumn.Value.Type.VectorSize == keyCount + && metaColumn.Value.Type.ItemType.IsText; } [BestFriend] @@ -343,19 +360,17 @@ internal static bool HasKeyValues(this SchemaShape.Column col) } /// - /// Returns whether a column has the metadata set to true. - /// That metadata should be set when the data has undergone transforms that would render it - /// "normalized." + /// Returns true iff has IsNormalized metadata set to true. /// - /// The schema to query - /// Which column in the schema to query - /// True if and only if the column has the metadata - /// set to the scalar value true - public static bool IsNormalized(this Schema schema, int col) + public static bool IsNormalized(this Schema.Column column) { - Contracts.CheckValue(schema, nameof(schema)); - var value = default(bool); - return schema.TryGetMetadata(BoolType.Instance, Kinds.IsNormalized, col, ref value) && value; + var metaColumn = column.Metadata.Schema.GetColumnOrNull((Kinds.IsNormalized)); + if (metaColumn == null || !metaColumn.Value.Type.IsBool) + return false; + + bool value = default; + column.Metadata.GetValue(Kinds.IsNormalized, ref value); + return value; } /// @@ -367,7 +382,7 @@ public static bool IsNormalized(this Schema schema, int col) /// of a scalar type, which we assume, if set, should be true. public static bool IsNormalized(this SchemaShape.Column col) { - Contracts.CheckValue(col, nameof(col)); + Contracts.CheckParam(col.IsValid, nameof(col), "struct not initialized properly"); return col.Metadata.TryFindColumn(Kinds.IsNormalized, out var metaCol) && metaCol.Kind == SchemaShape.Column.VectorKind.Scalar && !metaCol.IsKey && metaCol.ItemType == BoolType.Instance; @@ -382,7 +397,7 @@ public static bool IsNormalized(this SchemaShape.Column col) /// metadata of definite sized vectors of text. public static bool HasSlotNames(this SchemaShape.Column col) { - Contracts.CheckValue(col, nameof(col)); + Contracts.CheckParam(col.IsValid, nameof(col), "struct not initialized properly"); return col.Kind == SchemaShape.Column.VectorKind.Vector && col.Metadata.TryFindColumn(Kinds.SlotNames, out var metaCol) && metaCol.Kind == SchemaShape.Column.VectorKind.Vector && !metaCol.IsKey @@ -399,23 +414,19 @@ public static bool HasSlotNames(this SchemaShape.Column col) /// The column /// The value to return, if successful /// True if the metadata of the right type exists, false otherwise - public static bool TryGetMetadata(this Schema schema, PrimitiveType type, string kind, int col, ref T value) + [BestFriend] + internal static bool TryGetMetadata(this Schema schema, PrimitiveType type, string kind, int col, ref T value) { Contracts.CheckValue(schema, nameof(schema)); Contracts.CheckValue(type, nameof(type)); - var metadataType = schema.GetMetadataTypeOrNull(kind, col); + var metadataType = schema[col].Metadata.Schema.GetColumnOrNull(kind)?.Type; if (!type.Equals(metadataType)) return false; - schema.GetMetadata(kind, col, ref value); + schema[col].Metadata.GetValue(kind, ref value); return true; } - /// - /// Return whether the given column index is hidden in the given schema. - /// - public static bool IsHidden(this Schema schema, int col) => schema[col].IsHidden; - /// /// The categoricalFeatures is a vector of the indices of categorical features slots. /// This vector should always have an even number of elements, and the elements should be parsed in groups of two consecutive numbers. @@ -424,21 +435,22 @@ public static bool TryGetMetadata(this Schema schema, PrimitiveType type, str /// The way to interpret that is: feature with indices 0, 1, and 2 are one categorical /// Features with indices 3 and 4 are another categorical. Features 5 and 6 don't appear there, so they are not categoricals. /// - public static bool TryGetCategoricalFeatureIndices(Schema schema, int colIndex, out int[] categoricalFeatures) + [BestFriend] + internal static bool TryGetCategoricalFeatureIndices(Schema schema, int colIndex, out int[] categoricalFeatures) { Contracts.CheckValue(schema, nameof(schema)); Contracts.Check(colIndex >= 0, nameof(colIndex)); bool isValid = false; categoricalFeatures = null; - if (!(schema.GetColumnType(colIndex) is VectorType vecType && vecType.Size > 0)) + if (!(schema[colIndex].Type is VectorType vecType && vecType.Size > 0)) return isValid; - var type = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.CategoricalSlotRanges, colIndex); + var type = schema[colIndex].Metadata.Schema.GetColumnOrNull(Kinds.CategoricalSlotRanges)?.Type; if (type?.RawType == typeof(VBuffer)) { VBuffer catIndices = default(VBuffer); - schema.GetMetadata(MetadataUtils.Kinds.CategoricalSlotRanges, colIndex, ref catIndices); + schema[colIndex].Metadata.GetValue(Kinds.CategoricalSlotRanges, ref catIndices); VBufferUtils.Densify(ref catIndices); int columnSlotsCount = vecType.Size; if (catIndices.Length > 0 && catIndices.Length % 2 == 0 && catIndices.Length <= columnSlotsCount * 2) @@ -471,7 +483,8 @@ public static bool TryGetCategoricalFeatureIndices(Schema schema, int colIndex, /// Produces sequence of columns that are generated by trainer estimators. /// /// whether we should also append 'IsNormalized' (typically for probability column) - public static IEnumerable GetTrainerOutputMetadata(bool isNormalized = false) + [BestFriend] + internal static IEnumerable GetTrainerOutputMetadata(bool isNormalized = false) { var cols = new List(); cols.Add(new SchemaShape.Column(Kinds.ScoreColumnSetId, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true)); @@ -483,16 +496,48 @@ public static bool TryGetCategoricalFeatureIndices(Schema schema, int colIndex, } /// - /// Produces sequence of columns that are generated by multiclass trainer estimators. + /// Produces metadata for the score column generated by trainer estimators for multiclass classification. + /// If input LabelColumn is not available it produces slotnames metadata by default. /// /// Label column. - public static IEnumerable MetadataForMulticlassScoreColumn(SchemaShape.Column labelColumn) + [BestFriend] + internal static IEnumerable MetadataForMulticlassScoreColumn(SchemaShape.Column? labelColumn = null) { var cols = new List(); - if (labelColumn.IsKey && HasKeyValues(labelColumn)) + if (labelColumn != null && labelColumn.Value.IsKey && HasKeyValues(labelColumn.Value)) cols.Add(new SchemaShape.Column(Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false)); cols.AddRange(GetTrainerOutputMetadata()); return cols; } + + private sealed class MetadataRow : Row + { + private readonly Schema.Metadata _metadata; + + public MetadataRow(Schema.Metadata metadata) + { + Contracts.AssertValue(metadata); + _metadata = metadata; + } + + public override Schema Schema => _metadata.Schema; + public override long Position => 0; + public override long Batch => 0; + public override ValueGetter GetGetter(int col) => _metadata.GetGetter(col); + public override ValueGetter GetIdGetter() => (ref RowId dst) => dst = default; + public override bool IsColumnActive(int col) => true; + } + + /// + /// Presents a as a an . + /// + /// The metadata to wrap. + /// A row that wraps an input metadata. + [BestFriend] + internal static Row MetadataAsRow(Schema.Metadata metadata) + { + Contracts.CheckValue(metadata, nameof(metadata)); + return new MetadataRow(metadata); + } } } \ No newline at end of file diff --git a/src/Microsoft.ML.Core/Data/ProgressReporter.cs b/src/Microsoft.ML.Core/Data/ProgressReporter.cs index 191364e2a3..e49c559571 100644 --- a/src/Microsoft.ML.Core/Data/ProgressReporter.cs +++ b/src/Microsoft.ML.Core/Data/ProgressReporter.cs @@ -7,9 +7,9 @@ using System.Collections.Generic; using System.Linq; using System.Threading; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// The progress reporting classes used by descendants. diff --git a/src/Microsoft.ML.Core/Data/ReadOnlyMemoryUtils.cs b/src/Microsoft.ML.Core/Data/ReadOnlyMemoryUtils.cs index 20ebb85b04..e776701030 100644 --- a/src/Microsoft.ML.Core/Data/ReadOnlyMemoryUtils.cs +++ b/src/Microsoft.ML.Core/Data/ReadOnlyMemoryUtils.cs @@ -2,13 +2,13 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime.Internal.Utilities; using System; using System.Collections.Generic; using System.Runtime.InteropServices; using System.Text; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { [BestFriend] internal static class ReadOnlyMemoryUtils diff --git a/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs b/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs index a8d22319b3..acd5c957e2 100644 --- a/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs +++ b/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs @@ -3,76 +3,12 @@ // See the LICENSE file in the project root for more information. using System.Collections.Generic; -using Microsoft.ML.Data; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// - /// This contains information about a column in an . It is essentially a convenience cache - /// containing the name, column index, and column type for the column. The intended usage is that users of - /// will have a convenient method of getting the index and type without having to separately query it through the , - /// since practically the first thing a consumer of a will want to do once they get a mappping is get - /// the type and index of the corresponding column. - /// - public sealed class ColumnInfo - { - public readonly string Name; - public readonly int Index; - public readonly ColumnType Type; - - private ColumnInfo(string name, int index, ColumnType type) - { - Name = name; - Index = index; - Type = type; - } - - /// - /// Create a ColumnInfo for the column with the given name in the given schema. Throws if the name - /// doesn't map to a column. - /// - public static ColumnInfo CreateFromName(ISchema schema, string name, string descName) - { - if (!TryCreateFromName(schema, name, out var colInfo)) - throw Contracts.ExceptParam(nameof(name), $"{descName} column '{name}' not found"); - - return colInfo; - } - - /// - /// Tries to create a ColumnInfo for the column with the given name in the given schema. Returns - /// false if the name doesn't map to a column. - /// - public static bool TryCreateFromName(ISchema schema, string name, out ColumnInfo colInfo) - { - Contracts.CheckValue(schema, nameof(schema)); - Contracts.CheckNonEmpty(name, nameof(name)); - - colInfo = null; - if (!schema.TryGetColumnIndex(name, out int index)) - return false; - - colInfo = new ColumnInfo(name, index, schema.GetColumnType(index)); - return true; - } - - /// - /// Creates a ColumnInfo for the column with the given column index. Note that the name - /// of the column might actually map to a different column, so this should be used with care - /// and rarely. - /// - public static ColumnInfo CreateFromIndex(ISchema schema, int index) - { - Contracts.CheckValue(schema, nameof(schema)); - Contracts.CheckParam(0 <= index && index < schema.ColumnCount, nameof(index)); - - return new ColumnInfo(schema.GetColumnName(index), index, schema.GetColumnType(index)); - } - } - - /// - /// Encapsulates an plus column role mapping information. The purpose of role mappings is to + /// Encapsulates an plus column role mapping information. The purpose of role mappings is to /// provide information on what the intended usage is for. That is: while a given data view may have a column named /// "Features", by itself that is insufficient: the trainer must be fed a role mapping that says that the role /// mapping for features is filled by that "Features" column. This allows things like columns not named "Features" @@ -88,7 +24,7 @@ public static ColumnInfo CreateFromIndex(ISchema schema, int index) /// in this schema. /// /// - /// Note that instances of this class are, like instances of , immutable. + /// Note that instances of this class are, like instances of , immutable. /// /// It is often the case that one wishes to bundle the actual data with the role mappings, not just the schema. For /// that case, please use the class. @@ -100,7 +36,8 @@ public static ColumnInfo CreateFromIndex(ISchema schema, int index) /// /// /// - public sealed class RoleMappedSchema + [BestFriend] + internal sealed class RoleMappedSchema { private const string FeatureString = "Feature"; private const string LabelString = "Label"; @@ -184,39 +121,39 @@ public static KeyValuePair CreatePair(ColumnRole role, strin => new KeyValuePair(role, name); /// - /// The source . + /// The source . /// public Schema Schema { get; } /// /// The column, when there is exactly one (null otherwise). /// - public ColumnInfo Feature { get; } + public Schema.Column? Feature { get; } /// /// The column, when there is exactly one (null otherwise). /// - public ColumnInfo Label { get; } + public Schema.Column? Label { get; } /// /// The column, when there is exactly one (null otherwise). /// - public ColumnInfo Group { get; } + public Schema.Column? Group { get; } /// /// The column, when there is exactly one (null otherwise). /// - public ColumnInfo Weight { get; } + public Schema.Column? Weight { get; } /// /// The column, when there is exactly one (null otherwise). /// - public ColumnInfo Name { get; } + public Schema.Column? Name { get; } // Maps from role to the associated column infos. - private readonly Dictionary> _map; + private readonly Dictionary> _map; - private RoleMappedSchema(Schema schema, Dictionary> map) + private RoleMappedSchema(Schema schema, Dictionary> map) { Contracts.AssertValue(schema); Contracts.AssertValue(map); @@ -229,7 +166,7 @@ private RoleMappedSchema(Schema schema, Dictionary - public IReadOnlyList GetColumns(ColumnRole role) + public IReadOnlyList GetColumns(ColumnRole role) => _map.TryGetValue(role.Value, out var list) ? list : null; /// /// An enumerable over all role-column associations within this object. /// - public IEnumerable> GetColumnRoles() + public IEnumerable> GetColumnRoles() { foreach (var roleAndList in _map) { foreach (var info in roleAndList.Value) - yield return new KeyValuePair(roleAndList.Key, info); + yield return new KeyValuePair(roleAndList.Key, info); } } @@ -358,13 +293,13 @@ public IEnumerable> GetColumnRoleNames(ColumnRo } /// - /// Returns the corresponding to if there is + /// Returns the corresponding to if there is /// exactly one such mapping, and otherwise throws an exception. /// /// The role to look up - /// The info corresponding to that role, assuming there was only one column + /// The column corresponding to that role, assuming there was only one column /// mapped to that - public ColumnInfo GetUniqueColumn(ColumnRole role) + public Schema.Column GetUniqueColumn(ColumnRole role) { var infos = GetColumns(role); if (Utils.Size(infos) != 1) @@ -372,9 +307,9 @@ public ColumnInfo GetUniqueColumn(ColumnRole role) return infos[0]; } - private static Dictionary> Copy(Dictionary> map) + private static Dictionary> Copy(Dictionary> map) { - var copy = new Dictionary>(map.Count); + var copy = new Dictionary>(map.Count); foreach (var kvp in map) { Contracts.Assert(Utils.Size(kvp.Value) > 0); @@ -467,9 +402,10 @@ public RoleMappedSchema(Schema schema, string label, string feature, /// /// Encapsulates an plus a corresponding . /// Note that the schema of of is - /// guaranteed to equal the the of . + /// guaranteed to equal the the of . /// - public sealed class RoleMappedData + [BestFriend] + internal sealed class RoleMappedData { /// /// The data. @@ -478,7 +414,7 @@ public sealed class RoleMappedData /// /// The role mapped schema. Note that 's is - /// guaranteed to be the same as 's . + /// guaranteed to be the same as 's . /// public RoleMappedSchema Schema { get; } diff --git a/src/Microsoft.ML.Core/Data/RootCursorBase.cs b/src/Microsoft.ML.Core/Data/RootCursorBase.cs index 5b64a40d6a..dae6c32888 100644 --- a/src/Microsoft.ML.Core/Data/RootCursorBase.cs +++ b/src/Microsoft.ML.Core/Data/RootCursorBase.cs @@ -2,9 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; - -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { // REVIEW: Since each cursor will create a channel, it would be great that the RootCursorBase takes // ownership of the channel so the derived classes don't have to. @@ -15,23 +13,21 @@ namespace Microsoft.ML.Runtime.Data /// This cursor base class returns "this" from . That is, all /// / calls will be seen by this cursor. For a cursor /// that has an input cursor and does NOT need notification on /, - /// use . + /// use . /// [BestFriend] - internal abstract class RootCursorBase : ICursor + internal abstract class RootCursorBase : RowCursor { protected readonly IChannel Ch; + private CursorState _state; + private long _position; /// /// Zero-based position of the cursor. /// - public long Position { get; private set; } - - public abstract long Batch { get; } - - public abstract ValueGetter GetIdGetter(); + public sealed override long Position => _position; - public CursorState State { get; private set; } + public sealed override CursorState State => _state; /// /// Convenience property for checking whether the current state of the cursor is . @@ -39,7 +35,7 @@ internal abstract class RootCursorBase : ICursor protected bool IsGood => State == CursorState.Good; /// - /// Creates an instance of the RootCursorBase class + /// Creates an instance of the class /// /// Channel provider protected RootCursorBase(IChannelProvider provider) @@ -47,21 +43,21 @@ protected RootCursorBase(IChannelProvider provider) Contracts.CheckValue(provider, nameof(provider)); Ch = provider.Start("Cursor"); - Position = -1; - State = CursorState.NotStarted; + _position = -1; + _state = CursorState.NotStarted; } - public virtual void Dispose() + protected override void Dispose(bool disposing) { - if (State != CursorState.Done) - { + if (State == CursorState.Done) + return; + if (disposing) Ch.Dispose(); - Position = -1; - State = CursorState.Done; - } + _position = -1; + _state = CursorState.Done; } - public bool MoveNext() + public sealed override bool MoveNext() { if (State == CursorState.Done) return false; @@ -71,8 +67,8 @@ public bool MoveNext() { Ch.Assert(State == CursorState.NotStarted || State == CursorState.Good); - Position++; - State = CursorState.Good; + _position++; + _state = CursorState.Good; return true; } @@ -80,7 +76,7 @@ public bool MoveNext() return false; } - public bool MoveMany(long count) + public sealed override bool MoveMany(long count) { // Note: If we decide to allow count == 0, then we need to special case // that MoveNext() has never been called. It's not entirely clear what the return @@ -95,8 +91,8 @@ public bool MoveMany(long count) { Ch.Assert(State == CursorState.NotStarted || State == CursorState.Good); - Position += count; - State = CursorState.Good; + _position += count; + _state = CursorState.Good; return true; } @@ -137,6 +133,6 @@ protected virtual bool MoveManyCore(long count) /// those on this cursor. Generally, if the root cursor is not the same as this cursor, using /// the root cursor will be faster. /// - public ICursor GetRootCursor() => this; + public override RowCursor GetRootCursor() => this; } } \ No newline at end of file diff --git a/src/Microsoft.ML.Core/Data/UInt128.cs b/src/Microsoft.ML.Core/Data/RowId.cs similarity index 58% rename from src/Microsoft.ML.Core/Data/UInt128.cs rename to src/Microsoft.ML.Core/Data/RowId.cs index 017068a65d..315b1939b1 100644 --- a/src/Microsoft.ML.Core/Data/UInt128.cs +++ b/src/Microsoft.ML.Core/Data/RowId.cs @@ -4,114 +4,123 @@ using System; using System.Runtime.CompilerServices; -using Microsoft.ML.Runtime.Internal.Utilities; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// - /// A sixteen-byte unsigned integer. + /// A structure serving as a sixteen-byte unsigned integer. It is used as the row id of . + /// For datasets with millions of records, those IDs need to be unique, therefore the need for such a large structure to hold the values. + /// Those Ids are derived from other Ids of the previous components of the pipelines, and dividing the structure in two: high order and low order of bits, + /// and reduces the changes of those collisions even further. /// - public readonly struct UInt128 : IComparable, IEquatable + /// + public readonly struct RowId : IComparable, IEquatable { - // The low order bits. Corresponds to H1 in the Murmur algorithms. - public readonly ulong Lo; - // The high order bits. Corresponds to H2 in the Murmur algorithms. - public readonly ulong Hi; + ///The low order bits. Corresponds to H1 in the Murmur algorithms. + public readonly ulong Low; - public UInt128(ulong lo, ulong hi) + /// The high order bits. Corresponds to H2 in the Murmur algorithms. + public readonly ulong High; + + /// + /// Initializes a new instance of + /// + /// The low order ulong. + /// The high order ulong. + public RowId(ulong low, ulong high) { - Lo = lo; - Hi = hi; + Low = low; + High = high; } public override string ToString() { // Since H1 are the low order bits, they are printed second. - return string.Format("{0:x16}{1:x16}", Hi, Lo); + return string.Format("{0:x16}{1:x16}", High, Low); } - public int CompareTo(UInt128 other) + public int CompareTo(RowId other) { - int result = Hi.CompareTo(other.Hi); - return result == 0 ? Lo.CompareTo(other.Lo) : result; + int result = High.CompareTo(other.High); + return result == 0 ? Low.CompareTo(other.Low) : result; } - public bool Equals(UInt128 other) + public bool Equals(RowId other) { - return Lo == other.Lo && Hi == other.Hi; + return Low == other.Low && High == other.High; } public override bool Equals(object obj) { - if (obj != null && obj is UInt128) + if (obj != null && obj is RowId) { - var item = (UInt128)obj; + var item = (RowId)obj; return Equals(item); } return false; } - public static UInt128 operator +(UInt128 first, ulong second) + public static RowId operator +(RowId first, ulong second) { - ulong resHi = first.Hi; - ulong resLo = first.Lo + second; + ulong resHi = first.High; + ulong resLo = first.Low + second; if (resLo < second) resHi++; - return new UInt128(resLo, resHi); + return new RowId(resLo, resHi); } - public static UInt128 operator -(UInt128 first, ulong second) + public static RowId operator -(RowId first, ulong second) { - ulong resHi = first.Hi; - ulong resLo = first.Lo - second; - if (resLo > first.Lo) + ulong resHi = first.High; + ulong resLo = first.Low - second; + if (resLo > first.Low) resHi--; - return new UInt128(resLo, resHi); + return new RowId(resLo, resHi); } - public static bool operator ==(UInt128 first, ulong second) + public static bool operator ==(RowId first, ulong second) { - return first.Hi == 0 && first.Lo == second; + return first.High == 0 && first.Low == second; } - public static bool operator !=(UInt128 first, ulong second) + public static bool operator !=(RowId first, ulong second) { return !(first == second); } - public static bool operator <(UInt128 first, ulong second) + public static bool operator <(RowId first, ulong second) { - return first.Hi == 0 && first.Lo < second; + return first.High == 0 && first.Low < second; } - public static bool operator >(UInt128 first, ulong second) + public static bool operator >(RowId first, ulong second) { - return first.Hi > 0 || first.Lo > second; + return first.High > 0 || first.Low > second; } - public static bool operator <=(UInt128 first, ulong second) + public static bool operator <=(RowId first, ulong second) { - return first.Hi == 0 && first.Lo <= second; + return first.High == 0 && first.Low <= second; } - public static bool operator >=(UInt128 first, ulong second) + public static bool operator >=(RowId first, ulong second) { - return first.Hi > 0 || first.Lo >= second; + return first.High > 0 || first.Low >= second; } - public static explicit operator double(UInt128 x) + public static explicit operator double(RowId x) { // REVIEW: The 64-bit JIT has a bug where rounding might be not quite // correct when converting a ulong to double with the high bit set. Should we // care and compensate? See the DoubleParser code for a work-around. - return x.Hi * ((double)(1UL << 32) * (1UL << 32)) + x.Lo; + return x.High * ((double)(1UL << 32) * (1UL << 32)) + x.Low; } public override int GetHashCode() { return (int)( - (uint)Lo ^ (uint)(Lo >> 32) ^ - (uint)(Hi << 7) ^ (uint)(Hi >> 57) ^ (uint)(Hi >> (57 - 32))); + (uint)Low ^ (uint)(Low >> 32) ^ + (uint)(High << 7) ^ (uint)(High >> 57) ^ (uint)(High >> (57 - 32))); } #region Hashing style @@ -167,10 +176,10 @@ private static void FinalMix(ref ulong h1, ref ulong h2, int len) /// that were all zeros, except for the last bit which is one. /// [MethodImpl(MethodImplOptions.AggressiveInlining)] - public UInt128 Fork() + public RowId Fork() { - ulong h1 = Lo; - ulong h2 = Hi; + ulong h1 = Low; + ulong h2 = High; // Here it's as if k1=1, k2=0. h1 = RotL(h1, 27); h1 += h2; @@ -179,7 +188,7 @@ public UInt128 Fork() h2 += h1; h2 = h2 * 5 + 0x38495ab5; h1 ^= RotL(_c1, 31) * _c2; - return new UInt128(h1, h2); + return new RowId(h1, h2); } /// @@ -188,10 +197,10 @@ public UInt128 Fork() /// that were all zeros. /// [MethodImpl(MethodImplOptions.AggressiveInlining)] - public UInt128 Next() + public RowId Next() { - ulong h1 = Lo; - ulong h2 = Hi; + ulong h1 = Low; + ulong h2 = High; // Here it's as if k1=0, k2=0. h1 = RotL(h1, 27); h1 += h2; @@ -199,7 +208,7 @@ public UInt128 Next() h2 = RotL(h2, 31); h2 += h1; h2 = h2 * 5 + 0x38495ab5; - return new UInt128(h1, h2); + return new RowId(h1, h2); } /// @@ -210,14 +219,14 @@ public UInt128 Next() /// /// [MethodImpl(MethodImplOptions.AggressiveInlining)] - public UInt128 Combine(UInt128 other) + public RowId Combine(RowId other) { - var h1 = Lo; - var h2 = Hi; + var h1 = Low; + var h2 = High; other = other.Fork(); - ulong k1 = other.Lo; // First 8 bytes. - ulong k2 = other.Hi; // Second 8 bytes. + ulong k1 = other.Low; // First 8 bytes. + ulong k2 = other.High; // Second 8 bytes. k1 *= _c1; k1 = RotL(k1, 31); @@ -235,7 +244,7 @@ public UInt128 Combine(UInt128 other) h2 += h1; h2 = h2 * 5 + 0x38495ab5; - return new UInt128(h1, h2); + return new RowId(h1, h2); } #endregion } diff --git a/src/Microsoft.ML.Core/Data/Schema.cs b/src/Microsoft.ML.Core/Data/Schema.cs index 9be55ca48f..0977865544 100644 --- a/src/Microsoft.ML.Core/Data/Schema.cs +++ b/src/Microsoft.ML.Core/Data/Schema.cs @@ -2,19 +2,17 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; using System; using System.Collections; using System.Collections.Generic; using System.Collections.Immutable; using System.Linq; +using Microsoft.ML.Internal.Utilities; namespace Microsoft.ML.Data { /// - /// This class represents the schema of an object (like an or an ). + /// This class represents the of an object like, for interstance, an or an . /// On the high level, the schema is a collection of 'columns'. Each column has the following properties: /// - Column name. /// - Column type. @@ -22,16 +20,11 @@ namespace Microsoft.ML.Data /// and values. /// [System.Diagnostics.DebuggerTypeProxy(typeof(SchemaDebuggerProxy))] - public sealed class Schema : ISchema, IReadOnlyList + public sealed class Schema : IReadOnlyList { private readonly Column[] _columns; private readonly Dictionary _nameMap; - /// - /// Number of columns in the schema. - /// - public int ColumnCount => _columns.Length; - /// /// Number of columns in the schema. /// @@ -194,7 +187,7 @@ public sealed class Metadata /// public Schema Schema { get; } - public static Metadata Empty { get; } = new Metadata(new Schema(Enumerable.Empty()), new Delegate[0]); + public static Metadata Empty { get; } = new Metadata(new Schema(new Column[0]), new Delegate[0]); /// /// Create a metadata row by supplying the schema columns and the getter delegates for all the values. @@ -249,18 +242,19 @@ public void GetValue(string kind, ref TValue value) GetGetter(column.Value.Index)(ref value); } - public override string ToString() => string.Join(", ", Schema.GetColumns().Select(x => x.column.Name)); + public override string ToString() => string.Join(", ", Schema.Select(x => x.Name)); } /// /// This constructor should only be called by . /// - internal Schema(IEnumerable columns) + /// The input columns. The constructed instance takes ownership of the array. + internal Schema(Column[] columns) { Contracts.CheckValue(columns, nameof(columns)); - _columns = columns.ToArray(); + _columns = columns; _nameMap = new Dictionary(); for (int i = 0; i < _columns.Length; i++) { @@ -269,11 +263,6 @@ internal Schema(IEnumerable columns) } } - /// - /// Get all non-hidden columns as pairs of (index, ). - /// - public IEnumerable<(int index, Column column)> GetColumns() => _nameMap.Values.Select(idx => (idx, _columns[idx])); - /// /// Manufacture an instance of out of any . /// @@ -282,9 +271,6 @@ internal static Schema Create(ISchema inputSchema) { Contracts.CheckValue(inputSchema, nameof(inputSchema)); - if (inputSchema is Schema s) - return s; - var builder = new SchemaBuilder(); for (int i = 0; i < inputSchema.ColumnCount; i++) { @@ -309,42 +295,15 @@ private static Delegate GetMetadataGetterDelegate(ISchema schema, int co return getter; } - #region Legacy schema API to be removed - public string GetColumnName(int col) => this[col].Name; - - public ColumnType GetColumnType(int col) => this[col].Type; - - public IEnumerable> GetMetadataTypes(int col) - { - var meta = this[col].Metadata; - if (meta == null) - return Enumerable.Empty>(); - return meta.Schema.GetColumns().Select(c => new KeyValuePair(c.column.Name, c.column.Type)); - } - - public ColumnType GetMetadataTypeOrNull(string kind, int col) - { - var meta = this[col].Metadata; - if (meta == null) - return null; - if (((ISchema)meta.Schema).TryGetColumnIndex(kind, out int metaCol)) - return meta.Schema[metaCol].Type; - return null; - } - - public void GetMetadata(string kind, int col, ref TValue value) - { - var meta = this[col].Metadata; - if (meta == null) - throw MetadataUtils.ExceptGetMetadata(); - meta.GetValue(kind, ref value); - } - - public bool TryGetColumnIndex(string name, out int col) + /// + /// Legacy method to get the column index. + /// DO NOT USE: use instead. + /// + [BestFriend] + internal bool TryGetColumnIndex(string name, out int col) { col = GetColumnOrNull(name)?.Index ?? -1; return col >= 0; } - #endregion } } diff --git a/src/Microsoft.ML.Core/Data/SchemaBuilder.cs b/src/Microsoft.ML.Core/Data/SchemaBuilder.cs index 9fb22a7026..41f109530f 100644 --- a/src/Microsoft.ML.Core/Data/SchemaBuilder.cs +++ b/src/Microsoft.ML.Core/Data/SchemaBuilder.cs @@ -2,10 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime.Data; -using System; using System.Collections.Generic; -using System.Text; namespace Microsoft.ML.Data { @@ -30,8 +27,11 @@ public SchemaBuilder() /// The column name. /// The column type. /// The column metadata. - public void AddColumn(string name, ColumnType type, Schema.Metadata metadata) + public void AddColumn(string name, ColumnType type, Schema.Metadata metadata = null) { + Contracts.CheckNonEmpty(name, nameof(name)); + Contracts.CheckValue(type, nameof(type)); + Contracts.CheckValueOrNull(metadata); _items.Add((name, type, metadata)); } diff --git a/src/Microsoft.ML.Core/Data/SchemaDebuggerProxy.cs b/src/Microsoft.ML.Core/Data/SchemaDebuggerProxy.cs index 7e7c40fe03..00fff2d6a5 100644 --- a/src/Microsoft.ML.Core/Data/SchemaDebuggerProxy.cs +++ b/src/Microsoft.ML.Core/Data/SchemaDebuggerProxy.cs @@ -2,10 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; using System.Collections.Generic; using System.Linq; +using Microsoft.ML.Internal.Utilities; namespace Microsoft.ML.Data { @@ -39,10 +38,10 @@ public MetadataDebuggerProxy(Schema.Metadata metadata) private static List> BuildValues(Schema.Metadata metadata) { var result = new List>(); - foreach ((var index, var column) in metadata.Schema.GetColumns()) + foreach (var column in metadata.Schema) { var name = column.Name; - var value = Utils.MarshalInvoke(GetValue, column.Type.RawType, metadata, index); + var value = Utils.MarshalInvoke(GetValue, column.Type.RawType, metadata, column.Index); result.Add(new KeyValuePair(name, value)); } return result; diff --git a/src/Microsoft.ML.Core/Data/ServerChannel.cs b/src/Microsoft.ML.Core/Data/ServerChannel.cs index a9b33d1986..c85be7eb82 100644 --- a/src/Microsoft.ML.Core/Data/ServerChannel.cs +++ b/src/Microsoft.ML.Core/Data/ServerChannel.cs @@ -4,11 +4,9 @@ using System; using System.Collections.Generic; -using System.Reflection; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML.EntryPoints; -namespace Microsoft.ML.Runtime +namespace Microsoft.ML { /// /// Instances of this class are used to set up a bundle of named delegates. These diff --git a/src/Microsoft.ML.Core/Data/SynchronizedCursorBase.cs b/src/Microsoft.ML.Core/Data/SynchronizedCursorBase.cs index da60c84ccf..e8b94e7dbd 100644 --- a/src/Microsoft.ML.Core/Data/SynchronizedCursorBase.cs +++ b/src/Microsoft.ML.Core/Data/SynchronizedCursorBase.cs @@ -2,7 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// Base class for creating a cursor on top of another cursor that does not add or remove rows. @@ -11,28 +11,27 @@ namespace Microsoft.ML.Runtime.Data /// Dispose is virtual with the default implementation delegating to the input cursor. /// [BestFriend] - internal abstract class SynchronizedCursorBase : ICursor - where TBase : class, ICursor + internal abstract class SynchronizedCursorBase : RowCursor { protected readonly IChannel Ch; - private readonly ICursor _root; + private readonly RowCursor _root; private bool _disposed; - protected TBase Input { get; } + protected RowCursor Input { get; } - public long Position => _root.Position; + public sealed override long Position => _root.Position; - public long Batch => _root.Batch; + public sealed override long Batch => _root.Batch; - public CursorState State => _root.State; + public sealed override CursorState State => _root.State; /// /// Convenience property for checking whether the current state is CursorState.Good. /// protected bool IsGood => _root.State == CursorState.Good; - protected SynchronizedCursorBase(IChannelProvider provider, TBase input) + protected SynchronizedCursorBase(IChannelProvider provider, RowCursor input) { Contracts.AssertValue(provider, "provider"); Ch = provider.Start("Cursor"); @@ -42,34 +41,25 @@ protected SynchronizedCursorBase(IChannelProvider provider, TBase input) _root = Input.GetRootCursor(); } - public virtual void Dispose() + protected override void Dispose(bool disposing) { - if (!_disposed) + if (_disposed) + return; + if (disposing) { Input.Dispose(); Ch.Dispose(); - _disposed = true; } + base.Dispose(disposing); + _disposed = true; } - public bool MoveNext() - { - return _root.MoveNext(); - } + public sealed override bool MoveNext() => _root.MoveNext(); - public bool MoveMany(long count) - { - return _root.MoveMany(count); - } + public sealed override bool MoveMany(long count) => _root.MoveMany(count); - public ICursor GetRootCursor() - { - return _root; - } + public sealed override RowCursor GetRootCursor() => _root; - public ValueGetter GetIdGetter() - { - return Input.GetIdGetter(); - } + public sealed override ValueGetter GetIdGetter() => Input.GetIdGetter(); } } diff --git a/src/Microsoft.ML.Core/Data/VBuffer.cs b/src/Microsoft.ML.Core/Data/VBuffer.cs index a86f0bdae4..949857bc2b 100644 --- a/src/Microsoft.ML.Core/Data/VBuffer.cs +++ b/src/Microsoft.ML.Core/Data/VBuffer.cs @@ -4,9 +4,9 @@ using System; using System.Collections.Generic; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// A buffer that supports both dense and sparse representations. This is the diff --git a/src/Microsoft.ML.Core/Data/VBufferEditor.cs b/src/Microsoft.ML.Core/Data/VBufferEditor.cs index 8da19b641f..1dafac0fa5 100644 --- a/src/Microsoft.ML.Core/Data/VBufferEditor.cs +++ b/src/Microsoft.ML.Core/Data/VBufferEditor.cs @@ -4,7 +4,7 @@ using System; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// Various methods for creating instances. diff --git a/src/Microsoft.ML.Core/Data/WrappingRow.cs b/src/Microsoft.ML.Core/Data/WrappingRow.cs new file mode 100644 index 0000000000..f6568cf205 --- /dev/null +++ b/src/Microsoft.ML.Core/Data/WrappingRow.cs @@ -0,0 +1,64 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.ML.Data +{ + /// + /// Convenient base class for implementors that wrap a single + /// as their input. The , , and + /// are taken from this . + /// + [BestFriend] + internal abstract class WrappingRow : Row + { + private bool _disposed; + + /// + /// The wrapped input row. + /// + protected Row Input { get; } + + public sealed override long Batch => Input.Batch; + public sealed override long Position => Input.Position; + public override ValueGetter GetIdGetter() => Input.GetIdGetter(); + + [BestFriend] + private protected WrappingRow(Row input) + { + Contracts.AssertValue(input); + Input = input; + } + + /// + /// This override of the dispose method by default only calls 's + /// method, but subclasses can enable additional functionality + /// via the functionality. + /// + /// + protected sealed override void Dispose(bool disposing) + { + if (_disposed) + return; + // Since the input was created first, and this instance may depend on it, we should + // dispose local resources first before potentially disposing the input row resources. + DisposeCore(disposing); + if (disposing) + Input.Dispose(); + _disposed = true; + } + + /// + /// Called from with in the case where + /// that method has never been called before, and right after has been + /// disposed. The default implementation does nothing. + /// + /// Whether this was called through the dispose path, as opposed + /// to the finalizer path. + protected virtual void DisposeCore(bool disposing) + { + } + } +} diff --git a/src/Microsoft.ML.Core/EntryPoints/EntryPointModuleAttribute.cs b/src/Microsoft.ML.Core/EntryPoints/EntryPointModuleAttribute.cs index 0163222fc1..175e7cb09a 100644 --- a/src/Microsoft.ML.Core/EntryPoints/EntryPointModuleAttribute.cs +++ b/src/Microsoft.ML.Core/EntryPoints/EntryPointModuleAttribute.cs @@ -4,7 +4,7 @@ using System; -namespace Microsoft.ML.Runtime.EntryPoints +namespace Microsoft.ML.EntryPoints { /// /// This is a signature for classes that are 'holders' of entry points and components. diff --git a/src/Microsoft.ML.Core/EntryPoints/EntryPointUtils.cs b/src/Microsoft.ML.Core/EntryPoints/EntryPointUtils.cs index a7c3ddd298..82edc2bfaa 100644 --- a/src/Microsoft.ML.Core/EntryPoints/EntryPointUtils.cs +++ b/src/Microsoft.ML.Core/EntryPoints/EntryPointUtils.cs @@ -2,13 +2,13 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Data; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Internal.Utilities; using System; using System.Linq; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.EntryPoints +namespace Microsoft.ML.EntryPoints { [BestFriend] internal static class EntryPointUtils diff --git a/src/Microsoft.ML.Core/EntryPoints/IMlState.cs b/src/Microsoft.ML.Core/EntryPoints/IMlState.cs deleted file mode 100644 index 41ea062861..0000000000 --- a/src/Microsoft.ML.Core/EntryPoints/IMlState.cs +++ /dev/null @@ -1,14 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -namespace Microsoft.ML.Runtime.EntryPoints -{ - /// - /// Dummy interface to allow reference to the AutoMlState object in the C# API (since AutoMlState - /// has things that reference C# API, leading to circular dependency). Makes state object an opaque - /// black box to the graph. The macro itself will then case to the concrete type. - /// - public interface IMlState - { } -} \ No newline at end of file diff --git a/src/Microsoft.ML.Core/EntryPoints/ModuleArgs.cs b/src/Microsoft.ML.Core/EntryPoints/ModuleArgs.cs index d538f636ce..1308008c6e 100644 --- a/src/Microsoft.ML.Core/EntryPoints/ModuleArgs.cs +++ b/src/Microsoft.ML.Core/EntryPoints/ModuleArgs.cs @@ -4,15 +4,11 @@ using System; using System.Collections.Generic; -using System.Diagnostics; using System.Linq; -using System.Net.Sockets; -using System.Reflection; using System.Text; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Data; -namespace Microsoft.ML.Runtime.EntryPoints +namespace Microsoft.ML.EntryPoints { /// /// This class defines attributes to annotate module inputs, outputs, entry points etc. when defining @@ -577,11 +573,11 @@ public enum DataKind /// FileHandle, /// - /// A transform model, represented by an . + /// A transform model, represented by an . /// TransformModel, /// - /// A predictor model, represented by an . + /// A predictor model, represented by an . /// PredictorModel, /// @@ -602,11 +598,7 @@ public enum DataKind /// optionally, a set of parameters, unique to each component. Example: "BinaryClassifierEvaluator{threshold=0.5}". /// The C# representation is . /// - Component, - /// - /// An C# object that represents state, such as . - /// - State + Component } public static DataKind GetDataType(Type type) @@ -632,9 +624,9 @@ public static DataKind GetDataType(Type type) return DataKind.Float; if (typeof(IDataView).IsAssignableFrom(type)) return DataKind.DataView; - if (typeof(ITransformModel).IsAssignableFrom(type)) + if (typeof(TransformModel).IsAssignableFrom(type)) return DataKind.TransformModel; - if (typeof(IPredictorModel).IsAssignableFrom(type)) + if (typeof(PredictorModel).IsAssignableFrom(type)) return DataKind.PredictorModel; if (typeof(IFileHandle).IsAssignableFrom(type)) return DataKind.FileHandle; @@ -649,8 +641,6 @@ public static DataKind GetDataType(Type type) } if (typeof(IComponentFactory).IsAssignableFrom(type)) return DataKind.Component; - if (typeof(IMlState).IsAssignableFrom(type)) - return DataKind.State; return DataKind.Unknown; } @@ -673,7 +663,7 @@ public abstract class Optional public abstract object GetValue(); - protected Optional(bool isExplicit) + private protected Optional(bool isExplicit) { IsExplicit = isExplicit; } diff --git a/src/Microsoft.ML.Core/EntryPoints/IPredictorModel.cs b/src/Microsoft.ML.Core/EntryPoints/PredictorModel.cs similarity index 58% rename from src/Microsoft.ML.Core/EntryPoints/IPredictorModel.cs rename to src/Microsoft.ML.Core/EntryPoints/PredictorModel.cs index 0ec1151134..30872a5faa 100644 --- a/src/Microsoft.ML.Core/EntryPoints/IPredictorModel.cs +++ b/src/Microsoft.ML.Core/EntryPoints/PredictorModel.cs @@ -2,42 +2,51 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; using System.IO; -using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Data; -namespace Microsoft.ML.Runtime.EntryPoints +namespace Microsoft.ML.EntryPoints { /// - /// Interface for standard predictor model port type. + /// Base type for standard predictor model port type. /// - public interface IPredictorModel + public abstract class PredictorModel { + [BestFriend] + private protected PredictorModel() + { + } + /// /// Save the model to the given stream. /// - void Save(IHostEnvironment env, Stream stream); + [BestFriend] + internal abstract void Save(IHostEnvironment env, Stream stream); /// /// Extract only the transform portion of the predictor model. /// - ITransformModel TransformModel { get; } + [BestFriend] + internal abstract TransformModel TransformModel { get; } /// /// Extract the predictor object out of the predictor model. /// - IPredictor Predictor { get; } + [BestFriend] + internal abstract IPredictor Predictor { get; } /// /// Apply the predictor model to the transform model and return the resulting predictor model. /// - IPredictorModel Apply(IHostEnvironment env, ITransformModel transformModel); + [BestFriend] + internal abstract PredictorModel Apply(IHostEnvironment env, TransformModel transformModel); /// /// For a given input data, return role mapped data and the predictor object. /// The scoring entry point will hopefully know how to construct a scorer out of them. /// - void PrepareData(IHostEnvironment env, IDataView input, out RoleMappedData roleMappedData, out IPredictor predictor); + [BestFriend] + internal abstract void PrepareData(IHostEnvironment env, IDataView input, out RoleMappedData roleMappedData, out IPredictor predictor); /// /// Returns a string array containing the label names of the label column type predictor was trained on. @@ -46,13 +55,13 @@ public interface IPredictorModel /// /// /// The column type of the label the predictor was trained on. - string[] GetLabelInfo(IHostEnvironment env, out ColumnType labelType); + [BestFriend] + internal abstract string[] GetLabelInfo(IHostEnvironment env, out ColumnType labelType); /// - /// Returns the RoleMappedSchema that was used in training. + /// Returns the that was used in training. /// - /// - /// - RoleMappedSchema GetTrainingSchema(IHostEnvironment env); + [BestFriend] + internal abstract RoleMappedSchema GetTrainingSchema(IHostEnvironment env); } } \ No newline at end of file diff --git a/src/Microsoft.ML.Core/Data/ITransformModel.cs b/src/Microsoft.ML.Core/EntryPoints/TransformModel.cs similarity index 66% rename from src/Microsoft.ML.Core/Data/ITransformModel.cs rename to src/Microsoft.ML.Core/EntryPoints/TransformModel.cs index 37ea9c1a2d..110c75c7aa 100644 --- a/src/Microsoft.ML.Core/Data/ITransformModel.cs +++ b/src/Microsoft.ML.Core/EntryPoints/TransformModel.cs @@ -2,58 +2,67 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; using System.IO; using Microsoft.ML.Data; -using Microsoft.ML.Runtime.Data; -namespace Microsoft.ML.Runtime.EntryPoints +namespace Microsoft.ML.EntryPoints { /// /// Interface for standard transform model port type. /// - public interface ITransformModel + public abstract class TransformModel { + [BestFriend] + private protected TransformModel() + { + } + /// /// The input schema that this transform model was originally instantiated on. /// Note that the schema may have columns that aren't needed by this transform model. - /// If an IDataView exists with this schema, then applying this transform model to it + /// If an exists with this schema, then applying this transform model to it /// shouldn't fail because of column type issues. /// // REVIEW: Would be nice to be able to trim this to the minimum needed somehow. Note // however that doing so may cause issues for composing transform models. For example, // if transform model A needs column X and model B needs Y, that is NOT produced by A, // then trimming A's input schema would cause composition to fail. - Schema InputSchema { get; } + [BestFriend] + internal abstract Schema InputSchema { get; } /// /// The output schema that this transform model was originally instantiated on. The schema resulting - /// from may differ from this, similarly to how + /// from may differ from this, similarly to how /// may differ from the schema of dataviews we apply this transform model to. /// - Schema OutputSchema { get; } + [BestFriend] + internal abstract Schema OutputSchema { get; } /// /// Apply the transform(s) in the model to the given input data. /// - IDataView Apply(IHostEnvironment env, IDataView input); + [BestFriend] + internal abstract IDataView Apply(IHostEnvironment env, IDataView input); /// /// Apply the transform(s) in the model to the given transform model. /// - ITransformModel Apply(IHostEnvironment env, ITransformModel input); + [BestFriend] + internal abstract TransformModel Apply(IHostEnvironment env, TransformModel input); /// /// Save the model to the given stream. /// - void Save(IHostEnvironment env, Stream stream); + [BestFriend] + internal abstract void Save(IHostEnvironment env, Stream stream); /// /// Returns the transform model as an that can output a row /// given a row with the same schema as . /// /// The transform model as an . If not all transforms - /// in the pipeline are then it returns null. - IRowToRowMapper AsRowToRowMapper(IExceptionContext ectx); + /// in the pipeline are then it returns . + [BestFriend] + internal abstract IRowToRowMapper AsRowToRowMapper(IExceptionContext ectx); } } diff --git a/src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs b/src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs index 3e27ce2516..057969c81f 100644 --- a/src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs +++ b/src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs @@ -9,7 +9,7 @@ using System.Linq; using System.Threading; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { using Stopwatch = System.Diagnostics.Stopwatch; @@ -368,7 +368,7 @@ public ConsoleEnvironment(int? seed = null, bool verbose = false, /// Concurrency level. Set to 1 to run single-threaded. Set to 0 to pick automatically. /// Text writer to print normal messages to. /// Text writer to print error messages to. - private ConsoleEnvironment(IRandom rand, bool verbose = false, + private ConsoleEnvironment(Random rand, bool verbose = false, MessageSensitivity sensitivity = MessageSensitivity.All, int conc = 0, TextWriter outWriter = null, TextWriter errWriter = null) : base(rand, verbose, conc, nameof(ConsoleEnvironment)) @@ -401,7 +401,7 @@ protected override IFileHandle CreateTempFileCore(IHostEnvironment env, string s return base.CreateTempFileCore(env, suffix, "TLC_" + prefix); } - protected override IHost RegisterCore(HostEnvironmentBase source, string shortName, string parentFullName, IRandom rand, bool verbose, int? conc) + protected override IHost RegisterCore(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, bool verbose, int? conc) { Contracts.AssertValue(rand); Contracts.AssertValueOrNull(parentFullName); @@ -472,7 +472,7 @@ public void Dispose() private sealed class Host : HostBase { - public Host(HostEnvironmentBase source, string shortName, string parentFullName, IRandom rand, bool verbose, int? conc) + public Host(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, bool verbose, int? conc) : base(source, shortName, parentFullName, rand, verbose, conc) { IsCancelled = source.IsCancelled; @@ -494,7 +494,7 @@ protected override IPipe CreatePipe(ChannelProviderBase pare return new Pipe(parent, name, GetDispatchDelegate()); } - protected override IHost RegisterCore(HostEnvironmentBase source, string shortName, string parentFullName, IRandom rand, bool verbose, int? conc) + protected override IHost RegisterCore(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, bool verbose, int? conc) { return new Host(source, shortName, parentFullName, rand, verbose, conc); } diff --git a/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs b/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs index 31ab23e28f..deaf372cb9 100644 --- a/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs +++ b/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs @@ -5,11 +5,10 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; -using System.ComponentModel.Composition; using System.ComponentModel.Composition.Hosting; using System.IO; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// Base class for channel providers. This is a common base class for. @@ -107,12 +106,12 @@ public abstract class HostBase : HostEnvironmentBase, IHost { public override int Depth { get; } - public IRandom Rand => _rand; + public Random Rand => _rand; // We don't have dispose mechanism for hosts, so to let GC collect children hosts we make them WeakReference. private readonly List> _children; - public HostBase(HostEnvironmentBase source, string shortName, string parentFullName, IRandom rand, bool verbose, int? conc) + public HostBase(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, bool verbose, int? conc) : base(source, rand, verbose, conc, shortName, parentFullName) { Depth = source.Depth + 1; @@ -139,7 +138,7 @@ public void StopExecution() IHost host; lock (_cancelLock) { - IRandom rand = (seed.HasValue) ? RandomUtils.Create(seed.Value) : RandomUtils.Create(_rand); + Random rand = (seed.HasValue) ? RandomUtils.Create(seed.Value) : RandomUtils.Create(_rand); host = RegisterCore(this, name, Master?.FullName, rand, verbose ?? Verbose, conc ?? _conc); if (!IsCancelled) _children.Add(new WeakReference(host)); @@ -342,7 +341,7 @@ public void RemoveListener(Action listenerFunc) private readonly object _cancelLock; // The random number generator for this host. - private readonly IRandom _rand; + private readonly Random _rand; // A dictionary mapping the type of message to the Dispatcher that gets the strongly typed dispatch delegate. protected readonly ConcurrentDictionary ListenerDict; @@ -367,7 +366,7 @@ public void RemoveListener(Action listenerFunc) /// /// The main constructor. /// - protected HostEnvironmentBase(IRandom rand, bool verbose, int conc, + protected HostEnvironmentBase(Random rand, bool verbose, int conc, string shortName = null, string parentFullName = null) : base(shortName, parentFullName, verbose) { @@ -386,7 +385,7 @@ protected HostEnvironmentBase(IRandom rand, bool verbose, int conc, /// /// This constructor is for forking. /// - protected HostEnvironmentBase(HostEnvironmentBase source, IRandom rand, bool verbose, + protected HostEnvironmentBase(HostEnvironmentBase source, Random rand, bool verbose, int? conc, string shortName = null, string parentFullName = null) : base(shortName, parentFullName, verbose) { @@ -433,12 +432,12 @@ public virtual void Dispose() public IHost Register(string name, int? seed = null, bool? verbose = null, int? conc = null) { Contracts.CheckNonEmpty(name, nameof(name)); - IRandom rand = (seed.HasValue) ? RandomUtils.Create(seed.Value) : RandomUtils.Create(_rand); + Random rand = (seed.HasValue) ? RandomUtils.Create(seed.Value) : RandomUtils.Create(_rand); return RegisterCore(this, name, Master?.FullName, rand, verbose ?? Verbose, conc); } protected abstract IHost RegisterCore(HostEnvironmentBase source, string shortName, - string parentFullName, IRandom rand, bool verbose, int? conc); + string parentFullName, Random rand, bool verbose, int? conc); public IFileHandle OpenInputFile(string path) { diff --git a/src/Microsoft.ML.Core/Environment/TelemetryMessage.cs b/src/Microsoft.ML.Core/Environment/TelemetryMessage.cs index 72b08e2715..5501a99090 100644 --- a/src/Microsoft.ML.Core/Environment/TelemetryMessage.cs +++ b/src/Microsoft.ML.Core/Environment/TelemetryMessage.cs @@ -4,11 +4,8 @@ using System; using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; -namespace Microsoft.ML.Runtime +namespace Microsoft.ML { /// /// A telemetry message. diff --git a/src/Microsoft.ML.Core/Prediction/IPredictor.cs b/src/Microsoft.ML.Core/Prediction/IPredictor.cs index 6bd1ac2056..682cec417d 100644 --- a/src/Microsoft.ML.Core/Prediction/IPredictor.cs +++ b/src/Microsoft.ML.Core/Prediction/IPredictor.cs @@ -2,9 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; - -namespace Microsoft.ML.Runtime +namespace Microsoft.ML { /// /// Type of prediction task diff --git a/src/Microsoft.ML.Core/Prediction/ITrainer.cs b/src/Microsoft.ML.Core/Prediction/ITrainer.cs index 5e796aa602..ad5a0b2539 100644 --- a/src/Microsoft.ML.Core/Prediction/ITrainer.cs +++ b/src/Microsoft.ML.Core/Prediction/ITrainer.cs @@ -2,10 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime.Data; -using System; +using Microsoft.ML.Data; -namespace Microsoft.ML.Runtime +namespace Microsoft.ML { // REVIEW: Would be nice if the registration under SignatureTrainer were automatic // given registration for one of the "sub-class" signatures. diff --git a/src/Microsoft.ML.Core/Prediction/ITree.cs b/src/Microsoft.ML.Core/Prediction/ITree.cs index 67642ecfc5..9b07acdef7 100644 --- a/src/Microsoft.ML.Core/Prediction/ITree.cs +++ b/src/Microsoft.ML.Core/Prediction/ITree.cs @@ -2,10 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; using System.Collections.Generic; -namespace Microsoft.ML.Runtime.TreePredictor +namespace Microsoft.ML.TreePredictor { // The interfaces contained herein are meant to allow tree visualizer to run without an explicit dependency // on FastTree, so as to allow it greater generality. These should probably be moved somewhere else, but where? diff --git a/src/Microsoft.ML.Core/Prediction/TrainContext.cs b/src/Microsoft.ML.Core/Prediction/TrainContext.cs index e5e4bbad1c..ff37caf598 100644 --- a/src/Microsoft.ML.Core/Prediction/TrainContext.cs +++ b/src/Microsoft.ML.Core/Prediction/TrainContext.cs @@ -2,9 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Data; -namespace Microsoft.ML.Runtime +namespace Microsoft.ML { /// /// Holds information relevant to trainers. Instances of this class are meant to be constructed and passed diff --git a/src/Microsoft.ML.Core/Prediction/TrainerInfo.cs b/src/Microsoft.ML.Core/Prediction/TrainerInfo.cs index 4f97c0c893..53398bc651 100644 --- a/src/Microsoft.ML.Core/Prediction/TrainerInfo.cs +++ b/src/Microsoft.ML.Core/Prediction/TrainerInfo.cs @@ -2,7 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -namespace Microsoft.ML.Runtime +namespace Microsoft.ML { /// /// Instances of this class posses information about trainers, in terms of their requirements and capabilities. diff --git a/src/Microsoft.ML.Core/Properties/AssemblyInfo.cs b/src/Microsoft.ML.Core/Properties/AssemblyInfo.cs index a61e96ce7b..58bda21829 100644 --- a/src/Microsoft.ML.Core/Properties/AssemblyInfo.cs +++ b/src/Microsoft.ML.Core/Properties/AssemblyInfo.cs @@ -9,10 +9,12 @@ [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Tests" + PublicKey.TestValue)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Core.Tests" + PublicKey.TestValue)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Predictor.Tests" + PublicKey.TestValue)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Sweeper.Tests" + PublicKey.TestValue)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.InferenceTesting" + PublicKey.TestValue)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.StaticPipelineTesting" + PublicKey.TestValue)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.OnnxTransformTest" + PublicKey.TestValue)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.EntryPoints" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Legacy" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Maml" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.ResultProcessor" + PublicKey.Value)] @@ -31,7 +33,7 @@ [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.PCA" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.PipelineInference" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Recommender" + PublicKey.Value)] -[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Runtime.ImageAnalytics" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.ImageAnalytics" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Scoring" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.StandardLearners" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Sweeper" + PublicKey.Value)] @@ -39,4 +41,10 @@ [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TimeSeries" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Transforms" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.StaticPipe" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TensorFlow.StaticPipe" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.HalLearners.StaticPipe" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.OnnxTransform.StaticPipe" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.LightGBM.StaticPipe" + PublicKey.Value)] + [assembly: WantsToBeBestFriends] diff --git a/src/Microsoft.ML.Core/PublicKey.cs b/src/Microsoft.ML.Core/PublicKey.cs index 9a944c3d18..63718c3f8e 100644 --- a/src/Microsoft.ML.Core/PublicKey.cs +++ b/src/Microsoft.ML.Core/PublicKey.cs @@ -7,8 +7,8 @@ namespace Microsoft.ML #else // CpuMath module has its own PublicKey for isolating itself from Microsoft.ML.Core -// Note that CpuMath uses its own BestFriend defined in Microsoft.ML.Runtime.Internal.CpuMath.Core. -namespace Microsoft.ML.Runtime.Internal.CpuMath.Core +// Note that CpuMath uses its own BestFriend defined in Microsoft.ML.Internal.CpuMath.Core. +namespace Microsoft.ML.Internal.CpuMath.Core #endif { [BestFriend] diff --git a/src/Microsoft.ML.Core/Utilities/BigArray.cs b/src/Microsoft.ML.Core/Utilities/BigArray.cs index 3bfb4f688f..bda630cbc0 100644 --- a/src/Microsoft.ML.Core/Utilities/BigArray.cs +++ b/src/Microsoft.ML.Core/Utilities/BigArray.cs @@ -6,7 +6,7 @@ using System.Collections; using System.Collections.Generic; -namespace Microsoft.ML.Runtime.Internal.Utilities +namespace Microsoft.ML.Internal.Utilities { /// /// An array-like data structure that supports storing more than diff --git a/src/Microsoft.ML.Core/Utilities/BinFinder.cs b/src/Microsoft.ML.Core/Utilities/BinFinder.cs index cdfd0ad08b..d5456e2956 100644 --- a/src/Microsoft.ML.Core/Utilities/BinFinder.cs +++ b/src/Microsoft.ML.Core/Utilities/BinFinder.cs @@ -2,12 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Float = System.Single; - using System; using System.Collections.Generic; +using Float = System.Single; -namespace Microsoft.ML.Runtime.Internal.Utilities +namespace Microsoft.ML.Internal.Utilities { [BestFriend] internal abstract class BinFinderBase @@ -273,7 +272,7 @@ public static Double GetSplitValue(Double a, Double b) } } -namespace Microsoft.ML.Runtime.Internal.Utilities +namespace Microsoft.ML.Internal.Utilities { // This needs to be large enough to represent a product of 2 ints without losing precision using EnergyType = System.Int64; @@ -525,7 +524,7 @@ private void UpdatePeg(Peg peg) } } -namespace Microsoft.ML.Runtime.Internal.Utilities +namespace Microsoft.ML.Internal.Utilities { // Reasonable choices are Double and System.Int64. using EnergyType = System.Double; diff --git a/src/Microsoft.ML.Core/Utilities/BitUtils.cs b/src/Microsoft.ML.Core/Utilities/BitUtils.cs index 376b0ac8ef..b38b4ad01a 100644 --- a/src/Microsoft.ML.Core/Utilities/BitUtils.cs +++ b/src/Microsoft.ML.Core/Utilities/BitUtils.cs @@ -2,10 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; using System.Runtime.CompilerServices; -namespace Microsoft.ML.Runtime.Internal.Utilities +namespace Microsoft.ML.Internal.Utilities { internal static partial class Utils { diff --git a/src/Microsoft.ML.Core/Utilities/CharUtils.cs b/src/Microsoft.ML.Core/Utilities/CharUtils.cs index d88197c8e7..398b97f851 100644 --- a/src/Microsoft.ML.Core/Utilities/CharUtils.cs +++ b/src/Microsoft.ML.Core/Utilities/CharUtils.cs @@ -4,11 +4,10 @@ #pragma warning disable 420 // volatile with Interlocked.CompareExchange -using System; using System.Runtime.CompilerServices; using System.Threading; -namespace Microsoft.ML.Runtime.Internal.Utilities +namespace Microsoft.ML.Internal.Utilities { [BestFriend] internal static class CharUtils diff --git a/src/Microsoft.ML.Core/Utilities/CmdIndenter.cs b/src/Microsoft.ML.Core/Utilities/CmdIndenter.cs index 6b4dbc48db..da6e8b821f 100644 --- a/src/Microsoft.ML.Core/Utilities/CmdIndenter.cs +++ b/src/Microsoft.ML.Core/Utilities/CmdIndenter.cs @@ -2,15 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; using System.CodeDom.Compiler; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; -using Microsoft.ML.Runtime.CommandLine; +using Microsoft.ML.CommandLine; -namespace Microsoft.ML.Runtime.Internal.Utilities +namespace Microsoft.ML.Internal.Utilities { [BestFriend] internal static class CmdIndenter diff --git a/src/Microsoft.ML.Core/Utilities/Contracts.cs b/src/Microsoft.ML.Core/Utilities/Contracts.cs index d567d6e883..6fc94f6304 100644 --- a/src/Microsoft.ML.Core/Utilities/Contracts.cs +++ b/src/Microsoft.ML.Core/Utilities/Contracts.cs @@ -17,9 +17,9 @@ using System.Threading; #if CPUMATH_INFRASTRUCTURE -namespace Microsoft.ML.Runtime.Internal.CpuMath.Core +namespace Microsoft.ML.Internal.CpuMath.Core #else -namespace Microsoft.ML.Runtime +namespace Microsoft.ML #endif { using Conditional = System.Diagnostics.ConditionalAttribute; @@ -758,6 +758,10 @@ public static void CheckAlive(this IHostEnvironment env) public static void CheckValueOrNull(T val) where T : class { } + + /// + /// This documents that the parameter can legally be null. + /// [Conditional("INVARIANT_CHECKS")] public static void CheckValueOrNull(this IExceptionContext ctx, T val) where T : class { diff --git a/src/Microsoft.ML.Core/Utilities/DoubleParser.cs b/src/Microsoft.ML.Core/Utilities/DoubleParser.cs index 4f2ef9ad80..695c1981be 100644 --- a/src/Microsoft.ML.Core/Utilities/DoubleParser.cs +++ b/src/Microsoft.ML.Core/Utilities/DoubleParser.cs @@ -4,14 +4,9 @@ #undef COMPARE_BCL -using Float = System.Single; - using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -namespace Microsoft.ML.Runtime.Internal.Utilities +namespace Microsoft.ML.Internal.Utilities { [BestFriend] internal static class DoubleParser diff --git a/src/Microsoft.ML.Core/Utilities/FixedSizeQueue.cs b/src/Microsoft.ML.Core/Utilities/FixedSizeQueue.cs index 7263304046..ab2d938d22 100644 --- a/src/Microsoft.ML.Core/Utilities/FixedSizeQueue.cs +++ b/src/Microsoft.ML.Core/Utilities/FixedSizeQueue.cs @@ -2,9 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; - -namespace Microsoft.ML.Runtime.Internal.Utilities +namespace Microsoft.ML.Internal.Utilities { using Conditional = System.Diagnostics.ConditionalAttribute; diff --git a/src/Microsoft.ML.Core/Utilities/FloatUtils.cs b/src/Microsoft.ML.Core/Utilities/FloatUtils.cs index 06d403da9a..b93541eb3e 100644 --- a/src/Microsoft.ML.Core/Utilities/FloatUtils.cs +++ b/src/Microsoft.ML.Core/Utilities/FloatUtils.cs @@ -6,7 +6,7 @@ using System.Globalization; using System.Runtime.InteropServices; -namespace Microsoft.ML.Runtime.Internal.Utilities +namespace Microsoft.ML.Internal.Utilities { [BestFriend] internal static class FloatUtils diff --git a/src/Microsoft.ML.Core/Utilities/HashArray.cs b/src/Microsoft.ML.Core/Utilities/HashArray.cs index 64dced9792..c9f31b5361 100644 --- a/src/Microsoft.ML.Core/Utilities/HashArray.cs +++ b/src/Microsoft.ML.Core/Utilities/HashArray.cs @@ -6,7 +6,7 @@ using System; -namespace Microsoft.ML.Runtime.Internal.Utilities +namespace Microsoft.ML.Internal.Utilities { // REVIEW: May want to add an IEnumerable>. diff --git a/src/Microsoft.ML.Core/Utilities/Hashing.cs b/src/Microsoft.ML.Core/Utilities/Hashing.cs index ae36fae95d..8345a319ac 100644 --- a/src/Microsoft.ML.Core/Utilities/Hashing.cs +++ b/src/Microsoft.ML.Core/Utilities/Hashing.cs @@ -3,11 +3,10 @@ // See the LICENSE file in the project root for more information. using System; -using System.Collections.Generic; using System.Runtime.CompilerServices; using System.Text; -namespace Microsoft.ML.Runtime.Internal.Utilities +namespace Microsoft.ML.Internal.Utilities { [BestFriend] internal static class Hashing diff --git a/src/Microsoft.ML.Core/Utilities/Heap.cs b/src/Microsoft.ML.Core/Utilities/Heap.cs index 7652163f10..fe1178abd3 100644 --- a/src/Microsoft.ML.Core/Utilities/Heap.cs +++ b/src/Microsoft.ML.Core/Utilities/Heap.cs @@ -5,7 +5,7 @@ using System; using System.Collections.Generic; -namespace Microsoft.ML.Runtime.Internal.Utilities +namespace Microsoft.ML.Internal.Utilities { using Conditional = System.Diagnostics.ConditionalAttribute; diff --git a/src/Microsoft.ML.Core/Utilities/HybridMemoryStream.cs b/src/Microsoft.ML.Core/Utilities/HybridMemoryStream.cs index 02f713dd6e..c1db619510 100644 --- a/src/Microsoft.ML.Core/Utilities/HybridMemoryStream.cs +++ b/src/Microsoft.ML.Core/Utilities/HybridMemoryStream.cs @@ -5,7 +5,7 @@ using System; using System.IO; -namespace Microsoft.ML.Runtime.Internal.Utilities +namespace Microsoft.ML.Internal.Utilities { using Conditional = System.Diagnostics.ConditionalAttribute; diff --git a/src/Microsoft.ML.Core/Utilities/IndentedTextWriterExtensions.cs b/src/Microsoft.ML.Core/Utilities/IndentedTextWriterExtensions.cs index fe8fd12e96..fd733dc60c 100644 --- a/src/Microsoft.ML.Core/Utilities/IndentedTextWriterExtensions.cs +++ b/src/Microsoft.ML.Core/Utilities/IndentedTextWriterExtensions.cs @@ -4,10 +4,8 @@ using System; using System.CodeDom.Compiler; -using System.IO; -using System.Text; -namespace Microsoft.ML.Runtime.Internal.Utilities +namespace Microsoft.ML.Internal.Utilities { [BestFriend] internal static class IndentedTextWriterExtensions diff --git a/src/Microsoft.ML.Core/Utilities/LineParser.cs b/src/Microsoft.ML.Core/Utilities/LineParser.cs index 73d9bb6158..bc8eabcdda 100644 --- a/src/Microsoft.ML.Core/Utilities/LineParser.cs +++ b/src/Microsoft.ML.Core/Utilities/LineParser.cs @@ -5,7 +5,7 @@ using System; using System.Runtime.CompilerServices; -namespace Microsoft.ML.Runtime.Internal.Utilities +namespace Microsoft.ML.Internal.Utilities { [BestFriend] internal static class LineParser diff --git a/src/Microsoft.ML.Core/Utilities/LruCache.cs b/src/Microsoft.ML.Core/Utilities/LruCache.cs index e041efde02..ad874b60c1 100644 --- a/src/Microsoft.ML.Core/Utilities/LruCache.cs +++ b/src/Microsoft.ML.Core/Utilities/LruCache.cs @@ -5,7 +5,7 @@ using System.Collections.Generic; using System.Linq; -namespace Microsoft.ML.Runtime.Internal.Utilities +namespace Microsoft.ML.Internal.Utilities { /// /// Implements a least recently used cache. diff --git a/src/Microsoft.ML.Core/Utilities/MathUtils.cs b/src/Microsoft.ML.Core/Utilities/MathUtils.cs index 4ce1bb4cf0..7ab349da12 100644 --- a/src/Microsoft.ML.Core/Utilities/MathUtils.cs +++ b/src/Microsoft.ML.Core/Utilities/MathUtils.cs @@ -2,13 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Float = System.Single; - using System; using System.Collections.Generic; -using System.Linq; +using Float = System.Single; -namespace Microsoft.ML.Runtime.Internal.Utilities +namespace Microsoft.ML.Internal.Utilities { /// /// Some useful math methods. @@ -507,7 +505,15 @@ public static Float Tanh(Float x) /// public static Float SigmoidSlow(Float x) { - return 1 / (1 + ExpSlow(-x)); + // The following two expressions are mathematically equivalent. Due to the potential of getting overflow we should + // not call exp(x) for large positive x: instead, we modify the expression to compute exp(-x). + if (x > 0) + return 1 / (1 + ExpSlow(-x)); + else + { + var ex = ExpSlow(x); + return ex / (1 + ex); + } } /// @@ -726,7 +732,7 @@ public static Float GetMedianInPlace(Float[] src, int count) return (src[iv - 1] + src[iv]) / 2; } - public static Double CosineSimilarity(Float[] a, Float[] b, int aIdx, int bIdx, int len) + public static Double CosineSimilarity(ReadOnlySpan a, ReadOnlySpan b, int aIdx, int bIdx, int len) { const Double epsilon = 1e-12f; Contracts.Assert(len > 0); diff --git a/src/Microsoft.ML.Core/Utilities/MatrixTransposeOps.cs b/src/Microsoft.ML.Core/Utilities/MatrixTransposeOps.cs index cb945e56a2..2e25599c1d 100644 --- a/src/Microsoft.ML.Core/Utilities/MatrixTransposeOps.cs +++ b/src/Microsoft.ML.Core/Utilities/MatrixTransposeOps.cs @@ -7,7 +7,7 @@ using System.Linq; using System.Threading.Tasks; -namespace Microsoft.ML.Runtime.Internal.Utilities +namespace Microsoft.ML.Internal.Utilities { [BestFriend] internal static class MatrixTransposeOps diff --git a/src/Microsoft.ML.Core/Utilities/MinWaiter.cs b/src/Microsoft.ML.Core/Utilities/MinWaiter.cs index fbaf8fb6d6..8a671ab817 100644 --- a/src/Microsoft.ML.Core/Utilities/MinWaiter.cs +++ b/src/Microsoft.ML.Core/Utilities/MinWaiter.cs @@ -2,10 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; using System.Threading; -namespace Microsoft.ML.Runtime.Internal.Utilities +namespace Microsoft.ML.Internal.Utilities { /// /// A synchronization primitive meant to address situations where you have a set of diff --git a/src/Microsoft.ML.Core/Utilities/NormStr.cs b/src/Microsoft.ML.Core/Utilities/NormStr.cs index c79e2425d1..9859ea2724 100644 --- a/src/Microsoft.ML.Core/Utilities/NormStr.cs +++ b/src/Microsoft.ML.Core/Utilities/NormStr.cs @@ -3,14 +3,12 @@ // See the LICENSE file in the project root for more information. using System; -using System.Collections.Generic; using System.Collections; +using System.Collections.Generic; using System.Linq; -using System.Threading; using System.Text; -using Microsoft.ML.Runtime.Data; -namespace Microsoft.ML.Runtime.Internal.Utilities +namespace Microsoft.ML.Internal.Utilities { using Conditional = System.Diagnostics.ConditionalAttribute; diff --git a/src/Microsoft.ML.Core/Utilities/ObjectPool.cs b/src/Microsoft.ML.Core/Utilities/ObjectPool.cs index a06202af76..a5a9591fae 100644 --- a/src/Microsoft.ML.Core/Utilities/ObjectPool.cs +++ b/src/Microsoft.ML.Core/Utilities/ObjectPool.cs @@ -6,7 +6,7 @@ using System.Collections.Concurrent; using System.Threading; -namespace Microsoft.ML.Runtime.Internal.Utilities +namespace Microsoft.ML.Internal.Utilities { [BestFriend] internal sealed class ObjectPool : ObjectPoolBase where T : class, new() diff --git a/src/Microsoft.ML.Core/Utilities/OrderedWaiter.cs b/src/Microsoft.ML.Core/Utilities/OrderedWaiter.cs index ca0ef23445..5a71c6205d 100644 --- a/src/Microsoft.ML.Core/Utilities/OrderedWaiter.cs +++ b/src/Microsoft.ML.Core/Utilities/OrderedWaiter.cs @@ -5,7 +5,7 @@ using System; using System.Threading; -namespace Microsoft.ML.Runtime.Internal.Utilities +namespace Microsoft.ML.Internal.Utilities { /// /// The primary use case for this structure is to impose ordering among diff --git a/src/Microsoft.ML.Core/Utilities/PathUtils.cs b/src/Microsoft.ML.Core/Utilities/PathUtils.cs index 98407e24a1..c10e36b1eb 100644 --- a/src/Microsoft.ML.Core/Utilities/PathUtils.cs +++ b/src/Microsoft.ML.Core/Utilities/PathUtils.cs @@ -6,7 +6,7 @@ using System.IO; using System.Threading; -namespace Microsoft.ML.Runtime.Internal.Utilities +namespace Microsoft.ML.Internal.Utilities { internal static partial class Utils { diff --git a/src/Microsoft.ML.Core/Utilities/PlatformUtils.cs b/src/Microsoft.ML.Core/Utilities/PlatformUtils.cs index 2bd8acab3e..16cadbdb5b 100644 --- a/src/Microsoft.ML.Core/Utilities/PlatformUtils.cs +++ b/src/Microsoft.ML.Core/Utilities/PlatformUtils.cs @@ -6,7 +6,7 @@ using System.Collections.ObjectModel; using System.Reflection; -namespace Microsoft.ML.Runtime.Internal.Utilities +namespace Microsoft.ML.Internal.Utilities { /// /// Contains extension methods that aid in building cross platform. diff --git a/src/Microsoft.ML.Core/Utilities/Random.cs b/src/Microsoft.ML.Core/Utilities/Random.cs index d5bbf2d4ec..c79a22426b 100644 --- a/src/Microsoft.ML.Core/Utilities/Random.cs +++ b/src/Microsoft.ML.Core/Utilities/Random.cs @@ -4,48 +4,46 @@ using System; using System.IO; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime +namespace Microsoft.ML { - public interface IRandom + [BestFriend] + internal static class RandomUtils { - /// - /// Generates a Single in the range [0, 1). - /// - Single NextSingle(); - - /// - /// Generates a Double in the range [0, 1). - /// - Double NextDouble(); - - /// - /// Generates an int in the range [0, int.MaxValue]. Note that this differs - /// from the contract for System.Random.Next, which claims to never return - /// int.MaxValue. - /// - int Next(); - - /// - /// Generates an int in the range [int.MinValue, int.MaxValue]. - /// - int NextSigned(); - - /// - /// Generates an int in the range [0, limit), unless limit == 0, in which case this advances the generator - /// and returns 0. - /// Throws if limit is less than 0. - /// - int Next(int limit); - } + public static float NextSingle(this Random random) + { + if (random is TauswortheHybrid tauswortheHybrd) + { + return tauswortheHybrd.NextSingle(); + } + + for (; ; ) + { + // Since the largest value that NextDouble() can return rounds to 1 when cast to float, + // we need to protect against returning 1. + var res = (float)random.NextDouble(); + if (res < 1.0f) + return res; + } + } + + public static int NextSigned(this Random random) + { + if (random is TauswortheHybrid tauswortheHybrd) + { + return tauswortheHybrd.NextSigned(); + } + + // Note that, according to the documentation for System.Random, + // this won't ever achieve int.MaxValue, but oh well. + return random.Next(int.MinValue, int.MaxValue); + } - public static class RandomUtils - { public static TauswortheHybrid Create() { // Seed from a system random. - return new TauswortheHybrid(new SysRandom()); + return new TauswortheHybrid(new Random()); } public static TauswortheHybrid Create(int? seed) @@ -74,81 +72,17 @@ public static TauswortheHybrid Create(uint seed) return new TauswortheHybrid(state); } - public static TauswortheHybrid Create(IRandom seed) + public static TauswortheHybrid Create(Random seed) { return new TauswortheHybrid(seed); } } - public sealed class SysRandom : IRandom - { - private readonly Random _rnd; - - public SysRandom() - { - _rnd = new Random(); - } - - public SysRandom(int seed) - { - _rnd = new Random(seed); - } - - public static SysRandom Wrap(Random rnd) - { - if (rnd != null) - return new SysRandom(rnd); - return null; - } - - private SysRandom(Random rnd) - { - Contracts.AssertValue(rnd); - _rnd = rnd; - } - - public Single NextSingle() - { - // Since the largest value that NextDouble() can return rounds to 1 when cast to Single, - // we need to protect against returning 1. - for (;;) - { - var res = (Single)_rnd.NextDouble(); - if (res < 1.0f) - return res; - } - } - - public Double NextDouble() - { - return _rnd.NextDouble(); - } - - public int Next() - { - // Note that, according to the documentation for System.Random, - // this won't ever achieve int.MaxValue, but oh well. - return _rnd.Next(); - } - - public int Next(int limit) - { - Contracts.CheckParam(limit >= 0, nameof(limit), "limit must be non-negative"); - return _rnd.Next(limit); - } - - public int NextSigned() - { - // Note that, according to the documentation for System.Random, - // this won't ever achieve int.MaxValue, but oh well. - return _rnd.Next(int.MinValue, int.MaxValue); - } - } - /// /// Tausworthe hybrid random number generator. /// - public sealed class TauswortheHybrid : IRandom + [BestFriend] + internal sealed class TauswortheHybrid : Random { public readonly struct State { @@ -204,7 +138,7 @@ public TauswortheHybrid(State state) _z4 = state.U4; } - public TauswortheHybrid(IRandom rng) + public TauswortheHybrid(Random rng) { _z1 = GetSeed(rng); _z2 = GetSeed(rng); @@ -212,14 +146,14 @@ public TauswortheHybrid(IRandom rng) _z4 = GetU(rng); } - private static uint GetU(IRandom rng) + private static uint GetU(Random rng) { return ((uint)rng.Next(0x00010000) << 16) | ((uint)rng.Next(0x00010000)); } - private static uint GetSeed(IRandom rng) + private static uint GetSeed(Random rng) { - for (;;) + for (; ; ) { uint u = GetU(rng); if (u >= 128) @@ -227,19 +161,19 @@ private static uint GetSeed(IRandom rng) } } - public Single NextSingle() + public float NextSingle() { NextState(); return GetSingle(); } - public Double NextDouble() + public override double NextDouble() { NextState(); return GetDouble(); } - public int Next() + public override int Next() { NextState(); uint u = GetUint(); @@ -248,17 +182,35 @@ public int Next() return n; } - public int Next(int limit) + public override int Next(int maxValue) { - Contracts.CheckParam(limit >= 0, nameof(limit), "limit must be non-negative"); + Contracts.CheckParam(maxValue >= 0, nameof(maxValue), "maxValue must be non-negative"); NextState(); uint u = GetUint(); - ulong uu = (ulong)u * (ulong)limit; + ulong uu = (ulong)u * (ulong)maxValue; int res = (int)(uu >> 32); - Contracts.Assert(0 <= res && (res < limit || res == 0)); + Contracts.Assert(0 <= res && (res < maxValue || res == 0)); return res; } + public override int Next(int minValue, int maxValue) + { + Contracts.CheckParam(minValue <= maxValue, nameof(minValue), "minValue must be less than or equal to maxValue."); + + long range = (long)maxValue - minValue; + return (int)((long)(NextDouble() * range) + minValue); + } + + public override void NextBytes(byte[] buffer) + { + Contracts.CheckValue(buffer, nameof(buffer)); + + for (int i = 0; i < buffer.Length; i++) + { + buffer[i] = (byte)Next(); + } + } + public int NextSigned() { NextState(); @@ -270,23 +222,23 @@ private uint GetUint() return _z1 ^ _z2 ^ _z3 ^ _z4; } - private Single GetSingle() + private float GetSingle() { - const Single scale = (Single)1 / (1 << 23); + const float scale = (float)1 / (1 << 23); // Drop the low 9 bits so the conversion to Single is exact. Allowing rounding would cause // issues with biasing values and, worse, the possibility of returning exactly 1. uint u = GetUint() >> 9; - Contracts.Assert((uint)(Single)u == u); + Contracts.Assert((uint)(float)u == u); - return (Single)u * scale; + return (float)u * scale; } - private Double GetDouble() + private double GetDouble() { - const Double scale = (Double)1 / (1 << 16) / (1 << 16); + const double scale = (double)1 / (1 << 16) / (1 << 16); uint u = GetUint(); - return (Double)u * scale; + return (double)u * scale; } private void NextState() @@ -315,97 +267,4 @@ public State GetState() return new State(_z1, _z2, _z3, _z4); } } - -#if false // REVIEW: This was written for NN drop out but turned out to be too slow, so I inlined it instead. - public sealed class BooleanSampler - { - public const int CbitRand = 25; - - private readonly IRandom _rand; - private readonly uint _k; // probability of "true" is _k / (1U << _qlog). - private readonly int _qlog; // Number of bits consumed by each call to Sample(). - private readonly int _cv; // Number of calls to Sample() covered by a call to _rand.Next(...). - private readonly uint _mask; // (1U << _qlog) - 1 - - // Mutable state. - private int _c; - private uint _v; - - /// - /// Create a boolean sampler using the given random number generator, quantizing the true rate - /// to cbitQuant bits, assuming that sampling the random number generator is capable of producing - /// cbitRand good bits. - /// - /// For example, new BooleanSampler(0.5f, 1, 25, new Random()) will produce a reasonable fair coin flipper. - /// Note that this reduces the parameters, so new BooleanSampler(0.5f, 6, 25, new Random()) will produce - /// the same flipper. In other words, since 0.5 quantized to 6 bits can be reduced to only needing one - /// bit, it reduces cbitQuant to 1. - /// - public static BooleanSampler Create(Single rate, int cbitQuant, IRandom rand) - { - Contracts.Assert(0 < rate && rate < 1); - Contracts.Assert(0 < cbitQuant && cbitQuant <= CbitRand / 2); - - int qlog = cbitQuant; - uint k = (uint)(rate * (1 << qlog)); - if (k == 0) - k = 1; - Contracts.Assert(0 <= k && k < (1U << qlog)); - - while ((k & 1) == 0 && k > 0) - { - qlog--; - k >>= 1; - } - Contracts.Assert(qlog > 0); - uint q = 1U << qlog; - Contracts.Assert(0 < k && k < q); - - int cv = CbitRand / qlog; - Contracts.Assert(cv > 1); - return new BooleanSampler(qlog, k, rand); - } - - private BooleanSampler(int qlog, uint k, IRandom rand) - { - _qlog = qlog; - _k = k; - _rand = rand; - _qlog = qlog; - _cv = CbitRand / _qlog; - _mask = (1U << _qlog) - 1; - } - - public bool Sample() - { - _v >>= _qlog; - if (--_c <= 0) - { - _v = (uint)_rand.Next(1 << (_cv * _qlog)); - _c = _cv; - } - return (_v & _mask) < _k; - } - - public void SampleMany(out uint bits, out int count) - { - uint u = (uint)_rand.Next(1 << (_cv * _qlog)); - count = _cv; - if (_qlog == 1) - { - bits = u; - return; - } - - bits = 0; - for (int i = 0; i < count; i++) - { - bits <<= 1; - if ((u & _mask) < _k) - bits |= 1; - u >>= _qlog; - } - } - } -#endif } \ No newline at end of file diff --git a/src/Microsoft.ML.Core/Utilities/ReservoirSampler.cs b/src/Microsoft.ML.Core/Utilities/ReservoirSampler.cs index 8438a2e742..f442fd8e59 100644 --- a/src/Microsoft.ML.Core/Utilities/ReservoirSampler.cs +++ b/src/Microsoft.ML.Core/Utilities/ReservoirSampler.cs @@ -2,11 +2,12 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; using System.Collections.Generic; using System.Linq; -using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Data; -namespace Microsoft.ML.Runtime.Internal.Utilities +namespace Microsoft.ML.Internal.Utilities { /// /// This is an interface for creating samples of a requested size from a stream of data of type . @@ -56,7 +57,7 @@ internal sealed class ReservoirSamplerWithoutReplacement : IReservoirSampler< // This array contains a cache of the elements composing the reservoir. private readonly T[] _cache; - private readonly IRandom _rnd; + private readonly Random _rnd; private long _numSampled; private readonly ValueGetter _getter; @@ -67,7 +68,7 @@ internal sealed class ReservoirSamplerWithoutReplacement : IReservoirSampler< public long NumSampled { get { return _numSampled; } } - public ReservoirSamplerWithoutReplacement(IRandom rnd, int size, ValueGetter getter) + public ReservoirSamplerWithoutReplacement(Random rnd, int size, ValueGetter getter) { Contracts.CheckValue(rnd, nameof(rnd)); Contracts.CheckParam(size > 0, nameof(size), "Reservoir size must be positive"); @@ -135,7 +136,7 @@ internal sealed class ReservoirSamplerWithReplacement : IReservoirSampler private readonly T[] _cache; private readonly int[] _counts; - private readonly IRandom _rnd; + private readonly Random _rnd; private long _numSampled; private readonly ValueGetter _getter; @@ -146,7 +147,7 @@ internal sealed class ReservoirSamplerWithReplacement : IReservoirSampler public long NumSampled { get { return _numSampled; } } - public ReservoirSamplerWithReplacement(IRandom rnd, int size, ValueGetter getter) + public ReservoirSamplerWithReplacement(Random rnd, int size, ValueGetter getter) { Contracts.CheckValue(rnd, nameof(rnd)); Contracts.CheckParam(size > 0, nameof(size), "Reservoir size must be positive"); diff --git a/src/Microsoft.ML.Core/Utilities/ResourceManagerUtils.cs b/src/Microsoft.ML.Core/Utilities/ResourceManagerUtils.cs index 9e9c6f80bb..0813dac0fe 100644 --- a/src/Microsoft.ML.Core/Utilities/ResourceManagerUtils.cs +++ b/src/Microsoft.ML.Core/Utilities/ResourceManagerUtils.cs @@ -10,7 +10,7 @@ using System.Threading; using System.Threading.Tasks; -namespace Microsoft.ML.Runtime.Internal.Utilities +namespace Microsoft.ML.Internal.Utilities { /// /// This class takes care of downloading resources needed by ML.NET components. Resources are located in diff --git a/src/Microsoft.ML.Core/Utilities/Stats.cs b/src/Microsoft.ML.Core/Utilities/Stats.cs index 182239adde..fad719e80d 100644 --- a/src/Microsoft.ML.Core/Utilities/Stats.cs +++ b/src/Microsoft.ML.Core/Utilities/Stats.cs @@ -2,11 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Float = System.Single; - using System; +using Float = System.Single; -namespace Microsoft.ML.Runtime.Internal.Utilities +namespace Microsoft.ML.Internal.Utilities { /// /// A class containing common statistical functions @@ -20,7 +19,7 @@ internal static class Stats /// Size of range to sample from, between 0 and int.MaxValue^2 /// Random number generator /// Sampled value - public static long SampleLong(long rangeSize, IRandom rand) + public static long SampleLong(long rangeSize, Random rand) { Contracts.CheckParam(rangeSize > 0, nameof(rangeSize), "rangeSize must be positive."); @@ -51,7 +50,7 @@ public static long SampleLong(long rangeSize, IRandom rand) /// A Random to use for the sampling /// a sample /// uses Joseph L. Leva's algorithm from "A fast normal random number generator", 1992 - public static double SampleFromGaussian(IRandom rand) + public static double SampleFromGaussian(Random rand) { double u; double v; @@ -75,7 +74,7 @@ public static double SampleFromGaussian(IRandom rand) /// The random number generator to use /// Sample from gamma distribution /// Uses Marsaglia and Tsang's fast algorithm - public static double SampleFromGamma(IRandom r, double alpha) + public static double SampleFromGamma(Random r, double alpha) { Contracts.CheckParam(alpha > 0, nameof(alpha), "alpha must be positive"); @@ -111,7 +110,7 @@ public static double SampleFromGamma(IRandom r, double alpha) /// first parameter /// second parameter /// Sample from distribution - public static double SampleFromBeta(IRandom rand, double alpha1, double alpha2) + public static double SampleFromBeta(Random rand, double alpha1, double alpha2) { double gamma1 = SampleFromGamma(rand, alpha1); double gamma2 = SampleFromGamma(rand, alpha2); @@ -124,7 +123,7 @@ public static double SampleFromBeta(IRandom rand, double alpha1, double alpha2) /// Random generator to use /// array of parameters /// array in which to store resulting sample - public static void SampleFromDirichlet(IRandom rand, double[] alphas, double[] result) + public static void SampleFromDirichlet(Random rand, double[] alphas, double[] result) { Contracts.Check(alphas.Length == result.Length, "Dirichlet parameters must have the same dimensionality as sample space."); @@ -141,7 +140,7 @@ public static void SampleFromDirichlet(IRandom rand, double[] alphas, double[] r } } - public static int SampleFromPoisson(IRandom rand, double lambda) + public static int SampleFromPoisson(Random rand, double lambda) { if (lambda < 5) { @@ -201,7 +200,7 @@ public static int SampleFromPoisson(IRandom rand, double lambda) // Mean refers to the mu parameter. Scale refers to the b parameter. // https://en.wikipedia.org/wiki/Laplace_distribution - public static Float SampleFromLaplacian(IRandom rand, Float mean, Float scale) + public static Float SampleFromLaplacian(Random rand, Float mean, Float scale) { Float u = rand.NextSingle(); u = u - 0.5f; @@ -220,7 +219,7 @@ public static Float SampleFromLaplacian(IRandom rand, Float mean, Float scale) /// /// /// - public static Float SampleFromCauchy(IRandom rand) + public static Float SampleFromCauchy(Random rand) { return (Float)Math.Tan(Math.PI * (rand.NextSingle() - 0.5)); } @@ -233,7 +232,7 @@ public static Float SampleFromCauchy(IRandom rand) /// Parameter p of binomial /// /// Should be robust for all values of n, p - public static int SampleFromBinomial(IRandom r, int n, double p) + public static int SampleFromBinomial(Random r, int n, double p) { return BinoRand.Next(r, n, p); } @@ -266,7 +265,7 @@ private static double Fc(int k) } } - public static int Next(IRandom rand, int n, double p) + public static int Next(Random rand, int n, double p) { int x; double pin = Math.Min(p, 1 - p); @@ -284,7 +283,7 @@ public static int Next(IRandom rand, int n, double p) // For small n // Inverse transformation algorithm // Described in Kachitvichyanukul and Schmeiser: "Binomial Random Variate Generation" - private static int InvTransform(int n, double p, IRandom rn) + private static int InvTransform(int n, double p, Random rn) { int x = 0; double u = rn.NextDouble(); @@ -308,7 +307,7 @@ private static int InvTransform(int n, double p, IRandom rn) // For large n // Algorithm from W. Hormann: "The Generation of Binomial Random Variables" // This is algorithm BTRD - private static int GenerateLarge(int n, double p, IRandom rn) + private static int GenerateLarge(int n, double p, Random rn) { double np = n * p; double q = 1 - p; diff --git a/src/Microsoft.ML.Core/Utilities/Stream.cs b/src/Microsoft.ML.Core/Utilities/Stream.cs index 4fe21e7df7..633ee672f7 100644 --- a/src/Microsoft.ML.Core/Utilities/Stream.cs +++ b/src/Microsoft.ML.Core/Utilities/Stream.cs @@ -9,7 +9,7 @@ using System.Text; using System.Threading; -namespace Microsoft.ML.Runtime.Internal.Utilities +namespace Microsoft.ML.Internal.Utilities { internal static partial class Utils { diff --git a/src/Microsoft.ML.Core/Utilities/SubsetStream.cs b/src/Microsoft.ML.Core/Utilities/SubsetStream.cs index 3bf9ad2c9e..022ca28cb1 100644 --- a/src/Microsoft.ML.Core/Utilities/SubsetStream.cs +++ b/src/Microsoft.ML.Core/Utilities/SubsetStream.cs @@ -5,7 +5,7 @@ using System; using System.IO; -namespace Microsoft.ML.Runtime.Internal.Utilities +namespace Microsoft.ML.Internal.Utilities { /// /// Returns a "view" stream, which appears to be a possibly truncated diff --git a/src/Microsoft.ML.Core/Utilities/SummaryStatistics.cs b/src/Microsoft.ML.Core/Utilities/SummaryStatistics.cs index 3e36191564..3382885e9c 100644 --- a/src/Microsoft.ML.Core/Utilities/SummaryStatistics.cs +++ b/src/Microsoft.ML.Core/Utilities/SummaryStatistics.cs @@ -4,7 +4,7 @@ using System; -namespace Microsoft.ML.Runtime.Internal.Utilities +namespace Microsoft.ML.Internal.Utilities { internal abstract class SummaryStatisticsBase { diff --git a/src/Microsoft.ML.Core/Utilities/SupervisedBinFinder.cs b/src/Microsoft.ML.Core/Utilities/SupervisedBinFinder.cs index 63257823a2..f640043009 100644 --- a/src/Microsoft.ML.Core/Utilities/SupervisedBinFinder.cs +++ b/src/Microsoft.ML.Core/Utilities/SupervisedBinFinder.cs @@ -2,13 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Float = System.Single; - using System; using System.Collections.Generic; using System.Diagnostics; -namespace Microsoft.ML.Runtime.Internal.Utilities +namespace Microsoft.ML.Internal.Utilities { /// /// This class performs discretization of (value, label) pairs into bins in a way that minimizes diff --git a/src/Microsoft.ML.Core/Utilities/TextReaderStream.cs b/src/Microsoft.ML.Core/Utilities/TextReaderStream.cs index 682a0336cb..691876c0c7 100644 --- a/src/Microsoft.ML.Core/Utilities/TextReaderStream.cs +++ b/src/Microsoft.ML.Core/Utilities/TextReaderStream.cs @@ -6,7 +6,7 @@ using System.IO; using System.Text; -namespace Microsoft.ML.Runtime.Internal.Utilities +namespace Microsoft.ML.Internal.Utilities { /// /// A readable that is backed by a . diff --git a/src/Microsoft.ML.Core/Utilities/ThreadUtils.cs b/src/Microsoft.ML.Core/Utilities/ThreadUtils.cs index 46a82a4e7c..3585b36f23 100644 --- a/src/Microsoft.ML.Core/Utilities/ThreadUtils.cs +++ b/src/Microsoft.ML.Core/Utilities/ThreadUtils.cs @@ -8,7 +8,7 @@ using System.Threading; using System.Threading.Tasks; -namespace Microsoft.ML.Runtime.Internal.Utilities +namespace Microsoft.ML.Internal.Utilities { internal static partial class Utils { diff --git a/src/Microsoft.ML.Core/Utilities/Tree.cs b/src/Microsoft.ML.Core/Utilities/Tree.cs index 8d4c0f7585..cbec20fe8e 100644 --- a/src/Microsoft.ML.Core/Utilities/Tree.cs +++ b/src/Microsoft.ML.Core/Utilities/Tree.cs @@ -5,7 +5,7 @@ using System.Collections; using System.Collections.Generic; -namespace Microsoft.ML.Runtime.Internal.Utilities +namespace Microsoft.ML.Internal.Utilities { /// /// The tree structure is simultaneously a tree, and a node in a tree. The interface to diff --git a/src/Microsoft.ML.Core/Utilities/Utils.cs b/src/Microsoft.ML.Core/Utilities/Utils.cs index d0f6f6256e..c6e21486ca 100644 --- a/src/Microsoft.ML.Core/Utilities/Utils.cs +++ b/src/Microsoft.ML.Core/Utilities/Utils.cs @@ -2,8 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Float = System.Single; - using System; using System.Collections; using System.Collections.Generic; @@ -13,7 +11,7 @@ using System.Text.RegularExpressions; using System.Threading; -namespace Microsoft.ML.Runtime.Internal.Utilities +namespace Microsoft.ML.Internal.Utilities { [BestFriend] @@ -218,10 +216,10 @@ public static int FindIndexSorted(this IList input, int value) /// In case of duplicates it returns the index of the first one. /// It guarantees that items before the returned index are < value, while those at and after the returned index are >= value. /// - public static int FindIndexSorted(this Single[] input, Single value) + public static int FindIndexSorted(this IList input, float value) { Contracts.AssertValue(input); - return FindIndexSorted(input, 0, input.Length, value); + return FindIndexSorted(input, 0, input.Count, value); } /// @@ -342,11 +340,11 @@ public static int FindIndexSorted(this IList input, int min, int lim, int v /// In case of duplicates it returns the index of the first one. /// It guarantees that items before the returned index are < value, while those at and after the returned index are >= value. /// - public static int FindIndexSorted(this Single[] input, int min, int lim, Single value) + public static int FindIndexSorted(this IList input, int min, int lim, float value) { Contracts.AssertValue(input); - Contracts.Assert(0 <= min & min <= lim & lim <= input.Length); - Contracts.Assert(!Single.IsNaN(value)); + Contracts.Assert(0 <= min & min <= lim & lim <= input.Count); + Contracts.Assert(!float.IsNaN(value)); int minCur = min; int limCur = lim; @@ -354,7 +352,7 @@ public static int FindIndexSorted(this Single[] input, int min, int lim, Single { int mid = (int)(((uint)minCur + (uint)limCur) / 2); Contracts.Assert(minCur <= mid & mid < limCur); - Contracts.Assert(!Single.IsNaN(input[mid])); + Contracts.Assert(!float.IsNaN(input[mid])); if (input[mid] >= value) limCur = mid; @@ -530,56 +528,10 @@ public static int[] GetRandomPermutation(Random rand, int size) Contracts.Assert(size >= 0); var res = GetIdentityPermutation(size); - Shuffle(rand, res); + Shuffle(rand, res); return res; } - public static int[] GetRandomPermutation(IRandom rand, int size) - { - Contracts.AssertValue(rand); - Contracts.Assert(size >= 0); - - var res = GetIdentityPermutation(size); - Shuffle(rand, res); - return res; - } - - public static void Shuffle(IRandom rand, T[] rgv) - { - Contracts.AssertValue(rand); - Contracts.AssertValue(rgv); - - Shuffle(rand, rgv, 0, rgv.Length); - } - - public static void Shuffle(Random rand, T[] rgv) - { - Contracts.AssertValue(rand); - Contracts.AssertValue(rgv); - - Shuffle(rand, rgv, 0, rgv.Length); - } - - public static void Shuffle(IRandom rand, T[] rgv, int min, int lim) - { - Contracts.AssertValue(rand); - Contracts.AssertValue(rgv); - Contracts.Check(0 <= min & min <= lim & lim <= rgv.Length); - - for (int iv = min; iv < lim; iv++) - Swap(ref rgv[iv], ref rgv[iv + rand.Next(lim - iv)]); - } - - public static void Shuffle(Random rand, T[] rgv, int min, int lim) - { - Contracts.AssertValue(rand); - Contracts.AssertValue(rgv); - Contracts.Check(0 <= min & min <= lim & lim <= rgv.Length); - - for (int iv = min; iv < lim; iv++) - Swap(ref rgv[iv], ref rgv[iv + rand.Next(lim - iv)]); - } - public static bool AreEqual(Single[] arr1, Single[] arr2) { if (arr1 == arr2) @@ -614,6 +566,14 @@ public static bool AreEqual(Double[] arr1, Double[] arr2) return true; } + public static void Shuffle(Random rand, Span rgv) + { + Contracts.AssertValue(rand); + + for (int iv = 0; iv < rgv.Length; iv++) + Swap(ref rgv[iv], ref rgv[iv + rand.Next(rgv.Length - iv)]); + } + public static bool AreEqual(int[] arr1, int[] arr2) { if (arr1 == arr2) @@ -653,14 +613,14 @@ public static string ExtractLettersAndNumbers(string value) return Regex.Replace(value, "[^A-Za-z0-9]", ""); } - public static bool IsSorted(Float[] values) + public static bool IsSorted(IList values) { if (Utils.Size(values) <= 1) return true; var prev = values[0]; - for (int i = 1; i < values.Length; i++) + for (int i = 1; i < values.Count; i++) { if (!(values[i] >= prev)) return false; @@ -1104,6 +1064,15 @@ public static void MarshalActionInvoke(Action + /// A four-argument version of . + /// + public static void MarshalActionInvoke(Action act, Type genArg, TArg1 arg1, TArg2 arg2, TArg3 arg3, TArg4 arg4) + { + var meth = MarshalActionInvokeCheckAndCreate(genArg, act); + meth.Invoke(act.Target, new object[] { arg1, arg2, arg3, arg4 }); + } + public static string GetDescription(this Enum value) { Type type = value.GetType(); diff --git a/src/Microsoft.ML.Core/Utilities/VBufferUtils.cs b/src/Microsoft.ML.Core/Utilities/VBufferUtils.cs index f730d61724..9120ccbedd 100644 --- a/src/Microsoft.ML.Core/Utilities/VBufferUtils.cs +++ b/src/Microsoft.ML.Core/Utilities/VBufferUtils.cs @@ -4,9 +4,9 @@ using System; using System.Collections.Generic; -using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Data; -namespace Microsoft.ML.Runtime.Internal.Utilities +namespace Microsoft.ML.Internal.Utilities { // REVIEW: Consider automatic densification in some of the operations, where appropriate. // REVIEW: Once we do the conversions from Vector/WritableVector, review names of methods, @@ -405,7 +405,7 @@ public static void ApplyAt(ref VBuffer dst, int slot, SlotValueManipulator // we are modifying in the sparse vector, in which case the vector becomes // dense. Then there is no need to do anything with indices. bool needIndices = dstValuesCount + 1 < dst.Length; - editor = VBufferEditor.Create(ref dst, dst.Length, dstValuesCount + 1); + editor = VBufferEditor.Create(ref dst, dst.Length, dstValuesCount + 1, keepOldOnResize: true); if (idx != dstValuesCount) { // We have to do some sort of shift copy. @@ -1322,7 +1322,7 @@ public static void ApplyInto(in VBuffer a, in VBuffer // REVIEW: Worth optimizing the newCount == a.Length case? // Probably not... - editor = VBufferEditor.Create(ref dst, a.Length, newCount); + editor = VBufferEditor.Create(ref dst, a.Length, newCount, requireIndicesOnDense: true); Span indices = editor.Indices; if (newCount == bValues.Length) diff --git a/src/Microsoft.ML.CpuMath/AlignedArray.cs b/src/Microsoft.ML.CpuMath/AlignedArray.cs index 4602e884e4..6687022e8f 100644 --- a/src/Microsoft.ML.CpuMath/AlignedArray.cs +++ b/src/Microsoft.ML.CpuMath/AlignedArray.cs @@ -2,20 +2,18 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime.Internal.CpuMath.Core; using System; +using Microsoft.ML.Internal.CpuMath.Core; -namespace Microsoft.ML.Runtime.Internal.CpuMath +namespace Microsoft.ML.Internal.CpuMath { - using Float = System.Single; - /// - /// This implements a logical array of Floats that is automatically aligned for SSE/AVX operations. + /// This implements a logical array of floats that is automatically aligned for SSE/AVX operations. /// To pin and force alignment, call the GetPin method, typically wrapped in a using (since it /// returns a Pin struct that is IDisposable). From the pin, you can get the IntPtr to pass to /// native code. /// - /// The ctor takes an alignment value, which must be a power of two at least sizeof(Float). + /// The ctor takes an alignment value, which must be a power of two at least sizeof(float). /// [BestFriend] internal sealed class AlignedArray @@ -24,7 +22,7 @@ internal sealed class AlignedArray // items, also filled with NaN. Note that _size * sizeof(Float) is divisible by _cbAlign. // It is illegal to access any slot outsize [_base, _base + _size). This is internal so clients // can easily pin it. - public Float[] Items; + public float[] Items; private readonly int _size; // Must be divisible by (_cbAlign / sizeof(Float)). private readonly int _cbAlign; // The alignment in bytes, a power of two, divisible by sizeof(Float). @@ -40,12 +38,12 @@ public AlignedArray(int size, int cbAlign) { Contracts.Assert(0 < size); // cbAlign should be a power of two. - Contracts.Assert(sizeof(Float) <= cbAlign); + Contracts.Assert(sizeof(float) <= cbAlign); Contracts.Assert((cbAlign & (cbAlign - 1)) == 0); // cbAlign / sizeof(Float) should divide size. - Contracts.Assert((size * sizeof(Float)) % cbAlign == 0); + Contracts.Assert((size * sizeof(float)) % cbAlign == 0); - Items = new Float[size + cbAlign / sizeof(Float)]; + Items = new float[size + cbAlign / sizeof(float)]; _size = size; _cbAlign = cbAlign; _lock = new object(); @@ -54,15 +52,15 @@ public AlignedArray(int size, int cbAlign) public unsafe int GetBase(long addr) { #if DEBUG - fixed (Float* pv = Items) - Contracts.Assert((Float*)addr == pv); + fixed (float* pv = Items) + Contracts.Assert((float*)addr == pv); #endif int cbLow = (int)(addr & (_cbAlign - 1)); int ibMin = cbLow == 0 ? 0 : _cbAlign - cbLow; - Contracts.Assert(ibMin % sizeof(Float) == 0); + Contracts.Assert(ibMin % sizeof(float) == 0); - int ifltMin = ibMin / sizeof(Float); + int ifltMin = ibMin / sizeof(float); if (ifltMin == _base) return _base; @@ -71,9 +69,9 @@ public unsafe int GetBase(long addr) // Anything outsize [_base, _base + _size) should not be accessed, so // set them to NaN, for debug validation. for (int i = 0; i < _base; i++) - Items[i] = Float.NaN; + Items[i] = float.NaN; for (int i = _base + _size; i < Items.Length; i++) - Items[i] = Float.NaN; + Items[i] = float.NaN; #endif return _base; } @@ -96,7 +94,7 @@ private void MoveData(int newBase) public int CbAlign { get { return _cbAlign; } } - public Float this[int index] + public float this[int index] { get { @@ -110,30 +108,30 @@ public Float this[int index] } } - public void CopyTo(Float[] dst, int index, int count) + public void CopyTo(Span dst, int index, int count) { Contracts.Assert(0 <= count && count <= _size); Contracts.Assert(dst != null); Contracts.Assert(0 <= index && index <= dst.Length - count); - Array.Copy(Items, _base, dst, index, count); + Items.AsSpan(_base, count).CopyTo(dst.Slice(index)); } - public void CopyTo(int start, Float[] dst, int index, int count) + public void CopyTo(int start, Span dst, int index, int count) { Contracts.Assert(0 <= count); Contracts.Assert(0 <= start && start <= _size - count); Contracts.Assert(dst != null); Contracts.Assert(0 <= index && index <= dst.Length - count); - Array.Copy(Items, start + _base, dst, index, count); + Items.AsSpan(start + _base, count).CopyTo(dst.Slice(index)); } - public void CopyFrom(ReadOnlySpan src) + public void CopyFrom(ReadOnlySpan src) { Contracts.Assert(src.Length <= _size); src.CopyTo(Items.AsSpan(_base)); } - public void CopyFrom(int start, ReadOnlySpan src) + public void CopyFrom(int start, ReadOnlySpan src) { Contracts.Assert(0 <= start && start <= _size - src.Length); src.CopyTo(Items.AsSpan(start + _base)); @@ -143,7 +141,7 @@ public void CopyFrom(int start, ReadOnlySpan src) // valuesSrc contains only the non-zero entries. Those are copied into their logical positions in the dense array. // rgposSrc contains the logical positions + offset of the non-zero entries in the dense array. // rgposSrc runs parallel to the valuesSrc array. - public void CopyFrom(ReadOnlySpan rgposSrc, ReadOnlySpan valuesSrc, int posMin, int iposMin, int iposLim, bool zeroItems) + public void CopyFrom(ReadOnlySpan rgposSrc, ReadOnlySpan valuesSrc, int posMin, int iposMin, int iposLim, bool zeroItems) { Contracts.Assert(rgposSrc != null); Contracts.Assert(valuesSrc != null); @@ -202,7 +200,7 @@ public void ZeroItems(int[] rgposSrc, int posMin, int iposMin, int iposLim) // REVIEW: This is hackish and slightly dangerous. Perhaps we should wrap this in an // IDisposable that "locks" this, prohibiting GetBase from being called, while the buffer // is "checked out". - public void GetRawBuffer(out Float[] items, out int offset) + public void GetRawBuffer(out float[] items, out int offset) { items = Items; offset = _base; diff --git a/src/Microsoft.ML.CpuMath/AlignedMatrix.cs b/src/Microsoft.ML.CpuMath/AlignedMatrix.cs index 6d550fc3fc..d9dee9a868 100644 --- a/src/Microsoft.ML.CpuMath/AlignedMatrix.cs +++ b/src/Microsoft.ML.CpuMath/AlignedMatrix.cs @@ -2,14 +2,13 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Float = System.Single; - -using Microsoft.ML.Runtime.Internal.CpuMath.Core; using System; using System.Collections; using System.Collections.Generic; +using Microsoft.ML.Internal.CpuMath.Core; +using Float = System.Single; -namespace Microsoft.ML.Runtime.Internal.CpuMath +namespace Microsoft.ML.Internal.CpuMath { using Conditional = System.Diagnostics.ConditionalAttribute; diff --git a/src/Microsoft.ML.CpuMath/AssemblyInfo.cs b/src/Microsoft.ML.CpuMath/AssemblyInfo.cs index 7710703c29..65557d90f9 100644 --- a/src/Microsoft.ML.CpuMath/AssemblyInfo.cs +++ b/src/Microsoft.ML.CpuMath/AssemblyInfo.cs @@ -2,8 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime.Internal.CpuMath.Core; using System.Runtime.CompilerServices; +using Microsoft.ML.Internal.CpuMath.Core; [assembly: InternalsVisibleTo("Microsoft.ML.CpuMath.UnitTests.netstandard" + PublicKey.TestValue)] [assembly: InternalsVisibleTo("Microsoft.ML.CpuMath.UnitTests.netcoreapp" + PublicKey.TestValue)] diff --git a/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs b/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs index 55e692eb63..ba86b9158f 100644 --- a/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs +++ b/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs @@ -14,9 +14,10 @@ using System.Runtime.InteropServices; using System.Runtime.Intrinsics; using System.Runtime.Intrinsics.X86; +using Microsoft.ML.Internal.CpuMath.Core; using nuint = System.UInt64; -namespace Microsoft.ML.Runtime.Internal.CpuMath +namespace Microsoft.ML.Internal.CpuMath { internal static class AvxIntrinsics { @@ -46,6 +47,25 @@ internal static class AvxIntrinsics private static readonly Vector256 _absMask256 = Avx.StaticCast(Avx.SetAllVector256(0x7FFFFFFF)); + private const int Vector256Alignment = 32; + + [MethodImplAttribute(MethodImplOptions.AggressiveInlining)] + private static bool HasCompatibleAlignment(AlignedArray alignedArray) + { + Contracts.AssertValue(alignedArray); + Contracts.Assert(alignedArray.Size > 0); + return (alignedArray.CbAlign % Vector256Alignment) == 0; + } + + [MethodImplAttribute(MethodImplOptions.AggressiveInlining)] + private static unsafe float* GetAlignedBase(AlignedArray alignedArray, float* unalignedBase) + { + Contracts.AssertValue(alignedArray); + float* alignedBase = unalignedBase + alignedArray.GetBase((long)unalignedBase); + Contracts.Assert(((long)alignedBase % Vector256Alignment) == 0); + return alignedBase; + } + [MethodImplAttribute(MethodImplOptions.AggressiveInlining)] private static Vector128 GetHigh(in Vector256 x) => Avx.ExtractVector128(x, 1); @@ -153,20 +173,18 @@ private static Vector256 MultiplyAdd(Vector256 src1, Vector256 mat, ReadOnlySpan src, Span dst, int crow, int ccol) - { - Contracts.Assert(crow % 4 == 0); - Contracts.Assert(ccol % 4 == 0); + Contracts.Assert(HasCompatibleAlignment(mat)); + Contracts.Assert(HasCompatibleAlignment(src)); + Contracts.Assert(HasCompatibleAlignment(dst)); - fixed (float* psrc = &MemoryMarshal.GetReference(src)) - fixed (float* pdst = &MemoryMarshal.GetReference(dst)) - fixed (float* pmat = &MemoryMarshal.GetReference(mat)) - fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0]) - fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0]) + fixed (float* pSrcStart = &src.Items[0]) + fixed (float* pDstStart = &dst.Items[0]) + fixed (float* pMatStart = &mat.Items[0]) { + float* psrc = GetAlignedBase(src, pSrcStart); + float* pdst = GetAlignedBase(dst, pDstStart); + float* pmat = GetAlignedBase(mat, pMatStart); + float* pSrcEnd = psrc + ccol; float* pDstEnd = pdst + crow; float* pDstCurrent = pdst; @@ -175,118 +193,36 @@ public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan sr while (pDstCurrent < pDstEnd) { Vector256 res0 = Avx.SetZeroVector256(); - Vector256 res1 = Avx.SetZeroVector256(); - Vector256 res2 = Avx.SetZeroVector256(); - Vector256 res3 = Avx.SetZeroVector256(); + Vector256 res1 = res0; + Vector256 res2 = res0; + Vector256 res3 = res0; - int length = ccol; float* pSrcCurrent = psrc; - nuint address = (nuint)(pMatCurrent); - int misalignment = (int)(address % 32); - int remainder = 0; - - if ((misalignment & 3) != 0) + while (pSrcCurrent < pSrcEnd) { - // Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations - while (pSrcCurrent < pSrcEnd) - { - Vector256 vector = Avx.LoadVector256(pSrcCurrent); - - float* pMatTemp = pMatCurrent; - res0 = MultiplyAdd(pMatTemp, vector, res0); - res1 = MultiplyAdd(pMatTemp += ccol, vector, res1); - res2 = MultiplyAdd(pMatTemp += ccol, vector, res2); - res3 = MultiplyAdd(pMatTemp += ccol, vector, res3); - - pSrcCurrent += 8; - pMatCurrent += 8; - } - } - else - { - if (misalignment != 0) - { - // Handle cases where the data is not 256-bit aligned by doing an unaligned read and then - // masking any elements that will be included in the first aligned read - misalignment >>= 2; - misalignment = 8 - misalignment; - - Vector256 mask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + (misalignment * 8)); - - // We only align pMat since it has significantly more reads. - float* pMatTemp = pMatCurrent; - Vector256 x01 = Avx.And(mask, Avx.LoadVector256(pMatTemp)); - Vector256 x11 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); - Vector256 x21 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); - Vector256 x31 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); - Vector256 vector = Avx.And(mask, Avx.LoadVector256(pSrcCurrent)); - - res0 = Avx.Multiply(x01, vector); - res1 = Avx.Multiply(x11, vector); - res2 = Avx.Multiply(x21, vector); - res3 = Avx.Multiply(x31, vector); - - pMatCurrent += misalignment; - pSrcCurrent += misalignment; - length -= misalignment; - } - - if (length > 7) - { - // Handle all the 256-bit blocks that we can now that we have offset to an aligned address - remainder = length % 8; - - while (pSrcCurrent + 8 <= pSrcEnd) - { - // If we aren't using the VEX-encoding, the JIT will only fold away aligned loads - // (due to semantics of the legacy encoding). - // We don't need an assert, since the instruction will throw for unaligned inputs. - Vector256 vector = Avx.LoadVector256(pSrcCurrent); - - float* pMatTemp = pMatCurrent; - res0 = MultiplyAdd(pMatTemp, vector, res0); - res1 = MultiplyAdd(pMatTemp += ccol, vector, res1); - res2 = MultiplyAdd(pMatTemp += ccol, vector, res2); - res3 = MultiplyAdd(pMatTemp += ccol, vector, res3); - - pSrcCurrent += 8; - pMatCurrent += 8; - } - } - else - { - // Handle the "worst-case" scenario, which is when we have 8-16 elements and the input is not - // 256-bit aligned. This means we can't do any aligned loads and will just end up doing two - // unaligned loads where we mask the input each time. - remainder = length; - } - - if (remainder != 0) - { - // Handle any trailing elements that don't fit into a 256-bit block by moving back so that the next - // unaligned load will read to the end of the array and then mask out any elements already processed - - pMatCurrent -= (8 - remainder); - pSrcCurrent -= (8 - remainder); - - Vector256 mask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8)); - - float* pMatTemp = pMatCurrent; - Vector256 x01 = Avx.And(mask, Avx.LoadVector256(pMatTemp)); - Vector256 x11 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); - Vector256 x21 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); - Vector256 x31 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); - Vector256 vector = Avx.And(mask, Avx.LoadVector256(pSrcCurrent)); - - res0 = MultiplyAdd(x01, vector, res0); - res1 = MultiplyAdd(x11, vector, res1); - res2 = MultiplyAdd(x21, vector, res2); - res3 = MultiplyAdd(x31, vector, res3); - - pMatCurrent += 8; - pSrcCurrent += 8; - } + float* pMatTemp = pMatCurrent; + Contracts.Assert(((nuint)(pMatTemp) % 32) == 0); + Contracts.Assert(((nuint)(pSrcCurrent) % 32) == 0); + + // The JIT will only fold away unaligned loads due to the semantics behind + // the VEX-encoding of the memory operand for `ins xmm, xmm, [mem]`. Since + // modern hardware has unaligned loads that are as fast as aligned loads, + // when it doesn't cross a cache-line/page boundary, we will just assert + // that the alignment is correct and allow for the more-efficient codegen. + Vector256 x01 = Avx.LoadVector256(pMatTemp); + Vector256 x11 = Avx.LoadVector256(pMatTemp += ccol); + Vector256 x21 = Avx.LoadVector256(pMatTemp += ccol); + Vector256 x31 = Avx.LoadVector256(pMatTemp += ccol); + Vector256 x02 = Avx.LoadVector256(pSrcCurrent); + + res0 = MultiplyAdd(x01, x02, res0); + res1 = MultiplyAdd(x11, x02, res1); + res2 = MultiplyAdd(x21, x02, res2); + res3 = MultiplyAdd(x31, x02, res3); + + pSrcCurrent += 8; + pMatCurrent += 8; } // Add up the entries of each, with the 4 results in res0 @@ -295,7 +231,7 @@ public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan sr res0 = Avx.HorizontalAdd(res0, res2); Vector128 sum = Sse.Add(Avx.GetLowerHalf(res0), GetHigh(in res0)); - Sse.Store(pDstCurrent, sum); + Sse.StoreAligned(pDstCurrent, sum); pDstCurrent += 4; pMatCurrent += 3 * ccol; @@ -307,24 +243,21 @@ public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan sr public static unsafe void MatMulP(AlignedArray mat, ReadOnlySpan rgposSrc, AlignedArray src, int posMin, int iposMin, int iposEnd, AlignedArray dst, int crow, int ccol) { - MatMulP(mat.Items, rgposSrc, src.Items, posMin, iposMin, iposEnd, dst.Items, crow, ccol); - } - - public static unsafe void MatMulP(ReadOnlySpan mat, ReadOnlySpan rgposSrc, ReadOnlySpan src, - int posMin, int iposMin, int iposEnd, Span dst, int crow, int ccol) - { - Contracts.Assert(crow % 8 == 0); - Contracts.Assert(ccol % 8 == 0); - // REVIEW: For extremely sparse inputs, interchanging the loops would // likely be more efficient. - fixed (float* psrc = &MemoryMarshal.GetReference(src)) - fixed (float* pdst = &MemoryMarshal.GetReference(dst)) - fixed (float* pmat = &MemoryMarshal.GetReference(mat)) - fixed (int* pposSrc = &MemoryMarshal.GetReference(rgposSrc)) - fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0]) - fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0]) + Contracts.Assert(HasCompatibleAlignment(mat)); + Contracts.Assert(HasCompatibleAlignment(src)); + Contracts.Assert(HasCompatibleAlignment(dst)); + + fixed (float* pSrcStart = &src.Items[0]) + fixed (float* pDstStart = &dst.Items[0]) + fixed (float* pMatStart = &mat.Items[0]) + fixed (int* pposSrc = &rgposSrc[0]) { + float* psrc = GetAlignedBase(src, pSrcStart); + float* pdst = GetAlignedBase(dst, pDstStart); + float* pmat = GetAlignedBase(mat, pMatStart); + int* pposMin = pposSrc + iposMin; int* pposEnd = pposSrc + iposEnd; float* pDstEnd = pdst + crow; @@ -332,116 +265,7 @@ public static unsafe void MatMulP(ReadOnlySpan mat, ReadOnlySpan rgp float* pSrcCurrent = psrc - posMin; float* pDstCurrent = pdst; - nuint address = (nuint)(pDstCurrent); - int misalignment = (int)(address % 32); - int length = crow; - int remainder = 0; - - if ((misalignment & 3) != 0) - { - // Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations - while (pDstCurrent < pDstEnd) - { - Avx.Store(pDstCurrent, SparseMultiplicationAcrossRow()); - pDstCurrent += 8; - pm0 += 8 * ccol; - } - } - else - { - if (misalignment != 0) - { - // Handle cases where the data is not 256-bit aligned by doing an unaligned read and then - // masking any elements that will be included in the first aligned read - misalignment >>= 2; - misalignment = 8 - misalignment; - - Vector256 mask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + (misalignment * 8)); - - float* pm1 = pm0 + ccol; - float* pm2 = pm1 + ccol; - float* pm3 = pm2 + ccol; - Vector256 result = Avx.SetZeroVector256(); - - int* ppos = pposMin; - - while (ppos < pposEnd) - { - int col1 = *ppos; - int col2 = col1 + 4 * ccol; - Vector256 x1 = Avx.SetVector256(pm3[col2], pm2[col2], pm1[col2], pm0[col2], - pm3[col1], pm2[col1], pm1[col1], pm0[col1]); - - x1 = Avx.And(mask, x1); - Vector256 x2 = Avx.SetAllVector256(pSrcCurrent[col1]); - result = MultiplyAdd(x2, x1, result); - ppos++; - } - - Avx.Store(pDstCurrent, result); - pDstCurrent += misalignment; - pm0 += misalignment * ccol; - length -= misalignment; - } - - if (length > 7) - { - // Handle all the 256-bit blocks that we can now that we have offset to an aligned address - remainder = length % 8; - while (pDstCurrent < pDstEnd) - { - Avx.Store(pDstCurrent, SparseMultiplicationAcrossRow()); - pDstCurrent += 8; - pm0 += 8 * ccol; - } - } - else - { - // Handle the "worst-case" scenario, which is when we have 8-16 elements and the input is not - // 256-bit aligned. This means we can't do any aligned loads and will just end up doing two - // unaligned loads where we mask the input each time. - remainder = length; - } - - if (remainder != 0) - { - // Handle any trailing elements that don't fit into a 256-bit block by moving back so that the next - // unaligned load will read to the end of the array and then mask out any elements already processed - - pDstCurrent -= (8 - remainder); - pm0 -= (8 - remainder) * ccol; - Vector256 trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8)); - Vector256 leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + ((8 - remainder) * 8)); - - float* pm1 = pm0 + ccol; - float* pm2 = pm1 + ccol; - float* pm3 = pm2 + ccol; - Vector256 result = Avx.SetZeroVector256(); - - int* ppos = pposMin; - - while (ppos < pposEnd) - { - int col1 = *ppos; - int col2 = col1 + 4 * ccol; - Vector256 x1 = Avx.SetVector256(pm3[col2], pm2[col2], pm1[col2], pm0[col2], - pm3[col1], pm2[col1], pm1[col1], pm0[col1]); - x1 = Avx.And(x1, trailingMask); - - Vector256 x2 = Avx.SetAllVector256(pSrcCurrent[col1]); - result = MultiplyAdd(x2, x1, result); - ppos++; - } - - result = Avx.Add(result, Avx.And(leadingMask, Avx.LoadVector256(pDstCurrent))); - - Avx.Store(pDstCurrent, result); - pDstCurrent += 8; - pm0 += 8 * ccol; - } - } - - Vector256 SparseMultiplicationAcrossRow() + while (pDstCurrent < pDstEnd) { float* pm1 = pm0 + ccol; float* pm2 = pm1 + ccol; @@ -455,329 +279,133 @@ Vector256 SparseMultiplicationAcrossRow() int col1 = *ppos; int col2 = col1 + 4 * ccol; Vector256 x1 = Avx.SetVector256(pm3[col2], pm2[col2], pm1[col2], pm0[col2], - pm3[col1], pm2[col1], pm1[col1], pm0[col1]); + pm3[col1], pm2[col1], pm1[col1], pm0[col1]); Vector256 x2 = Avx.SetAllVector256(pSrcCurrent[col1]); - result = MultiplyAdd(x2, x1, result); + x2 = Avx.Multiply(x2, x1); + result = Avx.Add(result, x2); + ppos++; } - return result; + Avx.StoreAligned(pDstCurrent, result); + pDstCurrent += 8; + pm0 += 8 * ccol; } } } public static unsafe void MatMulTran(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol) { - MatMulTran(mat.Items, src.Items, dst.Items, crow, ccol); - } - - public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan src, Span dst, int crow, int ccol) - { - Contracts.Assert(crow % 4 == 0); - Contracts.Assert(ccol % 4 == 0); + Contracts.Assert(HasCompatibleAlignment(mat)); + Contracts.Assert(HasCompatibleAlignment(src)); + Contracts.Assert(HasCompatibleAlignment(dst)); - fixed (float* psrc = &MemoryMarshal.GetReference(src)) - fixed (float* pdst = &MemoryMarshal.GetReference(dst)) - fixed (float* pmat = &MemoryMarshal.GetReference(mat)) - fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0]) - fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0]) + fixed (float* pSrcStart = &src.Items[0]) + fixed (float* pDstStart = &dst.Items[0]) + fixed (float* pMatStart = &mat.Items[0]) { + float* psrc = GetAlignedBase(src, pSrcStart); + float* pdst = GetAlignedBase(dst, pDstStart); + float* pmat = GetAlignedBase(mat, pMatStart); + float* pSrcEnd = psrc + ccol; float* pDstEnd = pdst + crow; float* pSrcCurrent = psrc; float* pMatCurrent = pmat; - // The reason behind adding the if condtion instead of boolean flag - // is to avoid branching in codegen. - if (pSrcCurrent < pSrcEnd) - { - Vector128 h01 = Sse.LoadVector128(pSrcCurrent); - // Replicate each slot of h01 (ABCD) into its own register. - Vector128 h11 = Avx.Permute(h01, 0x55); // B - Vector128 h21 = Avx.Permute(h01, 0xAA); // C - Vector128 h31 = Avx.Permute(h01, 0xFF); // D - h01 = Avx.Permute(h01, 0x00); // A + // We do 4-way unrolling + Vector128 h01 = Sse.LoadAlignedVector128(pSrcCurrent); + // Replicate each slot of h01 (ABCD) into its own register. + Vector128 h11 = Sse.Shuffle(h01, h01, 0x55); // B + Vector128 h21 = Sse.Shuffle(h01, h01, 0xAA); // C + Vector128 h31 = Sse.Shuffle(h01, h01, 0xFF); // D + h01 = Sse.Shuffle(h01, h01, 0x00); // A - Vector256 x01 = Avx.SetHighLow(h01, h01); - Vector256 x11 = Avx.SetHighLow(h11, h11); - Vector256 x21 = Avx.SetHighLow(h21, h21); - Vector256 x31 = Avx.SetHighLow(h31, h31); + Vector256 x01 = Avx.SetHighLow(h01, h01); + Vector256 x11 = Avx.SetHighLow(h11, h11); + Vector256 x21 = Avx.SetHighLow(h21, h21); + Vector256 x31 = Avx.SetHighLow(h31, h31); - int length = crow; - float* pDstCurrent = pdst; + pSrcCurrent += 4; - nuint address = (nuint)(pMatCurrent); - int misalignment = (int)(address % 32); + float* pDstCurrent = pdst; - if ((misalignment & 3) != 0) - { - // Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations - while (pDstCurrent < pDstEnd) - { - float* pMatTemp = pMatCurrent; - Vector256 x02 = Avx.Multiply(x01, Avx.LoadVector256(pMatTemp)); - Vector256 x12 = Avx.Multiply(x11, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x22 = Avx.Multiply(x21, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x32 = Avx.Multiply(x31, Avx.LoadVector256(pMatTemp += crow)); - - x02 = Avx.Add(x02, x12); - x22 = Avx.Add(x22, x32); - x02 = Avx.Add(x02, x22); - - Avx.Store(pDstCurrent, x02); - pDstCurrent += 8; - pMatCurrent += 8; - } - } - else - { - int remainder = 0; - if (misalignment != 0) - { - // Handle cases where the data is not 256-bit aligned by doing an unaligned read and then - // masking any elements that will be included in the first aligned read - misalignment >>= 2; - misalignment = 8 - misalignment; - - Vector256 leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + (misalignment * 8)); - - // We only align pMat since it has significantly more reads. - float* pMatTemp = pMatCurrent; - Vector256 x02 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp)); - Vector256 x12 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x22 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x32 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow)); - - x02 = Avx.Multiply(x01, x02); - x12 = Avx.Multiply(x11, x12); - x22 = Avx.Multiply(x21, x22); - x32 = Avx.Multiply(x31, x32); - - x02 = Avx.Add(x02, x12); - x22 = Avx.Add(x22, x32); - x02 = Avx.Add(x02, x22); - - Vector256 trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + ((8 - misalignment) * 8)); - Vector256 x3 = Avx.LoadVector256(pDstCurrent); - x02 = Avx.Or(x02, Avx.And(x3, trailingMask)); - - Avx.Store(pDstCurrent, x02); - pMatCurrent += misalignment; - pDstCurrent += misalignment; - length -= misalignment; - } - if (length > 7) - { - // Handle all the 256-bit blocks that we can now that we have offset to an aligned address - remainder = length % 8; - - while (pDstCurrent + 8 <= pDstEnd) - { - // If we aren't using the VEX-encoding, the JIT will only fold away aligned loads - // (due to semantics of the legacy encoding). - // We don't need an assert, since the instruction will throw for unaligned inputs. - float* pMatTemp = pMatCurrent; - - Vector256 x02 = Avx.Multiply(x01, Avx.LoadVector256(pMatTemp)); - Vector256 x12 = Avx.Multiply(x11, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x22 = Avx.Multiply(x21, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x32 = Avx.Multiply(x31, Avx.LoadVector256(pMatTemp += crow)); - - x02 = Avx.Add(x02, x12); - x22 = Avx.Add(x22, x32); - x02 = Avx.Add(x02, x22); - - Avx.Store(pDstCurrent, x02); - pDstCurrent += 8; - pMatCurrent += 8; - } - } - else - { - // Handle the "worst-case" scenario, which is when we have 8-16 elements and the input is not - // 256-bit aligned. This means we can't do any aligned loads and will just end up doing two - // unaligned loads where we mask the input each time. - remainder = length; - } + while (pDstCurrent < pDstEnd) + { + float* pMatTemp = pMatCurrent; + Contracts.Assert(((nuint)(pMatTemp) % 32) == 0); - if (remainder != 0) - { - // Handle any trailing elements that don't fit into a 256-bit block by moving back so that the next - // unaligned load will read to the end of the array and then mask out any elements already processed - - pMatCurrent -= (8 - remainder); - pDstCurrent -= (8 - remainder); - Vector256 trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8)); - - float* pMatTemp = pMatCurrent; - Vector256 x02 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp)); - Vector256 x12 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x22 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x32 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp += crow)); - - x02 = Avx.Multiply(x01, x02); - x12 = Avx.Multiply(x11, x12); - x22 = Avx.Multiply(x21, x22); - x32 = Avx.Multiply(x31, x32); - - x02 = Avx.Add(x02, x12); - x22 = Avx.Add(x22, x32); - x02 = Avx.Add(x02, x22); - - Vector256 leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + ((8 - remainder) * 8)); - Vector256 x3 = Avx.LoadVector256(pDstCurrent); - x02 = Avx.Or(x02, Avx.And(x3, leadingMask)); - - Avx.Store(pDstCurrent, x02); - pDstCurrent += 8; - pMatCurrent += 8; - } - } + // The JIT will only fold away unaligned loads due to the semantics behind + // the VEX-encoding of the memory operand for `ins xmm, xmm, [mem]`. Since + // modern hardware has unaligned loads that are as fast as aligned loads, + // when it doesn't cross a cache-line/page boundary, we will just assert + // that the alignment is correct and allow for the more-efficient codegen. + Vector256 x02 = Avx.LoadVector256(pMatTemp); + Vector256 x12 = Avx.LoadVector256(pMatTemp += crow); + Vector256 x22 = Avx.LoadVector256(pMatTemp += crow); + Vector256 x32 = Avx.LoadVector256(pMatTemp += crow); - pMatCurrent += 3 * crow; - pSrcCurrent += 4; + x02 = Avx.Multiply(x01, x02); + x02 = MultiplyAdd(x11, x12, x02); + + x22 = Avx.Multiply(x21, x22); + x22 = MultiplyAdd(x31, x32, x22); + + x02 = Avx.Add(x02, x22); + Avx.StoreAligned(pDstCurrent, x02); + + pDstCurrent += 8; + pMatCurrent += 8; } - // We do 4-way unrolling + pMatCurrent += 3 * crow; + while (pSrcCurrent < pSrcEnd) { - Vector128 h01 = Sse.LoadVector128(pSrcCurrent); + h01 = Sse.LoadAlignedVector128(pSrcCurrent); // Replicate each slot of h01 (ABCD) into its own register. - Vector128 h11 = Avx.Permute(h01, 0x55); // B - Vector128 h21 = Avx.Permute(h01, 0xAA); // C - Vector128 h31 = Avx.Permute(h01, 0xFF); // D - h01 = Avx.Permute(h01, 0x00); // A - - Vector256 x01 = Avx.SetHighLow(h01, h01); - Vector256 x11 = Avx.SetHighLow(h11, h11); - Vector256 x21 = Avx.SetHighLow(h21, h21); - Vector256 x31 = Avx.SetHighLow(h31, h31); + h11 = Sse.Shuffle(h01, h01, 0x55); // B + h21 = Sse.Shuffle(h01, h01, 0xAA); // C + h31 = Sse.Shuffle(h01, h01, 0xFF); // D + h01 = Sse.Shuffle(h01, h01, 0x00); // A - int length = crow; - float* pDstCurrent = pdst; + x01 = Avx.SetHighLow(h01, h01); + x11 = Avx.SetHighLow(h11, h11); + x21 = Avx.SetHighLow(h21, h21); + x31 = Avx.SetHighLow(h31, h31); - nuint address = (nuint)(pMatCurrent); - int misalignment = (int)(address % 32); + pDstCurrent = pdst; - if ((misalignment & 3) != 0) + while (pDstCurrent < pDstEnd) { - while (pDstCurrent < pDstEnd) - { - float* pMatTemp = pMatCurrent; - Vector256 x02 = Avx.Multiply(x01, Avx.LoadVector256(pMatTemp)); - Vector256 x12 = Avx.Multiply(x11, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x22 = Avx.Multiply(x21, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x32 = Avx.Multiply(x31, Avx.LoadVector256(pMatTemp += crow)); + float* pMatTemp = pMatCurrent; - x02 = Avx.Add(x02, x12); - x22 = Avx.Add(x22, x32); - x02 = Avx.Add(x02, x22); + Contracts.Assert(((nuint)(pMatTemp) % 32) == 0); + Contracts.Assert(((nuint)(pDstCurrent) % 32) == 0); - x02 = Avx.Add(x02, Avx.LoadVector256(pDstCurrent)); + // The JIT will only fold away unaligned loads due to the semantics behind + // the VEX-encoding of the memory operand for `ins xmm, xmm, [mem]`. Since + // modern hardware has unaligned loads that are as fast as aligned loads, + // when it doesn't cross a cache-line/page boundary, we will just assert + // that the alignment is correct and allow for the more-efficient codegen. + Vector256 x02 = Avx.LoadVector256(pMatTemp); + Vector256 x12 = Avx.LoadVector256(pMatTemp += crow); + Vector256 x22 = Avx.LoadVector256(pMatTemp += crow); + Vector256 x32 = Avx.LoadVector256(pMatTemp += crow); + Vector256 x3 = Avx.LoadVector256(pDstCurrent); - Avx.Store(pDstCurrent, x02); - pDstCurrent += 8; - pMatCurrent += 8; - } - } - else - { - int remainder = 0; - if (misalignment != 0) - { - // Handle cases where the data is not 256-bit aligned by doing an unaligned read and then - // masking any elements that will be included in the first aligned read - misalignment >>= 2; - misalignment = 8 - misalignment; - - Vector256 leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + (misalignment * 8)); - - // We only align pMat since it has significantly more reads. - float* pMatTemp = pMatCurrent; - Vector256 x02 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp)); - Vector256 x12 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x22 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x32 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow)); - - x02 = Avx.Multiply(x01, x02); - x12 = Avx.Multiply(x11, x12); - x22 = Avx.Multiply(x21, x22); - x32 = Avx.Multiply(x31, x32); - - x02 = Avx.Add(x02, x12); - x22 = Avx.Add(x22, x32); - x02 = Avx.Add(x02, x22); - - Vector256 trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + ((8 - misalignment) * 8)); - Vector256 x3 = Avx.LoadVector256(pDstCurrent); - x02 = Avx.Or(x02, Avx.And(x3, trailingMask)); - - x02 = Avx.Add(x02, Avx.And(x3, leadingMask)); - - Avx.Store(pDstCurrent, x02); - pMatCurrent += misalignment; - pDstCurrent += misalignment; - length -= misalignment; - } - if (length > 7) - { - remainder = length % 8; - while (pDstCurrent + 8 <= pDstEnd) - { - float* pMatTemp = pMatCurrent; - - Vector256 x02 = Avx.Multiply(x01, Avx.LoadVector256(pMatTemp)); - Vector256 x12 = Avx.Multiply(x11, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x22 = Avx.Multiply(x21, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x32 = Avx.Multiply(x31, Avx.LoadVector256(pMatTemp += crow)); - - x02 = Avx.Add(x02, x12); - x22 = Avx.Add(x22, x32); - x02 = Avx.Add(x02, x22); - - x02 = Avx.Add(x02, Avx.LoadVector256(pDstCurrent)); - - Avx.Store(pDstCurrent, x02); - pDstCurrent += 8; - pMatCurrent += 8; - } - } - else - { - remainder = length; - } + x02 = Avx.Multiply(x01, x02); + x02 = MultiplyAdd(x11, x12, x02); - if (remainder != 0) - { - pMatCurrent -= (8 - remainder); - pDstCurrent -= (8 - remainder); - Vector256 trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8)); - - float* pMatTemp = pMatCurrent; - Vector256 x02 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp)); - Vector256 x12 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x22 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x32 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp += crow)); - - x02 = Avx.Multiply(x01, x02); - x12 = Avx.Multiply(x11, x12); - x22 = Avx.Multiply(x21, x22); - x32 = Avx.Multiply(x31, x32); - - x02 = Avx.Add(x02, x12); - x22 = Avx.Add(x22, x32); - x02 = Avx.Add(x02, x22); - - Vector256 leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + ((8 - remainder) * 8)); - Vector256 x3 = Avx.LoadVector256(pDstCurrent); - x02 = Avx.Or(x02, Avx.And(x3, leadingMask)); - - x02 = Avx.Add(x02, Avx.And(x3, trailingMask)); - - Avx.Store(pDstCurrent, x02); - pDstCurrent += 8; - pMatCurrent += 8; - } + x22 = Avx.Multiply(x21, x22); + x22 = MultiplyAdd(x31, x32, x22); + + x02 = Avx.Add(x02, x22); + x3 = Avx.Add(x02, x3); + Avx.StoreAligned(pDstCurrent, x3); + + pDstCurrent += 8; + pMatCurrent += 8; } pMatCurrent += 3 * crow; @@ -885,11 +513,12 @@ public static unsafe void Scale(float scale, Span dst) Vector256 leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + (misalignment * 8)); Vector256 trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + ((8 - misalignment) * 8)); - Vector256 temp = Avx.And(result, leadingMask); - result = Avx.And(result, trailingMask); + Vector256 temp = Avx.And(result, trailingMask); + result = Avx.Multiply(scaleVector256, result); - temp = Avx.Multiply(scaleVector256, temp); - result = Avx.Or(temp, result); + // Masking operation is done at the end to avoid doing an Or operation with negative Zero. + result = Avx.And(result, leadingMask); + result = Avx.Or(result, temp); Avx.Store(pDstCurrent, result); @@ -938,21 +567,22 @@ public static unsafe void Scale(float scale, Span dst) Vector256 trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8)); Vector256 leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + ((8 - remainder) * 8)); - Vector256 temp = Avx.And(result, trailingMask); - result = Avx.And(result, leadingMask); + Vector256 temp = Avx.And(result, leadingMask); + result = Avx.Multiply(scaleVector256, result); - temp = Avx.Multiply(scaleVector256, temp); - temp = Avx.Or(temp, result); + // Masking operation is done at the end to avoid doing an Or operation with negative Zero. + result = Avx.And(result, trailingMask); + result = Avx.Or(result, temp); - Avx.Store(pDstCurrent, temp); + Avx.Store(pDstCurrent, result); } } } public static unsafe void ScaleSrcU(float scale, ReadOnlySpan src, Span dst, int count) { - Contracts.Assert(src.Length == dst.Length); - + Contracts.Assert(count <= src.Length); + Contracts.Assert(count <= dst.Length); fixed (float* psrc = &MemoryMarshal.GetReference(src)) fixed (float* pdst = &MemoryMarshal.GetReference(dst)) { @@ -1046,8 +676,8 @@ public static unsafe void ScaleAddU(float a, float b, Span dst) public static unsafe void AddScaleU(float scale, ReadOnlySpan src, Span dst, int count) { - Contracts.Assert(src.Length == dst.Length); - + Contracts.Assert(count <= src.Length); + Contracts.Assert(count <= dst.Length); fixed (float* psrc = &MemoryMarshal.GetReference(src)) fixed (float* pdst = &MemoryMarshal.GetReference(dst)) { @@ -1100,8 +730,9 @@ public static unsafe void AddScaleU(float scale, ReadOnlySpan src, Span src, ReadOnlySpan dst, Span result, int count) { - Contracts.Assert(src.Length == dst.Length); - + Contracts.Assert(count <= src.Length); + Contracts.Assert(count <= dst.Length); + Contracts.Assert(count <= result.Length); fixed (float* psrc = &MemoryMarshal.GetReference(src)) fixed (float* pdst = &MemoryMarshal.GetReference(dst)) fixed (float* pres = &MemoryMarshal.GetReference(result)) @@ -1156,8 +787,9 @@ public static unsafe void AddScaleCopyU(float scale, ReadOnlySpan src, Re public static unsafe void AddScaleSU(float scale, ReadOnlySpan src, ReadOnlySpan idx, Span dst, int count) { - Contracts.Assert(src.Length == dst.Length); - + Contracts.Assert(count <= src.Length); + Contracts.Assert(count <= dst.Length); + Contracts.Assert(count <= idx.Length); fixed (float* psrc = &MemoryMarshal.GetReference(src)) fixed (int* pidx = &MemoryMarshal.GetReference(idx)) fixed (float* pdst = &MemoryMarshal.GetReference(dst)) @@ -1206,8 +838,8 @@ public static unsafe void AddScaleSU(float scale, ReadOnlySpan src, ReadO public static unsafe void AddU(ReadOnlySpan src, Span dst, int count) { - Contracts.Assert(src.Length == dst.Length); - + Contracts.Assert(count <= src.Length); + Contracts.Assert(count <= dst.Length); fixed (float* psrc = &MemoryMarshal.GetReference(src)) fixed (float* pdst = &MemoryMarshal.GetReference(dst)) { @@ -1255,8 +887,9 @@ public static unsafe void AddU(ReadOnlySpan src, Span dst, int cou public static unsafe void AddSU(ReadOnlySpan src, ReadOnlySpan idx, Span dst, int count) { - Contracts.Assert(src.Length == dst.Length); - + Contracts.Assert(count <= src.Length); + Contracts.Assert(count <= dst.Length); + Contracts.Assert(count <= idx.Length); fixed (float* psrc = &MemoryMarshal.GetReference(src)) fixed (int* pidx = &MemoryMarshal.GetReference(idx)) fixed (float* pdst = &MemoryMarshal.GetReference(dst)) @@ -1302,6 +935,9 @@ public static unsafe void AddSU(ReadOnlySpan src, ReadOnlySpan idx, public static unsafe void MulElementWiseU(ReadOnlySpan src1, ReadOnlySpan src2, Span dst, int count) { + Contracts.Assert(count <= src1.Length); + Contracts.Assert(count <= src2.Length); + Contracts.Assert(count <= dst.Length); fixed (float* psrc1 = &MemoryMarshal.GetReference(src1)) fixed (float* psrc2 = &MemoryMarshal.GetReference(src2)) fixed (float* pdst = &MemoryMarshal.GetReference(dst)) @@ -1738,8 +1374,8 @@ public static unsafe float MaxAbsDiffU(float mean, ReadOnlySpan src) public static unsafe float DotU(ReadOnlySpan src, ReadOnlySpan dst, int count) { - Contracts.Assert(src.Length == dst.Length); - + Contracts.Assert(count <= src.Length); + Contracts.Assert(count <= dst.Length); fixed (float* psrc = &MemoryMarshal.GetReference(src)) fixed (float* pdst = &MemoryMarshal.GetReference(dst)) { @@ -1792,8 +1428,9 @@ public static unsafe float DotU(ReadOnlySpan src, ReadOnlySpan dst public static unsafe float DotSU(ReadOnlySpan src, ReadOnlySpan dst, ReadOnlySpan idx, int count) { - Contracts.Assert(src.Length == dst.Length); - + Contracts.Assert(count <= src.Length); + Contracts.Assert(count <= dst.Length); + Contracts.Assert(count <= idx.Length); fixed (float* psrc = &MemoryMarshal.GetReference(src)) fixed (float* pdst = &MemoryMarshal.GetReference(dst)) fixed (int* pidx = &MemoryMarshal.GetReference(idx)) @@ -1848,8 +1485,8 @@ public static unsafe float DotSU(ReadOnlySpan src, ReadOnlySpan ds public static unsafe float Dist2(ReadOnlySpan src, ReadOnlySpan dst, int count) { - Contracts.Assert(src.Length == dst.Length); - + Contracts.Assert(count <= src.Length); + Contracts.Assert(count <= dst.Length); fixed (float* psrc = &MemoryMarshal.GetReference(src)) fixed (float* pdst = &MemoryMarshal.GetReference(dst)) { diff --git a/src/Microsoft.ML.CpuMath/CpuAligenedMathUtils.cs b/src/Microsoft.ML.CpuMath/CpuAligenedMathUtils.cs index 33690055b2..c80cdec192 100644 --- a/src/Microsoft.ML.CpuMath/CpuAligenedMathUtils.cs +++ b/src/Microsoft.ML.CpuMath/CpuAligenedMathUtils.cs @@ -2,9 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime.Internal.CpuMath.Core; +using Microsoft.ML.Internal.CpuMath.Core; -namespace Microsoft.ML.Runtime.Internal.CpuMath +namespace Microsoft.ML.Internal.CpuMath { [BestFriend] internal static class CpuAligenedMathUtils diff --git a/src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs b/src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs index d895e590a9..839b7d5471 100644 --- a/src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs +++ b/src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs @@ -2,12 +2,12 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime.Internal.CpuMath.Core; +using System; using System.Runtime.CompilerServices; using System.Runtime.Intrinsics.X86; -using System; +using Microsoft.ML.Internal.CpuMath.Core; -namespace Microsoft.ML.Runtime.Internal.CpuMath +namespace Microsoft.ML.Internal.CpuMath { internal static partial class CpuMathUtils { diff --git a/src/Microsoft.ML.CpuMath/CpuMathUtils.netstandard.cs b/src/Microsoft.ML.CpuMath/CpuMathUtils.netstandard.cs index 5ecbc62be1..b062898aec 100644 --- a/src/Microsoft.ML.CpuMath/CpuMathUtils.netstandard.cs +++ b/src/Microsoft.ML.CpuMath/CpuMathUtils.netstandard.cs @@ -2,11 +2,12 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime.Internal.CpuMath.Core; using System; using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using Microsoft.ML.Internal.CpuMath.Core; -namespace Microsoft.ML.Runtime.Internal.CpuMath +namespace Microsoft.ML.Internal.CpuMath { [BestFriend] internal static partial class CpuMathUtils @@ -18,57 +19,411 @@ internal static partial class CpuMathUtils public static int GetVectorAlignment() => Vector128Alignment; - public static void MatrixTimesSource(bool transpose, AlignedArray matrix, AlignedArray source, AlignedArray destination, int stride) => SseUtils.MatTimesSrc(transpose, matrix, source, destination, stride); + private static bool Compat(AlignedArray a) + { + Contracts.AssertValue(a); + Contracts.Assert(a.Size > 0); + return a.CbAlign == Vector128Alignment; + } - public static void MatrixTimesSource(AlignedArray matrix, ReadOnlySpan rgposSrc, AlignedArray sourceValues, - int posMin, int iposMin, int iposLimit, AlignedArray destination, int stride) => SseUtils.MatTimesSrc(matrix, rgposSrc, sourceValues, posMin, iposMin, iposLimit, destination, stride); + private static unsafe float* Ptr(AlignedArray a, float* p) + { + Contracts.AssertValue(a); + float* q = p + a.GetBase((long)p); + Contracts.Assert(((long)q & (Vector128Alignment - 1)) == 0); + return q; + } - public static void Add(float value, Span destination) => SseUtils.Add(value, destination); + public static void MatrixTimesSource(bool tran, AlignedArray mat, AlignedArray src, AlignedArray dst, int crun) + { + Contracts.Assert(Compat(mat)); + Contracts.Assert(Compat(src)); + Contracts.Assert(Compat(dst)); + Contracts.Assert(mat.Size == dst.Size * src.Size); - public static void Scale(float value, Span destination) => SseUtils.Scale(value, destination); + unsafe + { + fixed (float* pmat = &mat.Items[0]) + fixed (float* psrc = &src.Items[0]) + fixed (float* pdst = &dst.Items[0]) + { + if (!tran) + { + Contracts.Assert(0 <= crun && crun <= dst.Size); + Thunk.MatMul(Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), crun, src.Size); + } + else + { + Contracts.Assert(0 <= crun && crun <= src.Size); + Thunk.MatMulTran(Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), dst.Size, crun); + } + } + } + } - public static void Scale(float value, ReadOnlySpan source, Span destination, int count) => SseUtils.Scale(value, source, destination, count); + public static void MatrixTimesSource(AlignedArray mat, ReadOnlySpan rgposSrc, AlignedArray srcValues, + int posMin, int iposMin, int iposLim, AlignedArray dst, int crun) + { + Contracts.Assert(Compat(mat)); + Contracts.Assert(Compat(srcValues)); + Contracts.Assert(Compat(dst)); + Contracts.Assert(0 <= iposMin && iposMin <= iposLim && iposLim <= rgposSrc.Length); + Contracts.Assert(mat.Size == dst.Size * srcValues.Size); - public static void ScaleAdd(float value, float addend, Span destination) => SseUtils.ScaleAdd(value, addend, destination); + if (iposMin >= iposLim) + { + dst.ZeroItems(); + return; + } + Contracts.AssertNonEmpty(rgposSrc); + unsafe + { + fixed (float* pdst = &dst.Items[0]) + fixed (float* pmat = &mat.Items[0]) + fixed (float* psrc = &srcValues.Items[0]) + fixed (int* ppossrc = &rgposSrc[0]) + { + Contracts.Assert(0 <= crun && crun <= dst.Size); + Thunk.MatMulP(Ptr(mat, pmat), ppossrc, Ptr(srcValues, psrc), posMin, iposMin, iposLim, Ptr(dst, pdst), crun, srcValues.Size); + } + } + } - public static void AddScale(float value, ReadOnlySpan source, Span destination, int count) => SseUtils.AddScale(value, source, destination, count); + // dst += a + public static void Add(float a, Span dst) + { + Contracts.AssertNonEmpty(dst); - public static void AddScale(float value, ReadOnlySpan source, ReadOnlySpan indices, Span destination, int count) => SseUtils.AddScale(value, source, indices, destination, count); + unsafe + { + fixed (float* pdst = &MemoryMarshal.GetReference(dst)) + Thunk.AddScalarU(a, pdst, dst.Length); + } + } - public static void AddScaleCopy(float value, ReadOnlySpan source, ReadOnlySpan destination, Span res, int count) => SseUtils.AddScaleCopy(value, source, destination, res, count); + public static void Scale(float a, Span dst) + { + Contracts.AssertNonEmpty(dst); - public static void Add(ReadOnlySpan source, Span destination, int count) => SseUtils.Add(source, destination, count); + unsafe + { + fixed (float* pd = &MemoryMarshal.GetReference(dst)) + Thunk.Scale(a, pd, dst.Length); + } + } - public static void Add(ReadOnlySpan source, ReadOnlySpan indices, Span destination, int count) => SseUtils.Add(source, indices, destination, count); + // dst = a * src + public static void Scale(float a, ReadOnlySpan src, Span dst, int count) + { + Contracts.AssertNonEmpty(src); + Contracts.Assert(0 < count && count <= src.Length); + Contracts.AssertNonEmpty(dst); + Contracts.Assert(count <= dst.Length); - public static void MulElementWise(ReadOnlySpan left, ReadOnlySpan right, Span destination, int count) => SseUtils.MulElementWise(left, right, destination, count); + unsafe + { + fixed (float* psrc = &MemoryMarshal.GetReference(src)) + fixed (float* pdst = &MemoryMarshal.GetReference(dst)) + { + Thunk.ScaleSrcU(a, psrc, pdst, count); + } + } + } - public static float Sum(ReadOnlySpan source) => SseUtils.Sum(source); + // dst[i] = a * (dst[i] + b) + public static void ScaleAdd(float a, float b, Span dst) + { + Contracts.AssertNonEmpty(dst); - public static float SumSq(ReadOnlySpan source) => SseUtils.SumSq(source); + unsafe + { + fixed (float* pdst = &MemoryMarshal.GetReference(dst)) + Thunk.ScaleAddU(a, b, pdst, dst.Length); + } + } - public static float SumSq(float mean, ReadOnlySpan source) => SseUtils.SumSq(mean, source); + public static void AddScale(float a, ReadOnlySpan src, Span dst, int count) + { + Contracts.AssertNonEmpty(src); + Contracts.Assert(0 < count && count <= src.Length); + Contracts.AssertNonEmpty(dst); + Contracts.Assert(count <= dst.Length); - public static float SumAbs(ReadOnlySpan source) => SseUtils.SumAbs(source); + unsafe + { + fixed (float* psrc = &MemoryMarshal.GetReference(src)) + fixed (float* pdst = &MemoryMarshal.GetReference(dst)) + Thunk.AddScaleU(a, psrc, pdst, count); + } + } - public static float SumAbs(float mean, ReadOnlySpan source) => SseUtils.SumAbs(mean, source); + public static void AddScale(float a, ReadOnlySpan src, ReadOnlySpan indices, Span dst, int count) + { + Contracts.AssertNonEmpty(src); + Contracts.Assert(0 < count && count <= src.Length); + Contracts.AssertNonEmpty(indices); + Contracts.Assert(count <= indices.Length); + Contracts.AssertNonEmpty(dst); + Contracts.Assert(count < dst.Length); - public static float MaxAbs(ReadOnlySpan source) => SseUtils.MaxAbs(source); + unsafe + { + fixed (float* psrc = &MemoryMarshal.GetReference(src)) + fixed (int* pi = &MemoryMarshal.GetReference(indices)) + fixed (float* pdst = &MemoryMarshal.GetReference(dst)) + Thunk.AddScaleSU(a, psrc, pi, pdst, count); + } + } - public static float MaxAbsDiff(float mean, ReadOnlySpan source) => SseUtils.MaxAbsDiff(mean, source); + public static void AddScaleCopy(float a, ReadOnlySpan src, ReadOnlySpan dst, Span res, int count) + { + Contracts.AssertNonEmpty(dst); + Contracts.Assert(0 < count && count <= dst.Length); + Contracts.AssertNonEmpty(src); + Contracts.Assert(count <= src.Length); + Contracts.AssertNonEmpty(res); + Contracts.Assert(count <= res.Length); - public static float DotProductDense(ReadOnlySpan left, ReadOnlySpan right, int count) => SseUtils.DotProductDense(left, right, count); + unsafe + { + fixed (float* pdst = &MemoryMarshal.GetReference(dst)) + fixed (float* psrc = &MemoryMarshal.GetReference(src)) + fixed (float* pres = &MemoryMarshal.GetReference(res)) + Thunk.AddScaleCopyU(a, psrc, pdst, pres, count); + } + } - public static float DotProductSparse(ReadOnlySpan left, ReadOnlySpan right, ReadOnlySpan indices, int count) => SseUtils.DotProductSparse(left, right, indices, count); + public static void Add(ReadOnlySpan src, Span dst, int count) + { + Contracts.AssertNonEmpty(src); + Contracts.Assert(0 < count && count <= src.Length); + Contracts.AssertNonEmpty(dst); + Contracts.Assert(count <= dst.Length); - public static float L2DistSquared(ReadOnlySpan left, ReadOnlySpan right, int count) => SseUtils.L2DistSquared(left, right, count); + unsafe + { + fixed (float* ps = &MemoryMarshal.GetReference(src)) + fixed (float* pd = &MemoryMarshal.GetReference(dst)) + Thunk.AddU(ps, pd, count); + } + } - public static void ZeroMatrixItems(AlignedArray destination, int ccol, int cfltRow, int[] indices) => SseUtils.ZeroMatrixItems(destination, ccol, cfltRow, indices); + public static void Add(ReadOnlySpan src, ReadOnlySpan indices, Span dst, int count) + { + Contracts.AssertNonEmpty(src); + Contracts.Assert(0 < count && count <= src.Length); + Contracts.AssertNonEmpty(indices); + Contracts.Assert(count <= indices.Length); + Contracts.AssertNonEmpty(dst); + Contracts.Assert(count < dst.Length); - public static void SdcaL1UpdateDense(float primalUpdate, int count, ReadOnlySpan source, float threshold, Span v, Span w) - => SseUtils.SdcaL1UpdateDense(primalUpdate, count, source, threshold, v, w); + unsafe + { + fixed (float* ps = &MemoryMarshal.GetReference(src)) + fixed (int* pi = &MemoryMarshal.GetReference(indices)) + fixed (float* pd = &MemoryMarshal.GetReference(dst)) + Thunk.AddSU(ps, pi, pd, count); + } + } + + public static void MulElementWise(ReadOnlySpan src1, ReadOnlySpan src2, Span dst, int count) + { + Contracts.AssertNonEmpty(src1); + Contracts.Assert(0 < count && count <= src1.Length); + Contracts.AssertNonEmpty(src2); + Contracts.Assert(0 < count && count <= src2.Length); + Contracts.AssertNonEmpty(dst); + unsafe + { + fixed (float* ps1 = &MemoryMarshal.GetReference(src1)) + fixed (float* ps2 = &MemoryMarshal.GetReference(src2)) + fixed (float* pd = &MemoryMarshal.GetReference(dst)) + Thunk.MulElementWiseU(ps1, ps2, pd, count); + } + } + + public static float Sum(ReadOnlySpan src) + { + Contracts.AssertNonEmpty(src); + + unsafe + { + fixed (float* psrc = &MemoryMarshal.GetReference(src)) + return Thunk.Sum(psrc, src.Length); + } + } + + public static float SumSq(ReadOnlySpan src) + { + Contracts.AssertNonEmpty(src); + + unsafe + { + fixed (float* psrc = &MemoryMarshal.GetReference(src)) + return Thunk.SumSqU(psrc, src.Length); + } + } + + public static float SumSq(float mean, ReadOnlySpan src) + { + Contracts.AssertNonEmpty(src); + + unsafe + { + fixed (float* psrc = &MemoryMarshal.GetReference(src)) + return (mean == 0 ? Thunk.SumSqU(psrc, src.Length) : Thunk.SumSqDiffU(mean, psrc, src.Length)); + } + } + + public static float SumAbs(ReadOnlySpan src) + { + Contracts.AssertNonEmpty(src); + + unsafe + { + fixed (float* psrc = &MemoryMarshal.GetReference(src)) + return Thunk.SumAbsU(psrc, src.Length); + } + } + + public static float SumAbs(float mean, ReadOnlySpan src) + { + Contracts.AssertNonEmpty(src); + + unsafe + { + fixed (float* psrc = &MemoryMarshal.GetReference(src)) + return (mean == 0 ? Thunk.SumAbsU(psrc, src.Length) : Thunk.SumAbsDiffU(mean, psrc, src.Length)); + } + } + + public static float MaxAbs(ReadOnlySpan src) + { + Contracts.AssertNonEmpty(src); + + unsafe + { + fixed (float* psrc = &MemoryMarshal.GetReference(src)) + return Thunk.MaxAbsU(psrc, src.Length); + } + } + + public static float MaxAbsDiff(float mean, ReadOnlySpan src) + { + Contracts.AssertNonEmpty(src); + + unsafe + { + fixed (float* psrc = &MemoryMarshal.GetReference(src)) + return Thunk.MaxAbsDiffU(mean, psrc, src.Length); + } + } + + public static float DotProductDense(ReadOnlySpan a, ReadOnlySpan b, int count) + { + Contracts.AssertNonEmpty(a); + Contracts.AssertNonEmpty(b); + Contracts.Assert(0 < count); + Contracts.Assert(a.Length >= count); + Contracts.Assert(b.Length >= count); + + unsafe + { + fixed (float* pa = &MemoryMarshal.GetReference(a)) + fixed (float* pb = &MemoryMarshal.GetReference(b)) + return Thunk.DotU(pa, pb, count); + } + } + + public static float DotProductSparse(ReadOnlySpan a, ReadOnlySpan b, ReadOnlySpan indices, int count) + { + Contracts.AssertNonEmpty(a); + Contracts.AssertNonEmpty(b); + Contracts.Assert(0 < count); + Contracts.Assert(count < a.Length); + Contracts.Assert(count <= b.Length); + Contracts.Assert(count <= indices.Length); + + unsafe + { + fixed (float* pa = &MemoryMarshal.GetReference(a)) + fixed (float* pb = &MemoryMarshal.GetReference(b)) + fixed (int* pi = &MemoryMarshal.GetReference(indices)) + return Thunk.DotSU(pa, pb, pi, count); + } + } + + public static float L2DistSquared(ReadOnlySpan a, ReadOnlySpan b, int count) + { + Contracts.AssertNonEmpty(a); + Contracts.AssertNonEmpty(b); + Contracts.Assert(0 < count && count <= a.Length); + Contracts.Assert(count <= b.Length); + + unsafe + { + fixed (float* pa = &MemoryMarshal.GetReference(a)) + fixed (float* pb = &MemoryMarshal.GetReference(b)) + return Thunk.Dist2(pa, pb, count); + } + } + + public static void ZeroMatrixItems(AlignedArray dst, int ccol, int cfltRow, int[] indices) + { + Contracts.Assert(0 < ccol && ccol <= cfltRow); + + unsafe + { + fixed (float* pdst = &dst.Items[0]) + fixed (int* pi = &indices[0]) + { + if (ccol == cfltRow) + Thunk.ZeroItemsU(Ptr(dst, pdst), dst.Size, pi, indices.Length); + else + Thunk.ZeroMatrixItemsCore(Ptr(dst, pdst), dst.Size, ccol, cfltRow, pi, indices.Length); + } + } + } + + public static void SdcaL1UpdateDense(float primalUpdate, int count, ReadOnlySpan src, float threshold, Span v, Span w) + { + Contracts.AssertNonEmpty(src); + Contracts.Assert(count <= src.Length); + Contracts.AssertNonEmpty(v); + Contracts.Assert(count <= v.Length); + Contracts.AssertNonEmpty(w); + Contracts.Assert(count <= w.Length); + Contracts.Assert(count > 0); + + unsafe + { + fixed (float* psrc = &MemoryMarshal.GetReference(src)) + fixed (float* pd1 = &MemoryMarshal.GetReference(v)) + fixed (float* pd2 = &MemoryMarshal.GetReference(w)) + Thunk.SdcaL1UpdateU(primalUpdate, psrc, threshold, pd1, pd2, count); + } + } public static void SdcaL1UpdateSparse(float primalUpdate, int count, ReadOnlySpan source, ReadOnlySpan indices, float threshold, Span v, Span w) - => SseUtils.SdcaL1UpdateSparse(primalUpdate, count, source, indices, threshold, v, w); + { + Contracts.AssertNonEmpty(source); + Contracts.Assert(count <= source.Length); + Contracts.AssertNonEmpty(indices); + Contracts.Assert(count <= indices.Length); + Contracts.AssertNonEmpty(v); + Contracts.Assert(count <= v.Length); + Contracts.AssertNonEmpty(w); + Contracts.Assert(count <= w.Length); + Contracts.Assert(count > 0); + + unsafe + { + fixed (float* psrc = &MemoryMarshal.GetReference(source)) + fixed (int* pi = &MemoryMarshal.GetReference(indices)) + fixed (float* pd1 = &MemoryMarshal.GetReference(v)) + fixed (float* pd2 = &MemoryMarshal.GetReference(w)) + Thunk.SdcaL1UpdateSU(primalUpdate, psrc, pi, threshold, pd1, pd2, count); + } + } } -} +} \ No newline at end of file diff --git a/src/Microsoft.ML.CpuMath/EigenUtils.cs b/src/Microsoft.ML.CpuMath/EigenUtils.cs index cff9d1b32d..94d8ee1a88 100644 --- a/src/Microsoft.ML.CpuMath/EigenUtils.cs +++ b/src/Microsoft.ML.CpuMath/EigenUtils.cs @@ -2,11 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime.Internal.CpuMath.Core; using System; +using Microsoft.ML.Internal.CpuMath.Core; using Float = System.Single; -namespace Microsoft.ML.Runtime.Internal.CpuMath +namespace Microsoft.ML.Internal.CpuMath { [BestFriend] // REVIEW: improve perf with SSE and Multithreading diff --git a/src/Microsoft.ML.CpuMath/ICpuBuffer.cs b/src/Microsoft.ML.CpuMath/ICpuBuffer.cs index a121351cff..3a3e76d6c6 100644 --- a/src/Microsoft.ML.CpuMath/ICpuBuffer.cs +++ b/src/Microsoft.ML.CpuMath/ICpuBuffer.cs @@ -2,15 +2,13 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime.Internal.CpuMath.Core; using System; using System.Collections.Generic; +using Microsoft.ML.Internal.CpuMath.Core; using Float = System.Single; -namespace Microsoft.ML.Runtime.Internal.CpuMath +namespace Microsoft.ML.Internal.CpuMath { - using Conditional = System.Diagnostics.ConditionalAttribute; - [BestFriend] internal interface ICpuBuffer : IEnumerable, IDisposable where T : struct diff --git a/src/Microsoft.ML.CpuMath/IntUtils.cs b/src/Microsoft.ML.CpuMath/IntUtils.cs index dbb07e31cb..1a27854df3 100644 --- a/src/Microsoft.ML.CpuMath/IntUtils.cs +++ b/src/Microsoft.ML.CpuMath/IntUtils.cs @@ -2,12 +2,12 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime.Internal.CpuMath.Core; -using System.Runtime.InteropServices; using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; using System.Security; +using Microsoft.ML.Internal.CpuMath.Core; -namespace Microsoft.ML.Runtime.Internal.CpuMath +namespace Microsoft.ML.Internal.CpuMath { [BestFriend] internal static class IntUtils diff --git a/src/Microsoft.ML.CpuMath/ProbabilityFunctions.cs b/src/Microsoft.ML.CpuMath/ProbabilityFunctions.cs index 64875bb8b6..db3e304379 100644 --- a/src/Microsoft.ML.CpuMath/ProbabilityFunctions.cs +++ b/src/Microsoft.ML.CpuMath/ProbabilityFunctions.cs @@ -2,10 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime.Internal.CpuMath.Core; using System; +using Microsoft.ML.Internal.CpuMath.Core; -namespace Microsoft.ML.Runtime.Internal.CpuMath +namespace Microsoft.ML.Internal.CpuMath { /// /// Probability Functions. diff --git a/src/Microsoft.ML.CpuMath/Sse.cs b/src/Microsoft.ML.CpuMath/Sse.cs deleted file mode 100644 index 8b1c4da70f..0000000000 --- a/src/Microsoft.ML.CpuMath/Sse.cs +++ /dev/null @@ -1,427 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using Microsoft.ML.Runtime.Internal.CpuMath.Core; -using System; -using System.Runtime.InteropServices; - -namespace Microsoft.ML.Runtime.Internal.CpuMath -{ - /// - /// Keep Sse.cs in sync with Avx.cs. When making changes to one, use BeyondCompare or a similar tool - /// to view diffs and propagate appropriate changes to the other. - /// - [BestFriend] - internal static class SseUtils - { - public const int CbAlign = 16; - - private static bool Compat(AlignedArray a) - { - Contracts.AssertValue(a); - Contracts.Assert(a.Size > 0); - return a.CbAlign == CbAlign; - } - - private static unsafe float* Ptr(AlignedArray a, float* p) - { - Contracts.AssertValue(a); - float* q = p + a.GetBase((long)p); - Contracts.Assert(((long)q & (CbAlign - 1)) == 0); - return q; - } - - public static void MatTimesSrc(bool tran, AlignedArray mat, AlignedArray src, AlignedArray dst, int crun) - { - Contracts.Assert(Compat(mat)); - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(mat.Size == dst.Size * src.Size); - - unsafe - { - fixed (float* pmat = &mat.Items[0]) - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - { - if (!tran) - { - Contracts.Assert(0 <= crun && crun <= dst.Size); - Thunk.MatMul(Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), crun, src.Size); - } - else - { - Contracts.Assert(0 <= crun && crun <= src.Size); - Thunk.MatMulTran(Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), dst.Size, crun); - } - } - } - } - - public static void MatTimesSrc(AlignedArray mat, ReadOnlySpan rgposSrc, AlignedArray srcValues, - int posMin, int iposMin, int iposLim, AlignedArray dst, int crun) - { - Contracts.Assert(Compat(mat)); - Contracts.Assert(Compat(srcValues)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(0 <= iposMin && iposMin <= iposLim && iposLim <= rgposSrc.Length); - Contracts.Assert(mat.Size == dst.Size * srcValues.Size); - - if (iposMin >= iposLim) - { - dst.ZeroItems(); - return; - } - Contracts.AssertNonEmpty(rgposSrc); - unsafe - { - fixed (float* pdst = &dst.Items[0]) - fixed (float* pmat = &mat.Items[0]) - fixed (float* psrc = &srcValues.Items[0]) - fixed (int* ppossrc = &rgposSrc[0]) - { - Contracts.Assert(0 <= crun && crun <= dst.Size); - Thunk.MatMulP(Ptr(mat, pmat), ppossrc, Ptr(srcValues, psrc), posMin, iposMin, iposLim, Ptr(dst, pdst), crun, srcValues.Size); - } - } - } - - // dst += a - public static void Add(float a, Span dst) - { - Contracts.AssertNonEmpty(dst); - - unsafe - { - fixed (float* pdst = &MemoryMarshal.GetReference(dst)) - Thunk.AddScalarU(a, pdst, dst.Length); - } - } - - public static void Scale(float a, Span dst) - { - Contracts.AssertNonEmpty(dst); - - unsafe - { - fixed (float* pd = &MemoryMarshal.GetReference(dst)) - Thunk.Scale(a, pd, dst.Length); - } - } - - // dst = a * src - public static void Scale(float a, ReadOnlySpan src, Span dst, int count) - { - Contracts.AssertNonEmpty(src); - Contracts.Assert(0 < count && count <= src.Length); - Contracts.AssertNonEmpty(dst); - Contracts.Assert(count <= dst.Length); - - unsafe - { - fixed (float* psrc = &MemoryMarshal.GetReference(src)) - fixed (float* pdst = &MemoryMarshal.GetReference(dst)) - { - Thunk.ScaleSrcU(a, psrc, pdst, count); - } - } - } - - // dst[i] = a * (dst[i] + b) - public static void ScaleAdd(float a, float b, Span dst) - { - Contracts.AssertNonEmpty(dst); - - unsafe - { - fixed (float* pdst = &MemoryMarshal.GetReference(dst)) - Thunk.ScaleAddU(a, b, pdst, dst.Length); - } - } - - public static void AddScale(float a, ReadOnlySpan src, Span dst, int count) - { - Contracts.AssertNonEmpty(src); - Contracts.Assert(0 < count && count <= src.Length); - Contracts.AssertNonEmpty(dst); - Contracts.Assert(count <= dst.Length); - - unsafe - { - fixed (float* psrc = &MemoryMarshal.GetReference(src)) - fixed (float* pdst = &MemoryMarshal.GetReference(dst)) - Thunk.AddScaleU(a, psrc, pdst, count); - } - } - - public static void AddScale(float a, ReadOnlySpan src, ReadOnlySpan indices, Span dst, int count) - { - Contracts.AssertNonEmpty(src); - Contracts.Assert(0 < count && count <= src.Length); - Contracts.AssertNonEmpty(indices); - Contracts.Assert(count <= indices.Length); - Contracts.AssertNonEmpty(dst); - Contracts.Assert(count < dst.Length); - - unsafe - { - fixed (float* psrc = &MemoryMarshal.GetReference(src)) - fixed (int* pi = &MemoryMarshal.GetReference(indices)) - fixed (float* pdst = &MemoryMarshal.GetReference(dst)) - Thunk.AddScaleSU(a, psrc, pi, pdst, count); - } - } - - public static void AddScaleCopy(float a, ReadOnlySpan src, ReadOnlySpan dst, Span res, int count) - { - Contracts.AssertNonEmpty(dst); - Contracts.Assert(0 < count && count <= dst.Length); - Contracts.AssertNonEmpty(src); - Contracts.Assert(count <= src.Length); - Contracts.AssertNonEmpty(res); - Contracts.Assert(count <= res.Length); - - unsafe - { - fixed (float* pdst = &MemoryMarshal.GetReference(dst)) - fixed (float* psrc = &MemoryMarshal.GetReference(src)) - fixed (float* pres = &MemoryMarshal.GetReference(res)) - Thunk.AddScaleCopyU(a, psrc, pdst, pres, count); - } - } - - public static void Add(ReadOnlySpan src, Span dst, int count) - { - Contracts.AssertNonEmpty(src); - Contracts.Assert(0 < count && count <= src.Length); - Contracts.AssertNonEmpty(dst); - Contracts.Assert(count <= dst.Length); - - unsafe - { - fixed (float* ps = &MemoryMarshal.GetReference(src)) - fixed (float* pd = &MemoryMarshal.GetReference(dst)) - Thunk.AddU(ps, pd, count); - } - } - - public static void Add(ReadOnlySpan src, ReadOnlySpan indices, Span dst, int count) - { - Contracts.AssertNonEmpty(src); - Contracts.Assert(0 < count && count <= src.Length); - Contracts.AssertNonEmpty(indices); - Contracts.Assert(count <= indices.Length); - Contracts.AssertNonEmpty(dst); - Contracts.Assert(count < dst.Length); - - unsafe - { - fixed (float* ps = &MemoryMarshal.GetReference(src)) - fixed (int* pi = &MemoryMarshal.GetReference(indices)) - fixed (float* pd = &MemoryMarshal.GetReference(dst)) - Thunk.AddSU(ps, pi, pd, count); - } - } - - public static void MulElementWise(ReadOnlySpan src1, ReadOnlySpan src2, Span dst, int count) - { - Contracts.AssertNonEmpty(src1); - Contracts.Assert(0 < count && count <= src1.Length); - Contracts.AssertNonEmpty(src2); - Contracts.Assert(0 < count && count <= src2.Length); - Contracts.AssertNonEmpty(dst); - unsafe - { - fixed (float* ps1 = &MemoryMarshal.GetReference(src1)) - fixed (float* ps2 = &MemoryMarshal.GetReference(src2)) - fixed (float* pd = &MemoryMarshal.GetReference(dst)) - Thunk.MulElementWiseU(ps1, ps2, pd, count); - } - } - - public static float Sum(ReadOnlySpan src) - { - Contracts.AssertNonEmpty(src); - - unsafe - { - fixed (float* psrc = &MemoryMarshal.GetReference(src)) - return Thunk.Sum(psrc, src.Length); - } - } - - public static float SumSq(ReadOnlySpan src) - { - Contracts.AssertNonEmpty(src); - - unsafe - { - fixed (float* psrc = &MemoryMarshal.GetReference(src)) - return Thunk.SumSqU(psrc, src.Length); - } - } - - public static float SumSq(float mean, ReadOnlySpan src) - { - Contracts.AssertNonEmpty(src); - - unsafe - { - fixed (float* psrc = &MemoryMarshal.GetReference(src)) - return (mean == 0 ? Thunk.SumSqU(psrc, src.Length) : Thunk.SumSqDiffU(mean, psrc, src.Length)); - } - } - - public static float SumAbs(ReadOnlySpan src) - { - Contracts.AssertNonEmpty(src); - - unsafe - { - fixed (float* psrc = &MemoryMarshal.GetReference(src)) - return Thunk.SumAbsU(psrc, src.Length); - } - } - - public static float SumAbs(float mean, ReadOnlySpan src) - { - Contracts.AssertNonEmpty(src); - - unsafe - { - fixed (float* psrc = &MemoryMarshal.GetReference(src)) - return (mean == 0 ? Thunk.SumAbsU(psrc, src.Length) : Thunk.SumAbsDiffU(mean, psrc, src.Length)); - } - } - - public static float MaxAbs(ReadOnlySpan src) - { - Contracts.AssertNonEmpty(src); - - unsafe - { - fixed (float* psrc = &MemoryMarshal.GetReference(src)) - return Thunk.MaxAbsU(psrc, src.Length); - } - } - - public static float MaxAbsDiff(float mean, ReadOnlySpan src) - { - Contracts.AssertNonEmpty(src); - - unsafe - { - fixed (float* psrc = &MemoryMarshal.GetReference(src)) - return Thunk.MaxAbsDiffU(mean, psrc, src.Length); - } - } - - public static float DotProductDense(ReadOnlySpan a, ReadOnlySpan b, int count) - { - Contracts.AssertNonEmpty(a); - Contracts.AssertNonEmpty(b); - Contracts.Assert(0 < count); - Contracts.Assert(a.Length >= count); - Contracts.Assert(b.Length >= count); - - unsafe - { - fixed (float* pa = &MemoryMarshal.GetReference(a)) - fixed (float* pb = &MemoryMarshal.GetReference(b)) - return Thunk.DotU(pa, pb, count); - } - } - - public static float DotProductSparse(ReadOnlySpan a, ReadOnlySpan b, ReadOnlySpan indices, int count) - { - Contracts.AssertNonEmpty(a); - Contracts.AssertNonEmpty(b); - Contracts.Assert(0 < count); - Contracts.Assert(count < a.Length); - Contracts.Assert(count <= b.Length); - Contracts.Assert(count <= indices.Length); - - unsafe - { - fixed (float* pa = &MemoryMarshal.GetReference(a)) - fixed (float* pb = &MemoryMarshal.GetReference(b)) - fixed (int* pi = &MemoryMarshal.GetReference(indices)) - return Thunk.DotSU(pa, pb, pi, count); - } - } - - public static float L2DistSquared(ReadOnlySpan a, ReadOnlySpan b, int count) - { - Contracts.AssertNonEmpty(a); - Contracts.AssertNonEmpty(b); - Contracts.Assert(0 < count && count <= a.Length); - Contracts.Assert(count <= b.Length); - - unsafe - { - fixed (float* pa = &MemoryMarshal.GetReference(a)) - fixed (float* pb = &MemoryMarshal.GetReference(b)) - return Thunk.Dist2(pa, pb, count); - } - } - - public static void ZeroMatrixItems(AlignedArray dst, int ccol, int cfltRow, int[] indices) - { - Contracts.Assert(0 < ccol && ccol <= cfltRow); - - unsafe - { - fixed (float* pdst = &dst.Items[0]) - fixed (int* pi = &indices[0]) - { - if (ccol == cfltRow) - Thunk.ZeroItemsU(Ptr(dst, pdst), dst.Size, pi, indices.Length); - else - Thunk.ZeroMatrixItemsCore(Ptr(dst, pdst), dst.Size, ccol, cfltRow, pi, indices.Length); - } - } - } - - public static void SdcaL1UpdateDense(float primalUpdate, int count, ReadOnlySpan src, float threshold, Span v, Span w) - { - Contracts.AssertNonEmpty(src); - Contracts.Assert(count <= src.Length); - Contracts.AssertNonEmpty(v); - Contracts.Assert(count <= v.Length); - Contracts.AssertNonEmpty(w); - Contracts.Assert(count <= w.Length); - Contracts.Assert(count > 0); - - unsafe - { - fixed (float* psrc = &MemoryMarshal.GetReference(src)) - fixed (float* pd1 = &MemoryMarshal.GetReference(v)) - fixed (float* pd2 = &MemoryMarshal.GetReference(w)) - Thunk.SdcaL1UpdateU(primalUpdate, psrc, threshold, pd1, pd2, count); - } - } - - public static void SdcaL1UpdateSparse(float primalUpdate, int count, ReadOnlySpan source, ReadOnlySpan indices, float threshold, Span v, Span w) - { - Contracts.AssertNonEmpty(source); - Contracts.Assert(count <= source.Length); - Contracts.AssertNonEmpty(indices); - Contracts.Assert(count <= indices.Length); - Contracts.AssertNonEmpty(v); - Contracts.Assert(count <= v.Length); - Contracts.AssertNonEmpty(w); - Contracts.Assert(count <= w.Length); - Contracts.Assert(count > 0); - - unsafe - { - fixed (float* psrc = &MemoryMarshal.GetReference(source)) - fixed (int* pi = &MemoryMarshal.GetReference(indices)) - fixed (float* pd1 = &MemoryMarshal.GetReference(v)) - fixed (float* pd2 = &MemoryMarshal.GetReference(w)) - Thunk.SdcaL1UpdateSU(primalUpdate, psrc, pi, threshold, pd1, pd2, count); - } - } - } -} \ No newline at end of file diff --git a/src/Microsoft.ML.CpuMath/SseIntrinsics.cs b/src/Microsoft.ML.CpuMath/SseIntrinsics.cs index 5d6f2ed134..eabbf19508 100644 --- a/src/Microsoft.ML.CpuMath/SseIntrinsics.cs +++ b/src/Microsoft.ML.CpuMath/SseIntrinsics.cs @@ -18,9 +18,10 @@ using System.Runtime.InteropServices; using System.Runtime.Intrinsics; using System.Runtime.Intrinsics.X86; +using Microsoft.ML.Internal.CpuMath.Core; using nuint = System.UInt64; -namespace Microsoft.ML.Runtime.Internal.CpuMath +namespace Microsoft.ML.Internal.CpuMath { internal static class SseIntrinsics { @@ -40,6 +41,26 @@ internal static class SseIntrinsics 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, }; + // The count of bytes in Vector128, corresponding to _cbAlign in AlignedArray + private const int Vector128Alignment = 16; + + [MethodImplAttribute(MethodImplOptions.AggressiveInlining)] + private static bool HasCompatibleAlignment(AlignedArray alignedArray) + { + Contracts.AssertValue(alignedArray); + Contracts.Assert(alignedArray.Size > 0); + return (alignedArray.CbAlign % Vector128Alignment) == 0; + } + + [MethodImplAttribute(MethodImplOptions.AggressiveInlining)] + private static unsafe float* GetAlignedBase(AlignedArray alignedArray, float* unalignedBase) + { + Contracts.AssertValue(alignedArray); + float* alignedBase = unalignedBase + alignedArray.GetBase((long)unalignedBase); + Contracts.Assert(((long)alignedBase & (Vector128Alignment - 1)) == 0); + return alignedBase; + } + internal static readonly Vector128 AbsMask128 = Sse2.IsSupported ? Sse.StaticCast(Sse2.SetAllVector128(0x7FFFFFFF)) : Sse.SetAllVector128(BitConverter.Int32BitsToSingle(0x7FFFFFFF)); @@ -117,20 +138,18 @@ internal static Vector128 GetNewDst128(in Vector128 xDst1, in Vect // Multiply matrix times vector into vector. public static unsafe void MatMul(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol) { - MatMul(mat.Items, src.Items, dst.Items, crow, ccol); - } - - public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan src, Span dst, int crow, int ccol) - { - Contracts.Assert(crow % 4 == 0); - Contracts.Assert(ccol % 4 == 0); + Contracts.Assert(HasCompatibleAlignment(mat)); + Contracts.Assert(HasCompatibleAlignment(src)); + Contracts.Assert(HasCompatibleAlignment(dst)); - fixed (float* psrc = &MemoryMarshal.GetReference(src)) - fixed (float* pdst = &MemoryMarshal.GetReference(dst)) - fixed (float* pmat = &MemoryMarshal.GetReference(mat)) - fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0]) - fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0]) + fixed (float* pSrcStart = &src.Items[0]) + fixed (float* pDstStart = &dst.Items[0]) + fixed (float* pMatStart = &mat.Items[0]) { + float* psrc = GetAlignedBase(src, pSrcStart); + float* pdst = GetAlignedBase(dst, pDstStart); + float* pmat = GetAlignedBase(mat, pMatStart); + float* pSrcEnd = psrc + ccol; float* pDstEnd = pdst + crow; float* pDstCurrent = pdst; @@ -139,128 +158,29 @@ public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan sr while (pDstCurrent < pDstEnd) { Vector128 res0 = Sse.SetZeroVector128(); - Vector128 res1 = Sse.SetZeroVector128(); - Vector128 res2 = Sse.SetZeroVector128(); - Vector128 res3 = Sse.SetZeroVector128(); + Vector128 res1 = res0; + Vector128 res2 = res0; + Vector128 res3 = res0; - int length = ccol; float* pSrcCurrent = psrc; - nuint address = (nuint)(pMatCurrent); - int misalignment = (int)(address % 16); - int remainder = 0; - - if ((misalignment & 3) != 0) - { - // Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations - while (pSrcCurrent < pSrcEnd) - { - Vector128 vector = Sse.LoadVector128(pSrcCurrent); - - float* pMatTemp = pMatCurrent; - Vector128 x01 = Sse.Multiply(vector, Sse.LoadVector128(pMatTemp)); - Vector128 x11 = Sse.Multiply(vector, Sse.LoadVector128(pMatTemp += ccol)); - Vector128 x21 = Sse.Multiply(vector, Sse.LoadVector128(pMatTemp += ccol)); - Vector128 x31 = Sse.Multiply(vector, Sse.LoadVector128(pMatTemp += ccol)); - - res0 = Sse.Add(res0, x01); - res1 = Sse.Add(res1, x11); - res2 = Sse.Add(res2, x21); - res3 = Sse.Add(res3, x31); - - pSrcCurrent += 4; - pMatCurrent += 4; - } - } - else + while (pSrcCurrent < pSrcEnd) { - if (misalignment != 0) - { - // Handle cases where the data is not 128-bit aligned by doing an unaligned read and then - // masking any elements that will be included in the first aligned read - misalignment >>= 2; - misalignment = 4 - misalignment; - - Vector128 mask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + (misalignment * 4)); - - // We only align pMat since it has significantly more reads. - float* pMatTemp = pMatCurrent; - Vector128 x01 = Sse.And(mask, Sse.LoadVector128(pMatTemp)); - Vector128 x11 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); - Vector128 x21 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); - Vector128 x31 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); - Vector128 vector = Sse.And(mask, Sse.LoadVector128(pSrcCurrent)); - - res0 = Sse.Multiply(x01, vector); - res1 = Sse.Multiply(x11, vector); - res2 = Sse.Multiply(x21, vector); - res3 = Sse.Multiply(x31, vector); - - pMatCurrent += misalignment; - pSrcCurrent += misalignment; - length -= misalignment; - } + float* pMatTemp = pMatCurrent; - if (length > 3) - { - // Handle all the 128-bit blocks that we can now that we have offset to an aligned address - remainder = length % 4; + Vector128 x01 = Sse.LoadAlignedVector128(pMatTemp); + Vector128 x11 = Sse.LoadAlignedVector128(pMatTemp += ccol); + Vector128 x21 = Sse.LoadAlignedVector128(pMatTemp += ccol); + Vector128 x31 = Sse.LoadAlignedVector128(pMatTemp += ccol); + Vector128 x02 = Sse.LoadAlignedVector128(pSrcCurrent); - // If we aren't using the VEX-encoding, the JIT will only fold away aligned loads - // (due to semantics of the legacy encoding). - // We don't need an assert, since the instruction will throw for unaligned inputs. - while (pSrcCurrent + 4 <= pSrcEnd) - { - Vector128 vector = Sse.LoadVector128(pSrcCurrent); - - float* pMatTemp = pMatCurrent; - Vector128 x01 = Sse.Multiply(vector, Sse.LoadAlignedVector128(pMatTemp)); - Vector128 x11 = Sse.Multiply(vector, Sse.LoadAlignedVector128(pMatTemp += ccol)); - Vector128 x21 = Sse.Multiply(vector, Sse.LoadAlignedVector128(pMatTemp += ccol)); - Vector128 x31 = Sse.Multiply(vector, Sse.LoadAlignedVector128(pMatTemp += ccol)); - - res0 = Sse.Add(res0, x01); - res1 = Sse.Add(res1, x11); - res2 = Sse.Add(res2, x21); - res3 = Sse.Add(res3, x31); - - pSrcCurrent += 4; - pMatCurrent += 4; - } - } - else - { - // Handle the "worst-case" scenario, which is when we have 4-8 elements and the input is not - // 128-bit aligned. This means we can't do any aligned loads and will just end up doing two - // unaligned loads where we mask the input each time. - remainder = length; - } + res0 = Sse.Add(res0, Sse.Multiply(x01, x02)); + res1 = Sse.Add(res1, Sse.Multiply(x11, x02)); + res2 = Sse.Add(res2, Sse.Multiply(x21, x02)); + res3 = Sse.Add(res3, Sse.Multiply(x31, x02)); - if (remainder != 0) - { - // Handle any trailing elements that don't fit into a 128-bit block by moving back so that the next - // unaligned load will read to the end of the array and then mask out any elements already processed - - pMatCurrent -= (4 - remainder); - pSrcCurrent -= (4 - remainder); - - Vector128 mask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + (remainder * 4)); - - float* pMatTemp = pMatCurrent; - Vector128 x01 = Sse.And(mask, Sse.LoadVector128(pMatTemp)); - Vector128 x11 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); - Vector128 x21 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); - Vector128 x31 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); - Vector128 vector = Sse.And(mask, Sse.LoadVector128(pSrcCurrent)); - - res0 = Sse.Add(res0, Sse.Multiply(x01, vector)); - res1 = Sse.Add(res1, Sse.Multiply(x11, vector)); - res2 = Sse.Add(res2, Sse.Multiply(x21, vector)); - res3 = Sse.Add(res3, Sse.Multiply(x31, vector)); - - pMatCurrent += 4; - pSrcCurrent += 4; - } + pSrcCurrent += 4; + pMatCurrent += 4; } // Add up the entries of each, with the 4 results in res0 @@ -268,7 +188,8 @@ public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan sr res2 = Sse3.HorizontalAdd(res2, res3); res0 = Sse3.HorizontalAdd(res0, res2); - Sse.Store(pDstCurrent, res0); + Sse.StoreAligned(pDstCurrent, res0); + pDstCurrent += 4; pMatCurrent += 3 * ccol; } @@ -279,24 +200,23 @@ public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan sr public static unsafe void MatMulP(AlignedArray mat, ReadOnlySpan rgposSrc, AlignedArray src, int posMin, int iposMin, int iposEnd, AlignedArray dst, int crow, int ccol) { - MatMulP(mat.Items, rgposSrc, src.Items, posMin, iposMin, iposEnd, dst.Items, crow, ccol); - } - - public static unsafe void MatMulP(ReadOnlySpan mat, ReadOnlySpan rgposSrc, ReadOnlySpan src, - int posMin, int iposMin, int iposEnd, Span dst, int crow, int ccol) - { - Contracts.Assert(crow % 4 == 0); - Contracts.Assert(ccol % 4 == 0); + // REVIEW: For extremely sparse inputs, interchanging the loops would + // likely be more efficient. + Contracts.Assert(HasCompatibleAlignment(mat)); + Contracts.Assert(HasCompatibleAlignment(src)); + Contracts.Assert(HasCompatibleAlignment(dst)); // REVIEW: For extremely sparse inputs, interchanging the loops would // likely be more efficient. - fixed (float* psrc = &MemoryMarshal.GetReference(src)) - fixed (float* pdst = &MemoryMarshal.GetReference(dst)) - fixed (float* pmat = &MemoryMarshal.GetReference(mat)) - fixed (int* pposSrc = &MemoryMarshal.GetReference(rgposSrc)) - fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0]) - fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0]) + fixed (float* pSrcStart = &src.Items[0]) + fixed (float* pDstStart = &dst.Items[0]) + fixed (float* pMatStart = &mat.Items[0]) + fixed (int* pposSrc = &rgposSrc[0]) { + float* psrc = GetAlignedBase(src, pSrcStart); + float* pdst = GetAlignedBase(dst, pDstStart); + float* pmat = GetAlignedBase(mat, pMatStart); + int* pposMin = pposSrc + iposMin; int* pposEnd = pposSrc + iposEnd; float* pDstEnd = pdst + crow; @@ -304,120 +224,7 @@ public static unsafe void MatMulP(ReadOnlySpan mat, ReadOnlySpan rgp float* pSrcCurrent = psrc - posMin; float* pDstCurrent = pdst; - nuint address = (nuint)(pDstCurrent); - int misalignment = (int)(address % 16); - - int length = crow; - int remainder = 0; - - if ((misalignment & 3) != 0) - { - // Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations - while (pDstCurrent < pDstEnd) - { - Sse.Store(pDstCurrent, SparseMultiplicationAcrossRow()); - pDstCurrent += 4; - pm0 += 4 * ccol; - } - } - else - { - if (misalignment != 0) - { - // Handle cases where the data is not 128-bit aligned by doing an unaligned read and then - // masking any elements that will be included in the first aligned read - - misalignment >>= 2; - misalignment = 4 - misalignment; - - Vector128 mask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + (misalignment * 4)); - - float* pm1 = pm0 + ccol; - float* pm2 = pm1 + ccol; - float* pm3 = pm2 + ccol; - Vector128 result = Sse.SetZeroVector128(); - - int* ppos = pposMin; - - while (ppos < pposEnd) - { - int col = *ppos; - Vector128 x1 = Sse.SetVector128(pm3[col], pm2[col], pm1[col], pm0[col]); - - x1 = Sse.And(mask, x1); - Vector128 x2 = Sse.SetAllVector128(pSrcCurrent[col]); - x2 = Sse.Multiply(x2, x1); - result = Sse.Add(result, x2); - ppos++; - } - - Sse.Store(pDstCurrent, result); - pDstCurrent += misalignment; - pm0 += misalignment * ccol; - length -= misalignment; - } - - if (length > 3) - { - // Handle all the 128-bit blocks that we can now that we have offset to an aligned address - remainder = length % 4; - - // If we aren't using the VEX-encoding, the JIT will only fold away aligned loads - // (due to semantics of the legacy encoding). - // We don't need an assert, since the instruction will throw for unaligned inputs. - while (pDstCurrent < pDstEnd) - { - Sse.Store(pDstCurrent, SparseMultiplicationAcrossRow()); - pDstCurrent += 4; - pm0 += 4 * ccol; - } - } - else - { - // Handle the "worst-case" scenario, which is when we have 4-8 elements and the input is not - // 128-bit aligned. This means we can't do any aligned loads and will just end up doing two - // unaligned loads where we mask the input each time. - remainder = length; - } - - if (remainder != 0) - { - // Handle any trailing elements that don't fit into a 128-bit block by moving back so that the next - // unaligned load will read to the end of the array and then mask out any elements already processed - pDstCurrent -= (4 - remainder); - pm0 -= (4 - remainder) * ccol; - - Vector128 trailingMask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + (remainder * 4)); - Vector128 leadingMask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + ((4 - remainder) * 4)); - - float* pm1 = pm0 + ccol; - float* pm2 = pm1 + ccol; - float* pm3 = pm2 + ccol; - Vector128 result = Sse.SetZeroVector128(); - - int* ppos = pposMin; - - while (ppos < pposEnd) - { - int col = *ppos; - Vector128 x1 = Sse.SetVector128(pm3[col], pm2[col], pm1[col], pm0[col]); - x1 = Sse.And(x1, trailingMask); - - Vector128 x2 = Sse.SetAllVector128(pSrcCurrent[col]); - x2 = Sse.Multiply(x2, x1); - result = Sse.Add(result, x2); - ppos++; - } - - result = Sse.Add(result, Sse.And(leadingMask, Sse.LoadVector128(pDstCurrent))); - - Sse.Store(pDstCurrent, result); - pDstCurrent += 4; - pm0 += 4 * ccol; - } - } - - Vector128 SparseMultiplicationAcrossRow() + while (pDstCurrent < pDstEnd) { float* pm1 = pm0 + ccol; float* pm2 = pm1 + ccol; @@ -433,313 +240,107 @@ Vector128 SparseMultiplicationAcrossRow() Vector128 x2 = Sse.SetAllVector128(pSrcCurrent[col]); x2 = Sse.Multiply(x2, x1); result = Sse.Add(result, x2); + ppos++; } - return result; + Sse.StoreAligned(pDstCurrent, result); + pDstCurrent += 4; + pm0 += 4 * ccol; } } } public static unsafe void MatMulTran(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol) { - MatMulTran(mat.Items, src.Items, dst.Items, crow, ccol); - } - - public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan src, Span dst, int crow, int ccol) - { - Contracts.Assert(crow % 4 == 0); - Contracts.Assert(ccol % 4 == 0); + Contracts.Assert(HasCompatibleAlignment(mat)); + Contracts.Assert(HasCompatibleAlignment(src)); + Contracts.Assert(HasCompatibleAlignment(dst)); - fixed (float* psrc = &MemoryMarshal.GetReference(src)) - fixed (float* pdst = &MemoryMarshal.GetReference(dst)) - fixed (float* pmat = &MemoryMarshal.GetReference(mat)) - fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0]) - fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0]) + fixed (float* pSrcStart = &src.Items[0]) + fixed (float* pDstStart = &dst.Items[0]) + fixed (float* pMatStart = &mat.Items[0]) { + float* psrc = GetAlignedBase(src, pSrcStart); + float* pdst = GetAlignedBase(dst, pDstStart); + float* pmat = GetAlignedBase(mat, pMatStart); + float* pSrcEnd = psrc + ccol; float* pDstEnd = pdst + crow; float* pSrcCurrent = psrc; float* pMatCurrent = pmat; - // The reason behind adding the if condtion instead of boolean flag - // is to avoid branching in codegen. - if (pSrcCurrent < pSrcEnd) - { - Vector128 x01 = Sse.LoadVector128(pSrcCurrent); - // Replicate each 32-bit slot of x01 (ABCD) into its own register. - Vector128 x11 = Sse.Shuffle(x01, x01, 0x55); // B - Vector128 x21 = Sse.Shuffle(x01, x01, 0xAA); // C - Vector128 x31 = Sse.Shuffle(x01, x01, 0xFF); // D - x01 = Sse.Shuffle(x01, x01, 0x00); // A + Vector128 x01 = Sse.LoadAlignedVector128(pSrcCurrent); + // Replicate each 32-bit slot of x01 (ABCD) into its own register. + Vector128 x11 = Sse.Shuffle(x01, x01, 0x55); // B + Vector128 x21 = Sse.Shuffle(x01, x01, 0xAA); // C + Vector128 x31 = Sse.Shuffle(x01, x01, 0xFF); // D + x01 = Sse.Shuffle(x01, x01, 0x00); // A - int length = crow; - float* pDstCurrent = pdst; + pSrcCurrent += 4; - nuint address = (nuint)(pMatCurrent); - int misalignment = (int)(address % 16); + float* pDstCurrent = pdst; - if ((misalignment & 3) != 0) - { - // Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations - while (pDstCurrent < pDstEnd) - { - float* pMatTemp = pMatCurrent; - Vector128 x02 = Sse.Multiply(x01, Sse.LoadVector128(pMatTemp)); - Vector128 x12 = Sse.Multiply(x11, Sse.LoadVector128(pMatTemp += crow)); - Vector128 x22 = Sse.Multiply(x21, Sse.LoadVector128(pMatTemp += crow)); - Vector128 x32 = Sse.Multiply(x31, Sse.LoadVector128(pMatTemp += crow)); - - x02 = Sse.Add(x02, x12); - x22 = Sse.Add(x22, x32); - x02 = Sse.Add(x02, x22); - - Sse.Store(pDstCurrent, x02); - pDstCurrent += 4; - pMatCurrent += 4; - } - } - else - { - int remainder = 0; - if (misalignment != 0) - { - // Handle cases where the data is not 128-bit aligned by doing an unaligned read and then - // masking any elements that will be included in the first aligned read - misalignment >>= 2; - misalignment = 4 - misalignment; - - Vector128 leadingMask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + (misalignment * 4)); - - // We only align pMat since it has significantly more reads. - float* pMatTemp = pMatCurrent; - Vector128 x02 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp)); - Vector128 x12 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp += crow)); - Vector128 x22 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp += crow)); - Vector128 x32 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp += crow)); - - x02 = Sse.Multiply(x01, x02); - x12 = Sse.Multiply(x11, x12); - x22 = Sse.Multiply(x21, x22); - x32 = Sse.Multiply(x31, x32); - - x02 = Sse.Add(x02, x12); - x22 = Sse.Add(x22, x32); - x02 = Sse.Add(x02, x22); - - Vector128 trailingMask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + ((4 - misalignment) * 4)); - Vector128 x3 = Sse.LoadVector128(pDstCurrent); - x02 = Sse.Or(x02, Sse.And(x3, trailingMask)); - - Sse.Store(pDstCurrent, x02); - pMatCurrent += misalignment; - pDstCurrent += misalignment; - length -= misalignment; - } - if (length > 3) - { - // Handle all the 128-bit blocks that we can now that we have offset to an aligned address - remainder = length % 4; - while (pDstCurrent + 4 <= pDstEnd) - { - // If we aren't using the VEX-encoding, the JIT will only fold away aligned loads - // (due to semantics of the legacy encoding). - // We don't need an assert, since the instruction will throw for unaligned inputs. - float* pMatTemp = pMatCurrent; - - Vector128 x02 = Sse.Multiply(x01, Sse.LoadAlignedVector128(pMatTemp)); - Vector128 x12 = Sse.Multiply(x11, Sse.LoadAlignedVector128(pMatTemp += crow)); - Vector128 x22 = Sse.Multiply(x21, Sse.LoadAlignedVector128(pMatTemp += crow)); - Vector128 x32 = Sse.Multiply(x31, Sse.LoadAlignedVector128(pMatTemp += crow)); - - x02 = Sse.Add(x02, x12); - x22 = Sse.Add(x22, x32); - x02 = Sse.Add(x02, x22); - - Sse.Store(pDstCurrent, x02); - pDstCurrent += 4; - pMatCurrent += 4; - } - } - else - { - // Handle the "worst-case" scenario, which is when we have 4-8 elements and the input is not - // 128-bit aligned. This means we can't do any aligned loads and will just end up doing two - // unaligned loads where we mask the input each time. - remainder = length; - } + while (pDstCurrent < pDstEnd) + { + float* pMatTemp = pMatCurrent; + Vector128 x02 = Sse.LoadAlignedVector128(pMatTemp); + Vector128 x12 = Sse.LoadAlignedVector128(pMatTemp += crow); + Vector128 x22 = Sse.LoadAlignedVector128(pMatTemp += crow); + Vector128 x32 = Sse.LoadAlignedVector128(pMatTemp += crow); - if (remainder != 0) - { - // Handle any trailing elements that don't fit into a 128-bit block by moving back so that the next - // unaligned load will read to the end of the array and then mask out any elements already processed - pMatCurrent -= (4 - remainder); - pDstCurrent -= (4 - remainder); - - Vector128 trailingMask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + (remainder * 4)); - - float* pMatTemp = pMatCurrent; - Vector128 x02 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp)); - Vector128 x12 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp += crow)); - Vector128 x22 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp += crow)); - Vector128 x32 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp += crow)); - - x02 = Sse.Multiply(x01, x02); - x12 = Sse.Multiply(x11, x12); - x22 = Sse.Multiply(x21, x22); - x32 = Sse.Multiply(x31, x32); - - x02 = Sse.Add(x02, x12); - x22 = Sse.Add(x22, x32); - x02 = Sse.Add(x02, x22); - - Vector128 leadingMask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + ((4 - remainder) * 4)); - Vector128 x3 = Sse.LoadVector128(pDstCurrent); - x02 = Sse.Or(x02, Sse.And(x3, leadingMask)); - - Sse.Store(pDstCurrent, x02); - pDstCurrent += 4; - pMatCurrent += 4; - } - } + x02 = Sse.Multiply(x01, x02); + x12 = Sse.Multiply(x11, x12); + x22 = Sse.Multiply(x21, x22); + x32 = Sse.Multiply(x31, x32); - pMatCurrent += 3 * crow; - pSrcCurrent += 4; + x02 = Sse.Add(x02, x12); + x22 = Sse.Add(x22, x32); + x02 = Sse.Add(x02, x22); + + Sse.StoreAligned(pDstCurrent, x02); + + pDstCurrent += 4; + pMatCurrent += 4; } - // We do 4-way unrolling + pMatCurrent += 3 * crow; + while (pSrcCurrent < pSrcEnd) { - Vector128 x01 = Sse.LoadVector128(pSrcCurrent); + x01 = Sse.LoadAlignedVector128(pSrcCurrent); // Replicate each 32-bit slot of x01 (ABCD) into its own register. - Vector128 x11 = Sse.Shuffle(x01, x01, 0x55); // B - Vector128 x21 = Sse.Shuffle(x01, x01, 0xAA); // C - Vector128 x31 = Sse.Shuffle(x01, x01, 0xFF); // D + x11 = Sse.Shuffle(x01, x01, 0x55); // B + x21 = Sse.Shuffle(x01, x01, 0xAA); // C + x31 = Sse.Shuffle(x01, x01, 0xFF); // D x01 = Sse.Shuffle(x01, x01, 0x00); // A - int length = crow; - float* pDstCurrent = pdst; - - nuint address = (nuint)(pMatCurrent); - int misalignment = (int)(address % 16); + pDstCurrent = pdst; - if ((misalignment & 3) != 0) + while (pDstCurrent < pDstEnd) { - while (pDstCurrent < pDstEnd) - { - float* pMatTemp = pMatCurrent; - Vector128 x02 = Sse.Multiply(x01, Sse.LoadVector128(pMatTemp)); - Vector128 x12 = Sse.Multiply(x11, Sse.LoadVector128(pMatTemp += crow)); - Vector128 x22 = Sse.Multiply(x21, Sse.LoadVector128(pMatTemp += crow)); - Vector128 x32 = Sse.Multiply(x31, Sse.LoadVector128(pMatTemp += crow)); + float* pMatTemp = pMatCurrent; - x02 = Sse.Add(x02, x12); - x22 = Sse.Add(x22, x32); - x02 = Sse.Add(x02, x22); + Vector128 x02 = Sse.LoadAlignedVector128(pMatTemp); + Vector128 x12 = Sse.LoadAlignedVector128(pMatTemp += crow); + Vector128 x22 = Sse.LoadAlignedVector128(pMatTemp += crow); + Vector128 x32 = Sse.LoadAlignedVector128(pMatTemp += crow); + Vector128 x3 = Sse.LoadAlignedVector128(pDstCurrent); - x02 = Sse.Add(x02, Sse.LoadVector128(pDstCurrent)); + x02 = Sse.Multiply(x01, x02); + x12 = Sse.Multiply(x11, x12); + x22 = Sse.Multiply(x21, x22); + x32 = Sse.Multiply(x31, x32); - Sse.Store(pDstCurrent, x02); - pDstCurrent += 4; - pMatCurrent += 4; - } - } - else - { - int remainder = 0; - if (misalignment != 0) - { - // Handle cases where the data is not 128-bit aligned by doing an unaligned read and then - // masking any elements that will be included in the first aligned read - misalignment >>= 2; - misalignment = 4 - misalignment; - - Vector128 leadingMask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + (misalignment * 4)); - - // We only align pMat since it has significantly more reads. - float* pMatTemp = pMatCurrent; - Vector128 x02 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp)); - Vector128 x12 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp += crow)); - Vector128 x22 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp += crow)); - Vector128 x32 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp += crow)); - - x02 = Sse.Multiply(x01, x02); - x12 = Sse.Multiply(x11, x12); - x22 = Sse.Multiply(x21, x22); - x32 = Sse.Multiply(x31, x32); - - x02 = Sse.Add(x02, x12); - x22 = Sse.Add(x22, x32); - x02 = Sse.Add(x02, x22); - - Vector128 trailingMask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + ((4 - misalignment) * 4)); - Vector128 x3 = Sse.LoadVector128(pDstCurrent); - x02 = Sse.Or(x02, Sse.And(x3, trailingMask)); - - x02 = Sse.Add(x02, Sse.And(x3, leadingMask)); - - Sse.Store(pDstCurrent, x02); - pMatCurrent += misalignment; - pDstCurrent += misalignment; - length -= misalignment; - } - if (length > 3) - { - remainder = length % 4; - while (pDstCurrent + 4 <= pDstEnd) - { - float* pMatTemp = pMatCurrent; - - Vector128 x02 = Sse.Multiply(x01, Sse.LoadAlignedVector128(pMatTemp)); - Vector128 x12 = Sse.Multiply(x11, Sse.LoadAlignedVector128(pMatTemp += crow)); - Vector128 x22 = Sse.Multiply(x21, Sse.LoadAlignedVector128(pMatTemp += crow)); - Vector128 x32 = Sse.Multiply(x31, Sse.LoadAlignedVector128(pMatTemp += crow)); - - x02 = Sse.Add(x02, x12); - x22 = Sse.Add(x22, x32); - x02 = Sse.Add(x02, x22); - - x02 = Sse.Add(x02, Sse.LoadVector128(pDstCurrent)); - Sse.Store(pDstCurrent, x02); - pDstCurrent += 4; - pMatCurrent += 4; - } - } - else - { - remainder = length; - } + x02 = Sse.Add(x02, x12); + x22 = Sse.Add(x22, x32); + x02 = Sse.Add(x02, x22); + x3 = Sse.Add(x02, x3); - if (remainder != 0) - { - pMatCurrent -= (4 - remainder); - pDstCurrent -= (4 - remainder); - Vector128 trailingMask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + (remainder * 4)); - - float* pMatTemp = pMatCurrent; - Vector128 x02 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp)); - Vector128 x12 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp += crow)); - Vector128 x22 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp += crow)); - Vector128 x32 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp += crow)); - - x02 = Sse.Multiply(x01, x02); - x12 = Sse.Multiply(x11, x12); - x22 = Sse.Multiply(x21, x22); - x32 = Sse.Multiply(x31, x32); - - x02 = Sse.Add(x02, x12); - x22 = Sse.Add(x22, x32); - x02 = Sse.Add(x02, x22); - - Vector128 leadingMask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + ((4 - remainder) * 4)); - Vector128 x3 = Sse.LoadVector128(pDstCurrent); - x02 = Sse.Or(x02, Sse.And(x3, leadingMask)); - - x02 = Sse.Add(x02, Sse.And(x3, trailingMask)); - Sse.Store(pDstCurrent, x02); - pDstCurrent += 4; - pMatCurrent += 4; - } + Sse.StoreAligned(pDstCurrent, x3); + + pDstCurrent += 4; + pMatCurrent += 4; } pMatCurrent += 3 * crow; @@ -832,11 +433,12 @@ public static unsafe void Scale(float scale, Span dst) Vector128 leadingMask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + (misalignment * 4)); Vector128 trailingMask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + ((4 - misalignment) * 4)); - Vector128 temp = Sse.And(result, leadingMask); - result = Sse.And(result, trailingMask); + Vector128 temp = Sse.And(result, trailingMask); + result = Sse.Multiply(scaleVector128, result); - temp = Sse.Multiply(scaleVector128, temp); - result = Sse.Or(temp, result); + // Masking operation is done at the end to avoid doing an Or operation with negative Zero. + result = Sse.And(result, leadingMask); + result = Sse.Or(result, temp); Sse.Store(pDstCurrent, result); @@ -880,21 +482,22 @@ public static unsafe void Scale(float scale, Span dst) Vector128 trailingMask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + (remainder * 4)); Vector128 leadingMask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + ((4 - remainder) * 4)); - Vector128 temp = Sse.And(result, trailingMask); - result = Sse.And(result, leadingMask); + Vector128 temp = Sse.And(result, leadingMask); + result = Sse.Multiply(scaleVector128, result); - temp = Sse.Multiply(scaleVector128, temp); - temp = Sse.Or(temp, result); + // Masking operation is done at the end to avoid doing an Or operation with negative Zero. + result = Sse.And(result, trailingMask); + result = Sse.Or(result, temp); - Sse.Store(pDstCurrent, temp); + Sse.Store(pDstCurrent, result); } } } public static unsafe void ScaleSrcU(float scale, ReadOnlySpan src, Span dst, int count) { - Contracts.Assert(src.Length == dst.Length); - + Contracts.Assert(count <= src.Length); + Contracts.Assert(count <= dst.Length); fixed (float* psrc = &MemoryMarshal.GetReference(src)) fixed (float* pdst = &MemoryMarshal.GetReference(dst)) { @@ -963,8 +566,8 @@ public static unsafe void ScaleAddU(float a, float b, Span dst) public static unsafe void AddScaleU(float scale, ReadOnlySpan src, Span dst, int count) { - Contracts.Assert(src.Length == dst.Length); - + Contracts.Assert(count <= src.Length); + Contracts.Assert(count <= dst.Length); fixed (float* psrc = &MemoryMarshal.GetReference(src)) fixed (float* pdst = &MemoryMarshal.GetReference(dst)) { @@ -1004,8 +607,9 @@ public static unsafe void AddScaleU(float scale, ReadOnlySpan src, Span src, ReadOnlySpan dst, Span result, int count) { - Contracts.Assert(src.Length == dst.Length); - + Contracts.Assert(count <= src.Length); + Contracts.Assert(count <= dst.Length); + Contracts.Assert(count <= result.Length); fixed (float* psrc = &MemoryMarshal.GetReference(src)) fixed (float* pdst = &MemoryMarshal.GetReference(dst)) fixed (float* pres = &MemoryMarshal.GetReference(result)) @@ -1047,8 +651,9 @@ public static unsafe void AddScaleCopyU(float scale, ReadOnlySpan src, Re public static unsafe void AddScaleSU(float scale, ReadOnlySpan src, ReadOnlySpan idx, Span dst, int count) { - Contracts.Assert(src.Length == dst.Length); - + Contracts.Assert(count <= src.Length); + Contracts.Assert(count <= dst.Length); + Contracts.Assert(count <= idx.Length); fixed (float* psrc = &MemoryMarshal.GetReference(src)) fixed (int* pidx = &MemoryMarshal.GetReference(idx)) fixed (float* pdst = &MemoryMarshal.GetReference(dst)) @@ -1085,8 +690,8 @@ public static unsafe void AddScaleSU(float scale, ReadOnlySpan src, ReadO public static unsafe void AddU(ReadOnlySpan src, Span dst, int count) { - Contracts.Assert(src.Length == dst.Length); - + Contracts.Assert(count <= src.Length); + Contracts.Assert(count <= dst.Length); fixed (float* psrc = &MemoryMarshal.GetReference(src)) fixed (float* pdst = &MemoryMarshal.GetReference(dst)) { @@ -1122,8 +727,9 @@ public static unsafe void AddU(ReadOnlySpan src, Span dst, int cou public static unsafe void AddSU(ReadOnlySpan src, ReadOnlySpan idx, Span dst, int count) { - Contracts.Assert(src.Length == dst.Length); - + Contracts.Assert(count <= src.Length); + Contracts.Assert(count <= dst.Length); + Contracts.Assert(count <= idx.Length); fixed (float* psrc = &MemoryMarshal.GetReference(src)) fixed (int* pidx = &MemoryMarshal.GetReference(idx)) fixed (float* pdst = &MemoryMarshal.GetReference(dst)) @@ -1157,9 +763,9 @@ public static unsafe void AddSU(ReadOnlySpan src, ReadOnlySpan idx, public static unsafe void MulElementWiseU(ReadOnlySpan src1, ReadOnlySpan src2, Span dst, int count) { - Contracts.Assert(src1.Length == dst.Length); - Contracts.Assert(src2.Length == dst.Length); - + Contracts.Assert(count <= src1.Length); + Contracts.Assert(count <= src2.Length); + Contracts.Assert(count <= dst.Length); fixed (float* psrc1 = &MemoryMarshal.GetReference(src1)) fixed (float* psrc2 = &MemoryMarshal.GetReference(src2)) fixed (float* pdst = &MemoryMarshal.GetReference(dst)) @@ -1494,8 +1100,8 @@ public static unsafe float MaxAbsDiffU(float mean, ReadOnlySpan src) public static unsafe float DotU(ReadOnlySpan src, ReadOnlySpan dst, int count) { - Contracts.Assert(src.Length == dst.Length); - + Contracts.Assert(count <= src.Length); + Contracts.Assert(count <= dst.Length); fixed (float* psrc = &MemoryMarshal.GetReference(src)) fixed (float* pdst = &MemoryMarshal.GetReference(dst)) { @@ -1535,8 +1141,9 @@ public static unsafe float DotU(ReadOnlySpan src, ReadOnlySpan dst public static unsafe float DotSU(ReadOnlySpan src, ReadOnlySpan dst, ReadOnlySpan idx, int count) { - Contracts.Assert(src.Length == dst.Length); - + Contracts.Assert(count <= src.Length); + Contracts.Assert(count <= dst.Length); + Contracts.Assert(count <= idx.Length); fixed (float* psrc = &MemoryMarshal.GetReference(src)) fixed (float* pdst = &MemoryMarshal.GetReference(dst)) fixed (int* pidx = &MemoryMarshal.GetReference(idx)) @@ -1578,8 +1185,8 @@ public static unsafe float DotSU(ReadOnlySpan src, ReadOnlySpan ds public static unsafe float Dist2(ReadOnlySpan src, ReadOnlySpan dst, int count) { - Contracts.Assert(src.Length == dst.Length); - + Contracts.Assert(count <= src.Length); + Contracts.Assert(count <= dst.Length); fixed (float* psrc = &MemoryMarshal.GetReference(src)) fixed (float* pdst = &MemoryMarshal.GetReference(dst)) { diff --git a/src/Microsoft.ML.CpuMath/Thunk.cs b/src/Microsoft.ML.CpuMath/Thunk.cs index 8ff725b54a..bd0931bd55 100644 --- a/src/Microsoft.ML.CpuMath/Thunk.cs +++ b/src/Microsoft.ML.CpuMath/Thunk.cs @@ -2,11 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime.Internal.CpuMath.Core; using System.Runtime.InteropServices; using System.Security; +using Microsoft.ML.Internal.CpuMath.Core; -namespace Microsoft.ML.Runtime.Internal.CpuMath +namespace Microsoft.ML.Internal.CpuMath { [BestFriend] internal static unsafe class Thunk diff --git a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs index c1138996d5..78bce9cec6 100644 --- a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs +++ b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs @@ -2,25 +2,24 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Command; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Calibration; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Transforms; -using Microsoft.ML.Transforms.Conversions; using System; using System.Collections.Generic; using System.IO; using System.Linq; using System.Threading.Tasks; +using Microsoft.ML; +using Microsoft.ML.Command; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Calibration; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Transforms; +using Microsoft.ML.Transforms.Conversions; [assembly: LoadableClass(typeof(CrossValidationCommand), typeof(CrossValidationCommand.Arguments), typeof(SignatureCommand), "Cross Validation", CrossValidationCommand.LoadName)] -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { [BestFriend] internal sealed class CrossValidationCommand : DataCommand.ImplBase @@ -298,7 +297,7 @@ private string GetSplitColumn(IChannel ch, IDataView input, ref IDataView output if (group != null && schema.TryGetColumnIndex(group, out index)) { // Check if group column key type with known cardinality. - var type = schema.GetColumnType(index); + var type = schema[index].Type; if (type.KeyCount > 0) stratificationColumn = group; } @@ -322,7 +321,7 @@ private string GetSplitColumn(IChannel ch, IDataView input, ref IDataView output int col; if (!input.Schema.TryGetColumnIndex(stratificationColumn, out col)) throw ch.ExceptUserArg(nameof(Arguments.StratificationColumn), "Column '{0}' does not exist", stratificationColumn); - var type = input.Schema.GetColumnType(col); + var type = input.Schema[col].Type; if (!RangeFilter.IsValidRangeFilterColumnType(ch, type)) { ch.Info("Hashing the stratification column"); diff --git a/src/Microsoft.ML.Data/Commands/DataCommand.cs b/src/Microsoft.ML.Data/Commands/DataCommand.cs index 61da007321..d74cd2cd03 100644 --- a/src/Microsoft.ML.Data/Commands/DataCommand.cs +++ b/src/Microsoft.ML.Data/Commands/DataCommand.cs @@ -6,13 +6,13 @@ using System.Collections.Generic; using System.IO; using System.Linq; -using Microsoft.ML.Runtime.Command; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data.IO; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Command; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data.IO; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// This holds useful base classes for commands that ingest a primary dataset and deal with associated model files. @@ -165,8 +165,8 @@ protected void SendTelemetryMetric(Dictionary[] metricValues) { for (int currentIndex = 0; currentIndex < cursor.Schema.Count; currentIndex++) { - var nameOfMetric = "TLC_" + cursor.Schema.GetColumnName(currentIndex); - var type = cursor.Schema.GetColumnType(currentIndex); + var nameOfMetric = "TLC_" + cursor.Schema[currentIndex].Name; + var type = cursor.Schema[currentIndex].Type; if (type.IsNumber) { var getter = RowCursorUtils.GetGetterAs(NumberType.R8, cursor, currentIndex); diff --git a/src/Microsoft.ML.Data/Commands/DefaultColumnNames.cs b/src/Microsoft.ML.Data/Commands/DefaultColumnNames.cs index 85b6b0f1a5..d66e4d518b 100644 --- a/src/Microsoft.ML.Data/Commands/DefaultColumnNames.cs +++ b/src/Microsoft.ML.Data/Commands/DefaultColumnNames.cs @@ -2,7 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { public static class DefaultColumnNames { diff --git a/src/Microsoft.ML.Data/Commands/EvaluateCommand.cs b/src/Microsoft.ML.Data/Commands/EvaluateCommand.cs index 7ed4144262..06f81f4306 100644 --- a/src/Microsoft.ML.Data/Commands/EvaluateCommand.cs +++ b/src/Microsoft.ML.Data/Commands/EvaluateCommand.cs @@ -5,11 +5,11 @@ using System; using System.Collections.Generic; using System.Text.RegularExpressions; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Command; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML; +using Microsoft.ML.Command; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Utilities; [assembly: LoadableClass(EvaluateTransform.Summary, typeof(IDataTransform), typeof(EvaluateTransform), typeof(EvaluateTransform.Arguments), typeof(SignatureDataTransform), "Evaluate Predictor", "Evaluate")] @@ -17,7 +17,7 @@ [assembly: LoadableClass(EvaluateCommand.Summary, typeof(EvaluateCommand), typeof(EvaluateCommand.Arguments), typeof(SignatureCommand), "Evaluate Predictor", "Evaluate")] -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { // REVIEW: For simplicity (since this is currently the case), // we assume that all metrics are either numeric, or numeric vectors. @@ -89,7 +89,8 @@ public string GetNameMatch(string input) /// Both take a as input. The is assumed to contain all the column /// roles needed for evaluation, including the score column. /// - public interface IEvaluator + [BestFriend] + internal interface IEvaluator { /// /// Compute the aggregate metrics. Return a dictionary from the metric kind @@ -109,12 +110,14 @@ public interface IEvaluator } /// - /// Signature for creating an IEvaluator. + /// Signature for creating an . /// - public delegate void SignatureEvaluator(); - public delegate void SignatureMamlEvaluator(); + [BestFriend] + internal delegate void SignatureEvaluator(); + [BestFriend] + internal delegate void SignatureMamlEvaluator(); - public static class EvaluateTransform + internal static class EvaluateTransform { public sealed class Arguments { @@ -143,7 +146,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV using (var ch = env.Register("EvaluateTransform").Start("Create Transform")) { ch.Trace("Binding columns"); - ISchema schema = input.Schema; + var schema = input.Schema; string label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.LabelColumn), args.LabelColumn, DefaultColumnNames.Label); string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.GroupColumn), @@ -218,7 +221,7 @@ private void RunCore(IChannel ch) (env, source) => new IO.BinaryLoader(env, new IO.BinaryLoader.Arguments(), source)); ch.Trace("Binding columns"); - ISchema schema = view.Schema; + var schema = view.Schema; string label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.LabelColumn), Args.LabelColumn, DefaultColumnNames.Label); string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.GroupColumn), diff --git a/src/Microsoft.ML.Data/Commands/SaveDataCommand.cs b/src/Microsoft.ML.Data/Commands/SaveDataCommand.cs index 8be15fdac5..3a5d11085c 100644 --- a/src/Microsoft.ML.Data/Commands/SaveDataCommand.cs +++ b/src/Microsoft.ML.Data/Commands/SaveDataCommand.cs @@ -2,17 +2,17 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Command; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Data.IO; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Transforms; using System; using System.Collections.Generic; using System.IO; using System.Linq; +using Microsoft.ML; +using Microsoft.ML.Command; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.Data.IO; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Transforms; [assembly: LoadableClass(SaveDataCommand.Summary, typeof(SaveDataCommand), typeof(SaveDataCommand.Arguments), typeof(SignatureCommand), "Save Data", "SaveData", "save")] @@ -20,7 +20,7 @@ [assembly: LoadableClass(ShowDataCommand.Summary, typeof(ShowDataCommand), typeof(ShowDataCommand.Arguments), typeof(SignatureCommand), "Show Data", "ShowData", "show")] -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { internal sealed class SaveDataCommand : DataCommand.ImplBase { @@ -144,13 +144,13 @@ private void RunCore(IChannel ch) var cols = new List(); for (int i = 0; i < data.Schema.Count; i++) { - if (!Args.KeepHidden && data.Schema.IsHidden(i)) + if (!Args.KeepHidden && data.Schema[i].IsHidden) continue; - var type = data.Schema.GetColumnType(i); + var type = data.Schema[i].Type; if (saver.IsColumnSavable(type)) cols.Add(i); else - ch.Info(MessageSensitivity.Schema, "The column '{0}' will not be written as it has unsavable column type.", data.Schema.GetColumnName(i)); + ch.Info(MessageSensitivity.Schema, "The column '{0}' will not be written as it has unsavable column type.", data.Schema[i].Name); } Host.NotSensitive().Check(cols.Count > 0, "No valid columns to save"); @@ -203,15 +203,15 @@ public static void SaveDataView(IChannel ch, IDataSaver saver, IDataView view, S ch.CheckValue(stream, nameof(stream)); var cols = new List(); - for (int i = 0; i < view.Schema.ColumnCount; i++) + for (int i = 0; i < view.Schema.Count; i++) { - if (!keepHidden && view.Schema.IsHidden(i)) + if (!keepHidden && view.Schema[i].IsHidden) continue; - var type = view.Schema.GetColumnType(i); + var type = view.Schema[i].Type; if (saver.IsColumnSavable(type)) cols.Add(i); else - ch.Info(MessageSensitivity.Schema, "The column '{0}' will not be written as it has unsavable column type.", view.Schema.GetColumnName(i)); + ch.Info(MessageSensitivity.Schema, "The column '{0}' will not be written as it has unsavable column type.", view.Schema[i].Name); } ch.Check(cols.Count > 0, "No valid columns to save"); diff --git a/src/Microsoft.ML.Data/Commands/SavePredictorCommand.cs b/src/Microsoft.ML.Data/Commands/SavePredictorCommand.cs index 633f871dd0..86ace0e0a0 100644 --- a/src/Microsoft.ML.Data/Commands/SavePredictorCommand.cs +++ b/src/Microsoft.ML.Data/Commands/SavePredictorCommand.cs @@ -4,21 +4,20 @@ using System.IO; using System.Text; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Command; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Tools; - +using Microsoft.ML; +using Microsoft.ML.Command; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; // REVIEW: Fix these namespaces. -using Microsoft.ML.Runtime.Internal.Internallearn; +using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Tools; [assembly: LoadableClass(SavePredictorCommand.Summary, typeof(SavePredictorCommand), typeof(SavePredictorCommand.Arguments), typeof(SignatureCommand), "Save Predictor As", "SavePredictorAs", "SavePredictor", "SaveAs", "SaveModel")] -namespace Microsoft.ML.Runtime.Tools +namespace Microsoft.ML.Tools { internal sealed class SavePredictorCommand : ICommand { @@ -114,7 +113,8 @@ private Stream CreateStrm(IFileHandle file) } } - public static class SavePredictorUtils + [BestFriend] + internal static class SavePredictorUtils { public static void SavePredictor(IHostEnvironment env, Stream modelStream, Stream binaryModelStream = null, Stream summaryModelStream = null, Stream textModelStream = null, Stream iniModelStream = null, Stream codeModelStream = null) diff --git a/src/Microsoft.ML.Data/Commands/ScoreCommand.cs b/src/Microsoft.ML.Data/Commands/ScoreCommand.cs index ee3d05b546..ec08c6e49d 100644 --- a/src/Microsoft.ML.Data/Commands/ScoreCommand.cs +++ b/src/Microsoft.ML.Data/Commands/ScoreCommand.cs @@ -7,18 +7,17 @@ using System; using System.Collections.Generic; using System.IO; +using Microsoft.ML; +using Microsoft.ML.Command; +using Microsoft.ML.CommandLine; using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Command; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Data.IO; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Data.IO; +using Microsoft.ML.Internal.Utilities; [assembly: LoadableClass(ScoreCommand.Summary, typeof(ScoreCommand), typeof(ScoreCommand.Arguments), typeof(SignatureCommand), "Score Predictor", "Score")] -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { using TScorerFactory = IComponentFactory; @@ -34,7 +33,8 @@ public interface IDataScorerTransform : IDataTransform, ITransformTemplate /// This parameter holds a snapshot of the role mapped training schema as /// it existed at the point when was trained, or null if it not /// available for some reason - public delegate void SignatureDataScorer(IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema); + [BestFriend] + internal delegate void SignatureDataScorer(IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema); public delegate void SignatureBindableMapper(IPredictor predictor); @@ -181,17 +181,17 @@ private void RunCore(IChannel ch) var cols = new List(); for (int i = 0; i < loader.Schema.Count; i++) { - if (!Args.KeepHidden && loader.Schema.IsHidden(i)) + if (!Args.KeepHidden && loader.Schema[i].IsHidden) continue; if (!(outputAllColumns || ShouldAddColumn(loader.Schema, i, maxScoreId, outputNamesAndLabels))) continue; - var type = loader.Schema.GetColumnType(i); + var type = loader.Schema[i].Type; if (writer.IsColumnSavable(type)) cols.Add(i); else { ch.Warning("The column '{0}' will not be written as it has unsavable column type.", - loader.Schema.GetColumnName(i)); + loader.Schema[i].Name); } } @@ -210,14 +210,14 @@ private void RunCore(IChannel ch) private bool ShouldAddColumn(Schema schema, int i, uint scoreSet, bool outputNamesAndLabels) { uint scoreSetId = 0; - if (schema.TryGetMetadata(MetadataUtils.ScoreColumnSetIdType.AsPrimitive, MetadataUtils.Kinds.ScoreColumnSetId, i, ref scoreSetId) + if (schema.TryGetMetadata(MetadataUtils.ScoreColumnSetIdType, MetadataUtils.Kinds.ScoreColumnSetId, i, ref scoreSetId) && scoreSetId == scoreSet) { return true; } if (outputNamesAndLabels) { - switch (schema.GetColumnName(i)) + switch (schema[i].Name) { case "Label": case "Name": @@ -227,7 +227,7 @@ private bool ShouldAddColumn(Schema schema, int i, uint scoreSet, bool outputNam break; } } - if (Args.OutputColumn != null && Array.FindIndex(Args.OutputColumn, schema.GetColumnName(i).Equals) >= 0) + if (Args.OutputColumn != null && Array.FindIndex(Args.OutputColumn, schema[i].Name.Equals) >= 0) return true; return false; } @@ -306,8 +306,8 @@ public static TScorerFactory GetScorerComponent( ComponentCatalog.LoadableClassInfo info = null; ReadOnlyMemory scoreKind = default; - if (mapper.Schema.Count > 0 && - mapper.Schema.TryGetMetadata(TextType.Instance, MetadataUtils.Kinds.ScoreColumnKind, 0, ref scoreKind) && + if (mapper.OutputSchema.Count > 0 && + mapper.OutputSchema.TryGetMetadata(TextType.Instance, MetadataUtils.Kinds.ScoreColumnKind, 0, ref scoreKind) && !scoreKind.IsEmpty) { var loadName = scoreKind.ToString(); diff --git a/src/Microsoft.ML.Data/Commands/ShowSchemaCommand.cs b/src/Microsoft.ML.Data/Commands/ShowSchemaCommand.cs index 4082eab1df..80c3249aaa 100644 --- a/src/Microsoft.ML.Data/Commands/ShowSchemaCommand.cs +++ b/src/Microsoft.ML.Data/Commands/ShowSchemaCommand.cs @@ -9,17 +9,17 @@ using System.Linq; using System.Reflection; using System.Text; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Command; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Data.Conversion; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML; +using Microsoft.ML.Command; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.Data.Conversion; +using Microsoft.ML.Internal.Utilities; [assembly: LoadableClass(ShowSchemaCommand.Summary, typeof(ShowSchemaCommand), typeof(ShowSchemaCommand.Arguments), typeof(SignatureCommand), "Show Schema", ShowSchemaCommand.LoadName, "schema")] -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { internal sealed class ShowSchemaCommand : DataCommand.ImplBase { @@ -113,7 +113,7 @@ private static IEnumerable GetViewChainReversed(IDataView data) } } - private static void PrintSchema(TextWriter writer, Arguments args, ISchema schema, ITransposeSchema tschema) + private static void PrintSchema(TextWriter writer, Arguments args, Schema schema, ITransposeSchema tschema) { Contracts.AssertValue(writer); Contracts.AssertValue(args); @@ -126,7 +126,7 @@ private static void PrintSchema(TextWriter writer, Arguments args, ISchema schem return; } #endif - int colLim = schema.ColumnCount; + int colLim = schema.Count; var itw = new IndentedTextWriter(writer, " "); itw.WriteLine("{0} columns:", colLim); @@ -135,8 +135,8 @@ private static void PrintSchema(TextWriter writer, Arguments args, ISchema schem var names = default(VBuffer>); for (int col = 0; col < colLim; col++) { - var name = schema.GetColumnName(col); - var type = schema.GetColumnType(col); + var name = schema[col].Name; + var type = schema[col].Type; var slotType = tschema == null ? null : tschema.GetSlotType(col); itw.WriteLine("{0}: {1}{2}", name, type, slotType == null ? "" : " (T)"); @@ -152,14 +152,14 @@ private static void PrintSchema(TextWriter writer, Arguments args, ISchema schem if (!type.IsKnownSizeVector) continue; ColumnType typeNames; - if ((typeNames = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.SlotNames, col)) == null) + if ((typeNames = schema[col].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.SlotNames)?.Type) == null) continue; if (typeNames.VectorSize != type.VectorSize || !typeNames.ItemType.IsText) { Contracts.Assert(false, "Unexpected slot names type"); continue; } - schema.GetMetadata(MetadataUtils.Kinds.SlotNames, col, ref names); + schema[col].Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref names); if (names.Length != type.VectorSize) { Contracts.Assert(false, "Unexpected length of slot names vector"); @@ -179,37 +179,35 @@ private static void PrintSchema(TextWriter writer, Arguments args, ISchema schem } } - private static void ShowMetadata(IndentedTextWriter itw, ISchema schema, int col, bool showVals) + private static void ShowMetadata(IndentedTextWriter itw, Schema schema, int col, bool showVals) { Contracts.AssertValue(itw); Contracts.AssertValue(schema); - Contracts.Assert(0 <= col && col < schema.ColumnCount); + Contracts.Assert(0 <= col && col < schema.Count); using (itw.Nest()) { - foreach (var kvp in schema.GetMetadataTypes(col).OrderBy(p => p.Key)) + foreach (var metaColumn in schema[col].Metadata.Schema.OrderBy(mcol => mcol.Name)) { - Contracts.AssertNonEmpty(kvp.Key); - Contracts.AssertValue(kvp.Value); - var type = kvp.Value; - itw.Write("Metadata '{0}': {1}", kvp.Key, type); + var type = metaColumn.Type; + itw.Write("Metadata '{0}': {1}", metaColumn.Name, type); if (showVals) { if (!type.IsVector) - ShowMetadataValue(itw, schema, col, kvp.Key, type); + ShowMetadataValue(itw, schema, col, metaColumn.Name, type); else - ShowMetadataValueVec(itw, schema, col, kvp.Key, type); + ShowMetadataValueVec(itw, schema, col, metaColumn.Name, type); } itw.WriteLine(); } } } - private static void ShowMetadataValue(IndentedTextWriter itw, ISchema schema, int col, string kind, ColumnType type) + private static void ShowMetadataValue(IndentedTextWriter itw, Schema schema, int col, string kind, ColumnType type) { Contracts.AssertValue(itw); Contracts.AssertValue(schema); - Contracts.Assert(0 <= col && col < schema.ColumnCount); + Contracts.Assert(0 <= col && col < schema.Count); Contracts.AssertNonEmpty(kind); Contracts.AssertValue(type); Contracts.Assert(!type.IsVector); @@ -220,16 +218,16 @@ private static void ShowMetadataValue(IndentedTextWriter itw, ISchema schema, in return; } - Action del = ShowMetadataValue; + Action del = ShowMetadataValue; var meth = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(type.RawType); meth.Invoke(null, new object[] { itw, schema, col, kind, type }); } - private static void ShowMetadataValue(IndentedTextWriter itw, ISchema schema, int col, string kind, ColumnType type) + private static void ShowMetadataValue(IndentedTextWriter itw, Schema schema, int col, string kind, ColumnType type) { Contracts.AssertValue(itw); Contracts.AssertValue(schema); - Contracts.Assert(0 <= col && col < schema.ColumnCount); + Contracts.Assert(0 <= col && col < schema.Count); Contracts.AssertNonEmpty(kind); Contracts.AssertValue(type); Contracts.Assert(!type.IsVector); @@ -239,17 +237,17 @@ private static void ShowMetadataValue(IndentedTextWriter itw, ISchema schema, var value = default(T); var sb = default(StringBuilder); - schema.GetMetadata(kind, col, ref value); + schema[col].Metadata.GetValue(kind, ref value); conv(in value, ref sb); itw.Write(": '{0}'", sb); } - private static void ShowMetadataValueVec(IndentedTextWriter itw, ISchema schema, int col, string kind, ColumnType type) + private static void ShowMetadataValueVec(IndentedTextWriter itw, Schema schema, int col, string kind, ColumnType type) { Contracts.AssertValue(itw); Contracts.AssertValue(schema); - Contracts.Assert(0 <= col && col < schema.ColumnCount); + Contracts.Assert(0 <= col && col < schema.Count); Contracts.AssertNonEmpty(kind); Contracts.AssertValue(type); Contracts.Assert(type.IsVector); @@ -260,16 +258,16 @@ private static void ShowMetadataValueVec(IndentedTextWriter itw, ISchema schema, return; } - Action del = ShowMetadataValueVec; + Action del = ShowMetadataValueVec; var meth = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(type.ItemType.RawType); meth.Invoke(null, new object[] { itw, schema, col, kind, type }); } - private static void ShowMetadataValueVec(IndentedTextWriter itw, ISchema schema, int col, string kind, ColumnType type) + private static void ShowMetadataValueVec(IndentedTextWriter itw, Schema schema, int col, string kind, ColumnType type) { Contracts.AssertValue(itw); Contracts.AssertValue(schema); - Contracts.Assert(0 <= col && col < schema.ColumnCount); + Contracts.Assert(0 <= col && col < schema.Count); Contracts.AssertNonEmpty(kind); Contracts.AssertValue(type); Contracts.Assert(type.IsVector); @@ -278,7 +276,7 @@ private static void ShowMetadataValueVec(IndentedTextWriter itw, ISchema sche var conv = Conversions.Instance.GetStringConversion(type.ItemType); var value = default(VBuffer); - schema.GetMetadata(kind, col, ref value); + schema[col].Metadata.GetValue(kind, ref value); itw.Write(": Length={0}, Count={0}", value.Length, value.GetValues().Length); diff --git a/src/Microsoft.ML.Data/Commands/TestCommand.cs b/src/Microsoft.ML.Data/Commands/TestCommand.cs index 407f7e713d..c27e95aa29 100644 --- a/src/Microsoft.ML.Data/Commands/TestCommand.cs +++ b/src/Microsoft.ML.Data/Commands/TestCommand.cs @@ -3,16 +3,16 @@ // See the LICENSE file in the project root for more information. using System.Collections.Generic; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Command; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML; +using Microsoft.ML.Command; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Utilities; [assembly: LoadableClass(TestCommand.Summary, typeof(TestCommand), typeof(TestCommand.Arguments), typeof(SignatureCommand), "Test Predictor", "Test")] -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// This command is essentially chaining together and @@ -94,7 +94,7 @@ private void RunCore(IChannel ch) ch.AssertValue(loader); ch.Trace("Binding columns"); - ISchema schema = loader.Schema; + var schema = loader.Schema; string label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.LabelColumn), Args.LabelColumn, DefaultColumnNames.Label); string features = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.FeatureColumn), diff --git a/src/Microsoft.ML.Data/Commands/TrainCommand.cs b/src/Microsoft.ML.Data/Commands/TrainCommand.cs index ff709de143..652ba527bf 100644 --- a/src/Microsoft.ML.Data/Commands/TrainCommand.cs +++ b/src/Microsoft.ML.Data/Commands/TrainCommand.cs @@ -2,25 +2,24 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Command; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Data.IO; -using Microsoft.ML.Runtime.Internal.Calibration; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Transforms.Normalizers; using System; using System.Collections.Generic; using System.IO; using System.Linq; +using Microsoft.ML; +using Microsoft.ML.Command; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.Data.IO; +using Microsoft.ML.Internal.Calibration; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Transforms.Normalizers; [assembly: LoadableClass(TrainCommand.Summary, typeof(TrainCommand), typeof(TrainCommand.Arguments), typeof(SignatureCommand), "Train Predictor", "Train")] -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { using ColumnRole = RoleMappedSchema.ColumnRole; @@ -146,7 +145,7 @@ private void RunCore(IChannel ch, string cmd) ch.Trace("Constructing data pipeline"); IDataView view = CreateLoader(); - ISchema schema = view.Schema; + var schema = view.Schema; var label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.LabelColumn), _labelColumn, DefaultColumnNames.Label); var feature = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.FeatureColumn), _featureColumn, DefaultColumnNames.Features); var group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.GroupColumn), _groupColumn, DefaultColumnNames.GroupId); @@ -221,7 +220,7 @@ public static void CheckTrainer(IExceptionContext ectx, IComponentFactory - public static string MatchNameOrDefaultOrNull(IExceptionContext ectx, ISchema schema, string argName, string userName, string defaultName) + public static string MatchNameOrDefaultOrNull(IExceptionContext ectx, Schema schema, string argName, string userName, string defaultName) { Contracts.CheckValueOrNull(ectx); ectx.CheckValue(schema, nameof(schema)); @@ -462,7 +461,7 @@ public static bool AddNormalizerIfNeeded(IHostEnvironment env, IChannel ch, ITra { if (autoNorm != NormalizeOption.Yes) { - if (!trainer.Info.NeedNormalization || schema.IsNormalized(featCol)) + if (!trainer.Info.NeedNormalization || schema[featCol].IsNormalized()) { ch.Info("Not adding a normalizer."); return false; diff --git a/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs b/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs index 49f25375f5..40bac35c7b 100644 --- a/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs +++ b/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs @@ -4,18 +4,18 @@ using System.Collections.Generic; using System.IO; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Command; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Calibration; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML; +using Microsoft.ML.Command; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Calibration; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; [assembly: LoadableClass(TrainTestCommand.Summary, typeof(TrainTestCommand), typeof(TrainTestCommand.Arguments), typeof(SignatureCommand), "Train Test", TrainTestCommand.LoadName)] -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { [BestFriend] internal sealed class TrainTestCommand : DataCommand.ImplBase @@ -130,7 +130,7 @@ private void RunCore(IChannel ch, string cmd) ch.Trace("Constructing the training pipeline"); IDataView trainPipe = CreateLoader(); - ISchema schema = trainPipe.Schema; + var schema = trainPipe.Schema; string label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.LabelColumn), Args.LabelColumn, DefaultColumnNames.Label); string features = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.FeatureColumn), diff --git a/src/Microsoft.ML.Data/Commands/TypeInfoCommand.cs b/src/Microsoft.ML.Data/Commands/TypeInfoCommand.cs index e9db784f1f..e7f0bb5d15 100644 --- a/src/Microsoft.ML.Data/Commands/TypeInfoCommand.cs +++ b/src/Microsoft.ML.Data/Commands/TypeInfoCommand.cs @@ -5,12 +5,11 @@ using System; using System.Collections.Generic; using System.Linq; +using Microsoft.ML; +using Microsoft.ML.Command; using Microsoft.ML.Data.Commands; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Command; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Data.Conversion; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Data.Conversion; +using Microsoft.ML.Internal.Utilities; [assembly: LoadableClass(typeof(TypeInfoCommand), typeof(TypeInfoCommand.Arguments), typeof(SignatureCommand), "", TypeInfoCommand.LoadName)] diff --git a/src/Microsoft.ML.Data/Data/BufferBuilder.cs b/src/Microsoft.ML.Data/Data/BufferBuilder.cs index 2f37f4ea81..3527c50096 100644 --- a/src/Microsoft.ML.Data/Data/BufferBuilder.cs +++ b/src/Microsoft.ML.Data/Data/BufferBuilder.cs @@ -3,9 +3,9 @@ // See the LICENSE file in the project root for more information. using System; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { using Conditional = System.Diagnostics.ConditionalAttribute; diff --git a/src/Microsoft.ML.Data/Data/Combiner.cs b/src/Microsoft.ML.Data/Data/Combiner.cs index 6335620b8b..73957e8d67 100644 --- a/src/Microsoft.ML.Data/Data/Combiner.cs +++ b/src/Microsoft.ML.Data/Data/Combiner.cs @@ -4,12 +4,11 @@ #pragma warning disable 420 // volatile with Interlocked.CompareExchange -using Float = System.Single; - using System; using System.Threading; +using Float = System.Single; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { // REVIEW: Need better names for these and possibly a distinct namespace. These are too // specialized to have such prominent fully qualified names. diff --git a/src/Microsoft.ML.Data/Data/Conversion.cs b/src/Microsoft.ML.Data/Data/Conversion.cs index 1cfe187f24..3b1a9a9628 100644 --- a/src/Microsoft.ML.Data/Data/Conversion.cs +++ b/src/Microsoft.ML.Data/Data/Conversion.cs @@ -10,29 +10,30 @@ using System.Reflection; using System.Text; using System.Threading; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.Data.Conversion +namespace Microsoft.ML.Data.Conversion { using BL = Boolean; using DT = DateTime; using DZ = DateTimeOffset; - using R4 = Single; - using R8 = Double; using I1 = SByte; using I2 = Int16; using I4 = Int32; using I8 = Int64; + using R4 = Single; + using R8 = Double; using SB = StringBuilder; - using TX = ReadOnlyMemory; using TS = TimeSpan; + using TX = ReadOnlyMemory; using U1 = Byte; using U2 = UInt16; using U4 = UInt32; using U8 = UInt64; - using UG = UInt128; + using UG = RowId; - public delegate bool TryParseMapper(in TX src, out T dst); + [BestFriend] + internal delegate bool TryParseMapper(in TX src, out T dst); /// /// This type exists to provide efficient delegates for conversion between standard ColumnTypes, @@ -44,7 +45,8 @@ namespace Microsoft.ML.Runtime.Data.Conversion /// text (and StringBuilder). These are needed by the standard TextSaver, which handles /// differences between sparse and dense inputs in a semantically invariant way. /// - public sealed class Conversions + [BestFriend] + internal sealed class Conversions { // REVIEW: Reconcile implementations with TypeUtils, and clarify the distinction. @@ -412,15 +414,12 @@ public bool TryGetStandardConversion(ColumnType typeSrc, ColumnType typeDst, conv = null; identity = false; - if (typeSrc.IsKey) + if (typeSrc is KeyType keySrc) { - var keySrc = typeSrc.AsKey; - // Key types are only convertable to compatible key types or unsigned integer // types that are large enough. - if (typeDst.IsKey) + if (typeDst is KeyType keyDst) { - var keyDst = typeDst.AsKey; // We allow the Min value to shift. We currently don't allow the counts to vary. // REVIEW: Should we allow the counts to vary? Allowing the dst to be bigger is trivial. // Smaller dst means mapping values to NA. @@ -451,11 +450,11 @@ public bool TryGetStandardConversion(ColumnType typeSrc, ColumnType typeDst, // REVIEW: Should we look for illegal values and force them to zero? If so, then // we'll need to set identity to false. } - else if (typeDst.IsKey) + else if (typeDst is KeyType keyDst) { if (!typeSrc.IsText) return false; - conv = GetKeyParse(typeDst.AsKey); + conv = GetKeyParse(keyDst); return true; } else if (!typeDst.IsStandardScalar) @@ -490,10 +489,10 @@ public bool TryGetStringConversion(ColumnType type, out ValueMapper(type.AsKey); + conv = GetKeyStringConversion(keyType); return true; } return TryGetStringConversion(out conv); @@ -572,8 +571,8 @@ public TryParseMapper GetTryParseConversion(ColumnType typeDst) "Parse conversion only supported for standard types"); Contracts.Check(typeDst.RawType == typeof(TDst), "Wrong TDst type parameter"); - if (typeDst.IsKey) - return GetKeyTryParse(typeDst.AsKey); + if (typeDst is KeyType keyType) + return GetKeyTryParse(keyType); Contracts.Assert(_tryParseDelegates.ContainsKey(typeDst.RawKind)); return (TryParseMapper)_tryParseDelegates[typeDst.RawKind]; @@ -892,7 +891,7 @@ public ValueGetter GetNAOrDefaultGetter(ColumnType type) public void Convert(in U2 src, ref U1 dst) => dst = src <= U1.MaxValue ? (U1)src : (U1)0; public void Convert(in U4 src, ref U1 dst) => dst = src <= U1.MaxValue ? (U1)src : (U1)0; public void Convert(in U8 src, ref U1 dst) => dst = src <= U1.MaxValue ? (U1)src : (U1)0; - public void Convert(in UG src, ref U1 dst) => dst = src.Hi == 0 && src.Lo <= U1.MaxValue ? (U1)src.Lo : (U1)0; + public void Convert(in UG src, ref U1 dst) => dst = src.High == 0 && src.Low <= U1.MaxValue ? (U1)src.Low : (U1)0; #endregion ToU1 #region ToU2 @@ -900,7 +899,7 @@ public ValueGetter GetNAOrDefaultGetter(ColumnType type) public void Convert(in U2 src, ref U2 dst) => dst = src; public void Convert(in U4 src, ref U2 dst) => dst = src <= U2.MaxValue ? (U2)src : (U2)0; public void Convert(in U8 src, ref U2 dst) => dst = src <= U2.MaxValue ? (U2)src : (U2)0; - public void Convert(in UG src, ref U2 dst) => dst = src.Hi == 0 && src.Lo <= U2.MaxValue ? (U2)src.Lo : (U2)0; + public void Convert(in UG src, ref U2 dst) => dst = src.High == 0 && src.Low <= U2.MaxValue ? (U2)src.Low : (U2)0; #endregion ToU2 #region ToU4 @@ -908,7 +907,7 @@ public ValueGetter GetNAOrDefaultGetter(ColumnType type) public void Convert(in U2 src, ref U4 dst) => dst = src; public void Convert(in U4 src, ref U4 dst) => dst = src; public void Convert(in U8 src, ref U4 dst) => dst = src <= U4.MaxValue ? (U4)src : (U4)0; - public void Convert(in UG src, ref U4 dst) => dst = src.Hi == 0 && src.Lo <= U4.MaxValue ? (U4)src.Lo : (U4)0; + public void Convert(in UG src, ref U4 dst) => dst = src.High == 0 && src.Low <= U4.MaxValue ? (U4)src.Low : (U4)0; #endregion ToU4 #region ToU8 @@ -916,7 +915,7 @@ public ValueGetter GetNAOrDefaultGetter(ColumnType type) public void Convert(in U2 src, ref U8 dst) => dst = src; public void Convert(in U4 src, ref U8 dst) => dst = src; public void Convert(in U8 src, ref U8 dst) => dst = src; - public void Convert(in UG src, ref U8 dst) => dst = src.Hi == 0 ? src.Lo : (U8)0; + public void Convert(in UG src, ref U8 dst) => dst = src.High == 0 ? src.Low : (U8)0; #endregion ToU8 #region ToUG @@ -972,7 +971,7 @@ public ValueGetter GetNAOrDefaultGetter(ColumnType type) public void Convert(in U2 src, ref SB dst) => ClearDst(ref dst).Append(src); public void Convert(in U4 src, ref SB dst) => ClearDst(ref dst).Append(src); public void Convert(in U8 src, ref SB dst) => ClearDst(ref dst).Append(src); - public void Convert(in UG src, ref SB dst) { ClearDst(ref dst); dst.AppendFormat("0x{0:x16}{1:x16}", src.Hi, src.Lo); } + public void Convert(in UG src, ref SB dst) { ClearDst(ref dst); dst.AppendFormat("0x{0:x16}{1:x16}", src.High, src.Low); } public void Convert(in R4 src, ref SB dst) { ClearDst(ref dst); if (R4.IsNaN(src)) dst.AppendFormat(CultureInfo.InvariantCulture, "{0}", "?"); else dst.AppendFormat(CultureInfo.InvariantCulture, "{0:R}", src); } public void Convert(in R8 src, ref SB dst) { ClearDst(ref dst); if (R8.IsNaN(src)) dst.AppendFormat(CultureInfo.InvariantCulture, "{0}", "?"); else dst.AppendFormat(CultureInfo.InvariantCulture, "{0:G17}", src); } public void Convert(in BL src, ref SB dst) @@ -1060,7 +1059,7 @@ public bool TryParse(in TX src, out U8 dst) } /// - /// A parse method that transforms a 34-length string into a . + /// A parse method that transforms a 34-length string into a . /// /// What should be a 34-length hexadecimal representation, including a 0x prefix, /// of the 128-bit number diff --git a/src/Microsoft.ML.Data/Data/DataViewUtils.cs b/src/Microsoft.ML.Data/Data/DataViewUtils.cs index 30ba97bc67..1ee926378c 100644 --- a/src/Microsoft.ML.Data/Data/DataViewUtils.cs +++ b/src/Microsoft.ML.Data/Data/DataViewUtils.cs @@ -9,11 +9,10 @@ using System.Reflection; using System.Text; using System.Threading; -using Microsoft.ML.Data; -using Microsoft.ML.Runtime.Data.Conversion; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Data.Conversion; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { public static class DataViewUtils { @@ -22,7 +21,7 @@ public static class DataViewUtils /// Use tag to independently create multiple temporary, unique column /// names for a single transform. /// - public static string GetTempColumnName(this ISchema schema, string tag = null) + public static string GetTempColumnName(this Schema schema, string tag = null) { Contracts.CheckValue(schema, nameof(schema)); @@ -46,7 +45,7 @@ public static string GetTempColumnName(this ISchema schema, string tag = null) /// Use tag to independently create multiple temporary, unique column /// names for a single transform. /// - public static string[] GetTempColumnNames(this ISchema schema, int n, string tag = null) + public static string[] GetTempColumnNames(this Schema schema, int n, string tag = null) { Contracts.CheckValue(schema, nameof(schema)); Contracts.Check(n > 0, "n"); @@ -55,7 +54,7 @@ public static string[] GetTempColumnNames(this ISchema schema, int n, string tag int j = 0; for (int i = 0; i < n; i++) { - for (;;) + for (; ; ) { string name = string.IsNullOrWhiteSpace(tag) ? string.Format("temp_{0:000}", j) : @@ -114,8 +113,8 @@ public static int GetThreadCount(IHost host, int num = 0, bool preferOne = false /// Try to create a cursor set from upstream and consolidate it here. The host determines /// the target cardinality of the cursor set. /// - public static bool TryCreateConsolidatingCursor(out IRowCursor curs, - IDataView view, Func predicate, IHost host, IRandom rand) + public static bool TryCreateConsolidatingCursor(out RowCursor curs, + IDataView view, Func predicate, IHost host, Random rand) { Contracts.CheckValue(host, nameof(host)); host.CheckValue(view, nameof(view)); @@ -129,15 +128,20 @@ public static bool TryCreateConsolidatingCursor(out IRowCursor curs, return false; } - IRowCursorConsolidator consolidator; - var inputs = view.GetRowCursorSet(out consolidator, predicate, cthd, rand); + var inputs = view.GetRowCursorSet(predicate, cthd, rand); host.Check(Utils.Size(inputs) > 0); - host.Check(inputs.Length == 1 || consolidator != null); if (inputs.Length == 1) curs = inputs[0]; else - curs = consolidator.CreateCursor(host, inputs); + { + // We have a somewhat arbitrary batch size of about 64 for buffering results from the + // intermediate cursors, since that at least empirically for most datasets seems to + // strike a nice balance between a size large enough to benefit from parallelism but + // small enough so as to not be too onerous to keep in memory. + const int batchSize = 64; + curs = DataViewUtils.ConsolidateGeneric(host, inputs, batchSize); + } return true; } @@ -146,19 +150,17 @@ public static bool TryCreateConsolidatingCursor(out IRowCursor curs, /// cardinality. If not all the active columns are cachable, this will only /// produce the given input cursor. /// - public static IRowCursor[] CreateSplitCursors(out IRowCursorConsolidator consolidator, - IChannelProvider provider, IRowCursor input, int num) + public static RowCursor[] CreateSplitCursors(IChannelProvider provider, RowCursor input, int num) { Contracts.CheckValue(provider, nameof(provider)); provider.CheckValue(input, nameof(input)); - consolidator = null; if (num <= 1) - return new IRowCursor[1] { input }; + return new RowCursor[1] { input }; // If any active columns are not cachable, we can't split. if (!AllCachable(input.Schema, input.IsColumnActive)) - return new IRowCursor[1] { input }; + return new RowCursor[1] { input }; // REVIEW: Should we limit the cardinality to some reasonable size? @@ -168,23 +170,23 @@ public static IRowCursor[] CreateSplitCursors(out IRowCursorConsolidator consoli // REVIEW: Keep the utility method here, move this splitter stuff // to some other file. - return Splitter.Split(out consolidator, provider, input.Schema, input, num); + return Splitter.Split(provider, input.Schema, input, num); } /// /// Return whether all the active columns, as determined by the predicate, are /// cachable - either primitive types or vector types. /// - public static bool AllCachable(ISchema schema, Func predicate) + public static bool AllCachable(Schema schema, Func predicate) { Contracts.CheckValue(schema, nameof(schema)); Contracts.CheckValue(predicate, nameof(predicate)); - for (int col = 0; col < schema.ColumnCount; col++) + for (int col = 0; col < schema.Count; col++) { if (!predicate(col)) continue; - var type = schema.GetColumnType(col); + var type = schema[col].Type; if (!IsCachable(type)) return false; } @@ -205,7 +207,7 @@ public static bool IsCachable(this ColumnType type) /// that is, they all are non-null, have the same schemas, and the same /// set of columns are active. /// - public static bool SameSchemaAndActivity(IRowCursor[] cursors) + public static bool SameSchemaAndActivity(RowCursor[] cursors) { // There must be something to actually consolidate. if (Utils.Size(cursors) == 0) @@ -223,7 +225,7 @@ public static bool SameSchemaAndActivity(IRowCursor[] cursors) return false; } // All cursors must have the same columns active. - for (int c = 0; c < schema.ColumnCount; ++c) + for (int c = 0; c < schema.Count; ++c) { bool active = firstCursor.IsColumnActive(c); for (int i = 1; i < cursors.Length; ++i) @@ -239,7 +241,7 @@ public static bool SameSchemaAndActivity(IRowCursor[] cursors) /// Given a parallel cursor set, this consolidates them into a single cursor. The batchSize /// is a hint used for efficiency. /// - public static IRowCursor ConsolidateGeneric(IChannelProvider provider, IRowCursor[] inputs, int batchSize) + public static RowCursor ConsolidateGeneric(IChannelProvider provider, RowCursor[] inputs, int batchSize) { Contracts.CheckValue(provider, nameof(provider)); provider.CheckNonEmpty(inputs, nameof(inputs)); @@ -249,7 +251,7 @@ public static IRowCursor ConsolidateGeneric(IChannelProvider provider, IRowCurso return inputs[0]; object[] pools = null; - return Splitter.Consolidator.Consolidate(provider, inputs, batchSize, ref pools); + return Splitter.Consolidate(provider, inputs, batchSize, ref pools); } /// @@ -277,7 +279,6 @@ private sealed class Splitter { private readonly Schema _schema; private readonly object[] _cachePools; - private object[] _consolidateCachePools; /// /// Pipes, in addition to column values, will also communicate extra information @@ -296,217 +297,201 @@ private Splitter(Schema schema) { Contracts.AssertValue(schema); _schema = schema; - _cachePools = new object[_schema.ColumnCount + (int)ExtraIndex._Lim]; + _cachePools = new object[_schema.Count + (int)ExtraIndex._Lim]; } - public sealed class Consolidator : IRowCursorConsolidator + public static RowCursor Consolidate(IChannelProvider provider, RowCursor[] inputs, int batchSize, ref object[] ourPools) { - private readonly Splitter _splitter; - - public Consolidator(Splitter splitter) + Contracts.AssertValue(provider); + using (var ch = provider.Start("Consolidate")) { - Contracts.AssertValue(splitter); - _splitter = splitter; + return ConsolidateCore(provider, inputs, ref ourPools, ch); } + } - public IRowCursor CreateCursor(IChannelProvider provider, IRowCursor[] inputs) - { - return Consolidate(provider, inputs, 128, ref _splitter._consolidateCachePools); - } + private static RowCursor ConsolidateCore(IChannelProvider provider, RowCursor[] inputs, ref object[] ourPools, IChannel ch) + { + ch.CheckNonEmpty(inputs, nameof(inputs)); + if (inputs.Length == 1) + return inputs[0]; + ch.CheckParam(SameSchemaAndActivity(inputs), nameof(inputs), "Inputs not compatible for consolidation"); + + RowCursor cursor = inputs[0]; + var schema = cursor.Schema; + ch.CheckParam(AllCachable(schema, cursor.IsColumnActive), nameof(inputs), "Inputs had some uncachable input columns"); - public static IRowCursor Consolidate(IChannelProvider provider, IRowCursor[] inputs, int batchSize, ref object[] ourPools) + int[] activeToCol; + int[] colToActive; + Utils.BuildSubsetMaps(schema.Count, cursor.IsColumnActive, out activeToCol, out colToActive); + + // Because the schema of the consolidator is not necessary fixed, we are merely + // opportunistic about buffer sharing, from cursoring to cursoring. If we can do + // it easily, great, if not, no big deal. + if (Utils.Size(ourPools) != schema.Count) + ourPools = new object[schema.Count + (int)ExtraIndex._Lim]; + // Create the out pipes. + OutPipe[] outPipes = new OutPipe[activeToCol.Length + (int)ExtraIndex._Lim]; + for (int i = 0; i < activeToCol.Length; ++i) { - Contracts.AssertValue(provider); - using (var ch = provider.Start("Consolidate")) - { - return ConsolidateCore(provider, inputs, ref ourPools, ch); - } + int c = activeToCol[i]; + ColumnType type = schema[c].Type; + var pool = GetPool(type, ourPools, c); + outPipes[i] = OutPipe.Create(type, pool); } + int idIdx = activeToCol.Length + (int)ExtraIndex.Id; + outPipes[idIdx] = OutPipe.Create(NumberType.UG, GetPool(NumberType.UG, ourPools, idIdx)); + + // Create the structures to synchronize between the workers and the consumer. + const int toConsumeBound = 4; + var toConsume = new BlockingCollection(toConsumeBound); + var batchColumnPool = new MadeObjectPool(() => new BatchColumn[outPipes.Length]); + Thread[] workers = new Thread[inputs.Length]; + MinWaiter waiter = new MinWaiter(workers.Length); + bool done = false; - private static IRowCursor ConsolidateCore(IChannelProvider provider, IRowCursor[] inputs, ref object[] ourPools, IChannel ch) + for (int t = 0; t < workers.Length; ++t) { - ch.CheckNonEmpty(inputs, nameof(inputs)); - if (inputs.Length == 1) - return inputs[0]; - ch.CheckParam(SameSchemaAndActivity(inputs), nameof(inputs), "Inputs not compatible for consolidation"); - - IRowCursor cursor = inputs[0]; - var schema = cursor.Schema; - ch.CheckParam(AllCachable(schema, cursor.IsColumnActive), nameof(inputs), "Inputs had some uncachable input columns"); - - int[] activeToCol; - int[] colToActive; - Utils.BuildSubsetMaps(schema.ColumnCount, cursor.IsColumnActive, out activeToCol, out colToActive); - - // Because the schema of the consolidator is not necessary fixed, we are merely - // opportunistic about buffer sharing, from cursoring to cursoring. If we can do - // it easily, great, if not, no big deal. - if (Utils.Size(ourPools) != schema.ColumnCount) - ourPools = new object[schema.ColumnCount + (int)ExtraIndex._Lim]; - // Create the out pipes. - OutPipe[] outPipes = new OutPipe[activeToCol.Length + (int)ExtraIndex._Lim]; - for (int i = 0; i < activeToCol.Length; ++i) + var localCursor = inputs[t]; + ch.Assert(localCursor.State == CursorState.NotStarted); + // Note that these all take ownership of their respective cursors, + // so they all handle their disposal internal to the thread. + workers[t] = Utils.CreateBackgroundThread(() => { - int c = activeToCol[i]; - ColumnType type = schema.GetColumnType(c); - var pool = GetPool(type, ourPools, c); - outPipes[i] = OutPipe.Create(type, pool); - } - int idIdx = activeToCol.Length + (int)ExtraIndex.Id; - outPipes[idIdx] = OutPipe.Create(NumberType.UG, GetPool(NumberType.UG, ourPools, idIdx)); - - // Create the structures to synchronize between the workers and the consumer. - const int toConsumeBound = 4; - var toConsume = new BlockingCollection(toConsumeBound); - var batchColumnPool = new MadeObjectPool(() => new BatchColumn[outPipes.Length]); - Thread[] workers = new Thread[inputs.Length]; - MinWaiter waiter = new MinWaiter(workers.Length); - bool done = false; - - for (int t = 0; t < workers.Length; ++t) - { - var localCursor = inputs[t]; - ch.Assert(localCursor.State == CursorState.NotStarted); - // Note that these all take ownership of their respective cursors, - // so they all handle their disposal internal to the thread. - workers[t] = Utils.CreateBackgroundThread(() => - { // This will be the last batch sent in the finally. If iteration procedes without // error, it will remain null, and be sent as a sentinel. If iteration results in // an exception that we catch, the exception catching block will set this to an // exception bearing block, and that will be passed along as the last block instead. Batch lastBatch = null; - try + try + { + using (localCursor) { - using (localCursor) - { - InPipe[] inPipes = new InPipe[outPipes.Length]; - for (int i = 0; i < activeToCol.Length; ++i) - inPipes[i] = outPipes[i].CreateInPipe(RowCursorUtils.GetGetterAsDelegate(localCursor, activeToCol[i])); - inPipes[idIdx] = outPipes[idIdx].CreateInPipe(localCursor.GetIdGetter()); + InPipe[] inPipes = new InPipe[outPipes.Length]; + for (int i = 0; i < activeToCol.Length; ++i) + inPipes[i] = outPipes[i].CreateInPipe(RowCursorUtils.GetGetterAsDelegate(localCursor, activeToCol[i])); + inPipes[idIdx] = outPipes[idIdx].CreateInPipe(localCursor.GetIdGetter()); - long oldBatch = 0; - int count = 0; + long oldBatch = 0; + int count = 0; // This event is used to synchronize ourselves using a MinWaiter // so that we add batches to the consumer queue at the appropriate time. ManualResetEventSlim waiterEvent = null; - Action pushBatch = () => + Action pushBatch = () => + { + if (count > 0) { - if (count > 0) - { - var batchColumns = batchColumnPool.Get(); - for (int i = 0; i < inPipes.Length; ++i) - batchColumns[i] = inPipes[i].GetBatchColumnAndReset(); + var batchColumns = batchColumnPool.Get(); + for (int i = 0; i < inPipes.Length; ++i) + batchColumns[i] = inPipes[i].GetBatchColumnAndReset(); // REVIEW: Is it worth not allocating new Batch object for each batch? var batch = new Batch(batchColumnPool, batchColumns, count, oldBatch); - count = 0; + count = 0; // The waiter event should never be null since this is only // called after a point where waiter.Register has been called. ch.AssertValue(waiterEvent); - waiterEvent.Wait(); - waiterEvent = null; - toConsume.Add(batch); - } - }; + waiterEvent.Wait(); + waiterEvent = null; + toConsume.Add(batch); + } + }; // Handle the first one separately, then go into the main loop. if (localCursor.MoveNext() && !done) - { - oldBatch = localCursor.Batch; - foreach (var pipe in inPipes) - pipe.Fill(); - count++; + { + oldBatch = localCursor.Batch; + foreach (var pipe in inPipes) + pipe.Fill(); + count++; // Register with the min waiter that we want to wait on this batch number. waiterEvent = waiter.Register(oldBatch); - while (localCursor.MoveNext() && !done) + while (localCursor.MoveNext() && !done) + { + if (oldBatch != localCursor.Batch) { - if (oldBatch != localCursor.Batch) - { - ch.Assert(count == 0 || localCursor.Batch > oldBatch); - pushBatch(); - oldBatch = localCursor.Batch; - waiterEvent = waiter.Register(oldBatch); - } - foreach (var pipe in inPipes) - pipe.Fill(); - count++; + ch.Assert(count == 0 || localCursor.Batch > oldBatch); + pushBatch(); + oldBatch = localCursor.Batch; + waiterEvent = waiter.Register(oldBatch); } - pushBatch(); + foreach (var pipe in inPipes) + pipe.Fill(); + count++; } + pushBatch(); } } - catch (Exception ex) - { + } + catch (Exception ex) + { // Whoops, we won't be sending null as the sentinel now. lastBatch = new Batch(ex); - toConsume.Add(new Batch(ex)); - } - finally + toConsume.Add(new Batch(ex)); + } + finally + { + if (waiter.Retire() == 0) { - if (waiter.Retire() == 0) + if (lastBatch == null) { - if (lastBatch == null) - { // If it wasn't null, this already sent along an exception bearing batch, in which // case sending the sentinel is unnecessary and unhelpful. toConsume.Add(null); - } - toConsume.CompleteAdding(); } + toConsume.CompleteAdding(); } - }); - workers[t].Start(); - } - - Action quitAction = () => - { - done = true; - var myOutPipes = outPipes; - foreach (var batch in toConsume.GetConsumingEnumerable()) - { - if (batch == null) - continue; - batch.SetAll(myOutPipes); - foreach (var outPipe in myOutPipes) - outPipe.Unset(); } - foreach (Thread thread in workers) - thread.Join(); - }; - - return new Cursor(provider, schema, activeToCol, colToActive, outPipes, toConsume, quitAction); + }); + workers[t].Start(); } - private static object GetPool(ColumnType type, object[] pools, int poolIdx) + Action quitAction = () => { - Func func = GetPoolCore; - var method = func.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(type.RawType); - return method.Invoke(null, new object[] { pools, poolIdx }); - } + done = true; + var myOutPipes = outPipes; + foreach (var batch in toConsume.GetConsumingEnumerable()) + { + if (batch == null) + continue; + batch.SetAll(myOutPipes); + foreach (var outPipe in myOutPipes) + outPipe.Unset(); + } + foreach (Thread thread in workers) + thread.Join(); + }; - private static MadeObjectPool GetPoolCore(object[] pools, int poolIdx) - { - var pool = pools[poolIdx] as MadeObjectPool; - if (pool == null) - pools[poolIdx] = pool = new MadeObjectPool(() => null); - return pool; - } + return new Cursor(provider, schema, activeToCol, colToActive, outPipes, toConsume, quitAction); + } + + private static object GetPool(ColumnType type, object[] pools, int poolIdx) + { + Func func = GetPoolCore; + var method = func.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(type.RawType); + return method.Invoke(null, new object[] { pools, poolIdx }); + } + + private static MadeObjectPool GetPoolCore(object[] pools, int poolIdx) + { + var pool = pools[poolIdx] as MadeObjectPool; + if (pool == null) + pools[poolIdx] = pool = new MadeObjectPool(() => null); + return pool; } - public static IRowCursor[] Split(out IRowCursorConsolidator consolidator, IChannelProvider provider, Schema schema, IRowCursor input, int cthd) + public static RowCursor[] Split(IChannelProvider provider, Schema schema, RowCursor input, int cthd) { Contracts.AssertValue(provider, "provider"); var splitter = new Splitter(schema); using (var ch = provider.Start("CursorSplitter")) { - var result = splitter.SplitCore(out consolidator, provider, input, cthd); + var result = splitter.SplitCore(provider, input, cthd); return result; } } - private IRowCursor[] SplitCore(out IRowCursorConsolidator consolidator, IChannelProvider ch, IRowCursor input, int cthd) + private RowCursor[] SplitCore(IChannelProvider ch, RowCursor input, int cthd) { Contracts.AssertValue(ch); ch.AssertValue(input); @@ -522,9 +507,9 @@ private IRowCursor[] SplitCore(out IRowCursorConsolidator consolidator, IChannel // Create the mappings between active column index, and column index. int[] activeToCol; int[] colToActive; - Utils.BuildSubsetMaps(_schema.ColumnCount, input.IsColumnActive, out activeToCol, out colToActive); + Utils.BuildSubsetMaps(_schema.Count, input.IsColumnActive, out activeToCol, out colToActive); - Func createFunc = CreateInPipe; + Func createFunc = CreateInPipe; var inGenMethod = createFunc.GetMethodInfo().GetGenericMethodDefinition(); object[] arguments = new object[] { input, 0 }; // Only one set of in-pipes, one per column, as well as for extra side information. @@ -537,10 +522,10 @@ private IRowCursor[] SplitCore(out IRowCursorConsolidator consolidator, IChannel // For each column, create the InPipe, and all OutPipes per output cursor. for (int c = 0; c < activeToCol.Length; ++c) { - ch.Assert(0 <= activeToCol[c] && activeToCol[c] < _schema.ColumnCount); + ch.Assert(0 <= activeToCol[c] && activeToCol[c] < _schema.Count); ch.Assert(c == 0 || activeToCol[c - 1] < activeToCol[c]); ch.Assert(input.IsColumnActive(activeToCol[c])); - var type = input.Schema.GetColumnType(activeToCol[c]); + var type = input.Schema[activeToCol[c]].Type; ch.Assert(type.IsCachable()); arguments[1] = activeToCol[c]; var inPipe = inPipes[c] = @@ -637,27 +622,26 @@ private IRowCursor[] SplitCore(out IRowCursorConsolidator consolidator, IChannel var cursors = new Cursor[cthd]; for (int i = 0; i < cthd; ++i) cursors[i] = new Cursor(ch, _schema, activeToCol, colToActive, outPipes[i], toConsume, quitAction); - consolidator = new Consolidator(this); return cursors; } /// /// An in pipe creator intended to be used from the splitter only. /// - private InPipe CreateInPipe(IRow input, int col) + private InPipe CreateInPipe(Row input, int col) { Contracts.AssertValue(input); - Contracts.Assert(0 <= col && col < _schema.ColumnCount); + Contracts.Assert(0 <= col && col < _schema.Count); return CreateInPipeCore(col, input.GetGetter(col)); } /// /// An in pipe creator intended to be used from the splitter only. /// - private InPipe CreateIdInPipe(IRow input) + private InPipe CreateIdInPipe(Row input) { Contracts.AssertValue(input); - return CreateInPipeCore(_schema.ColumnCount + (int)ExtraIndex.Id, input.GetIdGetter()); + return CreateInPipeCore(_schema.Count + (int)ExtraIndex.Id, input.GetIdGetter()); } private InPipe CreateInPipeCore(int poolIdx, ValueGetter getter) @@ -849,7 +833,7 @@ public void SetAll(OutPipe[] pipes) /// /// This helps a cursor present the results of a . Practically its role - /// really is to just provide a stable delegate for the . + /// really is to just provide a stable delegate for the . /// There is one of these created per column, per output cursor, i.e., in splitting /// there are n of these created per column, and when consolidating only one of these /// is created per column. @@ -999,14 +983,14 @@ protected override void Getter(ref T value) /// objects from the input blocking collection, and yields the /// values stored therein through the help of objects. /// - private sealed class Cursor : RootCursorBase, IRowCursor + private sealed class Cursor : RootCursorBase { private readonly Schema _schema; private readonly int[] _activeToCol; private readonly int[] _colToActive; private readonly OutPipe[] _pipes; private readonly Delegate[] _getters; - private readonly ValueGetter _idGetter; + private readonly ValueGetter _idGetter; private readonly BlockingCollection _batchInputs; private readonly Action _quitAction; @@ -1014,12 +998,9 @@ private sealed class Cursor : RootCursorBase, IRowCursor private long _batch; private bool _disposed; - public Schema Schema => _schema; + public override Schema Schema => _schema; - public override long Batch - { - get { return _batch; } - } + public override long Batch => _batch; /// /// Constructs one of the split cursors. @@ -1042,7 +1023,7 @@ public Cursor(IChannelProvider provider, Schema schema, int[] activeToCol, int[] Ch.AssertValue(pipes); Ch.AssertValue(batchInputs); Ch.AssertValueOrNull(quitAction); - Ch.Assert(colToActive.Length == schema.ColumnCount); + Ch.Assert(colToActive.Length == schema.Count); Ch.Assert(activeToCol.Length + (int)ExtraIndex._Lim == pipes.Length); Ch.Assert(pipes.All(p => p != null)); // Could also confirm the inverse mappiness of activeToCol/colToActive, but that seems like a bit much. @@ -1053,29 +1034,27 @@ public Cursor(IChannelProvider provider, Schema schema, int[] activeToCol, int[] _getters = new Delegate[pipes.Length]; for (int i = 0; i < activeToCol.Length; ++i) _getters[i] = _pipes[i].GetGetter(); - _idGetter = (ValueGetter)_pipes[activeToCol.Length + (int)ExtraIndex.Id].GetGetter(); + _idGetter = (ValueGetter)_pipes[activeToCol.Length + (int)ExtraIndex.Id].GetGetter(); _batchInputs = batchInputs; _batch = -1; _quitAction = quitAction; } - public override void Dispose() + protected override void Dispose(bool disposing) { - if (!_disposed) + if (_disposed) + return; + if (disposing) { foreach (var pipe in _pipes) pipe.Unset(); - _disposed = true; - if (_quitAction != null) - _quitAction(); + _quitAction?.Invoke(); } - base.Dispose(); + _disposed = true; + base.Dispose(disposing); } - public override ValueGetter GetIdGetter() - { - return _idGetter; - } + public override ValueGetter GetIdGetter() => _idGetter; protected override bool MoveNextCore() { @@ -1114,13 +1093,13 @@ protected override bool MoveNextCore() return true; } - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { Ch.CheckParam(0 <= col && col < _colToActive.Length, nameof(col)); return _colToActive[col] >= 0; } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Ch.CheckParam(IsColumnActive(col), nameof(col), "requested column not active"); var getter = _getters[_colToActive[col]] as ValueGetter; @@ -1136,9 +1115,9 @@ public ValueGetter GetGetter(int col) /// at the cost of being totally synchronous, that is, there is no parallel benefit from /// having split the input cursors. /// - internal sealed class SynchronousConsolidatingCursor : RootCursorBase, IRowCursor + internal sealed class SynchronousConsolidatingCursor : RootCursorBase { - private readonly IRowCursor[] _cursors; + private readonly RowCursor[] _cursors; private readonly Delegate[] _getters; private readonly Schema _schema; @@ -1152,7 +1131,7 @@ internal sealed class SynchronousConsolidatingCursor : RootCursorBase, IRowCurso // Index into _cursors array pointing to the current cursor, or -1 if this cursor is not in Good state. private int _icursor; // If this cursor is in Good state then this should equal _cursors[_icursor], else null. - private IRowCursor _currentCursor; + private RowCursor _currentCursor; private bool _disposed; private readonly struct CursorStats @@ -1171,16 +1150,16 @@ public CursorStats(long batch, int idx) // input batch as our own batch. Should we suppress it? public override long Batch { get { return _batch; } } - public Schema Schema => _schema; + public override Schema Schema => _schema; - public SynchronousConsolidatingCursor(IChannelProvider provider, IRowCursor[] cursors) + public SynchronousConsolidatingCursor(IChannelProvider provider, RowCursor[] cursors) : base(provider) { Ch.CheckNonEmpty(cursors, nameof(cursors)); _cursors = cursors; _schema = _cursors[0].Schema; - Utils.BuildSubsetMaps(_schema.ColumnCount, _cursors[0].IsColumnActive, out _activeToCol, out _colToActive); + Utils.BuildSubsetMaps(_schema.Count, _cursors[0].IsColumnActive, out _activeToCol, out _colToActive); Func func = CreateGetter; _methInfo = func.GetMethodInfo().GetGenericMethodDefinition(); @@ -1199,18 +1178,19 @@ private void InitHeap() { for (int i = 0; i < _cursors.Length; ++i) { - IRowCursor cursor = _cursors[i]; + RowCursor cursor = _cursors[i]; Ch.Assert(cursor.State == CursorState.NotStarted); if (cursor.MoveNext()) _mins.Add(new CursorStats(cursor.Batch, i)); } } - public override void Dispose() + protected override void Dispose(bool disposing) { - if (!_disposed) + if (_disposed) + return; + if (disposing) { - _disposed = true; _batch = -1; _icursor = -1; _currentCursor = null; @@ -1218,16 +1198,17 @@ public override void Dispose() foreach (var cursor in _cursors) cursor.Dispose(); } - base.Dispose(); + _disposed = true; + base.Dispose(disposing); } - public override ValueGetter GetIdGetter() + public override ValueGetter GetIdGetter() { - ValueGetter[] idGetters = new ValueGetter[_cursors.Length]; + ValueGetter[] idGetters = new ValueGetter[_cursors.Length]; for (int i = 0; i < _cursors.Length; ++i) idGetters[i] = _cursors[i].GetIdGetter(); return - (ref UInt128 val) => + (ref RowId val) => { Ch.Check(_icursor >= 0, "Cannot call ID getter in current state"); idGetters[_icursor](ref val); @@ -1236,21 +1217,21 @@ public override ValueGetter GetIdGetter() private Delegate CreateGetter(int col) { - var methInfo = _methInfo.MakeGenericMethod(Schema.GetColumnType(col).RawType); + var methInfo = _methInfo.MakeGenericMethod(Schema[col].Type.RawType); return (Delegate)methInfo.Invoke(this, new object[] { col }); } private Delegate CreateGetter(int col) { ValueGetter[] getters = new ValueGetter[_cursors.Length]; - var type = Schema.GetColumnType(col); + var type = Schema[col].Type; for (int i = 0; i < _cursors.Length; ++i) { var cursor = _cursors[i]; Ch.AssertValue(cursor); - Ch.Assert(col < cursor.Schema.ColumnCount); + Ch.Assert(col < cursor.Schema.Count); Ch.Assert(cursor.IsColumnActive(col)); - Ch.Assert(type.Equals(cursor.Schema.GetColumnType(col))); + Ch.Assert(type.Equals(cursor.Schema[col].Type)); getters[i] = _cursors[i].GetGetter(col); } ValueGetter mine = @@ -1291,13 +1272,13 @@ protected override bool MoveNextCore() return true; } - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { Ch.CheckParam(0 <= col && col < _colToActive.Length, nameof(col)); return _colToActive[col] >= 0; } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Ch.CheckParam(IsColumnActive(col), nameof(col), "requested column not active"); var getter = _getters[_colToActive[col]] as ValueGetter; @@ -1307,7 +1288,7 @@ public ValueGetter GetGetter(int col) } } - public static ValueGetter>[] PopulateGetterArray(IRowCursor cursor, List colIndices) + public static ValueGetter>[] PopulateGetterArray(RowCursor cursor, List colIndices) { var n = colIndices.Count; var getters = new ValueGetter>[n]; @@ -1317,7 +1298,7 @@ public static ValueGetter>[] PopulateGetterArray(IRowCursor ValueGetter> getter; var srcColIndex = colIndices[i]; - var colType = cursor.Schema.GetColumnType(srcColIndex); + var colType = cursor.Schema[srcColIndex].Type; if (colType.IsVector) { getter = Utils.MarshalInvoke(GetVectorFlatteningGetter, colType.ItemType.RawType, @@ -1335,7 +1316,7 @@ public static ValueGetter>[] PopulateGetterArray(IRowCursor return getters; } - public static ValueGetter> GetSingleValueGetter(IRow cursor, int i, ColumnType colType) + public static ValueGetter> GetSingleValueGetter(Row cursor, int i, ColumnType colType) { var floatGetter = cursor.GetGetter(i); T v = default(T); @@ -1365,7 +1346,7 @@ public static ValueGetter> GetSingleValueGetter(IRow cur return getter; } - public static ValueGetter> GetVectorFlatteningGetter(IRow cursor, int colIndex, ColumnType colType) + public static ValueGetter> GetVectorFlatteningGetter(Row cursor, int colIndex, ColumnType colType) { var vecGetter = cursor.GetGetter>(colIndex); var vbuf = default(VBuffer); diff --git a/src/Microsoft.ML.Data/Data/IColumn.cs b/src/Microsoft.ML.Data/Data/IColumn.cs deleted file mode 100644 index 9d2146ee37..0000000000 --- a/src/Microsoft.ML.Data/Data/IColumn.cs +++ /dev/null @@ -1,667 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Collections.Generic; -using System.Linq; -using System.Threading; -using Microsoft.ML.Data; -using Microsoft.ML.Runtime.Internal.Utilities; - -namespace Microsoft.ML.Runtime.Data -{ - /// - /// This interface is an analogy to that encapsulates the contents of a single - /// column. - /// - /// Note that in the same sense that is not thread safe, implementors of this interface - /// by similar token must not be considered thread safe by users of the interface, and by the same token - /// implementors should feel free to write their implementations with the expectation that only one thread - /// will be calling it at a time. - /// - /// Similarly, in the same sense that an can have its values "change under it" by having - /// the underlying cursor move, so too might this item have its values change under it, and they will if - /// they were directly instantiated from a row. - /// - /// Generally actual implementors of this interface should not implement this directly, but instead implement - /// . - /// - // REVIEW: It is possible we may want to make this ICounted, but let's not start with - // that assumption. The use cases I have in mind are that we'll still, on the side, have an - // IRow lying around. - public interface IColumn - { - /// - /// The name of a column. This string should always be non-empty. - /// - string Name { get; } - - /// - /// The type of the column. - /// - ColumnType Type { get; } - - // REVIEW: This property anticipates a time when we get away with metadata accessors - // altogether, and just have the metadata for a column be represented as a row. - /// - /// The metadata for a column, or null if this column has no metadata. - /// - IRow Metadata { get; } - - /// - /// Whether the column should be considered active or not. - /// - bool IsActive { get; } - - /// - /// The value getter, as a . Implementators should just pass through - /// . - /// - /// The generic getter delegate - Delegate GetGetter(); - } - - /// - /// The type specific interface for a . - /// - /// The type of values in this column. This should agree with the - /// field of . - public interface IValueColumn : IColumn - { - new ValueGetter GetGetter(); - } - - public static class RowColumnUtils - { - /// - /// Exposes a single column in a row. - /// - /// The row to wrap - /// The column to expose - /// A row column instance - public static IColumn GetColumn(IRow row, int col) - { - Contracts.CheckValue(row, nameof(row)); - Contracts.CheckParam(0 <= col && col < row.Schema.ColumnCount, nameof(col)); - - Func func = GetColumnCore; - return Utils.MarshalInvoke(func, row.Schema.GetColumnType(col).RawType, row, col); - } - - private static IColumn GetColumnCore(IRow row, int col) - { - Contracts.AssertValue(row); - Contracts.Assert(0 <= col && col < row.Schema.ColumnCount); - Contracts.Assert(row.Schema.GetColumnType(col).RawType == typeof(T)); - - return new RowWrap(row, col); - } - - /// - /// Exposes a single column in a schema. The column is considered inactive. - /// - /// The schema to get the data for - /// The column to get - /// A column with false - public static IColumn GetColumn(ISchema schema, int col) - { - Contracts.CheckValue(schema, nameof(schema)); - Contracts.CheckParam(0 <= col && col < schema.ColumnCount, nameof(col)); - - Func func = GetColumnCore; - return Utils.MarshalInvoke(func, schema.GetColumnType(col).RawType, schema, col); - } - - private static IColumn GetColumnCore(ISchema schema, int col) - { - Contracts.AssertValue(schema); - Contracts.Assert(0 <= col && col < schema.ColumnCount); - Contracts.Assert(schema.GetColumnType(col).RawType == typeof(T)); - - return new SchemaWrap(schema, col); - } - - /// - /// Constructs a column out of a value. This will store the input value, not make a copy. - /// - /// The type of the value - /// The column name, which must be non-empty - /// The type of the column, whose raw type must be - /// The value to store in the column - /// Optionally, metadata for the column - /// A column with this value - public static IColumn GetColumn(string name, ColumnType type, ref T value, IRow meta = null) - { - Contracts.CheckNonEmpty(name, nameof(name)); - Contracts.CheckValue(type, nameof(type)); - Contracts.CheckParam(type.RawType == typeof(T), nameof(type), "Mismatch on object type and column type"); - if (type.IsVector) - return Utils.MarshalInvoke(GetColumnVecCore, type.ItemType.RawType, name, type.AsVector, (object)value, meta); - Contracts.CheckParam(type.IsPrimitive, nameof(type), "Type must be either vector or primitive"); - Contracts.CheckValueOrNull(meta); - return Utils.MarshalInvoke(GetColumnOneCore, type.RawType, name, type, (object)value, meta); - } - - private static IColumn GetColumnVecCore(string name, VectorType type, object value, IRow meta) - { - // REVIEW: Ugh. Nasty. Any alternative to boxing? - Contracts.AssertNonEmpty(name); - Contracts.AssertValue(type); - Contracts.Assert(type.IsVector); - Contracts.Assert(type.ItemType.RawType == typeof(T)); - Contracts.Assert(value is VBuffer); - Contracts.AssertValueOrNull(meta); - VBuffer typedVal = (VBuffer)value; - return new ConstVecImpl(name, meta, type, typedVal); - } - - private static IColumn GetColumnOneCore(string name, ColumnType type, object value, IRow meta) - { - Contracts.AssertNonEmpty(name); - Contracts.AssertValue(type); - Contracts.Assert(type.IsPrimitive); - Contracts.Assert(type.RawType == typeof(T)); - Contracts.Assert(value is T); - Contracts.AssertValueOrNull(meta); - T typedVal = (T)value; - return new ConstOneImpl(name, meta, type, typedVal); - } - - /// - /// Constructs a column out of a getter. - /// - /// The type of the value - /// The column name, which must be non-empty - /// The type of the column, whose raw type must be - /// The getter for the column - /// Optionally, metadata for the column - /// A column with this getter - public static IColumn GetColumn(string name, ColumnType type, ValueGetter getter, IRow meta = null) - { - Contracts.CheckNonEmpty(name, nameof(name)); - Contracts.CheckValue(type, nameof(type)); - Contracts.CheckParam(type.RawType == typeof(T), nameof(type), "Mismatch on object type and column type"); - Contracts.CheckValue(getter, nameof(getter)); - Contracts.CheckValueOrNull(meta); - - return new GetterImpl(name, meta, type, getter); - } - - /// - /// Wraps a set of row columns as a row. - /// - /// The counted object that the output row will wrap for its own implementation of - /// , or if null, the output row will yield default values for those implementations, - /// that is, a totally static row - /// A set of row columns - /// A row with items derived from - public static IRow GetRow(ICounted counted, params IColumn[] columns) - { - Contracts.CheckValueOrNull(counted); - Contracts.CheckValue(columns, nameof(columns)); - return new RowColumnRow(counted, columns); - } - - /// - /// Given a column, returns a deep-copied memory-materialized version of it. Note that - /// it is acceptable for the column to be inactive: the returned column will likewise - /// be inactive. - /// - /// - /// A memory materialized version of which may be, - /// under appropriate circumstances, the input object itself - public static IColumn CloneColumn(IColumn column) - { - Contracts.CheckValue(column, nameof(column)); - return Utils.MarshalInvoke(CloneColumnCore, column.Type.RawType, column); - } - - private static IColumn CloneColumnCore(IColumn column) - { - Contracts.Assert(column is IValueColumn); - IRow meta = column.Metadata; - if (meta != null) - meta = RowCursorUtils.CloneRow(meta); - - var tcolumn = (IValueColumn)column; - if (!tcolumn.IsActive) - return new InactiveImpl(tcolumn.Name, meta, tcolumn.Type); - T val = default(T); - tcolumn.GetGetter()(ref val); - return GetColumn(tcolumn.Name, tcolumn.Type, ref val, meta); - } - - /// - /// The implementation for a simple wrapping of an . - /// - private sealed class RowWrap : IValueColumn - { - private readonly IRow _row; - private readonly int _col; - private MetadataRow _meta; - - public string Name => _row.Schema.GetColumnName(_col); - public ColumnType Type => _row.Schema.GetColumnType(_col); - public bool IsActive => _row.IsColumnActive(_col); - - public IRow Metadata - { - get - { - if (_meta == null) - Interlocked.CompareExchange(ref _meta, new MetadataRow(_row.Schema, _col, x => true), null); - return _meta; - } - } - - public RowWrap(IRow row, int col) - { - Contracts.AssertValue(row); - Contracts.Assert(0 <= col && col < row.Schema.ColumnCount); - Contracts.Assert(row.Schema.GetColumnType(col).RawType == typeof(T)); - - _row = row; - _col = col; - } - - Delegate IColumn.GetGetter() - => GetGetter(); - - public ValueGetter GetGetter() - => _row.GetGetter(_col); - } - - /// - /// The base class for a few implementations that do not "go" anywhere. - /// - private abstract class DefaultCounted : ICounted - { - public long Position => 0; - public long Batch => 0; - public ValueGetter GetIdGetter() - => IdGetter; - - private static void IdGetter(ref UInt128 id) - => id = default; - } - - /// - /// Simple wrapper for a schema column, considered inctive with no getter. - /// - /// The type of the getter - private sealed class SchemaWrap : IValueColumn - { - private readonly ISchema _schema; - private readonly int _col; - private MetadataRow _meta; - - public string Name => _schema.GetColumnName(_col); - public ColumnType Type => _schema.GetColumnType(_col); - public bool IsActive => false; - - public IRow Metadata - { - get - { - if (_meta == null) - Interlocked.CompareExchange(ref _meta, new MetadataRow(_schema, _col, x => true), null); - return _meta; - } - } - - public SchemaWrap(ISchema schema, int col) - { - Contracts.AssertValue(schema); - Contracts.Assert(0 <= col && col < schema.ColumnCount); - Contracts.Assert(schema.GetColumnType(col).RawType == typeof(T)); - - _schema = schema; - _col = col; - } - - Delegate IColumn.GetGetter() - => GetGetter(); - - public ValueGetter GetGetter() - => throw Contracts.Except("Column not active"); - } - - /// - /// This class exists to present metadata as stored in an for one particular - /// column as an . This class will cease to be necessary at the point when all - /// metadata implementations are just simple s. - /// - public sealed class MetadataRow : IRow - { - public Schema Schema => _schemaImpl.AsSchema; - - private readonly ISchema _metaSchema; - private readonly int _col; - private readonly SchemaImpl _schemaImpl; - - private readonly KeyValuePair[] _map; - - long ICounted.Position => 0; - long ICounted.Batch => 0; - ValueGetter ICounted.GetIdGetter() - => IdGetter; - - private static void IdGetter(ref UInt128 id) - => id = default; - - private sealed class SchemaImpl : ISchema - { - private readonly MetadataRow _parent; - private readonly Dictionary _nameToCol; - public Schema AsSchema { get; } - - public int ColumnCount { get { return _parent._map.Length; } } - - public SchemaImpl(MetadataRow parent) - { - Contracts.AssertValue(parent); - _parent = parent; - _nameToCol = new Dictionary(ColumnCount); - for (int i = 0; i < _parent._map.Length; ++i) - _nameToCol[_parent._map[i].Key] = i; - - AsSchema = Schema.Create(this); - } - - public string GetColumnName(int col) - { - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - return _parent._map[col].Key; - } - - public ColumnType GetColumnType(int col) - { - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - return _parent._map[col].Value; - } - - public bool TryGetColumnIndex(string name, out int col) - { - return _nameToCol.TryGetValue(name, out col); - } - - public IEnumerable> GetMetadataTypes(int col) - { - return Enumerable.Empty>(); - } - - public void GetMetadata(string kind, int col, ref TValue value) - { - throw MetadataUtils.ExceptGetMetadata(); - } - - public ColumnType GetMetadataTypeOrNull(string kind, int col) - { - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - return null; - } - } - - public MetadataRow(ISchema schema, int col, Func takeMetadata) - { - Contracts.CheckValue(schema, nameof(schema)); - Contracts.CheckParam(0 <= col && col < schema.ColumnCount, nameof(col)); - Contracts.CheckValue(takeMetadata, nameof(takeMetadata)); - - _metaSchema = schema; - _col = col; - _map = _metaSchema.GetMetadataTypes(_col).Where(x => takeMetadata(x.Key)).ToArray(); - _schemaImpl = new SchemaImpl(this); - } - - public bool IsColumnActive(int col) - { - Contracts.CheckParam(0 <= col && col < _map.Length, nameof(col)); - return true; - } - - public ValueGetter GetGetter(int col) - { - Contracts.CheckParam(0 <= col && col < _map.Length, nameof(col)); - // REVIEW: On type mismatch, this will throw a metadata exception, which is not really - // appropriate. However, since this meant to be a shim anyway, we will tolerate imperfection. - return (ref TValue dst) => _metaSchema.GetMetadata(_map[col].Key, _col, ref dst); - } - } - - /// - /// This is used for a few implementations that need to store their own name, - /// metadata, and type themselves. - /// - private abstract class SimpleColumnBase : IValueColumn - { - public string Name { get; } - public IRow Metadata { get; } - public ColumnType Type { get; } - public abstract bool IsActive { get; } - - public SimpleColumnBase(string name, IRow meta, ColumnType type) - { - Contracts.CheckNonEmpty(name, nameof(name)); - Contracts.CheckValueOrNull(meta); - Contracts.CheckValue(type, nameof(type)); - Contracts.CheckParam(type.RawType == typeof(T), nameof(type), "Mismatch between CLR type and column type"); - - Name = name; - Metadata = meta; - Type = type; - } - - Delegate IColumn.GetGetter() - { - return GetGetter(); - } - - public abstract ValueGetter GetGetter(); - } - - private sealed class InactiveImpl : SimpleColumnBase - { - public override bool IsActive { get { return false; } } - - public InactiveImpl(string name, IRow meta, ColumnType type) - : base(name, meta, type) - { - } - - public override ValueGetter GetGetter() - { - throw Contracts.Except("Can't get getter for inactive column"); - } - } - - private sealed class ConstOneImpl : SimpleColumnBase - { - private readonly T _value; - - public override bool IsActive => true; - - public ConstOneImpl(string name, IRow meta, ColumnType type, T value) - : base(name, meta, type) - { - Contracts.Assert(type.IsPrimitive); - _value = value; - } - - public override ValueGetter GetGetter() - { - return Getter; - } - - private void Getter(ref T val) - { - val = _value; - } - } - - private sealed class ConstVecImpl : SimpleColumnBase> - { - private readonly VBuffer _value; - - public override bool IsActive { get { return true; } } - - public ConstVecImpl(string name, IRow meta, ColumnType type, VBuffer value) - : base(name, meta, type) - { - _value = value; - } - - public override ValueGetter> GetGetter() - { - return Getter; - } - - private void Getter(ref VBuffer val) - { - _value.CopyTo(ref val); - } - } - - private sealed class GetterImpl : SimpleColumnBase - { - private readonly ValueGetter _getter; - - public override bool IsActive => _getter != null; - - public GetterImpl(string name, IRow meta, ColumnType type, ValueGetter getter) - : base(name, meta, type) - { - Contracts.CheckValueOrNull(getter); - _getter = getter; - } - - public override ValueGetter GetGetter() - { - Contracts.Check(IsActive, "column is not active"); - return _getter; - } - } - - /// - /// An that is an amalgation of multiple implementers. - /// - private sealed class RowColumnRow : IRow - { - private static readonly DefaultCountedImpl _defCount = new DefaultCountedImpl(); - private readonly ICounted _counted; - private readonly IColumn[] _columns; - private readonly SchemaImpl _schema; - - public Schema Schema => _schema.AsSchema; - public long Position => _counted.Position; - public long Batch => _counted.Batch; - - public RowColumnRow(ICounted counted, IColumn[] columns) - { - Contracts.AssertValueOrNull(counted); - Contracts.AssertValue(columns); - _counted = counted ?? _defCount; - _columns = columns; - _schema = new SchemaImpl(this); - } - - public ValueGetter GetGetter(int col) - { - Contracts.CheckParam(IsColumnActive(col), nameof(col), "requested column not active"); - var rowCol = _columns[col] as IValueColumn; - if (rowCol == null) - throw Contracts.Except("Invalid TValue: '{0}'", typeof(TValue)); - return rowCol.GetGetter(); - } - - public bool IsColumnActive(int col) - { - Contracts.CheckParam(0 <= col && col < _columns.Length, nameof(col)); - return _columns[col].IsActive; - } - - public ValueGetter GetIdGetter() - { - return _counted.GetIdGetter(); - } - - private sealed class SchemaImpl : ISchema - { - private readonly RowColumnRow _parent; - private readonly Dictionary _nameToIndex; - - public Schema AsSchema { get; } - - public int ColumnCount => _parent._columns.Length; - - public SchemaImpl(RowColumnRow parent) - { - Contracts.AssertValue(parent); - _parent = parent; - _nameToIndex = new Dictionary(); - for (int i = 0; i < _parent._columns.Length; ++i) - _nameToIndex[_parent._columns[i].Name] = i; - AsSchema = Schema.Create(this); - } - - public void GetMetadata(string kind, int col, ref TValue value) - { - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - var meta = _parent._columns[col].Metadata; - int mcol; - if (meta == null || !meta.Schema.TryGetColumnIndex(kind, out mcol)) - throw MetadataUtils.ExceptGetMetadata(); - // REVIEW: Again, since this is a shim, not going to sweat the potential for inappropriate exception message. - meta.GetGetter(mcol)(ref value); - } - - public ColumnType GetMetadataTypeOrNull(string kind, int col) - { - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - var meta = _parent._columns[col].Metadata; - int mcol; - if (meta == null || !meta.Schema.TryGetColumnIndex(kind, out mcol)) - return null; - return meta.Schema.GetColumnType(mcol); - } - - public IEnumerable> GetMetadataTypes(int col) - { - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - // REVIEW: An IRow can have collisions in names, whereas there is no notion of this in metadata types. - // Since I intend to remove this soon anyway and the number of usages of this will be very low, I am just going - // to tolerate the potential for strangeness here, since it will practically never arise until we reorganize - // the whole thing. - var meta = _parent._columns[col].Metadata; - if (meta == null) - yield break; - var schema = meta.Schema; - for (int i = 0; i < schema.ColumnCount; ++i) - yield return new KeyValuePair(schema.GetColumnName(i), schema.GetColumnType(i)); - } - - public ColumnType GetColumnType(int col) - { - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - return _parent._columns[col].Type; - } - - public string GetColumnName(int col) - { - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - return _parent._columns[col].Name; - } - - public bool TryGetColumnIndex(string name, out int col) - { - return _nameToIndex.TryGetValue(name, out col); - } - } - - private sealed class DefaultCountedImpl : DefaultCounted - { - } - } - } -} diff --git a/src/Microsoft.ML.Data/Data/IDataLoader.cs b/src/Microsoft.ML.Data/Data/IDataLoader.cs index 8aecbb70df..7112511dae 100644 --- a/src/Microsoft.ML.Data/Data/IDataLoader.cs +++ b/src/Microsoft.ML.Data/Data/IDataLoader.cs @@ -2,11 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; using System.IO; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Model; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// An interface for exposing some number of items that can be opened for reading. diff --git a/src/Microsoft.ML.Data/Data/IRowSeekable.cs b/src/Microsoft.ML.Data/Data/IRowSeekable.cs index 3c0bf0db08..b29a49343c 100644 --- a/src/Microsoft.ML.Data/Data/IRowSeekable.cs +++ b/src/Microsoft.ML.Data/Data/IRowSeekable.cs @@ -4,24 +4,26 @@ using System; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { // REVIEW: Would it be a better apporach to add something akin to CanSeek, // as we have a CanShuffle? The idea is trying to make IRowSeekable propagate along certain transforms. /// /// Represents a data view that supports random access to a specific row. /// - public interface IRowSeekable : ISchematized + public interface IRowSeekable { - IRowSeeker GetSeeker(Func predicate); + RowSeeker GetSeeker(Func predicate); + + Schema Schema { get; } } /// /// Represents a row seeker with random access that can retrieve a specific row by the row index. - /// For IRowSeeker, when the state is valid (that is when MoveTo() returns true), it returns the - /// current row index. Otherwise it's -1. + /// For , when the state is valid (that is when + /// returns ), it returns the current row index. Otherwise it's -1. /// - public interface IRowSeeker : IRow, IDisposable + public abstract class RowSeeker : Row { /// /// Moves the seeker to a row at a specific row index. @@ -30,6 +32,6 @@ public interface IRowSeeker : IRow, IDisposable /// /// The row index to move to. /// True if a row with specified index is found; false otherwise. - bool MoveTo(long rowIndex); + public abstract bool MoveTo(long rowIndex); } } diff --git a/src/Microsoft.ML.Data/Data/ITransposeDataView.cs b/src/Microsoft.ML.Data/Data/ITransposeDataView.cs index f247bc9859..a608a28712 100644 --- a/src/Microsoft.ML.Data/Data/ITransposeDataView.cs +++ b/src/Microsoft.ML.Data/Data/ITransposeDataView.cs @@ -2,7 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { // REVIEW: There are a couple problems. Firstly, what to do about cases where // the number of rows exceeds int.MaxValue? Right now we just fail. Practically this makes @@ -14,8 +14,8 @@ namespace Microsoft.ML.Runtime.Data /// /// A view of data where columns can optionally be accessed slot by slot, as opposed to row /// by row in a typical dataview. A slot-accessible column can be accessed with a slot-by-slot - /// cursor via an (naturally, as opposed to row-by-row through an - /// ). This interface is intended to be implemented by classes that + /// cursor via an (naturally, as opposed to row-by-row through an + /// ). This interface is intended to be implemented by classes that /// want to provide an option for an alternate way of accessing the data stored in a /// . /// @@ -23,12 +23,13 @@ namespace Microsoft.ML.Runtime.Data /// is accessible in this fashion iff 's /// returns a non-null value. /// - public interface ITransposeDataView : IDataView + [BestFriend] + internal interface ITransposeDataView : IDataView { /// /// An enhanced schema, containing information on the transposition properties, if any, /// of each column. Note that there is no contract or suggestion that this property - /// should be equal to . + /// should be equal to . /// ITransposeSchema TransposeSchema { get; } @@ -36,37 +37,19 @@ public interface ITransposeDataView : IDataView /// Presents a cursor over the slots of a transposable column, or throws if the column /// is not transposable. /// - ISlotCursor GetSlotCursor(int col); - } - - /// - /// A cursor that allows slot-by-slot access of data. - /// - public interface ISlotCursor : ICursor - { - /// - /// The slot type for this cursor. Note that this should equal the - /// for the column from which this slot cursor - /// was created. - /// - VectorType GetSlotType(); - - /// - /// A getter delegate for the slot values. The type must correspond - /// to the item type from . - /// - ValueGetter> GetGetter(); + SlotCursor GetSlotCursor(int col); } /// /// The transpose schema returns the schema information of the view we have transposed. /// - public interface ITransposeSchema : ISchema + [BestFriend] + internal interface ITransposeSchema : ISchema { /// /// Analogous to , except instead of returning the type of value - /// accessible through the , returns the item type of value accessible - /// through the . This will return null iff this particular + /// accessible through the , returns the item type of value accessible + /// through the . This will return null iff this particular /// column is not transposable, that is, it cannot be viewed in a slotwise fashion. Observe from /// the return type that this will always be a vector type. This vector type should be of fixed /// size and one dimension. diff --git a/src/Microsoft.ML.Data/Data/RowCursorUtils.cs b/src/Microsoft.ML.Data/Data/RowCursorUtils.cs index 1104a36fb1..7f6aeff35e 100644 --- a/src/Microsoft.ML.Data/Data/RowCursorUtils.cs +++ b/src/Microsoft.ML.Data/Data/RowCursorUtils.cs @@ -7,11 +7,10 @@ using System.Linq; using System.Reflection; using System.Text; -using Microsoft.ML.Data; -using Microsoft.ML.Runtime.Data.Conversion; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Data.Conversion; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { public static class RowCursorUtils { @@ -23,17 +22,17 @@ public static class RowCursorUtils /// The row to get the getter for /// The column index, which must be active on that row /// The getter as a delegate - public static Delegate GetGetterAsDelegate(IRow row, int col) + public static Delegate GetGetterAsDelegate(Row row, int col) { Contracts.CheckValue(row, nameof(row)); - Contracts.CheckParam(0 <= col && col < row.Schema.ColumnCount, nameof(col)); + Contracts.CheckParam(0 <= col && col < row.Schema.Count, nameof(col)); Contracts.CheckParam(row.IsColumnActive(col), nameof(col), "column was not active"); - Func getGetter = GetGetterAsDelegateCore; - return Utils.MarshalInvoke(getGetter, row.Schema.GetColumnType(col).RawType, row, col); + Func getGetter = GetGetterAsDelegateCore; + return Utils.MarshalInvoke(getGetter, row.Schema[col].Type.RawType, row, col); } - private static Delegate GetGetterAsDelegateCore(IRow row, int col) + private static Delegate GetGetterAsDelegateCore(Row row, int col) { return row.GetGetter(col); } @@ -44,18 +43,18 @@ private static Delegate GetGetterAsDelegateCore(IRow row, int col) /// . /// /// - public static Delegate GetGetterAs(ColumnType typeDst, IRow row, int col) + public static Delegate GetGetterAs(ColumnType typeDst, Row row, int col) { Contracts.CheckValue(typeDst, nameof(typeDst)); Contracts.CheckParam(typeDst.IsPrimitive, nameof(typeDst)); Contracts.CheckValue(row, nameof(row)); - Contracts.CheckParam(0 <= col && col < row.Schema.ColumnCount, nameof(col)); + Contracts.CheckParam(0 <= col && col < row.Schema.Count, nameof(col)); Contracts.CheckParam(row.IsColumnActive(col), nameof(col), "column was not active"); - var typeSrc = row.Schema.GetColumnType(col); + var typeSrc = row.Schema[col].Type; Contracts.Check(typeSrc.IsPrimitive, "Source column type must be primitive"); - Func> del = GetGetterAsCore; + Func> del = GetGetterAsCore; var methodInfo = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(typeSrc.RawType, typeDst.RawType); return (Delegate)methodInfo.Invoke(null, new object[] { typeSrc, typeDst, row, col }); } @@ -64,24 +63,24 @@ public static Delegate GetGetterAs(ColumnType typeDst, IRow row, int col) /// Given a destination type, IRow, and column index, return a ValueGetter{TDst} for the column /// with a conversion to typeDst, if needed. /// - public static ValueGetter GetGetterAs(ColumnType typeDst, IRow row, int col) + public static ValueGetter GetGetterAs(ColumnType typeDst, Row row, int col) { Contracts.CheckValue(typeDst, nameof(typeDst)); Contracts.CheckParam(typeDst.IsPrimitive, nameof(typeDst)); Contracts.CheckParam(typeDst.RawType == typeof(TDst), nameof(typeDst)); Contracts.CheckValue(row, nameof(row)); - Contracts.CheckParam(0 <= col && col < row.Schema.ColumnCount, nameof(col)); + Contracts.CheckParam(0 <= col && col < row.Schema.Count, nameof(col)); Contracts.CheckParam(row.IsColumnActive(col), nameof(col), "column was not active"); - var typeSrc = row.Schema.GetColumnType(col); + var typeSrc = row.Schema[col].Type; Contracts.Check(typeSrc.IsPrimitive, "Source column type must be primitive"); - Func> del = GetGetterAsCore; + Func> del = GetGetterAsCore; var methodInfo = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(typeSrc.RawType, typeof(TDst)); return (ValueGetter)methodInfo.Invoke(null, new object[] { typeSrc, typeDst, row, col }); } - private static ValueGetter GetGetterAsCore(ColumnType typeSrc, ColumnType typeDst, IRow row, int col) + private static ValueGetter GetGetterAsCore(ColumnType typeSrc, ColumnType typeDst, Row row, int col) { Contracts.Assert(typeof(TSrc) == typeSrc.RawType); Contracts.Assert(typeof(TDst) == typeDst.RawType); @@ -112,18 +111,18 @@ private static ValueGetter GetGetterAsCore(ColumnType typeSrc, /// into the required type. This method can be useful if you want to output a value /// as a string in a generic way, but don't really care how you do it. /// - public static ValueGetter GetGetterAsStringBuilder(IRow row, int col) + public static ValueGetter GetGetterAsStringBuilder(Row row, int col) { Contracts.CheckValue(row, nameof(row)); - Contracts.CheckParam(0 <= col && col < row.Schema.ColumnCount, nameof(col)); + Contracts.CheckParam(0 <= col && col < row.Schema.Count, nameof(col)); Contracts.CheckParam(row.IsColumnActive(col), nameof(col), "column was not active"); - var typeSrc = row.Schema.GetColumnType(col); + var typeSrc = row.Schema[col].Type; Contracts.Check(typeSrc.IsPrimitive, "Source column type must be primitive"); return Utils.MarshalInvoke(GetGetterAsStringBuilderCore, typeSrc.RawType, typeSrc, row, col); } - private static ValueGetter GetGetterAsStringBuilderCore(ColumnType typeSrc, IRow row, int col) + private static ValueGetter GetGetterAsStringBuilderCore(ColumnType typeSrc, Row row, int col) { Contracts.Assert(typeof(TSrc) == typeSrc.RawType); @@ -142,16 +141,16 @@ private static ValueGetter GetGetterAsStringBuilderCore(Col /// /// Given the item type, typeDst, a row, and column index, return a ValueGetter for the vector-valued /// column with a conversion to a vector of typeDst, if needed. This is the weakly typed version of - /// . + /// . /// - public static Delegate GetVecGetterAs(PrimitiveType typeDst, IRow row, int col) + public static Delegate GetVecGetterAs(PrimitiveType typeDst, Row row, int col) { Contracts.CheckValue(typeDst, nameof(typeDst)); Contracts.CheckValue(row, nameof(row)); - Contracts.CheckParam(0 <= col && col < row.Schema.ColumnCount, nameof(col)); + Contracts.CheckParam(0 <= col && col < row.Schema.Count, nameof(col)); Contracts.CheckParam(row.IsColumnActive(col), nameof(col), "column was not active"); - var typeSrc = row.Schema.GetColumnType(col); + var typeSrc = row.Schema[col].Type; Contracts.Check(typeSrc.IsVector, "Source column type must be vector"); Func>> del = GetVecGetterAsCore; @@ -163,15 +162,15 @@ public static Delegate GetVecGetterAs(PrimitiveType typeDst, IRow row, int col) /// Given the item type, typeDst, a row, and column index, return a ValueGetter{VBuffer{TDst}} for the /// vector-valued column with a conversion to a vector of typeDst, if needed. /// - public static ValueGetter> GetVecGetterAs(PrimitiveType typeDst, IRow row, int col) + public static ValueGetter> GetVecGetterAs(PrimitiveType typeDst, Row row, int col) { Contracts.CheckValue(typeDst, nameof(typeDst)); Contracts.CheckParam(typeDst.RawType == typeof(TDst), nameof(typeDst)); Contracts.CheckValue(row, nameof(row)); - Contracts.CheckParam(0 <= col && col < row.Schema.ColumnCount, nameof(col)); + Contracts.CheckParam(0 <= col && col < row.Schema.Count, nameof(col)); Contracts.CheckParam(row.IsColumnActive(col), nameof(col), "column was not active"); - var typeSrc = row.Schema.GetColumnType(col); + var typeSrc = row.Schema[col].Type; Contracts.Check(typeSrc.IsVector, "Source column type must be vector"); Func>> del = GetVecGetterAsCore; @@ -183,7 +182,7 @@ public static ValueGetter> GetVecGetterAs(PrimitiveType type /// Given the item type, typeDst, and a slot cursor, return a ValueGetter{VBuffer{TDst}} for the /// vector-valued column with a conversion to a vector of typeDst, if needed. /// - public static ValueGetter> GetVecGetterAs(PrimitiveType typeDst, ISlotCursor cursor) + public static ValueGetter> GetVecGetterAs(PrimitiveType typeDst, SlotCursor cursor) { Contracts.CheckValue(typeDst, nameof(typeDst)); Contracts.CheckParam(typeDst.RawType == typeof(TDst), nameof(typeDst)); @@ -200,12 +199,12 @@ public static ValueGetter> GetVecGetterAs(PrimitiveType type /// private abstract class GetterFactory { - public static GetterFactory Create(IRow row, int col) + public static GetterFactory Create(Row row, int col) { return new RowImpl(row, col); } - public static GetterFactory Create(ISlotCursor cursor) + public static GetterFactory Create(SlotCursor cursor) { return new SlotImpl(cursor); } @@ -214,10 +213,10 @@ public static GetterFactory Create(ISlotCursor cursor) private sealed class RowImpl : GetterFactory { - private readonly IRow _row; + private readonly Row _row; private readonly int _col; - public RowImpl(IRow row, int col) + public RowImpl(Row row, int col) { _row = row; _col = col; @@ -231,9 +230,9 @@ public override ValueGetter GetGetter() private sealed class SlotImpl : GetterFactory { - private readonly ISlotCursor _cursor; + private readonly SlotCursor _cursor; - public SlotImpl(ISlotCursor cursor) + public SlotImpl(SlotCursor cursor) { _cursor = cursor; } @@ -294,16 +293,16 @@ private static ValueGetter> GetVecGetterAsCore(VectorT /// is different than it was, in the last call. This is practically useful for determining /// group boundaries. Note that the delegate will return true on the first row. /// - public static Func GetIsNewGroupDelegate(IRow cursor, int col) + public static Func GetIsNewGroupDelegate(Row cursor, int col) { Contracts.CheckValue(cursor, nameof(cursor)); - Contracts.Check(0 <= col && col < cursor.Schema.ColumnCount); - ColumnType type = cursor.Schema.GetColumnType(col); + Contracts.Check(0 <= col && col < cursor.Schema.Count); + ColumnType type = cursor.Schema[col].Type; Contracts.Check(type.IsKey); return Utils.MarshalInvoke(GetIsNewGroupDelegateCore, type.RawType, cursor, col); } - private static Func GetIsNewGroupDelegateCore(IRow cursor, int col) + private static Func GetIsNewGroupDelegateCore(Row cursor, int col) { var getter = cursor.GetGetter(col); bool first = true; @@ -326,7 +325,10 @@ private static Func GetIsNewGroupDelegateCore(IRow cursor, int col) }; } - public static Func GetIsNewBatchDelegate(IRow cursor, int batchSize) + [Obsolete("The usages of this appear to be based on a total misunderstanding of what Batch actually is. It is a mechanism " + + "to enable sharding and recovery of parallelized data, and has nothing to do with actual data.")] + [BestFriend] + internal static Func GetIsNewBatchDelegate(Row cursor, int batchSize) { Contracts.CheckParam(batchSize > 0, nameof(batchSize), "Batch size must be > 0"); long lastNewBatchPosition = -1; @@ -363,9 +365,9 @@ public static string TestGetLabelGetter(ColumnType type, bool allowKeys) return allowKeys ? "Expected R4, R8, Bool or Key type" : "Expected R4, R8 or Bool type"; } - public static ValueGetter GetLabelGetter(IRow cursor, int labelIndex) + public static ValueGetter GetLabelGetter(Row cursor, int labelIndex) { - var type = cursor.Schema.GetColumnType(labelIndex); + var type = cursor.Schema[labelIndex].Type; if (type == NumberType.R4) return cursor.GetGetter(labelIndex); @@ -385,9 +387,9 @@ public static ValueGetter GetLabelGetter(IRow cursor, int labelIndex) return GetLabelGetterNotFloat(cursor, labelIndex); } - private static ValueGetter GetLabelGetterNotFloat(IRow cursor, int labelIndex) + private static ValueGetter GetLabelGetterNotFloat(Row cursor, int labelIndex) { - var type = cursor.Schema.GetColumnType(labelIndex); + var type = cursor.Schema[labelIndex].Type; Contracts.Assert(type != NumberType.R4 && type != NumberType.R8); @@ -422,7 +424,7 @@ private static ValueGetter GetLabelGetterNotFloat(IRow cursor, int label }; } - public static ValueGetter> GetLabelGetter(ISlotCursor cursor) + public static ValueGetter> GetLabelGetter(SlotCursor cursor) { var type = cursor.GetSlotType().ItemType; if (type == NumberType.R4) @@ -454,25 +456,11 @@ public static ValueGetter> GetLabelGetter(ISlotCursor cursor) }; } - /// - /// Returns a row that is a deep in-memory copy of an input row. Note that inactive - /// columns are allowed in this row, and their activity or inactivity will be reflected - /// in the output row. Note that the deep copy includes a copy of the metadata as well. - /// - /// The input row - /// A deep in-memory copy of the input row - public static IRow CloneRow(IRow row) - { - Contracts.CheckValue(row, nameof(row)); - return RowColumnUtils.GetRow(null, - Utils.BuildArray(row.Schema.ColumnCount, c => RowColumnUtils.GetColumn(row, c))); - } - /// /// Fetches the value of the column by name, in the given row. /// Used by the evaluators to retrieve the metrics from the results IDataView. /// - public static T Fetch(IExceptionContext ectx, IRow row, string name) + public static T Fetch(IExceptionContext ectx, Row row, string name) { if (!row.Schema.TryGetColumnIndex(name, out int col)) throw ectx.Except($"Could not find column '{name}'"); @@ -483,56 +471,54 @@ public static T Fetch(IExceptionContext ectx, IRow row, string name) /// /// Given a row, returns a one-row data view. This is useful for cases where you have a row, and you - /// wish to use some facility normally only exposed to dataviews. (For example, you have an - /// but want to save it somewhere using a .) + /// wish to use some facility normally only exposed to dataviews. (For example, you have an + /// but want to save it somewhere using a .) /// Note that it is not possible for this method to ensure that the input does not /// change, so users of this convenience must take care of what they do with the input row or the data - /// source it came from, while the returned dataview is potentially being used; if this is somehow - /// difficult it may be wise to use to first have a deep copy of the resulting row. + /// source it came from, while the returned dataview is potentially being used. /// /// An environment used to create the host for the resulting data view /// A row, whose columns must all be active /// A single-row data view incorporating that row - public static IDataView RowAsDataView(IHostEnvironment env, IRow row) + public static IDataView RowAsDataView(IHostEnvironment env, Row row) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(row, nameof(row)); - env.CheckParam(Enumerable.Range(0, row.Schema.ColumnCount).All(c => row.IsColumnActive(c)), nameof(row), "Some columns were inactive"); + env.CheckParam(Enumerable.Range(0, row.Schema.Count).All(c => row.IsColumnActive(c)), nameof(row), "Some columns were inactive"); return new OneRowDataView(env, row); } private sealed class OneRowDataView : IDataView { - private readonly IRow _row; + private readonly Row _row; private readonly IHost _host; // A channel provider is required for creating the cursor. public Schema Schema => _row.Schema; - public bool CanShuffle { get { return true; } } // The shuffling is even uniformly IID!! :) + public bool CanShuffle => true; // The shuffling is even uniformly IID!! :) - public OneRowDataView(IHostEnvironment env, IRow row) + public OneRowDataView(IHostEnvironment env, Row row) { Contracts.AssertValue(env); _host = env.Register("OneRowDataView"); _host.AssertValue(row); - _host.Assert(Enumerable.Range(0, row.Schema.ColumnCount).All(c => row.IsColumnActive(c))); + _host.Assert(Enumerable.Range(0, row.Schema.Count).All(c => row.IsColumnActive(c))); _row = row; } - public IRowCursor GetRowCursor(Func needCol, IRandom rand = null) + public RowCursor GetRowCursor(Func needCol, Random rand = null) { _host.CheckValue(needCol, nameof(needCol)); _host.CheckValueOrNull(rand); - bool[] active = Utils.BuildArray(Schema.ColumnCount, needCol); + bool[] active = Utils.BuildArray(Schema.Count, needCol); return new Cursor(_host, this, active); } - public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func needCol, int n, IRandom rand = null) + public RowCursor[] GetRowCursorSet(Func needCol, int n, Random rand = null) { _host.CheckValue(needCol, nameof(needCol)); _host.CheckValueOrNull(rand); - consolidator = null; - return new IRowCursor[] { GetRowCursor(needCol, rand) }; + return new RowCursor[] { GetRowCursor(needCol, rand) }; } public long? GetRowCount() @@ -540,12 +526,12 @@ public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Fun return 1; } - private sealed class Cursor : RootCursorBase, IRowCursor + private sealed class Cursor : RootCursorBase { private readonly OneRowDataView _parent; private readonly bool[] _active; - public Schema Schema => _parent.Schema; + public override Schema Schema => _parent.Schema; public override long Batch { get { return 0; } } public Cursor(IHost host, OneRowDataView parent, bool[] active) @@ -553,7 +539,7 @@ public Cursor(IHost host, OneRowDataView parent, bool[] active) { Ch.AssertValue(parent); Ch.AssertValue(active); - Ch.Assert(active.Length == parent.Schema.ColumnCount); + Ch.Assert(active.Length == parent.Schema.Count); _parent = parent; _active = active; } @@ -563,9 +549,9 @@ protected override bool MoveNextCore() return State == CursorState.NotStarted; } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { - Ch.CheckParam(0 <= col && col < Schema.ColumnCount, nameof(col)); + Ch.CheckParam(0 <= col && col < Schema.Count, nameof(col)); Ch.CheckParam(IsColumnActive(col), nameof(col), "Requested column is not active"); var getter = _parent._row.GetGetter(col); return @@ -576,22 +562,22 @@ public ValueGetter GetGetter(int col) }; } - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { - Ch.CheckParam(0 <= col && col < Schema.ColumnCount, nameof(col)); + Ch.CheckParam(0 <= col && col < Schema.Count, nameof(col)); // We present the "illusion" that this column is not active, even though it must be // in the input row. Ch.Assert(_parent._row.IsColumnActive(col)); return _active[col]; } - public override ValueGetter GetIdGetter() + public override ValueGetter GetIdGetter() { return - (ref UInt128 val) => + (ref RowId val) => { Ch.Check(IsGood, "Cannot call ID getter in current state"); - val = new UInt128((ulong)Position, 0); + val = new RowId((ulong)Position, 0); }; } } diff --git a/src/Microsoft.ML.Api/SchemaDefinition.cs b/src/Microsoft.ML.Data/Data/SchemaDefinition.cs similarity index 96% rename from src/Microsoft.ML.Api/SchemaDefinition.cs rename to src/Microsoft.ML.Data/Data/SchemaDefinition.cs index 3ed4fbbbdd..b669783d4f 100644 --- a/src/Microsoft.ML.Api/SchemaDefinition.cs +++ b/src/Microsoft.ML.Data/Data/SchemaDefinition.cs @@ -6,10 +6,9 @@ using System.Collections.Generic; using System.Linq; using System.Reflection; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.Api +namespace Microsoft.ML.Data { /// /// Attach to a member of a class to indicate that the item type should be of class key. @@ -72,25 +71,12 @@ public sealed class ColumnAttribute : Attribute public ColumnAttribute(string ordinal, string name = null) { Name = name; - Ordinal = ordinal; } /// /// Column name. /// public string Name { get; } - - /// - /// Contains positions of indices of source columns in the form - /// of ranges. Examples of range: if we want to include just column - /// with index 1 we can write the range as 1, if we want to include - /// columns 1 to 10 then we can write the range as 1-10 and we want to include all the - /// columns from column with index 1 until end then we can write 1-*. - /// - /// This takes sequence of ranges that are comma seperated, example: - /// 1,2-5,10-* - /// - public string Ordinal { get; } } /// diff --git a/src/Microsoft.ML.Data/Data/SlotCursor.cs b/src/Microsoft.ML.Data/Data/SlotCursor.cs new file mode 100644 index 0000000000..1b043a9273 --- /dev/null +++ b/src/Microsoft.ML.Data/Data/SlotCursor.cs @@ -0,0 +1,142 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.ML.Data +{ + /// + /// A cursor that allows slot-by-slot access of data. This is to + /// what is to . + /// + public abstract class SlotCursor : IDisposable + { + [BestFriend] + private protected readonly IChannel Ch; + private CursorState _state; + + /// + /// Whether the cursor is in a state where it can serve up data, that is, + /// has been called and returned . + /// + [BestFriend] + private protected bool IsGood => _state == CursorState.Good; + + [BestFriend] + private protected SlotCursor(IChannelProvider provider) + { + Contracts.AssertValue(provider); + Ch = provider.Start("Slot Cursor"); + _state = CursorState.NotStarted; + } + + /// + /// The slot index. Incremented by one when is called and returns . + /// When initially created, or after returns , this will be -1. + /// + public abstract int SlotIndex { get; } + + /// + /// Advance to the next slot. When the cursor is first created, this method should be called to + /// move to the first slot. Returns if there are no more slots. + /// + public abstract bool MoveNext(); + + /// + /// The slot type for this cursor. Note that this should equal the + /// for the column from which this slot cursor + /// was created. + /// + public abstract VectorType GetSlotType(); + + /// + /// A getter delegate for the slot values. The type must correspond + /// to the item type from . + /// + public abstract ValueGetter> GetGetter(); + + public virtual void Dispose() + { + if (_state != CursorState.Done) + { + Ch.Dispose(); + _state = CursorState.Done; + } + } + + /// + /// For wrapping another slot cursor from which we get and , + /// but not the data or type accesors. Somewhat analogous to the + /// for s. + /// + [BestFriend] + internal abstract class SynchronizedSlotCursor : SlotCursor + { + private readonly SlotCursor _root; + + public SynchronizedSlotCursor(IChannelProvider provider, SlotCursor cursor) + : base(provider) + { + Contracts.AssertValue(cursor); + // If the input is itself a sync-base, we can walk up the chain to get its root, + // thereby making things more efficient. + _root = cursor is SynchronizedSlotCursor sync ? sync._root : cursor; + } + + public override bool MoveNext() + => _root.MoveNext(); + + public override int SlotIndex => _root.SlotIndex; + } + + /// + /// A useful base class for common implementations, somewhat + /// analogous to the for s. + /// + [BestFriend] + internal abstract class RootSlotCursor : SlotCursor + { + private int _slotIndex; + + public RootSlotCursor(IChannelProvider provider) + : base(provider) + { + _slotIndex = -1; + } + + public override int SlotIndex => _slotIndex; + + public override void Dispose() + { + base.Dispose(); + _slotIndex = -1; + } + + public override bool MoveNext() + { + if (_state == CursorState.Done) + return false; + + Ch.Assert(_state == CursorState.NotStarted || _state == CursorState.Good); + if (MoveNextCore()) + { + Ch.Assert(_state == CursorState.NotStarted || _state == CursorState.Good); + + _slotIndex++; + _state = CursorState.Good; + return true; + } + + Dispose(); + return false; + } + + /// + /// Core implementation of . Called only if this method + /// has not yet previously returned . + /// + protected abstract bool MoveNextCore(); + } + } +} diff --git a/src/Microsoft.ML.Data/DataDebuggerPreview.cs b/src/Microsoft.ML.Data/DataDebuggerPreview.cs index ba1049eb2b..f57a5d5697 100644 --- a/src/Microsoft.ML.Data/DataDebuggerPreview.cs +++ b/src/Microsoft.ML.Data/DataDebuggerPreview.cs @@ -2,13 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; using System; using System.Collections.Generic; using System.Collections.Immutable; using System.Linq; +using Microsoft.ML.Internal.Utilities; namespace Microsoft.ML.Data { @@ -32,7 +30,7 @@ internal DataDebuggerPreview(IDataView data, int maxRows = Defaults.MaxRows) Contracts.CheckParam(maxRows >= 0, nameof(maxRows)); Schema = data.Schema; - int n = data.Schema.ColumnCount; + int n = data.Schema.Count; var rows = new List(); var columns = new List[n]; @@ -63,7 +61,7 @@ internal DataDebuggerPreview(IDataView data, int maxRows = Defaults.MaxRows) public override string ToString() => $"{Schema.Count} columns, {RowView.Length} rows"; - private Action> MakeSetter(IRow row, int col) + private Action> MakeSetter(Row row, int col) { var getter = row.GetGetter(col); string name = row.Schema[col].Name; diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs index 473ee0d8d1..e8839a6efe 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs @@ -2,15 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Command; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Data.IO; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Transforms; using System; using System.Collections.Concurrent; using System.Collections.Generic; @@ -21,6 +12,14 @@ using System.Text; using System.Threading; using System.Threading.Tasks; +using Microsoft.ML; +using Microsoft.ML.Command; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.Data.IO; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Transforms; [assembly: LoadableClass(BinaryLoader.Summary, typeof(BinaryLoader), typeof(BinaryLoader.Arguments), typeof(SignatureDataLoader), "Binary Loader", @@ -34,7 +33,7 @@ [assembly: LoadableClass(typeof(BinaryLoader.InfoCommand), typeof(BinaryLoader.InfoCommand.Arguments), typeof(SignatureCommand), "", BinaryLoader.InfoCommand.LoadName, "idv")] -namespace Microsoft.ML.Runtime.Data.IO +namespace Microsoft.ML.Data.IO { public sealed class BinaryLoader : IDataLoader, IDisposable { @@ -754,7 +753,7 @@ public void GetMetadata(string kind, int col, ref TValue value) private const ulong SlotNamesVersion = 0x0001000100010003; /// - /// Lower inclusive bound of versions this reader can read. + /// Low inclusive bound of versions this reader can read. /// private const ulong ReaderFirstVersion = 0x0001000100010002; @@ -1016,7 +1015,7 @@ private static void SaveSchema(IHostEnvironment env, ModelSaveContext ctx, Schem var saver = new BinarySaver(env, saverArgs); var cols = Enumerable.Range(0, schema.Count) - .Select(x => new { col = x, isSavable = saver.IsColumnSavable(schema.GetColumnType(x)) }); + .Select(x => new { col = x, isSavable = saver.IsColumnSavable(schema[x].Type) }); int[] toSave = cols.Where(x => x.isSavable).Select(x => x.col).ToArray(); unsavableColIndices = cols.Where(x => !x.isSavable).Select(x => x.col).ToArray(); ctx.SaveBinaryStream("Schema.idv", w => saver.SaveData(w.BaseStream, noRows, toSave)); @@ -1234,7 +1233,7 @@ private TableOfContentsEntry CreateRowIndexEntry(string rowIndexName) return entry; } - private IRowCursor GetRowCursorCore(Func predicate, IRandom rand = null) + private RowCursor GetRowCursorCore(Func predicate, Random rand = null) { if (rand != null && _randomShufflePoolRows > 0) { @@ -1247,23 +1246,21 @@ private IRowCursor GetRowCursorCore(Func predicate, IRandom rand = nu return new Cursor(this, predicate, rand); } - public IRowCursor GetRowCursor(Func predicate, IRandom rand = null) + public RowCursor GetRowCursor(Func predicate, Random rand = null) { _host.CheckValue(predicate, nameof(predicate)); _host.CheckValueOrNull(rand); return GetRowCursorCore(predicate, rand); } - public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, - Func predicate, int n, IRandom rand = null) + public RowCursor[] GetRowCursorSet(Func predicate, int n, Random rand = null) { _host.CheckValue(predicate, nameof(predicate)); _host.CheckValueOrNull(rand); - consolidator = null; - return new IRowCursor[] { GetRowCursorCore(predicate, rand) }; + return new RowCursor[] { GetRowCursorCore(predicate, rand) }; } - private sealed class Cursor : RootCursorBase, IRowCursor + private sealed class Cursor : RootCursorBase { private const string _badCursorState = "cursor is either not started or is ended, and cannot get values"; @@ -1284,8 +1281,9 @@ private sealed class Cursor : RootCursorBase, IRowCursor private readonly ExceptionMarshaller _exMarshaller; private volatile bool _disposed; + private volatile bool _done; - public Schema Schema => _parent.Schema; + public override Schema Schema => _parent.Schema; public override long Batch { @@ -1293,7 +1291,7 @@ public override long Batch get { return 0; } } - public Cursor(BinaryLoader parent, Func predicate, IRandom rand) + public Cursor(BinaryLoader parent, Func predicate, Random rand) : base(parent._host) { _parent = parent; @@ -1363,60 +1361,71 @@ public Cursor(BinaryLoader parent, Func predicate, IRandom rand) _pipeTask = SetupDecompressTask(); } - public override void Dispose() + protected override void Dispose(bool disposing) { - if (!_disposed && _readerThread != null) + if (_disposed) + return; + if (_done) + { + base.Dispose(disposing); + return; + } + + if (disposing) { - // We should reach this block only in the event of a dispose - // before all rows have been iterated upon. + if (_readerThread != null) + { + // We should reach this block only in the event of a dispose + // before all rows have been iterated upon. - // First set the flag on the cursor. The stream-reader and the - // pipe-decompressor workers will detect this, stop their work, - // and do whatever "cleanup" is natural for them to perform. - _disposed = true; + // First set the flag on the cursor. The stream-reader and the + // pipe-decompressor workers will detect this, stop their work, + // and do whatever "cleanup" is natural for them to perform. + _disposed = true; - // In the disk read -> decompress -> codec read pipeline, we - // clean up in reverse order. - // 1. First we clear out any pending codec readers, for each pipe. - // 2. Then we join the pipe worker threads, which in turn should - // have cleared out all of the pending blocks to decompress. - // 3. Then finally we join against the reader thread. + // In the disk read -> decompress -> codec read pipeline, we + // clean up in reverse order. + // 1. First we clear out any pending codec readers, for each pipe. + // 2. Then we join the pipe worker threads, which in turn should + // have cleared out all of the pending blocks to decompress. + // 3. Then finally we join against the reader thread. - // This code is analogous to the stuff in MoveNextCore, except - // nothing is actually done with the resulting blocks. + // This code is analogous to the stuff in MoveNextCore, except + // nothing is actually done with the resulting blocks. - try - { - for (; ; ) + try + { + for (; ; ) + { + // This cross-block-index access pattern is deliberate, as + // by having a consistent access pattern everywhere we can + // have much greater confidence this will never deadlock. + bool anyTrue = false; + for (int c = 0; c < _pipes.Length; ++c) + anyTrue |= _pipes[c].MoveNextCleanup(); + if (!anyTrue) + break; + } + } + catch (OperationCanceledException ex) { - // This cross-block-index access pattern is deliberate, as - // by having a consistent access pattern everywhere we can - // have much greater confidence this will never deadlock. - bool anyTrue = false; - for (int c = 0; c < _pipes.Length; ++c) - anyTrue |= _pipes[c].MoveNextCleanup(); - if (!anyTrue) - break; + // REVIEW: Encountering this here means that we did not encounter + // the exception during normal cursoring, but at some later point. I feel + // we should not be tolerant of this, and should throw, though it might be + // an ambiguous point. + Contracts.Assert(ex.CancellationToken == _exMarshaller.Token); + _exMarshaller.ThrowIfSet(Ch); + Contracts.Assert(false); + } + finally + { + _pipeTask.Wait(); + _readerThread.Join(); } - } - catch (OperationCanceledException ex) - { - // REVIEW: Encountering this here means that we did not encounter - // the exception during normal cursoring, but at some later point. I feel - // we should not be tolerant of this, and should throw, though it might be - // an ambiguous point. - Contracts.Assert(ex.CancellationToken == _exMarshaller.Token); - _exMarshaller.ThrowIfSet(Ch); - Contracts.Assert(false); - } - finally - { - _pipeTask.Wait(); - _readerThread.Join(); } } - - base.Dispose(); + _disposed = true; + base.Dispose(disposing); } private Task SetupDecompressTask() @@ -2009,7 +2018,7 @@ public override Delegate GetGetter() } } - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { Ch.CheckParam(0 <= col && col < _colToActivesIndex.Length, nameof(col)); return _colToActivesIndex[col] >= 0; @@ -2032,7 +2041,7 @@ protected override bool MoveNextCore() // threads will exit if all potentially blocking operations are // waiting on the same cancellation token that we catch here. Contracts.Assert(ex.CancellationToken == _exMarshaller.Token); - _disposed = true; + _done = true; // Unlike the early non-error dispose case, we do not make any // effort to recycle buffers since it would be exceptionally difficult // to do so. All threads are already unblocked, one of them with the @@ -2060,7 +2069,7 @@ protected override bool MoveNextCore() // Set the _disposed flag, so that when the Dispose // method is called it does not trigger the "premature // exit" handling. - _disposed = true; + _done = true; // If we got to this point these threads must have already // completed their work, but for the sake of hygiene join // against them anyway. @@ -2070,7 +2079,7 @@ protected override bool MoveNextCore() return more; } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Ch.CheckParam(0 <= col && col < _colToActivesIndex.Length, nameof(col)); Ch.CheckParam(_colToActivesIndex[col] >= 0, nameof(col), "requested column not active"); @@ -2099,15 +2108,15 @@ private Delegate NoRowGetter() return del; } - public override ValueGetter GetIdGetter() + public override ValueGetter GetIdGetter() { if (_blockShuffleOrder == null) { return - (ref UInt128 val) => + (ref RowId val) => { Ch.Check(IsGood, "Cannot call ID getter in current state"); - val = new UInt128((ulong)Position, 0); + val = new RowId((ulong)Position, 0); }; } // Find the index of the last block. Because the last block is unevenly sized, @@ -2123,7 +2132,7 @@ public override ValueGetter GetIdGetter() long firstPositionToCorrect = ((long)lastBlockIdx * _rowsPerBlock) + _rowsInLastBlock; return - (ref UInt128 val) => + (ref RowId val) => { Ch.Check(IsGood, "Cannot call ID getter in current state"); long pos = Position; @@ -2133,7 +2142,7 @@ public override ValueGetter GetIdGetter() long blockPos = (long)_rowsPerBlock * _blockShuffleOrder[(int)(pos / _rowsPerBlock)]; blockPos += (pos % _rowsPerBlock); Ch.Assert(0 <= blockPos && blockPos < _parent.RowCount); - val = new UInt128((ulong)blockPos, 0); + val = new RowId((ulong)blockPos, 0); }; } } diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoaderSaverCatalog.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoaderSaverCatalog.cs index d697112719..b912a83a25 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoaderSaverCatalog.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoaderSaverCatalog.cs @@ -3,9 +3,8 @@ // See the LICENSE file in the project root for more information. using System.IO; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Data.IO; +using Microsoft.ML.Data; +using Microsoft.ML.Data.IO; namespace Microsoft.ML { diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/BinarySaver.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/BinarySaver.cs index 2b434e8e3c..f044a9eb9c 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/BinarySaver.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/BinarySaver.cs @@ -12,16 +12,16 @@ using System.Text; using System.Threading; using System.Threading.Tasks; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Data.IO; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.Data.IO; +using Microsoft.ML.Internal.Utilities; [assembly: LoadableClass(BinarySaver.Summary, typeof(BinarySaver), typeof(BinarySaver.Arguments), typeof(SignatureDataSaver), "Binary Saver", "BinarySaver", "Binary")] -namespace Microsoft.ML.Runtime.Data.IO +namespace Microsoft.ML.Data.IO { using Stopwatch = System.Diagnostics.Stopwatch; @@ -88,7 +88,7 @@ protected WritePipe(BinarySaver parent) /// /// Returns an appropriate generic WritePipe{T} for the given column. /// - public static WritePipe Create(BinarySaver parent, IRowCursor cursor, ColumnCodec col) + public static WritePipe Create(BinarySaver parent, RowCursor cursor, ColumnCodec col) { Type writePipeType = typeof(WritePipe<>).MakeGenericType(col.Codec.Type.RawType); return (WritePipe)Activator.CreateInstance(writePipeType, parent, cursor, col); @@ -109,7 +109,7 @@ private sealed class WritePipe : WritePipe private MemoryStream _currentStream; private T _value; - public WritePipe(BinarySaver parent, IRowCursor cursor, ColumnCodec col) + public WritePipe(BinarySaver parent, RowCursor cursor, ColumnCodec col) : base(parent) { var codec = col.Codec as IValueCodec; @@ -254,11 +254,11 @@ private void CompressionWorker(BlockingCollection toCompress, BlockingCol /// The channel to which we write any diagnostic information /// The offset of the metadata table of contents, or 0 if there was /// no metadata - private long WriteMetadata(BinaryWriter writer, ISchema schema, int col, IChannel ch) + private long WriteMetadata(BinaryWriter writer, Schema schema, int col, IChannel ch) { _host.AssertValue(writer); _host.AssertValue(schema); - _host.Assert(0 <= col && col < schema.ColumnCount); + _host.Assert(0 <= col && col < schema.Count); int count = 0; WriteMetadataCoreDelegate del = WriteMetadataCore; @@ -274,25 +274,25 @@ private long WriteMetadata(BinaryWriter writer, ISchema schema, int col, IChanne // track of the location and size of each for when we write the metadata table of contents. // (To be clear, this specific layout is not required by the format.) - foreach (var pair in schema.GetMetadataTypes(col)) + foreach (var metaColumn in schema[col].Metadata.Schema) { - _host.Check(!string.IsNullOrEmpty(pair.Key), "Metadata with null or empty kind detected, disallowed"); - _host.Check(pair.Value != null, "Metadata with null type detected, disallowed"); - if (!kinds.Add(pair.Key)) - throw _host.Except("Metadata with duplicate kind '{0}' encountered, disallowed", pair.Key, schema.GetColumnName(col)); - args[3] = pair.Key; - args[4] = pair.Value; - IValueCodec codec = (IValueCodec)methInfo.MakeGenericMethod(pair.Value.RawType).Invoke(this, args); + _host.Check(!string.IsNullOrEmpty(metaColumn.Name), "Metadata with null or empty kind detected, disallowed"); + _host.Check(metaColumn.Type != null, "Metadata with null type detected, disallowed"); + if (!kinds.Add(metaColumn.Name)) + throw _host.Except("Metadata with duplicate kind '{0}' encountered, disallowed", metaColumn.Name, schema[col].Name); + args[3] = metaColumn.Name; + args[4] = metaColumn.Type; + IValueCodec codec = (IValueCodec)methInfo.MakeGenericMethod(metaColumn.Type.RawType).Invoke(this, args); if (codec == null) { // Nothing was written. ch.Warning("Could not get codec for type {0}, dropping column '{1}' index {2} metadata kind '{3}'", - pair.Value, schema.GetColumnName(col), col, pair.Key); + metaColumn.Type, schema[col].Name, col, metaColumn.Name); continue; } offsets.Add(writer.BaseStream.Position); _host.CheckIO(offsets[offsets.Count - 1] > offsets[offsets.Count - 2], "Bad offsets detected during write"); - metadataInfos.Add(Tuple.Create(pair.Key, codec, (CompressionKind)args[5])); + metadataInfos.Add(Tuple.Create(metaColumn.Name, codec, (CompressionKind)args[5])); count++; } if (metadataInfos.Count == 0) @@ -341,9 +341,9 @@ private long WriteMetadata(BinaryWriter writer, ISchema schema, int col, IChanne return offsets[metadataInfos.Count]; } - private delegate IValueCodec WriteMetadataCoreDelegate(Stream stream, ISchema schema, int col, string kind, ColumnType type, out CompressionKind compression); + private delegate IValueCodec WriteMetadataCoreDelegate(Stream stream, Schema schema, int col, string kind, ColumnType type, out CompressionKind compression); - private IValueCodec WriteMetadataCore(Stream stream, ISchema schema, int col, string kind, ColumnType type, out CompressionKind compressionKind) + private IValueCodec WriteMetadataCore(Stream stream, Schema schema, int col, string kind, ColumnType type, out CompressionKind compressionKind) { _host.Assert(typeof(T) == type.RawType); IValueCodec generalCodec; @@ -354,7 +354,7 @@ private IValueCodec WriteMetadataCore(Stream stream, ISchema schema, int col, } IValueCodec codec = (IValueCodec)generalCodec; T value = default(T); - schema.GetMetadata(kind, col, ref value); + schema[col].Metadata.GetValue(kind, ref value); // Metadatas will often be pretty small, so that compression makes no sense. // We try both a compressed and uncompressed version of metadata and @@ -390,7 +390,7 @@ private IValueCodec WriteMetadataCore(Stream stream, ISchema schema, int col, } private void WriteWorker(Stream stream, BlockingCollection toWrite, ColumnCodec[] activeColumns, - ISchema sourceSchema, int rowsPerBlock, IChannelProvider cp, ExceptionMarshaller exMarshaller) + Schema sourceSchema, int rowsPerBlock, IChannelProvider cp, ExceptionMarshaller exMarshaller) { _host.AssertValue(exMarshaller); try @@ -508,7 +508,7 @@ private void WriteWorker(Stream stream, BlockingCollection toWrite, Colum // long: Offset to the start of the lookup table // long: Offset to the start of the metadata TOC entries, or 0 if this has no metadata - string name = sourceSchema.GetColumnName(active.SourceIndex); + string name = sourceSchema[active.SourceIndex].Name; writer.Write(name); int nameLen = Encoding.UTF8.GetByteCount(name); expectedPosition += Utils.Leb128IntLength((uint)nameLen) + nameLen; @@ -581,7 +581,7 @@ private void FetchWorker(BlockingCollection toCompress, IDataView data, HashSet activeSet = new HashSet(activeColumns.Select(col => col.SourceIndex)); long blockIndex = 0; int remainingInBlock = rowsPerBlock; - using (IRowCursor cursor = data.GetRowCursor(activeSet.Contains)) + using (RowCursor cursor = data.GetRowCursor(activeSet.Contains)) { WritePipe[] pipes = new WritePipe[activeColumns.Length]; for (int c = 0; c < activeColumns.Length; ++c) @@ -708,7 +708,7 @@ public void SaveData(Stream stream, IDataView data, params int[] colIndices) } } - private ColumnCodec[] GetActiveColumns(ISchema schema, int[] colIndices) + private ColumnCodec[] GetActiveColumns(Schema schema, int[] colIndices) { _host.AssertValue(schema); _host.AssertValueOrNull(colIndices); @@ -719,10 +719,10 @@ private ColumnCodec[] GetActiveColumns(ISchema schema, int[] colIndices) for (int c = 0; c < colIndices.Length; ++c) { - ColumnType type = schema.GetColumnType(colIndices[c]); + ColumnType type = schema[colIndices[c]].Type; IValueCodec codec; if (!_factory.TryGetCodec(type, out codec)) - throw _host.Except("Could not get codec for requested column {0} of type {1}", schema.GetColumnName(c), type); + throw _host.Except("Could not get codec for requested column {0} of type {1}", schema[c].Name, type); _host.Assert(type.Equals(codec.Type)); activeSourceColumns[c] = new ColumnCodec(colIndices[c], codec); } @@ -741,12 +741,12 @@ private int RowsPerBlockHeuristic(IDataView data, ColumnCodec[] actives) // First get the cursor. HashSet active = new HashSet(actives.Select(cc => cc.SourceIndex)); - IRandom rand = data.CanShuffle ? new TauswortheHybrid(_host.Rand) : null; + Random rand = data.CanShuffle ? new TauswortheHybrid(_host.Rand) : null; // Get the estimators. EstimatorDelegate del = EstimatorCore; MethodInfo methInfo = del.GetMethodInfo().GetGenericMethodDefinition(); - using (IRowCursor cursor = data.GetRowCursor(active.Contains, rand)) + using (RowCursor cursor = data.GetRowCursor(active.Contains, rand)) { object[] args = new object[] { cursor, null, null, null }; var writers = new IValueWriter[actives.Length]; @@ -776,10 +776,10 @@ private int RowsPerBlockHeuristic(IDataView data, ColumnCodec[] actives) } } - private delegate void EstimatorDelegate(IRowCursor cursor, ColumnCodec col, + private delegate void EstimatorDelegate(RowCursor cursor, ColumnCodec col, out Func fetchWriteEstimator, out IValueWriter writer); - private void EstimatorCore(IRowCursor cursor, ColumnCodec col, + private void EstimatorCore(RowCursor cursor, ColumnCodec col, out Func fetchWriteEstimator, out IValueWriter writer) { ValueGetter getter = cursor.GetGetter(col.SourceIndex); diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/BlockLookup.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/BlockLookup.cs index c5b1571f5a..ff5edd2014 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/BlockLookup.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/BlockLookup.cs @@ -2,7 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -namespace Microsoft.ML.Runtime.Data.IO +namespace Microsoft.ML.Data.IO { /// /// This structure is utilized by both the binary loader and binary saver to hold diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/CodecFactory.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/CodecFactory.cs index 39ecb66e30..66a3ceacc1 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/CodecFactory.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/CodecFactory.cs @@ -6,9 +6,9 @@ using System.Collections.Generic; using System.IO; using System.Text; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.Data.IO +namespace Microsoft.ML.Data.IO { internal sealed partial class CodecFactory { @@ -59,7 +59,7 @@ public CodecFactory(IHostEnvironment env, MemoryStreamPool memPool = null) RegisterSimpleCodec(new BoolCodec(this)); RegisterSimpleCodec(new DateTimeCodec(this)); RegisterSimpleCodec(new DateTimeOffsetCodec(this)); - RegisterSimpleCodec(new UnsafeTypeCodec(this)); + RegisterSimpleCodec(new UnsafeTypeCodec(this)); // Register the old type system reading codec. RegisterOtherCodec("DvBool", new OldBoolCodec(this).GetCodec); diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs index 11963c3700..7bef51d52f 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs @@ -5,13 +5,12 @@ using System; using System.Collections.Generic; using System.IO; -using System.Linq; using System.Runtime.InteropServices; using System.Text; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Internal.Internallearn; +using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.Data.IO +namespace Microsoft.ML.Data.IO { internal sealed partial class CodecFactory { diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/CompressionKind.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/CompressionKind.cs index 438177807a..1779ea4349 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/CompressionKind.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/CompressionKind.cs @@ -5,13 +5,11 @@ using System; using System.IO; using System.IO.Compression; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data.IO; -using Microsoft.ML.Runtime.Data.IO.Zlib; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data.IO.Zlib; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.Data.IO +namespace Microsoft.ML.Data.IO { /// /// A code indicating the kind of compression. It is supposed that each kind of compression is totally diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/Header.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/Header.cs index b552ab6523..0395094700 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/Header.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/Header.cs @@ -4,7 +4,7 @@ using System.Runtime.InteropServices; -namespace Microsoft.ML.Runtime.Data.IO +namespace Microsoft.ML.Data.IO { [StructLayout(LayoutKind.Explicit, Size = HeaderSize)] public struct Header diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/IValueCodec.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/IValueCodec.cs index 2f81e90056..5419e67ab8 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/IValueCodec.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/IValueCodec.cs @@ -5,7 +5,7 @@ using System; using System.IO; -namespace Microsoft.ML.Runtime.Data.IO +namespace Microsoft.ML.Data.IO { /// /// A value codec encapsulates implementations capable of writing and reading data of some diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/MemoryStreamPool.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/MemoryStreamPool.cs index 105b0cf8a0..5fbb5c1318 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/MemoryStreamPool.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/MemoryStreamPool.cs @@ -5,9 +5,9 @@ using System; using System.IO; using System.Threading; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.Data.IO +namespace Microsoft.ML.Data.IO { internal sealed class MemoryStreamPool { diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/UnsafeTypeOps.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/UnsafeTypeOps.cs index 7de0a60e13..d1a47f252a 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/UnsafeTypeOps.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/UnsafeTypeOps.cs @@ -2,16 +2,13 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Float = System.Single; - using System; using System.Collections.Generic; using System.IO; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; using System.Runtime.InteropServices; +using Microsoft.ML.Data; -namespace Microsoft.ML.Runtime.Internal.Internallearn +namespace Microsoft.ML.Internal.Internallearn { /// /// Represents some common global operations over a type @@ -44,7 +41,7 @@ static UnsafeTypeOpsFactory() _type2ops[typeof(Single)] = new SingleUnsafeTypeOps(); _type2ops[typeof(Double)] = new DoubleUnsafeTypeOps(); _type2ops[typeof(TimeSpan)] = new TimeSpanUnsafeTypeOps(); - _type2ops[typeof(UInt128)] = new UgUnsafeTypeOps(); + _type2ops[typeof(RowId)] = new UgUnsafeTypeOps(); } public static UnsafeTypeOps Get() @@ -189,21 +186,21 @@ public override TimeSpan Read(BinaryReader reader) } } - private sealed class UgUnsafeTypeOps : UnsafeTypeOps + private sealed class UgUnsafeTypeOps : UnsafeTypeOps { public override int Size { get { return 2 * sizeof(ulong); } } - public override unsafe void Apply(ReadOnlySpan array, Action func) + public override unsafe void Apply(ReadOnlySpan array, Action func) { - fixed (UInt128* pArray = &MemoryMarshal.GetReference(array)) + fixed (RowId* pArray = &MemoryMarshal.GetReference(array)) func(new IntPtr(pArray)); } - public override void Write(UInt128 a, BinaryWriter writer) { writer.Write(a.Lo); writer.Write(a.Hi); } - public override UInt128 Read(BinaryReader reader) + public override void Write(RowId a, BinaryWriter writer) { writer.Write(a.Low); writer.Write(a.High); } + public override RowId Read(BinaryReader reader) { ulong lo = reader.ReadUInt64(); ulong hi = reader.ReadUInt64(); - return new UInt128(lo, hi); + return new RowId(lo, hi); } } } diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/Zlib/Constants.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/Zlib/Constants.cs index 51f3869ee8..4aaec0db62 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/Zlib/Constants.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/Zlib/Constants.cs @@ -2,13 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; - -namespace Microsoft.ML.Runtime.Data.IO.Zlib +namespace Microsoft.ML.Data.IO.Zlib { /// /// See zlib.h diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/Zlib/ZDeflateStream.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/Zlib/ZDeflateStream.cs index 9b46743ab5..b7d2a98ac9 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/Zlib/ZDeflateStream.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/Zlib/ZDeflateStream.cs @@ -5,7 +5,7 @@ using System; using System.IO; -namespace Microsoft.ML.Runtime.Data.IO.Zlib +namespace Microsoft.ML.Data.IO.Zlib { public sealed class ZDeflateStream : Stream { diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/Zlib/ZInflateStream.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/Zlib/ZInflateStream.cs index cb8bef5360..871627858f 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/Zlib/ZInflateStream.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/Zlib/ZInflateStream.cs @@ -4,9 +4,8 @@ using System; using System.IO; -using Microsoft.ML.Runtime; -namespace Microsoft.ML.Runtime.Data.IO.Zlib +namespace Microsoft.ML.Data.IO.Zlib { public sealed class ZInflateStream : Stream { diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/Zlib/Zlib.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/Zlib/Zlib.cs index 7b2ae812a8..879394acfe 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/Zlib/Zlib.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/Zlib/Zlib.cs @@ -6,7 +6,7 @@ using System.Runtime.InteropServices; using System.Security; -namespace Microsoft.ML.Runtime.Data.IO.Zlib +namespace Microsoft.ML.Data.IO.Zlib { internal static class Zlib { diff --git a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs index 222d5c994a..607d3c27c9 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs @@ -2,16 +2,15 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; using System; using System.Collections.Generic; using System.IO; using System.Linq; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; using Float = System.Single; [assembly: LoadableClass(typeof(IDataLoader), typeof(CompositeDataLoader), typeof(CompositeDataLoader.Arguments), typeof(SignatureDataLoader), @@ -20,7 +19,7 @@ [assembly: LoadableClass(typeof(IDataLoader), typeof(CompositeDataLoader), null, typeof(SignatureLoadDataLoader), "Pipe DataL Loader", CompositeDataLoader.LoaderSignature)] -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// A data loader that wraps an underlying loader plus a sequence of transforms. @@ -409,7 +408,7 @@ private CompositeDataLoader(IHost host, TransformEx[] transforms) View = transforms[transforms.Length - 1].Transform; _tview = View as ITransposeDataView; - TransposeSchema = _tview?.TransposeSchema ?? new TransposerUtils.SimpleTransposeSchema(View.Schema); + _transposeSchema = _tview?.TransposeSchema ?? new TransposerUtils.SimpleTransposeSchema(View.Schema); var srcLoader = transforms[0].Transform.Source as IDataLoader; @@ -566,30 +565,30 @@ private static string GenerateTag(int index) public Schema Schema => View.Schema; - public ITransposeSchema TransposeSchema { get; } + private readonly ITransposeSchema _transposeSchema; + ITransposeSchema ITransposeDataView.TransposeSchema => _transposeSchema; - public IRowCursor GetRowCursor(Func predicate, IRandom rand = null) + public RowCursor GetRowCursor(Func predicate, Random rand = null) { _host.CheckValue(predicate, nameof(predicate)); _host.CheckValueOrNull(rand); return View.GetRowCursor(predicate, rand); } - public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, - Func predicate, int n, IRandom rand = null) + public RowCursor[] GetRowCursorSet(Func predicate, int n, Random rand = null) { _host.CheckValue(predicate, nameof(predicate)); _host.CheckValueOrNull(rand); - return View.GetRowCursorSet(out consolidator, predicate, n, rand); + return View.GetRowCursorSet(predicate, n, rand); } - public ISlotCursor GetSlotCursor(int col) + public SlotCursor GetSlotCursor(int col) { - _host.CheckParam(0 <= col && col < Schema.ColumnCount, nameof(col)); - if (TransposeSchema?.GetSlotType(col) == null) + _host.CheckParam(0 <= col && col < Schema.Count, nameof(col)); + if (_transposeSchema?.GetSlotType(col) == null) { throw _host.ExceptParam(nameof(col), "Bad call to GetSlotCursor on untransposable column '{0}'", - Schema.GetColumnName(col)); + Schema[col].Name); } _host.AssertValue(_tview); return _tview.GetSlotCursor(col); diff --git a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataReader.cs b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataReader.cs index 830684ae0d..895eb155f7 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataReader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataReader.cs @@ -2,11 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Core.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Model; using System.IO; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Model; namespace Microsoft.ML.Data { diff --git a/src/Microsoft.ML.Data/DataLoadSave/CompositeReaderEstimator.cs b/src/Microsoft.ML.Data/DataLoadSave/CompositeReaderEstimator.cs index 61f1246563..c4157caab9 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/CompositeReaderEstimator.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/CompositeReaderEstimator.cs @@ -3,9 +3,8 @@ // See the LICENSE file in the project root for more information. using Microsoft.ML.Core.Data; -using Microsoft.ML.Data; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// An estimator class for composite data reader. diff --git a/src/Microsoft.ML.Data/DataLoadSave/DataOperations.cs b/src/Microsoft.ML.Data/DataLoadSave/DataOperations.cs index cc0998da03..c07bd3b4de 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/DataOperations.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/DataOperations.cs @@ -3,11 +3,10 @@ // See the LICENSE file in the project root for more information. using Microsoft.ML.Data; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Transforms; -namespace Microsoft.ML.Runtime +namespace Microsoft.ML { /// /// A catalog of operations over data that are not transformers or estimators. diff --git a/src/Microsoft.ML.Data/DataLoadSave/EstimatorChain.cs b/src/Microsoft.ML.Data/DataLoadSave/EstimatorChain.cs index 7d20d22761..e31621dfe9 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/EstimatorChain.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/EstimatorChain.cs @@ -2,12 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Core.Data; -using Microsoft.ML.Data; -using Microsoft.ML.Runtime.Internal.Utilities; using System.Linq; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// Represents a chain (potentially empty) of estimators that end with a . diff --git a/src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs b/src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs index f64ecb4ef2..6e557561b0 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs @@ -2,12 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; using Microsoft.ML.Core.Data; using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; -using System; +using Microsoft.ML.Internal.Utilities; namespace Microsoft.ML { diff --git a/src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs b/src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs index d94219a453..f1cb7bf3ae 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs @@ -2,109 +2,68 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; using Microsoft.ML.Core.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; -using System.Collections.Generic; -using System.Linq; +using Microsoft.ML.Internal.Utilities; namespace Microsoft.ML.Data.DataLoadSave { - /// /// A fake schema that is manufactured out of a SchemaShape. /// It will pretend that all vector sizes are equal to 10, all key value counts are equal to 10, /// and all values are defaults (for metadata). /// - internal sealed class FakeSchema : ISchema + internal static class FakeSchemaFactory { private const int AllVectorSizes = 10; private const int AllKeySizes = 10; - private readonly IHostEnvironment _env; - private readonly SchemaShape _shape; - private readonly Dictionary _colMap; - - public FakeSchema(IHostEnvironment env, SchemaShape inputShape) + public static Schema Create(SchemaShape shape) { - _env = env; - _shape = inputShape; - _colMap = Enumerable.Range(0, _shape.Columns.Length) - .ToDictionary(idx => _shape.Columns[idx].Name, idx => idx); - } + var builder = new SchemaBuilder(); - public int ColumnCount => _shape.Columns.Length; - - public string GetColumnName(int col) - { - _env.Check(0 <= col && col < ColumnCount); - return _shape.Columns[col].Name; + for (int i = 0; i < shape.Count; ++i) + { + var metaBuilder = new MetadataBuilder(); + var partialMetadata = shape[i].Metadata; + for (int j = 0; j < partialMetadata.Count; ++j) + { + var metaColumnType = MakeColumnType(partialMetadata[i]); + Delegate del; + if (metaColumnType.IsVector) + del = Utils.MarshalInvoke(GetDefaultVectorGetter, metaColumnType.ItemType.RawType); + else + del = Utils.MarshalInvoke(GetDefaultGetter, metaColumnType.RawType); + metaBuilder.Add(partialMetadata[j].Name, metaColumnType, del); + } + builder.AddColumn(shape[i].Name, MakeColumnType(shape[i])); + } + return builder.GetSchema(); } - public ColumnType GetColumnType(int col) + private static ColumnType MakeColumnType(SchemaShape.Column column) { - _env.Check(0 <= col && col < ColumnCount); - var inputCol = _shape.Columns[col]; - return MakeColumnType(inputCol); - } - - public bool TryGetColumnIndex(string name, out int col) => _colMap.TryGetValue(name, out col); - - private static ColumnType MakeColumnType(SchemaShape.Column inputCol) - { - ColumnType curType = inputCol.ItemType; - if (inputCol.IsKey) - curType = new KeyType(curType.AsPrimitive.RawKind, 0, AllKeySizes); - if (inputCol.Kind == SchemaShape.Column.VectorKind.VariableVector) - curType = new VectorType(curType.AsPrimitive, 0); - else if (inputCol.Kind == SchemaShape.Column.VectorKind.Vector) - curType = new VectorType(curType.AsPrimitive, AllVectorSizes); + ColumnType curType = column.ItemType; + if (column.IsKey) + curType = new KeyType(((PrimitiveType)curType).RawKind, 0, AllKeySizes); + if (column.Kind == SchemaShape.Column.VectorKind.VariableVector) + curType = new VectorType((PrimitiveType)curType, 0); + else if (column.Kind == SchemaShape.Column.VectorKind.Vector) + curType = new VectorType((PrimitiveType)curType, AllVectorSizes); return curType; } - public void GetMetadata(string kind, int col, ref TValue value) + private static Delegate GetDefaultVectorGetter() { - _env.Check(0 <= col && col < ColumnCount); - var inputCol = _shape.Columns[col]; - var metaShape = inputCol.Metadata; - if (metaShape == null || !metaShape.TryFindColumn(kind, out var metaColumn)) - throw _env.ExceptGetMetadata(); - - var colType = MakeColumnType(metaColumn); - _env.Check(colType.RawType.Equals(typeof(TValue))); - - if (colType.IsVector) - { - // This as an atypical use of VBuffer: we create it in GetMetadataVec, and then pass through - // via boxing to be returned out of this method. This is intentional. - value = (TValue)Utils.MarshalInvoke(GetMetadataVec, colType.ItemType.RawType); - } - else - value = default; + ValueGetter> getter = (ref VBuffer value) => value = new VBuffer(AllVectorSizes, 0, null, null); + return getter; } - private object GetMetadataVec() => new VBuffer(AllVectorSizes, 0, null, null); - - public ColumnType GetMetadataTypeOrNull(string kind, int col) + private static Delegate GetDefaultGetter() { - _env.Check(0 <= col && col < ColumnCount); - var inputCol = _shape.Columns[col]; - var metaShape = inputCol.Metadata; - if (metaShape == null || !metaShape.TryFindColumn(kind, out var metaColumn)) - return null; - return MakeColumnType(metaColumn); + ValueGetter getter = (ref TValue value) => value = default; + return getter; } - public IEnumerable> GetMetadataTypes(int col) - { - _env.Check(0 <= col && col < ColumnCount); - var inputCol = _shape.Columns[col]; - var metaShape = inputCol.Metadata; - if (metaShape == null) - return Enumerable.Empty>(); - - return metaShape.Columns.Select(c => new KeyValuePair(c.Name, MakeColumnType(c))); - } } } diff --git a/src/Microsoft.ML.Data/DataLoadSave/MultiFileSource.cs b/src/Microsoft.ML.Data/DataLoadSave/MultiFileSource.cs index 74fe0cb544..616587d558 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/MultiFileSource.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/MultiFileSource.cs @@ -4,9 +4,9 @@ using System; using System.IO; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// Wraps a potentially compound path as an IMultiStreamSource. @@ -20,7 +20,7 @@ public sealed class MultiFileSource : IMultiStreamSource /// /// Initializes a new instance of . - /// In case of in case of usage from Maml, the paths would be wildcard concatenated in the first string of . + /// In case of usage from Maml, the paths would be wildcard concatenated in the first string of . /// /// The paths of the files to load. public MultiFileSource(params string[] paths) diff --git a/src/Microsoft.ML.Data/DataLoadSave/PartitionedFileLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/PartitionedFileLoader.cs index 4ccda7c147..e93c5fdd8a 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/PartitionedFileLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/PartitionedFileLoader.cs @@ -7,15 +7,14 @@ using System.IO; using System.Linq; using System.Text; +using Microsoft.ML; +using Microsoft.ML.CommandLine; using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Data.Conversion; -using Microsoft.ML.Runtime.Data.IO; -using Microsoft.ML.Runtime.Data.Utilities; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Data.Conversion; +using Microsoft.ML.Data.IO; +using Microsoft.ML.Data.Utilities; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; [assembly: LoadableClass(PartitionedFileLoader.Summary, typeof(PartitionedFileLoader), typeof(PartitionedFileLoader.Arguments), typeof(SignatureDataLoader), PartitionedFileLoader.UserName, PartitionedFileLoader.LoadName, PartitionedFileLoader.ShortName)] @@ -23,7 +22,7 @@ [assembly: LoadableClass(PartitionedFileLoader.Summary, typeof(PartitionedFileLoader), null, typeof(SignatureLoadDataLoader), PartitionedFileLoader.UserName, PartitionedFileLoader.LoadName, PartitionedFileLoader.ShortName)] -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// Loads a set of directory partitioned files into an IDataView. @@ -293,16 +292,15 @@ public void Save(ModelSaveContext ctx) return null; } - public IRowCursor GetRowCursor(Func needCol, IRandom rand = null) + public RowCursor GetRowCursor(Func needCol, Random rand = null) { return new Cursor(_host, this, _files, needCol, rand); } - public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func needCol, int n, IRandom rand = null) + public RowCursor[] GetRowCursorSet(Func needCol, int n, Random rand = null) { - consolidator = null; var cursor = new Cursor(_host, this, _files, needCol, rand); - return new IRowCursor[] { cursor }; + return new RowCursor[] { cursor }; } /// @@ -329,7 +327,7 @@ private Schema CreateSchema(IExceptionContext ectx, Column[] cols, IDataLoader s } else { - var schemas = new ISchema[] + var schemas = new Schema[] { subSchema, colSchema @@ -362,7 +360,7 @@ private IDataLoader CreateLoaderFromBytes(byte[] loaderBytes, IMultiStreamSource } } - private sealed class Cursor : RootCursorBase, IRowCursor + private sealed class Cursor : RootCursorBase { private PartitionedFileLoader _parent; @@ -372,11 +370,11 @@ private sealed class Cursor : RootCursorBase, IRowCursor private Delegate[] _subGetters; // Cached getters of the sub-cursor. private ReadOnlyMemory[] _colValues; // Column values cached from the file path. - private IRowCursor _subCursor; // Sub cursor of the current file. + private RowCursor _subCursor; // Sub cursor of the current file. private IEnumerator _fileOrder; - public Cursor(IChannelProvider provider, PartitionedFileLoader parent, IMultiStreamSource files, Func predicate, IRandom rand) + public Cursor(IChannelProvider provider, PartitionedFileLoader parent, IMultiStreamSource files, Func predicate, Random rand) : base(provider) { Contracts.AssertValue(parent); @@ -397,9 +395,9 @@ public Cursor(IChannelProvider provider, PartitionedFileLoader parent, IMultiStr public override long Batch => 0; - public Schema Schema => _parent.Schema; + public override Schema Schema => _parent.Schema; - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Ch.Check(IsColumnActive(col)); @@ -412,18 +410,18 @@ public ValueGetter GetGetter(int col) return getter; } - public override ValueGetter GetIdGetter() + public override ValueGetter GetIdGetter() { return - (ref UInt128 val) => + (ref RowId val) => { Ch.Check(IsGood, "Cannot call ID getter in current state"); - val = new UInt128(0, (ulong)Position); + val = new RowId(0, (ulong)Position); }; } - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { Ch.Check(0 <= col && col < Schema.Count); return _active[col]; @@ -526,7 +524,7 @@ private void UpdateSubGetters() { if (_subActive[i]) { - var type = _subCursor.Schema.GetColumnType(i); + var type = _subCursor.Schema[i].Type; _subGetters[i] = MarshalGetter(_subCursor.GetGetter, type.RawType, i); } } @@ -561,7 +559,7 @@ private Delegate[] CreateGetters() continue; } - var type = Schema.GetColumnType(i); + var type = Schema[i].Type; // Use sub-cursor for all sub-columns. if (IsSubColumn(i)) @@ -624,7 +622,7 @@ private bool IsSubColumn(int col) private int SubColumnCount => Schema.Count - _parent._srcDirIndex.Length; - private IEnumerable CreateFileOrder(IRandom rand) + private IEnumerable CreateFileOrder(Random rand) { if (rand == null) { @@ -636,18 +634,18 @@ private IEnumerable CreateFileOrder(IRandom rand) } } - private bool SchemasMatch(ISchema schema1, ISchema schema2) + private bool SchemasMatch(Schema schema1, Schema schema2) { - if (schema1.ColumnCount != schema2.ColumnCount) + if (schema1.Count != schema2.Count) { return false; } - int colLim = schema1.ColumnCount; + int colLim = schema1.Count; for (int col = 0; col < colLim; col++) { - var type1 = schema1.GetColumnType(col); - var type2 = schema2.GetColumnType(col); + var type1 = schema1[col].Type; + var type2 = schema2[col].Type; if (!type1.Equals(type2)) { return false; diff --git a/src/Microsoft.ML.Data/DataLoadSave/PartitionedPathParser.cs b/src/Microsoft.ML.Data/DataLoadSave/PartitionedPathParser.cs index 33f0a5b5f4..5b7f3db452 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/PartitionedPathParser.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/PartitionedPathParser.cs @@ -7,12 +7,12 @@ using System.Linq; using System.Text; using System.Web; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Data.Utilities; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.Data.Utilities; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Model; [assembly: LoadableClass(SimplePartitionedPathParser.Summary, typeof(SimplePartitionedPathParser), typeof(SimplePartitionedPathParser.Arguments), typeof(PartitionedPathParser), SimplePartitionedPathParser.UserName, SimplePartitionedPathParser.LoadName, SimplePartitionedPathParser.ShortName)] @@ -28,7 +28,7 @@ [assembly: EntryPointModule(typeof(SimplePartitionedPathParser.Arguments))] [assembly: EntryPointModule(typeof(ParquetPartitionedPathParserFactory))] -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// Delegate signature for a partitioned path parser. diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/LoadColumnAttribute.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/LoadColumnAttribute.cs new file mode 100644 index 0000000000..8aefb527ea --- /dev/null +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/LoadColumnAttribute.cs @@ -0,0 +1,72 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; + +namespace Microsoft.ML.Data +{ + // REVIEW: The Start field is decorated with [Obsolete], and this warning disables using Obsolete for this class. + // The Start field should get deleted together with the Legacy API. +#pragma warning disable 618 + /// + /// Describes column information such as name and the source columns indices that this + /// column encapsulates. + /// + [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = true)] + public sealed class LoadColumnAttribute : Attribute + { + /// + /// Initializes new instance of . + /// + /// The index of the column in the text file. + // REVIEW: Remove calling the private constructor with just the start parameter, + // when the Legacy API's TextLoader gets deleted, and with it the Start field here. + public LoadColumnAttribute(int columnIndex) + : this(columnIndex.ToString()) + { + Sources.Add(new TextLoader.Range(columnIndex)); + } + + /// + /// Initializes new instance of . + /// + /// The starting column index, for the range. + /// The ending column index, for the range. + // REVIEW: Calling the private constructor with just the start parameter, is incorrect, + // but it is just temporary there, until the Legacy API's TextLoader gets deleted, together with the Start field. + public LoadColumnAttribute(int start, int end) + : this(start.ToString()) + { + Sources.Add(new TextLoader.Range(start, end)); + } + + /// + /// Initializes new instance of . + /// + /// Distinct text file column indices to load as part of this column. + // REVIEW: Calling the private constructor with just the columnIndexes[0] parameter, is incorrect, + // but it is just temporary there, until the Legacy API's TextLoader gets deleted together with the Start field. + public LoadColumnAttribute(int[] columnIndexes) + : this(columnIndexes[0].ToString()) // REVIEW: this is incorrect, but it is just temporary there, until the Legacy API's TextLoader gets deleted. + { + foreach (var col in columnIndexes) + Sources.Add(new TextLoader.Range(col)); + } + + [Obsolete("Should be deleted together with the Legacy project.")] + private LoadColumnAttribute(string start) + { + Sources = new List(); + Start = start; + } + + internal List Sources; + + [Obsolete("Should be deleted together with the Legacy project.")] + [BestFriend] + internal string Start { get; } + } +#pragma warning restore 618 +} diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs index a2fa37d08f..9619e4d0d5 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs @@ -2,17 +2,17 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Core.Data; -using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; using System; using System.Collections.Generic; using System.Linq; +using System.Reflection; using System.Text; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; using Float = System.Single; [assembly: LoadableClass(TextLoader.Summary, typeof(IDataLoader), typeof(TextLoader), typeof(TextLoader.Arguments), typeof(SignatureDataLoader), @@ -21,11 +21,10 @@ [assembly: LoadableClass(TextLoader.Summary, typeof(IDataLoader), typeof(TextLoader), null, typeof(SignatureLoadDataLoader), "Text Loader", TextLoader.LoaderSignature)] -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// Loads a text file into an IDataView. Supports basic mapping from input columns to IDataView columns. - /// Should accept any file that TlcTextInstances accepts. /// public sealed partial class TextLoader : IDataReader, ICanSaveModel { @@ -344,10 +343,10 @@ public class ArgumentsCore " missing value and an empty value is denoted by \"\". When false, consecutive separators" + " denote an empty value.", ShortName = "quote")] - public bool AllowQuoting = true; + public bool AllowQuoting = DefaultArguments.AllowQuoting; [Argument(ArgumentType.AtMostOnce, HelpText = "Whether the input may include sparse representations", ShortName = "sparse")] - public bool AllowSparse = true; + public bool AllowSparse = DefaultArguments.AllowSparse; [Argument(ArgumentType.AtMostOnce, HelpText = "Number of source columns in the text data. Default is that sparse rows contain their size information.", @@ -355,17 +354,17 @@ public class ArgumentsCore public int? InputSize; [Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Source column separator. Options: tab, space, comma, single character", ShortName = "sep")] - public string Separator = "tab"; + public string Separator = DefaultArguments.Separator.ToString(); [Argument(ArgumentType.AtMostOnce, Name = nameof(Separator), Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly, HelpText = "Source column separator.", ShortName = "sep")] - public char[] SeparatorChars = new[] { '\t' }; + public char[] SeparatorChars = new[] { DefaultArguments.Separator }; [Argument(ArgumentType.Multiple, HelpText = "Column groups. Each group is specified as name:type:numeric-ranges, eg, col=Features:R4:1-17,26,35-40", ShortName = "col", SortOrder = 1)] public Column[] Column; [Argument(ArgumentType.AtMostOnce, HelpText = "Remove trailing whitespace from lines", ShortName = "trim")] - public bool TrimWhitespace; + public bool TrimWhitespace = DefaultArguments.TrimWhitespace; [Argument(ArgumentType.AtMostOnce, ShortName = "header", HelpText = "Data file has header with feature names. Header is read only if options 'hs' and 'hf' are not specified.")] @@ -393,6 +392,15 @@ public sealed class Arguments : ArgumentsCore public long? MaxRows; } + internal static class DefaultArguments + { + internal const bool AllowQuoting = true; + internal const bool AllowSparse = true; + internal const char Separator = '\t'; + internal const bool HasHeader = false; + internal const bool TrimWhitespace = false; + } + /// /// Used as an input column range. /// A variable length segment (extending to the end of the input line) is represented by Lim == SrcLim. @@ -841,9 +849,8 @@ public void Save(ModelSaveContext ctx) Contracts.Assert((DataKind)(byte)type.RawKind == type.RawKind); ctx.Writer.Write((byte)type.RawKind); ctx.Writer.WriteBoolByte(type.IsKey); - if (type.IsKey) + if (type is KeyType key) { - var key = type.AsKey; ctx.Writer.WriteBoolByte(key.Contiguous); ctx.Writer.Write(key.Min); ctx.Writer.Write(key.Count); @@ -1008,23 +1015,38 @@ private bool HasHeader private readonly IHost _host; private const string RegistrationName = "TextLoader"; - public TextLoader(IHostEnvironment env, Column[] columns, Action advancedSettings, IMultiStreamSource dataSample = null) - : this(env, MakeArgs(columns, advancedSettings), dataSample) + /// + /// Loads a text file into an . Supports basic mapping from input columns to IDataView columns. + /// + /// The environment to use. + /// Defines a mapping between input columns in the file and IDataView columns. + /// Whether the file has a header. + /// The character used as separator between data points in a row. By default the tab character is used as separator. + /// Allows to expose items that can be used for reading. + public TextLoader(IHostEnvironment env, Column[] columns, bool hasHeader = false, char separatorChar = '\t', IMultiStreamSource dataSample = null) + : this(env, MakeArgs(columns, hasHeader, new[] { separatorChar }), dataSample) { } - private static Arguments MakeArgs(Column[] columns, Action advancedSettings) + private static Arguments MakeArgs(Column[] columns, bool hasHeader, char[] separatorChars) { - var result = new Arguments { Column = columns }; - advancedSettings?.Invoke(result); + Contracts.AssertValue(separatorChars); + var result = new Arguments { Column = columns, HasHeader = hasHeader, SeparatorChars = separatorChars}; return result; } - public TextLoader(IHostEnvironment env, Arguments args, IMultiStreamSource dataSample = null) + /// + /// Loads a text file into an . Supports basic mapping from input columns to IDataView columns. + /// + /// The environment to use. + /// Defines the settings of the load operation. + /// Allows to expose items that can be used for reading. + public TextLoader(IHostEnvironment env, Arguments args = null, IMultiStreamSource dataSample = null) { + args = args ?? new Arguments(); + Contracts.CheckValue(env, nameof(env)); _host = env.Register(RegistrationName); - _host.CheckValue(args, nameof(args)); _host.CheckValueOrNull(dataSample); @@ -1285,7 +1307,7 @@ private TextLoader(IHost host, ModelLoadContext ctx) _parser = new Parser(this); } - public static TextLoader Create(IHostEnvironment env, ModelLoadContext ctx) + internal static TextLoader Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); IHost h = env.Register(RegistrationName); @@ -1297,15 +1319,15 @@ public static TextLoader Create(IHostEnvironment env, ModelLoadContext ctx) } // These are legacy constructors needed for ComponentCatalog. - public static IDataLoader Create(IHostEnvironment env, ModelLoadContext ctx, IMultiStreamSource files) + internal static IDataLoader Create(IHostEnvironment env, ModelLoadContext ctx, IMultiStreamSource files) => (IDataLoader)Create(env, ctx).Read(files); - public static IDataLoader Create(IHostEnvironment env, Arguments args, IMultiStreamSource files) + internal static IDataLoader Create(IHostEnvironment env, Arguments args, IMultiStreamSource files) => (IDataLoader)new TextLoader(env, args, files).Read(files); /// /// Convenience method to create a and use it to read a specified file. /// - public static IDataView ReadFile(IHostEnvironment env, Arguments args, IMultiStreamSource fileSource) + internal static IDataView ReadFile(IHostEnvironment env, Arguments args, IMultiStreamSource fileSource) => new TextLoader(env, args, fileSource).Read(fileSource); public void Save(ModelSaveContext ctx) @@ -1339,6 +1361,75 @@ public void Save(ModelSaveContext ctx) public IDataView Read(string path) => Read(new MultiFileSource(path)); + internal static TextLoader CreateTextReader(IHostEnvironment host, + bool hasHeader = DefaultArguments.HasHeader, + char separator = DefaultArguments.Separator, + bool allowQuotedStrings = DefaultArguments.AllowQuoting, + bool supportSparse = DefaultArguments.AllowSparse, + bool trimWhitespace = DefaultArguments.TrimWhitespace) + { + var userType = typeof(TInput); + + var fieldInfos = userType.GetFields(BindingFlags.Public | BindingFlags.Instance); + + var propertyInfos = + userType + .GetProperties(BindingFlags.Public | BindingFlags.Instance) + .Where(x => x.CanRead && x.CanWrite && x.GetGetMethod() != null && x.GetSetMethod() != null && x.GetIndexParameters().Length == 0); + + var memberInfos = (fieldInfos as IEnumerable).Concat(propertyInfos).ToArray(); + + var columns = new List(); + + for (int index = 0; index < memberInfos.Length; index++) + { + var memberInfo = memberInfos[index]; + var mappingAttr = memberInfo.GetCustomAttribute(); + + host.Assert(mappingAttr != null, $"Field or property {memberInfo.Name} is missing the {nameof(LoadColumnAttribute)} attribute"); + + var mappingAttrName = memberInfo.GetCustomAttribute(); + + var column = new Column(); + column.Name = mappingAttrName?.Name ?? memberInfo.Name; + column.Source = mappingAttr.Sources.ToArray(); + DataKind dk; + switch (memberInfo) + { + case FieldInfo field: + if (!DataKindExtensions.TryGetDataKind(field.FieldType.IsArray ? field.FieldType.GetElementType() : field.FieldType, out dk)) + throw Contracts.Except($"Field {memberInfo.Name} is of unsupported type."); + + break; + + case PropertyInfo property: + if (!DataKindExtensions.TryGetDataKind(property.PropertyType.IsArray ? property.PropertyType.GetElementType() : property.PropertyType, out dk)) + throw Contracts.Except($"Property {memberInfo.Name} is of unsupported type."); + break; + + default: + Contracts.Assert(false); + throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo"); + } + + column.Type = dk; + + columns.Add(column); + } + + Arguments args = new Arguments + { + HasHeader = hasHeader, + SeparatorChars = new[] { separator }, + AllowQuoting = allowQuotedStrings, + AllowSparse = supportSparse, + TrimWhitespace = trimWhitespace, + Column = columns.ToArray() + }; + + return new TextLoader(host, args); + } + private sealed class BoundLoader : IDataLoader { private readonly TextLoader _reader; @@ -1364,7 +1455,7 @@ public BoundLoader(TextLoader reader, IMultiStreamSource files) public Schema Schema => _reader._bindings.AsSchema; - public IRowCursor GetRowCursor(Func predicate, IRandom rand = null) + public RowCursor GetRowCursor(Func predicate, Random rand = null) { _host.CheckValue(predicate, nameof(predicate)); _host.CheckValueOrNull(rand); @@ -1372,13 +1463,12 @@ public IRowCursor GetRowCursor(Func predicate, IRandom rand = null) return Cursor.Create(_reader, _files, active); } - public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, - Func predicate, int n, IRandom rand = null) + public RowCursor[] GetRowCursorSet(Func predicate, int n, Random rand = null) { _host.CheckValue(predicate, nameof(predicate)); _host.CheckValueOrNull(rand); var active = Utils.BuildArray(_reader._bindings.ColumnCount, predicate); - return Cursor.CreateSet(out consolidator, _reader, _files, active, n); + return Cursor.CreateSet(_reader, _files, active, n); } public void Save(ModelSaveContext ctx) => _reader.Save(ctx); diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderCursor.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderCursor.cs index ac70417f6b..a786865063 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderCursor.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderCursor.cs @@ -2,19 +2,18 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Data; -using Microsoft.ML.Runtime.Internal.Utilities; using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.Text; using System.Threading; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { public sealed partial class TextLoader { - private sealed class Cursor : RootCursorBase, IRowCursor + private sealed class Cursor : RootCursorBase { // Lines are divided into batches and processed a batch at a time. This enables // parallel parsing. @@ -133,7 +132,7 @@ private Cursor(TextLoader parent, ParseStats stats, bool[] active, LineReader re } } - public static IRowCursor Create(TextLoader parent, IMultiStreamSource files, bool[] active) + public static RowCursor Create(TextLoader parent, IMultiStreamSource files, bool[] active) { // Note that files is allowed to be empty. Contracts.AssertValue(parent); @@ -150,8 +149,7 @@ public static IRowCursor Create(TextLoader parent, IMultiStreamSource files, boo return new Cursor(parent, stats, active, reader, srcNeeded, cthd); } - public static IRowCursor[] CreateSet(out IRowCursorConsolidator consolidator, - TextLoader parent, IMultiStreamSource files, bool[] active, int n) + public static RowCursor[] CreateSet(TextLoader parent, IMultiStreamSource files, bool[] active, int n) { // Note that files is allowed to be empty. Contracts.AssertValue(parent); @@ -166,13 +164,9 @@ public static IRowCursor[] CreateSet(out IRowCursorConsolidator consolidator, var reader = new LineReader(files, BatchSize, 100, parent.HasHeader, parent._maxRows, cthd); var stats = new ParseStats(parent._host, cthd); if (cthd <= 1) - { - consolidator = null; - return new IRowCursor[1] { new Cursor(parent, stats, active, reader, srcNeeded, 1) }; - } + return new RowCursor[1] { new Cursor(parent, stats, active, reader, srcNeeded, 1) }; - consolidator = new Consolidator(cthd); - var cursors = new IRowCursor[cthd]; + var cursors = new RowCursor[cthd]; try { for (int i = 0; i < cursors.Length; i++) @@ -199,13 +193,13 @@ public static IRowCursor[] CreateSet(out IRowCursorConsolidator consolidator, } } - public override ValueGetter GetIdGetter() + public override ValueGetter GetIdGetter() { return - (ref UInt128 val) => + (ref RowId val) => { Ch.Check(IsGood, "Cannot call ID getter in current state"); - val = new UInt128((ulong)_total, 0); + val = new RowId((ulong)_total, 0); }; } @@ -273,18 +267,20 @@ public static string GetEmbeddedArgs(IMultiStreamSource files) return sb.ToString(); } - public Schema Schema => _bindings.AsSchema; + public override Schema Schema => _bindings.AsSchema; - public override void Dispose() + protected override void Dispose(bool disposing) { if (_disposed) return; - + if (disposing) + { + _ator.Dispose(); + _reader.Release(); + _stats.Release(); + } _disposed = true; - _ator.Dispose(); - _reader.Release(); - _stats.Release(); - base.Dispose(); + base.Dispose(disposing); } protected override bool MoveNextCore() @@ -301,13 +297,13 @@ protected override bool MoveNextCore() return false; } - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { Ch.Check(0 <= col && col < _bindings.Infos.Length); return _active == null || _active[col]; } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Ch.Check(IsColumnActive(col)); var fn = _getters[col] as ValueGetter; @@ -819,37 +815,6 @@ private void Parse(int tid) } } } - - /// - /// The consolidator object. This simply records the number of threads and checks - /// that they match at the end. - /// - private sealed class Consolidator : IRowCursorConsolidator - { - private int _cthd; - - public Consolidator(int cthd) - { - Contracts.Assert(cthd > 1); - _cthd = cthd; - } - - public IRowCursor CreateCursor(IChannelProvider provider, IRowCursor[] inputs) - { - Contracts.AssertValue(provider); - int cthd = Interlocked.Exchange(ref _cthd, 0); - provider.Check(cthd > 1, "Consolidator can only be used once"); - provider.Check(Utils.Size(inputs) == cthd, "Unexpected number of cursors"); - - // ConsolidateGeneric does all the standard validity checks: all cursors non-null, - // all have the same schema, all have the same active columns, and all active - // column types are cachable. - using (var ch = provider.Start("Consolidator")) - { - return DataViewUtils.ConsolidateGeneric(provider, inputs, BatchSize); - } - } - } } } } diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs index cdcb507f51..5b71bcc95e 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs @@ -4,8 +4,6 @@ #pragma warning disable 420 // volatile with Interlocked.CompareExchange -using Float = System.Single; - using System; using System.Collections.Generic; using System.Linq; @@ -13,10 +11,10 @@ using System.Runtime.CompilerServices; using System.Text; using System.Threading; -using Microsoft.ML.Runtime.Data.Conversion; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Data.Conversion; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { using Conditional = System.Diagnostics.ConditionalAttribute; @@ -663,12 +661,14 @@ public Parser(TextLoader parent) { var info = _infos[i]; - if (info.ColType.ItemType.IsKey) + if (info.ColType is KeyType keyType) { - if (!info.ColType.IsVector) - _creator[i] = cache.GetCreatorOne(info.ColType.AsKey); - else - _creator[i] = cache.GetCreatorVec(info.ColType.ItemType.AsKey); + _creator[i] = cache.GetCreatorOne(keyType); + continue; + } + else if (info.ColType is VectorType vectorType && vectorType.ItemType is KeyType vectorKeyType) + { + _creator[i] = cache.GetCreatorVec(vectorKeyType); continue; } diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderSaverCatalog.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderSaverCatalog.cs index e5de3573ee..882c6d5998 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderSaverCatalog.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderSaverCatalog.cs @@ -1,54 +1,78 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Data.IO; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.StaticPipe; -using System; -using System.Collections.Generic; using System.IO; -using System.Linq; -using System.Text; -using static Microsoft.ML.Runtime.Data.TextLoader; +using Microsoft.ML.Data; +using Microsoft.ML.Data.IO; namespace Microsoft.ML { public static class TextLoaderSaverCatalog { /// - /// Create a text reader. + /// Create a text reader . /// - /// The catalog. - /// The arguments to text reader, describing the data schema. + /// The catalog. + /// The columns of the schema. + /// Whether the file has a header. + /// The character used as separator between data points in a row. By default the tab character is used as separator. /// The optional location of a data sample. - public static TextLoader TextReader(this DataOperations catalog, - TextLoader.Arguments args, IMultiStreamSource dataSample = null) + public static TextLoader CreateTextReader(this DataOperations catalog, + TextLoader.Column[] columns, + bool hasHeader = TextLoader.DefaultArguments.HasHeader, + char separatorChar = TextLoader.DefaultArguments.Separator, + IMultiStreamSource dataSample = null) + => new TextLoader(CatalogUtils.GetEnvironment(catalog), columns, hasHeader, separatorChar, dataSample); + + /// + /// Create a text reader . + /// + /// The catalog. + /// Defines the settings of the load operation. + /// Allows to expose items that can be used for reading. + public static TextLoader CreateTextReader(this DataOperations catalog, + TextLoader.Arguments args, + IMultiStreamSource dataSample = null) => new TextLoader(CatalogUtils.GetEnvironment(catalog), args, dataSample); /// - /// Create a text reader. + /// Create a text reader by inferencing the dataset schema from a data model type. /// - /// The catalog. - /// The columns of the schema. - /// The delegate to set additional settings. - /// The optional location of a data sample. - public static TextLoader TextReader(this DataOperations catalog, - TextLoader.Column[] columns, Action advancedSettings = null, IMultiStreamSource dataSample = null) - => new TextLoader(CatalogUtils.GetEnvironment(catalog), columns, advancedSettings, dataSample); + /// The catalog. + /// Does the file contains header? + /// Column separator character. Default is '\t' + /// Whether the input may include quoted values, + /// which can contain separator characters, colons, + /// and distinguish empty values from missing values. When true, consecutive separators + /// denote a missing value and an empty value is denoted by \"\". + /// When false, consecutive separators denote an empty value. + /// Whether the input may include sparse representations for example, + /// if one of the row contains "5 2:6 4:3" that's mean there are 5 columns all zero + /// except for 3rd and 5th columns which have values 6 and 3 + /// Remove trailing whitespace from lines + public static TextLoader CreateTextReader(this DataOperations catalog, + bool hasHeader = TextLoader.DefaultArguments.HasHeader, + char separatorChar = TextLoader.DefaultArguments.Separator, + bool allowQuotedStrings = TextLoader.DefaultArguments.AllowQuoting, + bool supportSparse = TextLoader.DefaultArguments.AllowSparse, + bool trimWhitespace = TextLoader.DefaultArguments.TrimWhitespace) + => TextLoader.CreateTextReader(CatalogUtils.GetEnvironment(catalog), hasHeader, separatorChar, allowQuotedStrings, supportSparse, trimWhitespace); /// /// Read a data view from a text file using . /// - /// The catalog. + /// The catalog. /// The columns of the schema. - /// The delegate to set additional settings - /// The path to the file + /// Whether the file has a header. + /// The character used as separator between data points in a row. By default the tab character is used as separator. + /// The path to the file. /// The data view. public static IDataView ReadFromTextFile(this DataOperations catalog, - TextLoader.Column[] columns, string path, Action advancedSettings = null) + string path, + TextLoader.Column[] columns, + bool hasHeader = TextLoader.DefaultArguments.HasHeader, + char separatorChar = TextLoader.DefaultArguments.Separator) { Contracts.CheckNonEmpty(path, nameof(path)); @@ -56,29 +80,83 @@ public static IDataView ReadFromTextFile(this DataOperations catalog, // REVIEW: it is almost always a mistake to have a 'trainable' text loader here. // Therefore, we are going to disallow data sample. - var reader = new TextLoader(env, columns, advancedSettings, dataSample: null); + var reader = new TextLoader(env, columns, hasHeader, separatorChar, dataSample: null); return reader.Read(new MultiFileSource(path)); } + /// + /// Read a data view from a text file using . + /// + /// The catalog. + /// Does the file contains header? + /// Column separator character. Default is '\t' + /// Whether the input may include quoted values, + /// which can contain separator characters, colons, + /// and distinguish empty values from missing values. When true, consecutive separators + /// denote a missing value and an empty value is denoted by \"\". + /// When false, consecutive separators denote an empty value. + /// Whether the input may include sparse representations for example, + /// if one of the row contains "5 2:6 4:3" that's mean there are 5 columns all zero + /// except for 3rd and 5th columns which have values 6 and 3 + /// Remove trailing whitespace from lines + /// The path to the file. + /// The data view. + public static IDataView ReadFromTextFile(this DataOperations catalog, + string path, + bool hasHeader = TextLoader.DefaultArguments.HasHeader, + char separatorChar = TextLoader.DefaultArguments.Separator, + bool allowQuotedStrings = TextLoader.DefaultArguments.AllowQuoting, + bool supportSparse = TextLoader.DefaultArguments.AllowSparse, + bool trimWhitespace = TextLoader.DefaultArguments.TrimWhitespace) + { + Contracts.CheckNonEmpty(path, nameof(path)); + + // REVIEW: it is almost always a mistake to have a 'trainable' text loader here. + // Therefore, we are going to disallow data sample. + return TextLoader.CreateTextReader(CatalogUtils.GetEnvironment(catalog), hasHeader, separatorChar, allowQuotedStrings, supportSparse, trimWhitespace) + .Read(new MultiFileSource(path)); + } + + /// + /// Read a data view from a text file using . + /// + /// The catalog. + /// Specifies a file from which to read. + /// Defines the settings of the load operation. + public static IDataView ReadFromTextFile(this DataOperations catalog, string path, TextLoader.Arguments args = null) + { + Contracts.CheckNonEmpty(path, nameof(path)); + + var env = catalog.GetEnvironment(); + var source = new MultiFileSource(path); + + return new TextLoader(env, args, source).Read(source); + } + /// /// Save the data view as text. /// - /// The catalog. + /// The catalog. /// The data view to save. /// The stream to write to. - /// The column separator. + /// The column separator. /// Whether to write the header row. /// Whether to write the header comment with the schema. /// Whether to keep hidden columns in the dataset. - public static void SaveAsText(this DataOperations catalog, IDataView data, Stream stream, - char separator = '\t', bool headerRow = true, bool schema = true, bool keepHidden = false) + public static void SaveAsText(this DataOperations catalog, + IDataView data, + Stream stream, + char separatorChar = TextLoader.DefaultArguments.Separator, + bool headerRow = TextLoader.DefaultArguments.HasHeader, + bool schema = true, + bool keepHidden = false) { Contracts.CheckValue(catalog, nameof(catalog)); Contracts.CheckValue(data, nameof(data)); Contracts.CheckValue(stream, nameof(stream)); var env = catalog.GetEnvironment(); - var saver = new TextSaver(env, new TextSaver.Arguments { Separator = separator.ToString(), OutputHeader = headerRow, OutputSchema = schema }); + var saver = new TextSaver(env, new TextSaver.Arguments { Separator = separatorChar.ToString(), OutputHeader = headerRow, OutputSchema = schema }); using (var ch = env.Start("Saving data")) DataSaverUtils.SaveDataView(ch, saver, data, stream, keepHidden); diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderStatic.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderStatic.cs index c70a1158d9..708cd30647 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderStatic.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderStatic.cs @@ -8,7 +8,7 @@ using Microsoft.ML.StaticPipe; using Microsoft.ML.StaticPipe.Runtime; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { public sealed partial class TextLoader { diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs index 9474605079..e03816c715 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs @@ -5,18 +5,17 @@ using System; using System.IO; using System.Text; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Data.Conversion; -using Microsoft.ML.Runtime.Data.IO; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Internal.Internallearn; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.Data.Conversion; +using Microsoft.ML.Data.IO; +using Microsoft.ML.Internal.Utilities; [assembly: LoadableClass(TextSaver.Summary, typeof(TextSaver), typeof(TextSaver.Arguments), typeof(SignatureDataSaver), "Text Saver", "TextSaver", "Text", DocName = "saver/TextSaver.md")] -namespace Microsoft.ML.Runtime.Data.IO +namespace Microsoft.ML.Data.IO { public sealed class TextSaver : IDataSaver { @@ -47,11 +46,11 @@ private abstract class ValueWriter { public readonly int Source; - public static ValueWriter Create(IRowCursor cursor, int col, char sep) + public static ValueWriter Create(RowCursor cursor, int col, char sep) { Contracts.AssertValue(cursor); - ColumnType type = cursor.Schema.GetColumnType(col); + ColumnType type = cursor.Schema[col].Type; Type writePipeType; if (type.IsVector) writePipeType = typeof(VecValueWriter<>).MakeGenericType(type.ItemType.RawType); @@ -148,16 +147,16 @@ private sealed class VecValueWriter : ValueWriterBase private readonly VBuffer> _slotNames; private readonly int _slotCount; - public VecValueWriter(IRowCursor cursor, VectorType type, int source, char sep) + public VecValueWriter(RowCursor cursor, VectorType type, int source, char sep) : base(type.ItemType, source, sep) { _getSrc = cursor.GetGetter>(source); ColumnType typeNames; if (type.IsKnownSizeVector && - (typeNames = cursor.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.SlotNames, source)) != null && + (typeNames = cursor.Schema[source].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.SlotNames)?.Type) != null && typeNames.VectorSize == type.VectorSize && typeNames.ItemType.IsText) { - cursor.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, source, ref _slotNames); + cursor.Schema[source].Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref _slotNames); Contracts.Check(_slotNames.Length == typeNames.VectorSize, "Unexpected slot names length"); } _slotCount = type.VectorSize; @@ -213,11 +212,11 @@ private sealed class ValueWriter : ValueWriterBase private T _src; private string _columnName; - public ValueWriter(IRowCursor cursor, PrimitiveType type, int source, char sep) + public ValueWriter(RowCursor cursor, PrimitiveType type, int source, char sep) : base(type, source, sep) { _getSrc = cursor.GetGetter(source); - _columnName = cursor.Schema.GetColumnName(source); + _columnName = cursor.Schema[source].Name; } public override void WriteData(Action appendItem, out int length) @@ -384,11 +383,11 @@ private void WriteDataCore(IChannel ch, TextWriter writer, IDataView data, ch.AssertNonEmpty(cols); // Determine the active columns and whether there is header information. - bool[] active = new bool[data.Schema.ColumnCount]; + bool[] active = new bool[data.Schema.Count]; for (int i = 0; i < cols.Length; i++) { ch.Check(0 <= cols[i] && cols[i] < active.Length); - ch.Check(data.Schema.GetColumnType(cols[i]).ItemType.RawKind != 0); + ch.Check(data.Schema[cols[i]].Type.ItemType.RawKind != 0); active[cols[i]] = true; } @@ -399,7 +398,7 @@ private void WriteDataCore(IChannel ch, TextWriter writer, IDataView data, { if (hasHeader) continue; - var type = data.Schema.GetColumnType(cols[i]); + var type = data.Schema[cols[i]].Type; if (!type.IsVector) { hasHeader = true; @@ -407,7 +406,7 @@ private void WriteDataCore(IChannel ch, TextWriter writer, IDataView data, } if (!type.IsKnownSizeVector) continue; - var typeNames = data.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.SlotNames, cols[i]); + var typeNames = data.Schema[cols[i]].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.SlotNames)?.Type; if (typeNames != null && typeNames.VectorSize == type.VectorSize && typeNames.ItemType.IsText) hasHeader = true; } @@ -449,7 +448,7 @@ private void WriteSchemaAsComment(TextWriter writer, string str) writer.WriteLine("#@ }"); } - private string CreateLoaderArguments(ISchema schema, ValueWriter[] pipes, bool hasHeader, IChannel ch) + private string CreateLoaderArguments(Schema schema, ValueWriter[] pipes, bool hasHeader, IChannel ch) { StringBuilder sb = new StringBuilder(); if (hasHeader) @@ -462,8 +461,8 @@ private string CreateLoaderArguments(ISchema schema, ValueWriter[] pipes, bool h for (int i = 0; i < pipes.Length; i++) { int src = pipes[i].Source; - string name = schema.GetColumnName(src); - var type = schema.GetColumnType(src); + string name = schema[src].Name; + var type = schema[src].Type; var column = GetColumn(name, type, index); sb.Append(" col="); @@ -488,9 +487,8 @@ private TextLoader.Column GetColumn(string name, ColumnType type, int? start) { DataKind? kind; KeyRange keyRange = null; - if (type.ItemType.IsKey) + if (type.ItemType is KeyType key) { - var key = type.ItemType.AsKey; if (!key.Contiguous) keyRange = new KeyRange(key.Min, contiguous: false); else if (key.Count == 0) @@ -573,7 +571,7 @@ public State(TextSaver parent, TextWriter writer, ValueWriter[] pipes, bool hasH _mpslotichLim = new int[128]; } - public void Run(IRowCursor cursor, ref long count, out int minLen, out int maxLen) + public void Run(RowCursor cursor, ref long count, out int minLen, out int maxLen) { minLen = int.MaxValue; maxLen = 0; diff --git a/src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs b/src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs index 863594ce5a..ab6cc7dad6 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs @@ -2,19 +2,18 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Collections.Generic; +using Microsoft.ML; using Microsoft.ML.Core.Data; -using Microsoft.ML.Data.DataLoadSave; using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Data.IO; -using Microsoft.ML.Runtime.Model; -using System.Collections.Generic; +using Microsoft.ML.Data.DataLoadSave; +using Microsoft.ML.Data.IO; +using Microsoft.ML.Model; [assembly: LoadableClass(typeof(TransformWrapper), null, typeof(SignatureLoadModel), "Transform wrapper", TransformWrapper.LoaderSignature)] -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { // REVIEW: this class is public, as long as the Wrappers.cs in tests still rely on it. // It needs to become internal. @@ -160,7 +159,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); - var fakeSchema = Schema.Create(new FakeSchema(Host, inputSchema)); + var fakeSchema = FakeSchemaFactory.Create(inputSchema); var transformer = Fit(new EmptyDataView(Host, fakeSchema)); return SchemaShape.Create(transformer.GetOutputSchema(fakeSchema)); } @@ -179,7 +178,7 @@ protected TrivialWrapperEstimator(IHost host, TransformWrapper transformer) public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); - var fakeSchema = Schema.Create(new FakeSchema(Host, inputSchema)); + var fakeSchema = FakeSchemaFactory.Create(inputSchema); return SchemaShape.Create(Transformer.GetOutputSchema(fakeSchema)); } } diff --git a/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs b/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs index 2f355a04a1..6a2d907b91 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs @@ -2,17 +2,16 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Core.Data; -using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; using System; using System.Collections; using System.Collections.Generic; using System.IO; using System.Linq; +using Microsoft.ML; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; [assembly: LoadableClass(typeof(TransformerChain), typeof(TransformerChain), null, typeof(SignatureLoadModel), "Transformer chain", TransformerChain.LoaderSignature)] @@ -227,7 +226,7 @@ public IRowToRowMapper GetRowToRowMapper(Schema inputSchema) for (int i = 0; i < mappers.Length; ++i) { mappers[i] = _transformers[i].GetRowToRowMapper(schema); - schema = mappers[i].Schema; + schema = mappers[i].OutputSchema; } return new CompositeRowToRowMapper(inputSchema, mappers); } diff --git a/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs index 6490b91483..eb39bb0e5c 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs @@ -9,13 +9,12 @@ using System.Reflection; using System.Runtime.InteropServices; using System.Threading; +using Microsoft.ML; +using Microsoft.ML.CommandLine; using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Data.IO; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Data.IO; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; [assembly: LoadableClass(TransposeLoader.Summary, typeof(TransposeLoader), typeof(TransposeLoader.Arguments), typeof(SignatureDataLoader), "Transpose Loader", TransposeLoader.LoadName, "Transpose", "trans")] @@ -23,7 +22,7 @@ [assembly: LoadableClass(TransposeLoader.Summary, typeof(TransposeLoader), null, typeof(SignatureLoadDataLoader), "Transpose Data View Loader", TransposeLoader.LoadName)] -namespace Microsoft.ML.Runtime.Data.IO +namespace Microsoft.ML.Data.IO { /// /// The transposed loader reads the transposed binary format. This binary format, at a high level, is nothing more @@ -251,7 +250,7 @@ protected override void VerifyView(IDataView view) Host.CheckDecode(rowCount == 0 || _parent._header.RowCount == rowCount); var schema = view.Schema; - Host.CheckDecode(schema.ColumnCount == _parent._header.ColumnCount); + Host.CheckDecode(schema.Count == _parent._header.ColumnCount); } } @@ -270,7 +269,7 @@ public TransposedSubIdv(TransposeLoader parent, BinaryReader reader, int col) { // The correctness of this relies upon the schema entry being read first. Host.AssertValue(parent._schemaEntry); - Host.Assert(0 <= col && col < parent.Schema.ColumnCount); + Host.Assert(0 <= col && col < parent.Schema.Count); _col = col; // Either we have to have data, or the parent has to have explicit row data. @@ -293,8 +292,8 @@ protected override void VerifyView(IDataView view) Host.AssertValue(view); // This must have precisely one column, of type vector. var schema = view.Schema; - Host.CheckDecode(schema.ColumnCount == 1); - var ttype = schema.GetColumnType(0); + Host.CheckDecode(schema.Count == 1); + var ttype = schema[0].Type; Host.CheckDecode(ttype.IsVector); // We have no way to encode a type of zero length vectors per se in the case // when there are no rows in the original dataset, but accept that if the vector @@ -307,7 +306,7 @@ protected override void VerifyView(IDataView view) long rowCount = rowCountNull.Value; // There must be one "row" per "slot" on the column this is a transpose of. // Check that. - var type = _parent.Schema.GetColumnType(_col); + var type = _parent.Schema[_col].Type; Host.CheckDecode(type.ValueCount == rowCount); // The item types should be the same. Host.CheckDecode(type.ItemType.Equals(ttype.ItemType)); @@ -336,7 +335,7 @@ protected override void VerifyView(IDataView view) private readonly object _colTransposersLock; /// - /// Lower inclusive bound of versions this reader can read. + /// Low inclusive bound of versions this reader can read. /// private const ulong ReaderFirstVersion = 0x0001000100010001; @@ -367,7 +366,7 @@ private static VersionInfo GetVersionInfo() // to use the cursors from the schema view if convenient to do so. public Schema Schema { get { return _schemaEntry.GetView().Schema; } } - public ITransposeSchema TransposeSchema { get { return _schema; } } + ITransposeSchema ITransposeDataView.TransposeSchema { get { return _schema; } } /// /// Whether the master schema sub-IDV has the actual data. @@ -470,7 +469,7 @@ private TransposeLoader(IHost host, ModelLoadContext ctx, IDataView schemaView) _header = new Header() { - ColumnCount = schemaView.Schema.ColumnCount + ColumnCount = schemaView.Schema.Count }; _schemaEntry = new SubIdvEntry.SchemaSubIdv(this, schemaView); _host.Assert(_schemaEntry.GetViewOrNull() == schemaView); @@ -551,8 +550,8 @@ private static void SaveSchema(IHostEnvironment env, ModelSaveContext ctx, Schem var saver = new BinarySaver(env, saverArgs); // We load our schema from what amounts to a binary loader, so all columns should likewise be savable. - env.Assert(Enumerable.Range(0, schema.ColumnCount).All(c => saver.IsColumnSavable(schema.GetColumnType(c)))); - ctx.SaveBinaryStream("Schema.idv", w => saver.SaveData(w.BaseStream, noRows, Utils.GetIdentityPermutation(schema.ColumnCount))); + env.Assert(Enumerable.Range(0, schema.Count).All(c => saver.IsColumnSavable(schema[c].Type))); + ctx.SaveBinaryStream("Schema.idv", w => saver.SaveData(w.BaseStream, noRows, Utils.GetIdentityPermutation(schema.Count))); } private unsafe Header InitHeader(BinaryReader reader) @@ -612,9 +611,9 @@ private unsafe Header InitHeader(BinaryReader reader) private sealed class SchemaImpl : ITransposeSchema { private readonly TransposeLoader _parent; - private ISchema Schema { get { return _parent.Schema; } } + private Schema Schema { get { return _parent.Schema; } } private IHost Host { get { return _parent._host; } } - public int ColumnCount { get { return Schema.ColumnCount; } } + public int ColumnCount { get { return Schema.Count; } } public SchemaImpl(TransposeLoader parent) { @@ -625,7 +624,7 @@ public SchemaImpl(TransposeLoader parent) public string GetColumnName(int col) { - return Schema.GetColumnName(col); + return Schema[col].Name; } public bool TryGetColumnIndex(string name, out int col) @@ -635,22 +634,22 @@ public bool TryGetColumnIndex(string name, out int col) public ColumnType GetColumnType(int col) { - return Schema.GetColumnType(col); + return Schema[col].Type; } public ColumnType GetMetadataTypeOrNull(string kind, int col) { - return Schema.GetMetadataTypeOrNull(kind, col); + return Schema[col].Metadata.Schema.GetColumnOrNull(kind)?.Type; } public IEnumerable> GetMetadataTypes(int col) { - return Schema.GetMetadataTypes(col); + return Schema[col].Metadata.Schema.Select(c => new KeyValuePair(c.Name, c.Type)); } public void GetMetadata(string kind, int col, ref TValue value) { - Schema.GetMetadata(kind, col, ref value); + Schema[col].Metadata.GetValue(kind, ref value); } public VectorType GetSlotType(int col) @@ -659,7 +658,7 @@ public VectorType GetSlotType(int col) var view = _parent._entries[col].GetViewOrNull(); if (view == null) return null; - return view.Schema.GetColumnType(0).AsVector; + return view.Schema[0].Type as VectorType; } } @@ -668,7 +667,7 @@ public VectorType GetSlotType(int col) return _header.RowCount; } - public IRowCursor GetRowCursor(Func predicate, IRandom rand = null) + public RowCursor GetRowCursor(Func predicate, Random rand = null) { _host.CheckValue(predicate, nameof(predicate)); _host.CheckValueOrNull(rand); @@ -677,29 +676,28 @@ public IRowCursor GetRowCursor(Func predicate, IRandom rand = null) return new Cursor(this, predicate); } - public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, IRandom rand = null) + public RowCursor[] GetRowCursorSet(Func predicate, int n, Random rand = null) { _host.CheckValue(predicate, nameof(predicate)); if (HasRowData) - return _schemaEntry.GetView().GetRowCursorSet(out consolidator, predicate, n, rand); - consolidator = null; - return new IRowCursor[] { GetRowCursor(predicate, rand) }; + return _schemaEntry.GetView().GetRowCursorSet(predicate, n, rand); + return new RowCursor[] { GetRowCursor(predicate, rand) }; } - public ISlotCursor GetSlotCursor(int col) + public SlotCursor GetSlotCursor(int col) { _host.CheckParam(0 <= col && col < _header.ColumnCount, nameof(col)); var view = _entries[col].GetViewOrNull(); if (view == null) { throw _host.ExceptParam(nameof(col), "Bad call to GetSlotCursor on untransposable column '{0}'", - Schema.GetColumnName(col)); + Schema[col].Name); } _host.CheckParam(0 <= col && col < _header.ColumnCount, nameof(col)); // We don't want the type error, if there is one, to be handled by the get-getter, because // at the point we've gotten the interior cursor, but not yet constructed the slot cursor. - ColumnType cursorType = TransposeSchema.GetSlotType(col).ItemType; - IRowCursor inputCursor = view.GetRowCursor(c => true); + ColumnType cursorType = _schema.GetSlotType(col).ItemType; + RowCursor inputCursor = view.GetRowCursor(c => true); try { return Utils.MarshalInvoke(GetSlotCursorCore, cursorType.RawType, inputCursor); @@ -714,41 +712,56 @@ public ISlotCursor GetSlotCursor(int col) } } - private ISlotCursor GetSlotCursorCore(IRowCursor inputCursor) + private SlotCursor GetSlotCursorCore(RowCursor inputCursor) { return new SlotCursor(this, inputCursor); } - private sealed class SlotCursor : SynchronizedCursorBase, ISlotCursor + private sealed class SlotCursor : SlotCursor { private readonly TransposeLoader _parent; private readonly ValueGetter> _getter; + private readonly RowCursor _rowCursor; - private IHost Host { get { return _parent._host; } } - - public SlotCursor(TransposeLoader parent, IRowCursor cursor) - : base(parent._host, cursor) + public SlotCursor(TransposeLoader parent, RowCursor cursor) + : base(parent._host) { _parent = parent; - Ch.Assert(cursor.Schema.ColumnCount == 1); - Ch.Assert(cursor.Schema.GetColumnType(0).RawType == typeof(VBuffer)); - _getter = Input.GetGetter>(0); - } + Ch.AssertValue(cursor); + Ch.Assert(cursor.Schema.Count == 1); + Ch.Assert(cursor.Schema[0].Type.RawType == typeof(VBuffer)); + Ch.Assert(cursor.Schema[0].Type is VectorType); + _rowCursor = cursor; - public VectorType GetSlotType() - { - var type = Input.Schema.GetColumnType(0).AsVector; - Ch.AssertValue(type); - return type; + _getter = _rowCursor.GetGetter>(0); } - public ValueGetter> GetGetter() + public override VectorType GetSlotType() + => (VectorType)_rowCursor.Schema[0].Type; + + public override ValueGetter> GetGetter() { ValueGetter> getter = _getter as ValueGetter>; if (getter == null) throw Ch.Except("Invalid TValue: '{0}'", typeof(TValue)); return getter; } + + public override bool MoveNext() + { + return _rowCursor.MoveNext(); + } + + public override int SlotIndex + { + get + { + long pos = _rowCursor.Position; + Contracts.Assert(pos <= int.MaxValue); + return (int)pos; + } + } + } private Transposer EnsureAndGetTransposer(int col) @@ -766,10 +779,10 @@ private Transposer EnsureAndGetTransposer(int col) var view = _entries[col].GetViewOrNull(); // Since we don't have row-wise data, this view must exist. _host.AssertValue(view); - _host.Assert(view.Schema.ColumnCount == 1); + _host.Assert(view.Schema.Count == 1); var trans = _colTransposers[col] = Transposer.Create(_host, view, false, new int[] { 0 }); - _host.Assert(trans.TransposeSchema.ColumnCount == 1); - _host.Assert(trans.TransposeSchema.GetSlotType(0).ValueCount == Schema.GetColumnType(col).ValueCount); + _host.Assert(((ITransposeDataView)trans).TransposeSchema.ColumnCount == 1); + _host.Assert(((ITransposeDataView)trans).TransposeSchema.GetSlotType(0).ValueCount == Schema[col].Type.ValueCount); } } } @@ -777,16 +790,16 @@ private Transposer EnsureAndGetTransposer(int col) return _colTransposers[col]; } - private sealed class Cursor : RootCursorBase, IRowCursor + private sealed class Cursor : RootCursorBase { private readonly TransposeLoader _parent; private readonly int[] _actives; private readonly int[] _colToActivesIndex; - private readonly ICursor[] _transCursors; + private readonly SlotCursor[] _transCursors; private readonly Delegate[] _getters; private bool _disposed; - public Schema Schema => _parent.Schema; + public override Schema Schema => _parent.Schema; public override long Batch { get { return 0; } } @@ -802,22 +815,24 @@ public Cursor(TransposeLoader parent, Func pred) Ch.Assert(!_parent.HasRowData); Utils.BuildSubsetMaps(_parent._header.ColumnCount, pred, out _actives, out _colToActivesIndex); - _transCursors = new ICursor[_actives.Length]; + _transCursors = new SlotCursor[_actives.Length]; _getters = new Delegate[_actives.Length]; // The following will fill in both the _transCursors and _getters arrays. for (int i = 0; i < _actives.Length; ++i) Init(_actives[i]); } - public override void Dispose() + protected override void Dispose(bool disposing) { - if (!_disposed) + if (_disposed) + return; + if (disposing) { - _disposed = true; for (int i = 0; i < _transCursors.Length; ++i) _transCursors[i].Dispose(); - base.Dispose(); } + _disposed = true; + base.Dispose(disposing); } /// @@ -825,10 +840,10 @@ public override void Dispose() /// private void Init(int col) { - Ch.Assert(0 <= col && col < Schema.ColumnCount); + Ch.Assert(0 <= col && col < Schema.Count); Ch.Assert(_colToActivesIndex[col] >= 0); - var type = Schema.GetColumnType(col); - Ch.Assert(_parent.TransposeSchema.GetSlotType(col).ValueCount == _parent._header.RowCount); + var type = Schema[col].Type; + Ch.Assert(((ITransposeDataView)_parent).TransposeSchema.GetSlotType(col).ValueCount == _parent._header.RowCount); Action func = InitOne; if (type.IsVector) func = InitVec; @@ -838,10 +853,10 @@ private void Init(int col) private void InitOne(int col) { - var type = Schema.GetColumnType(col); + var type = Schema[col].Type; Ch.Assert(typeof(T) == type.RawType); var trans = _parent.EnsureAndGetTransposer(col); - ISlotCursor cursor = trans.GetSlotCursor(0); + SlotCursor cursor = trans.GetSlotCursor(0); ValueGetter> getter = cursor.GetGetter(); VBuffer buff = default(VBuffer); ValueGetter oneGetter = @@ -858,24 +873,24 @@ private void InitOne(int col) private void InitVec(int col) { - var type = Schema.GetColumnType(col); + var type = Schema[col].Type; Ch.Assert(type.IsVector); Ch.Assert(typeof(T) == type.ItemType.RawType); var trans = _parent.EnsureAndGetTransposer(col); - ISlotCursor cursor = trans.GetSlotCursor(0); + SlotCursor cursor = trans.GetSlotCursor(0); ValueGetter> getter = cursor.GetGetter(); int i = _colToActivesIndex[col]; _getters[i] = getter; _transCursors[i] = cursor; } - public override ValueGetter GetIdGetter() + public override ValueGetter GetIdGetter() { return - (ref UInt128 val) => + (ref RowId val) => { Ch.Check(IsGood, "Cannot call ID getter in current state"); - val = new UInt128((ulong)Position, 0); + val = new RowId((ulong)Position, 0); }; } @@ -892,25 +907,13 @@ protected override bool MoveNextCore() return more; } - protected override bool MoveManyCore(long count) - { - Ch.Assert(State != CursorState.Done); - bool more = Position < _parent._header.RowCount - count; - for (int i = 0; i < _transCursors.Length; ++i) - { - bool cMore = _transCursors[i].MoveMany(count); - Ch.Assert(cMore == more); - } - return more; - } - - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { Ch.CheckParam(0 <= col && col <= _colToActivesIndex.Length, nameof(col)); return _colToActivesIndex[col] >= 0; } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Ch.CheckParam(0 <= col && col <= _colToActivesIndex.Length, nameof(col)); Ch.CheckParam(IsColumnActive(col), nameof(col), "requested column not active"); diff --git a/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeSaver.cs b/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeSaver.cs index 9ce4ab3872..51c9d00b8f 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeSaver.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeSaver.cs @@ -2,22 +2,22 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Data.IO; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Transforms; using System; using System.Collections.Generic; using System.IO; using System.Runtime.InteropServices; using System.Text; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.Data.IO; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Transforms; [assembly: LoadableClass(TransposeSaver.Summary, typeof(TransposeSaver), typeof(TransposeSaver.Arguments), typeof(SignatureDataSaver), "Transpose Saver", TransposeSaver.LoadName, "TransposedSaver", "Transpose", "Transposed", "trans")] -namespace Microsoft.ML.Runtime.Data.IO +namespace Microsoft.ML.Data.IO { /// /// Saver for a format that can be loaded using the . @@ -70,7 +70,7 @@ public bool IsColumnSavable(ColumnType type) // an artificial vector type out of this. Obviously if you can't make a vector // out of the items, then you could not save each slot's values. var itemType = type.ItemType; - var primitiveType = itemType.AsPrimitive; + var primitiveType = itemType as PrimitiveType; if (primitiveType == null) return false; var vectorType = new VectorType(primitiveType, size: 2); @@ -131,7 +131,7 @@ private void SaveTransposedData(IChannel ch, Stream stream, ITransposeDataView d { using (var substream = new SubsetStream(stream)) { - _internalSaver.SaveData(substream, view, Utils.GetIdentityPermutation(view.Schema.ColumnCount)); + _internalSaver.SaveData(substream, view, Utils.GetIdentityPermutation(view.Schema.Count)); substream.Seek(0, SeekOrigin.End); ch.Info("Wrote {0} data view in {1} bytes", name, substream.Length); } @@ -148,7 +148,7 @@ private void SaveTransposedData(IChannel ch, Stream stream, ITransposeDataView d string msg = _writeRowData ? "row-wise data, schema, and metadata" : "schema and metadata"; viewAction(msg, subdata); foreach (var col in cols) - viewAction(data.Schema.GetColumnName(col), new TransposerUtils.SlotDataView(_host, data, col)); + viewAction(data.Schema[col].Name, new TransposerUtils.SlotDataView(_host, data, col)); // Wrote out the dataview. Write out the table offset. using (var writer = new BinaryWriter(stream, Encoding.UTF8, leaveOpen: true)) diff --git a/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs b/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs index 34eb0874c6..e1ac2a2936 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs @@ -4,7 +4,7 @@ using Microsoft.ML.Core.Data; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// The trivial implementation of that already has diff --git a/src/Microsoft.ML.Data/DataLoadSave/TrivialReaderEstimator.cs b/src/Microsoft.ML.Data/DataLoadSave/TrivialReaderEstimator.cs index 506ea8cf73..957a20ee55 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/TrivialReaderEstimator.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/TrivialReaderEstimator.cs @@ -4,7 +4,7 @@ using Microsoft.ML.Core.Data; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// The trivial wrapper for a that acts as an estimator and ignores the source. diff --git a/src/Microsoft.ML.Data/DataView/AppendRowsDataView.cs b/src/Microsoft.ML.Data/DataView/AppendRowsDataView.cs index 64f158a603..1b32f35cc5 100644 --- a/src/Microsoft.ML.Data/DataView/AppendRowsDataView.cs +++ b/src/Microsoft.ML.Data/DataView/AppendRowsDataView.cs @@ -3,14 +3,11 @@ // See the LICENSE file in the project root for more information. using System; -using System.Collections.Generic; using System.Linq; using System.Reflection; -using Microsoft.ML.Data; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { // REVIEW: Currently, to enable shuffling, we require the row counts of the sources to be known. // We can think of the shuffling in AppendRowsDataView as a two-stage process: @@ -109,10 +106,10 @@ private void CheckSchemaConsistency() const string errMsg = "Inconsistent schema: all source dataviews must have identical column names, sizes, and item types."; int startingSchemaIndex = _schema == _sources[0].Schema ? 1 : 0; - int colCount = _schema.ColumnCount; + int colCount = _schema.Count; // Check if the column counts are identical. - _host.Check(_sources.All(source => source.Schema.ColumnCount == colCount), errMsg); + _host.Check(_sources.All(source => source.Schema.Count == colCount), errMsg); for (int c = 0; c < colCount; c++) { @@ -121,9 +118,9 @@ private void CheckSchemaConsistency() for (int i = startingSchemaIndex; i < _sources.Length; i++) { - ISchema schema = _sources[i].Schema; - _host.Check(schema.GetColumnName(c) == name, errMsg); - _host.Check(schema.GetColumnType(c).SameSizeAndItemType(type), errMsg); + var schema = _sources[i].Schema; + _host.Check(schema[c].Name == name, errMsg); + _host.Check(schema[c].Type.SameSizeAndItemType(type), errMsg); } } } @@ -146,7 +143,7 @@ private void CheckSchemaConsistency() return sum; } - public IRowCursor GetRowCursor(Func needCol, IRandom rand = null) + public RowCursor GetRowCursor(Func needCol, Random rand = null) { _host.CheckValue(needCol, nameof(needCol)); if (rand == null || !_canShuffle) @@ -154,20 +151,19 @@ public IRowCursor GetRowCursor(Func needCol, IRandom rand = null) return new RandCursor(this, needCol, rand, _counts); } - public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, IRandom rand = null) + public RowCursor[] GetRowCursorSet(Func predicate, int n, Random rand = null) { - consolidator = null; - return new IRowCursor[] { GetRowCursor(predicate, rand) }; + return new RowCursor[] { GetRowCursor(predicate, rand) }; } - private abstract class CursorBase : RootCursorBase, IRowCursor + private abstract class CursorBase : RootCursorBase { protected readonly IDataView[] Sources; protected readonly Delegate[] Getters; public override long Batch => 0; - public Schema Schema { get; } + public sealed override Schema Schema { get; } public CursorBase(AppendRowsDataView parent) : base(parent._host) @@ -175,7 +171,7 @@ public CursorBase(AppendRowsDataView parent) Sources = parent._sources; Ch.AssertNonEmpty(Sources); Schema = parent._schema; - Getters = new Delegate[Schema.ColumnCount]; + Getters = new Delegate[Schema.Count]; } protected Delegate CreateGetter(int col) @@ -189,7 +185,7 @@ protected Delegate CreateGetter(int col) protected abstract ValueGetter CreateTypedGetter(int col); - public ValueGetter GetGetter(int col) + public sealed override ValueGetter GetGetter(int col) { Ch.Check(IsColumnActive(col), "The column must be active against the defined predicate."); if (!(Getters[col] is ValueGetter)) @@ -197,9 +193,9 @@ public ValueGetter GetGetter(int col) return Getters[col] as ValueGetter; } - public bool IsColumnActive(int col) + public sealed override bool IsColumnActive(int col) { - Ch.Check(0 <= col && col < Schema.ColumnCount, "Column index is out of range"); + Ch.Check(0 <= col && col < Schema.Count, "Column index is out of range"); return Getters[col] != null; } } @@ -209,8 +205,8 @@ public bool IsColumnActive(int col) /// private sealed class Cursor : CursorBase { - private IRowCursor _currentCursor; - private ValueGetter _currentIdGetter; + private RowCursor _currentCursor; + private ValueGetter _currentIdGetter; private int _currentSourceIndex; public Cursor(AppendRowsDataView parent, Func needCol) @@ -228,17 +224,17 @@ public Cursor(AppendRowsDataView parent, Func needCol) } } - public override ValueGetter GetIdGetter() + public override ValueGetter GetIdGetter() { return - (ref UInt128 val) => + (ref RowId val) => { _currentIdGetter(ref val); // While the union of all IDs may not be acceptable, by taking each // data views IDs and combining them against their source index, the // union of these IDs becomes acceptable. - // REVIEW: Convenience UInt128 constructor for this scenario? - val = val.Combine(new UInt128((ulong)_currentSourceIndex, 0)); + // REVIEW: Convenience RowId constructor for this scenario? + val = val.Combine(new RowId((ulong)_currentSourceIndex, 0)); }; } @@ -280,15 +276,16 @@ protected override bool MoveNextCore() return true; } - public override void Dispose() + protected override void Dispose(bool disposing) { - if (State != CursorState.Done) + if (State == CursorState.Done) + return; + if (disposing) { Ch.Dispose(); - if (_currentCursor != null) - _currentCursor.Dispose(); - base.Dispose(); + _currentCursor?.Dispose(); } + base.Dispose(disposing); } } @@ -299,12 +296,12 @@ public override void Dispose() /// private sealed class RandCursor : CursorBase { - private readonly IRowCursor[] _cursorSet; + private readonly RowCursor[] _cursorSet; private readonly MultinomialWithoutReplacementSampler _sampler; - private readonly IRandom _rand; + private readonly Random _rand; private int _currentSourceIndex; - public RandCursor(AppendRowsDataView parent, Func needCol, IRandom rand, int[] counts) + public RandCursor(AppendRowsDataView parent, Func needCol, Random rand, int[] counts) : base(parent) { Ch.AssertValue(needCol); @@ -313,7 +310,7 @@ public RandCursor(AppendRowsDataView parent, Func needCol, IRandom ra _rand = rand; Ch.AssertValue(counts); Ch.Assert(Sources.Length == counts.Length); - _cursorSet = new IRowCursor[counts.Length]; + _cursorSet = new RowCursor[counts.Length]; for (int i = 0; i < counts.Length; i++) { Ch.Assert(counts[i] >= 0); @@ -328,17 +325,17 @@ public RandCursor(AppendRowsDataView parent, Func needCol, IRandom ra } } - public override ValueGetter GetIdGetter() + public override ValueGetter GetIdGetter() { - ValueGetter[] idGetters = new ValueGetter[_cursorSet.Length]; + ValueGetter[] idGetters = new ValueGetter[_cursorSet.Length]; for (int i = 0; i < _cursorSet.Length; ++i) idGetters[i] = _cursorSet[i].GetIdGetter(); return - (ref UInt128 val) => + (ref RowId val) => { Ch.Check(IsGood, "Cannot call ID getter in current state"); idGetters[_currentSourceIndex](ref val); - val = val.Combine(new UInt128((ulong)_currentSourceIndex, 0)); + val = val.Combine(new RowId((ulong)_currentSourceIndex, 0)); }; } @@ -369,15 +366,17 @@ protected override bool MoveNextCore() return true; } - public override void Dispose() + protected override void Dispose(bool disposing) { - if (State != CursorState.Done) + if (State == CursorState.Done) + return; + if (disposing) { Ch.Dispose(); - foreach (IRowCursor c in _cursorSet) + foreach (RowCursor c in _cursorSet) c.Dispose(); - base.Dispose(); } + base.Dispose(disposing); } } @@ -397,7 +396,7 @@ private sealed class MultinomialWithoutReplacementSampler private const int BatchSize = 1000; private readonly int[] _rowsLeft; - private readonly IRandom _rand; + private readonly Random _rand; private readonly int[] _batch; private readonly IExceptionContext _ectx; @@ -405,7 +404,7 @@ private sealed class MultinomialWithoutReplacementSampler private int _batchPos; private int _totalLeft; - public MultinomialWithoutReplacementSampler(IExceptionContext context, int[] counts, IRandom rand) + public MultinomialWithoutReplacementSampler(IExceptionContext context, int[] counts, Random rand) { Contracts.AssertValue(context); _ectx = context; @@ -444,7 +443,7 @@ private void GenerateNextBatch() _batchEnd = newEnd; } _totalLeft -= _batchEnd; - Utils.Shuffle(_rand, _batch, 0, _batchEnd); + Utils.Shuffle(_rand, _batch.AsSpan(0, _batchEnd)); } public int Next() diff --git a/src/Microsoft.ML.Data/DataView/ArrayDataViewBuilder.cs b/src/Microsoft.ML.Data/DataView/ArrayDataViewBuilder.cs index acd69db40d..24c5cb6796 100644 --- a/src/Microsoft.ML.Data/DataView/ArrayDataViewBuilder.cs +++ b/src/Microsoft.ML.Data/DataView/ArrayDataViewBuilder.cs @@ -2,12 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Data; -using Microsoft.ML.Runtime.Internal.Utilities; using System; using System.Collections.Generic; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { using BitArray = System.Collections.BitArray; @@ -75,12 +74,18 @@ public void AddColumn(string name, PrimitiveType type, params T[] values) /// Constructs a new key column from an array where values are copied to output simply /// by being assigned. /// - public void AddColumn(string name, ValueGetter>> getKeyValues, ulong keyMin, int keyCount, params uint[] values) + /// The name of the column. + /// The delegate that does a reverse lookup based upon the given key. This is for metadata creation + /// The minimum to use. + /// The count of unique keys specified in values + /// The values to add to the column. Note that since this is creating a column, the values will be offset by 1. + public void AddColumn(string name, ValueGetter>> getKeyValues, ulong keyMin, int keyCount, params T1[] values) { _host.CheckValue(getKeyValues, nameof(getKeyValues)); _host.CheckParam(keyCount > 0, nameof(keyCount)); CheckLength(name, values); - _columns.Add(new AssignmentColumn(new KeyType(DataKind.U4, keyMin, keyCount), values)); + values.GetType().GetElementType().TryGetDataKind(out DataKind kind); + _columns.Add(new AssignmentColumn(new KeyType(kind, keyMin, keyCount), values)); _getKeyValues.Add(name, getKeyValues); _names.Add(name); } @@ -226,29 +231,27 @@ public DataView(IHostEnvironment env, ArrayDataViewBuilder builder, int rowCount _rowCount = rowCount; } - public IRowCursor GetRowCursor(Func predicate, IRandom rand = null) + public RowCursor GetRowCursor(Func predicate, Random rand = null) { _host.CheckValue(predicate, nameof(predicate)); _host.CheckValueOrNull(rand); - return new RowCursor(_host, this, predicate, rand); + return new Cursor(_host, this, predicate, rand); } - public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, - Func predicate, int n, IRandom rand = null) + public RowCursor[] GetRowCursorSet(Func predicate, int n, Random rand = null) { _host.CheckValue(predicate, nameof(predicate)); _host.CheckValueOrNull(rand); - consolidator = null; - return new IRowCursor[] { new RowCursor(_host, this, predicate, rand) }; + return new RowCursor[] { new Cursor(_host, this, predicate, rand) }; } - private sealed class RowCursor : RootCursorBase, IRowCursor + private sealed class Cursor : RootCursorBase { private readonly DataView _view; private readonly BitArray _active; private readonly int[] _indices; - public Schema Schema => _view.Schema; + public override Schema Schema => _view.Schema; public override long Batch { @@ -256,57 +259,57 @@ public override long Batch get { return 0; } } - public RowCursor(IChannelProvider provider, DataView view, Func predicate, IRandom rand) + public Cursor(IChannelProvider provider, DataView view, Func predicate, Random rand) : base(provider) { Ch.AssertValue(view); Ch.AssertValueOrNull(rand); - Ch.Assert(view.Schema.ColumnCount >= 0); + Ch.Assert(view.Schema.Count >= 0); _view = view; - _active = new BitArray(view.Schema.ColumnCount); + _active = new BitArray(view.Schema.Count); if (predicate == null) _active.SetAll(true); else { - for (int i = 0; i < view.Schema.ColumnCount; ++i) + for (int i = 0; i < view.Schema.Count; ++i) _active[i] = predicate(i); } if (rand != null) _indices = Utils.GetRandomPermutation(rand, view._rowCount); } - public override ValueGetter GetIdGetter() + public override ValueGetter GetIdGetter() { if (_indices == null) { return - (ref UInt128 val) => + (ref RowId val) => { Ch.Check(IsGood, "Cannot call ID getter in current state"); - val = new UInt128((ulong)Position, 0); + val = new RowId((ulong)Position, 0); }; } else { return - (ref UInt128 val) => + (ref RowId val) => { Ch.Check(IsGood, "Cannot call ID getter in current state"); - val = new UInt128((ulong)MappedIndex(), 0); + val = new RowId((ulong)MappedIndex(), 0); }; } } - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { - Ch.Check(0 <= col & col < Schema.ColumnCount); + Ch.Check(0 <= col & col < Schema.Count); return _active[col]; } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { - Ch.Check(0 <= col & col < Schema.ColumnCount); + Ch.Check(0 <= col & col < Schema.Count); Ch.Check(_active[col], "column is not active"); var column = _view._columns[col] as Column; if (column == null) diff --git a/src/Microsoft.ML.Data/DataView/CacheDataView.cs b/src/Microsoft.ML.Data/DataView/CacheDataView.cs index 3674ec40ca..2563fe5cc1 100644 --- a/src/Microsoft.ML.Data/DataView/CacheDataView.cs +++ b/src/Microsoft.ML.Data/DataView/CacheDataView.cs @@ -4,15 +4,13 @@ #pragma warning disable 420 // volatile with Interlocked.CompareExchange -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.Linq; using System.Runtime.CompilerServices; using System.Threading; +using Microsoft.ML.Internal.Utilities; namespace Microsoft.ML.Data { @@ -98,7 +96,7 @@ public CacheDataView(IHostEnvironment env, IDataView input, int[] prefetch) _cacheLock = new object(); _cacheFillerThreads = new ConcurrentBag(); - _caches = new ColumnCache[_subsetInput.Schema.ColumnCount]; + _caches = new ColumnCache[_subsetInput.Schema.Count]; if (Utils.Size(prefetch) > 0) KickoffFiller(prefetch); @@ -124,21 +122,21 @@ private static IDataView SelectCachableColumns(IDataView data, IHostEnvironment Array.Copy(prefetch, tmp, prefetch.Length); Array.Sort(tmp); prefetch = tmp; - if (prefetch.Length > 0 && (prefetch[0] < 0 || prefetch[prefetch.Length - 1] >= schema.ColumnCount)) + if (prefetch.Length > 0 && (prefetch[0] < 0 || prefetch[prefetch.Length - 1] >= schema.Count)) throw env.Except("Prefetch array had column indices out of range"); } int ip = 0; inputToSubset = null; - for (int c = 0; c < schema.ColumnCount; ++c) + for (int c = 0; c < schema.Count; ++c) { - var type = schema.GetColumnType(c); + var type = schema[c].Type; env.Assert(ip == prefetch.Length || c <= prefetch[ip]); if (!type.IsCachable()) { if (inputToSubset == null) { - inputToSubset = new int[schema.ColumnCount]; + inputToSubset = new int[schema.Count]; for (int cc = 0; cc < c; ++cc) inputToSubset[cc] = cc; } @@ -149,7 +147,7 @@ private static IDataView SelectCachableColumns(IDataView data, IHostEnvironment { throw env.Except( "Asked to prefetch column '{0}' into cache, but it is of unhandled type '{1}'", - schema.GetColumnName(c), type); + schema[c].Name, type); } } else @@ -182,10 +180,10 @@ private static IDataView SelectCachableColumns(IDataView data, IHostEnvironment /// if this was cachable, or else -1 if the column was not cachable public int MapInputToCacheColumnIndex(int inputIndex) { - int inputIndexLim = _inputToSubsetColIndex == null ? _subsetInput.Schema.ColumnCount : _inputToSubsetColIndex.Length; + int inputIndexLim = _inputToSubsetColIndex == null ? _subsetInput.Schema.Count : _inputToSubsetColIndex.Length; _host.CheckParam(0 <= inputIndex && inputIndex < inputIndexLim, nameof(inputIndex), "Input column index not in range"); var result = _inputToSubsetColIndex == null ? inputIndex : _inputToSubsetColIndex[inputIndex]; - _host.Assert(-1 <= result && result < _subsetInput.Schema.ColumnCount); + _host.Assert(-1 <= result && result < _subsetInput.Schema.Count); return result; } @@ -203,7 +201,7 @@ public int MapInputToCacheColumnIndex(int inputIndex) return _rowCount; } - public IRowCursor GetRowCursor(Func predicate, IRandom rand = null) + public RowCursor GetRowCursor(Func predicate, Random rand = null) { _host.CheckValue(predicate, nameof(predicate)); _host.CheckValueOrNull(rand); @@ -221,7 +219,7 @@ public IRowCursor GetRowCursor(Func predicate, IRandom rand = null) /// Returns a permutation or null. This function will return null if either /// is null, or if the row count of this cache exceeds the maximum array size. /// - private int[] GetPermutationOrNull(IRandom rand) + private int[] GetPermutationOrNull(Random rand) { if (rand == null) return null; @@ -235,7 +233,7 @@ private int[] GetPermutationOrNull(IRandom rand) return Utils.GetRandomPermutation(rand, (int)_rowCount); } - private IRowCursor GetRowCursorWaiterCore(TWaiter waiter, Func predicate, IRandom rand) + private RowCursor GetRowCursorWaiterCore(TWaiter waiter, Func predicate, Random rand) where TWaiter : struct, IWaiter { _host.AssertValue(predicate); @@ -247,8 +245,7 @@ private IRowCursor GetRowCursorWaiterCore(TWaiter waiter, Func.Create(waiter, perm)); } - public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, - Func predicate, int n, IRandom rand = null) + public RowCursor[] GetRowCursorSet(Func predicate, int n, Random rand = null) { _host.CheckValue(predicate, nameof(predicate)); _host.CheckValueOrNull(rand); @@ -256,30 +253,15 @@ public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, n = DataViewUtils.GetThreadCount(_host, n); if (n <= 1) - { - consolidator = null; - return new IRowCursor[] { GetRowCursor(predicate, rand) }; - } + return new RowCursor[] { GetRowCursor(predicate, rand) }; - consolidator = new Consolidator(); var waiter = WaiterWaiter.Create(this, predicate); if (waiter.IsTrivial) return GetRowCursorSetWaiterCore(TrivialWaiter.Create(this), predicate, n, rand); return GetRowCursorSetWaiterCore(waiter, predicate, n, rand); } - /// - /// Minimal consolidator. - /// - private sealed class Consolidator : IRowCursorConsolidator - { - public IRowCursor CreateCursor(IChannelProvider provider, IRowCursor[] inputs) - { - return DataViewUtils.ConsolidateGeneric(provider, inputs, _batchSize); - } - } - - private IRowCursor[] GetRowCursorSetWaiterCore(TWaiter waiter, Func predicate, int n, IRandom rand) + private RowCursor[] GetRowCursorSetWaiterCore(TWaiter waiter, Func predicate, int n, Random rand) where TWaiter : struct, IWaiter { _host.AssertValue(predicate); @@ -287,7 +269,7 @@ private IRowCursor[] GetRowCursorSetWaiterCore(TWaiter waiter, Func(TWaiter waiter, Func(Func predicate, TIndex index) + private RowCursor CreateCursor(Func predicate, TIndex index) where TIndex : struct, IIndex { Contracts.AssertValue(predicate); return new RowCursor(this, predicate, index); } - public IRowSeeker GetSeeker(Func predicate) + public RowSeeker GetSeeker(Func predicate) { _host.CheckValue(predicate, nameof(predicate)); // The seeker needs to know the row count when it validates the row index to move to. @@ -320,11 +302,11 @@ public IRowSeeker GetSeeker(Func predicate) return GetSeeker(predicate, waiter); } - private IRowSeeker GetSeeker(Func predicate, TWaiter waiter) + private RowSeeker GetSeeker(Func predicate, TWaiter waiter) where TWaiter : struct, IWaiter { _host.AssertValue(predicate); - return new RowSeeker(this, predicate, waiter); + return new RowSeeker(new RowSeekerCore(this, predicate, waiter)); } /// @@ -339,7 +321,7 @@ private void KickoffFiller(int[] columns) _host.AssertValue(columns); HashSet taskColumns = null; - IRowCursor cursor; + RowCursor cursor; ColumnCache[] caches; OrderedWaiter waiter; lock (_cacheLock) @@ -390,7 +372,7 @@ private void KickoffFiller(int[] columns) /// The caches we must fill and, at the end of the cursor, freeze /// The waiter to increment as we cache each additional row /// - private void Filler(IRowCursor cursor, ColumnCache[] caches, OrderedWaiter waiter) + private void Filler(RowCursor cursor, ColumnCache[] caches, OrderedWaiter waiter) { _host.AssertValue(cursor); _host.AssertValue(caches); @@ -464,15 +446,15 @@ internal void Wait() } } - private sealed class RowCursor : RowCursorSeekerBase, IRowCursor + private sealed class RowCursor : RowCursorSeekerBase where TIndex : struct, IIndex { private CursorState _state; private readonly TIndex _index; - public CursorState State { get { return _state; } } + public override CursorState State => _state; - public long Batch { get { return _index.Batch; } } + public override long Batch => _index.Batch; public RowCursor(CacheDataView parent, Func predicate, TIndex index) : base(parent, predicate) @@ -481,17 +463,11 @@ public RowCursor(CacheDataView parent, Func predicate, TIndex index) _index = index; } - public ValueGetter GetIdGetter() - { - return _index.GetIdGetter(); - } + public override ValueGetter GetIdGetter() => _index.GetIdGetter(); - public ICursor GetRootCursor() - { - return this; - } + public override RowCursor GetRootCursor() => this; - public bool MoveNext() + public override bool MoveNext() { if (_state == CursorState.Done) { @@ -502,7 +478,7 @@ public bool MoveNext() Ch.Assert(_state == CursorState.NotStarted || _state == CursorState.Good); if (_index.MoveNext()) { - Position++; + PositionCore++; Ch.Assert(Position >= 0); _state = CursorState.Good; return true; @@ -513,7 +489,7 @@ public bool MoveNext() return false; } - public bool MoveMany(long count) + public override bool MoveMany(long count) { // Note: If we decide to allow count == 0, then we need to special case // that MoveNext() has never been called. It's not entirely clear what the return @@ -529,7 +505,7 @@ public bool MoveMany(long count) Ch.Assert(_state == CursorState.NotStarted || _state == CursorState.Good); if (_index.MoveMany(count)) { - Position += count; + PositionCore += count; _state = CursorState.Good; Ch.Assert(Position >= 0); return true; @@ -556,24 +532,47 @@ protected override ValueGetter CreateGetterDelegateCore(ColumnCa } } - private sealed class RowSeeker : RowCursorSeekerBase, IRowSeeker - where TWaiter : struct, IWaiter + private sealed class RowSeeker : RowSeeker + where TWaiter : struct, IWaiter + { + private readonly RowSeekerCore _internal; + + public RowSeeker(RowSeekerCore toWrap) + { + Contracts.AssertValue(toWrap); + _internal = toWrap; + } + + public override long Position => _internal.Position; + public override long Batch => _internal.Batch; + public override Schema Schema => _internal.Schema; + + public override ValueGetter GetGetter(int col) => _internal.GetGetter(col); + public override ValueGetter GetIdGetter() => _internal.GetIdGetter(); + public override bool IsColumnActive(int col) => _internal.IsColumnActive(col); + public override bool MoveTo(long rowIndex) => _internal.MoveTo(rowIndex); + } + + private sealed class RowSeekerCore : RowCursorSeekerBase + where TWaiter : struct, IWaiter { private readonly TWaiter _waiter; - public long Batch { get { return 0; } } + public override long Batch => 0; + + public override CursorState State => throw new NotImplementedException(); - public ValueGetter GetIdGetter() + public override ValueGetter GetIdGetter() { return - (ref UInt128 val) => + (ref RowId val) => { Ch.Check(Position >= 0, "Cannot call ID getter in current state"); - val = new UInt128((ulong)Position, 0); + val = new RowId((ulong)Position, 0); }; } - public RowSeeker(CacheDataView parent, Func predicate, TWaiter waiter) + public RowSeekerCore(CacheDataView parent, Func predicate, TWaiter waiter) : base(parent, predicate) { _waiter = waiter; @@ -585,11 +584,11 @@ public bool MoveTo(long rowIndex) { // If requested row index is out of range, the row seeker // returns false and sets its position to -1. - Position = -1; + PositionCore = -1; return false; } - Position = rowIndex; + PositionCore = rowIndex; return true; } @@ -601,6 +600,10 @@ protected override ValueGetter CreateGetterDelegateCore(ColumnCa { return (ref TValue value) => cache.Fetch((int)Position, ref value); } + + public override bool MoveNext() => throw Ch.ExceptNotSupp(); + public override bool MoveMany(long count) => throw Ch.ExceptNotSupp(); + public override RowCursor GetRootCursor() => throw Ch.ExceptNotSupp(); } private interface IWaiter @@ -675,7 +678,7 @@ private sealed class WaiterWaiter : IWaiter /// /// If this is true, then a could be used instead. /// - public bool IsTrivial { get { return _waiters.Length == 0; } } + public bool IsTrivial => _waiters.Length == 0; private WaiterWaiter(CacheDataView parent, Func pred) { @@ -683,7 +686,7 @@ private WaiterWaiter(CacheDataView parent, Func pred) Contracts.AssertValue(pred); _parent = parent; - int[] actives = Enumerable.Range(0, _parent.Schema.ColumnCount).Where(pred).ToArray(); + int[] actives = Enumerable.Range(0, _parent.Schema.Count).Where(pred).ToArray(); // Kick off the thread to fill in any requested columns. _parent.KickoffFiller(actives); @@ -722,7 +725,7 @@ public static Wrapper Create(CacheDataView parent, Func pred) { private readonly WaiterWaiter _waiter; - public bool IsTrivial { get { return _waiter.IsTrivial; } } + public bool IsTrivial => _waiter.IsTrivial; public Wrapper(WaiterWaiter waiter) { @@ -730,7 +733,7 @@ public Wrapper(WaiterWaiter waiter) _waiter = waiter; } - public bool Wait(long pos) { return _waiter.Wait(pos); } + public bool Wait(long pos) => _waiter.Wait(pos); } } @@ -753,12 +756,12 @@ private interface IIndex /// An ID getter, which should be based on the value that would be returned /// from , if valid, and otherwise have undefined behavior. /// - ValueGetter GetIdGetter(); + ValueGetter GetIdGetter(); /// /// Moves to the next index. Once this or has returned /// false, it should never be called again. (This in constrast to public - /// objects, whose move methods are robust to that usage.) + /// objects, whose move methods are robust to that usage.) /// /// Whether the next index is available. bool MoveNext(); @@ -798,13 +801,13 @@ public long GetIndex() return _curr; } - public ValueGetter GetIdGetter() + public ValueGetter GetIdGetter() { return - (ref UInt128 val) => + (ref RowId val) => { Contracts.Check(_curr >= 0, "Cannot call ID getter in current state"); - val = new UInt128((ulong)_curr, 0); + val = new RowId((ulong)_curr, 0); }; } @@ -842,11 +845,11 @@ public Wrapper(SequenceIndex index) _index = index; } - public long Batch { get { return _index.Batch; } } - public long GetIndex() { return _index.GetIndex(); } - public ValueGetter GetIdGetter() { return _index.GetIdGetter(); } - public bool MoveNext() { return _index.MoveNext(); } - public bool MoveMany(long count) { return _index.MoveMany(count); } + public long Batch => _index.Batch; + public long GetIndex() => _index.GetIndex(); + public ValueGetter GetIdGetter() => _index.GetIdGetter(); + public bool MoveNext() => _index.MoveNext(); + public bool MoveMany(long count) => _index.MoveMany(count); } } @@ -873,13 +876,13 @@ public long GetIndex() return _perm[_curr]; } - public ValueGetter GetIdGetter() + public ValueGetter GetIdGetter() { return - (ref UInt128 val) => + (ref RowId val) => { Contracts.Check(_curr >= 0, "Cannot call ID getter in current state"); - val = new UInt128((ulong)_perm[_curr], 0); + val = new RowId((ulong)_perm[_curr], 0); }; } @@ -933,11 +936,11 @@ public Wrapper(RandomIndex index) _index = index; } - public long Batch { get { return _index.Batch; } } - public long GetIndex() { return _index.GetIndex(); } - public ValueGetter GetIdGetter() { return _index.GetIdGetter(); } - public bool MoveNext() { return _index.MoveNext(); } - public bool MoveMany(long count) { return _index.MoveMany(count); } + public long Batch => _index.Batch; + public long GetIndex() => _index.GetIndex(); + public ValueGetter GetIdGetter() => _index.GetIdGetter(); + public bool MoveNext() => _index.MoveNext(); + public bool MoveMany(long count) => _index.MoveMany(count); } } @@ -1029,13 +1032,13 @@ public long GetIndex() return _curr; } - public ValueGetter GetIdGetter() + public ValueGetter GetIdGetter() { return - (ref UInt128 val) => + (ref RowId val) => { Contracts.Check(_curr >= 0, "Cannot call ID getter in current state"); - val = new UInt128((ulong)_curr, 0); + val = new RowId((ulong)_curr, 0); }; } @@ -1105,7 +1108,7 @@ public Wrapper(BlockSequenceIndex index) public long Batch { get { return _index.Batch; } } public long GetIndex() { return _index.GetIndex(); } - public ValueGetter GetIdGetter() { return _index.GetIdGetter(); } + public ValueGetter GetIdGetter() { return _index.GetIdGetter(); } public bool MoveNext() { return _index.MoveNext(); } public bool MoveMany(long count) { return _index.MoveMany(count); } } @@ -1148,13 +1151,13 @@ public long GetIndex() return _perm[_curr]; } - public ValueGetter GetIdGetter() + public ValueGetter GetIdGetter() { return - (ref UInt128 val) => + (ref RowId val) => { Contracts.Check(_curr >= 0, "Cannot call ID getter in current state"); - val = new UInt128((ulong)_perm[_curr], 0); + val = new RowId((ulong)_perm[_curr], 0); }; } @@ -1213,35 +1216,36 @@ public Wrapper(BlockRandomIndex index) public long Batch { get { return _index.Batch; } } public long GetIndex() { return _index.GetIndex(); } - public ValueGetter GetIdGetter() { return _index.GetIdGetter(); } + public ValueGetter GetIdGetter() { return _index.GetIdGetter(); } public bool MoveNext() { return _index.MoveNext(); } public bool MoveMany(long count) { return _index.MoveMany(count); } } } - private abstract class RowCursorSeekerBase : IDisposable + private abstract class RowCursorSeekerBase : RowCursor { protected readonly CacheDataView Parent; protected readonly IChannel Ch; + protected long PositionCore; private readonly int[] _colToActivesIndex; private readonly Delegate[] _getters; private bool _disposed; - public Schema Schema => Parent.Schema; + public sealed override Schema Schema => Parent.Schema; - public long Position { get; protected set; } + public sealed override long Position => PositionCore; protected RowCursorSeekerBase(CacheDataView parent, Func predicate) { Contracts.AssertValue(parent); Parent = parent; Ch = parent._host.Start("Cursor"); - Position = -1; + PositionCore = -1; // Set up the mapping from active columns. - int colLim = Schema.ColumnCount; + int colLim = Schema.Count; int[] actives; Utils.BuildSubsetMaps(colLim, predicate, out actives, out _colToActivesIndex); // Construct the getters. Simultaneously collect whatever "waiters" @@ -1259,24 +1263,27 @@ protected RowCursorSeekerBase(CacheDataView parent, Func predicate) } } - public bool IsColumnActive(int col) + public sealed override bool IsColumnActive(int col) { Ch.CheckParam(0 <= col && col < _colToActivesIndex.Length, nameof(col)); return _colToActivesIndex[col] >= 0; } - public void Dispose() + protected sealed override void Dispose(bool disposing) { - if (!_disposed) + if (_disposed) + return; + if (disposing) { DisposeCore(); - Position = -1; + PositionCore = -1; Ch.Dispose(); - _disposed = true; } + base.Dispose(disposing); + _disposed = true; } - public ValueGetter GetGetter(int col) + public sealed override ValueGetter GetGetter(int col) { if (!IsColumnActive(col)) throw Ch.Except("Column #{0} is requested but not active in the cursor", col); @@ -1290,14 +1297,14 @@ private Delegate CreateGetterDelegate(int col) { Ch.Assert(0 <= col && col < _colToActivesIndex.Length); Ch.Assert(_colToActivesIndex[col] >= 0); - return Utils.MarshalInvoke(CreateGetterDelegate, Schema.GetColumnType(col).RawType, col); + return Utils.MarshalInvoke(CreateGetterDelegate, Schema[col].Type.RawType, col); } private Delegate CreateGetterDelegate(int col) { Ch.Assert(0 <= col && col < _colToActivesIndex.Length); Ch.Assert(_colToActivesIndex[col] >= 0); - Ch.Assert(Schema.GetColumnType(col).RawType == typeof(TValue)); + Ch.Assert(Schema[col].Type.RawType == typeof(TValue)); var cache = (ColumnCache)Parent._caches[col]; return CreateGetterDelegateCore(cache); @@ -1348,15 +1355,15 @@ protected ColumnCache(IExceptionContext ctx, OrderedWaiter waiter) /// The column of the cursor we are wrapping. /// The waiter for the filler associated with this column /// - public static ColumnCache Create(CacheDataView parent, IRowCursor input, int srcCol, OrderedWaiter waiter) + public static ColumnCache Create(CacheDataView parent, RowCursor input, int srcCol, OrderedWaiter waiter) { Contracts.AssertValue(parent); var host = parent._host; host.AssertValue(input); - host.Assert(0 <= srcCol & srcCol < input.Schema.ColumnCount); + host.Assert(0 <= srcCol & srcCol < input.Schema.Count); host.Assert(input.IsColumnActive(srcCol)); - var type = input.Schema.GetColumnType(srcCol); + var type = input.Schema[srcCol].Type; Type pipeType; if (type.IsVector) pipeType = typeof(ImplVec<>).MakeGenericType(type.ItemType.RawType); @@ -1368,7 +1375,7 @@ public static ColumnCache Create(CacheDataView parent, IRowCursor input, int src if (_pipeConstructorTypes == null) { Interlocked.CompareExchange(ref _pipeConstructorTypes, - new Type[] { typeof(CacheDataView), typeof(IRowCursor), typeof(int), typeof(OrderedWaiter) }, null); + new Type[] { typeof(CacheDataView), typeof(RowCursor), typeof(int), typeof(OrderedWaiter) }, null); } var constructor = pipeType.GetConstructor(_pipeConstructorTypes); return (ColumnCache)constructor.Invoke(new object[] { parent, input, srcCol, waiter }); @@ -1416,10 +1423,10 @@ private sealed class ImplVec : ColumnCache> // Temporary working reusable storage for caching the source data. private VBuffer _temp; - public ImplVec(CacheDataView parent, IRowCursor input, int srcCol, OrderedWaiter waiter) + public ImplVec(CacheDataView parent, RowCursor input, int srcCol, OrderedWaiter waiter) : base(parent, input, srcCol, waiter) { - var type = input.Schema.GetColumnType(srcCol); + var type = input.Schema[srcCol].Type; Ctx.Assert(type.IsVector); _uniformLength = type.VectorSize; _indices = new BigArray(); @@ -1499,7 +1506,7 @@ private sealed class ImplOne : ColumnCache private T[] _values; private ValueGetter _getter; - public ImplOne(CacheDataView parent, IRowCursor input, int srcCol, OrderedWaiter waiter) + public ImplOne(CacheDataView parent, RowCursor input, int srcCol, OrderedWaiter waiter) : base(parent, input, srcCol, waiter) { _getter = input.GetGetter(srcCol); @@ -1534,12 +1541,12 @@ public override void Freeze() private abstract class ColumnCache : ColumnCache { - public ColumnCache(CacheDataView parent, IRowCursor input, int srcCol, OrderedWaiter waiter) + public ColumnCache(CacheDataView parent, RowCursor input, int srcCol, OrderedWaiter waiter) : base(parent._host, waiter) { Contracts.AssertValue(input); - Contracts.Assert(0 <= srcCol & srcCol < input.Schema.ColumnCount); - Contracts.Assert(input.Schema.GetColumnType(srcCol).RawType == typeof(T)); + Contracts.Assert(0 <= srcCol & srcCol < input.Schema.Count); + Contracts.Assert(input.Schema[srcCol].Type.RawType == typeof(T)); } /// diff --git a/src/Microsoft.ML.Data/DataView/CompositeRowToRowMapper.cs b/src/Microsoft.ML.Data/DataView/CompositeRowToRowMapper.cs index 4d79c0d6f5..a04b226a0d 100644 --- a/src/Microsoft.ML.Data/DataView/CompositeRowToRowMapper.cs +++ b/src/Microsoft.ML.Data/DataView/CompositeRowToRowMapper.cs @@ -3,10 +3,9 @@ // See the LICENSE file in the project root for more information. using System; -using Microsoft.ML.Data; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// A row-to-row mapper that is the result of a chained application of multiple mappers. @@ -18,7 +17,7 @@ public sealed class CompositeRowToRowMapper : IRowToRowMapper private static readonly IRowToRowMapper[] _empty = new IRowToRowMapper[0]; public Schema InputSchema { get; } - public Schema Schema { get; } + public Schema OutputSchema { get; } /// /// Out of a series of mappers, construct a seemingly unitary mapper that is able to apply them in sequence. @@ -32,7 +31,7 @@ public CompositeRowToRowMapper(Schema inputSchema, IRowToRowMapper[] mappers) Contracts.CheckValueOrNull(mappers); InnerMappers = Utils.Size(mappers) > 0 ? mappers : _empty; InputSchema = inputSchema; - Schema = Utils.Size(mappers) > 0 ? mappers[mappers.Length - 1].Schema : inputSchema; + OutputSchema = Utils.Size(mappers) > 0 ? mappers[mappers.Length - 1].OutputSchema : inputSchema; } public Func GetDependencies(Func predicate) @@ -43,24 +42,23 @@ public Func GetDependencies(Func predicate) return toReturn; } - public IRow GetRow(IRow input, Func active, out Action disposer) + public Row GetRow(Row input, Func active) { Contracts.CheckValue(input, nameof(input)); Contracts.CheckValue(active, nameof(active)); Contracts.CheckParam(input.Schema == InputSchema, nameof(input), "Schema did not match original schema"); - disposer = null; if (InnerMappers.Length == 0) { bool differentActive = false; - for (int c = 0; c < input.Schema.ColumnCount; ++c) + for (int c = 0; c < input.Schema.Count; ++c) { bool wantsActive = active(c); bool isActive = input.IsColumnActive(c); differentActive |= wantsActive != isActive; if (wantsActive && !isActive) - throw Contracts.ExceptParam(nameof(input), $"Mapper required column '{input.Schema.GetColumnName(c)}' active but it was not."); + throw Contracts.ExceptParam(nameof(input), $"Mapper required column '{input.Schema[c].Name}' active but it was not."); } return input; } @@ -73,29 +71,19 @@ public IRow GetRow(IRow input, Func active, out Action disposer) for (int i = deps.Length - 1; i >= 1; --i) deps[i - 1] = InnerMappers[i].GetDependencies(deps[i]); - IRow result = input; + Row result = input; for (int i = 0; i < InnerMappers.Length; ++i) - { - result = InnerMappers[i].GetRow(result, deps[i], out var localDisp); - if (localDisp != null) - { - if (disposer == null) - disposer = localDisp; - else - disposer = localDisp + disposer; - // We want the last disposer to be called first, so the order of the addition here is important. - } - } + result = InnerMappers[i].GetRow(result, deps[i]); return result; } - private sealed class SubsetActive : IRow + private sealed class SubsetActive : Row { - private readonly IRow _row; + private readonly Row _row; private Func _pred; - public SubsetActive(IRow row, Func pred) + public SubsetActive(Row row, Func pred) { Contracts.AssertValue(row); Contracts.AssertValue(pred); @@ -103,12 +91,12 @@ public SubsetActive(IRow row, Func pred) _pred = pred; } - public Schema Schema => _row.Schema; - public long Position => _row.Position; - public long Batch => _row.Batch; - public ValueGetter GetGetter(int col) => _row.GetGetter(col); - public ValueGetter GetIdGetter() => _row.GetIdGetter(); - public bool IsColumnActive(int col) => _pred(col); + public override Schema Schema => _row.Schema; + public override long Position => _row.Position; + public override long Batch => _row.Batch; + public override ValueGetter GetGetter(int col) => _row.GetGetter(col); + public override ValueGetter GetIdGetter() => _row.GetIdGetter(); + public override bool IsColumnActive(int col) => _pred(col); } } } diff --git a/src/Microsoft.ML.Data/DataView/CompositeSchema.cs b/src/Microsoft.ML.Data/DataView/CompositeSchema.cs index 9739112c22..2d4ca1a49c 100644 --- a/src/Microsoft.ML.Data/DataView/CompositeSchema.cs +++ b/src/Microsoft.ML.Data/DataView/CompositeSchema.cs @@ -4,10 +4,10 @@ using System; using System.Collections.Generic; -using Microsoft.ML.Data; -using Microsoft.ML.Runtime.Internal.Utilities; +using System.Linq; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// A convenience class for concatenating several schemas together. @@ -15,14 +15,14 @@ namespace Microsoft.ML.Runtime.Data /// internal sealed class CompositeSchema : ISchema { - private readonly ISchema[] _sources; + private readonly Schema[] _sources; public Schema AsSchema { get; } // Zero followed by cumulative column counts. Zero being used for the empty case. private readonly int[] _cumulativeColCounts; - public CompositeSchema(ISchema[] sources) + public CompositeSchema(Schema[] sources) { Contracts.AssertNonEmpty(sources); _sources = sources; @@ -32,7 +32,7 @@ public CompositeSchema(ISchema[] sources) for (int i = 0; i < sources.Length; i++) { var schema = sources[i]; - _cumulativeColCounts[i + 1] = _cumulativeColCounts[i] + schema.ColumnCount; + _cumulativeColCounts[i + 1] = _cumulativeColCounts[i] + schema.Count; } AsSchema = Schema.Create(this); } @@ -72,7 +72,7 @@ public void GetColumnSource(int col, out int srcIndex, out int srcCol) srcIndex--; Contracts.Assert(0 <= srcIndex && srcIndex < _cumulativeColCounts.Length); srcCol = col - _cumulativeColCounts[srcIndex]; - Contracts.Assert(0 <= srcCol && srcCol < _sources[srcIndex].ColumnCount); + Contracts.Assert(0 <= srcCol && srcCol < _sources[srcIndex].Count); } public bool TryGetColumnIndex(string name, out int col) @@ -93,31 +93,31 @@ public bool TryGetColumnIndex(string name, out int col) public string GetColumnName(int col) { GetColumnSource(col, out int dv, out int srcCol); - return _sources[dv].GetColumnName(srcCol); + return _sources[dv][srcCol].Name; } public ColumnType GetColumnType(int col) { GetColumnSource(col, out int dv, out int srcCol); - return _sources[dv].GetColumnType(srcCol); + return _sources[dv][srcCol].Type; } public IEnumerable> GetMetadataTypes(int col) { GetColumnSource(col, out int dv, out int srcCol); - return _sources[dv].GetMetadataTypes(srcCol); + return _sources[dv][srcCol].Metadata.Schema.Select(c => new KeyValuePair(c.Name, c.Type)); } public ColumnType GetMetadataTypeOrNull(string kind, int col) { GetColumnSource(col, out int dv, out int srcCol); - return _sources[dv].GetMetadataTypeOrNull(kind, srcCol); + return _sources[dv][srcCol].Metadata.Schema.GetColumnOrNull(kind)?.Type; } public void GetMetadata(string kind, int col, ref TValue value) { GetColumnSource(col, out int dv, out int srcCol); - _sources[dv].GetMetadata(kind, srcCol, ref value); + _sources[dv][srcCol].Metadata.GetValue(kind, ref value); } } } diff --git a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs b/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs similarity index 88% rename from src/Microsoft.ML.Api/DataViewConstructionUtils.cs rename to src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs index 253c86aa5b..4fb1cc7f30 100644 --- a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs +++ b/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs @@ -2,17 +2,14 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Data; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; using System; using System.Collections.Generic; using System.IO; -using System.Linq; using System.Reflection; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; -namespace Microsoft.ML.Runtime.Api +namespace Microsoft.ML.Data { /// /// A helper class to create data views based on the user-provided types. @@ -76,7 +73,7 @@ public static IDataView LoadPipeWithPredictor(IHostEnvironment env, Stream model return pipe; } - public sealed class InputRow : InputRowBase, IRowBackedBy + public sealed class InputRow : InputRowBase where TRow : class { private TRow _value; @@ -110,12 +107,12 @@ public void ExtractValues(TRow row) _position++; } - public override ValueGetter GetIdGetter() + public override ValueGetter GetIdGetter() { return IdGetter; } - private void IdGetter(ref UInt128 val) => val = new UInt128((ulong)Position, 0); + private void IdGetter(ref RowId val) => val = new RowId((ulong)Position, 0); protected override TRow GetCurrentRowObject() { @@ -125,22 +122,20 @@ protected override TRow GetCurrentRowObject() } /// - /// A row that consumes items of type , and provides an . This + /// A row that consumes items of type , and provides an . This /// is in contrast to which consumes a data view row and publishes them as the output type. /// /// The input data type. - public abstract class InputRowBase : IRow + public abstract class InputRowBase : Row where TRow : class { private readonly int _colCount; private readonly Delegate[] _getters; protected readonly IHost Host; - public long Batch => 0; + public override long Batch => 0; - public Schema Schema { get; } - - public abstract long Position { get; } + public override Schema Schema { get; } public InputRowBase(IHostEnvironment env, Schema schema, InternalSchemaDefinition schemaDef, Delegate[] peeks, Func predicate) { @@ -150,14 +145,14 @@ public InputRowBase(IHostEnvironment env, Schema schema, InternalSchemaDefinitio Host.AssertValue(schemaDef); Host.AssertValue(peeks); Host.AssertValue(predicate); - Host.Assert(schema.ColumnCount == schemaDef.Columns.Length); - Host.Assert(schema.ColumnCount == peeks.Length); + Host.Assert(schema.Count == schemaDef.Columns.Length); + Host.Assert(schema.Count == peeks.Length); - _colCount = schema.ColumnCount; + _colCount = schema.Count; Schema = schema; _getters = new Delegate[_colCount]; for (int c = 0; c < _colCount; c++) - _getters[c] = predicate(c) ? CreateGetter(schema.GetColumnType(c), schemaDef.Columns[c], peeks[c]) : null; + _getters[c] = predicate(c) ? CreateGetter(schema[c].Type, schemaDef.Columns[c], peeks[c]) : null; } //private Delegate CreateGetter(SchemaProxy schema, int index, Delegate peek) @@ -210,12 +205,12 @@ private Delegate CreateGetter(ColumnType colType, InternalSchemaDefinition.Colum else Host.Assert(colType.RawType == outputType); - if (!colType.IsKey) + if (!(colType is KeyType keyType)) del = CreateDirectGetterDelegate; else { var keyRawType = colType.RawType; - Host.Assert(colType.AsKey.Contiguous); + Host.Assert(keyType.Contiguous); Func delForKey = CreateKeyGetterDelegate; return Utils.MarshalInvoke(delForKey, keyRawType, peek, colType); } @@ -301,9 +296,10 @@ private Delegate CreateDirectGetterDelegate(Delegate peekDel) private Delegate CreateKeyGetterDelegate(Delegate peekDel, ColumnType colType) { // Make sure the function is dealing with key. - Host.Check(colType.IsKey); + KeyType keyType = colType as KeyType; + Host.Check(keyType != null); // Following equations work only with contiguous key type. - Host.Check(colType.AsKey.Contiguous); + Host.Check(keyType.Contiguous); // Following equations work only with unsigned integers. Host.Check(typeof(TDst) == typeof(ulong) || typeof(TDst) == typeof(uint) || typeof(TDst) == typeof(byte) || typeof(TDst) == typeof(bool)); @@ -314,8 +310,8 @@ private Delegate CreateKeyGetterDelegate(Delegate peekDel, ColumnType colT TDst rawKeyValue = default; ulong key = 0; // the raw key value as ulong - ulong min = colType.AsKey.Min; - ulong max = min + (ulong)colType.AsKey.Count - 1; + ulong min = keyType.Min; + ulong max = min + (ulong)keyType.Count - 1; ulong result = 0; // the result as ulong ValueGetter getter = (ref TDst dst) => { @@ -332,7 +328,7 @@ private Delegate CreateKeyGetterDelegate(Delegate peekDel, ColumnType colT protected abstract TRow GetCurrentRowObject(); - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { CheckColumnInRange(col); return _getters[col] != null; @@ -344,7 +340,7 @@ private void CheckColumnInRange(int columnIndex) throw Host.Except("Column index must be between 0 and {0}", _colCount); } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { if (!IsColumnActive(col)) throw Host.Except("Column {0} is not active in the cursor", col); @@ -355,8 +351,6 @@ public ValueGetter GetGetter(int col) throw Host.Except("Invalid TValue in GetGetter for column #{0}: '{1}'", col, typeof(TValue)); return fn; } - - public abstract ValueGetter GetIdGetter(); } /// @@ -400,16 +394,40 @@ protected DataViewBase(IHostEnvironment env, string name, InternalSchemaDefiniti public abstract long? GetRowCount(); - public abstract IRowCursor GetRowCursor(Func predicate, IRandom rand = null); + public abstract RowCursor GetRowCursor(Func predicate, Random rand = null); - public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, - int n, IRandom rand = null) + public RowCursor[] GetRowCursorSet(Func predicate, int n, Random rand = null) { - consolidator = null; return new[] { GetRowCursor(predicate, rand) }; } - public abstract class DataViewCursorBase : InputRowBase, IRowCursor + public sealed class WrappedCursor : RowCursor + { + private readonly DataViewCursorBase _toWrap; + + public WrappedCursor(DataViewCursorBase toWrap) => _toWrap = toWrap; + + public override CursorState State => _toWrap.State; + public override long Position => _toWrap.Position; + public override long Batch => _toWrap.Batch; + public override Schema Schema => _toWrap.Schema; + + protected override void Dispose(bool disposing) + { + if (disposing) + _toWrap.Dispose(); + } + + public override ValueGetter GetGetter(int col) + => _toWrap.GetGetter(col); + public override ValueGetter GetIdGetter() => _toWrap.GetIdGetter(); + public override RowCursor GetRootCursor() => this; + public override bool IsColumnActive(int col) => _toWrap.IsColumnActive(col); + public override bool MoveMany(long count) => _toWrap.MoveMany(count); + public override bool MoveNext() => _toWrap.MoveNext(); + } + + public abstract class DataViewCursorBase : InputRowBase { // There is no real concept of multiple inheritance and for various reasons it was better to // descend from the row class as opposed to wrapping it, so much of this class is regrettably @@ -417,8 +435,8 @@ public abstract class DataViewCursorBase : InputRowBase, IRowCursor protected readonly DataViewBase DataView; protected readonly IChannel Ch; - private long _position; + /// /// Zero-based position of the cursor. /// @@ -445,14 +463,14 @@ protected DataViewCursorBase(IHostEnvironment env, DataViewBase dataView, /// protected bool IsGood => State == CursorState.Good; - public virtual void Dispose() + protected sealed override void Dispose(bool disposing) { - if (State != CursorState.Done) - { - Ch.Dispose(); - _position = -1; - State = CursorState.Done; - } + if (State == CursorState.Done) + return; + Ch.Dispose(); + _position = -1; + base.Dispose(disposing); + State = CursorState.Done; } public bool MoveNext() @@ -524,14 +542,6 @@ protected virtual bool MoveManyCore(long count) /// . /// protected abstract bool MoveNextCore(); - - /// - /// Returns a cursor that can be used for invoking , , - /// , and , with results identical to calling - /// those on this cursor. Generally, if the root cursor is not the same as this cursor, using - /// the root cursor will be faster. - /// - public ICursor GetRootCursor() => this; } } @@ -561,10 +571,10 @@ public override bool CanShuffle return _data.Count; } - public override IRowCursor GetRowCursor(Func predicate, IRandom rand = null) + public override RowCursor GetRowCursor(Func predicate, Random rand = null) { Host.CheckValue(predicate, nameof(predicate)); - return new Cursor(Host, "ListDataView", this, predicate, rand); + return new WrappedCursor(new Cursor(Host, "ListDataView", this, predicate, rand)); } private sealed class Cursor : DataViewCursorBase @@ -578,7 +588,7 @@ private int Index } public Cursor(IHostEnvironment env, string name, ListDataView dataView, - Func predicate, IRandom rand) + Func predicate, Random rand) : base(env, dataView, predicate) { Ch.AssertValueOrNull(rand); @@ -587,24 +597,24 @@ public Cursor(IHostEnvironment env, string name, ListDataView dataView, _permutation = Utils.GetRandomPermutation(rand, dataView._data.Count); } - public override ValueGetter GetIdGetter() + public override ValueGetter GetIdGetter() { if (_permutation == null) { return - (ref UInt128 val) => + (ref RowId val) => { Ch.Check(IsGood, "Cannot call ID getter in current state"); - val = new UInt128((ulong)Position, 0); + val = new RowId((ulong)Position, 0); }; } else { return - (ref UInt128 val) => + (ref RowId val) => { Ch.Check(IsGood, "Cannot call ID getter in current state"); - val = new UInt128((ulong)Index, 0); + val = new RowId((ulong)Index, 0); }; } } @@ -660,9 +670,9 @@ public override bool CanShuffle return (_data as ICollection)?.Count; } - public override IRowCursor GetRowCursor(Func predicate, IRandom rand = null) + public override RowCursor GetRowCursor(Func predicate, Random rand = null) { - return new Cursor(Host, this, predicate); + return new WrappedCursor (new Cursor(Host, this, predicate)); } /// @@ -677,7 +687,7 @@ public void SetData(IEnumerable data) _data = data; } - private class Cursor : DataViewCursorBase + private sealed class Cursor : DataViewCursorBase { private readonly IEnumerator _enumerator; private TRow _currentRow; @@ -689,13 +699,13 @@ public Cursor(IHostEnvironment env, StreamingDataView dataView, Func GetIdGetter() + public override ValueGetter GetIdGetter() { return - (ref UInt128 val) => + (ref RowId val) => { Ch.Check(IsGood, "Cannot call ID getter in current state"); - val = new UInt128((ulong)Position, 0); + val = new RowId((ulong)Position, 0); }; } @@ -731,15 +741,9 @@ public SingleRowLoopDataView(IHostEnvironment env, InternalSchemaDefinition sche { } - public override bool CanShuffle - { - get { return false; } - } + public override bool CanShuffle => false; - public override long? GetRowCount() - { - return null; - } + public override long? GetRowCount() => null; public void SetCurrentRowObject(TRow value) { @@ -747,10 +751,10 @@ public void SetCurrentRowObject(TRow value) _current = value; } - public override IRowCursor GetRowCursor(Func predicate, IRandom rand = null) + public override RowCursor GetRowCursor(Func predicate, Random rand = null) { Contracts.Assert(_current != null, "The current object must be set prior to cursoring"); - return new Cursor(Host, this, predicate); + return new WrappedCursor (new Cursor(Host, this, predicate)); } private sealed class Cursor : DataViewCursorBase @@ -763,20 +767,17 @@ public Cursor(IHostEnvironment env, SingleRowLoopDataView dataView, Func GetIdGetter() + public override ValueGetter GetIdGetter() { return - (ref UInt128 val) => + (ref RowId val) => { Ch.Check(IsGood, "Cannot call ID getter in current state"); - val = new UInt128((ulong)Position, 0); + val = new RowId((ulong)Position, 0); }; } - protected override TRow GetCurrentRowObject() - { - return _currentRow; - } + protected override TRow GetCurrentRowObject() => _currentRow; protected override bool MoveNextCore() { @@ -792,6 +793,7 @@ protected override bool MoveManyCore(long count) } } + [BestFriend] internal static Schema.DetachedColumn[] GetSchemaColumns(InternalSchemaDefinition schemaDefn) { Contracts.AssertValue(schemaDefn); @@ -946,7 +948,16 @@ public override ValueGetter GetGetter() throw Contracts.ExceptNotImpl("Type '{0}' is not yet supported.", typeT.FullName); } - internal override Delegate GetGetterDelegate() => Utils.MarshalInvoke(GetGetter, MetadataType.RawType); + // We want to use MarshalInvoke instead of adding custom Reflection logic for calling GetGetter + private Delegate GetGetterCore() + { + return GetGetter(); + } + + internal override Delegate GetGetterDelegate() + { + return Utils.MarshalInvoke(GetGetterCore, MetadataType.RawType); + } public class TElement { diff --git a/src/Microsoft.ML.Data/DataView/EmptyDataView.cs b/src/Microsoft.ML.Data/DataView/EmptyDataView.cs index 6e0d779958..275c0ca390 100644 --- a/src/Microsoft.ML.Data/DataView/EmptyDataView.cs +++ b/src/Microsoft.ML.Data/DataView/EmptyDataView.cs @@ -3,10 +3,9 @@ // See the LICENSE file in the project root for more information. using System; -using Microsoft.ML.Data; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// This implements a data view that has a schema, but no rows. @@ -28,26 +27,25 @@ public EmptyDataView(IHostEnvironment env, Schema schema) public long? GetRowCount() => 0; - public IRowCursor GetRowCursor(Func needCol, IRandom rand = null) + public RowCursor GetRowCursor(Func needCol, Random rand = null) { _host.CheckValue(needCol, nameof(needCol)); _host.CheckValueOrNull(rand); return new Cursor(_host, Schema, needCol); } - public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func needCol, int n, IRandom rand = null) + public RowCursor[] GetRowCursorSet(Func needCol, int n, Random rand = null) { _host.CheckValue(needCol, nameof(needCol)); _host.CheckValueOrNull(rand); - consolidator = null; return new[] { new Cursor(_host, Schema, needCol) }; } - private sealed class Cursor : RootCursorBase, IRowCursor + private sealed class Cursor : RootCursorBase { private readonly bool[] _active; - public Schema Schema { get; } + public override Schema Schema { get; } public override long Batch => 0; public Cursor(IChannelProvider provider, Schema schema, Func needCol) @@ -56,13 +54,13 @@ public Cursor(IChannelProvider provider, Schema schema, Func needCol) Ch.AssertValue(schema); Ch.AssertValue(needCol); Schema = schema; - _active = Utils.BuildArray(Schema.ColumnCount, needCol); + _active = Utils.BuildArray(Schema.Count, needCol); } - public override ValueGetter GetIdGetter() + public override ValueGetter GetIdGetter() { return - (ref UInt128 val) => + (ref RowId val) => { Ch.Assert(!IsGood); throw Ch.Except("Cannot call ID getter in current state"); @@ -71,9 +69,9 @@ public override ValueGetter GetIdGetter() protected override bool MoveNextCore() => false; - public bool IsColumnActive(int col) => 0 <= col && col < _active.Length && _active[col]; + public override bool IsColumnActive(int col) => 0 <= col && col < _active.Length && _active[col]; - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Ch.Check(IsColumnActive(col), "Can't get getter for inactive column"); return diff --git a/src/Microsoft.ML.Api/InternalSchemaDefinition.cs b/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs similarity index 98% rename from src/Microsoft.ML.Api/InternalSchemaDefinition.cs rename to src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs index f39c274df0..9c400cefcc 100644 --- a/src/Microsoft.ML.Api/InternalSchemaDefinition.cs +++ b/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs @@ -6,16 +6,14 @@ using System.Collections.Generic; using System.Linq; using System.Reflection; -using Microsoft.ML.Runtime.Data; -using static Microsoft.ML.Runtime.Api.SchemaDefinition; -namespace Microsoft.ML.Runtime.Api +namespace Microsoft.ML.Data { using Conditional = System.Diagnostics.ConditionalAttribute; - /// /// An internal class that holds the (already validated) mapping between a custom type and an IDataView schema. /// + [BestFriend] internal sealed class InternalSchemaDefinition { public readonly Column[] Columns; @@ -208,7 +206,7 @@ public static void GetVectorAndKind(Type rawType, string name, out bool isVector throw Contracts.ExceptParam(nameof(rawType), "Could not determine an IDataView type for member {0}", name); } - public static InternalSchemaDefinition Create(Type userType, Direction direction) + public static InternalSchemaDefinition Create(Type userType, SchemaDefinition.Direction direction) { var userSchemaDefinition = SchemaDefinition.Create(userType, direction); return Create(userType, userSchemaDefinition); diff --git a/src/Microsoft.ML.Data/DataView/LambdaColumnMapper.cs b/src/Microsoft.ML.Data/DataView/LambdaColumnMapper.cs index 817ba98ed1..a49318a63c 100644 --- a/src/Microsoft.ML.Data/DataView/LambdaColumnMapper.cs +++ b/src/Microsoft.ML.Data/DataView/LambdaColumnMapper.cs @@ -4,16 +4,17 @@ using System; using System.Reflection; -using Microsoft.ML.Runtime.Data.Conversion; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Data.Conversion; +using Microsoft.ML.Model; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// This applies the user provided ValueMapper to a column to produce a new column. It automatically /// injects a standard conversion from the actual type of the source column to typeSrc (if needed). /// - public static class LambdaColumnMapper + [BestFriend] + internal static class LambdaColumnMapper { // REVIEW: It would be nice to support propagation of select metadata. public static IDataView Create(IHostEnvironment env, string name, IDataView input, @@ -45,7 +46,7 @@ public static IDataView Create(IHostEnvironment env, string name, ID bool tmp = input.Schema.TryGetColumnIndex(src, out int colSrc); if (!tmp) throw env.ExceptParam(nameof(src), "The input data doesn't have a column named '{0}'", src); - var typeOrig = input.Schema.GetColumnType(colSrc); + var typeOrig = input.Schema[colSrc].Type; // REVIEW: Ideally this should support vector-type conversion. It currently doesn't. bool ident; @@ -150,7 +151,7 @@ protected override ColumnType GetColumnTypeCore(int iinfo) return _typeDst; } - protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer) + protected override Delegate GetGetterCore(IChannel ch, Row input, int iinfo, out Action disposer) { Host.AssertValueOrNull(ch); Host.AssertValue(input); diff --git a/src/Microsoft.ML.Data/DataView/LambdaFilter.cs b/src/Microsoft.ML.Data/DataView/LambdaFilter.cs index d778ddbc11..848a650fa9 100644 --- a/src/Microsoft.ML.Data/DataView/LambdaFilter.cs +++ b/src/Microsoft.ML.Data/DataView/LambdaFilter.cs @@ -4,10 +4,10 @@ using System; using System.Reflection; -using Microsoft.ML.Runtime.Data.Conversion; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Data.Conversion; +using Microsoft.ML.Model; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// This applies the user provided RefPredicate to a column and drops rows that map to false. It automatically @@ -35,7 +35,7 @@ public static IDataView Create(IHostEnvironment env, string name, IDataVie bool tmp = input.Schema.TryGetColumnIndex(src, out colSrc); if (!tmp) throw env.ExceptParam(nameof(src), "The input data doesn't have a column named '{0}'", src); - var typeOrig = input.Schema.GetColumnType(colSrc); + var typeOrig = input.Schema[colSrc].Type; // REVIEW: Ideally this should support vector-type conversion. It currently doesn't. bool ident; @@ -86,7 +86,7 @@ public Impl(IHostEnvironment env, string name, IDataView input, { Host.AssertValue(pred); Host.Assert(conv != null | typeof(T1) == typeof(T2)); - Host.Assert(0 <= colSrc & colSrc < Source.Schema.ColumnCount); + Host.Assert(0 <= colSrc & colSrc < Source.Schema.Count); _colSrc = colSrc; _pred = pred; @@ -106,7 +106,7 @@ public override void Save(ModelSaveContext ctx) return null; } - protected override IRowCursor GetRowCursorCore(Func predicate, IRandom rand = null) + protected override RowCursor GetRowCursorCore(Func predicate, Random rand = null) { Host.AssertValue(predicate, "predicate"); Host.AssertValueOrNull(rand); @@ -114,32 +114,31 @@ protected override IRowCursor GetRowCursorCore(Func predicate, IRando bool[] active; Func inputPred = GetActive(predicate, out active); var input = Source.GetRowCursor(inputPred, rand); - return new RowCursor(this, input, active); + return new Cursor(this, input, active); } - public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, - Func predicate, int n, IRandom rand = null) + public override RowCursor[] GetRowCursorSet(Func predicate, int n, Random rand = null) { Host.CheckValue(predicate, nameof(predicate)); Host.CheckValueOrNull(rand); bool[] active; Func inputPred = GetActive(predicate, out active); - var inputs = Source.GetRowCursorSet(out consolidator, inputPred, n, rand); + var inputs = Source.GetRowCursorSet(inputPred, n, rand); Host.AssertNonEmpty(inputs); // No need to split if this is given 1 input cursor. - var cursors = new IRowCursor[inputs.Length]; + var cursors = new RowCursor[inputs.Length]; for (int i = 0; i < inputs.Length; i++) - cursors[i] = new RowCursor(this, inputs[i], active); + cursors[i] = new Cursor(this, inputs[i], active); return cursors; } private Func GetActive(Func predicate, out bool[] active) { Host.AssertValue(predicate); - active = new bool[Source.Schema.ColumnCount]; - bool[] activeInput = new bool[Source.Schema.ColumnCount]; + active = new bool[Source.Schema.Count]; + bool[] activeInput = new bool[Source.Schema.Count]; for (int i = 0; i < active.Length; i++) activeInput[i] = active[i] = predicate(i); activeInput[_colSrc] = true; @@ -147,14 +146,14 @@ private Func GetActive(Func predicate, out bool[] active) } // REVIEW: Should this cache the source value like MissingValueFilter does? - private sealed class RowCursor : LinkedRowFilterCursorBase + private sealed class Cursor : LinkedRowFilterCursorBase { private readonly ValueGetter _getSrc; private readonly InPredicate _pred; private T1 _src; - public RowCursor(Impl parent, IRowCursor input, bool[] active) - : base(parent.Host, input, parent.Schema, active) + public Cursor(Impl parent, RowCursor input, bool[] active) + : base(parent.Host, input, parent.OutputSchema, active) { _getSrc = Input.GetGetter(parent._colSrc); if (parent._conv == null) diff --git a/src/Microsoft.ML.Data/DataView/OpaqueDataView.cs b/src/Microsoft.ML.Data/DataView/OpaqueDataView.cs index 1172142972..456b98d9fe 100644 --- a/src/Microsoft.ML.Data/DataView/OpaqueDataView.cs +++ b/src/Microsoft.ML.Data/DataView/OpaqueDataView.cs @@ -2,10 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Data; using System; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// Opaque IDataView implementation to provide a barrier for data pipe optimizations. @@ -27,15 +26,14 @@ public OpaqueDataView(IDataView source) return _source.GetRowCount(); } - public IRowCursor GetRowCursor(Func predicate, IRandom rand = null) + public RowCursor GetRowCursor(Func predicate, Random rand = null) { return _source.GetRowCursor(predicate, rand); } - public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, - Func predicate, int n, IRandom rand = null) + public RowCursor[] GetRowCursorSet(Func predicate, int n, Random rand = null) { - return _source.GetRowCursorSet(out consolidator, predicate, n, rand); + return _source.GetRowCursorSet(predicate, n, rand); } } } diff --git a/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs b/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs index 2de71f0a78..bc697376d8 100644 --- a/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs +++ b/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs @@ -3,30 +3,28 @@ // See the LICENSE file in the project root for more information. using System; -using System.Collections.Generic; using System.IO; using System.Linq; -using System.Reflection; +using Microsoft.ML; using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Model.Onnx; -using Microsoft.ML.Runtime.Model.Pfa; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Model.Onnx; +using Microsoft.ML.Model.Pfa; [assembly: LoadableClass(typeof(RowToRowMapperTransform), null, typeof(SignatureLoadDataTransform), "", RowToRowMapperTransform.LoaderSignature)] -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// This interface is used to create a . - /// Implementations should be given an in their constructor, and should have a + /// Implementations should be given an in their constructor, and should have a /// ctor or Create method with , along with a corresponding /// . /// - public interface IRowMapper : ICanSaveModel + [BestFriend] + internal interface IRowMapper : ICanSaveModel { /// /// Returns the input columns needed for the requested output columns. @@ -36,9 +34,11 @@ public interface IRowMapper : ICanSaveModel /// /// Returns the getters for the output columns given an active set of output columns. The length of the getters /// array should be equal to the number of columns added by the IRowMapper. It should contain the getter for the - /// i'th output column if activeOutput(i) is true, and null otherwise. + /// i'th output column if activeOutput(i) is true, and null otherwise. If creating a or + /// out of this, the delegate (if non-null) should be called + /// from the dispose of either of those instances. /// - Delegate[] CreateGetters(IRow input, Func activeOutput, out Action disposer); + Delegate[] CreateGetters(Row input, Func activeOutput, out Action disposer); /// /// Returns information about the output columns, including their name, type and any metadata information. @@ -46,7 +46,7 @@ public interface IRowMapper : ICanSaveModel Schema.DetachedColumn[] GetOutputColumns(); } - public delegate void SignatureLoadRowMapper(ModelLoadContext ctx, ISchema schema); + public delegate void SignatureLoadRowMapper(ModelLoadContext ctx, Schema schema); /// /// This class is a transform that can add any number of output columns, that depend on any number of input columns. @@ -75,27 +75,29 @@ private static VersionInfo GetVersionInfo() loaderAssemblyName: typeof(RowToRowMapperTransform).Assembly.FullName); } - public override Schema Schema => _bindings.Schema; + public override Schema OutputSchema => _bindings.Schema; bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => _mapper is ICanSaveOnnx onnxMapper ? onnxMapper.CanSaveOnnx(ctx) : false; bool ICanSavePfa.CanSavePfa => _mapper is ICanSavePfa pfaMapper ? pfaMapper.CanSavePfa : false; - public RowToRowMapperTransform(IHostEnvironment env, IDataView input, IRowMapper mapper, Func mapperFactory) + [BestFriend] + internal RowToRowMapperTransform(IHostEnvironment env, IDataView input, IRowMapper mapper, Func mapperFactory) : base(env, RegistrationName, input) { Contracts.CheckValue(mapper, nameof(mapper)); Contracts.CheckValueOrNull(mapperFactory); _mapper = mapper; _mapperFactory = mapperFactory; - _bindings = new ColumnBindings(Schema.Create(input.Schema), mapper.GetOutputColumns()); + _bindings = new ColumnBindings(input.Schema, mapper.GetOutputColumns()); } - public static Schema GetOutputSchema(ISchema inputSchema, IRowMapper mapper) + [BestFriend] + internal static Schema GetOutputSchema(Schema inputSchema, IRowMapper mapper) { Contracts.CheckValue(inputSchema, nameof(inputSchema)); Contracts.CheckValue(mapper, nameof(mapper)); - return new ColumnBindings(Schema.Create(inputSchema), mapper.GetOutputColumns()).Schema; + return new ColumnBindings(inputSchema, mapper.GetOutputColumns()).Schema; } private RowToRowMapperTransform(IHost host, ModelLoadContext ctx, IDataView input) @@ -105,7 +107,7 @@ private RowToRowMapperTransform(IHost host, ModelLoadContext ctx, IDataView inpu // _mapper ctx.LoadModel(host, out _mapper, "Mapper", input.Schema); - _bindings = new ColumnBindings(Schema.Create(input.Schema), _mapper.GetOutputColumns()); + _bindings = new ColumnBindings(input.Schema, _mapper.GetOutputColumns()); } public static RowToRowMapperTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) @@ -137,12 +139,12 @@ public override void Save(ModelSaveContext ctx) /// private bool[] GetActive(Func predicate, out Func predicateInput) { - int n = _bindings.Schema.ColumnCount; + int n = _bindings.Schema.Count; var active = Utils.BuildArray(n, predicate); Contracts.Assert(active.Length == n); var activeInput = _bindings.GetActiveInput(predicate); - Contracts.Assert(activeInput.Length == _bindings.InputSchema.ColumnCount); + Contracts.Assert(activeInput.Length == _bindings.InputSchema.Count); // Get a predicate that determines which outputs are active. var predicateOut = GetActiveOutputColumns(active); @@ -160,7 +162,7 @@ private bool[] GetActive(Func predicate, out Func predicat private Func GetActiveOutputColumns(bool[] active) { Contracts.AssertValue(active); - Contracts.Assert(active.Length == _bindings.Schema.ColumnCount); + Contracts.Assert(active.Length == _bindings.Schema.Count); return col => @@ -178,14 +180,14 @@ private Func GetActiveOutputColumns(bool[] active) return null; } - protected override IRowCursor GetRowCursorCore(Func predicate, IRandom rand = null) + protected override RowCursor GetRowCursorCore(Func predicate, Random rand = null) { Func predicateInput; var active = GetActive(predicate, out predicateInput); - return new RowCursor(Host, Source.GetRowCursor(predicateInput, rand), this, active); + return new Cursor(Host, Source.GetRowCursor(predicateInput, rand), this, active); } - public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, IRandom rand = null) + public override RowCursor[] GetRowCursorSet(Func predicate, int n, Random rand = null) { Host.CheckValue(predicate, nameof(predicate)); Host.CheckValueOrNull(rand); @@ -193,16 +195,16 @@ public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolid Func predicateInput; var active = GetActive(predicate, out predicateInput); - var inputs = Source.GetRowCursorSet(out consolidator, predicateInput, n, rand); + var inputs = Source.GetRowCursorSet(predicateInput, n, rand); Host.AssertNonEmpty(inputs); if (inputs.Length == 1 && n > 1 && _bindings.AddedColumnIndices.Any(predicate)) - inputs = DataViewUtils.CreateSplitCursors(out consolidator, Host, inputs[0], n); + inputs = DataViewUtils.CreateSplitCursors(Host, inputs[0], n); Host.AssertNonEmpty(inputs); - var cursors = new IRowCursor[inputs.Length]; + var cursors = new RowCursor[inputs.Length]; for (int i = 0; i < inputs.Length; i++) - cursors[i] = new RowCursor(Host, inputs[i], this, active); + cursors[i] = new Cursor(Host, inputs[i], this, active); return cursors; } @@ -233,25 +235,22 @@ public Func GetDependencies(Func predicate) return predicateInput; } - Schema IRowToRowMapper.InputSchema => Source.Schema; + public Schema InputSchema => Source.Schema; - public IRow GetRow(IRow input, Func active, out Action disposer) + public Row GetRow(Row input, Func active) { Host.CheckValue(input, nameof(input)); Host.CheckValue(active, nameof(active)); Host.Check(input.Schema == Source.Schema, "Schema of input row must be the same as the schema the mapper is bound to"); - disposer = null; using (var ch = Host.Start("GetEntireRow")) { - Action disp; - var activeArr = new bool[Schema.ColumnCount]; - for (int i = 0; i < Schema.ColumnCount; i++) + var activeArr = new bool[OutputSchema.Count]; + for (int i = 0; i < OutputSchema.Count; i++) activeArr[i] = active(i); var pred = GetActiveOutputColumns(activeArr); - var getters = _mapper.CreateGetters(input, pred, out disp); - disposer += disp; - return new Row(input, this, Schema, getters); + var getters = _mapper.CreateGetters(input, pred, out Action disp); + return new RowImpl(input, this, OutputSchema, getters, disp); } } @@ -285,33 +284,35 @@ public IDataTransform ApplyToData(IHostEnvironment env, IDataView newSource) } } - private sealed class Row : IRow + private sealed class RowImpl : WrappingRow { - private readonly IRow _input; private readonly Delegate[] _getters; - private readonly RowToRowMapperTransform _parent; + private readonly Action _disposer; - public long Batch { get { return _input.Batch; } } - - public long Position { get { return _input.Position; } } - - public Schema Schema { get; } + public override Schema Schema { get; } - public Row(IRow input, RowToRowMapperTransform parent, Schema schema, Delegate[] getters) + public RowImpl(Row input, RowToRowMapperTransform parent, Schema schema, Delegate[] getters, Action disposer) + : base(input) { - _input = input; _parent = parent; Schema = schema; _getters = getters; + _disposer = disposer; } - public ValueGetter GetGetter(int col) + protected override void DisposeCore(bool disposing) + { + if (disposing) + _disposer?.Invoke(); + } + + public override ValueGetter GetGetter(int col) { bool isSrc; int index = _parent._bindings.MapColumnIndex(out isSrc, col); if (isSrc) - return _input.GetGetter(index); + return Input.GetGetter(index); Contracts.Assert(_getters[index] != null); var fn = _getters[index] as ValueGetter; @@ -320,28 +321,27 @@ public ValueGetter GetGetter(int col) return fn; } - public ValueGetter GetIdGetter() => _input.GetIdGetter(); - - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { bool isSrc; int index = _parent._bindings.MapColumnIndex(out isSrc, col); if (isSrc) - return _input.IsColumnActive((index)); + return Input.IsColumnActive((index)); return _getters[index] != null; } } - private sealed class RowCursor : SynchronizedCursorBase, IRowCursor + private sealed class Cursor : SynchronizedCursorBase { private readonly Delegate[] _getters; private readonly bool[] _active; private readonly ColumnBindings _bindings; private readonly Action _disposer; + private bool _disposed; - public Schema Schema => _bindings.Schema; + public override Schema Schema => _bindings.Schema; - public RowCursor(IChannelProvider provider, IRowCursor input, RowToRowMapperTransform parent, bool[] active) + public Cursor(IChannelProvider provider, RowCursor input, RowToRowMapperTransform parent, bool[] active) : base(provider, input) { var pred = parent.GetActiveOutputColumns(active); @@ -350,13 +350,13 @@ public RowCursor(IChannelProvider provider, IRowCursor input, RowToRowMapperTran _bindings = parent._bindings; } - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { - Ch.Check(0 <= col && col < _bindings.Schema.ColumnCount); + Ch.Check(0 <= col && col < _bindings.Schema.Count); return _active[col]; } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Ch.Check(IsColumnActive(col)); @@ -374,10 +374,14 @@ public ValueGetter GetGetter(int col) return fn; } - public override void Dispose() + protected override void Dispose(bool disposing) { - _disposer?.Invoke(); - base.Dispose(); + if (_disposed) + return; + if (disposing) + _disposer?.Invoke(); + _disposed = true; + base.Dispose(disposing); } } } diff --git a/src/Microsoft.ML.Data/DataView/SimpleRow.cs b/src/Microsoft.ML.Data/DataView/SimpleRow.cs index 7cc8ec7059..0e37b2f0f1 100644 --- a/src/Microsoft.ML.Data/DataView/SimpleRow.cs +++ b/src/Microsoft.ML.Data/DataView/SimpleRow.cs @@ -5,57 +5,64 @@ using System; using System.Collections.Generic; using System.Linq; -using Microsoft.ML.Data; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// - /// An implementation of that gets its , , - /// and from an input row. The constructor requires a schema and array of getter - /// delegates. A null delegate indicates an inactive column. The delegates are assumed to be of the appropriate type - /// (this does not validate the type). + /// An implementation of that gets its , , + /// and from an input row. The constructor requires a schema and array of getter + /// delegates. A delegate indicates an inactive column. The delegates are assumed to be + /// of the appropriate type (this does not validate the type). /// REVIEW: Should this validate that the delegates are of the appropriate type? It wouldn't be difficult /// to do so. /// - public sealed class SimpleRow : IRow + [BestFriend] + internal sealed class SimpleRow : WrappingRow { - private readonly Schema _schema; - private readonly IRow _input; private readonly Delegate[] _getters; - - public Schema Schema { get { return _schema; } } - - public long Position { get { return _input.Position; } } - - public long Batch { get { return _input.Batch; } } - - public SimpleRow(Schema schema, IRow input, Delegate[] getters) + private readonly Action _disposer; + + public override Schema Schema { get; } + + /// + /// Constructor. + /// + /// The schema for the row. + /// The row that is being wrapped by this row, where our , + /// , . + /// The collection of getter delegates, whose types should map those in a schema. + /// If one of these is , the corresponding column is considered inactive. + /// A method that, if non-null, will be called exactly once during + /// , prior to disposing . + public SimpleRow(Schema schema, Row input, Delegate[] getters, Action disposer = null) + : base(input) { Contracts.CheckValue(schema, nameof(schema)); Contracts.CheckValue(input, nameof(input)); - Contracts.Check(Utils.Size(getters) == schema.ColumnCount); - _schema = schema; - _input = input; + Contracts.Check(Utils.Size(getters) == schema.Count); + Contracts.CheckValueOrNull(disposer); + Schema = schema; _getters = getters ?? new Delegate[0]; + _disposer = disposer; } - public ValueGetter GetIdGetter() + protected override void DisposeCore(bool disposing) { - return _input.GetIdGetter(); + if (disposing) + _disposer?.Invoke(); } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Contracts.CheckParam(0 <= col && col < _getters.Length, nameof(col), "Invalid col value in GetGetter"); Contracts.Check(IsColumnActive(col)); - var fn = _getters[col] as ValueGetter; - if (fn == null) - throw Contracts.Except("Unexpected TValue in GetGetter"); - return fn; + if (_getters[col] is ValueGetter fn) + return fn; + throw Contracts.Except("Unexpected TValue in GetGetter"); } - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { Contracts.Check(0 <= col && col < _getters.Length); return _getters[col] != null; @@ -136,68 +143,6 @@ public void GetMetadata(string kind, int col, ref TValue value) protected abstract void GetMetadataCore(string kind, int col, ref TValue value); } - /// - /// An that takes all column names and types as constructor parameters. - /// The columns can optionally have text metadata. - /// - public sealed class SimpleSchema : SimpleSchemaBase - { - private readonly MetadataUtils.MetadataGetter>>[] _keyValueGetters; - - public SimpleSchema(IExceptionContext ectx, params KeyValuePair[] columns) - : base(ectx, columns) - { - _keyValueGetters = new MetadataUtils.MetadataGetter>>[ColumnCount]; - } - - public SimpleSchema(IExceptionContext ectx, KeyValuePair[] columns, - Dictionary>>> keyValues) - : this(ectx, columns) - { - foreach (var kvp in keyValues) - { - var name = kvp.Key; - var getter = kvp.Value; - if (!ColumnNameMap.TryGetValue(name, out int col)) - throw Ectx.ExceptParam(nameof(keyValues), $"Output schema does not contain column '{name}'"); - if (!Types[col].ItemType.IsKey) - throw Ectx.ExceptParam(nameof(keyValues), $"Column '{name}' is not a key column, so it cannot have key value metadata"); - _keyValueGetters[col] = getter; - } - } - - protected override IEnumerable> GetMetadataTypesCore(int col) - { - Ectx.Assert(0 <= col && col < ColumnCount); - if (_keyValueGetters[col] != null) - { - Ectx.Assert(Types[col].ItemType.IsKey); - yield return new KeyValuePair(MetadataUtils.Kinds.KeyValues, - new VectorType(TextType.Instance, Types[col].ItemType.KeyCount)); - } - } - - protected override ColumnType GetMetadataTypeOrNullCore(string kind, int col) - { - Ectx.Assert(0 <= col && col < ColumnCount); - if (kind == MetadataUtils.Kinds.KeyValues && _keyValueGetters[col] != null) - { - Ectx.Assert(Types[col].ItemType.IsKey); - return new VectorType(TextType.Instance, Types[col].ItemType.KeyCount); - } - return null; - } - - protected override void GetMetadataCore(string kind, int col, ref TValue value) - { - Ectx.Assert(0 <= col && col < ColumnCount); - if (kind == MetadataUtils.Kinds.KeyValues && _keyValueGetters[col] != null) - _keyValueGetters[col].Marshal(col, ref value); - else - throw Ectx.ExceptGetMetadata(); - } - } - public static class SimpleSchemaUtils { public static Schema Create(IExceptionContext ectx, params KeyValuePair[] columns) @@ -210,5 +155,4 @@ public static Schema Create(IExceptionContext ectx, params KeyValuePair /// This provides a scalable method of getting a "transposed" view of a subset of columns from an @@ -21,7 +19,8 @@ namespace Microsoft.ML.Runtime.Data /// were not transposable before. Note that transposition is a somewhat slow and resource intensive /// operation. /// - public sealed class Transposer : ITransposeDataView, IDisposable + [BestFriend] + internal sealed class Transposer : ITransposeDataView, IDisposable { private readonly IHost _host; // The input view. @@ -35,7 +34,7 @@ public sealed class Transposer : ITransposeDataView, IDisposable public readonly int RowCount; // -1 for input columns that were not transposed, a non-negative index into _cols for those that were. private readonly int[] _inputToTransposed; - private readonly ColumnInfo[] _cols; + private readonly Schema.Column[] _cols; private readonly int[] _splitLim; private readonly SchemaImpl _tschema; private bool _disposed; @@ -103,13 +102,13 @@ private Transposer(IHost host, IDataView view, bool forceSave, int[] columns) columnSet = columnSet.Where(c => ttschema.GetSlotType(c) == null); } columns = columnSet.ToArray(); - _cols = new ColumnInfo[columns.Length]; + _cols = new Schema.Column[columns.Length]; var schema = _view.Schema; _nameToICol = new Dictionary(); - _inputToTransposed = Utils.CreateArray(schema.ColumnCount, -1); + _inputToTransposed = Utils.CreateArray(schema.Count, -1); for (int c = 0; c < columns.Length; ++c) { - _nameToICol[(_cols[c] = ColumnInfo.CreateFromIndex(schema, columns[c])).Name] = c; + _nameToICol[(_cols[c] = schema[columns[c]]).Name] = c; _inputToTransposed[columns[c]] = c; } @@ -134,7 +133,7 @@ private Transposer(IHost host, IDataView view, bool forceSave, int[] columns) // since it would be strange if the same type failed or not in the // transposer depending on the size. At least as a user, that would // surprise me. Also I expect this to never happen... - var type = schema.GetColumnType(_cols[c].Index); + var type = schema[_cols[c].Index].Type; if (!saver.IsColumnSavable(type)) throw ch.ExceptParam(nameof(view), "Column named '{0}' is not serializable by the transposer", _cols[c].Name); if (type.IsVector && !type.IsKnownSizeVector) @@ -143,7 +142,7 @@ private Transposer(IHost host, IDataView view, bool forceSave, int[] columns) var slicer = new DataViewSlicer(_host, view, columns); var slicerSchema = slicer.Schema; - ch.Assert(Enumerable.Range(0, slicerSchema.ColumnCount).All(c => saver.IsColumnSavable(slicerSchema.GetColumnType(c)))); + ch.Assert(Enumerable.Range(0, slicerSchema.Count).All(c => saver.IsColumnSavable(slicerSchema[c].Type))); _splitLim = new int[_cols.Length]; List toSave = new List(); int offset = 0; @@ -223,13 +222,13 @@ private static int[] CheckIndices(IHost host, IDataView view, int[] columns) var schema = view.Schema; for (int c = 0; c < columns.Length; ++c) { - if (!(0 <= columns[c] && columns[c] < schema.ColumnCount)) - throw host.ExceptParam(nameof(columns), "Column index {0} illegal for data with {1} column", columns[c], schema.ColumnCount); + if (!(0 <= columns[c] && columns[c] < schema.Count)) + throw host.ExceptParam(nameof(columns), "Column index {0} illegal for data with {1} column", columns[c], schema.Count); } return columns; } - public ISlotCursor GetSlotCursor(int col) + public SlotCursor GetSlotCursor(int col) { _host.CheckParam(0 <= col && col < _tschema.ColumnCount, nameof(col)); if (_inputToTransposed[col] == -1) @@ -249,7 +248,7 @@ public ISlotCursor GetSlotCursor(int col) return Utils.MarshalInvoke(GetSlotCursorCore, type, col); } - private ISlotCursor GetSlotCursorCore(int col) + private SlotCursor GetSlotCursorCore(int col) { if (_tschema.GetColumnType(col).IsVector) return new SlotCursorVec(this, col); @@ -265,14 +264,14 @@ private ISlotCursor GetSlotCursorCore(int col) public bool CanShuffle { get { return _view.CanShuffle; } } - public IRowCursor GetRowCursor(Func predicate, IRandom rand = null) + public RowCursor GetRowCursor(Func predicate, Random rand = null) { return _view.GetRowCursor(predicate, rand); } - public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, IRandom rand = null) + public RowCursor[] GetRowCursorSet(Func predicate, int n, Random rand = null) { - return _view.GetRowCursorSet(out consolidator, predicate, n, rand); + return _view.GetRowCursorSet(predicate, n, rand); } public long? GetRowCount() @@ -292,7 +291,7 @@ private sealed class SchemaImpl : ITransposeSchema public Schema AsSchema { get; } - public int ColumnCount { get { return InputSchema.ColumnCount; } } + public int ColumnCount { get { return InputSchema.Count; } } public SchemaImpl(Transposer parent) { @@ -304,10 +303,11 @@ public SchemaImpl(Transposer parent) _slotTypes = new VectorType[_parent._cols.Length]; for (int c = 0; c < _slotTypes.Length; ++c) { - ColumnInfo srcInfo = _parent._cols[c]; + var srcInfo = _parent._cols[c]; var ctype = srcInfo.Type.ItemType; - _ectx.Assert(ctype.IsPrimitive); - _slotTypes[c] = new VectorType(ctype.AsPrimitive, _parent.RowCount); + var primitiveType = ctype as PrimitiveType; + _ectx.Assert(primitiveType != null); + _slotTypes[c] = new VectorType(primitiveType, _parent.RowCount); } AsSchema = Schema.Create(this); @@ -321,27 +321,27 @@ public bool TryGetColumnIndex(string name, out int col) public string GetColumnName(int col) { - return InputSchema.GetColumnName(col); + return InputSchema[col].Name; } public ColumnType GetColumnType(int col) { - return InputSchema.GetColumnType(col); + return InputSchema[col].Type; } public IEnumerable> GetMetadataTypes(int col) { - return InputSchema.GetMetadataTypes(col); + return InputSchema[col].Metadata.Schema.Select(c => new KeyValuePair(c.Name, c.Type)); } public ColumnType GetMetadataTypeOrNull(string kind, int col) { - return InputSchema.GetMetadataTypeOrNull(kind, col); + return InputSchema[col].Metadata.Schema.GetColumnOrNull(kind)?.Type; } public void GetMetadata(string kind, int col, ref TValue value) { - InputSchema.GetMetadata(kind, col, ref value); + InputSchema[col].Metadata.GetValue(kind, ref value); } public VectorType GetSlotType(int col) @@ -357,33 +357,21 @@ public VectorType GetSlotType(int col) } } - private abstract class SlotCursor : RootCursorBase, ISlotCursor + private abstract class SlotCursor : SlotCursor.RootSlotCursor { private readonly Transposer _parent; private readonly int _col; private ValueGetter> _getter; - public override long Batch { get { return 0; } } - protected SlotCursor(Transposer parent, int col) : base(parent._host) { - Ch.Assert(0 <= col && col < parent.Schema.ColumnCount); + Ch.Assert(0 <= col && col < parent.Schema.Count); _parent = parent; _col = col; } - public override ValueGetter GetIdGetter() - { - return - (ref UInt128 val) => - { - Ch.Check(IsGood, "Cannot call ID getter in current state"); - val = new UInt128((ulong)Position, 0); - }; - } - - public ValueGetter> GetGetter() + public override ValueGetter> GetGetter() { if (_getter == null) _getter = GetGetterCore(); @@ -393,7 +381,7 @@ public ValueGetter> GetGetter() return getter; } - public VectorType GetSlotType() + public override VectorType GetSlotType() { return _parent.TransposeSchema.GetSlotType(_col); } @@ -406,11 +394,12 @@ private sealed class SlotCursorOne : SlotCursor private readonly IDataView _view; private readonly int _col; private readonly int _len; + private bool _moved; public SlotCursorOne(Transposer parent, int col) : base(parent, col) { - Ch.Assert(0 <= col && col < parent.Schema.ColumnCount); + Ch.Assert(0 <= col && col < parent.Schema.Count); int iinfo = parent._inputToTransposed[col]; Ch.Assert(iinfo >= 0); int smin = iinfo == 0 ? 0 : parent._splitLim[iinfo - 1]; @@ -427,20 +416,20 @@ public SlotCursorOne(Transposer parent, int col) Ch.Assert(parent._splitLim[iinfo] - _col == 1); } Ch.AssertValue(_view); - Ch.Assert(_view.Schema.GetColumnType(_col).IsPrimitive); - Ch.Assert(_view.Schema.GetColumnType(_col).RawType == typeof(T)); + Ch.Assert(_view.Schema[_col].Type.IsPrimitive); + Ch.Assert(_view.Schema[_col].Type.RawType == typeof(T)); _len = parent.RowCount; } protected override bool MoveNextCore() { // We only can move next on one slot, since this is a scalar column. - return State == CursorState.NotStarted; + return _moved = !_moved; } protected override ValueGetter> GetGetterCore() { - var isDefault = Conversion.Conversions.Instance.GetIsDefaultPredicate(_view.Schema.GetColumnType(_col)); + var isDefault = Conversion.Conversions.Instance.GetIsDefaultPredicate(_view.Schema[_col].Type); bool valid = false; VBuffer cached = default(VBuffer); return @@ -577,12 +566,12 @@ public SlotCursorVec(Transposer parent, int col) /// private void EnsureValid() { - Ch.Check(State == CursorState.Good, "Cursor is not in good state, cannot get values"); + Ch.Check(IsGood, "Cursor is not in good state, cannot get values"); Ch.Assert(_slotCurr >= 0); if (_colStored == _colCurr) return; - var type = _view.Schema.GetColumnType(_colCurr); + var type = _view.Schema[_colCurr].Type; Ch.Assert(type.ItemType.RawType == typeof(T)); Ch.Assert(type.ValueCount > 0); InPredicate isDefault = Conversion.Conversions.Instance.GetIsDefaultPredicate(type.ItemType); @@ -753,7 +742,7 @@ protected override bool MoveNextCore() _slotCurr = 0; if (++_colCurr == _colLim) return false; - _slotLim = _view.Schema.GetColumnType(_colCurr).ValueCount; + _slotLim = _view.Schema[_colCurr].Type.ValueCount; Ch.Assert(_slotLim > 0); return true; } @@ -822,7 +811,7 @@ public DataViewSlicer(IHost host, IDataView input, int[] toSlice) var splitter = _splitters[c] = Splitter.Create(_input, toSlice[c]); _host.Assert(splitter.ColumnCount >= 1); _incolToLim[c] = outputColumnCount += splitter.ColumnCount; - nameToCol[_input.Schema.GetColumnName(toSlice[c])] = outputColumnCount - 1; + nameToCol[_input.Schema[toSlice[c]].Name] = outputColumnCount - 1; } _colToSplitIndex = new int[outputColumnCount]; _colToSplitCol = new int[outputColumnCount]; @@ -867,7 +856,7 @@ private void OutputColumnToSplitterIndices(int col, out int splitInd, out int sp splitCol = _colToSplitCol[col]; } - public IRowCursor GetRowCursor(Func predicate, IRandom rand = null) + public RowCursor GetRowCursor(Func predicate, Random rand = null) { _host.CheckValue(predicate, nameof(predicate)); bool[] activeSplitters; @@ -875,13 +864,13 @@ public IRowCursor GetRowCursor(Func predicate, IRandom rand = null) return new Cursor(_host, this, _input.GetRowCursor(srcPred, rand), predicate, activeSplitters); } - public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, IRandom rand = null) + public RowCursor[] GetRowCursorSet(Func predicate, int n, Random rand = null) { _host.CheckValue(predicate, nameof(predicate)); _host.CheckValueOrNull(rand); bool[] activeSplitters; var srcPred = CreateInputPredicate(predicate, out activeSplitters); - var result = _input.GetRowCursorSet(out consolidator, srcPred, n, rand); + var result = _input.GetRowCursorSet(srcPred, n, rand); for (int i = 0; i < result.Length; ++i) result[i] = new Cursor(_host, this, result[i], predicate, activeSplitters); return result; @@ -891,7 +880,7 @@ public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Fun /// Given a possibly null predicate for this data view, produce the dependency predicate for the sources, /// as well as a list of all the splitters for which we should produce rowsets. /// - /// The predicate input into the method. + /// The predicate input into the method. /// A boolean indicator array of length equal to the number of splitters, /// indicating whether that splitter has any active columns in its outputs or not /// The predicate to use when constructing the row cursor from the source @@ -906,7 +895,7 @@ private Func CreateInputPredicate(Func pred, out bool[] ac { var splitter = _splitters[i]; // Don't activate input source columns if none of the resulting columns were selected. - bool isActive = pred == null || Enumerable.Range(offset, splitter.ColumnCount).Any(c => pred(c)); + bool isActive = pred == null || Enumerable.Range(offset, splitter.AsSchema.Count).Any(c => pred(c)); if (isActive) { activeSplitters[i] = isActive; @@ -1000,7 +989,7 @@ public void GetMetadata(string kind, int col, ref TValue value) /// There is one instance of these per column, implementing the possible splitting /// of one column from a into multiple columns. The instance /// describes the resulting split columns through its implementation of - /// , and then can be bound to an to provide + /// , and then can be bound to an to provide /// that splitting functionality. /// private abstract class Splitter : NoMetadataSchema @@ -1015,7 +1004,7 @@ private abstract class Splitter : NoMetadataSchema protected Splitter(IDataView view, int col) { Contracts.AssertValue(view); - Contracts.Assert(0 <= col && col < view.Schema.ColumnCount); + Contracts.Assert(0 <= col && col < view.Schema.Count); _view = view; _col = col; } @@ -1025,7 +1014,7 @@ protected Splitter(IDataView view, int col) /// public static Splitter Create(IDataView view, int col) { - var type = view.Schema.GetColumnType(col); + var type = view.Schema[col].Type; Contracts.Assert(type.IsPrimitive || type.VectorSize > 0); const int defaultSplitThreshold = 16; if (type.VectorSize <= defaultSplitThreshold) @@ -1059,10 +1048,10 @@ public static Splitter Create(IDataView view, int col) } /// - /// Given an input , create the containing the split + /// Given an input , create the containing the split /// version of the columns. /// - public abstract IRow Bind(IRow row, Func pred); + public abstract Row Bind(Row row, Func pred); private static Splitter CreateCore(IDataView view, int col) { @@ -1079,7 +1068,7 @@ private static Splitter CreateCore(IDataView view, int col, int[] ends) public override bool TryGetColumnIndex(string name, out int col) { Contracts.CheckNonEmpty(name, nameof(name)); - if (name != _view.Schema.GetColumnName(SrcCol)) + if (name != _view.Schema[SrcCol].Name) { col = default(int); return false; @@ -1093,37 +1082,25 @@ public override bool TryGetColumnIndex(string name, out int col) public override string GetColumnName(int col) { Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - return _view.Schema.GetColumnName(SrcCol); + return _view.Schema[SrcCol].Name; } #endregion - private abstract class RowBase : IRow + private abstract class RowBase : WrappingRow where TSplitter : Splitter { protected readonly TSplitter Parent; - protected readonly IRow Input; - public Schema Schema => Parent.AsSchema; - public long Position => Input.Position; - public long Batch => Input.Batch; + public sealed override Schema Schema => Parent.AsSchema; - public RowBase(TSplitter parent, IRow input) + public RowBase(TSplitter parent, Row input) + : base(input) { Contracts.AssertValue(parent); Contracts.AssertValue(input); Contracts.Assert(input.IsColumnActive(parent.SrcCol)); Parent = parent; - Input = input; } - - public ValueGetter GetIdGetter() - { - return Input.GetIdGetter(); - } - - public abstract bool IsColumnActive(int col); - - public abstract ValueGetter GetGetter(int col); } /// @@ -1140,30 +1117,30 @@ private sealed class NoSplitter : Splitter public NoSplitter(IDataView view, int col) : base(view, col) { - Contracts.Assert(_view.Schema.GetColumnType(col).RawType == typeof(T)); + Contracts.Assert(_view.Schema[col].Type.RawType == typeof(T)); AsSchema = Schema.Create(this); } public override ColumnType GetColumnType(int col) { Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - return _view.Schema.GetColumnType(SrcCol); + return _view.Schema[SrcCol].Type; } - public override IRow Bind(IRow row, Func pred) + public override Row Bind(Row row, Func pred) { Contracts.AssertValue(row); Contracts.Assert(row.Schema == _view.Schema); Contracts.AssertValue(pred); Contracts.Assert(row.IsColumnActive(SrcCol)); - return new Row(this, row, pred(0)); + return new RowImpl(this, row, pred(0)); } - private sealed class Row : RowBase> + private sealed class RowImpl : RowBase> { private readonly bool _isActive; - public Row(NoSplitter parent, IRow input, bool isActive) + public RowImpl(NoSplitter parent, Row input, bool isActive) : base(parent, input) { Contracts.Assert(Parent.ColumnCount == 1); @@ -1212,7 +1189,7 @@ private sealed class ColumnSplitter : Splitter public ColumnSplitter(IDataView view, int col, int[] lims) : base(view, col) { - var type = _view.Schema.GetColumnType(SrcCol).AsVector; + var type = _view.Schema[SrcCol].Type as VectorType; // Only valid use is for two or more slices. Contracts.Assert(Utils.Size(lims) >= 2); Contracts.AssertValue(type); @@ -1236,16 +1213,16 @@ public override ColumnType GetColumnType(int col) return _types[col]; } - public override IRow Bind(IRow row, Func pred) + public override Row Bind(Row row, Func pred) { Contracts.AssertValue(row); Contracts.Assert(row.Schema == _view.Schema); Contracts.AssertValue(pred); Contracts.Assert(row.IsColumnActive(SrcCol)); - return new Row(this, row, pred); + return new RowImpl(this, row, pred); } - private sealed class Row : RowBase> + private sealed class RowImpl : RowBase> { // Counter of the last valid input, updated by EnsureValid. private long _lastValid; @@ -1260,7 +1237,7 @@ private sealed class Row : RowBase> // Getters. private readonly ValueGetter>[] _getters; - public Row(ColumnSplitter parent, IRow input, Func pred) + public RowImpl(ColumnSplitter parent, Row input, Func pred) : base(parent, input) { _inputGetter = input.GetGetter>(Parent.SrcCol); @@ -1367,17 +1344,17 @@ private void EnsureValid() } /// - /// The cursor implementation creates the s using , + /// The cursor implementation creates the s using , /// then collates the results from those rows as effectively one big row. /// - private sealed class Cursor : SynchronizedCursorBase, IRowCursor + private sealed class Cursor : SynchronizedCursorBase { private readonly DataViewSlicer _slicer; - private readonly IRow[] _sliceRows; + private readonly Row[] _sliceRows; - public Schema Schema => _slicer.Schema; + public override Schema Schema => _slicer.Schema; - public Cursor(IChannelProvider provider, DataViewSlicer slicer, IRowCursor input, Func pred, bool[] activeSplitters) + public Cursor(IChannelProvider provider, DataViewSlicer slicer, RowCursor input, Func pred, bool[] activeSplitters) : base(provider, input) { Ch.AssertValue(slicer); @@ -1385,7 +1362,7 @@ public Cursor(IChannelProvider provider, DataViewSlicer slicer, IRowCursor input Ch.Assert(Utils.Size(activeSplitters) == slicer._splitters.Length); _slicer = slicer; - _sliceRows = new IRow[_slicer._splitters.Length]; + _sliceRows = new Row[_slicer._splitters.Length]; var activeSrc = new bool[slicer._splitters.Length]; var activeSrcSet = new HashSet(); int offset = 0; @@ -1403,16 +1380,16 @@ public Cursor(IChannelProvider provider, DataViewSlicer slicer, IRowCursor input } } - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { - Ch.Check(0 <= col && col < Schema.ColumnCount, "col"); + Ch.Check(0 <= col && col < Schema.Count, "col"); int splitInd; int splitCol; _slicer.OutputColumnToSplitterIndices(col, out splitInd, out splitCol); return _sliceRows[splitInd] != null && _sliceRows[splitInd].IsColumnActive(splitCol); } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Ch.Check(IsColumnActive(col)); int splitInd; @@ -1424,7 +1401,7 @@ public ValueGetter GetGetter(int col) } } - public static class TransposerUtils + internal static class TransposerUtils { /// /// This is a convenience method that extracts a single slot value's vector, @@ -1433,20 +1410,20 @@ public static class TransposerUtils public static void GetSingleSlotValue(this ITransposeDataView view, int col, ref VBuffer dst) { Contracts.CheckValue(view, nameof(view)); - Contracts.CheckParam(0 <= col && col < view.Schema.ColumnCount, nameof(col)); + Contracts.CheckParam(0 <= col && col < view.Schema.Count, nameof(col)); using (var cursor = view.GetSlotCursor(col)) { var getter = cursor.GetGetter(); if (!cursor.MoveNext()) - throw Contracts.Except("Could not get single value on column '{0}' because there are no slots", view.Schema.GetColumnName(col)); + throw Contracts.Except("Could not get single value on column '{0}' because there are no slots", view.Schema[col].Name); getter(ref dst); if (cursor.MoveNext()) - throw Contracts.Except("Could not get single value on column '{0}' because there is more than one slot", view.Schema.GetColumnName(col)); + throw Contracts.Except("Could not get single value on column '{0}' because there is more than one slot", view.Schema[col].Name); } } /// - /// The is parameterized by a type that becomes the + /// The is parameterized by a type that becomes the /// type parameter for a , and this is generally preferable and more /// sensible but for various reasons it's often a lot simpler to have a get-getter be over /// the actual type returned by the getter, that is, parameterize this by the actual @@ -1457,7 +1434,7 @@ public static void GetSingleSlotValue(this ITransposeDataView view, int col, /// The cursor to get the getter for /// The exception contxt /// The value getter - public static ValueGetter GetGetterWithVectorType(this ISlotCursor cursor, IExceptionContext ctx = null) + public static ValueGetter GetGetterWithVectorType(this SlotCursor cursor, IExceptionContext ctx = null) { Contracts.CheckValueOrNull(ctx); ctx.CheckValue(cursor, nameof(cursor)); @@ -1478,15 +1455,15 @@ public static ValueGetter GetGetterWithVectorType(this ISlotCurs /// /// Given a slot cursor, construct a single-column equivalent row cursor, with the single column /// active and having the same type. This is useful to exploit the many utility methods that exist - /// to handle and but that know nothing about - /// , without having to rewrite all of them. This is, however, rather + /// to handle and but that know nothing about + /// , without having to rewrite all of them. This is, however, rather /// something of a hack; whenever possible or reasonable the slot cursor should be used directly. /// The name of this column is always "Waffles". /// /// The channel provider used in creating the wrapping row cursor /// The slot cursor to wrap /// A row cursor with a single active column with the same type as the slot type - public static IRowCursor GetRowCursorShim(IChannelProvider provider, ISlotCursor cursor) + public static RowCursor GetRowCursorShim(IChannelProvider provider, SlotCursor cursor) { Contracts.CheckValue(provider, nameof(provider)); provider.CheckValue(cursor, nameof(cursor)); @@ -1494,7 +1471,7 @@ public static IRowCursor GetRowCursorShim(IChannelProvider provider, ISlotCursor return Utils.MarshalInvoke(GetRowCursorShimCore, cursor.GetSlotType().ItemType.RawType, provider, cursor); } - private static IRowCursor GetRowCursorShimCore(IChannelProvider provider, ISlotCursor cursor) + private static RowCursor GetRowCursorShimCore(IChannelProvider provider, SlotCursor cursor) { return new SlotRowCursorShim(provider, cursor); } @@ -1508,141 +1485,79 @@ public sealed class SlotDataView : IDataView private readonly ITransposeDataView _data; private readonly int _col; private readonly ColumnType _type; - private readonly SchemaImpl _schemaImpl; - public Schema Schema => _schemaImpl.AsSchema; + public Schema Schema { get; } - public bool CanShuffle { get { return false; } } + public bool CanShuffle => false; public SlotDataView(IHostEnvironment env, ITransposeDataView data, int col) { Contracts.CheckValue(env, nameof(env)); _host = env.Register("SlotDataView"); _host.CheckValue(data, nameof(data)); - _host.CheckParam(0 <= col && col < data.Schema.ColumnCount, nameof(col)); + _host.CheckParam(0 <= col && col < data.Schema.Count, nameof(col)); _type = data.TransposeSchema.GetSlotType(col); _host.AssertValue(_type); _data = data; _col = col; - _schemaImpl = new SchemaImpl(this); + + var builder = new SchemaBuilder(); + builder.AddColumn(_data.Schema[_col].Name, _type); + Schema = builder.GetSchema(); } public long? GetRowCount() { - var type = _data.Schema.GetColumnType(_col); + var type = _data.Schema[_col].Type; int valueCount = type.ValueCount; _host.Assert(valueCount > 0); return valueCount; } - public IRowCursor GetRowCursor(Func predicate, IRandom rand = null) + public RowCursor GetRowCursor(Func predicate, Random rand = null) { _host.CheckValue(predicate, nameof(predicate)); return Utils.MarshalInvoke(GetRowCursor, _type.ItemType.RawType, predicate(0)); } - private IRowCursor GetRowCursor(bool active) + private RowCursor GetRowCursor(bool active) { return new Cursor(this, active); } - public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, IRandom rand = null) + public RowCursor[] GetRowCursorSet(Func predicate, int n, Random rand = null) { _host.CheckValue(predicate, nameof(predicate)); - consolidator = null; - return new IRowCursor[] { GetRowCursor(predicate, rand) }; + return new RowCursor[] { GetRowCursor(predicate, rand) }; } - private sealed class SchemaImpl : ISchema - { - private readonly SlotDataView _parent; - - private IHost Host { get { return _parent._host; } } - - public Schema AsSchema { get; } - - public int ColumnCount { get { return 1; } } - - public SchemaImpl(SlotDataView parent) - { - Contracts.AssertValue(parent); - _parent = parent; - AsSchema = Schema.Create(this); - } - - public ColumnType GetColumnType(int col) - { - Host.CheckParam(col == 0, nameof(col)); - return _parent._type; - } - - public string GetColumnName(int col) - { - Host.CheckParam(col == 0, nameof(col)); - // There is no real need for this to have the real name as the internal IDV - // substream does not have its name accessed, but we'll save it just the same. - // I am tempted though to just have this thing always claim its name is 'Pancakes'. - return _parent._data.Schema.GetColumnName(_parent._col); - } - - public bool TryGetColumnIndex(string name, out int col) - { - if (name == GetColumnName(0)) - { - col = 0; - return true; - } - col = -1; - return false; - } - - // No metadata. The top level IDV will hold the schema information, including metadata. - // This per-column dataview schema information is just minimally functional. - - public IEnumerable> GetMetadataTypes(int col) - { - Host.CheckParam(col == 0, nameof(col)); - return Enumerable.Empty>(); - } - - public ColumnType GetMetadataTypeOrNull(string kind, int col) - { - Host.CheckNonEmpty(kind, nameof(kind)); - Host.CheckParam(col == 0, nameof(col)); - return null; - } - - public void GetMetadata(string kind, int col, ref TValue value) - { - Host.CheckNonEmpty(kind, nameof(kind)); - Host.CheckParam(col == 0, nameof(col)); - throw MetadataUtils.ExceptGetMetadata(); - } - } - - private sealed class Cursor : SynchronizedCursorBase, IRowCursor + private sealed class Cursor : RootCursorBase { private readonly SlotDataView _parent; + private readonly SlotCursor _slotCursor; private readonly Delegate _getter; - public Schema Schema => _parent.Schema; + public override Schema Schema => _parent.Schema; + + public override long Batch => 0; public Cursor(SlotDataView parent, bool active) - : base(parent._host, parent._data.GetSlotCursor(parent._col)) + : base(parent._host) { _parent = parent; + _slotCursor = _parent._data.GetSlotCursor(parent._col); if (active) - _getter = Input.GetGetter(); + _getter = _slotCursor.GetGetter(); } - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { Ch.CheckParam(col == 0, nameof(col)); return _getter != null; } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Ch.CheckParam(col == 0, nameof(col)); Ch.CheckParam(_getter != null, nameof(col), "requested column not active"); @@ -1652,98 +1567,61 @@ public ValueGetter GetGetter(int col) throw Ch.Except("Invalid TValue: '{0}'", typeof(TValue)); return getter; } - } - } - - // REVIEW: This shim class is very similar to the above shim class, except at the - // cursor level, not the cursorable level. Is there some non-horrifying way to unify both, somehow? - private sealed class SlotRowCursorShim : SynchronizedCursorBase, IRowCursor - { - private readonly SchemaImpl _schema; - - public Schema Schema => _schema.AsSchema; - - private sealed class SchemaImpl : ISchema - { - private readonly SlotRowCursorShim _parent; - private readonly VectorType _type; - - private IChannel Ch { get { return _parent.Ch; } } - - public Schema AsSchema { get; } - public int ColumnCount { get { return 1; } } + public override ValueGetter GetIdGetter() => GetId; - public SchemaImpl(SlotRowCursorShim parent, VectorType slotType) + private void GetId(ref RowId id) { - Contracts.AssertValue(parent); - _parent = parent; - Ch.AssertValue(slotType); - _type = slotType; - AsSchema = Schema.Create(this); + Ch.Check(_slotCursor.SlotIndex >= 0, "Cannot get ID with cursor in current state."); + id = new RowId((ulong)_slotCursor.SlotIndex, 0); } - public ColumnType GetColumnType(int col) - { - Ch.CheckParam(col == 0, nameof(col)); - return _type; - } + protected override bool MoveNextCore() => _slotCursor.MoveNext(); + } + } - public string GetColumnName(int col) - { - Ch.CheckParam(col == 0, nameof(col)); - return "Waffles"; - } + // REVIEW: This shim class is very similar to the above shim class, except at the + // cursor level, not the cursorable level. Is there some non-horrifying way to unify both, somehow? + private sealed class SlotRowCursorShim : RootCursorBase + { + private readonly SlotCursor _slotCursor; - public bool TryGetColumnIndex(string name, out int col) - { - if (name == GetColumnName(0)) - { - col = 0; - return true; - } - col = -1; - return false; - } + public override Schema Schema { get; } - public IEnumerable> GetMetadataTypes(int col) - { - Ch.CheckParam(col == 0, nameof(col)); - return Enumerable.Empty>(); - } + public override long Batch => 0; - public ColumnType GetMetadataTypeOrNull(string kind, int col) - { - Ch.CheckNonEmpty(kind, nameof(kind)); - Ch.CheckParam(col == 0, nameof(col)); - return null; - } + public SlotRowCursorShim(IChannelProvider provider, SlotCursor cursor) + : base(provider) + { + Contracts.AssertValue(cursor); - public void GetMetadata(string kind, int col, ref TValue value) - { - Ch.CheckNonEmpty(kind, nameof(kind)); - Ch.CheckParam(col == 0, nameof(col)); - throw MetadataUtils.ExceptGetMetadata(); - } + _slotCursor = cursor; + var builder = new SchemaBuilder(); + builder.AddColumn("Waffles", cursor.GetSlotType()); + Schema = builder.GetSchema(); } - public SlotRowCursorShim(IChannelProvider provider, ISlotCursor cursor) - : base(provider, cursor) + public override bool IsColumnActive(int col) { - _schema = new SchemaImpl(this, Input.GetSlotType()); + Ch.CheckParam(col == 0, nameof(col)); + return true; } - public bool IsColumnActive(int col) + public override ValueGetter GetGetter(int col) { Ch.CheckParam(col == 0, nameof(col)); - return true; + return _slotCursor.GetGetterWithVectorType(Ch); } - public ValueGetter GetGetter(int col) + public override ValueGetter GetIdGetter() => GetId; + + private void GetId(ref RowId id) { - Ch.CheckParam(col == 0, nameof(col)); - return Input.GetGetterWithVectorType(Ch); + Ch.Check(_slotCursor.SlotIndex >= 0, "Cannot get ID with cursor in current state."); + id = new RowId((ulong)_slotCursor.SlotIndex, 0); } + + protected override bool MoveNextCore() => _slotCursor.MoveNext(); } /// @@ -1754,11 +1632,11 @@ public ValueGetter GetGetter(int col) /// internal sealed class SimpleTransposeSchema : ITransposeSchema { - private readonly ISchema _schema; + private readonly Schema _schema; - public int ColumnCount { get { return _schema.ColumnCount; } } + public int ColumnCount { get { return _schema.Count; } } - public SimpleTransposeSchema(ISchema schema) + public SimpleTransposeSchema(Schema schema) { Contracts.CheckValue(schema, nameof(schema)); _schema = schema; @@ -1766,7 +1644,7 @@ public SimpleTransposeSchema(ISchema schema) public string GetColumnName(int col) { - return _schema.GetColumnName(col); + return _schema[col].Name; } public bool TryGetColumnIndex(string name, out int col) @@ -1776,7 +1654,7 @@ public bool TryGetColumnIndex(string name, out int col) public ColumnType GetColumnType(int col) { - return _schema.GetColumnType(col); + return _schema[col].Type; } public VectorType GetSlotType(int col) @@ -1787,17 +1665,17 @@ public VectorType GetSlotType(int col) public IEnumerable> GetMetadataTypes(int col) { - return _schema.GetMetadataTypes(col); + return _schema[col].Metadata.Schema.Select(c => new KeyValuePair(c.Name, c.Type)); } public ColumnType GetMetadataTypeOrNull(string kind, int col) { - return _schema.GetMetadataTypeOrNull(kind, col); + return _schema[col].Metadata.Schema.GetColumnOrNull(kind)?.Type; } public void GetMetadata(string kind, int col, ref TValue value) { - _schema.GetMetadata(kind, col, ref value); + _schema[col].Metadata.GetValue(kind, ref value); } } } diff --git a/src/Microsoft.ML.Api/TypedCursor.cs b/src/Microsoft.ML.Data/DataView/TypedCursor.cs similarity index 79% rename from src/Microsoft.ML.Api/TypedCursor.cs rename to src/Microsoft.ML.Data/DataView/TypedCursor.cs index c7944c7704..724a19ff29 100644 --- a/src/Microsoft.ML.Api/TypedCursor.cs +++ b/src/Microsoft.ML.Data/DataView/TypedCursor.cs @@ -2,22 +2,21 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Data; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; using System; using System.Collections.Generic; using System.Linq; using System.Reflection; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.Api +namespace Microsoft.ML.Data { /// - /// This interface is an with 'strongly typed' binding. + /// This interface is an with 'strongly typed' binding. /// It can populate the user-supplied object's fields with the values of the current row. /// /// The user-defined type that is being populated while cursoring. - public interface IRowReadableAs : IRow + [BestFriend] + internal interface IRowReadableAs : IDisposable where TRow : class { /// @@ -27,30 +26,15 @@ public interface IRowReadableAs : IRow void FillValues(TRow row); } - /// - /// This interface is an with 'strongly typed' binding. - /// It can accept values of type and present the value as a row. - /// - /// The user-defined type that provides the values while cursoring. - public interface IRowBackedBy : IRow - where TRow : class - { - /// - /// Accepts the fields of the user-supplied object and publishes the instance as a row. - /// If the row is accessed prior to any object being set, then the data accessors on the row should throw. - /// - /// The row object. Cannot be null. - void ExtractValues(TRow row); - } - /// /// This interface provides cursoring through a via a 'strongly typed' binding. /// It can populate the user-supplied object's fields with the values of the current row. /// /// The user-defined type that is being populated while cursoring. - public interface IRowCursor : IRowReadableAs, ICursor + public abstract class RowCursor : RowCursor, IRowReadableAs where TRow : class { + public abstract void FillValues(TRow row); } /// @@ -63,13 +47,13 @@ public interface ICursorable /// /// Get a new cursor. /// - IRowCursor GetCursor(); + RowCursor GetCursor(); /// /// Get a new randomized cursor. /// /// The random seed to use. - IRowCursor GetRandomizedCursor(int randomSeed); + RowCursor GetRandomizedCursor(int randomSeed); } /// @@ -120,7 +104,7 @@ private TypedCursorable(IHostEnvironment env, IDataView data, bool ignoreMissing continue; throw _host.Except("Column '{0}' not found in the data view", col.ColumnName); } - var realColType = _data.Schema.GetColumnType(colIndex); + var realColType = _data.Schema[colIndex].Type; if (!IsCompatibleType(realColType, col.MemberInfo)) { throw _host.Except( @@ -163,7 +147,7 @@ private static bool IsCompatibleType(ColumnType colType, MemberInfo memberInfo) /// /// Create and return a new cursor. /// - public IRowCursor GetCursor() + public RowCursor GetCursor() { return GetCursor(x => false); } @@ -172,14 +156,14 @@ public IRowCursor GetCursor() /// Create and return a new randomized cursor. /// /// The random seed to use. - public IRowCursor GetRandomizedCursor(int randomSeed) + public RowCursor GetRandomizedCursor(int randomSeed) { return GetCursor(x => false, randomSeed); } - public IRowReadableAs GetRow(IRow input) + public IRowReadableAs GetRow(Row input) { - return new TypedRow(this, input); + return new RowImplementation(new TypedRow(this, input)); } /// @@ -188,14 +172,14 @@ public IRowReadableAs GetRow(IRow input) /// Predicate that denotes which additional columns to include in the cursor, /// in addition to the columns that are needed for populating the object. /// The random seed to use. If null, the cursor will be non-randomized. - public IRowCursor GetCursor(Func additionalColumnsPredicate, int? randomSeed = null) + public RowCursor GetCursor(Func additionalColumnsPredicate, int? randomSeed = null) { _host.CheckValue(additionalColumnsPredicate, nameof(additionalColumnsPredicate)); - IRandom rand = randomSeed.HasValue ? RandomUtils.Create(randomSeed.Value) : null; + Random rand = randomSeed.HasValue ? RandomUtils.Create(randomSeed.Value) : null; var cursor = _data.GetRowCursor(GetDependencies(additionalColumnsPredicate), rand); - return new TypedCursor(this, cursor); + return new RowCursorImplementation(new TypedCursor(this, cursor)); } public Func GetDependencies(Func additionalColumnsPredicate) @@ -206,28 +190,26 @@ public Func GetDependencies(Func additionalColumnsPredicat /// /// Create a set of cursors with additional active columns. /// - /// The consolidator for the original row cursors /// Predicate that denotes which additional columns to include in the cursor, /// in addition to the columns that are needed for populating the object. /// Number of cursors to create /// Random generator to use - public IRowCursor[] GetCursorSet(out IRowCursorConsolidator consolidator, - Func additionalColumnsPredicate, int n, IRandom rand) + public RowCursor[] GetCursorSet(Func additionalColumnsPredicate, int n, Random rand) { _host.CheckValue(additionalColumnsPredicate, nameof(additionalColumnsPredicate)); _host.CheckValueOrNull(rand); Func inputPredicate = col => _columnIndices.Contains(col) || additionalColumnsPredicate(col); - var inputs = _data.GetRowCursorSet(out consolidator, inputPredicate, n, rand); + var inputs = _data.GetRowCursorSet(inputPredicate, n, rand); _host.AssertNonEmpty(inputs); if (inputs.Length == 1 && n > 1) - inputs = DataViewUtils.CreateSplitCursors(out consolidator, _host, inputs[0], n); + inputs = DataViewUtils.CreateSplitCursors(_host, inputs[0], n); _host.AssertNonEmpty(inputs); return inputs - .Select(rc => (IRowCursor)(new TypedCursor(this, rc))) - .ToArray(); + .Select(rc => (RowCursor)(new RowCursorImplementation(new TypedCursor(this, rc)))) + .ToArray(); } /// @@ -251,49 +233,44 @@ public static TypedCursorable Create(IHostEnvironment env, IDataView data, return new TypedCursorable(env, data, ignoreMissingColumns, outSchema); } - private abstract class TypedRowBase : IRowReadableAs + private abstract class TypedRowBase : WrappingRow { protected readonly IChannel Ch; - private readonly IRow _input; private readonly Action[] _setters; - public long Batch => _input.Batch; - - public long Position => _input.Position; - - public Schema Schema => _input.Schema; + public override Schema Schema => base.Input.Schema; - public TypedRowBase(TypedCursorable parent, IRow input, string channelMessage) + public TypedRowBase(TypedCursorable parent, Row input, string channelMessage) + : base(input) { Contracts.AssertValue(parent); Contracts.AssertValue(parent._host); Ch = parent._host.Start(channelMessage); Ch.AssertValue(input); - _input = input; - int n = parent._pokes.Length; Ch.Assert(n == parent._columns.Length); Ch.Assert(n == parent._columnIndices.Length); _setters = new Action[n]; for (int i = 0; i < n; i++) - _setters[i] = GenerateSetter(_input, parent._columnIndices[i], parent._columns[i], parent._pokes[i], parent._peeks[i]); + _setters[i] = GenerateSetter(Input, parent._columnIndices[i], parent._columns[i], parent._pokes[i], parent._peeks[i]); } - public ValueGetter GetIdGetter() + protected override void DisposeCore(bool disposing) { - return _input.GetIdGetter(); + if (disposing) + Ch.Dispose(); } - private Action GenerateSetter(IRow input, int index, InternalSchemaDefinition.Column column, Delegate poke, Delegate peek) + private Action GenerateSetter(Row input, int index, InternalSchemaDefinition.Column column, Delegate poke, Delegate peek) { - var colType = input.Schema.GetColumnType(index); + var colType = input.Schema[index].Type; var fieldType = column.OutputType; var genericType = fieldType; - Func> del; + Func> del; if (fieldType.IsArray) { - Ch.Assert(colType.IsVector); + Ch.Assert(colType is VectorType); // VBuffer> -> String[] if (fieldType.GetElementType() == typeof(string)) { @@ -349,7 +326,7 @@ private Action GenerateSetter(IRow input, int index, InternalSchemaDefinit // REVIEW: The converting getter invokes a type conversion delegate on every call, so it's inherently slower // than the 'direct' getter. We don't have good indication of this to the user, and the selection // of affected types is pretty arbitrary (signed integers and bools, but not uints and floats). - private Action CreateConvertingVBufferSetter(IRow input, int col, Delegate poke, Delegate peek, Func convert) + private Action CreateConvertingVBufferSetter(Row input, int col, Delegate poke, Delegate peek, Func convert) { var getter = input.GetGetter>(col); var typedPoke = poke as Poke; @@ -371,7 +348,7 @@ private Action CreateConvertingVBufferSetter(IRow input, int c }; } - private Action CreateDirectVBufferSetter(IRow input, int col, Delegate poke, Delegate peek) + private Action CreateDirectVBufferSetter(Row input, int col, Delegate poke, Delegate peek) { var getter = input.GetGetter>(col); var typedPoke = poke as Poke; @@ -409,7 +386,7 @@ private Action CreateDirectVBufferSetter(IRow input, int col, Delega }; } - private static Action CreateConvertingActionSetter(IRow input, int col, Delegate poke, Func convert) + private static Action CreateConvertingActionSetter(Row input, int col, Delegate poke, Func convert) { var getter = input.GetGetter(col); var typedPoke = poke as Poke; @@ -423,7 +400,7 @@ private static Action CreateConvertingActionSetter(IRow input, }; } - private static Action CreateDirectSetter(IRow input, int col, Delegate poke, Delegate peek) + private static Action CreateDirectSetter(Row input, int col, Delegate poke, Delegate peek) { // Awkward to have a parameter that's always null, but slightly more convenient for generalizing the setter. Contracts.Assert(peek == null); @@ -438,7 +415,7 @@ private static Action CreateDirectSetter(IRow input, int col, Delega }; } - private Action CreateVBufferToVBufferSetter(IRow input, int col, Delegate poke, Delegate peek) + private Action CreateVBufferToVBufferSetter(Row input, int col, Delegate poke, Delegate peek) { var getter = input.GetGetter>(col); var typedPoke = poke as Poke>; @@ -460,68 +437,101 @@ public virtual void FillValues(TRow row) setter(row); } - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { - return _input.IsColumnActive(col); + return Input.IsColumnActive(col); } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { - return _input.GetGetter(col); + return Input.GetGetter(col); } } private sealed class TypedRow : TypedRowBase { - public TypedRow(TypedCursorable parent, IRow input) + public TypedRow(TypedCursorable parent, Row input) : base(parent, input, "Row") { } } - private sealed class TypedCursor : TypedRowBase, IRowCursor + private sealed class RowImplementation : IRowReadableAs { - private readonly IRowCursor _input; + private readonly TypedRow _row; private bool _disposed; - public TypedCursor(TypedCursorable parent, IRowCursor input) - : base(parent, input, "Cursor") + public void Dispose() { - _input = input; + if (_disposed) + return; + _row.Dispose(); + _disposed = true; } - public override void FillValues(TRow row) - { - Ch.Check(_input.State == CursorState.Good, "Can't fill values: the cursor is not active."); - base.FillValues(row); - } + public RowImplementation(TypedRow row) => _row = row; - public CursorState State { get { return _input.State; } } + public long Position => _row.Position; + public long Batch => _row.Batch; + public Schema Schema => _row.Schema; + public void FillValues(TRow row) => _row.FillValues(row); + public ValueGetter GetGetter(int col) => _row.GetGetter(col); + public ValueGetter GetIdGetter() => _row.GetIdGetter(); + public bool IsColumnActive(int col) => _row.IsColumnActive(col); + } - public void Dispose() - { - if (!_disposed) - { - _input.Dispose(); - Ch.Dispose(); - _disposed = true; - } - } + private sealed class RowCursorImplementation : RowCursor + { + private readonly TypedCursor _cursor; + private bool _disposed; + + public RowCursorImplementation(TypedCursor cursor) => _cursor = cursor; + + public override CursorState State => _cursor.State; + public override long Position => _cursor.Position; + public override long Batch => _cursor.Batch; + public override Schema Schema => _cursor.Schema; - public bool MoveNext() + protected override void Dispose(bool disposing) { - return _input.MoveNext(); + if (_disposed) + return; + if (disposing) + _cursor.Dispose(); + _disposed = true; + base.Dispose(disposing); } - public bool MoveMany(long count) + public override void FillValues(TRow row) => _cursor.FillValues(row); + public override ValueGetter GetGetter(int col) => _cursor.GetGetter(col); + public override ValueGetter GetIdGetter() => _cursor.GetIdGetter(); + public override RowCursor GetRootCursor() => _cursor.GetRootCursor(); + public override bool IsColumnActive(int col) => _cursor.IsColumnActive(col); + public override bool MoveMany(long count) => _cursor.MoveMany(count); + public override bool MoveNext() => _cursor.MoveNext(); + } + + private sealed class TypedCursor : TypedRowBase + { + private readonly RowCursor _input; + + public TypedCursor(TypedCursorable parent, RowCursor input) + : base(parent, input, "Cursor") { - return _input.MoveMany(count); + _input = input; } - public ICursor GetRootCursor() + public override void FillValues(TRow row) { - return _input.GetRootCursor(); + Ch.Check(_input.State == CursorState.Good, "Can't fill values: the cursor is not active."); + base.FillValues(row); } + + public CursorState State => _input.State; + + public bool MoveNext() => _input.MoveNext(); + public bool MoveMany(long count) => _input.MoveMany(count); + public RowCursor GetRootCursor() => _input.GetRootCursor(); } } diff --git a/src/Microsoft.ML.Data/DataView/ZipDataView.cs b/src/Microsoft.ML.Data/DataView/ZipDataView.cs index e481a96504..827b47f724 100644 --- a/src/Microsoft.ML.Data/DataView/ZipDataView.cs +++ b/src/Microsoft.ML.Data/DataView/ZipDataView.cs @@ -5,10 +5,9 @@ using System; using System.Collections.Generic; using System.Linq; -using Microsoft.ML.Data; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// This is a data view that is a 'zip' of several data views. @@ -71,7 +70,7 @@ private ZipDataView(IHost host, IDataView[] sources) return min; } - public IRowCursor GetRowCursor(Func predicate, IRandom rand = null) + public RowCursor GetRowCursor(Func predicate, Random rand = null) { _host.CheckValue(predicate, nameof(predicate)); _host.CheckValueOrNull(rand); @@ -89,31 +88,31 @@ public IRowCursor GetRowCursor(Func predicate, IRandom rand = null) } /// - /// Create an with no requested columns on a data view. + /// Create an with no requested columns on a data view. /// Potentially, this can be optimized by calling GetRowCount(lazy:true) first, and if the count is not known, /// wrapping around GetCursor(). /// - private IRowCursor GetMinimumCursor(IDataView dv) + private RowCursor GetMinimumCursor(IDataView dv) { _host.AssertValue(dv); return dv.GetRowCursor(x => false); } - public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, IRandom rand = null) + public RowCursor[] GetRowCursorSet(Func predicate, int n, Random rand = null) { - consolidator = null; - return new IRowCursor[] { GetRowCursor(predicate, rand) }; + return new RowCursor[] { GetRowCursor(predicate, rand) }; } - private sealed class Cursor : RootCursorBase, IRowCursor + private sealed class Cursor : RootCursorBase { - private readonly IRowCursor[] _cursors; + private readonly RowCursor[] _cursors; private readonly CompositeSchema _compositeSchema; private readonly bool[] _isColumnActive; + private bool _disposed; public override long Batch { get { return 0; } } - public Cursor(ZipDataView parent, IRowCursor[] srcCursors, Func predicate) + public Cursor(ZipDataView parent, RowCursor[] srcCursors, Func predicate) : base(parent._host) { Ch.AssertNonEmpty(srcCursors); @@ -124,20 +123,26 @@ public Cursor(ZipDataView parent, IRowCursor[] srcCursors, Func predi _isColumnActive = Utils.BuildArray(_compositeSchema.ColumnCount, predicate); } - public override void Dispose() + protected override void Dispose(bool disposing) { - for (int i = _cursors.Length - 1; i >= 0; i--) - _cursors[i].Dispose(); - base.Dispose(); + if (_disposed) + return; + if (disposing) + { + for (int i = _cursors.Length - 1; i >= 0; i--) + _cursors[i].Dispose(); + } + _disposed = true; + base.Dispose(disposing); } - public override ValueGetter GetIdGetter() + public override ValueGetter GetIdGetter() { return - (ref UInt128 val) => + (ref RowId val) => { Ch.Check(IsGood, "Cannot call ID getter in current state"); - val = new UInt128((ulong)Position, 0); + val = new RowId((ulong)Position, 0); }; } @@ -167,15 +172,15 @@ protected override bool MoveManyCore(long count) return true; } - public Schema Schema => _compositeSchema.AsSchema; + public override Schema Schema => _compositeSchema.AsSchema; - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { _compositeSchema.CheckColumnInRange(col); return _isColumnActive[col]; } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { int dv; int srcCol; diff --git a/src/Microsoft.ML.Data/DebuggerExtensions.cs b/src/Microsoft.ML.Data/DebuggerExtensions.cs index 1a7c6e4920..0f479b83b1 100644 --- a/src/Microsoft.ML.Data/DebuggerExtensions.cs +++ b/src/Microsoft.ML.Data/DebuggerExtensions.cs @@ -4,8 +4,6 @@ using Microsoft.ML.Core.Data; using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; using Microsoft.ML.Transforms; namespace Microsoft.ML diff --git a/src/Microsoft.ML.Data/Depricated/Instances/HeaderSchema.cs b/src/Microsoft.ML.Data/Depricated/Instances/HeaderSchema.cs index 84c5115532..9b81f5dbd1 100644 --- a/src/Microsoft.ML.Data/Depricated/Instances/HeaderSchema.cs +++ b/src/Microsoft.ML.Data/Depricated/Instances/HeaderSchema.cs @@ -4,21 +4,19 @@ #pragma warning disable 420 // volatile with Interlocked.CompareExchange -using Float = System.Single; - using System; using System.Collections; using System.Collections.Generic; using System.Linq; using System.Threading; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; -namespace Microsoft.ML.Runtime.Internal.Internallearn +namespace Microsoft.ML.Internal.Internallearn { - public abstract class FeatureNameCollection : IEnumerable + [BestFriend] + internal abstract class FeatureNameCollection : IEnumerable { private sealed class FeatureNameCollectionSchema : ISchema { @@ -186,13 +184,14 @@ public static FeatureNameCollection Create(RoleMappedSchema schema) { // REVIEW: This shim should be deleted as soon as is convenient. Contracts.CheckValue(schema, nameof(schema)); - Contracts.CheckParam(schema.Feature != null, nameof(schema), "Cannot create feature name collection if we have no features"); - Contracts.CheckParam(schema.Feature.Type.ValueCount > 0, nameof(schema), "Cannot create feature name collection if our features are not of known size"); + Contracts.CheckParam(schema.Feature.HasValue, nameof(schema), "Cannot create feature name collection if we have no features"); + var featureCol = schema.Feature.Value; + Contracts.CheckParam(schema.Feature.Value.Type.ValueCount > 0, nameof(schema), "Cannot create feature name collection if our features are not of known size"); VBuffer> slotNames = default; - int len = schema.Feature.Type.ValueCount; - if (schema.Schema.HasSlotNames(schema.Feature.Index, len)) - schema.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, schema.Feature.Index, ref slotNames); + int len = featureCol.Type.ValueCount; + if (featureCol.HasSlotNames(len)) + featureCol.Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref slotNames); else slotNames = VBufferUtils.CreateEmpty>(len); var slotNameValues = slotNames.GetValues(); diff --git a/src/Microsoft.ML.Data/Depricated/TGUIAttribute.cs b/src/Microsoft.ML.Data/Depricated/TGUIAttribute.cs index 5f09c604bb..fbf044c946 100644 --- a/src/Microsoft.ML.Data/Depricated/TGUIAttribute.cs +++ b/src/Microsoft.ML.Data/Depricated/TGUIAttribute.cs @@ -3,9 +3,8 @@ // See the LICENSE file in the project root for more information. using System; -using Microsoft.ML.Runtime; -namespace Microsoft.ML.Runtime.Internal.Internallearn +namespace Microsoft.ML.Internal.Internallearn { #pragma warning disable MSML_GeneralName // This structure should be deprecated anyway. // REVIEW: Get rid of this. Everything should be in the ArgumentAttribute (or a class diff --git a/src/Microsoft.ML.Data/Depricated/Vector/GenericSpanSortHelper.cs b/src/Microsoft.ML.Data/Depricated/Vector/GenericSpanSortHelper.cs index f94993e419..f11fbce53b 100644 --- a/src/Microsoft.ML.Data/Depricated/Vector/GenericSpanSortHelper.cs +++ b/src/Microsoft.ML.Data/Depricated/Vector/GenericSpanSortHelper.cs @@ -20,7 +20,7 @@ using System; -namespace Microsoft.ML.Runtime.Numeric +namespace Microsoft.ML.Numeric { internal static class IntrospectiveSortUtilities { @@ -41,10 +41,10 @@ internal static int FloorLog2PlusOne(int n) } } - internal partial class GenericSpanSortHelper + internal partial class GenericSpanSortHelper where TKey : IComparable { - public static void Sort(Span keys, Span values, int index, int length) + public static void Sort(Span keys, Span values, int index, int length) { Contracts.Assert(keys != null, "Check the arguments in the caller!"); Contracts.Assert(index >= 0 && length >= 0 && (keys.Length - index >= length), "Check the arguments in the caller!"); @@ -52,38 +52,41 @@ public static void Sort(Span keys, Span values, int index, int len IntrospectiveSort(keys, values, index, length); } - private static void SwapIfGreaterWithItems(Span keys, Span values, int a, int b) + public static void Sort(Span keys, int index, int length) + { + Sort(keys, keys, index, length); + } + + private static void SwapIfGreaterWithItems(Span keys, Span values, int a, int b) { if (a != b) { if (keys[a] != null && keys[a].CompareTo(keys[b]) > 0) { TKey key = keys[a]; - keys[a] = keys[b]; - keys[b] = key; - TValue value = values[a]; + keys[a] = keys[b]; values[a] = values[b]; + keys[b] = key; values[b] = value; } } } - private static void Swap(Span keys, Span values, int i, int j) + private static void Swap(Span keys, Span values, int i, int j) { if (i != j) { TKey k = keys[i]; - keys[i] = keys[j]; - keys[j] = k; - TValue v = values[i]; + keys[i] = keys[j]; values[i] = values[j]; + keys[j] = k; values[j] = v; } } - internal static void IntrospectiveSort(Span keys, Span values, int left, int length) + internal static void IntrospectiveSort(Span keys, Span values, int left, int length) { Contracts.Assert(keys != null); Contracts.Assert(values != null); @@ -99,7 +102,7 @@ internal static void IntrospectiveSort(Span keys, Span values, int IntroSort(keys, values, left, length + left - 1, 2 * IntrospectiveSortUtilities.FloorLog2PlusOne(length)); } - private static void IntroSort(Span keys, Span values, int lo, int hi, int depthLimit) + private static void IntroSort(Span keys, Span values, int lo, int hi, int depthLimit) { Contracts.Assert(keys != null); Contracts.Assert(values != null); @@ -146,7 +149,7 @@ private static void IntroSort(Span keys, Span values, int lo, int } } - private static int PickPivotAndPartition(Span keys, Span values, int lo, int hi) + private static int PickPivotAndPartition(Span keys, Span values, int lo, int hi) { Contracts.Assert(keys != null); Contracts.Assert(values != null); @@ -191,7 +194,7 @@ private static int PickPivotAndPartition(Span keys, Span values, i return left; } - private static void Heapsort(Span keys, Span values, int lo, int hi) + private static void Heapsort(Span keys, Span values, int lo, int hi) { Contracts.Assert(keys != null); Contracts.Assert(values != null); @@ -211,7 +214,7 @@ private static void Heapsort(Span keys, Span values, int lo, int h } } - private static void DownHeap(Span keys, Span values, int i, int n, int lo) + private static void DownHeap(Span keys, Span values, int i, int n, int lo) { Contracts.Assert(keys != null); Contracts.Assert(lo >= 0); @@ -237,7 +240,7 @@ private static void DownHeap(Span keys, Span values, int i, int n, values[lo + i - 1] = dValue; } - private static void InsertionSort(Span keys, Span values, int lo, int hi) + private static void InsertionSort(Span keys, Span values, int lo, int hi) { Contracts.Assert(keys != null); Contracts.Assert(values != null); @@ -265,5 +268,4 @@ private static void InsertionSort(Span keys, Span values, int lo, } } } - } diff --git a/src/Microsoft.ML.Data/Depricated/Vector/VBufferMathUtils.cs b/src/Microsoft.ML.Data/Depricated/Vector/VBufferMathUtils.cs index 9be89b4f83..89336aa064 100644 --- a/src/Microsoft.ML.Data/Depricated/Vector/VBufferMathUtils.cs +++ b/src/Microsoft.ML.Data/Depricated/Vector/VBufferMathUtils.cs @@ -3,11 +3,11 @@ // See the LICENSE file in the project root for more information. using System; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.CpuMath; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.CpuMath; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.Numeric +namespace Microsoft.ML.Numeric { // REVIEW: Once we do the conversions from Vector/WritableVector, review names of methods, // parameters, parameter order, etc. @@ -298,7 +298,8 @@ public static void AddMultWithOffset(in VBuffer src, Float c, ref VBuffer editor = VBufferEditor.Create(ref dst, dst.Length, dstValues.Length + gapCount, - keepOldOnResize: true); + keepOldOnResize: true, + requireIndicesOnDense: true); var indices = editor.Indices; values = editor.Values; if (gapCount > 0) diff --git a/src/Microsoft.ML.Data/Depricated/Vector/VectorUtils.cs b/src/Microsoft.ML.Data/Depricated/Vector/VectorUtils.cs index 58655063d2..564156fd1b 100644 --- a/src/Microsoft.ML.Data/Depricated/Vector/VectorUtils.cs +++ b/src/Microsoft.ML.Data/Depricated/Vector/VectorUtils.cs @@ -4,13 +4,12 @@ using System; using System.Collections.Generic; -using System.Linq; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.CpuMath; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.CpuMath; +using Microsoft.ML.Internal.Utilities; using Float = System.Single; -namespace Microsoft.ML.Runtime.Numeric +namespace Microsoft.ML.Numeric { /// /// A series of vector utility functions, generally operating over arrays or @@ -150,7 +149,7 @@ public static void SparsifyNormalize(ref VBuffer a, int top, int bottom, } if (!aEditor.Indices.IsEmpty) - GenericSpanSortHelper.Sort(aEditor.Indices, aEditor.Values, 0, newCount); + GenericSpanSortHelper.Sort(aEditor.Indices, aEditor.Values, 0, newCount); a = aEditor.Commit(); } diff --git a/src/Microsoft.ML.Data/Dirty/ChooseColumnsByIndexTransform.cs b/src/Microsoft.ML.Data/Dirty/ChooseColumnsByIndexTransform.cs index ec12f9bc52..a01695a9c4 100644 --- a/src/Microsoft.ML.Data/Dirty/ChooseColumnsByIndexTransform.cs +++ b/src/Microsoft.ML.Data/Dirty/ChooseColumnsByIndexTransform.cs @@ -3,14 +3,12 @@ // See the LICENSE file in the project root for more information. using System; -using System.Collections.Generic; using System.Linq; +using Microsoft.ML; +using Microsoft.ML.CommandLine; using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; [assembly: LoadableClass(typeof(ChooseColumnsByIndexTransform), typeof(ChooseColumnsByIndexTransform.Arguments), typeof(SignatureDataTransform), "", "ChooseColumnsByIndexTransform", "ChooseColumnsByIndex")] @@ -18,7 +16,7 @@ [assembly: LoadableClass(typeof(ChooseColumnsByIndexTransform), null, typeof(SignatureLoadDataTransform), "", ChooseColumnsByIndexTransform.LoaderSignature, ChooseColumnsByIndexTransform.LoaderSignatureOld)] -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { public sealed class ChooseColumnsByIndexTransform : RowToRowTransformBase { @@ -31,158 +29,146 @@ public sealed class Arguments public bool Drop; } - private sealed class Bindings : ISchema + private sealed class Bindings { - public readonly int[] Sources; - - private readonly ISchema _input; - private readonly Dictionary _nameToIndex; - - // The following argument is used only to inform serialization. - private readonly int[] _dropped; - - public Schema AsSchema { get; } - - public Bindings(Arguments args, ISchema schemaInput) + /// + /// A collection of source column indexes after removing those we want to drop. Specifically, j=_sources[i] means + /// that the i-th output column in the output schema is the j-th column in the input schema. + /// + private readonly int[] _sources; + + /// + /// Input schema of this transform. It's useful when determining column dependencies and other + /// relations between input and output schemas. + /// + private readonly Schema _sourceSchema; + + /// + /// Some column indexes in the input schema. is computed from + /// and . + /// + private readonly int[] _selectedColumnIndexes; + + /// + /// True, if this transform drops selected columns indexed by . + /// + private readonly bool _drop; + + // This transform's output schema. + internal Schema OutputSchema { get; } + + internal Bindings(Arguments args, Schema sourceSchema) { Contracts.AssertValue(args); - Contracts.AssertValue(schemaInput); + Contracts.AssertValue(sourceSchema); + + _sourceSchema = sourceSchema; - _input = schemaInput; + // Store user-specified arguments as the major state of this transform. Only the major states will + // be saved and all other attributes can be reconstructed from them. + _drop = args.Drop; + _selectedColumnIndexes = args.Index; - int[] indexCopy = args.Index == null ? new int[0] : args.Index.ToArray(); - BuildNameDict(indexCopy, args.Drop, out Sources, out _dropped, out _nameToIndex, user: true); + // Compute actually used attributes in runtime from those major states. + ComputeSources(_drop, _selectedColumnIndexes, _sourceSchema, out _sources); - AsSchema = Schema.Create(this); + // All necessary fields in this class are set, so we can compute output schema now. + OutputSchema = ComputeOutputSchema(); } - private void BuildNameDict(int[] indexCopy, bool drop, out int[] sources, out int[] dropped, out Dictionary nameToCol, bool user) + /// + /// Common method of computing from necessary parameters. This function is used in constructors. + /// + private static void ComputeSources(bool drop, int[] selectedColumnIndexes, Schema sourceSchema, out int[] sources) { - Contracts.AssertValue(indexCopy); - foreach (int col in indexCopy) - { - if (col < 0 || _input.ColumnCount <= col) - { - const string fmt = "Column index {0} invalid for input with {1} columns"; - if (user) - throw Contracts.ExceptUserArg(nameof(Arguments.Index), fmt, col, _input.ColumnCount); - else - throw Contracts.ExceptDecode(fmt, col, _input.ColumnCount); - } - } + // Compute the mapping, , from output column index to input column index. if (drop) - { - sources = Enumerable.Range(0, _input.ColumnCount).Except(indexCopy).ToArray(); - dropped = indexCopy; - } - else - { - sources = indexCopy; - dropped = null; - } - if (user) - Contracts.CheckUserArg(sources.Length > 0, nameof(Arguments.Index), "Choose columns by index has no output columns"); + // Drop columns indexed by args.Index + sources = Enumerable.Range(0, sourceSchema.Count).Except(selectedColumnIndexes).ToArray(); else - Contracts.CheckDecode(sources.Length > 0, "Choose columns by index has no output columns"); - nameToCol = new Dictionary(); - for (int c = 0; c < sources.Length; ++c) - nameToCol[_input.GetColumnName(sources[c])] = c; - } - - public Bindings(ModelLoadContext ctx, ISchema schemaInput) - { - Contracts.AssertValue(ctx); - Contracts.AssertValue(schemaInput); - - _input = schemaInput; + // Keep columns indexed by args.Index + sources = selectedColumnIndexes; - // *** Binary format *** - // bool(as byte): whether the indicated source columns are columns to keep, or drop - // int: number of source column indices - // int[]: source column indices - - bool isDrop = ctx.Reader.ReadBoolByte(); - BuildNameDict(ctx.Reader.ReadIntArray() ?? new int[0], isDrop, out Sources, out _dropped, out _nameToIndex, user: false); - AsSchema = Schema.Create(this); + // Make sure the output of this transform is meaningful. + Contracts.Check(sources.Length > 0, "Choose columns by index has no output column."); } - public void Save(ModelSaveContext ctx) + /// + /// After and are set, pick up selected columns from to create + /// Note that tells us what columns in are put into . + /// + private Schema ComputeOutputSchema() { - Contracts.AssertValue(ctx); + var schemaBuilder = new SchemaBuilder(); + for (int i = 0; i < _sources.Length; ++i) + { + // selectedIndex is an column index of input schema. Note that the input column indexed by _sources[i] in _sourceSchema is sent + // to the i-th column in the output schema. + var selectedIndex = _sources[i]; - // *** Binary format *** - // bool(as byte): whether the indicated columns are columns to keep, or drop - // int: number of source column indices - // int[]: source column indices + // The dropped/kept columns are determined by user-specified arguments, so we throw if a bad configuration is provided. + string fmt = string.Format("Column index {0} invalid for input with {1} columns", selectedIndex, _sourceSchema.Count); + Contracts.Check(selectedIndex < _sourceSchema.Count, fmt); - ctx.Writer.WriteBoolByte(_dropped != null); - ctx.Writer.WriteIntArray(_dropped ?? Sources); + // Copy the selected column into output schema. + var selectedColumn = _sourceSchema[selectedIndex]; + schemaBuilder.AddColumn(selectedColumn.Name, selectedColumn.Type, selectedColumn.Metadata); + } + return schemaBuilder.GetSchema(); } - public int ColumnCount + internal Bindings(ModelLoadContext ctx, Schema sourceSchema) { - get { return Sources.Length; } - } + Contracts.AssertValue(ctx); + Contracts.AssertValue(sourceSchema); - public bool TryGetColumnIndex(string name, out int col) - { - Contracts.CheckValueOrNull(name); - if (name == null) - { - col = default(int); - return false; - } - return _nameToIndex.TryGetValue(name, out col); - } + _sourceSchema = sourceSchema; - public string GetColumnName(int col) - { - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - return _input.GetColumnName(Sources[col]); - } + // *** Binary format *** + // bool (as byte): operation mode + // int[]: selected source column indices + _drop = ctx.Reader.ReadBoolByte(); + _selectedColumnIndexes = ctx.Reader.ReadIntArray(); - public ColumnType GetColumnType(int col) - { - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - return _input.GetColumnType(Sources[col]); - } + // Compute actually used attributes in runtime from those major states. + ComputeSources(_drop, _selectedColumnIndexes, _sourceSchema, out _sources); - public IEnumerable> GetMetadataTypes(int col) - { - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - return _input.GetMetadataTypes(Sources[col]); + _sourceSchema = sourceSchema; + OutputSchema = ComputeOutputSchema(); } - public ColumnType GetMetadataTypeOrNull(string kind, int col) + internal void Save(ModelSaveContext ctx) { - Contracts.CheckNonEmpty(kind, nameof(kind)); - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - return _input.GetMetadataTypeOrNull(kind, Sources[col]); - } + Contracts.AssertValue(ctx); - public void GetMetadata(string kind, int col, ref TValue value) - { - Contracts.CheckNonEmpty(kind, nameof(kind)); - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - _input.GetMetadata(kind, Sources[col], ref value); + // *** Binary format *** + // bool (as byte): operation mode + // int[]: selected source column indices + ctx.Writer.WriteBoolByte(_drop); + ctx.Writer.WriteIntArray(_selectedColumnIndexes); } internal bool[] GetActive(Func predicate) { - return Utils.BuildArray(ColumnCount, predicate); + return Utils.BuildArray(OutputSchema.Count, predicate); } internal Func GetDependencies(Func predicate) { Contracts.AssertValue(predicate); - var active = new bool[_input.ColumnCount]; - for (int i = 0; i < Sources.Length; i++) + var active = new bool[_sourceSchema.Count]; + for (int i = 0; i < _sources.Length; i++) { if (predicate(i)) - active[Sources[i]] = true; + active[_sources[i]] = true; } return col => 0 <= col && col < active.Length && active[col]; } + + /// + /// Given the column index in the output schema, this function returns its source column's index in the input schema. + /// + internal int GetSourceColumnIndex(int outputColumnIndex) => _sources[outputColumnIndex]; } public const string LoaderSignature = "ChooseColumnsIdxTrans"; @@ -245,7 +231,7 @@ public override void Save(ModelSaveContext ctx) _bindings.Save(ctx); } - public override Schema Schema => _bindings.AsSchema; + public override Schema OutputSchema => _bindings.OutputSchema; protected override bool? ShouldUseParallelCursors(Func predicate) { @@ -254,7 +240,7 @@ public override void Save(ModelSaveContext ctx) return null; } - protected override IRowCursor GetRowCursorCore(Func predicate, IRandom rand = null) + protected override RowCursor GetRowCursorCore(Func predicate, Random rand = null) { Host.AssertValue(predicate, "predicate"); Host.AssertValueOrNull(rand); @@ -262,55 +248,54 @@ protected override IRowCursor GetRowCursorCore(Func predicate, IRando var inputPred = _bindings.GetDependencies(predicate); var active = _bindings.GetActive(predicate); var input = Source.GetRowCursor(inputPred, rand); - return new RowCursor(Host, _bindings, input, active); + return new Cursor(Host, _bindings, input, active); } - public sealed override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, - Func predicate, int n, IRandom rand = null) + public sealed override RowCursor[] GetRowCursorSet(Func predicate, int n, Random rand = null) { Host.CheckValue(predicate, nameof(predicate)); Host.CheckValueOrNull(rand); var inputPred = _bindings.GetDependencies(predicate); var active = _bindings.GetActive(predicate); - var inputs = Source.GetRowCursorSet(out consolidator, inputPred, n, rand); + var inputs = Source.GetRowCursorSet(inputPred, n, rand); Host.AssertNonEmpty(inputs); // No need to split if this is given 1 input cursor. - var cursors = new IRowCursor[inputs.Length]; + var cursors = new RowCursor[inputs.Length]; for (int i = 0; i < inputs.Length; i++) - cursors[i] = new RowCursor(Host, _bindings, inputs[i], active); + cursors[i] = new Cursor(Host, _bindings, inputs[i], active); return cursors; } - private sealed class RowCursor : SynchronizedCursorBase, IRowCursor + private sealed class Cursor : SynchronizedCursorBase { private readonly Bindings _bindings; private readonly bool[] _active; - public RowCursor(IChannelProvider provider, Bindings bindings, IRowCursor input, bool[] active) + public Cursor(IChannelProvider provider, Bindings bindings, RowCursor input, bool[] active) : base(provider, input) { Ch.AssertValue(bindings); - Ch.Assert(active == null || active.Length == bindings.ColumnCount); + Ch.Assert(active == null || active.Length == bindings.OutputSchema.Count); _bindings = bindings; _active = active; } - public Schema Schema => _bindings.AsSchema; + public override Schema Schema => _bindings.OutputSchema; - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { - Ch.Check(0 <= col && col < _bindings.ColumnCount); + Ch.Check(0 <= col && col < _bindings.OutputSchema.Count); return _active == null || _active[col]; } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Ch.Check(IsColumnActive(col)); - var src = _bindings.Sources[col]; + var src = _bindings.GetSourceColumnIndex(col); return Input.GetGetter(src); } } diff --git a/src/Microsoft.ML.Data/Dirty/ILoss.cs b/src/Microsoft.ML.Data/Dirty/ILoss.cs index 69c394c76e..1bb1350e0b 100644 --- a/src/Microsoft.ML.Data/Dirty/ILoss.cs +++ b/src/Microsoft.ML.Data/Dirty/ILoss.cs @@ -2,12 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Float = System.Single; - using System; -using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML.EntryPoints; +using Float = System.Single; -namespace Microsoft.ML.Runtime +namespace Microsoft.ML { public interface ILossFunction { diff --git a/src/Microsoft.ML.Data/Dirty/IniFileUtils.cs b/src/Microsoft.ML.Data/Dirty/IniFileUtils.cs index 782c07af01..f704aca7a0 100644 --- a/src/Microsoft.ML.Data/Dirty/IniFileUtils.cs +++ b/src/Microsoft.ML.Data/Dirty/IniFileUtils.cs @@ -4,9 +4,9 @@ using System.Text; using System.Text.RegularExpressions; -using Microsoft.ML.Runtime.Internal.Calibration; +using Microsoft.ML.Internal.Calibration; -namespace Microsoft.ML.Runtime.Internal.Utilities +namespace Microsoft.ML.Internal.Utilities { [BestFriend] internal static class IniFileUtils @@ -52,10 +52,10 @@ public static string GetCalibratorEvaluatorIni(string originalIni, PlattCalibrat StringBuilder newEvaluator = new StringBuilder(); newEvaluator.AppendLine("EvaluatorType=Aggregator"); newEvaluator.AppendLine("Type=Sigmoid"); - newEvaluator.AppendLine("Bias=" + -calibrator.ParamB); + newEvaluator.AppendLine("Bias=" + -calibrator.Offset); newEvaluator.AppendLine("NumNodes=1"); newEvaluator.AppendLine("Nodes=E:" + NumEvaluators(originalIni)); - newEvaluator.AppendLine("Weights=" + -calibrator.ParamA); + newEvaluator.AppendLine("Weights=" + -calibrator.Slope); return newEvaluator.ToString(); } } diff --git a/src/Microsoft.ML.Data/Dirty/PredictorBase.cs b/src/Microsoft.ML.Data/Dirty/ModelParametersBase.cs similarity index 78% rename from src/Microsoft.ML.Data/Dirty/PredictorBase.cs rename to src/Microsoft.ML.Data/Dirty/ModelParametersBase.cs index 0d73db05fc..0366772768 100644 --- a/src/Microsoft.ML.Data/Dirty/PredictorBase.cs +++ b/src/Microsoft.ML.Data/Dirty/ModelParametersBase.cs @@ -2,20 +2,17 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Float = System.Single; - using System; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Model; -namespace Microsoft.ML.Runtime.Internal.Internallearn +namespace Microsoft.ML.Internal.Internallearn { /// /// A base class for predictors producing . /// Note: This provides essentially no value going forward. New predictors should just /// derive from the interfaces they need. /// - public abstract class PredictorBase : IPredictorProducing + public abstract class ModelParametersBase : ICanSaveModel, IPredictorProducing { public const string NormalizerWarningFormat = "Ignoring integrated normalizer while loading a predictor of type {0}.{1}" + @@ -23,14 +20,14 @@ public abstract class PredictorBase : IPredictorProducing protected readonly IHost Host; - protected PredictorBase(IHostEnvironment env, string name) + protected ModelParametersBase(IHostEnvironment env, string name) { Contracts.CheckValue(env, nameof(env)); env.CheckNonWhiteSpace(name, nameof(name)); Host = env.Register(name); } - protected PredictorBase(IHostEnvironment env, string name, ModelLoadContext ctx) + protected ModelParametersBase(IHostEnvironment env, string name, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckNonWhiteSpace(name, nameof(name)); @@ -42,25 +39,29 @@ protected PredictorBase(IHostEnvironment env, string name, ModelLoadContext ctx) // Verify that the Float type matches. int cbFloat = ctx.Reader.ReadInt32(); #pragma warning disable MSML_NoMessagesForLoadContext // This one is actually useful. - Host.CheckDecode(cbFloat == sizeof(Float), "This file was saved by an incompatible version"); + Host.CheckDecode(cbFloat == sizeof(float), "This file was saved by an incompatible version"); #pragma warning restore MSML_NoMessagesForLoadContext } - public virtual void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) => Save(ctx); + + [BestFriend] + private protected virtual void Save(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); SaveCore(ctx); } - protected virtual void SaveCore(ModelSaveContext ctx) + [BestFriend] + private protected virtual void SaveCore(ModelSaveContext ctx) { Contracts.AssertValue(ctx); // *** Binary format *** // int: sizeof(Float) // - ctx.Writer.Write(sizeof(Float)); + ctx.Writer.Write(sizeof(float)); } public abstract PredictionKind PredictionKind { get; } diff --git a/src/Microsoft.ML.Data/Dirty/PredictionUtils.cs b/src/Microsoft.ML.Data/Dirty/PredictionUtils.cs index 9a78defd82..653a05fbb7 100644 --- a/src/Microsoft.ML.Data/Dirty/PredictionUtils.cs +++ b/src/Microsoft.ML.Data/Dirty/PredictionUtils.cs @@ -6,11 +6,10 @@ using System.Collections.Generic; using System.Linq; using System.Text; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.Internal.Internallearn +namespace Microsoft.ML.Internal.Internallearn { using Float = System.Single; diff --git a/src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs b/src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs index 153f442d32..905d91a89e 100644 --- a/src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs +++ b/src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs @@ -2,15 +2,13 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Float = System.Single; - using System; using System.Collections.Generic; using System.IO; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Calibration; +using Microsoft.ML.Calibrator; +using Microsoft.ML.Data; -namespace Microsoft.ML.Runtime.Internal.Internallearn +namespace Microsoft.ML.Internal.Internallearn { /// @@ -21,7 +19,8 @@ namespace Microsoft.ML.Runtime.Internal.Internallearn /// /// A generic interface for models that can average parameters from multiple instance of self /// - public interface IParameterMixer + [BestFriend] + internal interface IParameterMixer { IParameterMixer CombineParameters(IList models); } @@ -29,7 +28,8 @@ public interface IParameterMixer /// /// A generic interface for models that can average parameters from multiple instance of self /// - public interface IParameterMixer + [BestFriend] + internal interface IParameterMixer { IParameterMixer CombineParameters(IList> models); } @@ -38,7 +38,8 @@ public interface IParameterMixer /// Predictor that can specialize for quantile regression. It will produce a , given /// an array of quantiles. /// - public interface IQuantileRegressionPredictor + [BestFriend] + internal interface IQuantileRegressionPredictor { ISchemaBindableMapper CreateMapper(Double[] quantiles); } @@ -47,7 +48,8 @@ public interface IQuantileRegressionPredictor /// A generic interface for probability distributions /// /// Type of statistics result - public interface IDistribution + [BestFriend] + internal interface IDistribution { TResult Minimum { get; } @@ -59,16 +61,18 @@ public interface IDistribution } // REVIEW: How should this quantile stuff work? - public interface IQuantileValueMapper + [BestFriend] + internal interface IQuantileValueMapper { - ValueMapper, VBuffer> GetMapper(Float[] quantiles); + ValueMapper, VBuffer> GetMapper(float[] quantiles); } /// /// Interface for quantile distribution /// /// Type of statistics result - public interface IQuantileDistribution : IDistribution, ISampleableDistribution + [BestFriend] + internal interface IQuantileDistribution : IDistribution, ISampleableDistribution { TResult Median { get; } @@ -76,10 +80,11 @@ public interface IQuantileDistribution : IDistribution, ISampl /// Returns an estimate of the p-th quantile, the data value where proportionately p of the data has value /// less than or equal to the returned value. /// - TResult GetQuantile(Float p); + TResult GetQuantile(float p); } - public interface ISampleableDistribution : IDistribution + [BestFriend] + internal interface ISampleableDistribution : IDistribution { /// /// Returns Support sample for the distribution. @@ -92,7 +97,8 @@ public interface ISampleableDistribution : IDistribution /// /// Predictors that can output themselves in a human-readable text format /// - public interface ICanSaveInTextFormat + [BestFriend] + internal interface ICanSaveInTextFormat { void SaveAsText(TextWriter writer, RoleMappedSchema schema); } @@ -100,7 +106,8 @@ public interface ICanSaveInTextFormat /// /// Predictors that can output themselves in the Bing ini format. /// - public interface ICanSaveInIniFormat + [BestFriend] + internal interface ICanSaveInIniFormat { void SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICalibrator calibrator = null); } @@ -108,7 +115,8 @@ public interface ICanSaveInIniFormat /// /// Predictors that can output Summary. /// - public interface ICanSaveSummary + [BestFriend] + internal interface ICanSaveSummary { void SaveSummary(TextWriter writer, RoleMappedSchema schema); } @@ -118,7 +126,8 @@ public interface ICanSaveSummary /// The content of value 'object' can be any type such as integer, float, string or an array of them. /// It is up the caller to check and decide how to consume the values. /// - public interface ICanGetSummaryInKeyValuePairs + [BestFriend] + internal interface ICanGetSummaryInKeyValuePairs { /// /// Gets model summary including model statistics (if exists) in key value pairs. @@ -126,14 +135,16 @@ public interface ICanGetSummaryInKeyValuePairs IList> GetSummaryInKeyValuePairs(RoleMappedSchema schema); } - public interface ICanGetSummaryAsIRow + [BestFriend] + internal interface ICanGetSummaryAsIRow { - IRow GetSummaryIRowOrNull(RoleMappedSchema schema); + Row GetSummaryIRowOrNull(RoleMappedSchema schema); - IRow GetStatsIRowOrNull(RoleMappedSchema schema); + Row GetStatsIRowOrNull(RoleMappedSchema schema); } - public interface ICanGetSummaryAsIDataView + [BestFriend] + internal interface ICanGetSummaryAsIDataView { IDataView GetSummaryDataView(RoleMappedSchema schema); } @@ -141,7 +152,8 @@ public interface ICanGetSummaryAsIDataView /// /// Predictors that can output themselves in C#/C++ code. /// - public interface ICanSaveInSourceCode + [BestFriend] + internal interface ICanSaveInSourceCode { void SaveAsCode(TextWriter writer, RoleMappedSchema schema); } @@ -163,7 +175,7 @@ public interface IHaveFeatureWeights /// The larger the absolute value of a weights, the more informative/important the feature. /// A weights of zero signifies that the feature is not used by the model. /// - void GetFeatureWeights(ref VBuffer weights); + void GetFeatureWeights(ref VBuffer weights); } /// @@ -177,7 +189,8 @@ public interface IPredictorWithFeatureWeights : IHaveFeatureWeights /// Interface for mapping input values to corresponding feature contributions. /// This interface is commonly implemented by predictors. /// - public interface IFeatureContributionMapper : IPredictor + [BestFriend] + internal interface IFeatureContributionMapper : IPredictor { /// /// Get a delegate for mapping Contributions to Features. @@ -187,7 +200,24 @@ public interface IFeatureContributionMapper : IPredictor /// For trees we will not have negative contributions, so bottom param will be ignored. /// If normalization is requested that resulting values will be normalized to [-1, 1]. /// - ValueMapper> GetFeatureContributionMapper(int top, int bottom, bool normalize); + ValueMapper> GetFeatureContributionMapper(int top, int bottom, bool normalize); + } + + /// + /// Allows support for feature contribution calculation. + /// + public interface ICalculateFeatureContribution : IPredictor + { + FeatureContributionCalculator FeatureContributionClaculator { get; } + } + + /// + /// Support for feature contribution calculation. + /// + public sealed class FeatureContributionCalculator + { + internal IFeatureContributionMapper ContributionMapper { get; } + internal FeatureContributionCalculator(IFeatureContributionMapper contributionMapper) => ContributionMapper = contributionMapper; } /// diff --git a/src/Microsoft.ML.Data/Dirty/PredictorUtils.cs b/src/Microsoft.ML.Data/Dirty/PredictorUtils.cs index bcd5492f59..9c7adcf095 100644 --- a/src/Microsoft.ML.Data/Dirty/PredictorUtils.cs +++ b/src/Microsoft.ML.Data/Dirty/PredictorUtils.cs @@ -3,16 +3,16 @@ // See the LICENSE file in the project root for more information. using System.IO; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Data; +using Microsoft.ML.Model; -namespace Microsoft.ML.Runtime.Internal.Internallearn +namespace Microsoft.ML.Internal.Internallearn { - public static class PredictorUtils + [BestFriend] + internal static class PredictorUtils { /// - /// Save the model summary + /// Save the model summary. /// public static void SaveSummary(IChannel ch, IPredictor predictor, RoleMappedSchema schema, TextWriter writer) { diff --git a/src/Microsoft.ML.Data/EntryPoints/Cache.cs b/src/Microsoft.ML.Data/EntryPoints/Cache.cs index aa2693f4aa..27d739e444 100644 --- a/src/Microsoft.ML.Data/EntryPoints/Cache.cs +++ b/src/Microsoft.ML.Data/EntryPoints/Cache.cs @@ -2,19 +2,15 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; using System.Collections.Generic; -using System.IO; -using System.Linq; +using Microsoft.ML; +using Microsoft.ML.CommandLine; using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Data.IO; -using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML.Data.IO; +using Microsoft.ML.EntryPoints; [assembly: LoadableClass(typeof(void), typeof(Cache), null, typeof(SignatureEntryPointModule), "Cache")] -namespace Microsoft.ML.Runtime.EntryPoints +namespace Microsoft.ML.EntryPoints { public static class Cache { @@ -61,9 +57,9 @@ public static CacheOutput CacheData(IHostEnvironment env, CacheInput input) var schema = input.Data.Schema; var cols = new List(); - for (int i = 0; i < schema.ColumnCount; i++) + for (int i = 0; i < schema.Count; i++) { - var type = schema.GetColumnType(i); + var type = schema[i].Type; if (saver.IsColumnSavable(type)) cols.Add(i); } diff --git a/src/Microsoft.ML.Data/EntryPoints/CommonOutputs.cs b/src/Microsoft.ML.Data/EntryPoints/CommonOutputs.cs index 37f37f6c64..8d593d132c 100644 --- a/src/Microsoft.ML.Data/EntryPoints/CommonOutputs.cs +++ b/src/Microsoft.ML.Data/EntryPoints/CommonOutputs.cs @@ -2,11 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; using System.Collections.Generic; -using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Data; -namespace Microsoft.ML.Runtime.EntryPoints +namespace Microsoft.ML.EntryPoints { /// /// Common output classes for trainers and transforms. @@ -24,7 +23,7 @@ public sealed class TransformOutput public IDataView OutputData; [TlcModule.Output(Desc = "Transform model", SortOrder = 2)] - public ITransformModel Model; + public TransformModel Model; } /// @@ -33,7 +32,7 @@ public sealed class TransformOutput public interface ITransformOutput { Var OutputData { get; } - Var Model { get; } + Var Model { get; } } /// @@ -44,7 +43,7 @@ public interface ITransformOutput public abstract class TrainerOutput { [TlcModule.Output(Desc = "The trained model", SortOrder = 1)] - public IPredictorModel PredictorModel; + public PredictorModel PredictorModel; } /// @@ -187,7 +186,7 @@ public interface ISequencePredictionOutput /// public interface ITrainerOutput { - Var PredictorModel { get; } + Var PredictorModel { get; } } /// diff --git a/src/Microsoft.ML.Data/EntryPoints/EntryPointNode.cs b/src/Microsoft.ML.Data/EntryPoints/EntryPointNode.cs index 371a58fcae..ac1da11a49 100644 --- a/src/Microsoft.ML.Data/EntryPoints/EntryPointNode.cs +++ b/src/Microsoft.ML.Data/EntryPoints/EntryPointNode.cs @@ -7,13 +7,13 @@ using System.Diagnostics; using System.Linq; using System.Text.RegularExpressions; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints.JsonUtils; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints.JsonUtils; +using Microsoft.ML.Internal.Utilities; using Newtonsoft.Json; using Newtonsoft.Json.Linq; -namespace Microsoft.ML.Runtime.EntryPoints +namespace Microsoft.ML.EntryPoints { public class VarSerializer : JsonConverter { @@ -77,11 +77,10 @@ public static bool CheckType(Type type) return type == typeof(IDataView) || type == typeof(IFileHandle) || - type == typeof(IPredictorModel) || - type == typeof(ITransformModel) || + type == typeof(PredictorModel) || + type == typeof(TransformModel) || type == typeof(CommonInputs.IEvaluatorInput) || - type == typeof(CommonOutputs.IEvaluatorOutput) || - type == typeof(IMlState); + type == typeof(CommonOutputs.IEvaluatorOutput); } } @@ -186,8 +185,6 @@ public static bool IsValidType(Type variableType) return true; if (variableType == typeof(CommonOutputs.IEvaluatorOutput)) return true; - if (variableType == typeof(IMlState)) - return true; var kind = TlcModule.GetDataType(variableType); if (kind == TlcModule.DataKind.Array) diff --git a/src/Microsoft.ML.Data/EntryPoints/InputBase.cs b/src/Microsoft.ML.Data/EntryPoints/InputBase.cs index 3254289f12..b895fcbf91 100644 --- a/src/Microsoft.ML.Data/EntryPoints/InputBase.cs +++ b/src/Microsoft.ML.Data/EntryPoints/InputBase.cs @@ -4,12 +4,12 @@ using System; using System.Collections.Generic; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Data.IO; -using Microsoft.ML.Runtime.Internal.Calibration; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.Data.IO; +using Microsoft.ML.Internal.Calibration; -namespace Microsoft.ML.Runtime.EntryPoints +namespace Microsoft.ML.EntryPoints { /// /// The base class for all transform inputs. @@ -102,7 +102,7 @@ public abstract class LearnerInputBaseWithGroupId : LearnerInputBaseWithWeight [BestFriend] internal static class LearnerEntryPointsUtils { - public static string FindColumn(IExceptionContext ectx, ISchema schema, Optional value) + public static string FindColumn(IExceptionContext ectx, Schema schema, Optional value) { Contracts.CheckValueOrNull(ectx); ectx.CheckValue(schema, nameof(schema)); @@ -133,7 +133,7 @@ public static TOut Train(IHost host, TArg input, { using (var ch = host.Start("Training")) { - ISchema schema = input.TrainingData.Schema; + var schema = input.TrainingData.Schema; var feature = FindColumn(ch, schema, input.FeatureColumn); var label = getLabel?.Invoke(); var weight = getWeight?.Invoke(); @@ -188,7 +188,7 @@ public static TOut Train(IHost host, TArg input, } var predictor = TrainUtils.Train(host, ch, cachedRoleMappedData, trainer, calibrator, maxCalibrationExamples); - return new TOut() { PredictorModel = new PredictorModel(host, roleMappedData, input.TrainingData, predictor) }; + return new TOut() { PredictorModel = new PredictorModelImpl(host, roleMappedData, input.TrainingData, predictor) }; } } } @@ -211,7 +211,7 @@ public interface ITransformInput /// public interface IFeaturizerInput : ITransformInput { - Var PredictorModel { get; set; } + Var PredictorModel { get; set; } } /// @@ -260,7 +260,7 @@ public interface ITrainerInputWithGroupId : ITrainerInputWithWeight /// public interface ICalibratorInput : ITransformInput { - Var UncalibratedPredictorModel { get; } + Var UncalibratedPredictorModel { get; } int MaxRows { get; } } diff --git a/src/Microsoft.ML.Data/EntryPoints/InputBuilder.cs b/src/Microsoft.ML.Data/EntryPoints/InputBuilder.cs index 0c8dae6438..40a9e2a5b5 100644 --- a/src/Microsoft.ML.Data/EntryPoints/InputBuilder.cs +++ b/src/Microsoft.ML.Data/EntryPoints/InputBuilder.cs @@ -6,12 +6,12 @@ using System.Collections.Generic; using System.Linq; using System.Reflection; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Utilities; using Newtonsoft.Json.Linq; -namespace Microsoft.ML.Runtime.EntryPoints.JsonUtils +namespace Microsoft.ML.EntryPoints.JsonUtils { /// /// The class that creates and wraps around an instance of an input object and gradually populates all fields, keeping track of missing diff --git a/src/Microsoft.ML.Data/EntryPoints/PredictorModel.cs b/src/Microsoft.ML.Data/EntryPoints/PredictorModelImpl.cs similarity index 58% rename from src/Microsoft.ML.Data/EntryPoints/PredictorModel.cs rename to src/Microsoft.ML.Data/EntryPoints/PredictorModelImpl.cs index f450c5e39a..80ecd4a62d 100644 --- a/src/Microsoft.ML.Data/EntryPoints/PredictorModel.cs +++ b/src/Microsoft.ML.Data/EntryPoints/PredictorModelImpl.cs @@ -6,66 +6,68 @@ using System.Collections.Generic; using System.IO; using System.Linq; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Calibration; -using Microsoft.ML.Runtime.Internal.Internallearn; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Calibration; +using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Model; -namespace Microsoft.ML.Runtime.EntryPoints +namespace Microsoft.ML.EntryPoints { /// - /// This class encapsulates the predictor and a preceding transform model. + /// This class encapsulates the predictor and a preceding transform model, as the concrete and hidden + /// implementation of . /// - public sealed class PredictorModel : IPredictorModel + [BestFriend] + internal sealed class PredictorModelImpl : PredictorModel { - private readonly IPredictor _predictor; - private readonly ITransformModel _transformModel; private readonly KeyValuePair[] _roleMappings; - public ITransformModel TransformModel { get { return _transformModel; } } + internal override TransformModel TransformModel { get; } - public IPredictor Predictor { get { return _predictor; } } + internal override IPredictor Predictor { get; } - public PredictorModel(IHostEnvironment env, RoleMappedData trainingData, IDataView startingData, IPredictor predictor) + [BestFriend] + internal PredictorModelImpl(IHostEnvironment env, RoleMappedData trainingData, IDataView startingData, IPredictor predictor) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(trainingData, nameof(trainingData)); env.CheckValue(predictor, nameof(predictor)); - _transformModel = new TransformModel(env, trainingData.Data, startingData); + TransformModel = new TransformModelImpl(env, trainingData.Data, startingData); _roleMappings = trainingData.Schema.GetColumnRoleNames().ToArray(); - _predictor = predictor; + Predictor = predictor; } - public PredictorModel(IHostEnvironment env, Stream stream) + [BestFriend] + internal PredictorModelImpl(IHostEnvironment env, Stream stream) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(stream, nameof(stream)); using (var ch = env.Start("Loading predictor model")) { // REVIEW: address the asymmetry in the way we're loading and saving the model. - _transformModel = new TransformModel(env, stream); + TransformModel = new TransformModelImpl(env, stream); var roles = ModelFileUtils.LoadRoleMappingsOrNull(env, stream); env.CheckDecode(roles != null, "Predictor model must contain role mappings"); _roleMappings = roles.ToArray(); - _predictor = ModelFileUtils.LoadPredictorOrNull(env, stream); - env.CheckDecode(_predictor != null, "Predictor model must contain a predictor"); + Predictor = ModelFileUtils.LoadPredictorOrNull(env, stream); + env.CheckDecode(Predictor != null, "Predictor model must contain a predictor"); } } - private PredictorModel(ITransformModel transformModel, IPredictor predictor, KeyValuePair[] roleMappings) + private PredictorModelImpl(TransformModel transformModel, IPredictor predictor, KeyValuePair[] roleMappings) { Contracts.AssertValue(transformModel); Contracts.AssertValue(predictor); Contracts.AssertValue(roleMappings); - _transformModel = transformModel; - _predictor = predictor; + TransformModel = transformModel; + Predictor = predictor; _roleMappings = roleMappings; } - public void Save(IHostEnvironment env, Stream stream) + internal override void Save(IHostEnvironment env, Stream stream) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(stream, nameof(stream)); @@ -77,37 +79,37 @@ public void Save(IHostEnvironment env, Stream stream) // (we use the TrainUtils.SaveModel that does all three). // Create the chain of transforms for saving. - IDataView data = new EmptyDataView(env, _transformModel.InputSchema); - data = _transformModel.Apply(env, data); + IDataView data = new EmptyDataView(env, TransformModel.InputSchema); + data = TransformModel.Apply(env, data); var roleMappedData = new RoleMappedData(data, _roleMappings, opt: true); - TrainUtils.SaveModel(env, ch, stream, _predictor, roleMappedData); + TrainUtils.SaveModel(env, ch, stream, Predictor, roleMappedData); } } - public IPredictorModel Apply(IHostEnvironment env, ITransformModel transformModel) + internal override PredictorModel Apply(IHostEnvironment env, TransformModel transformModel) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(transformModel, nameof(transformModel)); - ITransformModel newTransformModel = _transformModel.Apply(env, transformModel); + TransformModel newTransformModel = TransformModel.Apply(env, transformModel); Contracts.AssertValue(newTransformModel); - return new PredictorModel(newTransformModel, _predictor, _roleMappings); + return new PredictorModelImpl(newTransformModel, Predictor, _roleMappings); } - public void PrepareData(IHostEnvironment env, IDataView input, out RoleMappedData roleMappedData, out IPredictor predictor) + internal override void PrepareData(IHostEnvironment env, IDataView input, out RoleMappedData roleMappedData, out IPredictor predictor) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(input, nameof(input)); - input = _transformModel.Apply(env, input); + input = TransformModel.Apply(env, input); roleMappedData = new RoleMappedData(input, _roleMappings, opt: true); - predictor = _predictor; + predictor = Predictor; } - public string[] GetLabelInfo(IHostEnvironment env, out ColumnType labelType) + internal override string[] GetLabelInfo(IHostEnvironment env, out ColumnType labelType) { Contracts.CheckValue(env, nameof(env)); - var predictor = _predictor; + var predictor = Predictor; var calibrated = predictor as CalibratedPredictorBase; while (calibrated != null) { @@ -122,23 +124,21 @@ public string[] GetLabelInfo(IHostEnvironment env, out ColumnType labelType) labelType = null; if (trainRms.Label != null) { - labelType = trainRms.Label.Type; - if (labelType.IsKey && - trainRms.Schema.HasKeyValues(trainRms.Label.Index, labelType.KeyCount)) + labelType = trainRms.Label.Value.Type; + if (labelType is KeyType && trainRms.Label.Value.HasKeyValues(labelType.KeyCount)) { VBuffer> keyValues = default; - trainRms.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, trainRms.Label.Index, - ref keyValues); + trainRms.Label.Value.Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref keyValues); return keyValues.DenseValues().Select(v => v.ToString()).ToArray(); } } return null; } - public RoleMappedSchema GetTrainingSchema(IHostEnvironment env) + internal override RoleMappedSchema GetTrainingSchema(IHostEnvironment env) { Contracts.CheckValue(env, nameof(env)); - var predInput = _transformModel.Apply(env, new EmptyDataView(env, _transformModel.InputSchema)); + var predInput = TransformModel.Apply(env, new EmptyDataView(env, TransformModel.InputSchema)); var trainRms = new RoleMappedSchema(predInput.Schema, _roleMappings, opt: true); return trainRms; } diff --git a/src/Microsoft.ML.Data/EntryPoints/SchemaManipulation.cs b/src/Microsoft.ML.Data/EntryPoints/SchemaManipulation.cs index cccc1da6e7..d062fcac46 100644 --- a/src/Microsoft.ML.Data/EntryPoints/SchemaManipulation.cs +++ b/src/Microsoft.ML.Data/EntryPoints/SchemaManipulation.cs @@ -2,14 +2,14 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; using Microsoft.ML.Transforms; [assembly: LoadableClass(typeof(void), typeof(SchemaManipulation), null, typeof(SignatureEntryPointModule), "SchemaManipulation")] -namespace Microsoft.ML.Runtime.EntryPoints +namespace Microsoft.ML.EntryPoints { public static class SchemaManipulation { @@ -22,7 +22,7 @@ public static CommonOutputs.TransformOutput ConcatColumns(IHostEnvironment env, EntryPointUtils.CheckInputArgs(host, input); var xf = ColumnConcatenatingTransformer.Create(env, input, input.Data); - return new CommonOutputs.TransformOutput { Model = new TransformModel(env, xf, input.Data), OutputData = xf }; + return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, xf, input.Data), OutputData = xf }; } [TlcModule.EntryPoint(Name = "Transforms.ColumnSelector", Desc = "Selects a set of columns, dropping all others", UserName = "Select Columns")] @@ -34,7 +34,7 @@ public static CommonOutputs.TransformOutput SelectColumns(IHostEnvironment env, EntryPointUtils.CheckInputArgs(host, input); var xf = new ColumnSelectingTransformer(env, input.KeepColumns, input.DropColumns, input.KeepHidden, input.IgnoreMissing).Transform(input.Data); - return new CommonOutputs.TransformOutput { Model = new TransformModel(env, xf, input.Data), OutputData = xf }; + return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, xf, input.Data), OutputData = xf }; } [TlcModule.EntryPoint(Name = "Transforms.ColumnCopier", Desc = "Duplicates columns from the dataset", UserName = ColumnCopyingTransformer.UserName, ShortName = ColumnCopyingTransformer.ShortName)] @@ -45,7 +45,7 @@ public static CommonOutputs.TransformOutput CopyColumns(IHostEnvironment env, Co host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); var xf = ColumnCopyingTransformer.Create(env, input, input.Data); - return new CommonOutputs.TransformOutput { Model = new TransformModel(env, xf, input.Data), OutputData = xf }; + return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, xf, input.Data), OutputData = xf }; } } } diff --git a/src/Microsoft.ML.Data/EntryPoints/ScoreColumnSelector.cs b/src/Microsoft.ML.Data/EntryPoints/ScoreColumnSelector.cs index b2d2eb8450..c6619887cc 100644 --- a/src/Microsoft.ML.Data/EntryPoints/ScoreColumnSelector.cs +++ b/src/Microsoft.ML.Data/EntryPoints/ScoreColumnSelector.cs @@ -2,15 +2,14 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Data; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Transforms; using System; using System.Collections.Generic; using System.Linq; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.Transforms; -namespace Microsoft.ML.Runtime.EntryPoints +namespace Microsoft.ML.EntryPoints { public static partial class ScoreModel { @@ -29,27 +28,27 @@ public static CommonOutputs.TransformOutput SelectColumns(IHostEnvironment env, var view = input.Data; var maxScoreId = view.Schema.GetMaxMetadataKind(out int colMax, MetadataUtils.Kinds.ScoreColumnSetId); List indices = new List(); - for (int i = 0; i < view.Schema.ColumnCount; i++) + for (int i = 0; i < view.Schema.Count; i++) { - if (view.Schema.IsHidden(i)) + if (view.Schema[i].IsHidden) continue; if (!ShouldAddColumn(view.Schema, i, input.ExtraColumns, maxScoreId)) continue; indices.Add(i); } var newView = new ChooseColumnsByIndexTransform(env, new ChooseColumnsByIndexTransform.Arguments() { Index = indices.ToArray() }, input.Data); - return new CommonOutputs.TransformOutput { Model = new TransformModel(env, newView, input.Data), OutputData = newView }; + return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, newView, input.Data), OutputData = newView }; } private static bool ShouldAddColumn(Schema schema, int i, string[] extraColumns, uint scoreSet) { uint scoreSetId = 0; - if (schema.TryGetMetadata(MetadataUtils.ScoreColumnSetIdType.AsPrimitive, MetadataUtils.Kinds.ScoreColumnSetId, i, ref scoreSetId) + if (schema.TryGetMetadata(MetadataUtils.ScoreColumnSetIdType, MetadataUtils.Kinds.ScoreColumnSetId, i, ref scoreSetId) && scoreSetId == scoreSet) { return true; } - var columnName = schema.GetColumnName(i); + var columnName = schema[i].Name; if (extraColumns != null && Array.FindIndex(extraColumns, columnName.Equals) >= 0) return true; return false; @@ -58,7 +57,7 @@ private static bool ShouldAddColumn(Schema schema, int i, string[] extraColumns, public sealed class RenameBinaryPredictionScoreColumnsInput : TransformInputBase { [Argument(ArgumentType.Required, HelpText = "The predictor model used in scoring", SortOrder = 2)] - public IPredictorModel PredictorModel; + public PredictorModel PredictorModel; } [TlcModule.EntryPoint(Name = "Transforms.BinaryPredictionScoreColumnsRenamer", Desc = "For binary prediction, it renames the PredictedLabel and Score columns to include the name of the positive class.", UserName = "Rename Binary Prediction Score Columns")] @@ -82,9 +81,9 @@ public static CommonOutputs.TransformOutput RenameBinaryPredictionScoreColumns(I int colMax; var maxScoreId = input.Data.Schema.GetMaxMetadataKind(out colMax, MetadataUtils.Kinds.ScoreColumnSetId); var copyCols = new List<(string Source, string Name)>(); - for (int i = 0; i < input.Data.Schema.ColumnCount; i++) + for (int i = 0; i < input.Data.Schema.Count; i++) { - if (input.Data.Schema.IsHidden(i)) + if (input.Data.Schema[i].IsHidden) continue; if (!ShouldAddColumn(input.Data.Schema, i, null, maxScoreId)) continue; @@ -96,19 +95,19 @@ public static CommonOutputs.TransformOutput RenameBinaryPredictionScoreColumns(I { continue; } - var source = input.Data.Schema.GetColumnName(i); + var source = input.Data.Schema[i].Name; var name = source + "." + positiveClass; copyCols.Add((source, name)); } var copyColumn = new ColumnCopyingTransformer(env, copyCols.ToArray()).Transform(input.Data); var dropColumn = ColumnSelectingTransformer.CreateDrop(env, copyColumn, copyCols.Select(c => c.Source).ToArray()); - return new CommonOutputs.TransformOutput { Model = new TransformModel(env, dropColumn, input.Data), OutputData = dropColumn }; + return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, dropColumn, input.Data), OutputData = dropColumn }; } } var newView = NopTransform.CreateIfNeeded(env, input.Data); - return new CommonOutputs.TransformOutput { Model = new TransformModel(env, newView, input.Data), OutputData = newView }; + return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, newView, input.Data), OutputData = newView }; } } } diff --git a/src/Microsoft.ML.Data/EntryPoints/ScoreModel.cs b/src/Microsoft.ML.Data/EntryPoints/ScoreModel.cs index b5c49c7fdb..387dad5713 100644 --- a/src/Microsoft.ML.Data/EntryPoints/ScoreModel.cs +++ b/src/Microsoft.ML.Data/EntryPoints/ScoreModel.cs @@ -2,19 +2,18 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; [assembly: LoadableClass(typeof(void), typeof(ScoreModel), null, typeof(SignatureEntryPointModule), "ScoreModel")] -namespace Microsoft.ML.Runtime.EntryPoints +namespace Microsoft.ML.EntryPoints { /// - /// This module handles scoring a against a new dataset. - /// As a result, we return both the scored data and the scoring transform as a . + /// This module handles scoring a against a new dataset. + /// As a result, we return both the scored data and the scoring transform as a . /// /// REVIEW: This module does not support 'exotic' scoring scenarios, like recommendation and quantile regression /// (those where the user-defined scorer settings are necessary to identify the scorer). We could resolve this by @@ -28,7 +27,7 @@ public sealed class Input public IDataView Data; [Argument(ArgumentType.Required, HelpText = "The predictor model to apply to data", SortOrder = 2)] - public IPredictorModel PredictorModel; + public PredictorModel PredictorModel; [Argument(ArgumentType.AtMostOnce, HelpText = "Suffix to append to the score columns", SortOrder = 3)] public string Suffix; @@ -40,7 +39,7 @@ public sealed class InputTransformScorer public IDataView Data; [Argument(ArgumentType.Required, HelpText = "The transform model to apply to data", SortOrder = 2)] - public ITransformModel TransformModel; + public TransformModel TransformModel; } public sealed class Output @@ -49,19 +48,19 @@ public sealed class Output public IDataView ScoredData; [TlcModule.Output(Desc = "The scoring transform", SortOrder = 2)] - public ITransformModel ScoringTransform; + public TransformModel ScoringTransform; } public sealed class ModelInput { [Argument(ArgumentType.Required, HelpText = "The predictor model to turn into a transform", SortOrder = 1)] - public IPredictorModel PredictorModel; + public PredictorModel PredictorModel; } public sealed class ModelOutput { [TlcModule.Output(Desc = "The scoring transform", SortOrder = 1)] - public ITransformModel ScoringTransform; + public TransformModel ScoringTransform; } [TlcModule.EntryPoint(Name = "Transforms.DatasetScorer", Desc = "Score a dataset with a predictor model")] @@ -91,7 +90,7 @@ public static Output Score(IHostEnvironment env, Input input) new Output { ScoredData = scoredPipe, - ScoringTransform = new TransformModel(host, scoredPipe, inputData) + ScoringTransform = new TransformModelImpl(host, scoredPipe, inputData) }; } @@ -140,7 +139,7 @@ public static Output MakeScoringTransform(IHostEnvironment env, ModelInput input return new Output { ScoredData = scoredPipe, - ScoringTransform = new TransformModel(host, scoredPipe, emptyData) + ScoringTransform = new TransformModelImpl(host, scoredPipe, emptyData) }; } } diff --git a/src/Microsoft.ML.Data/EntryPoints/SelectRows.cs b/src/Microsoft.ML.Data/EntryPoints/SelectRows.cs index aa805eeac8..ddee42fa1f 100644 --- a/src/Microsoft.ML.Data/EntryPoints/SelectRows.cs +++ b/src/Microsoft.ML.Data/EntryPoints/SelectRows.cs @@ -2,12 +2,12 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML.EntryPoints; using Microsoft.ML.Transforms; [assembly: EntryPointModule(typeof(SelectRows))] -namespace Microsoft.ML.Runtime.EntryPoints +namespace Microsoft.ML.EntryPoints { public static class SelectRows { @@ -20,7 +20,7 @@ public static CommonOutputs.TransformOutput FilterByRange(IHostEnvironment env, EntryPointUtils.CheckInputArgs(host, input); var xf = new RangeFilter(host, input, input.Data); - return new CommonOutputs.TransformOutput { Model = new TransformModel(env, xf, input.Data), OutputData = xf }; + return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, xf, input.Data), OutputData = xf }; } [TlcModule.EntryPoint(Name = "Transforms.RowSkipFilter", Desc = SkipTakeFilter.SkipFilterSummary, UserName = SkipTakeFilter.SkipFilterUserName, @@ -32,7 +32,7 @@ public static CommonOutputs.TransformOutput SkipFilter(IHostEnvironment env, Ski host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); var xf = SkipTakeFilter.Create(host, input, input.Data); - return new CommonOutputs.TransformOutput { Model = new TransformModel(env, xf, input.Data), OutputData = xf }; + return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, xf, input.Data), OutputData = xf }; } [TlcModule.EntryPoint(Name = "Transforms.RowTakeFilter", Desc = SkipTakeFilter.TakeFilterSummary, UserName = SkipTakeFilter.TakeFilterUserName, @@ -44,7 +44,7 @@ public static CommonOutputs.TransformOutput TakeFilter(IHostEnvironment env, Ski host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); var xf = SkipTakeFilter.Create(host, input, input.Data); - return new CommonOutputs.TransformOutput { Model = new TransformModel(env, xf, input.Data), OutputData = xf }; + return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, xf, input.Data), OutputData = xf }; } [TlcModule.EntryPoint(Name = "Transforms.RowSkipAndTakeFilter", Desc = SkipTakeFilter.SkipTakeFilterSummary, @@ -56,7 +56,7 @@ public static CommonOutputs.TransformOutput SkipAndTakeFilter(IHostEnvironment e host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); var xf = SkipTakeFilter.Create(host, input, input.Data); - return new CommonOutputs.TransformOutput { Model = new TransformModel(env, xf, input.Data), OutputData = xf }; + return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, xf, input.Data), OutputData = xf }; } } } diff --git a/src/Microsoft.ML.Data/EntryPoints/SummarizePredictor.cs b/src/Microsoft.ML.Data/EntryPoints/SummarizePredictor.cs index 0603d28e84..80919b5024 100644 --- a/src/Microsoft.ML.Data/EntryPoints/SummarizePredictor.cs +++ b/src/Microsoft.ML.Data/EntryPoints/SummarizePredictor.cs @@ -4,22 +4,22 @@ using System.IO; using System.Text; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Calibration; -using Microsoft.ML.Runtime.Internal.Internallearn; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Calibration; +using Microsoft.ML.Internal.Internallearn; [assembly: EntryPointModule(typeof(SummarizePredictor))] -namespace Microsoft.ML.Runtime.EntryPoints +namespace Microsoft.ML.EntryPoints { public static class SummarizePredictor { public abstract class InputBase { [Argument(ArgumentType.Required, ShortName = "predictorModel", HelpText = "The predictor to summarize")] - public IPredictorModel PredictorModel; + public PredictorModel PredictorModel; } public sealed class Input : InputBase @@ -43,7 +43,8 @@ public static CommonOutputs.SummaryOutput Summarize(IHostEnvironment env, Summar return output; } - public static IDataView GetSummaryAndStats(IHostEnvironment env, IPredictor predictor, RoleMappedSchema schema, out IDataView stats) + [BestFriend] + internal static IDataView GetSummaryAndStats(IHostEnvironment env, IPredictor predictor, RoleMappedSchema schema, out IDataView stats) { var calibrated = predictor as CalibratedPredictorBase; while (calibrated != null) diff --git a/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs b/src/Microsoft.ML.Data/EntryPoints/TransformModelImpl.cs similarity index 80% rename from src/Microsoft.ML.Data/EntryPoints/TransformModel.cs rename to src/Microsoft.ML.Data/EntryPoints/TransformModelImpl.cs index 4d94dc9760..091d3d6d6f 100644 --- a/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs +++ b/src/Microsoft.ML.Data/EntryPoints/TransformModelImpl.cs @@ -6,53 +6,56 @@ using System.Collections.Generic; using System.IO; using Microsoft.ML.Data; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Data.IO; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; -namespace Microsoft.ML.Runtime.EntryPoints +namespace Microsoft.ML.EntryPoints { /// /// This encapsulates zero or more transform models. It does this by recording /// the initial schema, together with the sequence of transforms applied to that /// schema. /// - public sealed class TransformModel : ITransformModel + [BestFriend] + internal sealed class TransformModelImpl : TransformModel { // The cached schema of the root of the _chain. private readonly Schema _schemaRoot; - // This contains the transforms to save instantiated on an IDataView with - // appropriate initial schema. Note that the "root" of this is typically either - // an empty IDataView or a BinaryLoader with no rows. However, other root - // types are possible, since we don't insist on this when loading a model - // from a zip file. However, whenever we save, we force a BinaryLoader to - // be serialized for the root. + /// + /// This contains the transforms to save instantiated on an with + /// appropriate initial schema. Note that the "root" of this is typically either + /// an empty or a with no rows. However, other root + /// types are possible, since we don't insist on this when loading a model + /// from a zip file. However, whenever we save, we force a to + /// be serialized for the root. + /// private readonly IDataView _chain; /// /// The input schema that this transform model was originally instantiated on. /// Note that the schema may have columns that aren't needed by this transform model. - /// If an IDataView exists with this schema, then applying this transform model to it + /// If an exists with this schema, then applying this transform model to it /// shouldn't fail because of column type issues. /// REVIEW: Would be nice to be able to trim this to the minimum needed somehow. Note /// however that doing so may cause issues for composing transform models. For example, /// if transform model A needs column X and model B needs Y, that is NOT produced by A, /// then trimming A's input schema would cause composition to fail. /// - public Schema InputSchema => _schemaRoot; + internal override Schema InputSchema => _schemaRoot; /// /// The resulting schema once applied to this model. The might have /// columns that are not needed by this transform and these columns will be seen in the /// produced by this transform. /// - public Schema OutputSchema => _chain.Schema; + internal override Schema OutputSchema => _chain.Schema; /// /// Create a TransformModel containing the transforms from "result" back to "input". /// - public TransformModel(IHostEnvironment env, IDataView result, IDataView input) + public TransformModelImpl(IHostEnvironment env, IDataView result, IDataView input) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(result, nameof(result)); @@ -63,7 +66,7 @@ public TransformModel(IHostEnvironment env, IDataView result, IDataView input) _chain = ApplyTransformUtils.ApplyAllTransformsToData(env, result, root, input); } - private TransformModel(IHostEnvironment env, Schema schemaRoot, IDataView chain) + private TransformModelImpl(IHostEnvironment env, Schema schemaRoot, IDataView chain) { Contracts.AssertValue(env); env.AssertValue(schemaRoot); @@ -76,7 +79,7 @@ private TransformModel(IHostEnvironment env, Schema schemaRoot, IDataView chain) /// Create a TransformModel containing the given (optional) transforms applied to the /// given root schema. /// - public TransformModel(IHostEnvironment env, Schema schemaRoot, IDataTransform[] xfs) + public TransformModelImpl(IHostEnvironment env, Schema schemaRoot, IDataTransform[] xfs) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(schemaRoot, nameof(schemaRoot)); @@ -100,7 +103,7 @@ public TransformModel(IHostEnvironment env, Schema schemaRoot, IDataTransform[] /// /// Load a transform model. /// - public TransformModel(IHostEnvironment env, Stream stream) + public TransformModelImpl(IHostEnvironment env, Stream stream) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(stream, nameof(stream)); @@ -128,7 +131,7 @@ public TransformModel(IHostEnvironment env, Stream stream) /// /// Apply this transform model to the given input data. /// - public IDataView Apply(IHostEnvironment env, IDataView input) + internal override IDataView Apply(IHostEnvironment env, IDataView input) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(input, nameof(input)); @@ -138,14 +141,14 @@ public IDataView Apply(IHostEnvironment env, IDataView input) /// /// Apply this transform model to the given input transform model to produce a composite transform model. /// - public ITransformModel Apply(IHostEnvironment env, ITransformModel input) + internal override TransformModel Apply(IHostEnvironment env, TransformModel input) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(input, nameof(input)); IDataView view; Schema schemaRoot = input.InputSchema; - var mod = input as TransformModel; + var mod = input as TransformModelImpl; if (mod != null) view = ApplyTransformUtils.ApplyAllTransformsToData(env, _chain, mod._chain); else @@ -155,13 +158,13 @@ public ITransformModel Apply(IHostEnvironment env, ITransformModel input) view = Apply(env, view); } - return new TransformModel(env, schemaRoot, view); + return new TransformModelImpl(env, schemaRoot, view); } /// /// Save this transform model. /// - public void Save(IHostEnvironment env, Stream stream) + internal override void Save(IHostEnvironment env, Stream stream) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(stream, nameof(stream)); @@ -177,7 +180,7 @@ public void Save(IHostEnvironment env, Stream stream) } } - public IRowToRowMapper AsRowToRowMapper(IExceptionContext ectx) + internal override IRowToRowMapper AsRowToRowMapper(IExceptionContext ectx) { return CompositeRowToRowMapper.IsCompositeRowToRowMapper(_chain) @@ -193,7 +196,9 @@ private sealed class CompositeRowToRowMapper : IRowToRowMapper public Schema Schema => _chain.Schema; - public CompositeRowToRowMapper(IExceptionContext ectx, IDataView chain, ISchema rootSchema) + public Schema OutputSchema => Schema; + + public CompositeRowToRowMapper(IExceptionContext ectx, IDataView chain, Schema rootSchema) { Contracts.CheckValue(ectx, nameof(ectx)); _ectx = ectx; @@ -201,7 +206,7 @@ public CompositeRowToRowMapper(IExceptionContext ectx, IDataView chain, ISchema _ectx.CheckValue(rootSchema, nameof(rootSchema)); _chain = chain; - _rootSchema = Schema.Create(rootSchema); + _rootSchema = rootSchema; } public static bool IsCompositeRowToRowMapper(IDataView chain) @@ -234,7 +239,7 @@ public Func GetDependencies(Func predicate) public Schema InputSchema => _rootSchema; - public IRow GetRow(IRow input, Func active, out Action disposer) + public Row GetRow(Row input, Func active) { _ectx.Assert(IsCompositeRowToRowMapper(_chain)); _ectx.AssertValue(input); @@ -242,7 +247,6 @@ public IRow GetRow(IRow input, Func active, out Action disposer) _ectx.Check(input.Schema == InputSchema, "Schema of input row must be the same as the schema the mapper is bound to"); - disposer = null; var mappers = new List(); var actives = new List>(); var transform = _chain as IDataTransform; @@ -260,11 +264,7 @@ public IRow GetRow(IRow input, Func active, out Action disposer) actives.Reverse(); var row = input; for (int i = 0; i < mappers.Count; i++) - { - Action disp; - row = mappers[i].GetRow(row, actives[i], out disp); - disposer += disp; - } + row = mappers[i].GetRow(row, actives[i]); return row; } diff --git a/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs index 7b8bd9651e..666b2d7109 100644 --- a/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs @@ -2,16 +2,16 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Transforms; using System; using System.Collections.Generic; using System.Linq; using System.Text; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Transforms; [assembly: LoadableClass(typeof(AnomalyDetectionEvaluator), typeof(AnomalyDetectionEvaluator), typeof(AnomalyDetectionEvaluator.Arguments), typeof(SignatureEvaluator), "Anomaly Detection Evaluator", AnomalyDetectionEvaluator.LoadName, "AnomalyDetection", "Anomaly")] @@ -19,7 +19,7 @@ [assembly: LoadableClass(typeof(AnomalyDetectionMamlEvaluator), typeof(AnomalyDetectionMamlEvaluator), typeof(AnomalyDetectionMamlEvaluator.Arguments), typeof(SignatureMamlEvaluator), "Anomaly Detection Evaluator", AnomalyDetectionEvaluator.LoadName, "AnomalyDetection", "Anomaly")] -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { public sealed class AnomalyDetectionEvaluator : EvaluatorBase { @@ -90,24 +90,24 @@ public AnomalyDetectionEvaluator(IHostEnvironment env, Arguments args) _aucCount = args.MaxAucExamples; } - protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) + private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) { var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); var t = score.Type; if (t != NumberType.Float) throw Host.Except("Score column '{0}' has type '{1}' but must be R4", score, t).MarkSensitive(MessageSensitivity.Schema); - Host.Check(schema.Label != null, "Could not find the label column"); - t = schema.Label.Type; + Host.Check(schema.Label.HasValue, "Could not find the label column"); + t = schema.Label.Value.Type; if (t != NumberType.Float && t.KeyCount != 2) - throw Host.Except("Label column '{0}' has type '{1}' but must be R4 or a 2-value key", schema.Label.Name, t).MarkSensitive(MessageSensitivity.Schema); + throw Host.Except("Label column '{0}' has type '{1}' but must be R4 or a 2-value key", schema.Label.Value.Name, t).MarkSensitive(MessageSensitivity.Schema); } - protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName) + private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName) { - return new Aggregator(Host, _aucCount, _numTopResults, _k, _p, _streaming, schema.Name == null ? -1 : schema.Name.Index, stratName); + return new Aggregator(Host, _aucCount, _numTopResults, _k, _p, _streaming, schema.Name == null ? -1 : schema.Name.Value.Index, stratName); } - public override IDataTransform GetPerInstanceMetrics(RoleMappedData data) + internal override IDataTransform GetPerInstanceMetricsCore(RoleMappedData data) { return NopTransform.CreateIfNeeded(Host, data.Data); } @@ -123,7 +123,7 @@ public override IEnumerable GetOverallMetricColumns() yield return new MetricColumn("ThreshAtNumPos", OverallMetrics.ThreshAtNumPos, MetricColumn.Objective.Info, canBeWeighted: false); } - protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries, + private protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries, out Action, Aggregator> addAgg, out Func> consolidate) { var stratCol = new List(); @@ -497,14 +497,14 @@ private void FinishOtherMetrics() } } - public override void InitializeNextPass(IRow row, RoleMappedSchema schema) + internal override void InitializeNextPass(Row row, RoleMappedSchema schema) { Host.Assert(!_streaming && PassNum < 2 || PassNum < 1); - Host.AssertValue(schema.Label); + Host.Assert(schema.Label.HasValue); var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); - _labelGetter = RowCursorUtils.GetLabelGetter(row, schema.Label.Index); + _labelGetter = RowCursorUtils.GetLabelGetter(row, schema.Label.Value.Index); _scoreGetter = row.GetGetter(score.Index); Host.AssertValue(_labelGetter); Host.AssertValue(_scoreGetter); @@ -605,7 +605,7 @@ public sealed class Arguments : ArgumentsBase private readonly int _k; private readonly Double _p; - protected override IEvaluator Evaluator => _evaluator; + private protected override IEvaluator Evaluator => _evaluator; public AnomalyDetectionMamlEvaluator(IHostEnvironment env, Arguments args) : base(args, env, MetadataUtils.Const.ScoreColumnKind.AnomalyDetection, "AnomalyDetectionMamlEvaluator") @@ -620,7 +620,7 @@ public AnomalyDetectionMamlEvaluator(IHostEnvironment env, Arguments args) _evaluator = new AnomalyDetectionEvaluator(Host, evalArgs); } - protected override void PrintFoldResultsCore(IChannel ch, Dictionary metrics) + private protected override void PrintFoldResultsCore(IChannel ch, Dictionary metrics) { IDataView top; if (!metrics.TryGetValue(AnomalyDetectionEvaluator.TopKResults, out top)) @@ -685,7 +685,7 @@ protected override void PrintFoldResultsCore(IChannel ch, Dictionary stratGetter = null; if (hasStrat) { - var type = cursor.Schema.GetColumnType(stratCol); + var type = cursor.Schema[stratCol].Type; stratGetter = RowCursorUtils.GetGetterAs(type, cursor, stratCol); } bool foundRow = false; @@ -731,7 +731,7 @@ protected override void PrintFoldResultsCore(IChannel ch, Dictionary GetPerInstanceColumnsToSave(RoleMappedSchema schema) + private protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema) { Host.CheckValue(schema, nameof(schema)); - Host.CheckValue(schema.Label, nameof(schema), "Data must contain a label column"); + Host.CheckParam(schema.Label.HasValue, nameof(schema), "Data must contain a label column"); // The anomaly detection evaluator outputs the label and the score. - yield return schema.Label.Name; - var scoreInfo = EvaluateUtils.GetScoreColumnInfo(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn), + yield return schema.Label.Value.Name; + var scoreCol = EvaluateUtils.GetScoreColumn(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn), MetadataUtils.Const.ScoreColumnKind.AnomalyDetection); - yield return scoreInfo.Name; + yield return scoreCol.Name; // No additional output columns. } @@ -777,7 +777,7 @@ public static CommonOutputs.CommonEvaluateOutput AnomalyDetection(IHostEnvironme string weight; string name; MatchColumns(host, input, out label, out weight, out name); - var evaluator = new AnomalyDetectionMamlEvaluator(host, input); + IMamlEvaluator evaluator = new AnomalyDetectionMamlEvaluator(host, input); var data = new RoleMappedData(input.Data, label, null, null, weight, name); var metrics = evaluator.Evaluate(data); diff --git a/src/Microsoft.ML.Data/Evaluators/AucAggregator.cs b/src/Microsoft.ML.Data/Evaluators/AucAggregator.cs index 87bbf897d6..73c64070d0 100644 --- a/src/Microsoft.ML.Data/Evaluators/AucAggregator.cs +++ b/src/Microsoft.ML.Data/Evaluators/AucAggregator.cs @@ -5,9 +5,9 @@ using System; using System.Collections.Generic; using System.Linq; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { public abstract partial class EvaluatorBase { @@ -41,7 +41,7 @@ internal abstract class AucAggregatorBase : AucAggregatorBase protected IEnumerable PosSample; protected IEnumerable NegSample; - protected AucAggregatorBase(IRandom rand, int reservoirSize) + protected AucAggregatorBase(Random rand, int reservoirSize) { Contracts.Assert(reservoirSize >= -1); @@ -119,7 +119,7 @@ public override Double ComputeWeightedAuc(out Double unweighted) internal sealed class UnweightedAucAggregator : AucAggregatorBase { - public UnweightedAucAggregator(IRandom rand, int reservoirSize) + public UnweightedAucAggregator(Random rand, int reservoirSize) : base(rand, reservoirSize) { } @@ -220,7 +220,7 @@ public struct AucInfo private Single _weight; - public WeightedAucAggregator(IRandom rand, int reservoirSize) + public WeightedAucAggregator(Random rand, int reservoirSize) : base(rand, reservoirSize) { } @@ -368,7 +368,7 @@ private protected abstract class AuPrcAggregatorBase : AuPrcAggregatorBase { protected readonly ReservoirSamplerWithoutReplacement Reservoir; - protected AuPrcAggregatorBase(IRandom rand, int reservoirSize) + protected AuPrcAggregatorBase(Random rand, int reservoirSize) { Contracts.Assert(reservoirSize > 0); @@ -401,7 +401,7 @@ public struct Info public Single Label; } - public UnweightedAuPrcAggregator(IRandom rand, int reservoirSize) + public UnweightedAuPrcAggregator(Random rand, int reservoirSize) : base(rand, reservoirSize) { } @@ -475,7 +475,7 @@ public struct Info public Single Weight; } - public WeightedAuPrcAggregator(IRandom rand, int reservoirSize) + public WeightedAuPrcAggregator(Random rand, int reservoirSize) : base(rand, reservoirSize) { } diff --git a/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs index a06e038571..ae05dc53a3 100644 --- a/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs @@ -2,17 +2,16 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Transforms; using System; using System.Collections.Generic; using System.Linq; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Transforms; [assembly: LoadableClass(typeof(BinaryClassifierEvaluator), typeof(BinaryClassifierEvaluator), typeof(BinaryClassifierEvaluator.Arguments), typeof(SignatureEvaluator), "Binary Classifier Evaluator", BinaryClassifierEvaluator.LoadName, "BinaryClassifier", "Binary", "bin")] @@ -26,7 +25,7 @@ [assembly: LoadableClass(typeof(void), typeof(Evaluate), null, typeof(SignatureEntryPointModule), "Evaluators")] -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { public sealed class BinaryClassifierEvaluator : RowToRowEvaluatorBase { @@ -123,39 +122,39 @@ public BinaryClassifierEvaluator(IHostEnvironment env, Arguments args) _auPrcCount = args.NumAuPrcExamples; } - protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) + private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) { var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); var host = Host.SchemaSensitive(); var t = score.Type; if (t.IsVector || t.ItemType != NumberType.Float) - throw host.SchemaSensitive().Except("Score column '{0}' has type '{1}' but must be R4", score, t); - host.Check(schema.Label != null, "Could not find the label column"); - t = schema.Label.Type; + throw host.ExceptSchemaMismatch(nameof(schema), "score", score.Name, "R4", t.ToString()); + host.Check(schema.Label.HasValue, "Could not find the label column"); + t = schema.Label.Value.Type; if (t != NumberType.R4 && t != NumberType.R8 && t != BoolType.Instance && t.KeyCount != 2) - throw host.SchemaSensitive().Except("Label column '{0}' has type '{1}' but must be R4, R8, BL or a 2-value key", schema.Label.Name, t); + throw host.ExceptSchemaMismatch(nameof(schema), "label", schema.Label.Value.Name, "R4, R8, BL or a 2-value key", t.ToString()); } - protected override void CheckCustomColumnTypesCore(RoleMappedSchema schema) + private protected override void CheckCustomColumnTypesCore(RoleMappedSchema schema) { var prob = schema.GetColumns(MetadataUtils.Const.ScoreValueKind.Probability); var host = Host.SchemaSensitive(); if (prob != null) { - host.Check(prob.Count == 1, "Cannot have multiple probability columns"); + host.CheckParam(prob.Count == 1, nameof(schema), "Cannot have multiple probability columns"); var probType = prob[0].Type; if (probType != NumberType.Float) - throw host.SchemaSensitive().Except("Probability column '{0}' has type '{1}' but must be R4", prob[0].Name, probType); + throw host.ExceptSchemaMismatch(nameof(schema), "probability", prob[0].Name, "R4", probType.ToString()); } else if (!_useRaw) { - throw host.Except( + throw host.ExceptParam(nameof(schema), "Cannot compute the predicted label from the probability column because it does not exist"); } } // Add also the probability column. - protected override Func GetActiveColsCore(RoleMappedSchema schema) + private protected override Func GetActiveColsCore(RoleMappedSchema schema) { var pred = base.GetActiveColsCore(schema); var prob = schema.GetColumns(MetadataUtils.Const.ScoreValueKind.Probability); @@ -163,7 +162,7 @@ protected override Func GetActiveColsCore(RoleMappedSchema schema) return i => Utils.Size(prob) > 0 && i == prob[0].Index || pred(i); } - protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName) + private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName) { var classNames = GetClassNames(schema); return new Aggregator(Host, classNames, schema.Weight != null, _aucCount, _auPrcCount, _threshold, _useRaw, _prCount, stratName); @@ -172,13 +171,13 @@ protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string private ReadOnlyMemory[] GetClassNames(RoleMappedSchema schema) { // Get the label names if they exist, or use the default names. - ColumnType type; var labelNames = default(VBuffer>); - if (schema.Label.Type.IsKey && - (type = schema.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, schema.Label.Index)) != null && - type.ItemType.IsKnownSizeVector && type.ItemType.IsText) + var labelCol = schema.Label.Value; + if (labelCol.Type is KeyType && + labelCol.Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues)?.Type is VectorType vecType && + vecType.Size > 0 && vecType.ItemType == TextType.Instance) { - schema.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, schema.Label.Index, ref labelNames); + labelCol.Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref labelNames); } else labelNames = new VBuffer>(2, new[] { "positive".AsMemory(), "negative".AsMemory() }); @@ -188,16 +187,15 @@ private ReadOnlyMemory[] GetClassNames(RoleMappedSchema schema) return names; } - protected override IRowMapper CreatePerInstanceRowMapper(RoleMappedSchema schema) + private protected override IRowMapper CreatePerInstanceRowMapper(RoleMappedSchema schema) { Contracts.CheckValue(schema, nameof(schema)); Contracts.CheckParam(schema.Label != null, nameof(schema), "Could not find the label column"); var scoreInfo = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); - Contracts.AssertValue(scoreInfo); var probInfos = schema.GetColumns(MetadataUtils.Const.ScoreValueKind.Probability); var probCol = Utils.Size(probInfos) > 0 ? probInfos[0].Name : null; - return new BinaryPerInstanceEvaluator(Host, schema.Schema, scoreInfo.Name, probCol, schema.Label.Name, _threshold, _useRaw); + return new BinaryPerInstanceEvaluator(Host, schema.Schema, scoreInfo.Name, probCol, schema.Label.Value.Name, _threshold, _useRaw); } public override IEnumerable GetOverallMetricColumns() @@ -215,7 +213,7 @@ public override IEnumerable GetOverallMetricColumns() yield return new MetricColumn("AUPRC", AuPrc); } - protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries, + private protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries, out Action, Aggregator> addAgg, out Func> consolidate) { var stratCol = new List(); @@ -609,14 +607,14 @@ public Aggregator(IHostEnvironment env, ReadOnlyMemory[] classNames, bool } } - public override void InitializeNextPass(IRow row, RoleMappedSchema schema) + internal override void InitializeNextPass(Row row, RoleMappedSchema schema) { - Host.AssertValue(schema.Label); + Host.Assert(schema.Label.HasValue); Host.Assert(PassNum < 1); var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); - _labelGetter = RowCursorUtils.GetLabelGetter(row, schema.Label.Index); + _labelGetter = RowCursorUtils.GetLabelGetter(row, schema.Label.Value.Index); _scoreGetter = row.GetGetter(score.Index); Host.AssertValue(_labelGetter); Host.AssertValue(_scoreGetter); @@ -631,7 +629,7 @@ public override void InitializeNextPass(IRow row, RoleMappedSchema schema) Host.Assert((schema.Weight != null) == Weighted); if (Weighted) - _weightGetter = row.GetGetter(schema.Weight.Index); + _weightGetter = row.GetGetter(schema.Weight.Value.Index); } public override void ProcessRow() @@ -790,155 +788,6 @@ private void ComputePrCurves() } } - /// - /// Evaluation results for binary classifiers, excluding probabilistic metrics. - /// - public class Result - { - /// - /// Gets the area under the ROC curve. - /// - /// - /// The area under the ROC curve is equal to the probability that the classifier ranks - /// a randomly chosen positive instance higher than a randomly chosen negative one - /// (assuming 'positive' ranks higher than 'negative'). - /// - public double Auc { get; } - - /// - /// Gets the accuracy of a classifier which is the proportion of correct predictions in the test set. - /// - public double Accuracy { get; } - - /// - /// Gets the positive precision of a classifier which is the proportion of correctly predicted - /// positive instances among all the positive predictions (i.e., the number of positive instances - /// predicted as positive, divided by the total number of instances predicted as positive). - /// - public double PositivePrecision { get; } - - /// - /// Gets the positive recall of a classifier which is the proportion of correctly predicted - /// positive instances among all the positive instances (i.e., the number of positive instances - /// predicted as positive, divided by the total number of positive instances). - /// - public double PositiveRecall { get; private set; } - - /// - /// Gets the negative precision of a classifier which is the proportion of correctly predicted - /// negative instances among all the negative predictions (i.e., the number of negative instances - /// predicted as negative, divided by the total number of instances predicted as negative). - /// - public double NegativePrecision { get; } - - /// - /// Gets the negative recall of a classifier which is the proportion of correctly predicted - /// negative instances among all the negative instances (i.e., the number of negative instances - /// predicted as negative, divided by the total number of negative instances). - /// - public double NegativeRecall { get; } - - /// - /// Gets the F1 score of the classifier. - /// - /// - /// F1 score is the harmonic mean of precision and recall: 2 * precision * recall / (precision + recall). - /// - public double F1Score { get; } - - /// - /// Gets the area under the precision/recall curve of the classifier. - /// - /// - /// The area under the precision/recall curve is a single number summary of the information in the - /// precision/recall curve. It is increasingly used in the machine learning community, particularly - /// for imbalanced datasets where one class is observed more frequently than the other. On these - /// datasets, AUPRC can highlight performance differences that are lost with AUC. - /// - public double Auprc { get; } - - protected private static T Fetch(IExceptionContext ectx, IRow row, string name) - { - if (!row.Schema.TryGetColumnIndex(name, out int col)) - throw ectx.Except($"Could not find column '{name}'"); - T val = default; - row.GetGetter(col)(ref val); - return val; - } - - internal Result(IExceptionContext ectx, IRow overallResult) - { - double Fetch(string name) => Fetch(ectx, overallResult, name); - Auc = Fetch(BinaryClassifierEvaluator.Auc); - Accuracy = Fetch(BinaryClassifierEvaluator.Accuracy); - PositivePrecision = Fetch(BinaryClassifierEvaluator.PosPrecName); - PositiveRecall = Fetch(BinaryClassifierEvaluator.PosRecallName); - NegativePrecision = Fetch(BinaryClassifierEvaluator.NegPrecName); - NegativeRecall = Fetch(BinaryClassifierEvaluator.NegRecallName); - F1Score = Fetch(BinaryClassifierEvaluator.F1); - Auprc = Fetch(BinaryClassifierEvaluator.AuPrc); - } - - [BestFriend] - internal Result(double auc, double accuracy, double positivePrecision, double positiveRecall, - double negativePrecision, double negativeRecall, double f1Score, double auprc) - { - Auc = auc; - Accuracy = accuracy; - PositivePrecision = positivePrecision; - PositiveRecall = positiveRecall; - NegativePrecision = negativePrecision; - NegativeRecall = negativeRecall; - F1Score = f1Score; - Auprc = auprc; - } - } - - /// - /// Evaluation results for binary classifiers, including probabilistic metrics. - /// - public sealed class CalibratedResult : Result - { - /// - /// Gets the log-loss of the classifier. - /// - /// - /// The log-loss metric, is computed as follows: - /// LL = - (1/m) * sum( log(p[i])) - /// where m is the number of instances in the test set. - /// p[i] is the probability returned by the classifier if the instance belongs to class 1, - /// and 1 minus the probability returned by the classifier if the instance belongs to class 0. - /// - public double LogLoss { get; } - - /// - /// Gets the log-loss reduction (also known as relative log-loss, or reduction in information gain - RIG) - /// of the classifier. - /// - /// - /// The log-loss reduction is scaled relative to a classifier that predicts the prior for every example: - /// (LL(prior) - LL(classifier)) / LL(prior) - /// This metric can be interpreted as the advantage of the classifier over a random prediction. - /// For example, if the RIG equals 20, it can be interpreted as "the probability of a correct prediction is - /// 20% better than random guessing." - /// - public double LogLossReduction { get; } - - /// - /// Gets the test-set entropy (prior Log-Loss/instance) of the classifier. - /// - public double Entropy { get; } - - internal CalibratedResult(IExceptionContext ectx, IRow overallResult) - : base(ectx, overallResult) - { - double Fetch(string name) => Fetch(ectx, overallResult, name); - LogLoss = Fetch(BinaryClassifierEvaluator.LogLoss); - LogLossReduction = Fetch(BinaryClassifierEvaluator.LogLossReduction); - Entropy = Fetch(BinaryClassifierEvaluator.Entropy); - } - } - /// /// Evaluates scored binary classification data. /// @@ -948,7 +797,7 @@ internal CalibratedResult(IExceptionContext ectx, IRow overallResult) /// The name of the probability column in , the calibrated version of . /// The name of the predicted label column in . /// The evaluation results for these calibrated outputs. - public CalibratedResult Evaluate(IDataView data, string label, string score, string probability, string predictedLabel) + public CalibratedBinaryClassificationMetrics Evaluate(IDataView data, string label, string score, string probability, string predictedLabel) { Host.CheckValue(data, nameof(data)); Host.CheckNonEmpty(label, nameof(label)); @@ -962,16 +811,16 @@ public CalibratedResult Evaluate(IDataView data, string label, string score, str RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Probability, probability), RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.PredictedLabel, predictedLabel)); - var resultDict = Evaluate(roles); + var resultDict = ((IEvaluator)this).Evaluate(roles); Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics)); var overall = resultDict[MetricKinds.OverallMetrics]; - CalibratedResult result; + CalibratedBinaryClassificationMetrics result; using (var cursor = overall.GetRowCursor(i => true)) { var moved = cursor.MoveNext(); Host.Assert(moved); - result = new CalibratedResult(Host, cursor); + result = new CalibratedBinaryClassificationMetrics(Host, cursor); moved = cursor.MoveNext(); Host.Assert(!moved); } @@ -987,7 +836,7 @@ public CalibratedResult Evaluate(IDataView data, string label, string score, str /// The name of the predicted label column in . /// The evaluation results for these uncalibrated outputs. /// - public Result Evaluate(IDataView data, string label, string score, string predictedLabel) + public BinaryClassificationMetrics Evaluate(IDataView data, string label, string score, string predictedLabel) { Host.CheckValue(data, nameof(data)); Host.CheckNonEmpty(label, nameof(label)); @@ -999,16 +848,16 @@ public Result Evaluate(IDataView data, string label, string score, string predic RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, score), RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.PredictedLabel, predictedLabel)); - var resultDict = Evaluate(roles); + var resultDict = ((IEvaluator)this).Evaluate(roles); Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics)); var overall = resultDict[MetricKinds.OverallMetrics]; - Result result; + BinaryClassificationMetrics result; using (var cursor = overall.GetRowCursor(i => true)) { var moved = cursor.MoveNext(); Host.Assert(moved); - result = new Result(Host, cursor); + result = new BinaryClassificationMetrics(Host, cursor); moved = cursor.MoveNext(); Host.Assert(!moved); } @@ -1042,7 +891,7 @@ private static VersionInfo GetVersionInfo() private readonly bool _useRaw; private readonly ColumnType[] _types; - public BinaryPerInstanceEvaluator(IHostEnvironment env, ISchema schema, string scoreCol, string probCol, string labelCol, Single threshold, bool useRaw) + public BinaryPerInstanceEvaluator(IHostEnvironment env, Schema schema, string scoreCol, string probCol, string labelCol, Single threshold, bool useRaw) : base(env, schema, scoreCol, labelCol) { _threshold = threshold; @@ -1062,7 +911,7 @@ public BinaryPerInstanceEvaluator(IHostEnvironment env, ISchema schema, string s _types[AssignedCol] = BoolType.Instance; } - private BinaryPerInstanceEvaluator(IHostEnvironment env, ModelLoadContext ctx, ISchema schema) + private BinaryPerInstanceEvaluator(IHostEnvironment env, ModelLoadContext ctx, Schema schema) : base(env, ctx, schema) { // *** Binary format ** @@ -1088,7 +937,7 @@ private BinaryPerInstanceEvaluator(IHostEnvironment env, ModelLoadContext ctx, I _types[AssignedCol] = BoolType.Instance; } - public static BinaryPerInstanceEvaluator Create(IHostEnvironment env, ModelLoadContext ctx, ISchema schema) + public static BinaryPerInstanceEvaluator Create(IHostEnvironment env, ModelLoadContext ctx, Schema schema) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); @@ -1117,7 +966,7 @@ public override void Save(ModelSaveContext ctx) ctx.Writer.WriteBoolByte(_useRaw); } - public override Func GetDependencies(Func activeOutput) + private protected override Func GetDependenciesCore(Func activeOutput) { if (_probIndex >= 0) { @@ -1130,7 +979,7 @@ public override Func GetDependencies(Func activeOutput) return col => activeOutput(AssignedCol) && col == ScoreIndex; } - public override Delegate[] CreateGetters(IRow input, Func activeCols, out Action disposer) + private protected override Delegate[] CreateGettersCore(Row input, Func activeCols, out Action disposer) { Host.Assert(LabelIndex >= 0); Host.Assert(ScoreIndex >= 0); @@ -1228,7 +1077,7 @@ private bool GetPredictedLabel(Single val) return Single.IsNaN(val) ? false : val > _threshold; } - public override Schema.DetachedColumn[] GetOutputColumns() + private protected override Schema.DetachedColumn[] GetOutputColumnsCore() { if (_probIndex >= 0) { @@ -1240,24 +1089,24 @@ public override Schema.DetachedColumn[] GetOutputColumns() return new[] { new Schema.DetachedColumn(Assigned, _types[AssignedCol], null), }; } - private void CheckInputColumnTypes(ISchema schema) + private void CheckInputColumnTypes(Schema schema) { Host.AssertNonEmpty(ScoreCol); Host.AssertValueOrNull(_probCol); Host.AssertNonEmpty(LabelCol); - var t = schema.GetColumnType(LabelIndex); + var t = schema[(int) LabelIndex].Type; if (t != NumberType.R4 && t != NumberType.R8 && t != BoolType.Instance && t.KeyCount != 2) throw Host.Except("Label column '{0}' has type '{1}' but must be R4, R8, BL or a 2-value key", LabelCol, t); - t = schema.GetColumnType(ScoreIndex); + t = schema[ScoreIndex].Type; if (t.IsVector || t.ItemType != NumberType.Float) throw Host.Except("Score column '{0}' has type '{1}' but must be R4", ScoreCol, t); if (_probIndex >= 0) { Host.Assert(!string.IsNullOrEmpty(_probCol)); - t = schema.GetColumnType(_probIndex); + t = schema[_probIndex].Type; if (t.IsVector || t.ItemType != NumberType.Float) throw Host.Except("Probability column '{0}' has type '{1}' but must be R4", _probCol, t); } @@ -1301,7 +1150,7 @@ public class Arguments : ArgumentsBase private readonly string _prFileName; private readonly string _probCol; - protected override IEvaluator Evaluator { get { return _evaluator; } } + private protected override IEvaluator Evaluator => _evaluator; public BinaryClassifierMamlEvaluator(IHostEnvironment env, Arguments args) : base(args, env, MetadataUtils.Const.ScoreColumnKind.BinaryClassification, "BinaryClassifierMamlEvaluator") @@ -1321,22 +1170,22 @@ public BinaryClassifierMamlEvaluator(IHostEnvironment env, Arguments args) _evaluator = new BinaryClassifierEvaluator(Host, evalArgs); } - protected override IEnumerable> GetInputColumnRolesCore(RoleMappedSchema schema) + private protected override IEnumerable> GetInputColumnRolesCore(RoleMappedSchema schema) { var cols = base.GetInputColumnRolesCore(schema); - var scoreInfo = EvaluateUtils.GetScoreColumnInfo(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn), + var scoreCol = EvaluateUtils.GetScoreColumn(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn), MetadataUtils.Const.ScoreColumnKind.BinaryClassification); // Get the optional probability column. - var probInfo = EvaluateUtils.GetOptAuxScoreColumnInfo(Host, schema.Schema, _probCol, nameof(Arguments.ProbabilityColumn), - scoreInfo.Index, MetadataUtils.Const.ScoreValueKind.Probability, t => t == NumberType.Float); - if (probInfo != null) - cols = cols.Prepend(RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Probability, probInfo.Name)); + var probCol = EvaluateUtils.GetOptAuxScoreColumn(Host, schema.Schema, _probCol, nameof(Arguments.ProbabilityColumn), + scoreCol.Index, MetadataUtils.Const.ScoreValueKind.Probability, NumberType.Float.Equals); + if (probCol.HasValue) + cols = MetadataUtils.Prepend(cols, RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Probability, probCol.Value.Name)); return cols; } - protected override void PrintFoldResultsCore(IChannel ch, Dictionary metrics) + private protected override void PrintFoldResultsCore(IChannel ch, Dictionary metrics) { ch.AssertValue(metrics); @@ -1389,12 +1238,12 @@ protected override void PrintFoldResultsCore(IChannel ch, Dictionary[] metrics) + private protected override void PrintAdditionalMetricsCore(IChannel ch, Dictionary[] metrics) { ch.AssertNonEmpty(metrics); @@ -1628,22 +1477,22 @@ private void SavePrPlots(List prList) return avgPoints; } #endif - protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema) + private protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema) { Host.CheckValue(schema, nameof(schema)); - Host.CheckParam(schema.Label != null, nameof(schema), "Schema must contain a label column"); + Host.CheckParam(schema.Label.HasValue, nameof(schema), "Schema must contain a label column"); // The binary classifier evaluator outputs the label, score and probability columns. - yield return schema.Label.Name; - var scoreInfo = EvaluateUtils.GetScoreColumnInfo(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn), + yield return schema.Label.Value.Name; + var scoreCol = EvaluateUtils.GetScoreColumn(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn), MetadataUtils.Const.ScoreColumnKind.BinaryClassification); - yield return scoreInfo.Name; - var probInfo = EvaluateUtils.GetOptAuxScoreColumnInfo(Host, schema.Schema, _probCol, nameof(Arguments.ProbabilityColumn), - scoreInfo.Index, MetadataUtils.Const.ScoreValueKind.Probability, t => t == NumberType.Float); + yield return scoreCol.Name; + var probCol = EvaluateUtils.GetOptAuxScoreColumn(Host, schema.Schema, _probCol, nameof(Arguments.ProbabilityColumn), + scoreCol.Index, MetadataUtils.Const.ScoreValueKind.Probability, NumberType.Float.Equals); // Return the output columns. The LogLoss column is returned only if the probability column exists. - if (probInfo != null) + if (probCol.HasValue) { - yield return probInfo.Name; + yield return probCol.Value.Name; yield return BinaryPerInstanceEvaluator.LogLoss; } @@ -1670,7 +1519,7 @@ public static CommonOutputs.ClassificationEvaluateOutput Binary(IHostEnvironment string weight; string name; MatchColumns(host, input, out label, out weight, out name); - var evaluator = new BinaryClassifierMamlEvaluator(host, input); + IMamlEvaluator evaluator = new BinaryClassifierMamlEvaluator(host, input); var data = new RoleMappedData(input.Data, label, null, null, weight, name); var metrics = evaluator.Evaluate(data); @@ -1690,7 +1539,7 @@ public static CommonOutputs.ClassificationEvaluateOutput Binary(IHostEnvironment private static void MatchColumns(IHost host, MamlEvaluatorBase.ArgumentsBase input, out string label, out string weight, out string name) { - ISchema schema = input.Data.Schema; + var schema = input.Data.Schema; label = TrainUtils.MatchNameOrDefaultOrNull(host, schema, nameof(BinaryClassifierMamlEvaluator.Arguments.LabelColumn), input.LabelColumn, DefaultColumnNames.Label); diff --git a/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs index 7e3a39b175..2c989b87d1 100644 --- a/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs @@ -2,18 +2,17 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Numeric; -using Microsoft.ML.Transforms.FeatureSelection; using System; using System.Collections.Generic; using System.Linq; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Numeric; +using Microsoft.ML.Transforms.FeatureSelection; [assembly: LoadableClass(typeof(ClusteringEvaluator), typeof(ClusteringEvaluator), typeof(ClusteringEvaluator.Arguments), typeof(SignatureEvaluator), "Clustering Evaluator", ClusteringEvaluator.LoadName, "Clustering")] @@ -25,7 +24,7 @@ [assembly: LoadableClass(typeof(ClusteringPerInstanceEvaluator), null, typeof(SignatureLoadRowMapper), "", ClusteringPerInstanceEvaluator.LoaderSignature)] -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { using Conditional = System.Diagnostics.ConditionalAttribute; @@ -62,7 +61,7 @@ public ClusteringEvaluator(IHostEnvironment env, Arguments args) /// The name of the optional label column in . /// The name of the optional feature column in . /// The evaluation results. - public Result Evaluate(IDataView data, string score, string label = null, string features = null) + public ClusteringMetrics Evaluate(IDataView data, string score, string label = null, string features = null) { Host.CheckValue(data, nameof(data)); Host.CheckNonEmpty(score, nameof(score)); @@ -78,71 +77,70 @@ public Result Evaluate(IDataView data, string score, string label = null, string var rolesMappedData = new RoleMappedData(data, opt: false, roles.ToArray()); - var resultDict = Evaluate(rolesMappedData); + var resultDict = ((IEvaluator)this).Evaluate(rolesMappedData); Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics)); var overall = resultDict[MetricKinds.OverallMetrics]; - Result result; + ClusteringMetrics result; using (var cursor = overall.GetRowCursor(i => true)) { var moved = cursor.MoveNext(); Host.Assert(moved); - result = new Result(Host, cursor, _calculateDbi); + result = new ClusteringMetrics(Host, cursor, _calculateDbi); moved = cursor.MoveNext(); Host.Assert(!moved); } return result; } - protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) + private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) { - ColumnType type; - if (schema.Label != null && (type = schema.Label.Type) != NumberType.Float && type.KeyCount == 0) + ColumnType type = schema.Label?.Type; + if (type != null && type != NumberType.Float && !(type is KeyType keyType && keyType.Count > 0)) { - throw Host.Except("Clustering evaluator: label column '{0}' type must be {1} or Key of known cardinality." + - " Provide a correct label column, or none: it is optional.", - schema.Label.Name, NumberType.Float); + throw Host.ExceptSchemaMismatch(nameof(schema), "label", schema.Label.Value.Name, + "R4 or key of known cardinality", type.ToString()); } var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); type = score.Type; if (!type.IsKnownSizeVector || type.ItemType != NumberType.Float) - throw Host.Except("Scores column '{0}' type must be a float vector of known size", score.Name); + throw Host.ExceptSchemaMismatch(nameof(schema), "score", score.Name, "R4 vector of known size", type.ToString()); } - protected override void CheckCustomColumnTypesCore(RoleMappedSchema schema) + private protected override void CheckCustomColumnTypesCore(RoleMappedSchema schema) { if (_calculateDbi) { - Host.AssertValue(schema.Feature); - var t = schema.Feature.Type; + Host.Assert(schema.Feature.HasValue); + var t = schema.Feature.Value.Type; if (!t.IsKnownSizeVector || t.ItemType != NumberType.Float) { - throw Host.Except("Features column '{0}' type must be {1} vector of known-size", - schema.Feature.Name, NumberType.Float); + throw Host.ExceptSchemaMismatch(nameof(schema), "features", schema.Feature.Value.Name, + "R4 vector of known size", t.ToString()); } } } - protected override Func GetActiveColsCore(RoleMappedSchema schema) + private protected override Func GetActiveColsCore(RoleMappedSchema schema) { var pred = base.GetActiveColsCore(schema); // We also need the features column for dbi calculation. Host.Assert(!_calculateDbi || schema.Feature != null); - return i => _calculateDbi && i == schema.Feature.Index || pred(i); + return i => _calculateDbi && i == schema.Feature.Value.Index || pred(i); } - protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName) + private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName) { Host.AssertValue(schema); - Host.Assert(!_calculateDbi || (schema.Feature != null && schema.Feature.Type.IsKnownSizeVector)); + Host.Assert(!_calculateDbi || schema.Feature?.Type.IsKnownSizeVector == true); var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); Host.Assert(score.Type.VectorSize > 0); int numClusters = score.Type.VectorSize; return new Aggregator(Host, schema.Feature, numClusters, _calculateDbi, schema.Weight != null, stratName); } - protected override IRowMapper CreatePerInstanceRowMapper(RoleMappedSchema schema) + private protected override IRowMapper CreatePerInstanceRowMapper(RoleMappedSchema schema) { var scoreInfo = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); int numClusters = scoreInfo.Type.VectorSize; @@ -156,7 +154,7 @@ public override IEnumerable GetOverallMetricColumns() yield return new MetricColumn("DBI", Dbi, MetricColumn.Objective.Minimize); } - protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries, + private protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries, out Action, Aggregator> addAgg, out Func> consolidate) { var stratCol = new List(); @@ -316,7 +314,7 @@ public Double Dbi } } - public Counters(int numClusters, bool calculateDbi, ColumnInfo features) + public Counters(int numClusters, bool calculateDbi, Schema.Column? features) { _numClusters = numClusters; CalculateDbi = calculateDbi; @@ -326,10 +324,10 @@ public Counters(int numClusters, bool calculateDbi, ColumnInfo features) _confusionMatrix = new List(); if (CalculateDbi) { - Contracts.AssertValue(features); + Contracts.Assert(features.HasValue); _clusterCentroids = new VBuffer[_numClusters]; for (int i = 0; i < _numClusters; i++) - _clusterCentroids[i] = VBufferUtils.CreateEmpty(features.Type.VectorSize); + _clusterCentroids[i] = VBufferUtils.CreateEmpty(features.Value.Type.VectorSize); _distancesToCentroids = new Double[_numClusters]; } } @@ -396,7 +394,7 @@ public void UpdateSecondPass(in VBuffer features, int[] indices) private readonly bool _calculateDbi; - public Aggregator(IHostEnvironment env, ColumnInfo features, int scoreVectorSize, bool calculateDbi, bool weighted, string stratName) + internal Aggregator(IHostEnvironment env, Schema.Column? features, int scoreVectorSize, bool calculateDbi, bool weighted, string stratName) : base(env, stratName) { _calculateDbi = calculateDbi; @@ -407,10 +405,10 @@ public Aggregator(IHostEnvironment env, ColumnInfo features, int scoreVectorSize WeightedCounters = Weighted ? new Counters(scoreVectorSize, _calculateDbi, features) : null; if (_calculateDbi) { - Host.AssertValue(features); + Host.Assert(features.HasValue); _clusterCentroids = new VBuffer[scoreVectorSize]; for (int i = 0; i < scoreVectorSize; i++) - _clusterCentroids[i] = VBufferUtils.CreateEmpty(features.Type.VectorSize); + _clusterCentroids[i] = VBufferUtils.CreateEmpty(features.Value.Type.VectorSize); } } @@ -484,7 +482,7 @@ private void ProcessRowSecondPass() WeightedCounters.UpdateSecondPass(in _features, _indicesArr); } - public override void InitializeNextPass(IRow row, RoleMappedSchema schema) + internal override void InitializeNextPass(Row row, RoleMappedSchema schema) { AssertValid(assertGetters: false); @@ -493,8 +491,8 @@ public override void InitializeNextPass(IRow row, RoleMappedSchema schema) if (_calculateDbi) { - Host.AssertValue(schema.Feature); - _featGetter = row.GetGetter>(schema.Feature.Index); + Host.Assert(schema.Feature.HasValue); + _featGetter = row.GetGetter>(schema.Feature.Value.Index); } var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); Host.Assert(score.Type.VectorSize == _scoresArr.Length); @@ -502,12 +500,12 @@ public override void InitializeNextPass(IRow row, RoleMappedSchema schema) if (PassNum == 0) { - if (schema.Label != null) - _labelGetter = RowCursorUtils.GetLabelGetter(row, schema.Label.Index); + if (schema.Label.HasValue) + _labelGetter = RowCursorUtils.GetLabelGetter(row, schema.Label.Value.Index); else _labelGetter = (ref Single value) => value = Single.NaN; - if (schema.Weight != null) - _weightGetter = row.GetGetter(schema.Weight.Index); + if (schema.Weight.HasValue) + _weightGetter = row.GetGetter(schema.Weight.Value.Index); } else { @@ -559,46 +557,6 @@ private void AssertValid(bool assertGetters) } } } - - /// - /// The metrics generated after evaluating the clustering predictions. - /// - public sealed class Result - { - /// - /// Normalized Mutual Information - /// NMI is a measure of the mutual dependence of the variables. - /// Normalized variants work on data that already has cluster labels. - /// Its value ranged from 0 to 1, where higher numbers are better. - /// - public double Nmi { get; } - - /// - /// Average Score. For the K-Means algorithm, the 'score' is the distance from the centroid to the example. - /// The average score is, therefore, a measure of proximity of the examples to cluster centroids. - /// In other words, it's the 'cluster tightness' measure. - /// Note however, that this metric will only decrease if the number of clusters is increased, - /// and in the extreme case (where each distinct example is its own cluster) it will be equal to zero. - /// - public double AvgMinScore { get; } - - /// - /// Davies-Bouldin Index - /// DBI is a measure of the how much scatter is in the cluster and the cluster separation. - /// - public double Dbi { get; } - - internal Result(IExceptionContext ectx, IRow overallResult, bool calculateDbi) - { - double Fetch(string name) => RowCursorUtils.Fetch(ectx, overallResult, name); - - Nmi = Fetch(ClusteringEvaluator.Nmi); - AvgMinScore = Fetch(ClusteringEvaluator.AvgMinScore); - - if (calculateDbi) - Dbi = Fetch(ClusteringEvaluator.Dbi); - } - } } public sealed class ClusteringPerInstanceEvaluator : PerInstanceEvaluatorBase @@ -626,7 +584,7 @@ private static VersionInfo GetVersionInfo() private readonly int _numClusters; private readonly ColumnType[] _types; - public ClusteringPerInstanceEvaluator(IHostEnvironment env, ISchema schema, string scoreCol, int numClusters) + public ClusteringPerInstanceEvaluator(IHostEnvironment env, Schema schema, string scoreCol, int numClusters) : base(env, schema, scoreCol, null) { CheckInputColumnTypes(schema); @@ -639,7 +597,7 @@ public ClusteringPerInstanceEvaluator(IHostEnvironment env, ISchema schema, stri _types[SortedClusterScoreCol] = new VectorType(NumberType.R4, _numClusters); } - private ClusteringPerInstanceEvaluator(IHostEnvironment env, ModelLoadContext ctx, ISchema schema) + private ClusteringPerInstanceEvaluator(IHostEnvironment env, ModelLoadContext ctx, Schema schema) : base(env, ctx, schema) { CheckInputColumnTypes(schema); @@ -658,7 +616,7 @@ private ClusteringPerInstanceEvaluator(IHostEnvironment env, ModelLoadContext ct _types[SortedClusterScoreCol] = new VectorType(NumberType.R4, _numClusters); } - public static ClusteringPerInstanceEvaluator Create(IHostEnvironment env, ModelLoadContext ctx, ISchema schema) + public static ClusteringPerInstanceEvaluator Create(IHostEnvironment env, ModelLoadContext ctx, Schema schema) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); @@ -678,7 +636,7 @@ public override void Save(ModelSaveContext ctx) ctx.Writer.Write(_numClusters); } - public override Func GetDependencies(Func activeOutput) + private protected override Func GetDependenciesCore(Func activeOutput) { return col => @@ -686,13 +644,13 @@ public override Func GetDependencies(Func activeOutput) (activeOutput(ClusterIdCol) || activeOutput(SortedClusterCol) || activeOutput(SortedClusterScoreCol)); } - public override Delegate[] CreateGetters(IRow input, Func activeOutput, out Action disposer) + private protected override Delegate[] CreateGettersCore(Row input, Func activeCols, out Action disposer) { disposer = null; var getters = new Delegate[3]; - if (!activeOutput(ClusterIdCol) && !activeOutput(SortedClusterCol) && !activeOutput(SortedClusterScoreCol)) + if (!activeCols(ClusterIdCol) && !activeCols(SortedClusterCol) && !activeCols(SortedClusterScoreCol)) return getters; long cachedPosition = -1; @@ -715,7 +673,7 @@ public override Delegate[] CreateGetters(IRow input, Func activeOutpu } }; - if (activeOutput(ClusterIdCol)) + if (activeCols(ClusterIdCol)) { ValueGetter assignedFn = (ref uint dst) => @@ -726,7 +684,7 @@ public override Delegate[] CreateGetters(IRow input, Func activeOutpu getters[ClusterIdCol] = assignedFn; } - if (activeOutput(SortedClusterScoreCol)) + if (activeCols(SortedClusterScoreCol)) { ValueGetter> topKScoresFn = (ref VBuffer dst) => @@ -740,7 +698,7 @@ public override Delegate[] CreateGetters(IRow input, Func activeOutpu getters[SortedClusterScoreCol] = topKScoresFn; } - if (activeOutput(SortedClusterCol)) + if (activeCols(SortedClusterCol)) { ValueGetter> topKClassesFn = (ref VBuffer dst) => @@ -756,7 +714,7 @@ public override Delegate[] CreateGetters(IRow input, Func activeOutpu return getters; } - public override Schema.DetachedColumn[] GetOutputColumns() + private protected override Schema.DetachedColumn[] GetOutputColumnsCore() { var infos = new Schema.DetachedColumn[3]; infos[ClusterIdCol] = new Schema.DetachedColumn(ClusterId, _types[ClusterIdCol], null); @@ -787,11 +745,11 @@ private ValueGetter>> CreateSlotNamesGetter(int num }; } - private void CheckInputColumnTypes(ISchema schema) + private void CheckInputColumnTypes(Schema schema) { Host.AssertNonEmpty(ScoreCol); - var type = schema.GetColumnType(ScoreIndex); + var type = schema[(int) ScoreIndex].Type; if (!type.IsKnownSizeVector || type.ItemType != NumberType.Float) throw Host.Except("Score column '{0}' has type {1}, but must be a float vector of known-size", ScoreCol, type); } @@ -817,7 +775,7 @@ public class Arguments : ArgumentsBase private readonly string _featureCol; private readonly bool _calculateDbi; - protected override IEvaluator Evaluator { get { return _evaluator; } } + private protected override IEvaluator Evaluator => _evaluator; public ClusteringMamlEvaluator(IHostEnvironment env, Arguments args) : base(args, env, MetadataUtils.Const.ScoreColumnKind.Clustering, "ClusteringMamlEvaluator") @@ -836,7 +794,7 @@ public ClusteringMamlEvaluator(IHostEnvironment env, Arguments args) _evaluator = new ClusteringEvaluator(Host, evalArgs); } - protected override IEnumerable> GetInputColumnRolesCore(RoleMappedSchema schema) + private protected override IEnumerable> GetInputColumnRolesCore(RoleMappedSchema schema) { foreach (var col in base.GetInputColumnRolesCore(schema)) { @@ -856,13 +814,13 @@ public ClusteringMamlEvaluator(IHostEnvironment env, Arguments args) } // Clustering evaluator adds three per-instance columns: "ClusterId", "Top clusters" and "Top cluster scores". - protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema) + private protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema) { Host.CheckValue(schema, nameof(schema)); // Output the label column if it exists. - if (schema.Label != null) - yield return schema.Label.Name; + if (schema.Label.HasValue) + yield return schema.Label.Value.Name; // Return the output columns. yield return ClusteringPerInstanceEvaluator.ClusterId; @@ -870,19 +828,19 @@ protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSch yield return ClusteringPerInstanceEvaluator.SortedClusterScores; } - protected override IDataView GetPerInstanceMetricsCore(IDataView perInst, RoleMappedSchema schema) + private protected override IDataView GetPerInstanceMetricsCore(IDataView perInst, RoleMappedSchema schema) { // Wrap with a DropSlots transform to pick only the first _numTopClusters slots. if (perInst.Schema.TryGetColumnIndex(ClusteringPerInstanceEvaluator.SortedClusters, out int index)) { - var type = perInst.Schema.GetColumnType(index); + var type = perInst.Schema[index].Type; if (_numTopClusters < type.VectorSize) perInst = new SlotsDroppingTransformer(Host, ClusteringPerInstanceEvaluator.SortedClusters, min: _numTopClusters).Transform(perInst); } if (perInst.Schema.TryGetColumnIndex(ClusteringPerInstanceEvaluator.SortedClusterScores, out index)) { - var type = perInst.Schema.GetColumnType(index); + var type = perInst.Schema[index].Type; if (_numTopClusters < type.VectorSize) perInst = new SlotsDroppingTransformer(Host, ClusteringPerInstanceEvaluator.SortedClusterScores, min: _numTopClusters).Transform(perInst); } @@ -901,11 +859,11 @@ public static CommonOutputs.CommonEvaluateOutput Clustering(IHostEnvironment env EntryPointUtils.CheckInputArgs(host, input); MatchColumns(host, input, out string label, out string weight, out string name); - ISchema schema = input.Data.Schema; + var schema = input.Data.Schema; string features = TrainUtils.MatchNameOrDefaultOrNull(host, schema, nameof(ClusteringMamlEvaluator.Arguments.FeatureColumn), input.FeatureColumn, DefaultColumnNames.Features); - var evaluator = new ClusteringMamlEvaluator(host, input); + IMamlEvaluator evaluator = new ClusteringMamlEvaluator(host, input); var data = new RoleMappedData(input.Data, label, features, null, weight, name); var metrics = evaluator.Evaluate(data); diff --git a/src/Microsoft.ML.Data/Evaluators/EvaluatorBase.cs b/src/Microsoft.ML.Data/Evaluators/EvaluatorBase.cs index 14c8d58649..217f6751cc 100644 --- a/src/Microsoft.ML.Data/Evaluators/EvaluatorBase.cs +++ b/src/Microsoft.ML.Data/Evaluators/EvaluatorBase.cs @@ -5,16 +5,14 @@ using System; using System.Collections.Generic; using System.Linq; -using System.Reflection; -using Microsoft.ML.Data; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// This is a base class for TLC evaluators. It implements both of the methods: and - /// . Note that the input is assumed to contain all the column + /// . Note that the input is assumed to contain all the column /// roles needed for evaluation, including the score column. /// public abstract partial class EvaluatorBase : IEvaluator @@ -22,13 +20,14 @@ public abstract partial class EvaluatorBase : IEvaluator { protected readonly IHost Host; - protected EvaluatorBase(IHostEnvironment env, string registrationName) + [BestFriend] + private protected EvaluatorBase(IHostEnvironment env, string registrationName) { Contracts.CheckValue(env, nameof(env)); Host = env.Register(registrationName); } - public Dictionary Evaluate(RoleMappedData data) + Dictionary IEvaluator.Evaluate(RoleMappedData data) { CheckColumnTypes(data.Schema); Func activeCols = GetActiveCols(data.Schema); @@ -44,11 +43,12 @@ public Dictionary Evaluate(RoleMappedData data) /// Checks the column types of the evaluator's input columns. The base class implementation checks only the type /// of the weight column, and all other columns should be checked by the deriving classes in . /// - protected void CheckColumnTypes(RoleMappedSchema schema) + [BestFriend] + private protected void CheckColumnTypes(RoleMappedSchema schema) { // Check the weight column type. - if (schema.Weight != null) - EvaluateUtils.CheckWeightType(Host, schema.Weight.Type); + if (schema.Weight.HasValue) + EvaluateUtils.CheckWeightType(Host, schema.Weight.Value.Type); CheckScoreAndLabelTypes(schema); // Check the other column types. CheckCustomColumnTypesCore(schema); @@ -60,13 +60,15 @@ protected void CheckColumnTypes(RoleMappedSchema schema) /// Access the label column with the property, and the score column with the /// or methods. /// - protected abstract void CheckScoreAndLabelTypes(RoleMappedSchema schema); + [BestFriend] + private protected abstract void CheckScoreAndLabelTypes(RoleMappedSchema schema); /// /// Check the types of any other columns needed by the evaluator. Only override if the evaluator uses /// columns other than label, score and weight. /// - protected virtual void CheckCustomColumnTypesCore(RoleMappedSchema schema) + [BestFriend] + private protected virtual void CheckCustomColumnTypesCore(RoleMappedSchema schema) { } @@ -84,11 +86,12 @@ private Func GetActiveCols(RoleMappedSchema schema) /// and the stratification columns. /// Override if other input columns need to be activated. /// - protected virtual Func GetActiveColsCore(RoleMappedSchema schema) + [BestFriend] + private protected virtual Func GetActiveColsCore(RoleMappedSchema schema) { var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); - var label = schema.Label == null ? -1 : schema.Label.Index; - var weight = schema.Weight == null ? -1 : schema.Weight.Index; + int label = schema.Label?.Index ?? -1; + int weight = schema.Weight?.Index ?? -1; return i => i == score.Index || i == label || i == weight; } @@ -116,7 +119,8 @@ private AggregatorDictionaryBase[] GetAggregatorDictionaries(RoleMappedSchema sc return list.ToArray(); } - protected abstract TAgg GetAggregatorCore(RoleMappedSchema schema, string stratName); + [BestFriend] + private protected abstract TAgg GetAggregatorCore(RoleMappedSchema schema, string stratName); // This method does as many passes over the data as needed by the evaluator, and computes the metrics, outputting the // results in a dictionary from the metric kind (overal/per-fold/confusion matrix/PR-curves etc.), to a data view containing @@ -192,10 +196,12 @@ private Dictionary ProcessData(IDataView data, RoleMappedSche /// is called after has been called on all the aggregators, and it returns /// the dictionary of metric data views. /// - protected abstract void GetAggregatorConsolidationFuncs(TAgg aggregator, AggregatorDictionaryBase[] dictionaries, + [BestFriend] + private protected abstract void GetAggregatorConsolidationFuncs(TAgg aggregator, AggregatorDictionaryBase[] dictionaries, out Action, TAgg> addAgg, out Func> consolidate); - protected ValueGetter>> GetKeyValueGetter(AggregatorDictionaryBase[] dictionaries) + [BestFriend] + private protected ValueGetter>> GetKeyValueGetter(AggregatorDictionaryBase[] dictionaries) { if (Utils.Size(dictionaries) == 0) return null; @@ -209,7 +215,10 @@ protected ValueGetter>> GetKeyValueGetter(Aggregato }; } - public abstract IDataTransform GetPerInstanceMetrics(RoleMappedData data); + IDataTransform IEvaluator.GetPerInstanceMetrics(RoleMappedData data) => GetPerInstanceMetricsCore(data); + + [BestFriend] + internal abstract IDataTransform GetPerInstanceMetricsCore(RoleMappedData data); public abstract IEnumerable GetOverallMetricColumns(); @@ -233,7 +242,8 @@ public abstract class AggregatorBase protected int PassNum; - protected AggregatorBase(IHostEnvironment env, string stratName) + [BestFriend] + private protected AggregatorBase(IHostEnvironment env, string stratName) { Contracts.AssertValue(env); Host = env.Register("Aggregator"); @@ -253,7 +263,8 @@ public bool Start() /// /// This method should get the getters of the new IRow that are needed for the next pass. /// - public abstract void InitializeNextPass(IRow row, RoleMappedSchema schema); + [BestFriend] + internal abstract void InitializeNextPass(Row row, RoleMappedSchema schema); /// /// Call the getters once, and process the input as necessary. @@ -324,15 +335,15 @@ protected virtual List GetWarningsCore() // When a new value is encountered, it uses a callback for creating a new aggregator. protected abstract class AggregatorDictionaryBase { - protected IRow Row; - protected readonly Func CreateAgg; - protected readonly RoleMappedSchema Schema; + private protected Row Row; + private protected readonly Func CreateAgg; + private protected readonly RoleMappedSchema Schema; public string ColName { get; } public abstract int Count { get; } - protected AggregatorDictionaryBase(RoleMappedSchema schema, string stratCol, Func createAgg) + private protected AggregatorDictionaryBase(RoleMappedSchema schema, string stratCol, Func createAgg) { Contracts.AssertValue(schema); Contracts.AssertNonWhiteSpace(stratCol); @@ -346,9 +357,9 @@ protected AggregatorDictionaryBase(RoleMappedSchema schema, string stratCol, Fun /// /// Gets the stratification column getter for the new IRow. /// - public abstract void Reset(IRow row); + public abstract void Reset(Row row); - public static AggregatorDictionaryBase Create(RoleMappedSchema schema, string stratCol, ColumnType stratType, + internal static AggregatorDictionaryBase Create(RoleMappedSchema schema, string stratCol, ColumnType stratType, Func createAgg) { Contracts.AssertNonWhiteSpace(stratCol); @@ -397,7 +408,7 @@ public GenericAggregatorDictionary(RoleMappedSchema schema, string stratCol, Col _dict = new Dictionary(); } - public override void Reset(IRow row) + public override void Reset(Row row) { Row = row; int col; @@ -435,18 +446,20 @@ public override IEnumerable GetAll() public abstract class RowToRowEvaluatorBase : EvaluatorBase where TAgg : EvaluatorBase.AggregatorBase { - protected RowToRowEvaluatorBase(IHostEnvironment env, string registrationName) + [BestFriend] + private protected RowToRowEvaluatorBase(IHostEnvironment env, string registrationName) : base(env, registrationName) { } - public override IDataTransform GetPerInstanceMetrics(RoleMappedData data) + internal override IDataTransform GetPerInstanceMetricsCore(RoleMappedData data) { var mapper = CreatePerInstanceRowMapper(data.Schema); return new RowToRowMapperTransform(Host, data.Data, mapper, null); } - protected abstract IRowMapper CreatePerInstanceRowMapper(RoleMappedSchema schema); + [BestFriend] + private protected abstract IRowMapper CreatePerInstanceRowMapper(RoleMappedSchema schema); } /// @@ -460,7 +473,7 @@ public abstract class PerInstanceEvaluatorBase : IRowMapper protected readonly int ScoreIndex; protected readonly int LabelIndex; - protected PerInstanceEvaluatorBase(IHostEnvironment env, ISchema schema, string scoreCol, string labelCol) + protected PerInstanceEvaluatorBase(IHostEnvironment env, Schema schema, string scoreCol, string labelCol) { Contracts.AssertValue(env); Contracts.AssertNonEmpty(scoreCol); @@ -475,7 +488,7 @@ protected PerInstanceEvaluatorBase(IHostEnvironment env, ISchema schema, string throw Host.Except("Did not find column '{0}'", ScoreCol); } - protected PerInstanceEvaluatorBase(IHostEnvironment env, ModelLoadContext ctx, ISchema schema) + protected PerInstanceEvaluatorBase(IHostEnvironment env, ModelLoadContext ctx, Schema schema) { Host = env.Register("PerInstanceRowMapper"); @@ -501,10 +514,22 @@ public virtual void Save(ModelSaveContext ctx) ctx.SaveStringOrNull(LabelCol); } - public abstract Func GetDependencies(Func activeOutput); + Func IRowMapper.GetDependencies(Func activeOutput) + => GetDependenciesCore(activeOutput); + + [BestFriend] + private protected abstract Func GetDependenciesCore(Func activeOutput); + + Schema.DetachedColumn[] IRowMapper.GetOutputColumns() + => GetOutputColumnsCore(); + + [BestFriend] + private protected abstract Schema.DetachedColumn[] GetOutputColumnsCore(); - public abstract Schema.DetachedColumn[] GetOutputColumns(); + Delegate[] IRowMapper.CreateGetters(Row input, Func activeCols, out Action disposer) + => CreateGettersCore(input, activeCols, out disposer); - public abstract Delegate[] CreateGetters(IRow input, Func activeCols, out Action disposer); + [BestFriend] + private protected abstract Delegate[] CreateGettersCore(Row input, Func activeCols, out Action disposer); } } diff --git a/src/Microsoft.ML.Data/Evaluators/EvaluatorStaticExtensions.cs b/src/Microsoft.ML.Data/Evaluators/EvaluatorStaticExtensions.cs index df58e4d46c..8d66a8a53c 100644 --- a/src/Microsoft.ML.Data/Evaluators/EvaluatorStaticExtensions.cs +++ b/src/Microsoft.ML.Data/Evaluators/EvaluatorStaticExtensions.cs @@ -3,11 +3,10 @@ // See the LICENSE file in the project root for more information. using System; -using Microsoft.ML.Runtime.Training; using Microsoft.ML.StaticPipe; using Microsoft.ML.StaticPipe.Runtime; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// Extension methods for evaluation. @@ -24,7 +23,7 @@ public static class EvaluatorStaticExtensions /// The index delegate for columns from calibrated prediction of a binary classifier. /// Under typical scenarios, this will just be the same tuple of results returned from the trainer. /// The evaluation results for these calibrated outputs. - public static BinaryClassifierEvaluator.CalibratedResult Evaluate( + public static CalibratedBinaryClassificationMetrics Evaluate( this BinaryClassificationContext ctx, DataView data, Func> label, @@ -60,7 +59,7 @@ public static BinaryClassifierEvaluator.CalibratedResult Evaluate( /// The index delegate for columns from uncalibrated prediction of a binary classifier. /// Under typical scenarios, this will just be the same tuple of results returned from the trainer. /// The evaluation results for these uncalibrated outputs. - public static BinaryClassifierEvaluator.Result Evaluate( + public static BinaryClassificationMetrics Evaluate( this BinaryClassificationContext ctx, DataView data, Func> label, @@ -94,7 +93,7 @@ public static BinaryClassifierEvaluator.Result Evaluate( /// The optional index delegate for the label column. /// The optional index delegate for the features column. /// The evaluation metrics. - public static ClusteringEvaluator.Result Evaluate( + public static ClusteringMetrics Evaluate( this ClusteringContext ctx, DataView data, Func> score, @@ -127,11 +126,11 @@ public static ClusteringEvaluator.Result Evaluate( /// The index delegate for the label column. /// The index delegate for columns from the prediction of a multiclass classifier. /// Under typical scenarios, this will just be the same tuple of results returned from the trainer. - /// If given a positive value, the will be filled with + /// If given a positive value, the will be filled with /// the top-K accuracy, that is, the accuracy assuming we consider an example with the correct class within /// the top-K values as being stored "correctly." /// The evaluation metrics. - public static MultiClassClassifierEvaluator.Result Evaluate( + public static MultiClassClassifierMetrics Evaluate( this MulticlassClassificationContext ctx, DataView data, Func> label, @@ -178,7 +177,7 @@ private sealed class TrivialRegressionLossFactory : ISupportRegressionLossFactor /// The index delegate for predicted score column. /// Potentially custom loss function. If left unspecified defaults to . /// The evaluation metrics. - public static RegressionEvaluator.Result Evaluate( + public static RegressionMetrics Evaluate( this RegressionContext ctx, DataView data, Func> label, @@ -212,7 +211,7 @@ public static RegressionEvaluator.Result Evaluate( /// The index delegate for the groupId column. /// The index delegate for predicted score column. /// The evaluation metrics. - public static RankerEvaluator.Result Evaluate( + public static RankerMetrics Evaluate( this RankingContext ctx, DataView data, Func> label, diff --git a/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs b/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs index 8889ccd6cf..8a3c61705d 100644 --- a/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs +++ b/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs @@ -3,21 +3,21 @@ // See the LICENSE file in the project root for more information. #pragma warning disable 420 // volatile with Interlocked.CompareExchange -using Microsoft.ML.Data; -using Microsoft.ML.Runtime.Data.IO; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Transforms; -using Microsoft.ML.Transforms.Conversions; using System; using System.Collections.Generic; using System.Globalization; using System.Linq; using System.Text; using System.Threading; +using Microsoft.ML.Data.IO; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Transforms; +using Microsoft.ML.Transforms.Conversions; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { - public static class EvaluateUtils + [BestFriend] + internal static class EvaluateUtils { public struct AggregatedMetric { @@ -62,7 +62,7 @@ public static IMamlEvaluator GetEvaluator(IHostEnvironment env, Schema schema) schema.GetMaxMetadataKind(out int col, MetadataUtils.Kinds.ScoreColumnSetId, CheckScoreColumnKindIsKnown); if (col >= 0) { - schema.GetMetadata(MetadataUtils.Kinds.ScoreColumnKind, col, ref tmp); + schema[col].Metadata.GetValue(MetadataUtils.Kinds.ScoreColumnKind, ref tmp); var kind = tmp.ToString(); var map = DefaultEvaluatorTable.Instance; // The next assert is guaranteed because it is checked in CheckScoreColumnKindIsKnown which is the lambda passed to GetMaxMetadataKind. @@ -73,7 +73,7 @@ public static IMamlEvaluator GetEvaluator(IHostEnvironment env, Schema schema) schema.GetMaxMetadataKind(out col, MetadataUtils.Kinds.ScoreColumnSetId, CheckScoreColumnKind); if (col >= 0) { - schema.GetMetadata(MetadataUtils.Kinds.ScoreColumnKind, col, ref tmp); + schema[col].Metadata.GetValue(MetadataUtils.Kinds.ScoreColumnKind, ref tmp); throw env.ExceptUserArg(nameof(EvaluateCommand.Arguments.Evaluator), "No default evaluator found for score column kind '{0}'.", tmp.ToString()); } @@ -81,13 +81,13 @@ public static IMamlEvaluator GetEvaluator(IHostEnvironment env, Schema schema) } // Lambda used as validator/filter in calls to GetMaxMetadataKind. - private static bool CheckScoreColumnKindIsKnown(ISchema schema, int col) + private static bool CheckScoreColumnKindIsKnown(Schema schema, int col) { - var columnType = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.ScoreColumnKind, col); + var columnType = schema[col].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.ScoreColumnKind)?.Type; if (columnType == null || !columnType.IsText) return false; ReadOnlyMemory tmp = default; - schema.GetMetadata(MetadataUtils.Kinds.ScoreColumnKind, col, ref tmp); + schema[col].Metadata.GetValue(MetadataUtils.Kinds.ScoreColumnKind, ref tmp); var map = DefaultEvaluatorTable.Instance; return map.ContainsKey(tmp.ToString()); } @@ -95,16 +95,16 @@ private static bool CheckScoreColumnKindIsKnown(ISchema schema, int col) // Lambda used as validator/filter in calls to GetMaxMetadataKind. private static bool CheckScoreColumnKind(Schema schema, int col) { - var columnType = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.ScoreColumnKind, col); + var columnType = schema[col].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.ScoreColumnKind)?.Type; return columnType != null && columnType.IsText; } /// - /// Find the score column to use. If name is specified, that is used. Otherwise, this searches for the - /// most recent score set of the given kind. If there is no such score set and defName is specifed it - /// uses defName. Otherwise, it throws. + /// Find the score column to use. If is specified, that is used. Otherwise, this searches + /// for the most recent score set of the given . If there is no such score set and + /// is specifed it uses . Otherwise, it throws. /// - public static ColumnInfo GetScoreColumnInfo(IExceptionContext ectx, Schema schema, string name, string argName, string kind, + public static Schema.Column GetScoreColumn(IExceptionContext ectx, Schema schema, string name, string argName, string kind, string valueKind = MetadataUtils.Const.ScoreValueKind.Score, string defName = null) { Contracts.CheckValueOrNull(ectx); @@ -114,39 +114,40 @@ public static ColumnInfo GetScoreColumnInfo(IExceptionContext ectx, Schema schem ectx.CheckNonEmpty(kind, nameof(kind)); ectx.CheckNonEmpty(valueKind, nameof(valueKind)); - int colTmp; - ColumnInfo info; if (!string.IsNullOrWhiteSpace(name)) { -#pragma warning disable MSML_ContractsNameUsesNameof - if (!ColumnInfo.TryCreateFromName(schema, name, out info)) +#pragma warning disable MSML_ContractsNameUsesNameof // This utility method is meant to reflect the argument name of whatever is calling it, so we take that as a parameter, rather than using nameof directly as in most cases. + var col = schema.GetColumnOrNull(name); + if (!col.HasValue) throw ectx.ExceptUserArg(argName, "Score column is missing"); #pragma warning restore MSML_ContractsNameUsesNameof - return info; + return col.Value; } - var maxSetNum = schema.GetMaxMetadataKind(out colTmp, MetadataUtils.Kinds.ScoreColumnSetId, + var maxSetNum = schema.GetMaxMetadataKind(out int colTmp, MetadataUtils.Kinds.ScoreColumnSetId, (s, c) => IsScoreColumnKind(ectx, s, c, kind)); ReadOnlyMemory tmp = default; - foreach (var col in schema.GetColumnSet(MetadataUtils.Kinds.ScoreColumnSetId, maxSetNum)) + foreach (var colIdx in schema.GetColumnSet(MetadataUtils.Kinds.ScoreColumnSetId, maxSetNum)) { + var col = schema[colIdx]; #if DEBUG - schema.GetMetadata(MetadataUtils.Kinds.ScoreColumnKind, col, ref tmp); + col.Metadata.GetValue(MetadataUtils.Kinds.ScoreColumnKind, ref tmp); ectx.Assert(ReadOnlyMemoryUtils.EqualsStr(kind, tmp)); #endif // REVIEW: What should this do about hidden columns? Currently we ignore them. - if (schema.IsHidden(col)) + if (col.IsHidden) continue; - if (schema.TryGetMetadata(TextType.Instance, MetadataUtils.Kinds.ScoreValueKind, col, ref tmp) && - ReadOnlyMemoryUtils.EqualsStr(valueKind, tmp)) + if (col.Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.ScoreValueKind)?.Type == TextType.Instance) { - return ColumnInfo.CreateFromIndex(schema, col); + col.Metadata.GetValue(MetadataUtils.Kinds.ScoreValueKind, ref tmp); + if (ReadOnlyMemoryUtils.EqualsStr(valueKind, tmp)) + return col; } } - if (!string.IsNullOrWhiteSpace(defName) && ColumnInfo.TryCreateFromName(schema, defName, out info)) - return info; + if (!string.IsNullOrWhiteSpace(defName) && schema.GetColumnOrNull(defName) is Schema.Column defCol) + return defCol; #pragma warning disable MSML_ContractsNameUsesNameof throw ectx.ExceptUserArg(argName, "Score column is missing"); @@ -154,54 +155,55 @@ public static ColumnInfo GetScoreColumnInfo(IExceptionContext ectx, Schema schem } /// - /// Find the optional auxilliary score column to use. If name is specified, that is used. - /// Otherwise, if colScore is part of a score set, this looks in the score set for a column - /// with the given valueKind. If none is found, it returns null. + /// Find the optional auxilliary score column to use. If is specified, that is used. + /// Otherwise, if is part of a score set, this looks in the score set for a column + /// with the given . If none is found, it returns . /// - public static ColumnInfo GetOptAuxScoreColumnInfo(IExceptionContext ectx, Schema schema, string name, string argName, + public static Schema.Column? GetOptAuxScoreColumn(IExceptionContext ectx, Schema schema, string name, string argName, int colScore, string valueKind, Func testType) { Contracts.CheckValueOrNull(ectx); ectx.CheckValue(schema, nameof(schema)); ectx.CheckValueOrNull(name); ectx.CheckNonEmpty(argName, nameof(argName)); - ectx.CheckParam(0 <= colScore && colScore < schema.ColumnCount, nameof(colScore)); + ectx.CheckParam(0 <= colScore && colScore < schema.Count, nameof(colScore)); ectx.CheckNonEmpty(valueKind, nameof(valueKind)); if (!string.IsNullOrWhiteSpace(name)) { - ColumnInfo info; #pragma warning disable MSML_ContractsNameUsesNameof - if (!ColumnInfo.TryCreateFromName(schema, name, out info)) + var col = schema.GetColumnOrNull(name); + if (!col.HasValue) throw ectx.ExceptUserArg(argName, "{0} column is missing", valueKind); - if (!testType(info.Type)) + if (!testType(col.Value.Type)) throw ectx.ExceptUserArg(argName, "{0} column has incompatible type", valueKind); #pragma warning restore MSML_ContractsNameUsesNameof - return info; + return col.Value; } // Get the score column set id from colScore. - var type = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.ScoreColumnSetId, colScore); + var type = schema[colScore].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.ScoreColumnSetId)?.Type; if (type == null || !type.IsKey || type.RawKind != DataKind.U4) { // scoreCol is not part of a score column set, so can't determine an aux column. return null; } uint setId = 0; - schema.GetMetadata(MetadataUtils.Kinds.ScoreColumnSetId, colScore, ref setId); + schema[colScore].Metadata.GetValue(MetadataUtils.Kinds.ScoreColumnSetId, ref setId); ReadOnlyMemory tmp = default; - foreach (var col in schema.GetColumnSet(MetadataUtils.Kinds.ScoreColumnSetId, setId)) + foreach (var colIdx in schema.GetColumnSet(MetadataUtils.Kinds.ScoreColumnSetId, setId)) { // REVIEW: What should this do about hidden columns? Currently we ignore them. - if (schema.IsHidden(col)) + var col = schema[colIdx]; + if (col.IsHidden) continue; - if (schema.TryGetMetadata(TextType.Instance, MetadataUtils.Kinds.ScoreValueKind, col, ref tmp) && - ReadOnlyMemoryUtils.EqualsStr(valueKind, tmp)) + + if (col.Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.ScoreValueKind)?.Type == TextType.Instance) { - var res = ColumnInfo.CreateFromIndex(schema, col); - if (testType(res.Type)) - return res; + col.Metadata.GetValue(MetadataUtils.Kinds.ScoreValueKind, ref tmp); + if (ReadOnlyMemoryUtils.EqualsStr(valueKind, tmp) && testType(col.Type)) + return col; } } @@ -209,36 +211,33 @@ public static ColumnInfo GetOptAuxScoreColumnInfo(IExceptionContext ectx, Schema return null; } - private static bool IsScoreColumnKind(IExceptionContext ectx, ISchema schema, int col, string kind) + private static bool IsScoreColumnKind(IExceptionContext ectx, Schema schema, int col, string kind) { Contracts.CheckValueOrNull(ectx); ectx.CheckValue(schema, nameof(schema)); - ectx.CheckParam(0 <= col && col < schema.ColumnCount, nameof(col)); + ectx.CheckParam(0 <= col && col < schema.Count, nameof(col)); ectx.CheckNonEmpty(kind, nameof(kind)); - var type = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.ScoreColumnKind, col); + var type = schema[col].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.ScoreColumnKind)?.Type; if (type == null || !type.IsText) return false; var tmp = default(ReadOnlyMemory); - schema.GetMetadata(MetadataUtils.Kinds.ScoreColumnKind, col, ref tmp); + schema[col].Metadata.GetValue(MetadataUtils.Kinds.ScoreColumnKind, ref tmp); return ReadOnlyMemoryUtils.EqualsStr(kind, tmp); } /// - /// If str is non-empty, returns it. Otherwise if info is non-null, returns info.Name. - /// Otherwise, returns def. + /// If is non-empty, returns it. Otherwise if is non-, + /// returns its . Otherwise, returns . /// - public static string GetColName(string str, ColumnInfo info, string def) + public static string GetColName(string str, Schema.Column? info, string def) { Contracts.CheckValueOrNull(str); - Contracts.CheckValueOrNull(info); Contracts.CheckValueOrNull(def); if (!string.IsNullOrEmpty(str)) return str; - if (info != null) - return info.Name; - return def; + return info?.Name ?? def; } public static void CheckWeightType(IExceptionContext ectx, ColumnType type) @@ -285,24 +284,24 @@ public static IEnumerable> GetMetrics(IDataView met ValueGetter stratColGetter; if (hasStrats) { - var type = cursor.Schema.GetColumnType(stratCol); + var type = cursor.Schema[stratCol].Type; stratColGetter = RowCursorUtils.GetGetterAs(type, cursor, stratCol); } else stratColGetter = (ref uint dst) => dst = 0; // We currently have only double valued or vector of double valued metrics. - var colCount = schema.ColumnCount; + var colCount = schema.Count; var getters = new ValueGetter[colCount]; var vBufferGetters = getVectorMetrics ? new ValueGetter>[colCount] : null; - for (int i = 0; i < schema.ColumnCount; i++) + for (int i = 0; i < schema.Count; i++) { - if (schema.IsHidden(i) || hasWeighted && i == isWeightedCol || + if (schema[i].IsHidden || hasWeighted && i == isWeightedCol || hasStrats && (i == stratCol || i == stratVal)) continue; - var type = schema.GetColumnType(i); + var type = schema[i].Type; if (type == NumberType.R8 || type == NumberType.R4) getters[i] = RowCursorUtils.GetGetterAs(NumberType.R8, cursor, i); else if (type.IsKnownSizeVector && type.ItemType == NumberType.R8 && getVectorMetrics) @@ -336,7 +335,7 @@ public static IEnumerable> GetMetrics(IDataView met { getters[i](ref metricVal); // For R8 valued columns the metric name is the column name. - yield return new KeyValuePair(schema.GetColumnName(i), metricVal); + yield return new KeyValuePair(schema[i].Name, metricVal); } else if (getVectorMetrics && vBufferGetters[i] != null) { @@ -345,10 +344,10 @@ public static IEnumerable> GetMetrics(IDataView met // For R8 vector valued columns the names of the metrics are the column name, // followed by the slot name if it exists, or "Label_i" if it doesn't. VBuffer> names = default; - var size = schema.GetColumnType(i).VectorSize; - var slotNamesType = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.SlotNames, i); + var size = schema[i].Type.VectorSize; + var slotNamesType = schema[i].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.SlotNames)?.Type; if (slotNamesType != null && slotNamesType.VectorSize == size && slotNamesType.ItemType.IsText) - schema.GetMetadata(MetadataUtils.Kinds.SlotNames, i, ref names); + schema[i].Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref names); else { var namesArray = new ReadOnlyMemory[size]; @@ -356,7 +355,7 @@ public static IEnumerable> GetMetrics(IDataView met namesArray[j] = string.Format("({0})", j).AsMemory(); names = new VBuffer>(size, namesArray); } - var colName = schema.GetColumnName(i); + var colName = schema[i].Name; foreach (var metric in metricVals.Items(all: true)) { yield return new KeyValuePair( @@ -392,12 +391,12 @@ public static IDataView AddFoldIndex(IHostEnvironment env, IDataView input, int // We use the first column in the data view as an input column to the LambdaColumnMapper, // because it must have an input. int inputCol = 0; - while (inputCol < input.Schema.ColumnCount && input.Schema.IsHidden(inputCol)) + while (inputCol < input.Schema.Count && input.Schema[inputCol].IsHidden) inputCol++; - env.Assert(inputCol < input.Schema.ColumnCount); + env.Assert(inputCol < input.Schema.Count); - var inputColName = input.Schema.GetColumnName(0); - var inputColType = input.Schema.GetColumnType(0); + var inputColName = input.Schema[0].Name; + var inputColType = input.Schema[0].Type; return Utils.MarshalInvoke(AddTextColumn, inputColType.RawType, env, input, inputColName, MetricKinds.ColumnNames.FoldIndex, inputColType, $"Fold {curFold}", "FoldName"); } @@ -434,12 +433,12 @@ public static IDataView AddFoldIndex(IHostEnvironment env, IDataView input, int // We use the first column in the data view as an input column to the LambdaColumnMapper, // because it must have an input. int inputCol = 0; - while (inputCol < input.Schema.ColumnCount && input.Schema.IsHidden(inputCol)) + while (inputCol < input.Schema.Count && input.Schema[inputCol].IsHidden) inputCol++; - env.Assert(inputCol < input.Schema.ColumnCount); + env.Assert(inputCol < input.Schema.Count); - var inputColName = input.Schema.GetColumnName(inputCol); - var inputColType = input.Schema.GetColumnType(inputCol); + var inputColName = input.Schema[inputCol].Name; + var inputColType = input.Schema[inputCol].Type; return Utils.MarshalInvoke(AddKeyColumn, inputColType.RawType, env, input, inputColName, MetricKinds.ColumnNames.FoldIndex, inputColType, numFolds, curFold + 1, "FoldIndex", default(ValueGetter>>)); @@ -471,10 +470,10 @@ public static IDataView AddFoldIndex(IHostEnvironment env, IDataView input, int if (!idv.Schema.TryGetColumnIndex(columnName, out col)) throw env.Except("Data view number {0} does not contain column '{1}'", i, columnName); - var type = typeSrc[i] = idv.Schema.GetColumnType(col); - if (!idv.Schema.HasSlotNames(col, type.VectorSize)) + var type = typeSrc[i] = idv.Schema[col].Type; + if (!idv.Schema[col].HasSlotNames(type.VectorSize)) throw env.Except("Column '{0}' in data view number {1} did not contain slot names metadata", columnName, i); - idv.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, col, ref slotNamesCur); + idv.Schema[col].Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref slotNamesCur); var map = maps[i] = new int[slotNamesCur.Length]; foreach (var kvp in slotNamesCur.Items(true)) @@ -546,7 +545,7 @@ public static IDataView AddFoldIndex(IHostEnvironment env, IDataView input, int } } - private static int[][] MapKeys(ISchema[] schemas, string columnName, bool isVec, + private static int[][] MapKeys(Schema[] schemas, string columnName, bool isVec, int[] indices, Dictionary, int> reconciledKeyNames) { Contracts.AssertValue(indices); @@ -561,8 +560,8 @@ private static int[][] MapKeys(ISchema[] schemas, string columnName, bool isV if (!schema.TryGetColumnIndex(columnName, out indices[i])) throw Contracts.Except($"Schema number {i} does not contain column '{columnName}'"); - var type = schema.GetColumnType(indices[i]); - var keyValueType = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, indices[i]); + var type = schema[indices[i]].Type; + var keyValueType = schema[indices[i]].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues)?.Type; if (type.IsVector != isVec) throw Contracts.Except($"Column '{columnName}' in schema number {i} does not have the correct type"); if (keyValueType == null || keyValueType.ItemType.RawType != typeof(T)) @@ -570,7 +569,7 @@ private static int[][] MapKeys(ISchema[] schemas, string columnName, bool isV if (!type.ItemType.IsKey || type.ItemType.RawKind != DataKind.U4) throw Contracts.Except($"Column '{columnName}' must be a U4 key type, but is '{type.ItemType}'"); - schema.GetMetadata(MetadataUtils.Kinds.KeyValues, indices[i], ref keyNamesCur); + schema[indices[i]].Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref keyNamesCur); keyValueMappers[i] = new int[type.ItemType.KeyCount]; foreach (var kvp in keyNamesCur.Items(true)) @@ -622,7 +621,7 @@ public static void ReconcileKeyValues(IHostEnvironment env, IDataView[] views, s dst = (uint)keyMapperCur[src - 1] + 1; }; views[i] = LambdaColumnMapper.Create(env, "ReconcileKeyValues", views[i], columnName, columnName, - views[i].Schema.GetColumnType(indices[i]), keyType, mapper, keyValueGetter); + views[i].Schema[indices[i]].Type, keyType, mapper, keyValueGetter); } } @@ -653,7 +652,7 @@ public static void ReconcileKeyValuesWithNoNames(IHostEnvironment env, IDataView dst = src; }; views[i] = LambdaColumnMapper.Create(env, "ReconcileKeyValues", views[i], columnName, columnName, - views[i].Schema.GetColumnType(index), keyType, mapper); + views[i].Schema[index].Type, keyType, mapper); } } @@ -714,16 +713,16 @@ public static void ReconcileVectorKeyValues(IHostEnvironment env, IDataView[] vi }; ValueGetter>> slotNamesGetter = null; - var type = views[i].Schema.GetColumnType(columnIndices[i]); - if (views[i].Schema.HasSlotNames(columnIndices[i], type.VectorSize)) + var type = views[i].Schema[columnIndices[i]].Type; + if (views[i].Schema[columnIndices[i]].HasSlotNames(type.VectorSize)) { var schema = views[i].Schema; int index = columnIndices[i]; slotNamesGetter = - (ref VBuffer> dst) => schema.GetMetadata(MetadataUtils.Kinds.SlotNames, index, ref dst); + (ref VBuffer> dst) => schema[index].Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref dst); } views[i] = LambdaColumnMapper.Create(env, "ReconcileKeyValues", views[i], columnName, columnName, - type, new VectorType(keyType, type.AsVector), mapper, keyValueGetter, slotNamesGetter); + type, new VectorType(keyType, type as VectorType), mapper, keyValueGetter, slotNamesGetter); } } @@ -814,27 +813,27 @@ private static IDataView AppendPerInstanceDataViews(IHostEnvironment env, string foreach (var dv in foldDataViews) { var hidden = new List(); - for (int i = 0; i < dv.Schema.ColumnCount; i++) + for (int i = 0; i < dv.Schema.Count; i++) { - if (dv.Schema.IsHidden(i)) + if (dv.Schema[i].IsHidden) { hidden.Add(i); continue; } - var type = dv.Schema.GetColumnType(i); - var name = dv.Schema.GetColumnName(i); + var type = dv.Schema[i].Type; + var name = dv.Schema[i].Name; if (type.IsVector) { if (dvNumber == 0) { - if (dv.Schema.HasKeyValues(i, type.ItemType.KeyCount)) + if (dv.Schema[i].HasKeyValues(type.ItemType.KeyCount)) firstDvVectorKeyColumns.Add(name); // Store the slot names of the 1st idv and use them as baseline. - if (dv.Schema.HasSlotNames(i, type.VectorSize)) + if (dv.Schema[i].HasSlotNames(type.VectorSize)) { VBuffer> slotNames = default; - dv.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, i, ref slotNames); + dv.Schema[i].Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref slotNames); firstDvSlotNames.Add(name, slotNames); } } @@ -855,11 +854,11 @@ private static IDataView AppendPerInstanceDataViews(IHostEnvironment env, string else if (dvNumber == 0 && name == labelColName) { // The label column can be a key. Reconcile the key values, and wrap with a KeyToValue transform. - labelColKeyValuesType = dv.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, i); + labelColKeyValuesType = dv.Schema[i].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues)?.Type; } - else if (dvNumber == 0 && dv.Schema.HasKeyValues(i, type.KeyCount)) + else if (dvNumber == 0 && dv.Schema[i].HasKeyValues(type.KeyCount)) firstDvKeyWithNamesColumns.Add(name); - else if (type.KeyCount > 0 && name != labelColName && !dv.Schema.HasKeyValues(i, type.KeyCount)) + else if (type.KeyCount > 0 && name != labelColName && !dv.Schema[i].HasKeyValues(type.KeyCount)) { // For any other key column (such as GroupId) we do not reconcile the key values, we only convert to U4. if (!firstDvKeyNoNamesColumns.ContainsKey(name)) @@ -894,7 +893,7 @@ private static IDataView AppendPerInstanceDataViews(IHostEnvironment env, string Func keyToValue = (idv, i) => { - foreach (var keyCol in firstDvVectorKeyColumns.Concat(firstDvKeyWithNamesColumns).Prepend(labelColName)) + foreach (var keyCol in MetadataUtils.Prepend(firstDvVectorKeyColumns.Concat(firstDvKeyWithNamesColumns), labelColName)) { if (keyCol == labelColName && labelColKeyValuesType == null) continue; @@ -918,7 +917,7 @@ private static IDataView AppendPerInstanceDataViews(IHostEnvironment env, string { int index; idv.Schema.TryGetColumnIndex(variableSizeVectorColumnName, out index); - var type = idv.Schema.GetColumnType(index); + var type = idv.Schema[index].Type; idv = Utils.MarshalInvoke(AddVarLengthColumn, type.ItemType.RawType, env, idv, variableSizeVectorColumnName, type); @@ -934,9 +933,9 @@ private static IDataView AppendPerInstanceDataViews(IHostEnvironment env, string private static IEnumerable FindHiddenColumns(Schema schema, string colName) { - for (int i = 0; i < schema.ColumnCount; i++) + for (int i = 0; i < schema.Count; i++) { - if (schema.IsHidden(i) && schema.GetColumnName(i) == colName) + if (schema[i].IsHidden && schema[i].Name == colName) yield return i; } } @@ -948,11 +947,11 @@ private static bool VerifyVectorColumnsMatch(int cachedSize, int col, IDataView return false; // If we detect mismatch it a sign that slots reshuffling has happened. - if (dv.Schema.HasSlotNames(col, type.VectorSize)) + if (dv.Schema[col].HasSlotNames(type.VectorSize)) { // Verify that slots match with slots from 1st idv. VBuffer> currSlotNames = default; - dv.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, col, ref currSlotNames); + dv.Schema[col].Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref currSlotNames); if (currSlotNames.Length != firstDvSlotNames.Length) return false; @@ -974,30 +973,30 @@ private static bool VerifyVectorColumnsMatch(int cachedSize, int col, IDataView private static IDataView AddVarLengthColumn(IHostEnvironment env, IDataView idv, string variableSizeVectorColumnName, ColumnType typeSrc) { return LambdaColumnMapper.Create(env, "ChangeToVarLength", idv, variableSizeVectorColumnName, - variableSizeVectorColumnName + "_VarLength", typeSrc, new VectorType(typeSrc.ItemType.AsPrimitive), + variableSizeVectorColumnName + "_VarLength", typeSrc, new VectorType((PrimitiveType)typeSrc.ItemType), (in VBuffer src, ref VBuffer dst) => src.CopyTo(ref dst)); } - private static List GetMetricNames(IChannel ch, Schema schema, IRow row, Func ignoreCol, + private static List GetMetricNames(IChannel ch, Schema schema, Row row, Func ignoreCol, ValueGetter[] getters, ValueGetter>[] vBufferGetters) { ch.AssertValue(schema); ch.AssertValue(row); - ch.Assert(Utils.Size(getters) == schema.ColumnCount); - ch.Assert(Utils.Size(vBufferGetters) == schema.ColumnCount); + ch.Assert(Utils.Size(getters) == schema.Count); + ch.Assert(Utils.Size(vBufferGetters) == schema.Count); // Get the names of the metrics. For R8 valued columns the metric name is the column name. For R8 vector valued columns // the names of the metrics are the column name, followed by the slot name if it exists, or "Label_i" if it doesn't. VBuffer> names = default; int metricCount = 0; var metricNames = new List(); - for (int i = 0; i < schema.ColumnCount; i++) + for (int i = 0; i < schema.Count; i++) { - if (schema.IsHidden(i) || ignoreCol(i)) + if (schema[i].IsHidden || ignoreCol(i)) continue; - var type = schema.GetColumnType(i); - var metricName = row.Schema.GetColumnName(i); + var type = schema[i].Type; + var metricName = row.Schema[i].Name; if (type.IsNumber) { getters[i] = RowCursorUtils.GetGetterAs(NumberType.R8, row, i); @@ -1014,9 +1013,9 @@ private static List GetMetricNames(IChannel ch, Schema schema, IRow row, vBufferGetters[i] = row.GetGetter>(i); metricCount += type.VectorSize; - var slotNamesType = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.SlotNames, i); + var slotNamesType = schema[i].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.SlotNames)?.Type; if (slotNamesType != null && slotNamesType.VectorSize == type.VectorSize && slotNamesType.ItemType.IsText) - schema.GetMetadata(MetadataUtils.Kinds.SlotNames, i, ref names); + schema[i].Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref names); else { var editor = VBufferEditor.Create(ref names, type.VectorSize); @@ -1077,7 +1076,7 @@ internal static AggregatedMetric[] ComputeMetricsSum(IHostEnvironment env, IData foldCol = hasFoldCol ? fcol : -1; // We currently have only double valued or vector of double valued metrics. - int colCount = data.Schema.ColumnCount; + int colCount = data.Schema.Count; var getters = new ValueGetter[colCount]; var vBufferGetters = new ValueGetter>[colCount]; int numResults = 0; @@ -1095,7 +1094,7 @@ internal static AggregatedMetric[] ComputeMetricsSum(IHostEnvironment env, IData ValueGetter stratColGetter; if (hasStrats) { - var type = cursor.Schema.GetColumnType(stratCol); + var type = cursor.Schema[stratCol].Type; stratColGetter = RowCursorUtils.GetGetterAs(type, cursor, stratCol); } else @@ -1199,7 +1198,7 @@ internal static IDataView GetAverageToDataView(IHostEnvironment env, Schema sche { Contracts.AssertValue(env); - int colCount = schema.ColumnCount; + int colCount = schema.Count; var dvBldr = new ArrayDataViewBuilder(env); var weightedDvBldr = isWeightedCol >= 0 ? new ArrayDataViewBuilder(env) : null; @@ -1207,14 +1206,14 @@ internal static IDataView GetAverageToDataView(IHostEnvironment env, Schema sche int iMetric = 0; for (int i = 0; i < colCount; i++) { - if (schema.IsHidden(i)) + if (schema[i].IsHidden) continue; - var type = schema.GetColumnType(i); - var name = schema.GetColumnName(i); + var type = schema[i].Type; + var name = schema[i].Name; if (i == stratCol) { - var keyValuesType = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, i); + var keyValuesType = schema[i].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues)?.Type; if (keyValuesType == null || !keyValuesType.ItemType.IsText || keyValuesType.VectorSize != type.KeyCount) { @@ -1225,7 +1224,7 @@ internal static IDataView GetAverageToDataView(IHostEnvironment env, Schema sche ValueGetter>> getKeyValues = (ref VBuffer> dst) => { - schema.GetMetadata(MetadataUtils.Kinds.KeyValues, stratCol, ref dst); + schema[stratCol].Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref dst); Contracts.Assert(dst.IsDense); }; @@ -1273,7 +1272,7 @@ internal static IDataView GetAverageToDataView(IHostEnvironment env, Schema sche return idv; } - private static void AddVectorColumn(this ArrayDataViewBuilder dvBldr, IHostEnvironment env, ISchema schema, + private static void AddVectorColumn(this ArrayDataViewBuilder dvBldr, IHostEnvironment env, Schema schema, AggregatedMetric[] agg, bool hasStdev, int numFolds, int iMetric, int i, ColumnType type, string columnName) { var vectorMetrics = new double[type.VectorSize]; @@ -1301,7 +1300,7 @@ private static void AddVectorColumn(this ArrayDataViewBuilder dvBldr, IHostEnvir dvBldr.AddColumn(columnName, getSlotNames, NumberType.R8, new[] { vectorMetrics }); } - private static void AddScalarColumn(this ArrayDataViewBuilder dvBldr, ISchema schema, AggregatedMetric[] agg, bool hasStdev, int numFolds, int iMetric) + private static void AddScalarColumn(this ArrayDataViewBuilder dvBldr, Schema schema, AggregatedMetric[] agg, bool hasStdev, int numFolds, int iMetric) { Contracts.AssertValue(dvBldr); @@ -1343,11 +1342,11 @@ public static string GetConfusionTable(IHost host, IDataView confusionDataView, // Get the class names. int countCol; host.Check(confusionDataView.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.Count, out countCol), "Did not find the count column"); - var type = confusionDataView.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.SlotNames, countCol); + var type = confusionDataView.Schema[countCol].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.SlotNames)?.Type; host.Check(type != null && type.IsKnownSizeVector && type.ItemType.IsText, "The Count column does not have a text vector metadata of kind SlotNames."); var labelNames = default(VBuffer>); - confusionDataView.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, countCol, ref labelNames); + confusionDataView.Schema[countCol].Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref labelNames); host.Check(labelNames.IsDense, "Slot names vector must be dense"); int numConfusionTableLabels = sample < 0 ? labelNames.Length : Math.Min(labelNames.Length, sample); @@ -1427,13 +1426,13 @@ private static double[][] GetConfusionTableAsArray(IDataView confusionDataView, var hasStrat = confusionDataView.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratCol, out stratCol); using (var cursor = confusionDataView.GetRowCursor(col => col == countIndex || hasStrat && col == stratCol)) { - var type = cursor.Schema.GetColumnType(countIndex); + var type = cursor.Schema[countIndex].Type; Contracts.Check(type.IsKnownSizeVector && type.ItemType == NumberType.R8); var countGetter = cursor.GetGetter>(countIndex); ValueGetter stratGetter = null; if (hasStrat) { - type = cursor.Schema.GetColumnType(stratCol); + type = cursor.Schema[stratCol].Type; stratGetter = RowCursorUtils.GetGetterAs(type, cursor, stratCol); } @@ -1688,7 +1687,7 @@ public static void PrintWarnings(IChannel ch, Dictionary metr if (metrics.TryGetValue(MetricKinds.Warnings, out warnings)) { int col; - if (warnings.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.WarningText, out col) && warnings.Schema.GetColumnType(col).IsText) + if (warnings.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.WarningText, out col) && warnings.Schema[col].Type.IsText) { using (var cursor = warnings.GetRowCursor(c => c == col)) { @@ -1726,7 +1725,7 @@ public static IDataView GetNonStratifiedMetrics(IHostEnvironment env, IDataView int stratCol; if (!data.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratCol, out stratCol)) return data; - var type = data.Schema.GetColumnType(stratCol); + var type = data.Schema[stratCol].Type; env.Check(type.KeyCount > 0, "Expected a known count key type stratification column"); var filterArgs = new NAFilter.Arguments(); filterArgs.Column = new[] { MetricKinds.ColumnNames.StratCol }; @@ -1737,7 +1736,7 @@ public static IDataView GetNonStratifiedMetrics(IHostEnvironment env, IDataView var found = data.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratVal, out stratVal); env.Check(found, "If stratification column exist, data view must also contain a StratVal column"); - data = ColumnSelectingTransformer.CreateDrop(env, data, data.Schema.GetColumnName(stratCol), data.Schema.GetColumnName(stratVal)); + data = ColumnSelectingTransformer.CreateDrop(env, data, data.Schema[stratCol].Name, data.Schema[stratVal].Name); return data; } } diff --git a/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs index ccef4706f4..346346c429 100644 --- a/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs @@ -2,22 +2,23 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Transforms; using System.Collections.Generic; using System.Linq; +using Microsoft.ML.CommandLine; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Transforms; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// This interface is used by Maml components (the , the /// and the to evaluate, print and save the results. /// The input to the and the methods /// should be assumed to contain only the following column roles: label, group, weight and name. Any other columns needed for - /// evaluation should be searched for by name in the . + /// evaluation should be searched for by name in the . /// - public interface IMamlEvaluator : IEvaluator + [BestFriend] + internal interface IMamlEvaluator : IEvaluator { /// /// Print the aggregate metrics to the console. @@ -45,7 +46,7 @@ public interface IMamlEvaluator : IEvaluator } /// - /// A base class implementation of . The and + /// A base class implementation of . The and /// methods create a new containing all the columns needed for evaluation, and call the corresponding /// methods on an of the appropriate type. /// @@ -72,18 +73,26 @@ public abstract class ArgumentsBase : EvaluateInputBase public string[] StratColumn; } - public static RoleMappedSchema.ColumnRole Strat = "Strat"; - protected readonly IHost Host; - - protected readonly string ScoreColumnKind; - protected readonly string ScoreCol; - protected readonly string LabelCol; - protected readonly string WeightCol; - protected readonly string[] StratCols; - - protected abstract IEvaluator Evaluator { get; } - - protected MamlEvaluatorBase(ArgumentsBase args, IHostEnvironment env, string scoreColumnKind, string registrationName) + internal static RoleMappedSchema.ColumnRole Strat = "Strat"; + [BestFriend] + private protected readonly IHost Host; + + [BestFriend] + private protected readonly string ScoreColumnKind; + [BestFriend] + private protected readonly string ScoreCol; + [BestFriend] + private protected readonly string LabelCol; + [BestFriend] + private protected readonly string WeightCol; + [BestFriend] + private protected readonly string[] StratCols; + + [BestFriend] + private protected abstract IEvaluator Evaluator { get; } + + [BestFriend] + private protected MamlEvaluatorBase(ArgumentsBase args, IHostEnvironment env, string scoreColumnKind, string registrationName) { Contracts.CheckValue(env, nameof(env)); Host = env.Register(registrationName); @@ -94,13 +103,14 @@ protected MamlEvaluatorBase(ArgumentsBase args, IHostEnvironment env, string sco StratCols = args.StratColumn; } - public Dictionary Evaluate(RoleMappedData data) + Dictionary IEvaluator.Evaluate(RoleMappedData data) { data = new RoleMappedData(data.Data, GetInputColumnRoles(data.Schema, needStrat: true)); return Evaluator.Evaluate(data); } - protected IEnumerable> GetInputColumnRoles(RoleMappedSchema schema, bool needStrat = false, bool needName = false) + [BestFriend] + private protected IEnumerable> GetInputColumnRoles(RoleMappedSchema schema, bool needStrat = false, bool needName = false) { Host.CheckValue(schema, nameof(schema)); @@ -108,8 +118,8 @@ public Dictionary Evaluate(RoleMappedData data) ? Enumerable.Empty>() : StratCols.Select(col => RoleMappedSchema.CreatePair(Strat, col)); - if (needName && schema.Name != null) - roles = roles.Prepend(RoleMappedSchema.ColumnRole.Name.Bind(schema.Name.Name)); + if (needName && schema.Name.HasValue) + roles = MetadataUtils.Prepend(roles, RoleMappedSchema.ColumnRole.Name.Bind(schema.Name.Value.Name)); return roles.Concat(GetInputColumnRolesCore(schema)); } @@ -119,12 +129,13 @@ public Dictionary Evaluate(RoleMappedData data) /// The base class ipmlementation gets the score column, the label column (if exists) and the weight column (if exists). /// Override if additional columns are needed. /// - protected virtual IEnumerable> GetInputColumnRolesCore(RoleMappedSchema schema) + [BestFriend] + private protected virtual IEnumerable> GetInputColumnRolesCore(RoleMappedSchema schema) { // Get the score column information. - var scoreInfo = EvaluateUtils.GetScoreColumnInfo(Host, schema.Schema, ScoreCol, nameof(ArgumentsBase.ScoreColumn), + var scoreCol = EvaluateUtils.GetScoreColumn(Host, schema.Schema, ScoreCol, nameof(ArgumentsBase.ScoreColumn), ScoreColumnKind); - yield return RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, scoreInfo.Name); + yield return RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, scoreCol.Name); // Get the label column information. string label = EvaluateUtils.GetColName(LabelCol, schema.Label, DefaultColumnNames.Label); @@ -140,7 +151,7 @@ public virtual IEnumerable GetOverallMetricColumns() return Evaluator.GetOverallMetricColumns(); } - public void PrintFoldResults(IChannel ch, Dictionary metrics) + void IMamlEvaluator.PrintFoldResults(IChannel ch, Dictionary metrics) { Host.CheckValue(ch, nameof(ch)); Host.CheckValue(metrics, nameof(metrics)); @@ -151,7 +162,8 @@ public void PrintFoldResults(IChannel ch, Dictionary metrics) /// This method simply prints the overall metrics using EvaluateUtils.PrintConfusionMatrixAndPerFoldResults. /// Override if something else is needed. /// - protected virtual void PrintFoldResultsCore(IChannel ch, Dictionary metrics) + [BestFriend] + private protected virtual void PrintFoldResultsCore(IChannel ch, Dictionary metrics) { ch.AssertValue(ch); ch.AssertValue(metrics); @@ -167,24 +179,26 @@ protected virtual void PrintFoldResultsCore(IChannel ch, Dictionary[] metrics) + void IMamlEvaluator.PrintAdditionalMetrics(IChannel ch, params Dictionary[] metrics) { Host.CheckValue(ch, nameof(ch)); Host.CheckNonEmpty(metrics, nameof(metrics)); @@ -195,11 +209,12 @@ public void PrintAdditionalMetrics(IChannel ch, params Dictionary - protected virtual void PrintAdditionalMetricsCore(IChannel ch, Dictionary[] metrics) + [BestFriend] + private protected virtual void PrintAdditionalMetricsCore(IChannel ch, Dictionary[] metrics) { } - public IDataTransform GetPerInstanceMetrics(RoleMappedData scoredData) + IDataTransform IEvaluator.GetPerInstanceMetrics(RoleMappedData scoredData) { Host.AssertValue(scoredData); @@ -223,23 +238,23 @@ private IDataView WrapPerInstance(RoleMappedData perInst) colsToKeep.Add(MetricKinds.ColumnNames.FoldIndex); // Maml always outputs a name column, if it doesn't exist add a GenerateNumberTransform. - if (perInst.Schema.Name == null) + if (perInst.Schema.Name?.Name is string nameName) { - var args = new GenerateNumberTransform.Arguments(); - args.Column = new[] { new GenerateNumberTransform.Column() { Name = "Instance" } }; - args.UseCounter = true; - idv = new GenerateNumberTransform(Host, args, idv); + cols.Add((nameName, "Instance")); colsToKeep.Add("Instance"); } else { - cols.Add((perInst.Schema.Name.Name, "Instance")); + var args = new GenerateNumberTransform.Arguments(); + args.Column = new[] { new GenerateNumberTransform.Column() { Name = "Instance" } }; + args.UseCounter = true; + idv = new GenerateNumberTransform(Host, args, idv); colsToKeep.Add("Instance"); } // Maml outputs the weight column if it exists. - if (perInst.Schema.Weight != null) - colsToKeep.Add(perInst.Schema.Weight.Name); + if (perInst.Schema.Weight?.Name is string weightName) + colsToKeep.Add(weightName); // Get the other columns from the evaluator. foreach (var col in GetPerInstanceColumnsToSave(perInst.Schema)) @@ -256,12 +271,13 @@ private IDataView WrapPerInstance(RoleMappedData perInst) /// It should be overridden only if additional processing is needed, such as dropping slots in the "top k scores" column /// in the multi-class case. /// - protected virtual IDataView GetPerInstanceMetricsCore(IDataView perInst, RoleMappedSchema schema) + [BestFriend] + private protected virtual IDataView GetPerInstanceMetricsCore(IDataView perInst, RoleMappedSchema schema) { return perInst; } - public IDataView GetPerInstanceDataViewToSave(RoleMappedData perInstance) + IDataView IMamlEvaluator.GetPerInstanceDataViewToSave(RoleMappedData perInstance) { Host.CheckValue(perInstance, nameof(perInstance)); var data = new RoleMappedData(perInstance.Data, GetInputColumnRoles(perInstance.Schema, needName: true)); @@ -273,6 +289,7 @@ public IDataView GetPerInstanceDataViewToSave(RoleMappedData perInstance) /// the columns generated by the corresponding , or any of the input columns used by /// it. The Name and Weight columns should not be included, since the base class includes them automatically. /// - protected abstract IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema); + [BestFriend] + private protected abstract IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema); } } diff --git a/src/Microsoft.ML.Data/Evaluators/Metrics/BinaryClassificationMetrics.cs b/src/Microsoft.ML.Data/Evaluators/Metrics/BinaryClassificationMetrics.cs new file mode 100644 index 0000000000..1fa0e1c68f --- /dev/null +++ b/src/Microsoft.ML.Data/Evaluators/Metrics/BinaryClassificationMetrics.cs @@ -0,0 +1,110 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.ML.Data +{ + /// + /// Evaluation results for binary classifiers, excluding probabilistic metrics. + /// + public class BinaryClassificationMetrics + { + /// + /// Gets the area under the ROC curve. + /// + /// + /// The area under the ROC curve is equal to the probability that the classifier ranks + /// a randomly chosen positive instance higher than a randomly chosen negative one + /// (assuming 'positive' ranks higher than 'negative'). + /// + public double Auc { get; } + + /// + /// Gets the accuracy of a classifier which is the proportion of correct predictions in the test set. + /// + public double Accuracy { get; } + + /// + /// Gets the positive precision of a classifier which is the proportion of correctly predicted + /// positive instances among all the positive predictions (i.e., the number of positive instances + /// predicted as positive, divided by the total number of instances predicted as positive). + /// + public double PositivePrecision { get; } + + /// + /// Gets the positive recall of a classifier which is the proportion of correctly predicted + /// positive instances among all the positive instances (i.e., the number of positive instances + /// predicted as positive, divided by the total number of positive instances). + /// + public double PositiveRecall { get; private set; } + + /// + /// Gets the negative precision of a classifier which is the proportion of correctly predicted + /// negative instances among all the negative predictions (i.e., the number of negative instances + /// predicted as negative, divided by the total number of instances predicted as negative). + /// + public double NegativePrecision { get; } + + /// + /// Gets the negative recall of a classifier which is the proportion of correctly predicted + /// negative instances among all the negative instances (i.e., the number of negative instances + /// predicted as negative, divided by the total number of negative instances). + /// + public double NegativeRecall { get; } + + /// + /// Gets the F1 score of the classifier. + /// + /// + /// F1 score is the harmonic mean of precision and recall: 2 * precision * recall / (precision + recall). + /// + public double F1Score { get; } + + /// + /// Gets the area under the precision/recall curve of the classifier. + /// + /// + /// The area under the precision/recall curve is a single number summary of the information in the + /// precision/recall curve. It is increasingly used in the machine learning community, particularly + /// for imbalanced datasets where one class is observed more frequently than the other. On these + /// datasets, AUPRC can highlight performance differences that are lost with AUC. + /// + public double Auprc { get; } + + protected private static T Fetch(IExceptionContext ectx, Row row, string name) + { + if (!row.Schema.TryGetColumnIndex(name, out int col)) + throw ectx.Except($"Could not find column '{name}'"); + T val = default; + row.GetGetter(col)(ref val); + return val; + } + + internal BinaryClassificationMetrics(IExceptionContext ectx, Row overallResult) + { + double Fetch(string name) => Fetch(ectx, overallResult, name); + Auc = Fetch(BinaryClassifierEvaluator.Auc); + Accuracy = Fetch(BinaryClassifierEvaluator.Accuracy); + PositivePrecision = Fetch(BinaryClassifierEvaluator.PosPrecName); + PositiveRecall = Fetch(BinaryClassifierEvaluator.PosRecallName); + NegativePrecision = Fetch(BinaryClassifierEvaluator.NegPrecName); + NegativeRecall = Fetch(BinaryClassifierEvaluator.NegRecallName); + F1Score = Fetch(BinaryClassifierEvaluator.F1); + Auprc = Fetch(BinaryClassifierEvaluator.AuPrc); + } + + [BestFriend] + internal BinaryClassificationMetrics(double auc, double accuracy, double positivePrecision, double positiveRecall, + double negativePrecision, double negativeRecall, double f1Score, double auprc) + { + Auc = auc; + Accuracy = accuracy; + PositivePrecision = positivePrecision; + PositiveRecall = positiveRecall; + NegativePrecision = negativePrecision; + NegativeRecall = negativeRecall; + F1Score = f1Score; + Auprc = auprc; + } + } +} \ No newline at end of file diff --git a/src/Microsoft.ML.Data/Evaluators/Metrics/CalibratedBinaryClassificationMetrics.cs b/src/Microsoft.ML.Data/Evaluators/Metrics/CalibratedBinaryClassificationMetrics.cs new file mode 100644 index 0000000000..902daff204 --- /dev/null +++ b/src/Microsoft.ML.Data/Evaluators/Metrics/CalibratedBinaryClassificationMetrics.cs @@ -0,0 +1,51 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.ML.Data +{ + /// + /// Evaluation results for binary classifiers, including probabilistic metrics. + /// + public sealed class CalibratedBinaryClassificationMetrics : BinaryClassificationMetrics + { + /// + /// Gets the log-loss of the classifier. + /// + /// + /// The log-loss metric, is computed as follows: + /// LL = - (1/m) * sum( log(p[i])) + /// where m is the number of instances in the test set. + /// p[i] is the probability returned by the classifier if the instance belongs to class 1, + /// and 1 minus the probability returned by the classifier if the instance belongs to class 0. + /// + public double LogLoss { get; } + + /// + /// Gets the log-loss reduction (also known as relative log-loss, or reduction in information gain - RIG) + /// of the classifier. + /// + /// + /// The log-loss reduction is scaled relative to a classifier that predicts the prior for every example: + /// (LL(prior) - LL(classifier)) / LL(prior) + /// This metric can be interpreted as the advantage of the classifier over a random prediction. + /// For example, if the RIG equals 20, it can be interpreted as "the probability of a correct prediction is + /// 20% better than random guessing." + /// + public double LogLossReduction { get; } + + /// + /// Gets the test-set entropy (prior Log-Loss/instance) of the classifier. + /// + public double Entropy { get; } + + internal CalibratedBinaryClassificationMetrics(IExceptionContext ectx, Row overallResult) + : base(ectx, overallResult) + { + double Fetch(string name) => Fetch(ectx, overallResult, name); + LogLoss = Fetch(BinaryClassifierEvaluator.LogLoss); + LogLossReduction = Fetch(BinaryClassifierEvaluator.LogLossReduction); + Entropy = Fetch(BinaryClassifierEvaluator.Entropy); + } + } +} \ No newline at end of file diff --git a/src/Microsoft.ML.Data/Evaluators/Metrics/ClusteringMetrics.cs b/src/Microsoft.ML.Data/Evaluators/Metrics/ClusteringMetrics.cs new file mode 100644 index 0000000000..9b464f2728 --- /dev/null +++ b/src/Microsoft.ML.Data/Evaluators/Metrics/ClusteringMetrics.cs @@ -0,0 +1,53 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.ML.Data +{ + /// + /// The metrics generated after evaluating the clustering predictions. + /// + public sealed class ClusteringMetrics + { + /// + /// Normalized Mutual Information + /// NMI is a measure of the mutual dependence of the variables. + /// Normalized variants work on data that already has cluster labels. + /// Its value ranged from 0 to 1, where higher numbers are better. + /// + public double Nmi { get; } + + /// + /// Average Score. For the K-Means algorithm, the 'score' is the distance from the centroid to the example. + /// The average score is, therefore, a measure of proximity of the examples to cluster centroids. + /// In other words, it's the 'cluster tightness' measure. + /// Note however, that this metric will only decrease if the number of clusters is increased, + /// and in the extreme case (where each distinct example is its own cluster) it will be equal to zero. + /// + public double AvgMinScore { get; } + + /// + /// Davies-Bouldin Index + /// DBI is a measure of the how much scatter is in the cluster and the cluster separation. + /// + public double Dbi { get; } + + internal ClusteringMetrics(IExceptionContext ectx, Row overallResult, bool calculateDbi) + { + double Fetch(string name) => RowCursorUtils.Fetch(ectx, overallResult, name); + + Nmi = Fetch(ClusteringEvaluator.Nmi); + AvgMinScore = Fetch(ClusteringEvaluator.AvgMinScore); + + if (calculateDbi) + Dbi = Fetch(ClusteringEvaluator.Dbi); + } + + internal ClusteringMetrics(double nmi, double avgMinScore, double dbi) + { + Nmi = nmi; + AvgMinScore = avgMinScore; + Dbi = dbi; + } + } +} \ No newline at end of file diff --git a/src/Microsoft.ML.Data/Evaluators/Metrics/MultiClassClassifierMetrics.cs b/src/Microsoft.ML.Data/Evaluators/Metrics/MultiClassClassifierMetrics.cs new file mode 100644 index 0000000000..b4c9598533 --- /dev/null +++ b/src/Microsoft.ML.Data/Evaluators/Metrics/MultiClassClassifierMetrics.cs @@ -0,0 +1,110 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.ML.Data +{ + public sealed class MultiClassClassifierMetrics + { + /// + /// Gets the micro-average accuracy of the model. + /// + /// + /// The micro-average is the fraction of instances predicted correctly. + /// + /// The micro-average metric weighs each class according to the number of instances that belong + /// to it in the dataset. + /// + public double AccuracyMicro { get; } + + /// + /// Gets the macro-average accuracy of the model. + /// + /// + /// The macro-average is computed by taking the average over all the classes of the fraction + /// of correct predictions in this class (the number of correctly predicted instances in the class, + /// divided by the total number of instances in the class). + /// + /// The macro-average metric gives the same weight to each class, no matter how many instances from + /// that class the dataset contains. + /// + public double AccuracyMacro { get; } + + /// + /// Gets the average log-loss of the classifier. + /// + /// + /// The log-loss metric, is computed as follows: + /// LL = - (1/m) * sum( log(p[i])) + /// where m is the number of instances in the test set. + /// p[i] is the probability returned by the classifier if the instance belongs to class 1, + /// and 1 minus the probability returned by the classifier if the instance belongs to class 0. + /// + public double LogLoss { get; } + + /// + /// Gets the log-loss reduction (also known as relative log-loss, or reduction in information gain - RIG) + /// of the classifier. + /// + /// + /// The log-loss reduction is scaled relative to a classifier that predicts the prior for every example: + /// (LL(prior) - LL(classifier)) / LL(prior) + /// This metric can be interpreted as the advantage of the classifier over a random prediction. + /// For example, if the RIG equals 20, it can be interpreted as "the probability of a correct prediction is + /// 20% better than random guessing". + /// + public double LogLossReduction { get; private set; } + + /// + /// If positive, this is the top-K for which the is calculated. + /// + public int TopK { get; } + + /// + /// If is positive, this is the relative number of examples where + /// the true label is one of the top k predicted labels by the predictor. + /// + public double TopKAccuracy { get; } + + /// + /// Gets the log-loss of the classifier for each class. + /// + /// + /// The log-loss metric, is computed as follows: + /// LL = - (1/m) * sum( log(p[i])) + /// where m is the number of instances in the test set. + /// p[i] is the probability returned by the classifier if the instance belongs to the class, + /// and 1 minus the probability returned by the classifier if the instance does not belong to the class. + /// + public double[] PerClassLogLoss { get; } + + internal MultiClassClassifierMetrics(IExceptionContext ectx, Row overallResult, int topK) + { + double FetchDouble(string name) => RowCursorUtils.Fetch(ectx, overallResult, name); + AccuracyMicro = FetchDouble(MultiClassClassifierEvaluator.AccuracyMicro); + AccuracyMacro = FetchDouble(MultiClassClassifierEvaluator.AccuracyMacro); + LogLoss = FetchDouble(MultiClassClassifierEvaluator.LogLoss); + LogLossReduction = FetchDouble(MultiClassClassifierEvaluator.LogLossReduction); + TopK = topK; + if (topK > 0) + TopKAccuracy = FetchDouble(MultiClassClassifierEvaluator.TopKAccuracy); + + var perClassLogLoss = RowCursorUtils.Fetch>(ectx, overallResult, MultiClassClassifierEvaluator.PerClassLogLoss); + PerClassLogLoss = new double[perClassLogLoss.Length]; + perClassLogLoss.CopyTo(PerClassLogLoss); + } + + internal MultiClassClassifierMetrics(double accuracyMicro, double accuracyMacro, double logLoss, double logLossReduction, + int topK, double topKAccuracy, double[] perClassLogLoss) + { + AccuracyMicro = accuracyMicro; + AccuracyMacro = accuracyMacro; + LogLoss = logLoss; + LogLossReduction = logLossReduction; + TopK = topK; + TopKAccuracy = topKAccuracy; + PerClassLogLoss = new double[perClassLogLoss.Length]; + perClassLogLoss.CopyTo(PerClassLogLoss, 0); + } + } +} \ No newline at end of file diff --git a/src/Microsoft.ML.Data/Evaluators/Metrics/RankerMetrics.cs b/src/Microsoft.ML.Data/Evaluators/Metrics/RankerMetrics.cs new file mode 100644 index 0000000000..6a801d1007 --- /dev/null +++ b/src/Microsoft.ML.Data/Evaluators/Metrics/RankerMetrics.cs @@ -0,0 +1,49 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.ML.Data +{ + public sealed class RankerMetrics + { + /// + /// Array of normalized discounted cumulative gains where i-th element represent NDCG@i. + /// + /// + public double[] Ndcg { get; } + + /// + ///Array of discounted cumulative gains where i-th element represent DCG@i. + /// Discounted Cumulative gain + /// is the sum of the gains, for all the instances i, normalized by the natural logarithm of the instance + 1. + /// Note that unline the Wikipedia article, ML.Net uses the natural logarithm. + /// + /// + public double[] Dcg { get; } + + private static T Fetch(IExceptionContext ectx, Row row, string name) + { + if (!row.Schema.TryGetColumnIndex(name, out int col)) + throw ectx.Except($"Could not find column '{name}'"); + T val = default; + row.GetGetter(col)(ref val); + return val; + } + + internal RankerMetrics(IExceptionContext ectx, Row overallResult) + { + VBuffer Fetch(string name) => Fetch>(ectx, overallResult, name); + + Dcg = Fetch(RankerEvaluator.Dcg).GetValues().ToArray(); + Ndcg = Fetch(RankerEvaluator.Ndcg).GetValues().ToArray(); + } + + internal RankerMetrics(double[] dcg, double[] ndcg) + { + Dcg = new double[dcg.Length]; + dcg.CopyTo(Dcg, 0); + Ndcg = new double[ndcg.Length]; + ndcg.CopyTo(Ndcg, 0); + } + } +} \ No newline at end of file diff --git a/src/Microsoft.ML.Data/Evaluators/Metrics/RegressionMetrics.cs b/src/Microsoft.ML.Data/Evaluators/Metrics/RegressionMetrics.cs new file mode 100644 index 0000000000..4d58876019 --- /dev/null +++ b/src/Microsoft.ML.Data/Evaluators/Metrics/RegressionMetrics.cs @@ -0,0 +1,73 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.ML.Data +{ + public sealed class RegressionMetrics + { + /// + /// Gets the absolute loss of the model. + /// + /// + /// The absolute loss is defined as + /// L1 = (1/m) * sum( abs( yi - y'i)) + /// where m is the number of instances in the test set. + /// y'i are the predicted labels for each instance. + /// yi are the correct labels of each instance. + /// + public double L1 { get; } + + /// + /// Gets the squared loss of the model. + /// + /// + /// The squared loss is defined as + /// L2 = (1/m) * sum(( yi - y'i)^2) + /// where m is the number of instances in the test set. + /// y'i are the predicted labels for each instance. + /// yi are the correct labels of each instance. + /// + public double L2 { get; } + + /// + /// Gets the root mean square loss (or RMS) which is the square root of the L2 loss. + /// + public double Rms { get; } + + /// + /// Gets the result of user defined loss function. + /// + /// + /// This is the average of a loss function defined by the user, + /// computed over all the instances in the test set. + /// + public double LossFn { get; } + + /// + /// Gets the R squared value of the model, which is also known as + /// the coefficient of determination​. + /// + public double RSquared { get; } + + internal RegressionMetrics(IExceptionContext ectx, Row overallResult) + { + double Fetch(string name) => RowCursorUtils.Fetch(ectx, overallResult, name); + L1 = Fetch(RegressionEvaluator.L1); + L2 = Fetch(RegressionEvaluator.L2); + Rms = Fetch(RegressionEvaluator.Rms); + LossFn = Fetch(RegressionEvaluator.Loss); + RSquared = Fetch(RegressionEvaluator.RSquared); + } + + [BestFriend] + internal RegressionMetrics(double l1, double l2, double rms, double lossFunction, double rSquared) + { + L1 = l1; + L2 = l2; + Rms = rms; + LossFn = lossFunction; + RSquared = rSquared; + } + } +} \ No newline at end of file diff --git a/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs similarity index 82% rename from src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs rename to src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs index 93515e5315..a59d785aa2 100644 --- a/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs @@ -2,19 +2,18 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Transforms; -using Microsoft.ML.Transforms.FeatureSelection; using System; using System.Collections.Generic; using System.Linq; using System.Text.RegularExpressions; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Transforms; +using Microsoft.ML.Transforms.FeatureSelection; [assembly: LoadableClass(typeof(MultiClassClassifierEvaluator), typeof(MultiClassClassifierEvaluator), typeof(MultiClassClassifierEvaluator.Arguments), typeof(SignatureEvaluator), "Multi-Class Classifier Evaluator", MultiClassClassifierEvaluator.LoadName, "MultiClassClassifier", "MultiClass")] @@ -26,7 +25,7 @@ [assembly: LoadableClass(typeof(MultiClassPerInstanceEvaluator), null, typeof(SignatureLoadRowMapper), "", MultiClassPerInstanceEvaluator.LoaderSignature)] -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { public sealed class MultiClassClassifierEvaluator : RowToRowEvaluatorBase { @@ -58,7 +57,7 @@ public enum Metrics LogLossReduction, } - public const string LoadName = "MultiClassClassifierEvaluator"; + internal const string LoadName = "MultiClassClassifierEvaluator"; private readonly int? _outputTopKAcc; private readonly bool _names; @@ -72,19 +71,19 @@ public MultiClassClassifierEvaluator(IHostEnvironment env, Arguments args) _names = args.Names; } - protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) + private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) { var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); var t = score.Type; if (t.VectorSize < 2 || t.ItemType != NumberType.Float) - throw Host.Except("Score column '{0}' has type {1} but must be a vector of two or more items of type R4", score.Name, t); - Host.Check(schema.Label != null, "Could not find the label column"); - t = schema.Label.Type; + throw Host.ExceptSchemaMismatch(nameof(schema), "score", score.Name, "vector of two or more items of type R4", t.ToString()); + Host.CheckParam(schema.Label.HasValue, nameof(schema), "Could not find the label column"); + t = schema.Label.Value.Type; if (t != NumberType.Float && t.KeyCount <= 0) - throw Host.Except("Label column '{0}' has type {1} but must be a float or a known-cardinality key", schema.Label.Name, t); + throw Host.ExceptSchemaMismatch(nameof(schema), "label", schema.Label.Value.Name, "float or a known-cardinality key", t.ToString()); } - protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName) + private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName) { var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); Host.Assert(score.Type.VectorSize > 0); @@ -98,11 +97,11 @@ private ReadOnlyMemory[] GetClassNames(RoleMappedSchema schema) ReadOnlyMemory[] names; // Get the label names from the score column if they exist, or use the default names. var scoreInfo = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); - var mdType = schema.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.SlotNames, scoreInfo.Index); + var mdType = schema.Schema[scoreInfo.Index].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.SlotNames)?.Type; var labelNames = default(VBuffer>); if (mdType != null && mdType.IsKnownSizeVector && mdType.ItemType.IsText) { - schema.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, scoreInfo.Index, ref labelNames); + schema.Schema[scoreInfo.Index].Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref labelNames); names = new ReadOnlyMemory[labelNames.Length]; labelNames.CopyTo(names); } @@ -117,12 +116,12 @@ private ReadOnlyMemory[] GetClassNames(RoleMappedSchema schema) return names; } - protected override IRowMapper CreatePerInstanceRowMapper(RoleMappedSchema schema) + private protected override IRowMapper CreatePerInstanceRowMapper(RoleMappedSchema schema) { - Host.CheckParam(schema.Label != null, nameof(schema), "Schema must contain a label column"); + Host.CheckParam(schema.Label.HasValue, nameof(schema), "Schema must contain a label column"); var scoreInfo = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); int numClasses = scoreInfo.Type.VectorSize; - return new MultiClassPerInstanceEvaluator(Host, schema.Schema, scoreInfo, schema.Label.Name); + return new MultiClassPerInstanceEvaluator(Host, schema.Schema, scoreInfo, schema.Label.Value.Name); } public override IEnumerable GetOverallMetricColumns() @@ -137,7 +136,7 @@ public override IEnumerable GetOverallMetricColumns() yield return new MetricColumn("LogLossReduction", LogLossReduction); } - protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries, + private protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries, out Action, Aggregator> addAgg, out Func> consolidate) { var stratCol = new List(); @@ -387,20 +386,20 @@ public Aggregator(IHostEnvironment env, ReadOnlyMemory[] classNames, int s ClassNames = classNames; } - public override void InitializeNextPass(IRow row, RoleMappedSchema schema) + internal override void InitializeNextPass(Row row, RoleMappedSchema schema) { Host.Assert(PassNum < 1); - Host.AssertValue(schema.Label); + Host.Assert(schema.Label.HasValue); var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); Host.Assert(score.Type.VectorSize == _scoresArr.Length); - _labelGetter = RowCursorUtils.GetLabelGetter(row, schema.Label.Index); + _labelGetter = RowCursorUtils.GetLabelGetter(row, schema.Label.Value.Index); _scoreGetter = row.GetGetter>(score.Index); Host.AssertValue(_labelGetter); Host.AssertValue(_scoreGetter); - if (schema.Weight != null) - _weightGetter = row.GetGetter(schema.Weight.Index); + if (schema.Weight.HasValue) + _weightGetter = row.GetGetter(schema.Weight.Value.Index); } public override void ProcessRow() @@ -496,97 +495,6 @@ public void GetSlotNames(ref VBuffer> slotNames) } } - public sealed class Result - { - /// - /// Gets the micro-average accuracy of the model. - /// - /// - /// The micro-average is the fraction of instances predicted correctly. - /// - /// The micro-average metric weighs each class according to the number of instances that belong - /// to it in the dataset. - /// - public double AccuracyMicro { get; } - - /// - /// Gets the macro-average accuracy of the model. - /// - /// - /// The macro-average is computed by taking the average over all the classes of the fraction - /// of correct predictions in this class (the number of correctly predicted instances in the class, - /// divided by the total number of instances in the class). - /// - /// The macro-average metric gives the same weight to each class, no matter how many instances from - /// that class the dataset contains. - /// - public double AccuracyMacro { get; } - - /// - /// Gets the average log-loss of the classifier. - /// - /// - /// The log-loss metric, is computed as follows: - /// LL = - (1/m) * sum( log(p[i])) - /// where m is the number of instances in the test set. - /// p[i] is the probability returned by the classifier if the instance belongs to class 1, - /// and 1 minus the probability returned by the classifier if the instance belongs to class 0. - /// - public double LogLoss { get; } - - /// - /// Gets the log-loss reduction (also known as relative log-loss, or reduction in information gain - RIG) - /// of the classifier. - /// - /// - /// The log-loss reduction is scaled relative to a classifier that predicts the prior for every example: - /// (LL(prior) - LL(classifier)) / LL(prior) - /// This metric can be interpreted as the advantage of the classifier over a random prediction. - /// For example, if the RIG equals 20, it can be interpreted as "the probability of a correct prediction is - /// 20% better than random guessing". - /// - public double LogLossReduction { get; private set; } - - /// - /// If positive, this is the top-K for which the is calculated. - /// - public int TopK { get; } - - /// - /// If is positive, this is the relative number of examples where - /// the true label is one of the top k predicted labels by the predictor. - /// - public double TopKAccuracy { get; } - - /// - /// Gets the log-loss of the classifier for each class. - /// - /// - /// The log-loss metric, is computed as follows: - /// LL = - (1/m) * sum( log(p[i])) - /// where m is the number of instances in the test set. - /// p[i] is the probability returned by the classifier if the instance belongs to the class, - /// and 1 minus the probability returned by the classifier if the instance does not belong to the class. - /// - public double[] PerClassLogLoss { get; } - - internal Result(IExceptionContext ectx, IRow overallResult, int topK) - { - double FetchDouble(string name) => RowCursorUtils.Fetch(ectx, overallResult, name); - AccuracyMicro = FetchDouble(MultiClassClassifierEvaluator.AccuracyMicro); - AccuracyMacro = FetchDouble(MultiClassClassifierEvaluator.AccuracyMacro); - LogLoss = FetchDouble(MultiClassClassifierEvaluator.LogLoss); - LogLossReduction = FetchDouble(MultiClassClassifierEvaluator.LogLossReduction); - TopK = topK; - if (topK > 0) - TopKAccuracy = FetchDouble(MultiClassClassifierEvaluator.TopKAccuracy); - - var perClassLogLoss = RowCursorUtils.Fetch>(ectx, overallResult, MultiClassClassifierEvaluator.PerClassLogLoss); - PerClassLogLoss = new double[perClassLogLoss.Length]; - perClassLogLoss.CopyTo(PerClassLogLoss); - } - } - /// /// Evaluates scored multiclass classification data. /// @@ -595,7 +503,7 @@ internal Result(IExceptionContext ectx, IRow overallResult, int topK) /// The name of the score column in . /// The name of the predicted label column in . /// The evaluation results for these outputs. - public Result Evaluate(IDataView data, string label, string score, string predictedLabel) + public MultiClassClassifierMetrics Evaluate(IDataView data, string label, string score, string predictedLabel) { Host.CheckValue(data, nameof(data)); Host.CheckNonEmpty(label, nameof(label)); @@ -607,16 +515,16 @@ public Result Evaluate(IDataView data, string label, string score, string predic RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, score), RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.PredictedLabel, predictedLabel)); - var resultDict = Evaluate(roles); + var resultDict = ((IEvaluator)this).Evaluate(roles); Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics)); var overall = resultDict[MetricKinds.OverallMetrics]; - Result result; + MultiClassClassifierMetrics result; using (var cursor = overall.GetRowCursor(i => true)) { var moved = cursor.MoveNext(); Host.Assert(moved); - result = new Result(Host, cursor, _outputTopKAcc ?? 0); + result = new MultiClassClassifierMetrics(Host, cursor, _outputTopKAcc ?? 0); moved = cursor.MoveNext(); Host.Assert(!moved); } @@ -658,18 +566,18 @@ private static VersionInfo GetVersionInfo() private readonly ReadOnlyMemory[] _classNames; private readonly ColumnType[] _types; - public MultiClassPerInstanceEvaluator(IHostEnvironment env, Schema schema, ColumnInfo scoreInfo, string labelCol) - : base(env, schema, Contracts.CheckRef(scoreInfo, nameof(scoreInfo)).Name, labelCol) + public MultiClassPerInstanceEvaluator(IHostEnvironment env, Schema schema, Schema.Column scoreColumn, string labelCol) + : base(env, schema, scoreColumn.Name, labelCol) { CheckInputColumnTypes(schema); - _numClasses = scoreInfo.Type.VectorSize; + _numClasses = scoreColumn.Type.VectorSize; _types = new ColumnType[4]; - if (schema.HasSlotNames(ScoreIndex, _numClasses)) + if (schema[ScoreIndex].HasSlotNames(_numClasses)) { var classNames = default(VBuffer>); - schema.GetMetadata(MetadataUtils.Kinds.SlotNames, ScoreIndex, ref classNames); + schema[(int) ScoreIndex].Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref classNames); _classNames = new ReadOnlyMemory[_numClasses]; classNames.CopyTo(_classNames); } @@ -683,7 +591,7 @@ public MultiClassPerInstanceEvaluator(IHostEnvironment env, Schema schema, Colum _types[SortedClassesCol] = new VectorType(key, _numClasses); } - private MultiClassPerInstanceEvaluator(IHostEnvironment env, ModelLoadContext ctx, ISchema schema) + private MultiClassPerInstanceEvaluator(IHostEnvironment env, ModelLoadContext ctx, Schema schema) : base(env, ctx, schema) { CheckInputColumnTypes(schema); @@ -712,7 +620,7 @@ private MultiClassPerInstanceEvaluator(IHostEnvironment env, ModelLoadContext ct _types[SortedClassesCol] = new VectorType(key, _numClasses); } - public static MultiClassPerInstanceEvaluator Create(IHostEnvironment env, ModelLoadContext ctx, ISchema schema) + public static MultiClassPerInstanceEvaluator Create(IHostEnvironment env, ModelLoadContext ctx, Schema schema) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); @@ -739,7 +647,7 @@ public override void Save(ModelSaveContext ctx) ctx.SaveNonEmptyString(_classNames[i].ToString()); } - public override Func GetDependencies(Func activeOutput) + private protected override Func GetDependenciesCore(Func activeOutput) { Host.Assert(ScoreIndex >= 0); Host.Assert(LabelIndex >= 0); @@ -753,13 +661,13 @@ public override Func GetDependencies(Func activeOutput) activeOutput(SortedClassesCol) || activeOutput(LogLossCol)); } - public override Delegate[] CreateGetters(IRow input, Func activeOutput, out Action disposer) + private protected override Delegate[] CreateGettersCore(Row input, Func activeCols, out Action disposer) { disposer = null; var getters = new Delegate[4]; - if (!activeOutput(AssignedCol) && !activeOutput(SortedClassesCol) && !activeOutput(SortedScoresCol) && !activeOutput(LogLossCol)) + if (!activeCols(AssignedCol) && !activeCols(SortedClassesCol) && !activeCols(SortedScoresCol) && !activeCols(LogLossCol)) return getters; long cachedPosition = -1; @@ -768,7 +676,7 @@ public override Delegate[] CreateGetters(IRow input, Func activeOutpu var scoresArr = new float[_numClasses]; int[] sortedIndices = new int[_numClasses]; - var labelGetter = activeOutput(LogLossCol) ? RowCursorUtils.GetLabelGetter(input, LabelIndex) : + var labelGetter = activeCols(LogLossCol) ? RowCursorUtils.GetLabelGetter(input, LabelIndex) : (ref float dst) => dst = float.NaN; var scoreGetter = input.GetGetter>(ScoreIndex); Action updateCacheIfNeeded = @@ -786,7 +694,7 @@ public override Delegate[] CreateGetters(IRow input, Func activeOutpu } }; - if (activeOutput(AssignedCol)) + if (activeCols(AssignedCol)) { ValueGetter assignedFn = (ref uint dst) => @@ -797,7 +705,7 @@ public override Delegate[] CreateGetters(IRow input, Func activeOutpu getters[AssignedCol] = assignedFn; } - if (activeOutput(SortedScoresCol)) + if (activeCols(SortedScoresCol)) { ValueGetter> topKScoresFn = (ref VBuffer dst) => @@ -811,7 +719,7 @@ public override Delegate[] CreateGetters(IRow input, Func activeOutpu getters[SortedScoresCol] = topKScoresFn; } - if (activeOutput(SortedClassesCol)) + if (activeCols(SortedClassesCol)) { ValueGetter> topKClassesFn = (ref VBuffer dst) => @@ -825,7 +733,7 @@ public override Delegate[] CreateGetters(IRow input, Func activeOutpu getters[SortedClassesCol] = topKClassesFn; } - if (activeOutput(LogLossCol)) + if (activeCols(LogLossCol)) { ValueGetter logLossFn = (ref double dst) => @@ -852,7 +760,7 @@ public override Delegate[] CreateGetters(IRow input, Func activeOutpu return getters; } - public override Schema.DetachedColumn[] GetOutputColumns() + private protected override Schema.DetachedColumn[] GetOutputColumnsCore() { var infos = new Schema.DetachedColumn[4]; @@ -899,15 +807,15 @@ private ValueGetter>> CreateKeyValueGetter() }; } - private void CheckInputColumnTypes(ISchema schema) + private void CheckInputColumnTypes(Schema schema) { Host.AssertNonEmpty(ScoreCol); Host.AssertNonEmpty(LabelCol); - var t = schema.GetColumnType(ScoreIndex); + var t = schema[(int) ScoreIndex].Type; if (t.VectorSize < 2 || t.ItemType != NumberType.Float) throw Host.Except("Score column '{0}' has type '{1}' but must be a vector of two or more items of type R4", ScoreCol, t); - t = schema.GetColumnType(LabelIndex); + t = schema[LabelIndex].Type; if (t != NumberType.Float && t.KeyCount <= 0) throw Host.Except("Label column '{0}' has type '{1}' but must be a float or a known-cardinality key", LabelCol, t); } @@ -938,7 +846,7 @@ public class Arguments : ArgumentsBase private readonly int? _outputTopKAcc; private readonly MultiClassClassifierEvaluator _evaluator; - protected override IEvaluator Evaluator { get { return _evaluator; } } + private protected override IEvaluator Evaluator => _evaluator; public MultiClassMamlEvaluator(IHostEnvironment env, Arguments args) : base(args, env, MetadataUtils.Const.ScoreColumnKind.MultiClassClassification, "MultiClassMamlEvaluator") @@ -962,7 +870,7 @@ public MultiClassMamlEvaluator(IHostEnvironment env, Arguments args) _evaluator = new MultiClassClassifierEvaluator(Host, evalArgs); } - protected override void PrintFoldResultsCore(IChannel ch, Dictionary metrics) + private protected override void PrintFoldResultsCore(IChannel ch, Dictionary metrics) { Host.AssertValue(metrics); @@ -992,7 +900,7 @@ protected override void PrintFoldResultsCore(IChannel ch, Dictionary(); @@ -1015,10 +923,10 @@ protected override IDataView CombineOverallMetricsCore(IDataView[] metrics) var idv = views[i]; // Find the old per-class log-loss column and drop it. - for (int col = 0; col < idv.Schema.ColumnCount; col++) + for (int col = 0; col < idv.Schema.Count; col++) { - if (idv.Schema.IsHidden(col) && - idv.Schema.GetColumnName(col).Equals(MultiClassClassifierEvaluator.PerClassLogLoss)) + if (idv.Schema[col].IsHidden && + idv.Schema[col].Name.Equals(MultiClassClassifierEvaluator.PerClassLogLoss)) { idv = new ChooseColumnsByIndexTransform(Host, new ChooseColumnsByIndexTransform.Arguments() { Drop = true, Index = new[] { col } }, idv); @@ -1031,7 +939,7 @@ protected override IDataView CombineOverallMetricsCore(IDataView[] metrics) return base.CombineOverallMetricsCore(views); } - protected override IDataView GetOverallResultsCore(IDataView overall) + private protected override IDataView GetOverallResultsCore(IDataView overall) { // Change the name of the Top-k-accuracy column. if (_outputTopKAcc != null) @@ -1069,13 +977,13 @@ public override IEnumerable GetOverallMetricColumns() yield return new MetricColumn("LogLossReduction", MultiClassClassifierEvaluator.LogLossReduction); } - protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema) + private protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema) { Host.CheckValue(schema, nameof(schema)); - Host.CheckParam(schema.Label != null, nameof(schema), "Schema must contain a label column"); + Host.CheckParam(schema.Label.HasValue, nameof(schema), "Schema must contain a label column"); // Output the label column. - yield return schema.Label.Name; + yield return schema.Label.Value.Name; // Return the output columns. yield return MultiClassPerInstanceEvaluator.Assigned; @@ -1085,24 +993,25 @@ protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSch } // Multi-class evaluator adds four per-instance columns: "Assigned", "Top scores", "Top classes" and "Log-loss". - protected override IDataView GetPerInstanceMetricsCore(IDataView perInst, RoleMappedSchema schema) + private protected override IDataView GetPerInstanceMetricsCore(IDataView perInst, RoleMappedSchema schema) { // If the label column is a key without key values, convert it to I8, just for saving the per-instance // text file, since if there are different key counts the columns cannot be appended. - if (!perInst.Schema.TryGetColumnIndex(schema.Label.Name, out int labelCol)) - throw Host.Except("Could not find column '{0}'", schema.Label.Name); - var labelType = perInst.Schema.GetColumnType(labelCol); - if (labelType.IsKey && (!perInst.Schema.HasKeyValues(labelCol, labelType.KeyCount) || labelType.RawKind != DataKind.U4)) + string labelName = schema.Label.Value.Name; + if (!perInst.Schema.TryGetColumnIndex(labelName, out int labelCol)) + throw Host.Except("Could not find column '{0}'", labelName); + var labelType = perInst.Schema[labelCol].Type; + if (labelType is KeyType keyType && (!(bool)perInst.Schema[labelCol].HasKeyValues(keyType.KeyCount) || labelType.RawKind != DataKind.U4)) { - perInst = LambdaColumnMapper.Create(Host, "ConvertToDouble", perInst, schema.Label.Name, - schema.Label.Name, perInst.Schema.GetColumnType(labelCol), NumberType.R8, - (in uint src, ref double dst) => dst = src == 0 ? double.NaN : src - 1 + (double)labelType.AsKey.Min); + perInst = LambdaColumnMapper.Create(Host, "ConvertToDouble", perInst, labelName, + labelName, perInst.Schema[labelCol].Type, NumberType.R8, + (in uint src, ref double dst) => dst = src == 0 ? double.NaN : src - 1 + (double)keyType.Min); } var perInstSchema = perInst.Schema; if (perInstSchema.TryGetColumnIndex(MultiClassPerInstanceEvaluator.SortedClasses, out int sortedClassesIndex)) { - var type = perInstSchema.GetColumnType(sortedClassesIndex); + var type = perInstSchema[sortedClassesIndex].Type; // Wrap with a DropSlots transform to pick only the first _numTopClasses slots. if (_numTopClasses < type.VectorSize) perInst = new SlotsDroppingTransformer(Host, MultiClassPerInstanceEvaluator.SortedClasses, min: _numTopClasses).Transform(perInst); @@ -1111,9 +1020,9 @@ protected override IDataView GetPerInstanceMetricsCore(IDataView perInst, RoleMa // Wrap with a DropSlots transform to pick only the first _numTopClasses slots. if (perInst.Schema.TryGetColumnIndex(MultiClassPerInstanceEvaluator.SortedScores, out int sortedScoresIndex)) { - var type = perInst.Schema.GetColumnType(sortedScoresIndex); + var type = perInst.Schema[sortedScoresIndex].Type; if (_numTopClasses < type.VectorSize) - perInst = new SlotsDroppingTransformer(Host, MultiClassPerInstanceEvaluator.SortedScores, min: _numTopClasses).Transform(perInst); + perInst = new SlotsDroppingTransformer(Host, MultiClassPerInstanceEvaluator.SortedScores, min: _numTopClasses).Transform(perInst); } return perInst; } @@ -1130,7 +1039,7 @@ public static CommonOutputs.ClassificationEvaluateOutput MultiClass(IHostEnviron EntryPointUtils.CheckInputArgs(host, input); MatchColumns(host, input, out string label, out string weight, out string name); - var evaluator = new MultiClassMamlEvaluator(host, input); + IMamlEvaluator evaluator = new MultiClassMamlEvaluator(host, input); var data = new RoleMappedData(input.Data, label, null, null, weight, name); var metrics = evaluator.Evaluate(data); diff --git a/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs index 1b1f8f6f05..57e0629418 100644 --- a/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs @@ -2,18 +2,17 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Numeric; using System; using System.Collections.Generic; using System.Text; using System.Text.RegularExpressions; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Numeric; using Float = System.Single; [assembly: LoadableClass(typeof(MultiOutputRegressionEvaluator), typeof(MultiOutputRegressionEvaluator), typeof(MultiOutputRegressionEvaluator.Arguments), typeof(SignatureEvaluator), @@ -26,7 +25,7 @@ [assembly: LoadableClass(typeof(MultiOutputRegressionPerInstanceEvaluator), null, typeof(SignatureLoadRowMapper), "", MultiOutputRegressionPerInstanceEvaluator.LoaderSignature)] -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { public sealed class MultiOutputRegressionEvaluator : RegressionLossEvaluatorBase { @@ -47,28 +46,27 @@ public MultiOutputRegressionEvaluator(IHostEnvironment env, Arguments args) { } - protected override IRowMapper CreatePerInstanceRowMapper(RoleMappedSchema schema) + private protected override IRowMapper CreatePerInstanceRowMapper(RoleMappedSchema schema) { - Host.CheckParam(schema.Label != null, nameof(schema), "Could not find the label column"); - var scoreInfo = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); - Host.AssertValue(scoreInfo); + Host.CheckParam(schema.Label.HasValue, nameof(schema), "Could not find the label column"); + var scoreCol = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); - return new MultiOutputRegressionPerInstanceEvaluator(Host, schema.Schema, scoreInfo.Name, schema.Label.Name); + return new MultiOutputRegressionPerInstanceEvaluator(Host, schema.Schema, scoreCol.Name, schema.Label.Value.Name); } - protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) + private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) { var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); var t = score.Type; if (t.VectorSize == 0 || t.ItemType != NumberType.Float) - throw Host.Except("Score column '{0}' has type '{1}' but must be a known length vector of type R4", score.Name, t); - Host.Check(schema.Label != null, "Could not find the label column"); - t = schema.Label.Type; + throw Host.ExceptSchemaMismatch(nameof(schema), "score", score.Name, "known size vector of R4", t.ToString()); + Host.Check(schema.Label.HasValue, "Could not find the label column"); + t = schema.Label.Value.Type; if (!t.IsKnownSizeVector || (t.ItemType != NumberType.R4 && t.ItemType != NumberType.R8)) - throw Host.Except("Label column '{0}' has type '{1}' but must be a known-size vector of R4 or R8", schema.Label.Name, t); + throw Host.ExceptSchemaMismatch(nameof(schema), "label", schema.Label.Value.Name, "known size vector of R4 or R8", t.ToString()); } - protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName) + private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName) { var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); Host.Assert(score.Type.VectorSize > 0); @@ -92,7 +90,7 @@ public override IEnumerable GetOverallMetricColumns() groupName: "label", nameFormat: string.Format("{0} (Label_{{0}}", PerLabelLoss)); } - protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries, + private protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries, out Action, Aggregator> addAgg, out Func> consolidate) { var stratCol = new List(); @@ -299,20 +297,20 @@ public Aggregator(IHostEnvironment env, IRegressionLoss lossFunction, int size, WeightedCounters = Weighted ? new Counters(lossFunction, _size) : null; } - public override void InitializeNextPass(IRow row, RoleMappedSchema schema) + internal override void InitializeNextPass(Row row, RoleMappedSchema schema) { Contracts.Assert(PassNum < 1); - Contracts.AssertValue(schema.Label); + Contracts.Assert(schema.Label.HasValue); var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); - _labelGetter = RowCursorUtils.GetVecGetterAs(NumberType.Float, row, schema.Label.Index); + _labelGetter = RowCursorUtils.GetVecGetterAs(NumberType.Float, row, schema.Label.Value.Index); _scoreGetter = row.GetGetter>(score.Index); Contracts.AssertValue(_labelGetter); Contracts.AssertValue(_scoreGetter); - if (schema.Weight != null) - _weightGetter = row.GetGetter(schema.Weight.Index); + if (schema.Weight.HasValue) + _weightGetter = row.GetGetter(schema.Weight.Value.Index); } public override void ProcessRow() @@ -416,13 +414,13 @@ private MultiOutputRegressionPerInstanceEvaluator(IHostEnvironment env, ModelLoa // base } - public static MultiOutputRegressionPerInstanceEvaluator Create(IHostEnvironment env, ModelLoadContext ctx, ISchema schema) + public static MultiOutputRegressionPerInstanceEvaluator Create(IHostEnvironment env, ModelLoadContext ctx, Schema schema) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - return new MultiOutputRegressionPerInstanceEvaluator(env, ctx, Schema.Create(schema)); + return new MultiOutputRegressionPerInstanceEvaluator(env, ctx, schema); } public override void Save(ModelSaveContext ctx) @@ -436,7 +434,7 @@ public override void Save(ModelSaveContext ctx) base.Save(ctx); } - public override Func GetDependencies(Func activeOutput) + private protected override Func GetDependenciesCore(Func activeOutput) { return col => @@ -446,7 +444,7 @@ public override Func GetDependencies(Func activeOutput) (col == ScoreIndex || col == LabelIndex); } - public override Schema.DetachedColumn[] GetOutputColumns() + private protected override Schema.DetachedColumn[] GetOutputColumnsCore() { var infos = new Schema.DetachedColumn[5]; infos[LabelOutput] = new Schema.DetachedColumn(LabelCol, _labelType, _labelMetadata); @@ -457,7 +455,7 @@ public override Schema.DetachedColumn[] GetOutputColumns() return infos; } - public override Delegate[] CreateGetters(IRow input, Func activeCols, out Action disposer) + private protected override Delegate[] CreateGettersCore(Row input, Func activeCols, out Action disposer) { Host.Assert(LabelIndex >= 0); Host.Assert(ScoreIndex >= 0); @@ -546,19 +544,19 @@ private void CheckInputColumnTypes(Schema schema, out ColumnType labelType, out Host.AssertNonEmpty(ScoreCol); Host.AssertNonEmpty(LabelCol); - var t = schema.GetColumnType(LabelIndex); + var t = schema[(int) LabelIndex].Type; if (!t.IsKnownSizeVector || (t.ItemType != NumberType.R4 && t.ItemType != NumberType.R8)) throw Host.Except("Label column '{0}' has type '{1}' but must be a known-size vector of R4 or R8", LabelCol, t); - labelType = new VectorType(t.ItemType.AsPrimitive, t.VectorSize); + labelType = new VectorType((PrimitiveType)t.ItemType, t.VectorSize); var slotNamesType = new VectorType(TextType.Instance, t.VectorSize); var builder = new MetadataBuilder(); builder.AddSlotNames(t.VectorSize, CreateSlotNamesGetter(schema, LabelIndex, labelType.VectorSize, "True")); labelMetadata = builder.GetMetadata(); - t = schema.GetColumnType(ScoreIndex); + t = schema[ScoreIndex].Type; if (t.VectorSize == 0 || t.ItemType != NumberType.Float) throw Host.Except("Score column '{0}' has type '{1}' but must be a known length vector of type R4", ScoreCol, t); - scoreType = new VectorType(t.ItemType.AsPrimitive, t.VectorSize); + scoreType = new VectorType((PrimitiveType)t.ItemType, t.VectorSize); builder = new MetadataBuilder(); builder.AddSlotNames(t.VectorSize, CreateSlotNamesGetter(schema, ScoreIndex, scoreType.VectorSize, "Predicted")); @@ -590,13 +588,13 @@ private void GetScoreValueKind(ref ReadOnlyMemory dst) dst = MetadataUtils.Const.ScoreValueKind.Score.AsMemory(); } - private ValueGetter>> CreateSlotNamesGetter(ISchema schema, int column, int length, string prefix) + private ValueGetter>> CreateSlotNamesGetter(Schema schema, int column, int length, string prefix) { - var type = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.SlotNames, column); + var type = schema[column].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.SlotNames)?.Type; if (type != null && type.IsText) { return - (ref VBuffer> dst) => schema.GetMetadata(MetadataUtils.Kinds.SlotNames, column, ref dst); + (ref VBuffer> dst) => schema[column].Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref dst); } return (ref VBuffer> dst) => @@ -623,7 +621,7 @@ public sealed class Arguments : ArgumentsBase private readonly MultiOutputRegressionEvaluator _evaluator; private readonly bool _supressScoresAndLabels; - protected override IEvaluator Evaluator { get { return _evaluator; } } + private protected override IEvaluator Evaluator => _evaluator; public MultiOutputRegressionMamlEvaluator(IHostEnvironment env, Arguments args) : base(args, env, MetadataUtils.Const.ScoreColumnKind.MultiOutputRegression, "RegressionMamlEvaluator") @@ -636,7 +634,7 @@ public MultiOutputRegressionMamlEvaluator(IHostEnvironment env, Arguments args) _evaluator = new MultiOutputRegressionEvaluator(Host, evalArgs); } - protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema) + private protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema) { Host.CheckValue(schema, nameof(schema)); Host.CheckParam(schema.Label != null, nameof(schema), "Schema must contain a label column"); @@ -644,11 +642,11 @@ protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSch // The multi output regression evaluator outputs the label and score column if requested by the user. if (!_supressScoresAndLabels) { - yield return schema.Label.Name; + yield return schema.Label.Value.Name; - var scoreInfo = EvaluateUtils.GetScoreColumnInfo(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn), + var scoreCol = EvaluateUtils.GetScoreColumn(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn), MetadataUtils.Const.ScoreColumnKind.MultiOutputRegression); - yield return scoreInfo.Name; + yield return scoreCol.Name; } // Return the output columns. @@ -658,7 +656,7 @@ protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSch } // The multi-output regression evaluator prints only the per-label metrics for each fold. - protected override void PrintFoldResultsCore(IChannel ch, Dictionary metrics) + private protected override void PrintFoldResultsCore(IChannel ch, Dictionary metrics) { IDataView fold; if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out fold)) @@ -673,7 +671,7 @@ protected override void PrintFoldResultsCore(IChannel ch, Dictionary>[colCount]; using (var cursor = fold.GetRowCursor(col => true)) @@ -688,22 +686,22 @@ protected override void PrintFoldResultsCore(IChannel ch, Dictionary stratGetter; if (hasStrats) { - var type = cursor.Schema.GetColumnType(stratCol); + var type = cursor.Schema[stratCol].Type; stratGetter = RowCursorUtils.GetGetterAs(type, cursor, stratCol); } else stratGetter = (ref uint dst) => dst = 0; int labelCount = 0; - for (int i = 0; i < fold.Schema.ColumnCount; i++) + for (int i = 0; i < fold.Schema.Count; i++) { - if (fold.Schema.IsHidden(i) || (needWeighted && i == isWeightedCol) || + if (fold.Schema[i].IsHidden || (needWeighted && i == isWeightedCol) || (hasStrats && (i == stratCol || i == stratVal))) { continue; } - var type = fold.Schema.GetColumnType(i); + var type = fold.Schema[i].Type; if (type.IsKnownSizeVector && type.ItemType == NumberType.R8) { vBufferGetters[i] = cursor.GetGetter>(i); @@ -752,7 +750,7 @@ protected override void PrintFoldResultsCore(IChannel ch, Dictionary, VBuffer> @@ -39,44 +38,41 @@ public QuantileRegressionEvaluator(IHostEnvironment env, Arguments args) { } - protected override IRowMapper CreatePerInstanceRowMapper(RoleMappedSchema schema) + private protected override IRowMapper CreatePerInstanceRowMapper(RoleMappedSchema schema) { - Host.CheckParam(schema.Label != null, nameof(schema), "Schema must contain a label column"); + Host.CheckParam(schema.Label.HasValue, nameof(schema), "Must contain a label column"); var scoreInfo = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); int scoreSize = scoreInfo.Type.VectorSize; - var type = schema.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.SlotNames, scoreInfo.Index); + var type = schema.Schema[scoreInfo.Index].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.SlotNames)?.Type; Host.Check(type != null && type.IsKnownSizeVector && type.ItemType.IsText, "Quantile regression score column must have slot names"); var quantiles = default(VBuffer>); - schema.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, scoreInfo.Index, ref quantiles); + schema.Schema[scoreInfo.Index].Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref quantiles); Host.Assert(quantiles.IsDense && quantiles.Length == scoreSize); - return new QuantileRegressionPerInstanceEvaluator(Host, schema.Schema, scoreInfo.Name, schema.Label.Name, scoreSize, quantiles); + return new QuantileRegressionPerInstanceEvaluator(Host, schema.Schema, scoreInfo.Name, schema.Label.Value.Name, scoreSize, quantiles); } - protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) + private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) { var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); var t = score.Type; if (t.VectorSize == 0 || (t.ItemType != NumberType.R4 && t.ItemType != NumberType.R8)) - { - throw Host.Except( - "Score column '{0}' has type '{1}' but must be a known length vector of type R4 or R8", score.Name, t); - } - Host.Check(schema.Label != null, "Could not find the label column"); - t = schema.Label.Type; + throw Host.ExceptSchemaMismatch(nameof(schema), "score", score.Name, "vector of type R4 or R8", t.ToString()); + Host.CheckParam(schema.Label.HasValue, nameof(schema), "Must contain a label column"); + t = schema.Label.Value.Type; if (t != NumberType.R4) - throw Host.Except("Label column '{0}' has type '{1}' but must be R4", schema.Label.Name, t); + throw Host.ExceptSchemaMismatch(nameof(schema), "label", schema.Label.Value.Name, "R4", t.ToString()); } - protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName) + private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName) { var scoreInfo = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); var t = scoreInfo.Type; Host.Assert(t.VectorSize > 0 && (t.ItemType == NumberType.R4 || t.ItemType == NumberType.R8)); var slotNames = default(VBuffer>); - t = schema.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.SlotNames, scoreInfo.Index); + t = schema.Schema[scoreInfo.Index].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.SlotNames)?.Type; if (t != null && t.VectorSize == scoreInfo.Type.VectorSize && t.ItemType.IsText) - schema.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, scoreInfo.Index, ref slotNames); + schema.Schema[scoreInfo.Index].GetSlotNames(ref slotNames); return new Aggregator(Host, LossFunction, schema.Weight != null, scoreInfo.Type.VectorSize, in slotNames, stratName); } @@ -284,7 +280,7 @@ private static VersionInfo GetVersionInfo() private readonly VBuffer> _quantiles; private readonly ColumnType _outputType; - public QuantileRegressionPerInstanceEvaluator(IHostEnvironment env, ISchema schema, string scoreCol, string labelCol, int scoreSize, VBuffer> quantiles) + public QuantileRegressionPerInstanceEvaluator(IHostEnvironment env, Schema schema, string scoreCol, string labelCol, int scoreSize, VBuffer> quantiles) : base(env, schema, scoreCol, labelCol) { Host.CheckParam(scoreSize > 0, nameof(scoreSize), "must be greater than 0"); @@ -298,7 +294,7 @@ public QuantileRegressionPerInstanceEvaluator(IHostEnvironment env, ISchema sche _outputType = new VectorType(NumberType.R8, _scoreSize); } - private QuantileRegressionPerInstanceEvaluator(IHostEnvironment env, ModelLoadContext ctx, ISchema schema) + private QuantileRegressionPerInstanceEvaluator(IHostEnvironment env, ModelLoadContext ctx, Schema schema) : base(env, ctx, schema) { CheckInputColumnTypes(schema); @@ -317,7 +313,7 @@ private QuantileRegressionPerInstanceEvaluator(IHostEnvironment env, ModelLoadCo _outputType = new VectorType(NumberType.R8, _scoreSize); } - public static QuantileRegressionPerInstanceEvaluator Create(IHostEnvironment env, ModelLoadContext ctx, ISchema schema) + public static QuantileRegressionPerInstanceEvaluator Create(IHostEnvironment env, ModelLoadContext ctx, Schema schema) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); @@ -345,13 +341,13 @@ public override void Save(ModelSaveContext ctx) ctx.SaveNonEmptyString(quantiles[i].ToString()); } - public override Func GetDependencies(Func activeOutput) + private protected override Func GetDependenciesCore(Func activeOutput) { return col => (activeOutput(L1Col) || activeOutput(L2Col)) && (col == ScoreIndex || col == LabelIndex); } - public override Schema.DetachedColumn[] GetOutputColumns() + private protected override Schema.DetachedColumn[] GetOutputColumnsCore() { var infos = new Schema.DetachedColumn[2]; @@ -380,7 +376,7 @@ private ValueGetter>> CreateSlotNamesGetter(string }; } - public override Delegate[] CreateGetters(IRow input, Func activeCols, out Action disposer) + private protected override Delegate[] CreateGettersCore(Row input, Func activeCols, out Action disposer) { Host.Assert(LabelIndex >= 0); Host.Assert(ScoreIndex >= 0); @@ -442,16 +438,16 @@ public override Delegate[] CreateGetters(IRow input, Func activeCols, return getters; } - private void CheckInputColumnTypes(ISchema schema) + private void CheckInputColumnTypes(Schema schema) { Host.AssertNonEmpty(ScoreCol); Host.AssertNonEmpty(LabelCol); - var t = schema.GetColumnType(LabelIndex); + var t = schema[(int)LabelIndex].Type; if (t != NumberType.R4) throw Host.Except("Label column '{0}' has type '{1}' but must be R4", LabelCol, t); - t = schema.GetColumnType(ScoreIndex); + t = schema[ScoreIndex].Type; if (t.VectorSize == 0 || (t.ItemType != NumberType.R4 && t.ItemType != NumberType.R8)) { throw Host.Except( @@ -474,7 +470,7 @@ public sealed class Arguments : ArgumentsBase private readonly int? _index; private readonly QuantileRegressionEvaluator _evaluator; - protected override IEvaluator Evaluator => _evaluator; + private protected override IEvaluator Evaluator => _evaluator; public QuantileRegressionMamlEvaluator(IHostEnvironment env, Arguments args) : base(args, env, MetadataUtils.Const.ScoreColumnKind.QuantileRegression, "QuantilsRegressionMamlEvaluator") @@ -487,7 +483,7 @@ public QuantileRegressionMamlEvaluator(IHostEnvironment env, Arguments args) _evaluator = new QuantileRegressionEvaluator(Host, evalArgs); } - protected override void PrintFoldResultsCore(IChannel ch, Dictionary metrics) + private protected override void PrintFoldResultsCore(IChannel ch, Dictionary metrics) { ch.AssertValue(metrics); @@ -505,7 +501,7 @@ protected override void PrintFoldResultsCore(IChannel ch, Dictionary src, ref Double dst) => dst = src.GetItemOrDefault(index)); @@ -538,16 +534,16 @@ public override IEnumerable GetOverallMetricColumns() yield return new MetricColumn("RSquared", QuantileRegressionEvaluator.RSquared); } - protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema) + private protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema) { Host.CheckValue(schema, nameof(schema)); - Host.CheckParam(schema.Label != null, nameof(schema), "Schema must contain a label column"); + Host.CheckParam(schema.Label.HasValue, nameof(schema), "Must contain a label column"); // The quantile regression evaluator outputs the label and score columns. - yield return schema.Label.Name; - var scoreInfo = EvaluateUtils.GetScoreColumnInfo(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn), + yield return schema.Label.Value.Name; + var scoreCol = EvaluateUtils.GetScoreColumn(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn), MetadataUtils.Const.ScoreColumnKind.QuantileRegression); - yield return scoreInfo.Name; + yield return scoreCol.Name; // Return the output columns. yield return RegressionPerInstanceEvaluator.L1; @@ -569,7 +565,7 @@ public static CommonOutputs.CommonEvaluateOutput QuantileRegression(IHostEnviron string weight; string name; MatchColumns(host, input, out label, out weight, out name); - var evaluator = new QuantileRegressionMamlEvaluator(host, input); + IMamlEvaluator evaluator = new QuantileRegressionMamlEvaluator(host, input); var data = new RoleMappedData(input.Data, label, null, null, weight, name); var metrics = evaluator.Evaluate(data); diff --git a/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs index bffe37be58..bcc2b0985e 100644 --- a/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs @@ -9,13 +9,12 @@ using System.Text; using System.Text.RegularExpressions; using System.Threading; +using Microsoft.ML; +using Microsoft.ML.CommandLine; using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; [assembly: LoadableClass(typeof(RankerEvaluator), typeof(RankerEvaluator), typeof(RankerEvaluator.Arguments), typeof(SignatureEvaluator), "Ranking Evaluator", RankerEvaluator.LoadName, "Ranking", "rank")] @@ -26,7 +25,7 @@ [assembly: LoadableClass(typeof(RankerPerInstanceTransform), null, typeof(SignatureLoadDataTransform), "", RankerPerInstanceTransform.LoaderSignature)] -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { public sealed class RankerEvaluator : EvaluatorBase { @@ -84,54 +83,53 @@ public RankerEvaluator(IHostEnvironment env, Arguments args) _labelGains = labelGains.ToArray(); } - protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) + private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) { - var t = schema.Label.Type; - if (t != NumberType.Float && !t.IsKey) + var t = schema.Label.Value.Type; + if (t != NumberType.Float && !(t is KeyType)) { - throw Host.ExceptUserArg(nameof(RankerMamlEvaluator.Arguments.LabelColumn), "Label column '{0}' has type '{1}' but must be R4 or a key", - schema.Label.Name, t); + throw Host.ExceptSchemaMismatch(nameof(RankerMamlEvaluator.Arguments.LabelColumn), + "label", schema.Label.Value.Name, "R4 or a key", t.ToString()); } - var scoreInfo = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); - if (scoreInfo.Type != NumberType.Float) + var scoreCol = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); + if (scoreCol.Type != NumberType.Float) { - throw Host.ExceptUserArg(nameof(RankerMamlEvaluator.Arguments.ScoreColumn), "Score column '{0}' has type '{1}' but must be R4", - scoreInfo.Name, t); + throw Host.ExceptSchemaMismatch(nameof(RankerMamlEvaluator.Arguments.ScoreColumn), + "score", scoreCol.Name, "R4", t.ToString()); } } - protected override void CheckCustomColumnTypesCore(RoleMappedSchema schema) + private protected override void CheckCustomColumnTypesCore(RoleMappedSchema schema) { - var t = schema.Group.Type; - if (!t.IsKey) + var t = schema.Group.Value.Type; + if (!(t is KeyType)) { - throw Host.ExceptUserArg(nameof(RankerMamlEvaluator.Arguments.GroupIdColumn), - "Group column '{0}' has type '{1}' but must be a key", - schema.Group.Name, t); + throw Host.ExceptSchemaMismatch(nameof(RankerMamlEvaluator.Arguments.GroupIdColumn), + "group", schema.Group.Value.Name, "key", t.ToString()); } } // Add also the group column. - protected override Func GetActiveColsCore(RoleMappedSchema schema) + private protected override Func GetActiveColsCore(RoleMappedSchema schema) { var pred = base.GetActiveColsCore(schema); - return i => i == schema.Group.Index || pred(i); + return i => i == schema.Group.Value.Index || pred(i); } - protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName) + private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName) { return new Aggregator(Host, _labelGains, _truncationLevel, _groupSummary, schema.Weight != null, stratName); } - public override IDataTransform GetPerInstanceMetrics(RoleMappedData data) + internal override IDataTransform GetPerInstanceMetricsCore(RoleMappedData data) { Host.CheckValue(data, nameof(data)); - Host.CheckParam(data.Schema.Label != null, nameof(data), "Schema must contain a label column"); + Host.CheckParam(data.Schema.Label.HasValue, nameof(data), "Schema must contain a label column"); var scoreInfo = data.Schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); - Host.CheckParam(data.Schema.Group != null, nameof(data), "Schema must contain a group column"); + Host.CheckParam(data.Schema.Group.HasValue, nameof(data), "Schema must contain a group column"); return new RankerPerInstanceTransform(Host, data.Data, - data.Schema.Label.Name, scoreInfo.Name, data.Schema.Group.Name, _truncationLevel, _labelGains); + data.Schema.Label.Value.Name, scoreInfo.Name, data.Schema.Group.Value.Name, _truncationLevel, _labelGains); } public override IEnumerable GetOverallMetricColumns() @@ -147,7 +145,7 @@ public override IEnumerable GetOverallMetricColumns() groupName: "at", nameFormat: string.Format("{0} @{{0}}", MaxDcg)); } - protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries, + private protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries, out Action, Aggregator> addAgg, out Func> consolidate) { var stratCol = new List(); @@ -243,7 +241,7 @@ protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, A /// The name of the groupId column. /// The name of the predicted score column. /// The evaluation metrics for these outputs. - public Result Evaluate(IDataView data, string label, string groupId, string score) + public RankerMetrics Evaluate(IDataView data, string label, string groupId, string score) { Host.CheckValue(data, nameof(data)); Host.CheckNonEmpty(label, nameof(label)); @@ -253,16 +251,16 @@ public Result Evaluate(IDataView data, string label, string groupId, string scor RoleMappedSchema.ColumnRole.Group.Bind(groupId), RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, score)); - var resultDict = Evaluate(roles); + var resultDict = ((IEvaluator)this).Evaluate(roles); Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics)); var overall = resultDict[MetricKinds.OverallMetrics]; - Result result; + RankerMetrics result; using (var cursor = overall.GetRowCursor(i => true)) { var moved = cursor.MoveNext(); Host.Assert(moved); - result = new Result(Host, cursor); + result = new RankerMetrics(Host, cursor); moved = cursor.MoveNext(); Host.Assert(!moved); } @@ -440,23 +438,23 @@ public Aggregator(IHostEnvironment env, Double[] labelGains, int truncationLevel GroupId = new List>(); } - public override void InitializeNextPass(IRow row, RoleMappedSchema schema) + internal override void InitializeNextPass(Row row, RoleMappedSchema schema) { Contracts.Assert(PassNum < 1); - Contracts.AssertValue(schema.Label); - Contracts.AssertValue(schema.Group); + Contracts.Assert(schema.Label.HasValue); + Contracts.Assert(schema.Group.HasValue); var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); - _labelGetter = RowCursorUtils.GetLabelGetter(row, schema.Label.Index); + _labelGetter = RowCursorUtils.GetLabelGetter(row, schema.Label.Value.Index); _scoreGetter = row.GetGetter(score.Index); - _newGroupDel = RowCursorUtils.GetIsNewGroupDelegate(row, schema.Group.Index); - if (schema.Weight != null) - _weightGetter = row.GetGetter(schema.Weight.Index); + _newGroupDel = RowCursorUtils.GetIsNewGroupDelegate(row, schema.Group.Value.Index); + if (schema.Weight.HasValue) + _weightGetter = row.GetGetter(schema.Weight.Value.Index); if (UnweightedCounters.GroupSummary) { - ValueGetter groupIdBuilder = RowCursorUtils.GetGetterAsStringBuilder(row, schema.Group.Index); + ValueGetter groupIdBuilder = RowCursorUtils.GetGetterAsStringBuilder(row, schema.Group.Value.Index); _groupSbUpdate = () => groupIdBuilder(ref _groupSb); } else @@ -538,40 +536,6 @@ public void GetSlotNames(ref VBuffer> slotNames) slotNames = editor.Commit(); } } - - public sealed class Result - { - /// - /// Normalized Discounted Cumulative Gain - /// - /// - public double[] Ndcg { get; } - - /// - /// Discounted Cumulative gain - /// is the sum of the gains, for all the instances i, normalized by the natural logarithm of the instance + 1. - /// Note that unline the Wikipedia article, ML.Net uses the natural logarithm. - /// - /// - public double[] Dcg { get; } - - private static T Fetch(IExceptionContext ectx, IRow row, string name) - { - if (!row.Schema.TryGetColumnIndex(name, out int col)) - throw ectx.Except($"Could not find column '{name}'"); - T val = default; - row.GetGetter(col)(ref val); - return val; - } - - internal Result(IExceptionContext ectx, IRow overallResult) - { - VBuffer Fetch(string name) => Fetch>(ectx, overallResult, name); - - Dcg = Fetch(RankerEvaluator.Dcg).GetValues().ToArray(); - Ndcg = Fetch(RankerEvaluator.Ndcg).GetValues().ToArray(); - } - } } public sealed class RankerPerInstanceTransform : IDataTransform @@ -600,7 +564,16 @@ private static VersionInfo GetVersionInfo() public bool CanShuffle { get { return _transform.CanShuffle; } } - public Schema Schema => _transform.Schema; + /// + /// Explicit implementation prevents Schema from being accessed from derived classes. + /// It's our first step to separate data produced by transform from transform. + /// + Schema IDataView.Schema => OutputSchema; + + /// + /// Shape information of the produced output. Note that the input and the output of this transform (and their types) are identical. + /// + public Schema OutputSchema => _transform.OutputSchema; public RankerPerInstanceTransform(IHostEnvironment env, IDataView input, string labelCol, string scoreCol, string groupCol, int truncationLevel, Double[] labelGains) @@ -635,14 +608,14 @@ public void Save(ModelSaveContext ctx) return _transform.GetRowCount(); } - public IRowCursor GetRowCursor(Func needCol, IRandom rand = null) + public RowCursor GetRowCursor(Func needCol, Random rand = null) { return _transform.GetRowCursor(needCol, rand); } - public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func needCol, int n, IRandom rand = null) + public RowCursor[] GetRowCursorSet(Func needCol, int n, Random rand = null) { - return _transform.GetRowCursorSet(out consolidator, needCol, n, rand); + return _transform.GetRowCursorSet(needCol, n, rand); } private sealed class Transform : PerGroupTransformBase @@ -654,7 +627,7 @@ private sealed class Bindings : BindingsBase private readonly int _truncationLevel; private readonly MetadataUtils.MetadataGetter>> _slotNamesGetter; - public Bindings(IExceptionContext ectx, ISchema input, bool user, string labelCol, string scoreCol, string groupCol, + public Bindings(IExceptionContext ectx, Schema input, bool user, string labelCol, string scoreCol, string groupCol, int truncationLevel) : base(ectx, input, labelCol, scoreCol, groupCol, user, Ndcg, Dcg, MaxDcg) { @@ -761,7 +734,7 @@ public override void Save(ModelSaveContext ctx) ctx.Writer.WriteDoubleArray(_labelGains); } - protected override BindingsBase GetBindings() + private protected override BindingsBase GetBindings() { return _bindings; } @@ -800,7 +773,7 @@ private void Copy(Double[] src, ref VBuffer dst) dst = editor.Commit(); } - protected override ValueGetter GetLabelGetter(IRow row) + protected override ValueGetter GetLabelGetter(Row row) { var lb = RowCursorUtils.GetLabelGetter(row, _bindings.LabelIndex); return @@ -812,12 +785,12 @@ protected override ValueGetter GetLabelGetter(IRow row) }; } - protected override ValueGetter GetScoreGetter(IRow row) + protected override ValueGetter GetScoreGetter(Row row) { return row.GetGetter(_bindings.ScoreIndex); } - protected override RowCursorState InitializeState(IRow input) + protected override RowCursorState InitializeState(Row input) { return new RowCursorState(_truncationLevel); } @@ -889,7 +862,7 @@ public sealed class Arguments : ArgumentsBase private readonly string _groupSummaryFilename; - protected override IEvaluator Evaluator { get { return _evaluator; } } + private protected override IEvaluator Evaluator => _evaluator; public RankerMamlEvaluator(IHostEnvironment env, Arguments args) : base(args, env, MetadataUtils.Const.ScoreColumnKind.Ranking, "RankerMamlEvaluator") @@ -907,14 +880,14 @@ public RankerMamlEvaluator(IHostEnvironment env, Arguments args) _groupIdCol = args.GroupIdColumn; } - protected override IEnumerable> GetInputColumnRolesCore(RoleMappedSchema schema) + private protected override IEnumerable> GetInputColumnRolesCore(RoleMappedSchema schema) { var cols = base.GetInputColumnRolesCore(schema); var groupIdCol = EvaluateUtils.GetColName(_groupIdCol, schema.Group, DefaultColumnNames.GroupId); return cols.Prepend(RoleMappedSchema.ColumnRole.Group.Bind(groupIdCol)); } - protected override void PrintAdditionalMetricsCore(IChannel ch, Dictionary[] metrics) + private protected override void PrintAdditionalMetricsCore(IChannel ch, Dictionary[] metrics) { ch.AssertNonEmpty(metrics); @@ -954,18 +927,18 @@ private bool TryGetGroupSummaryMetrics(Dictionary[] metrics, return true; } - protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema) + private protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema) { Host.CheckValue(schema, nameof(schema)); - Host.CheckValue(schema.Label, nameof(schema), "Data must contain a label column"); - Host.CheckValue(schema.Group, nameof(schema), "Data must contain a group column"); + Host.CheckParam(schema.Label.HasValue, nameof(schema), "Data must contain a label column"); + Host.CheckParam(schema.Group.HasValue, nameof(schema), "Data must contain a group column"); // The ranking evaluator outputs the label, group key and score columns. - yield return schema.Group.Name; - yield return schema.Label.Name; - var scoreInfo = EvaluateUtils.GetScoreColumnInfo(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn), + yield return schema.Group.Value.Name; + yield return schema.Label.Value.Name; + var scoreCol = EvaluateUtils.GetScoreColumn(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn), MetadataUtils.Const.ScoreColumnKind.Ranking); - yield return scoreInfo.Name; + yield return scoreCol.Name; // Return the output columns. yield return RankerPerInstanceTransform.Ndcg; @@ -1093,11 +1066,11 @@ public static CommonOutputs.CommonEvaluateOutput Ranking(IHostEnvironment env, R string weight; string name; MatchColumns(host, input, out label, out weight, out name); - ISchema schema = input.Data.Schema; + var schema = input.Data.Schema; string groupId = TrainUtils.MatchNameOrDefaultOrNull(host, schema, nameof(RankerMamlEvaluator.Arguments.GroupIdColumn), input.GroupIdColumn, DefaultColumnNames.GroupId); - var evaluator = new RankerMamlEvaluator(host, input); + IMamlEvaluator evaluator = new RankerMamlEvaluator(host, input); var data = new RoleMappedData(input.Data, label, null, groupId, weight, name); var metrics = evaluator.Evaluate(data); diff --git a/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs index df3c4a12fa..b26e151847 100644 --- a/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs @@ -2,14 +2,13 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Model; using System; using System.Collections.Generic; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Model; using Float = System.Single; [assembly: LoadableClass(typeof(RegressionEvaluator), typeof(RegressionEvaluator), typeof(RegressionEvaluator.Arguments), typeof(SignatureEvaluator), @@ -22,7 +21,7 @@ [assembly: LoadableClass(typeof(RegressionPerInstanceEvaluator), null, typeof(SignatureLoadRowMapper), "", RegressionPerInstanceEvaluator.LoaderSignature)] -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { public sealed class RegressionEvaluator : RegressionEvaluatorBase @@ -52,30 +51,29 @@ public RegressionEvaluator(IHostEnvironment env, Arguments args) { } - protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) + private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) { var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); var t = score.Type; - if (t.IsVector || t.ItemType != NumberType.Float) - throw Host.Except("Score column '{0}' has type '{1}' but must be R4", score, t); - Host.Check(schema.Label != null, "Could not find the label column"); - t = schema.Label.Type; + if (t != NumberType.Float) + throw Host.ExceptSchemaMismatch(nameof(schema), "score", score.Name, "R4", t.ToString()); + Host.CheckParam(schema.Label.HasValue, nameof(schema), "Could not find the label column"); + t = schema.Label.Value.Type; if (t != NumberType.R4) - throw Host.Except("Label column '{0}' has type '{1}' but must be R4", schema.Label.Name, t); + throw Host.ExceptSchemaMismatch(nameof(schema), "label", schema.Label.Value.Name, "R4", t.ToString()); } - protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName) + private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName) { return new Aggregator(Host, LossFunction, schema.Weight != null, stratName); } - protected override IRowMapper CreatePerInstanceRowMapper(RoleMappedSchema schema) + private protected override IRowMapper CreatePerInstanceRowMapper(RoleMappedSchema schema) { - Contracts.CheckParam(schema.Label != null, nameof(schema), "Could not find the label column"); + Contracts.CheckParam(schema.Label.HasValue, nameof(schema), "Could not find the label column"); var scoreInfo = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); - Contracts.AssertValue(scoreInfo); - return new RegressionPerInstanceEvaluator(Host, schema.Schema, scoreInfo.Name, schema.Label.Name); + return new RegressionPerInstanceEvaluator(Host, schema.Schema, scoreInfo.Name, schema.Label.Value.Name); } public override IEnumerable GetOverallMetricColumns() @@ -157,73 +155,6 @@ public override void AddColumn(ArrayDataViewBuilder dvBldr, string metricName, p } } - public sealed class Result - { - /// - /// Gets the absolute loss of the model. - /// - /// - /// The absolute loss is defined as - /// L1 = (1/m) * sum( abs( yi - y'i)) - /// where m is the number of instances in the test set. - /// y'i are the predicted labels for each instance. - /// yi are the correct labels of each instance. - /// - public double L1 { get; } - - /// - /// Gets the squared loss of the model. - /// - /// - /// The squared loss is defined as - /// L2 = (1/m) * sum(( yi - y'i)^2) - /// where m is the number of instances in the test set. - /// y'i are the predicted labels for each instance. - /// yi are the correct labels of each instance. - /// - public double L2 { get; } - - /// - /// Gets the root mean square loss (or RMS) which is the square root of the L2 loss. - /// - public double Rms { get; } - - /// - /// Gets the user defined loss function. - /// - /// - /// This is the average of a loss function defined by the user, - /// computed over all the instances in the test set. - /// - public double LossFn { get; } - - /// - /// Gets the R squared value of the model, which is also known as - /// the coefficient of determination​. - /// - public double RSquared { get; } - - internal Result(IExceptionContext ectx, IRow overallResult) - { - double Fetch(string name) => RowCursorUtils.Fetch(ectx, overallResult, name); - L1 = Fetch(RegressionEvaluator.L1); - L2 = Fetch(RegressionEvaluator.L2); - Rms = Fetch(RegressionEvaluator.Rms); - LossFn = Fetch(RegressionEvaluator.Loss); - RSquared = Fetch(RegressionEvaluator.RSquared); - } - - [BestFriend] - internal Result(double l1, double l2, double rms, double lossFunction, double rSquared) - { - L1 = l1; - L2 = l2; - Rms = rms; - LossFn = lossFunction; - RSquared = rSquared; - } - } - /// /// Evaluates scored regression data. /// @@ -231,7 +162,7 @@ internal Result(double l1, double l2, double rms, double lossFunction, double rS /// The name of the label column. /// The name of the predicted score column. /// The evaluation metrics for these outputs. - public Result Evaluate(IDataView data, string label, string score) + public RegressionMetrics Evaluate(IDataView data, string label, string score) { Host.CheckValue(data, nameof(data)); Host.CheckNonEmpty(label, nameof(label)); @@ -240,16 +171,16 @@ public Result Evaluate(IDataView data, string label, string score) RoleMappedSchema.ColumnRole.Label.Bind(label), RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, score)); - var resultDict = Evaluate(roles); + var resultDict = ((IEvaluator)this).Evaluate(roles); Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics)); var overall = resultDict[MetricKinds.OverallMetrics]; - Result result; + RegressionMetrics result; using (var cursor = overall.GetRowCursor(i => true)) { var moved = cursor.MoveNext(); Host.Assert(moved); - result = new Result(Host, cursor); + result = new RegressionMetrics(Host, cursor); moved = cursor.MoveNext(); Host.Assert(!moved); } @@ -277,13 +208,13 @@ private static VersionInfo GetVersionInfo() public const string L1 = "L1-loss"; public const string L2 = "L2-loss"; - public RegressionPerInstanceEvaluator(IHostEnvironment env, ISchema schema, string scoreCol, string labelCol) + public RegressionPerInstanceEvaluator(IHostEnvironment env, Schema schema, string scoreCol, string labelCol) : base(env, schema, scoreCol, labelCol) { CheckInputColumnTypes(schema); } - private RegressionPerInstanceEvaluator(IHostEnvironment env, ModelLoadContext ctx, ISchema schema) + private RegressionPerInstanceEvaluator(IHostEnvironment env, ModelLoadContext ctx, Schema schema) : base(env, ctx, schema) { CheckInputColumnTypes(schema); @@ -292,7 +223,7 @@ private RegressionPerInstanceEvaluator(IHostEnvironment env, ModelLoadContext ct // base } - public static RegressionPerInstanceEvaluator Create(IHostEnvironment env, ModelLoadContext ctx, ISchema schema) + public static RegressionPerInstanceEvaluator Create(IHostEnvironment env, ModelLoadContext ctx, Schema schema) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); @@ -312,13 +243,13 @@ public override void Save(ModelSaveContext ctx) base.Save(ctx); } - public override Func GetDependencies(Func activeOutput) + private protected override Func GetDependenciesCore(Func activeOutput) { return col => (activeOutput(L1Col) || activeOutput(L2Col)) && (col == ScoreIndex || col == LabelIndex); } - public override Schema.DetachedColumn[] GetOutputColumns() + private protected override Schema.DetachedColumn[] GetOutputColumnsCore() { var infos = new Schema.DetachedColumn[2]; infos[L1Col] = new Schema.DetachedColumn(L1, NumberType.R8, null); @@ -326,7 +257,7 @@ public override Schema.DetachedColumn[] GetOutputColumns() return infos; } - public override Delegate[] CreateGetters(IRow input, Func activeCols, out Action disposer) + private protected override Delegate[] CreateGettersCore(Row input, Func activeCols, out Action disposer) { Host.Assert(LabelIndex >= 0); Host.Assert(ScoreIndex >= 0); @@ -380,16 +311,16 @@ public override Delegate[] CreateGetters(IRow input, Func activeCols, return getters; } - private void CheckInputColumnTypes(ISchema schema) + private void CheckInputColumnTypes(Schema schema) { Host.AssertNonEmpty(ScoreCol); Host.AssertNonEmpty(LabelCol); - var t = schema.GetColumnType(LabelIndex); + var t = schema[(int) LabelIndex].Type; if (t != NumberType.R4) throw Host.Except("Label column '{0}' has type '{1}' but must be R4", LabelCol, t); - t = schema.GetColumnType(ScoreIndex); + t = schema[ScoreIndex].Type; if (t.IsVector || t.ItemType != NumberType.Float) throw Host.Except("Score column '{0}' has type '{1}' but must be R4", ScoreCol, t); } @@ -405,7 +336,7 @@ public sealed class Arguments : ArgumentsBase private readonly RegressionEvaluator _evaluator; - protected override IEvaluator Evaluator { get { return _evaluator; } } + private protected override IEvaluator Evaluator => _evaluator; public RegressionMamlEvaluator(IHostEnvironment env, Arguments args) : base(args, env, MetadataUtils.Const.ScoreColumnKind.Regression, "RegressionMamlEvaluator") @@ -417,16 +348,16 @@ public RegressionMamlEvaluator(IHostEnvironment env, Arguments args) _evaluator = new RegressionEvaluator(Host, evalArgs); } - protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema) + private protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema) { Host.CheckValue(schema, nameof(schema)); - Host.CheckParam(schema.Label != null, nameof(schema), "Schema must contain a label column"); + Host.CheckParam(schema.Label.HasValue, nameof(schema), "Schema must contain a label column"); // The regression evaluator outputs the label and score columns. - yield return schema.Label.Name; - var scoreInfo = EvaluateUtils.GetScoreColumnInfo(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn), + yield return schema.Label.Value.Name; + var scoreCol = EvaluateUtils.GetScoreColumn(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn), MetadataUtils.Const.ScoreColumnKind.Regression); - yield return scoreInfo.Name; + yield return scoreCol.Name; // Return the output columns. yield return RegressionPerInstanceEvaluator.L1; @@ -453,7 +384,7 @@ public static CommonOutputs.CommonEvaluateOutput Regression(IHostEnvironment env string weight; string name; MatchColumns(host, input, out label, out weight, out name); - var evaluator = new RegressionMamlEvaluator(host, input); + IMamlEvaluator evaluator = new RegressionMamlEvaluator(host, input); var data = new RoleMappedData(input.Data, label, null, null, weight, name); var metrics = evaluator.Evaluate(data); diff --git a/src/Microsoft.ML.Data/Evaluators/RegressionEvaluatorBase.cs b/src/Microsoft.ML.Data/Evaluators/RegressionEvaluatorBase.cs index dcd2ce618e..3b42fdbd05 100644 --- a/src/Microsoft.ML.Data/Evaluators/RegressionEvaluatorBase.cs +++ b/src/Microsoft.ML.Data/Evaluators/RegressionEvaluatorBase.cs @@ -4,10 +4,10 @@ using System; using System.Collections.Generic; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { public abstract class RegressionLossEvaluatorBase : RowToRowEvaluatorBase where TAgg : EvaluatorBase.AggregatorBase @@ -37,12 +37,13 @@ protected RegressionLossEvaluatorBase(ArgumentsBase args, IHostEnvironment env, public abstract class RegressionEvaluatorBase : RegressionLossEvaluatorBase where TAgg : RegressionEvaluatorBase.RegressionAggregatorBase { - protected RegressionEvaluatorBase(ArgumentsBase args, IHostEnvironment env, string registrationName) + [BestFriend] + private protected RegressionEvaluatorBase(ArgumentsBase args, IHostEnvironment env, string registrationName) : base(args, env, registrationName) { } - protected override void GetAggregatorConsolidationFuncs(TAgg aggregator, AggregatorDictionaryBase[] dictionaries, + private protected override void GetAggregatorConsolidationFuncs(TAgg aggregator, AggregatorDictionaryBase[] dictionaries, out Action, TAgg> addAgg, out Func> consolidate) { var stratCol = new List(); @@ -183,7 +184,8 @@ public void Update(ref TScore score, float label, float weight, ref TMetrics los public abstract CountersBase UnweightedCounters { get; } public abstract CountersBase WeightedCounters { get; } - protected RegressionAggregatorBase(IHostEnvironment env, IRegressionLoss lossFunction, bool weighted, string stratName) + [BestFriend] + private protected RegressionAggregatorBase(IHostEnvironment env, IRegressionLoss lossFunction, bool weighted, string stratName) : base(env, stratName) { Host.AssertValue(lossFunction); @@ -191,20 +193,20 @@ protected RegressionAggregatorBase(IHostEnvironment env, IRegressionLoss lossFun Weighted = weighted; } - public override void InitializeNextPass(IRow row, RoleMappedSchema schema) + internal override void InitializeNextPass(Row row, RoleMappedSchema schema) { Contracts.Assert(PassNum < 1); - Contracts.AssertValue(schema.Label); + Contracts.Assert(schema.Label.HasValue); var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); - _labelGetter = RowCursorUtils.GetLabelGetter(row, schema.Label.Index); + _labelGetter = RowCursorUtils.GetLabelGetter(row, schema.Label.Value.Index); _scoreGetter = row.GetGetter(score.Index); Contracts.AssertValue(_labelGetter); Contracts.AssertValue(_scoreGetter); - if (schema.Weight != null) - _weightGetter = row.GetGetter(schema.Weight.Index); + if (schema.Weight.HasValue) + _weightGetter = row.GetGetter(schema.Weight.Value.Index); } public override void ProcessRow() diff --git a/src/Microsoft.ML.Data/MLContext.cs b/src/Microsoft.ML.Data/MLContext.cs index 326562a978..73727a113b 100644 --- a/src/Microsoft.ML.Data/MLContext.cs +++ b/src/Microsoft.ML.Data/MLContext.cs @@ -2,11 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; using System; using System.ComponentModel.Composition; using System.ComponentModel.Composition.Hosting; +using Microsoft.ML.Data; namespace Microsoft.ML { diff --git a/src/Microsoft.ML.Data/Microsoft.ML.Data.csproj b/src/Microsoft.ML.Data/Microsoft.ML.Data.csproj index 2ebb859178..f3e7d96b59 100644 --- a/src/Microsoft.ML.Data/Microsoft.ML.Data.csproj +++ b/src/Microsoft.ML.Data/Microsoft.ML.Data.csproj @@ -7,18 +7,11 @@ CORECLR - - - True - True - TermStaticExtensions.tt - - - + @@ -26,34 +19,8 @@ - - - TextTemplatingFileGenerator - ConvertStaticExtensions.cs - - - TextTemplatingFileGenerator - TermStaticExtensions.cs - - - - - - True - True - ConvertStaticExtensions.tt - - - True - True - TermStaticExtensions.tt - TextTemplatingFileGenerator - TermStaticExtensions.cs - - - \ No newline at end of file diff --git a/src/Microsoft.ML.Data/Model/ModelHeader.cs b/src/Microsoft.ML.Data/Model/ModelHeader.cs index b74d0fdca6..88fcf9bdad 100644 --- a/src/Microsoft.ML.Data/Model/ModelHeader.cs +++ b/src/Microsoft.ML.Data/Model/ModelHeader.cs @@ -4,12 +4,11 @@ using System; using System.IO; -using System.Reflection; using System.Runtime.InteropServices; using System.Text; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.Model +namespace Microsoft.ML.Model { [StructLayout(LayoutKind.Explicit, Size = ModelHeader.Size)] internal struct ModelHeader diff --git a/src/Microsoft.ML.Data/Model/ModelLoadContext.cs b/src/Microsoft.ML.Data/Model/ModelLoadContext.cs index 749a226012..8838b031de 100644 --- a/src/Microsoft.ML.Data/Model/ModelLoadContext.cs +++ b/src/Microsoft.ML.Data/Model/ModelLoadContext.cs @@ -5,9 +5,9 @@ using System; using System.IO; using System.Text; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.Model +namespace Microsoft.ML.Model { /// /// This is a convenience context object for loading models from a repository, for diff --git a/src/Microsoft.ML.Data/Model/ModelLoading.cs b/src/Microsoft.ML.Data/Model/ModelLoading.cs index 981da9e797..06ebc0bf06 100644 --- a/src/Microsoft.ML.Data/Model/ModelLoading.cs +++ b/src/Microsoft.ML.Data/Model/ModelLoading.cs @@ -6,9 +6,9 @@ using System.IO; using System.Reflection; using System.Text; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.Model +namespace Microsoft.ML.Model { public sealed partial class ModelLoadContext : IDisposable { diff --git a/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs b/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs index 554795b388..aa8b722147 100644 --- a/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs +++ b/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs @@ -2,12 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.IO; using Microsoft.ML.Core.Data; using Microsoft.ML.Data; -using Microsoft.ML.Runtime.Data; -using System.IO; -namespace Microsoft.ML.Runtime +namespace Microsoft.ML { /// /// An object serving as a 'catalog' of available model operations. @@ -16,10 +15,25 @@ public sealed class ModelOperationsCatalog { internal IHostEnvironment Environment { get; } + public ExplainabilityTransforms Explainability { get; } + internal ModelOperationsCatalog(IHostEnvironment env) { Contracts.AssertValue(env); Environment = env; + + Explainability = new ExplainabilityTransforms(this); + } + + public abstract class SubCatalogBase + { + internal IHostEnvironment Environment { get; } + + protected SubCatalogBase(ModelOperationsCatalog owner) + { + Environment = owner.Environment; + } + } /// @@ -35,5 +49,31 @@ internal ModelOperationsCatalog(IHostEnvironment env) /// A readable, seekable stream to load from. /// The loaded model. public ITransformer Load(Stream stream) => TransformerChain.LoadFrom(Environment, stream); + + /// + /// The catalog of model explainability operations. + /// + public sealed class ExplainabilityTransforms : SubCatalogBase + { + internal ExplainabilityTransforms(ModelOperationsCatalog owner) : base(owner) + { + } + } + + /// + /// Create a prediction engine for one-time prediction. + /// + /// The class that defines the input data. + /// The class that defines the output data. + /// The transformer to use for prediction. + /// Additional settings of the input schema. + /// Additional settings of the output schema. + public PredictionEngine CreatePredictionEngine(ITransformer transformer, + SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null) + where TSrc : class + where TDst : class, new() + { + return new PredictionEngine(Environment, transformer, false, inputSchemaDefinition, outputSchemaDefinition); + } } } diff --git a/src/Microsoft.ML.Data/Model/ModelSaveContext.cs b/src/Microsoft.ML.Data/Model/ModelSaveContext.cs index eee8d23df7..d8da3f4edb 100644 --- a/src/Microsoft.ML.Data/Model/ModelSaveContext.cs +++ b/src/Microsoft.ML.Data/Model/ModelSaveContext.cs @@ -5,9 +5,9 @@ using System; using System.IO; using System.Text; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.Model +namespace Microsoft.ML.Model { /// /// This is a convenience context object for saving models to a repository, for diff --git a/src/Microsoft.ML.Data/Model/ModelSaving.cs b/src/Microsoft.ML.Data/Model/ModelSaving.cs index da15986d9d..a73e631bbe 100644 --- a/src/Microsoft.ML.Data/Model/ModelSaving.cs +++ b/src/Microsoft.ML.Data/Model/ModelSaving.cs @@ -6,7 +6,7 @@ using System.IO; using System.Text; -namespace Microsoft.ML.Runtime.Model +namespace Microsoft.ML.Model { public sealed partial class ModelSaveContext : IDisposable { diff --git a/src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs b/src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs index dfc8756303..6366c7a159 100644 --- a/src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs +++ b/src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs @@ -2,9 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Calibrator; +using Microsoft.ML.Data; -namespace Microsoft.ML.Runtime.Model.Onnx +namespace Microsoft.ML.Model.Onnx { [BestFriend] internal interface ICanSaveOnnx @@ -68,7 +69,7 @@ internal interface IBindableCanSaveOnnx : ICanSaveOnnx, ISchemaBindableMapper /// /// For simple mappers. Intended to be used for and - /// instances. + /// instances. /// [BestFriend] internal interface ISingleCanSaveOnnx : ICanSaveOnnx diff --git a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs index 38d9f77915..d8685af52c 100644 --- a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs +++ b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs @@ -3,9 +3,9 @@ // See the LICENSE file in the project root for more information. using System.Collections.Generic; -using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Data; -namespace Microsoft.ML.Runtime.Model.Onnx +namespace Microsoft.ML.Model.Onnx { [BestFriend] internal enum OnnxVersion { Stable = 0, Experimental = 1 } diff --git a/src/Microsoft.ML.Data/Model/Onnx/OnnxNode.cs b/src/Microsoft.ML.Data/Model/Onnx/OnnxNode.cs index 79df068b9b..804d5d90a9 100644 --- a/src/Microsoft.ML.Data/Model/Onnx/OnnxNode.cs +++ b/src/Microsoft.ML.Data/Model/Onnx/OnnxNode.cs @@ -4,9 +4,8 @@ using System; using System.Collections.Generic; -using Microsoft.ML.Runtime.Data; -namespace Microsoft.ML.Runtime.Model.Onnx +namespace Microsoft.ML.Model.Onnx { /// /// An abstraction for an ONNX node as created by diff --git a/src/Microsoft.ML.Data/Model/Pfa/BoundPfaContext.cs b/src/Microsoft.ML.Data/Model/Pfa/BoundPfaContext.cs index db07fb69cc..997b4d0dc8 100644 --- a/src/Microsoft.ML.Data/Model/Pfa/BoundPfaContext.cs +++ b/src/Microsoft.ML.Data/Model/Pfa/BoundPfaContext.cs @@ -5,10 +5,9 @@ using System.Collections.Generic; using System.Linq; using Microsoft.ML.Data; -using Microsoft.ML.Runtime.Data; using Newtonsoft.Json.Linq; -namespace Microsoft.ML.Runtime.Model.Pfa +namespace Microsoft.ML.Model.Pfa { using T = PfaUtils.Type; @@ -63,11 +62,11 @@ private void SetInput(Schema schema, HashSet toDrop) recordType["name"] = "DataInput"; var fields = new JArray(); var fieldNames = new HashSet(); - for (int c = 0; c < schema.ColumnCount; ++c) + for (int c = 0; c < schema.Count; ++c) { - if (schema.IsHidden(c)) + if (schema[c].IsHidden) continue; - string name = schema.GetColumnName(c); + string name = schema[c].Name; if (toDrop.Contains(name)) continue; JToken pfaType = PfaTypeOrNullForColumn(schema, c); @@ -104,7 +103,7 @@ private void SetInput(Schema schema, HashSet toDrop) /// The columns to output /// Returns a complete PFA program, where the output will correspond to the subset /// of columns from . - public JObject Finalize(ISchema schema, params string[] toOutput) + public JObject Finalize(Schema schema, params string[] toOutput) { _host.CheckValue(schema, nameof(schema)); _host.CheckValue(toOutput, nameof(toOutput)); @@ -162,12 +161,12 @@ public JObject Finalize(ISchema schema, params string[] toOutput) return Pfa.Finalize(); } - private JToken PfaTypeOrNullForColumn(ISchema schema, int col) + private JToken PfaTypeOrNullForColumn(Schema schema, int col) { _host.AssertValue(schema); - _host.Assert(0 <= col && col < schema.ColumnCount); + _host.Assert(0 <= col && col < schema.Count); - ColumnType type = schema.GetColumnType(col); + ColumnType type = schema[col].Type; return T.PfaTypeOrNullForColumnType(type); } diff --git a/src/Microsoft.ML.Data/Model/Pfa/ICanSavePfa.cs b/src/Microsoft.ML.Data/Model/Pfa/ICanSavePfa.cs index bcda835fe8..cc07e75a3c 100644 --- a/src/Microsoft.ML.Data/Model/Pfa/ICanSavePfa.cs +++ b/src/Microsoft.ML.Data/Model/Pfa/ICanSavePfa.cs @@ -2,10 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Calibrator; +using Microsoft.ML.Data; using Newtonsoft.Json.Linq; -namespace Microsoft.ML.Runtime.Model.Pfa +namespace Microsoft.ML.Model.Pfa { [BestFriend] internal interface ICanSavePfa @@ -72,7 +73,7 @@ internal interface IBindableCanSavePfa : ICanSavePfa, ISchemaBindableMapper /// /// For simple mappers. Intended to be used for and - /// instances. + /// instances. /// [BestFriend] internal interface ISingleCanSavePfa : ICanSavePfa diff --git a/src/Microsoft.ML.Data/Model/Pfa/ModelUtils.cs b/src/Microsoft.ML.Data/Model/Pfa/ModelUtils.cs index 110296e1d0..e33a9ae4a6 100644 --- a/src/Microsoft.ML.Data/Model/Pfa/ModelUtils.cs +++ b/src/Microsoft.ML.Data/Model/Pfa/ModelUtils.cs @@ -4,7 +4,7 @@ using System; -namespace Microsoft.ML.Runtime.Model +namespace Microsoft.ML.Model { internal static class ModelUtils { diff --git a/src/Microsoft.ML.Data/Model/Pfa/PfaContext.cs b/src/Microsoft.ML.Data/Model/Pfa/PfaContext.cs index 2d89916028..9619e195ca 100644 --- a/src/Microsoft.ML.Data/Model/Pfa/PfaContext.cs +++ b/src/Microsoft.ML.Data/Model/Pfa/PfaContext.cs @@ -6,7 +6,7 @@ using System.Linq; using Newtonsoft.Json.Linq; -namespace Microsoft.ML.Runtime.Model.Pfa +namespace Microsoft.ML.Model.Pfa { /// /// A context for defining a restricted sort of PFA output. diff --git a/src/Microsoft.ML.Data/Model/Pfa/PfaUtils.cs b/src/Microsoft.ML.Data/Model/Pfa/PfaUtils.cs index 9b4f6bcf3f..9215e233e0 100644 --- a/src/Microsoft.ML.Data/Model/Pfa/PfaUtils.cs +++ b/src/Microsoft.ML.Data/Model/Pfa/PfaUtils.cs @@ -2,11 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Utilities; using Newtonsoft.Json.Linq; -namespace Microsoft.ML.Runtime.Model.Pfa +namespace Microsoft.ML.Model.Pfa { [BestFriend] internal static class PfaUtils diff --git a/src/Microsoft.ML.Data/Model/Pfa/SavePfaCommand.cs b/src/Microsoft.ML.Data/Model/Pfa/SavePfaCommand.cs index 57a1f782c7..685fdacab2 100644 --- a/src/Microsoft.ML.Data/Model/Pfa/SavePfaCommand.cs +++ b/src/Microsoft.ML.Data/Model/Pfa/SavePfaCommand.cs @@ -5,19 +5,19 @@ using System; using System.Collections.Generic; using System.IO; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Command; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model.Pfa; +using Microsoft.ML; +using Microsoft.ML.Command; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model.Pfa; using Newtonsoft.Json; using Newtonsoft.Json.Linq; [assembly: LoadableClass(SavePfaCommand.Summary, typeof(SavePfaCommand), typeof(SavePfaCommand.Arguments), typeof(SignatureCommand), "Save PFA", "SavePfa", DocName = "command/SavePfa.md")] -namespace Microsoft.ML.Runtime.Model.Pfa +namespace Microsoft.ML.Model.Pfa { internal sealed class SavePfaCommand : DataCommand.ImplBase { @@ -184,11 +184,11 @@ private void Run(IChannel ch) } var toExport = new List(); - for (int i = 0; i < end.Schema.ColumnCount; ++i) + for (int i = 0; i < end.Schema.Count; ++i) { - if (end.Schema.IsHidden(i)) + if (end.Schema[i].IsHidden) continue; - var name = end.Schema.GetColumnName(i); + var name = end.Schema[i].Name; if (_outputsToDrop.Contains(name)) continue; if (!ctx.IsInput(name) || _keepInput) diff --git a/src/Microsoft.ML.Data/Model/PredictionEngineExtensions.cs b/src/Microsoft.ML.Data/Model/PredictionEngineExtensions.cs new file mode 100644 index 0000000000..bc288ef864 --- /dev/null +++ b/src/Microsoft.ML.Data/Model/PredictionEngineExtensions.cs @@ -0,0 +1,30 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data; + +namespace Microsoft.ML +{ + /// + /// Extension methods to create a prediction engine. + /// + public static class PredictionEngineExtensions + { + /// + /// Create a prediction engine for one-time prediction. + /// + /// The class that defines the input data. + /// The class that defines the output data. + /// The transformer to use for prediction. + /// The environment to use. + /// Additional settings of the input schema. + /// Additional settings of the output schema. + public static PredictionEngine CreatePredictionEngine(this ITransformer transformer, + IHostEnvironment env, SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null) + where TSrc : class + where TDst : class, new() + => new PredictionEngine(env, transformer, true, inputSchemaDefinition, outputSchemaDefinition); + } +} diff --git a/src/Microsoft.ML.Data/Model/Repository.cs b/src/Microsoft.ML.Data/Model/Repository.cs index c93e6d00b6..4e3b8d89c1 100644 --- a/src/Microsoft.ML.Data/Model/Repository.cs +++ b/src/Microsoft.ML.Data/Model/Repository.cs @@ -6,9 +6,9 @@ using System.Collections.Generic; using System.IO; using System.IO.Compression; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.Model +namespace Microsoft.ML.Model { /// /// Signature for a repository based model loader. This is the dual of ICanSaveModel. diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs index ea2126b91e..d8b1de1cc8 100644 --- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs +++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs @@ -2,24 +2,24 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Calibration; -using Microsoft.ML.Runtime.Internal.Internallearn; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Model.Onnx; -using Microsoft.ML.Runtime.Model.Pfa; -using Newtonsoft.Json.Linq; using System; using System.Collections; using System.Collections.Generic; +using System.Collections.Immutable; using System.IO; using System.Linq; -using Float = System.Single; +using Microsoft.ML; +using Microsoft.ML.Calibrator; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Calibration; +using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Model.Onnx; +using Microsoft.ML.Model.Pfa; +using Newtonsoft.Json.Linq; [assembly: LoadableClass(PlattCalibratorTrainer.Summary, typeof(PlattCalibratorTrainer), null, typeof(SignatureCalibrator), PlattCalibratorTrainer.UserName, @@ -79,7 +79,7 @@ [assembly: EntryPointModule(typeof(PavCalibratorTrainerFactory))] [assembly: EntryPointModule(typeof(PlattCalibratorTrainerFactory))] -namespace Microsoft.ML.Runtime.Internal.Calibration +namespace Microsoft.ML.Internal.Calibration { /// /// Signature for the loaders of calibrators. @@ -100,34 +100,24 @@ public interface ICalibratorTrainer /// Training calibrators: provide the output and the class label /// True if it needs more examples, false otherwise - bool ProcessTrainingExample(Float output, bool labelIs1, Float weight); + bool ProcessTrainingExample(float output, bool labelIs1, float weight); /// Finish up training after seeing all examples ICalibrator FinishTraining(IChannel ch); } - /// - /// An interface for probability calibrators. - /// - public interface ICalibrator - { - /// Given a classifier output, produce the probability - Float PredictProbability(Float output); - - /// Get the summary of current calibrator settings - string GetSummary(); - } - /// /// An interface for predictors that take care of their own calibration given an input data view. /// - public interface ISelfCalibratingPredictor + [BestFriend] + internal interface ISelfCalibratingPredictor { IPredictor Calibrate(IChannel ch, IDataView data, ICalibratorTrainer caliTrainer, int maxRows); } + [BestFriend] public abstract class CalibratedPredictorBase : - IDistPredictorProducing, + IDistPredictorProducing, ICanSaveInIniFormat, ICanSaveInTextFormat, ICanSaveInSourceCode, @@ -136,11 +126,11 @@ public abstract class CalibratedPredictorBase : { protected readonly IHost Host; - public IPredictorProducing SubPredictor { get; } + public IPredictorProducing SubPredictor { get; } public ICalibrator Calibrator { get; } public PredictionKind PredictionKind => SubPredictor.PredictionKind; - protected CalibratedPredictorBase(IHostEnvironment env, string name, IPredictorProducing predictor, ICalibrator calibrator) + protected CalibratedPredictorBase(IHostEnvironment env, string name, IPredictorProducing predictor, ICalibrator calibrator) { Contracts.CheckValue(env, nameof(env)); env.CheckNonWhiteSpace(name, nameof(name)); @@ -152,14 +142,14 @@ protected CalibratedPredictorBase(IHostEnvironment env, string name, IPredictorP Calibrator = calibrator; } - public void SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICalibrator calibrator = null) + void ICanSaveInIniFormat.SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICalibrator calibrator) { Host.Check(calibrator == null, "Too many calibrators."); var saver = SubPredictor as ICanSaveInIniFormat; saver?.SaveAsIni(writer, schema, Calibrator); } - public void SaveAsText(TextWriter writer, RoleMappedSchema schema) + void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema) { // REVIEW: What about the calibrator? var saver = SubPredictor as ICanSaveInTextFormat; @@ -167,7 +157,7 @@ public void SaveAsText(TextWriter writer, RoleMappedSchema schema) saver.SaveAsText(writer, schema); } - public void SaveAsCode(TextWriter writer, RoleMappedSchema schema) + void ICanSaveInSourceCode.SaveAsCode(TextWriter writer, RoleMappedSchema schema) { // REVIEW: What about the calibrator? var saver = SubPredictor as ICanSaveInSourceCode; @@ -175,7 +165,7 @@ public void SaveAsCode(TextWriter writer, RoleMappedSchema schema) saver.SaveAsCode(writer, schema); } - public void SaveSummary(TextWriter writer, RoleMappedSchema schema) + void ICanSaveSummary.SaveSummary(TextWriter writer, RoleMappedSchema schema) { // REVIEW: What about the calibrator? var saver = SubPredictor as ICanSaveSummary; @@ -184,7 +174,7 @@ public void SaveSummary(TextWriter writer, RoleMappedSchema schema) } /// - public IList> GetSummaryInKeyValuePairs(RoleMappedSchema schema) + IList> ICanGetSummaryInKeyValuePairs.GetSummaryInKeyValuePairs(RoleMappedSchema schema) { // REVIEW: What about the calibrator? var saver = SubPredictor as ICanGetSummaryInKeyValuePairs; @@ -200,10 +190,10 @@ protected void SaveCore(ModelSaveContext ctx) ctx.SaveModel(Calibrator, @"Calibrator"); } - protected static IPredictorProducing GetPredictor(IHostEnvironment env, ModelLoadContext ctx) + protected static IPredictorProducing GetPredictor(IHostEnvironment env, ModelLoadContext ctx) { - IPredictorProducing predictor; - ctx.LoadModel, SignatureLoadModel>(env, out predictor, ModelFileUtils.DirPredictor); + IPredictorProducing predictor; + ctx.LoadModel, SignatureLoadModel>(env, out predictor, ModelFileUtils.DirPredictor); return predictor; } @@ -215,42 +205,45 @@ protected static ICalibrator GetCalibrator(IHostEnvironment env, ModelLoadContex } } - public abstract class ValueMapperCalibratedPredictorBase : CalibratedPredictorBase, IValueMapperDist, IFeatureContributionMapper, + public abstract class ValueMapperCalibratedPredictorBase : CalibratedPredictorBase, IValueMapperDist, IFeatureContributionMapper, ICalculateFeatureContribution, IDistCanSavePfa, IDistCanSaveOnnx { private readonly IValueMapper _mapper; private readonly IFeatureContributionMapper _featureContribution; - public ColumnType InputType => _mapper.InputType; - public ColumnType OutputType => _mapper.OutputType; - public ColumnType DistType => NumberType.Float; + ColumnType IValueMapper.InputType => _mapper.InputType; + ColumnType IValueMapper.OutputType => _mapper.OutputType; + ColumnType IValueMapperDist.DistType => NumberType.Float; bool ICanSavePfa.CanSavePfa => (_mapper as ICanSavePfa)?.CanSavePfa == true; + + public FeatureContributionCalculator FeatureContributionClaculator => new FeatureContributionCalculator(this); + bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => (_mapper as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true; - protected ValueMapperCalibratedPredictorBase(IHostEnvironment env, string name, IPredictorProducing predictor, ICalibrator calibrator) + protected ValueMapperCalibratedPredictorBase(IHostEnvironment env, string name, IPredictorProducing predictor, ICalibrator calibrator) : base(env, name, predictor, calibrator) { Contracts.AssertValue(Host); _mapper = SubPredictor as IValueMapper; Host.Check(_mapper != null, "The predictor does not implement IValueMapper"); - Host.Check(_mapper.OutputType == NumberType.Float, "The output type of the predictor is expected to be Float"); + Host.Check(_mapper.OutputType == NumberType.Float, "The output type of the predictor is expected to be float"); _featureContribution = predictor as IFeatureContributionMapper; } - public ValueMapper GetMapper() + ValueMapper IValueMapper.GetMapper() { return _mapper.GetMapper(); } - public ValueMapper GetMapper() + ValueMapper IValueMapperDist.GetMapper() { - Host.Check(typeof(TOut) == typeof(Float)); - Host.Check(typeof(TDist) == typeof(Float)); - var map = GetMapper(); - ValueMapper del = - (in TIn src, ref Float score, ref Float prob) => + Host.Check(typeof(TOut) == typeof(float)); + Host.Check(typeof(TDist) == typeof(float)); + var map = ((IValueMapper)this).GetMapper(); + ValueMapper del = + (in TIn src, ref float score, ref float prob) => { map(in src, ref score); prob = Calibrator.PredictProbability(score); @@ -258,7 +251,7 @@ public ValueMapper GetMapper() return (ValueMapper)(Delegate)del; } - public ValueMapper> GetFeatureContributionMapper(int top, int bottom, bool normalize) + ValueMapper> IFeatureContributionMapper.GetFeatureContributionMapper(int top, int bottom, bool normalize) { // REVIEW: checking this a bit too late. Host.Check(_featureContribution != null, "Predictor does not implement IFeatureContributionMapper"); @@ -318,15 +311,17 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string } } - public sealed class CalibratedPredictor : ValueMapperCalibratedPredictorBase, ICanSaveModel + + [BestFriend] + internal sealed class CalibratedPredictor : ValueMapperCalibratedPredictorBase, ICanSaveModel { - public CalibratedPredictor(IHostEnvironment env, IPredictorProducing predictor, ICalibrator calibrator) + internal CalibratedPredictor(IHostEnvironment env, IPredictorProducing predictor, ICalibrator calibrator) : base(env, RegistrationName, predictor, calibrator) { } - public const string LoaderSignature = "CaliPredExec"; - public const string RegistrationName = "CalibratedPredictor"; + internal const string LoaderSignature = "CaliPredExec"; + internal const string RegistrationName = "CalibratedPredictor"; private static VersionInfo GetVersionInfo() { @@ -354,7 +349,7 @@ private CalibratedPredictor(IHostEnvironment env, ModelLoadContext ctx) { } - public static CalibratedPredictor Create(IHostEnvironment env, ModelLoadContext ctx) + private static CalibratedPredictor Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(ctx, nameof(ctx)); // Can load either the old "bulk" model or standard "cali". The two formats are identical. @@ -374,22 +369,23 @@ public void Save(ModelSaveContext ctx) } } - public sealed class FeatureWeightsCalibratedPredictor : + [BestFriend] + internal sealed class FeatureWeightsCalibratedPredictor : ValueMapperCalibratedPredictorBase, - IPredictorWithFeatureWeights, + IPredictorWithFeatureWeights, ICanSaveModel { - private readonly IPredictorWithFeatureWeights _featureWeights; + private readonly IPredictorWithFeatureWeights _featureWeights; - public FeatureWeightsCalibratedPredictor(IHostEnvironment env, IPredictorWithFeatureWeights predictor, + internal FeatureWeightsCalibratedPredictor(IHostEnvironment env, IPredictorWithFeatureWeights predictor, ICalibrator calibrator) : base(env, RegistrationName, predictor, calibrator) { _featureWeights = predictor; } - public const string LoaderSignature = "FeatWCaliPredExec"; - public const string RegistrationName = "FeatureWeightsCalibratedPredictor"; + internal const string LoaderSignature = "FeatWCaliPredExec"; + internal const string RegistrationName = "FeatureWeightsCalibratedPredictor"; private static VersionInfo GetVersionInfo() { @@ -405,11 +401,11 @@ private static VersionInfo GetVersionInfo() private FeatureWeightsCalibratedPredictor(IHostEnvironment env, ModelLoadContext ctx) : base(env, RegistrationName, GetPredictor(env, ctx), GetCalibrator(env, ctx)) { - Host.Check(SubPredictor is IPredictorWithFeatureWeights, "Predictor does not implement " + nameof(IPredictorWithFeatureWeights)); - _featureWeights = (IPredictorWithFeatureWeights)SubPredictor; + Host.Check(SubPredictor is IPredictorWithFeatureWeights, "Predictor does not implement " + nameof(IPredictorWithFeatureWeights)); + _featureWeights = (IPredictorWithFeatureWeights)SubPredictor; } - public static FeatureWeightsCalibratedPredictor Create(IHostEnvironment env, ModelLoadContext ctx) + private static FeatureWeightsCalibratedPredictor Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); @@ -425,7 +421,7 @@ public void Save(ModelSaveContext ctx) SaveCore(ctx); } - public void GetFeatureWeights(ref VBuffer weights) + public void GetFeatureWeights(ref VBuffer weights) { _featureWeights.GetFeatureWeights(ref weights); } @@ -437,22 +433,22 @@ public void GetFeatureWeights(ref VBuffer weights) /// public sealed class ParameterMixingCalibratedPredictor : ValueMapperCalibratedPredictorBase, - IParameterMixer, - IPredictorWithFeatureWeights, + IParameterMixer, + IPredictorWithFeatureWeights, ICanSaveModel { - private readonly IPredictorWithFeatureWeights _featureWeights; + private readonly IPredictorWithFeatureWeights _featureWeights; - public ParameterMixingCalibratedPredictor(IHostEnvironment env, IPredictorWithFeatureWeights predictor, ICalibrator calibrator) + internal ParameterMixingCalibratedPredictor(IHostEnvironment env, IPredictorWithFeatureWeights predictor, ICalibrator calibrator) : base(env, RegistrationName, predictor, calibrator) { - Host.Check(predictor is IParameterMixer, "Predictor does not implement " + nameof(IParameterMixer)); + Host.Check(predictor is IParameterMixer, "Predictor does not implement " + nameof(IParameterMixer)); Host.Check(calibrator is IParameterMixer, "Calibrator does not implement " + nameof(IParameterMixer)); _featureWeights = predictor; } - public const string LoaderSignature = "PMixCaliPredExec"; - public const string RegistrationName = "ParameterMixingCalibratedPredictor"; + internal const string LoaderSignature = "PMixCaliPredExec"; + internal const string RegistrationName = "ParameterMixingCalibratedPredictor"; private static VersionInfo GetVersionInfo() { @@ -468,12 +464,12 @@ private static VersionInfo GetVersionInfo() private ParameterMixingCalibratedPredictor(IHostEnvironment env, ModelLoadContext ctx) : base(env, RegistrationName, GetPredictor(env, ctx), GetCalibrator(env, ctx)) { - Host.Check(SubPredictor is IParameterMixer, "Predictor does not implement " + nameof(IParameterMixer)); - Host.Check(SubPredictor is IPredictorWithFeatureWeights, "Predictor does not implement " + nameof(IPredictorWithFeatureWeights)); - _featureWeights = (IPredictorWithFeatureWeights)SubPredictor; + Host.Check(SubPredictor is IParameterMixer, "Predictor does not implement " + nameof(IParameterMixer)); + Host.Check(SubPredictor is IPredictorWithFeatureWeights, "Predictor does not implement " + nameof(IPredictorWithFeatureWeights)); + _featureWeights = (IPredictorWithFeatureWeights)SubPredictor; } - public static ParameterMixingCalibratedPredictor Create(IHostEnvironment env, ModelLoadContext ctx) + private static ParameterMixingCalibratedPredictor Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); @@ -489,19 +485,19 @@ public void Save(ModelSaveContext ctx) SaveCore(ctx); } - public void GetFeatureWeights(ref VBuffer weights) + public void GetFeatureWeights(ref VBuffer weights) { _featureWeights.GetFeatureWeights(ref weights); } - public IParameterMixer CombineParameters(IList> models) + IParameterMixer IParameterMixer.CombineParameters(IList> models) { var predictors = models.Select( m => { var model = m as ParameterMixingCalibratedPredictor; Contracts.Assert(model != null); - return (IParameterMixer)model.SubPredictor; + return (IParameterMixer)model.SubPredictor; }).ToArray(); var calibrators = models.Select( m => @@ -512,11 +508,12 @@ public IParameterMixer CombineParameters(IList> mo }).ToArray(); var combinedPredictor = predictors[0].CombineParameters(predictors); var combinedCalibrator = calibrators[0].CombineParameters(calibrators); - return new ParameterMixingCalibratedPredictor(Host, (IPredictorWithFeatureWeights)combinedPredictor, (ICalibrator)combinedCalibrator); + return new ParameterMixingCalibratedPredictor(Host, (IPredictorWithFeatureWeights)combinedPredictor, (ICalibrator)combinedCalibrator); } } - public sealed class SchemaBindableCalibratedPredictor : CalibratedPredictorBase, ISchemaBindableMapper, ICanSaveModel, + [BestFriend] + internal sealed class SchemaBindableCalibratedPredictor : CalibratedPredictorBase, ISchemaBindableMapper, ICanSaveModel, IBindableCanSavePfa, IBindableCanSaveOnnx, IFeatureContributionMapper { private sealed class Bound : ISchemaBoundRowMapper @@ -528,7 +525,7 @@ private sealed class Bound : ISchemaBoundRowMapper public ISchemaBindableMapper Bindable => _parent; public RoleMappedSchema InputRoleMappedSchema => _predictor.InputRoleMappedSchema; public Schema InputSchema => _predictor.InputSchema; - public Schema Schema { get; } + public Schema OutputSchema { get; } public Bound(IHostEnvironment env, SchemaBindableCalibratedPredictor parent, RoleMappedSchema schema) { @@ -537,16 +534,16 @@ public Bound(IHostEnvironment env, SchemaBindableCalibratedPredictor parent, Rol _parent = parent; _predictor = _parent._bindable.Bind(env, schema) as ISchemaBoundRowMapper; env.Check(_predictor != null, "Predictor is not a row-to-row mapper"); - if (!_predictor.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out _scoreCol)) + if (!_predictor.OutputSchema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out _scoreCol)) throw env.Except("Predictor does not output a score"); - var scoreType = _predictor.Schema.GetColumnType(_scoreCol); + var scoreType = _predictor.OutputSchema[_scoreCol].Type; env.Check(!scoreType.IsVector && scoreType.IsNumber); - Schema = Schema.Create(new BinaryClassifierSchema()); + OutputSchema = Schema.Create(new BinaryClassifierSchema()); } public Func GetDependencies(Func predicate) { - for (int i = 0; i < Schema.ColumnCount; i++) + for (int i = 0; i < OutputSchema.Count; i++) { if (predicate(i)) return _predictor.GetDependencies(col => true); @@ -559,10 +556,10 @@ public Func GetDependencies(Func predicate) return _predictor.GetInputColumnRoles(); } - public IRow GetRow(IRow input, Func predicate, out Action disposer) + public Row GetRow(Row input, Func predicate) { Func predictorPredicate = col => false; - for (int i = 0; i < Schema.ColumnCount; i++) + for (int i = 0; i < OutputSchema.Count; i++) { if (predicate(i)) { @@ -570,26 +567,26 @@ public IRow GetRow(IRow input, Func predicate, out Action disposer) break; } } - var predictorRow = _predictor.GetRow(input, predictorPredicate, out disposer); - var getters = new Delegate[Schema.ColumnCount]; - for (int i = 0; i < Schema.ColumnCount - 1; i++) + var predictorRow = _predictor.GetRow(input, predictorPredicate); + var getters = new Delegate[OutputSchema.Count]; + for (int i = 0; i < OutputSchema.Count - 1; i++) { - var type = predictorRow.Schema.GetColumnType(i); + var type = predictorRow.Schema[i].Type; if (!predicate(i)) continue; getters[i] = Utils.MarshalInvoke(GetPredictorGetter, type.RawType, predictorRow, i); } - if (predicate(Schema.ColumnCount - 1)) - getters[Schema.ColumnCount - 1] = GetProbGetter(predictorRow); - return new SimpleRow(Schema, predictorRow, getters); + if (predicate(OutputSchema.Count - 1)) + getters[OutputSchema.Count - 1] = GetProbGetter(predictorRow); + return new SimpleRow(OutputSchema, predictorRow, getters); } - private Delegate GetPredictorGetter(IRow input, int col) + private Delegate GetPredictorGetter(Row input, int col) { return input.GetGetter(col); } - private Delegate GetProbGetter(IRow input) + private Delegate GetProbGetter(Row input) { var scoreGetter = RowCursorUtils.GetGetterAs(NumberType.R4, input, _scoreCol); ValueGetter probGetter = @@ -606,7 +603,7 @@ private Delegate GetProbGetter(IRow input) private readonly ISchemaBindableMapper _bindable; private readonly IFeatureContributionMapper _featureContribution; - public const string LoaderSignature = "SchemaBindableCalibrated"; + internal const string LoaderSignature = "SchemaBindableCalibrated"; private static VersionInfo GetVersionInfo() { @@ -627,7 +624,7 @@ private static VersionInfo GetVersionInfo() bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => (_bindable as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true; - public SchemaBindableCalibratedPredictor(IHostEnvironment env, IPredictorProducing predictor, ICalibrator calibrator) + internal SchemaBindableCalibratedPredictor(IHostEnvironment env, IPredictorProducing predictor, ICalibrator calibrator) : base(env, LoaderSignature, predictor, calibrator) { _bindable = ScoreUtils.GetSchemaBindableMapper(Host, SubPredictor); @@ -641,7 +638,7 @@ private SchemaBindableCalibratedPredictor(IHostEnvironment env, ModelLoadContext _featureContribution = SubPredictor as IFeatureContributionMapper; } - public static SchemaBindableCalibratedPredictor Create(IHostEnvironment env, ModelLoadContext ctx) + private static SchemaBindableCalibratedPredictor Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); @@ -682,7 +679,7 @@ public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema) return new Bound(Host, this, schema); } - public ValueMapper> GetFeatureContributionMapper(int top, int bottom, bool normalize) + ValueMapper> IFeatureContributionMapper.GetFeatureContributionMapper(int top, int bottom, bool normalize) { // REVIEW: checking this a bit too late. Host.Check(_featureContribution != null, "Predictor does not implement " + nameof(IFeatureContributionMapper)); @@ -693,6 +690,9 @@ public ValueMapper> GetFeatureContributionMapper)) + if (!(predictor is IPredictorProducing)) { ch.Info("Not training a calibrator because the predictor does not implement IPredictorProducing."); return false; @@ -728,14 +728,14 @@ private static bool NeedCalibration(IHostEnvironment env, IChannel ch, ICalibrat var bindable = ScoreUtils.GetSchemaBindableMapper(env, predictor); var bound = bindable.Bind(env, schema); - var outputSchema = bound.Schema; + var outputSchema = bound.OutputSchema; int scoreCol; if (!outputSchema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out scoreCol)) { ch.Info("Not training a calibrator because the predictor does not output a score column."); return false; } - var type = outputSchema.GetColumnType(scoreCol); + var type = outputSchema[scoreCol].Type; if (type != NumberType.Float) { ch.Info("Not training a calibrator because the predictor output is {0}, but expected to be {1}.", type, NumberType.R4); @@ -768,7 +768,7 @@ public static IPredictor TrainCalibratorIfNeeded(IHostEnvironment env, IChannel if (!NeedCalibration(env, ch, calibrator, trainer, predictor, data.Schema)) return predictor; - return TrainCalibrator(env, ch, calibrator, maxRows, predictor, data); + return GetCalibratedPredictor(env, ch, calibrator, predictor, data, maxRows); } /// @@ -777,33 +777,50 @@ public static IPredictor TrainCalibratorIfNeeded(IHostEnvironment env, IChannel /// The environment to use. /// The channel. /// The calibrator trainer. + /// The predictor that needs calibration. + /// The examples to used for calibrator training. /// The maximum rows to use for calibrator training. + /// The original predictor, if no calibration is needed, + /// or a metapredictor that wraps the original predictor and the newly trained calibrator. + public static IPredictor GetCalibratedPredictor(IHostEnvironment env, IChannel ch, ICalibratorTrainer caliTrainer, + IPredictor predictor, RoleMappedData data, int maxRows = _maxCalibrationExamples) + { + var trainedCalibrator = TrainCalibrator(env, ch, caliTrainer, predictor, data, maxRows); + return CreateCalibratedPredictor(env, (IPredictorProducing)predictor, trainedCalibrator); + } + + /// + /// Trains a calibrator. + /// + /// The environment to use. + /// The channel. + /// The calibrator trainer. /// The predictor that needs calibration. /// The examples to used for calibrator training. + /// The maximum rows to use for calibrator training. /// The original predictor, if no calibration is needed, /// or a metapredictor that wraps the original predictor and the newly trained calibrator. - public static IPredictor TrainCalibrator(IHostEnvironment env, IChannel ch, ICalibratorTrainer caliTrainer, - int maxRows, IPredictor predictor, RoleMappedData data) + public static ICalibrator TrainCalibrator(IHostEnvironment env, IChannel ch, ICalibratorTrainer caliTrainer, IPredictor predictor, RoleMappedData data, int maxRows = _maxCalibrationExamples) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ch, nameof(ch)); ch.CheckValue(predictor, nameof(predictor)); ch.CheckValue(data, nameof(data)); - ch.CheckParam(data.Schema.Label != null, nameof(data), "data must have a Label column"); + ch.CheckParam(data.Schema.Label.HasValue, nameof(data), "data must have a Label column"); var scored = ScoreUtils.GetScorer(predictor, data, env, null); if (caliTrainer.NeedsTraining) { int labelCol; - if (!scored.Schema.TryGetColumnIndex(data.Schema.Label.Name, out labelCol)) + if (!scored.Schema.TryGetColumnIndex(data.Schema.Label.Value.Name, out labelCol)) throw ch.Except("No label column found"); int scoreCol; if (!scored.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out scoreCol)) throw ch.Except("No score column found"); - int weightCol; - if (data.Schema.Weight == null || !scored.Schema.TryGetColumnIndex(data.Schema.Weight.Name, out weightCol)) - weightCol = -1; + int weightCol = -1; + if (data.Schema.Weight?.Name is string weightName && scored.Schema.GetColumnOrNull(weightName)?.Index is int weightIdx) + weightCol = weightIdx; ch.Info("Training calibrator."); using (var cursor = scored.GetRowCursor(col => col == labelCol || col == scoreCol || col == weightCol)) { @@ -835,11 +852,10 @@ public static IPredictor TrainCalibrator(IHostEnvironment env, IChannel ch, ICal } } } - var cali = caliTrainer.FinishTraining(ch); - return CreateCalibratedPredictor(env, (IPredictorProducing)predictor, cali); + return caliTrainer.FinishTraining(ch); } - public static IPredictorProducing CreateCalibratedPredictor(IHostEnvironment env, IPredictorProducing predictor, ICalibrator cali) + public static IPredictorProducing CreateCalibratedPredictor(IHostEnvironment env, IPredictorProducing predictor, ICalibrator cali) { Contracts.Assert(predictor != null); if (cali == null) @@ -852,8 +868,8 @@ public static IPredictorProducing CreateCalibratedPredictor(IHostEnvironm predictor = p.SubPredictor; } // REVIEW: Split the requirement for IPredictorWithFeatureWeights into a different class. - var predWithFeatureScores = predictor as IPredictorWithFeatureWeights; - if (predWithFeatureScores != null && predictor is IParameterMixer && cali is IParameterMixer) + var predWithFeatureScores = predictor as IPredictorWithFeatureWeights; + if (predWithFeatureScores != null && predictor is IParameterMixer && cali is IParameterMixer) return new ParameterMixingCalibratedPredictor(env, predWithFeatureScores, cali); if (predictor is IValueMapper) return new CalibratedPredictor(env, predictor, cali); @@ -862,7 +878,7 @@ public static IPredictorProducing CreateCalibratedPredictor(IHostEnvironm } [TlcModule.Component(Name = "NaiveCalibrator", FriendlyName = "Naive Calibrator", Alias = "Naive")] - public sealed class NaiveCalibratorTrainerFactory : ICalibratorTrainerFactory + internal sealed class NaiveCalibratorTrainerFactory : ICalibratorTrainerFactory { public ICalibratorTrainer CreateComponent(IHostEnvironment env) { @@ -870,40 +886,49 @@ public ICalibratorTrainer CreateComponent(IHostEnvironment env) } } + /// + /// Trains a by dividing the range of the outputs into equally sized bins. + /// The probability of belonging to a particular class, for example class 1, is the number of class 1 instances in the bin, divided by the total number + /// of instances in that bin. + /// public sealed class NaiveCalibratorTrainer : ICalibratorTrainer { private readonly IHost _host; - private List _cMargins; - private List _ncMargins; + private List _cMargins; + private List _ncMargins; - private int _numBins; - private Float _binSize; - private Float _min; - private Float _max; - private Float[] _binProbs; + public int NumBins; + public float BinSize; + public float Min; + public float Max; + public float[] BinProbs; // REVIEW: The others have user/load names of calibraTION, but this has calibratOR. - public const string UserName = "Naive Calibrator"; - public const string LoadName = "NaiveCalibrator"; + internal const string UserName = "Naive Calibrator"; + internal const string LoadName = "NaiveCalibrator"; internal const string Summary = "Naive calibrator divides the range of the outputs into equally sized bins. In each bin, " + "the probability of belonging to class 1 is the number of class 1 instances in the bin, divided by the total number " + "of instances in the bin."; + // REVIEW: does this need a ctor that initialized the parameters to given values? + /// + /// Initializes a new instance of . + /// public NaiveCalibratorTrainer(IHostEnvironment env) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(LoadName); - _cMargins = new List(); - _ncMargins = new List(); - _numBins = 200; - _min = Float.MaxValue; - _max = Float.MinValue; + _cMargins = new List(); + _ncMargins = new List(); + NumBins = 200; + Min = float.MaxValue; + Max = float.MinValue; } - public bool NeedsTraining => true; + bool ICalibratorTrainer.NeedsTraining => true; - public bool ProcessTrainingExample(Float output, bool labelIs1, Float weight) + public bool ProcessTrainingExample(float output, bool labelIs1, float weight) { //AP todo proper weighting here if (labelIs1) @@ -917,57 +942,57 @@ public bool ProcessTrainingExample(Float output, bool labelIs1, Float weight) return true; } - public ICalibrator FinishTraining(IChannel ch) + ICalibrator ICalibratorTrainer.FinishTraining(IChannel ch) { - Float[] cOutputs = _cMargins.ToArray(); + float[] cOutputs = _cMargins.ToArray(); ch.Check(cOutputs.Length > 0, "Calibrator trained on zero instances."); - Float minC = MathUtils.Min(cOutputs); - Float maxC = MathUtils.Max(cOutputs); + float minC = MathUtils.Min(cOutputs); + float maxC = MathUtils.Max(cOutputs); - Float[] ncOutputs = _ncMargins.ToArray(); - Float minNC = MathUtils.Min(ncOutputs); - Float maxNC = MathUtils.Max(ncOutputs); + float[] ncOutputs = _ncMargins.ToArray(); + float minNC = MathUtils.Min(ncOutputs); + float maxNC = MathUtils.Max(ncOutputs); - _min = (minC < minNC) ? minC : minNC; - _max = (maxC > maxNC) ? maxC : maxNC; - _binSize = (_max - _min) / _numBins; + Min = (minC < minNC) ? minC : minNC; + Max = (maxC > maxNC) ? maxC : maxNC; + BinSize = (Max - Min) / NumBins; - Float[] cBins = new Float[_numBins]; - Float[] ncBins = new Float[_numBins]; + float[] cBins = new float[NumBins]; + float[] ncBins = new float[NumBins]; - foreach (Float xi in cOutputs) + foreach (float xi in cOutputs) { - int binIdx = NaiveCalibrator.GetBinIdx(xi, _min, _binSize, _numBins); + int binIdx = NaiveCalibrator.GetBinIdx(xi, Min, BinSize, NumBins); cBins[binIdx]++; } - foreach (Float xi in ncOutputs) + foreach (float xi in ncOutputs) { - int binIdx = NaiveCalibrator.GetBinIdx(xi, _min, _binSize, _numBins); + int binIdx = NaiveCalibrator.GetBinIdx(xi, Min, BinSize, NumBins); ncBins[binIdx]++; } - _binProbs = new Float[_numBins]; - for (int i = 0; i < _numBins; i++) + BinProbs = new float[NumBins]; + for (int i = 0; i < NumBins; i++) { if (cBins[i] + ncBins[i] == 0) - _binProbs[i] = 0; + BinProbs[i] = 0; else - _binProbs[i] = cBins[i] / (cBins[i] + ncBins[i]); + BinProbs[i] = cBins[i] / (cBins[i] + ncBins[i]); } - return new NaiveCalibrator(_host, _min, _binSize, _binProbs); + return new NaiveCalibrator(_host, Min, BinSize, BinProbs); } } /// - /// The naive binning-based calibrator + /// The naive binning-based calibrator. /// public sealed class NaiveCalibrator : ICalibrator, ICanSaveInBinaryFormat { - public const string LoaderSignature = "NaiveCaliExec"; - public const string RegistrationName = "NaiveCalibrator"; + internal const string LoaderSignature = "NaiveCaliExec"; + internal const string RegistrationName = "NaiveCalibrator"; private static VersionInfo GetVersionInfo() { @@ -981,18 +1006,28 @@ private static VersionInfo GetVersionInfo() } private readonly IHost _host; - private readonly Float _binSize; - private readonly Float _min; - private readonly Float[] _binProbs; - /// Create a default calibrator - public NaiveCalibrator(IHostEnvironment env, Float min, Float binSize, Float[] binProbs) + /// The bin size. + public readonly float BinSize; + + /// The minimum value in the first bin. + public readonly float Min; + + /// The value of probability in each bin. + public readonly float[] BinProbs; + + /// Initializes a new instance of . + /// The to use. + /// The minimum value in the first bin. + /// The values of the probability in each bin. + /// The bin size. + public NaiveCalibrator(IHostEnvironment env, float min, float binSize, float[] binProbs) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(RegistrationName); - _min = min; - _binSize = binSize; - _binProbs = binProbs; + Min = min; + BinSize = binSize; + BinProbs = binProbs; } private NaiveCalibrator(IHostEnvironment env, ModelLoadContext ctx) @@ -1002,26 +1037,26 @@ private NaiveCalibrator(IHostEnvironment env, ModelLoadContext ctx) _host.CheckValue(ctx, nameof(ctx)); // *** Binary format *** - // int: sizeof(Float) - // Float: bin size - // Float: minimum value of first bin + // int: sizeof(float) + // float: bin size + // float: minimum value of first bin // int: number of bins - // Float[]: probability in each bin + // float[]: probability in each bin int cbFloat = ctx.Reader.ReadInt32(); - _host.CheckDecode(cbFloat == sizeof(Float)); + _host.CheckDecode(cbFloat == sizeof(float)); - _binSize = ctx.Reader.ReadFloat(); - _host.CheckDecode(0 < _binSize && _binSize < Float.PositiveInfinity); + BinSize = ctx.Reader.ReadFloat(); + _host.CheckDecode(0 < BinSize && BinSize < float.PositiveInfinity); - _min = ctx.Reader.ReadFloat(); - _host.CheckDecode(FloatUtils.IsFinite(_min)); + Min = ctx.Reader.ReadFloat(); + _host.CheckDecode(FloatUtils.IsFinite(Min)); - _binProbs = ctx.Reader.ReadFloatArray(); - _host.CheckDecode(Utils.Size(_binProbs) > 0); - _host.CheckDecode(_binProbs.All(x => (0 <= x && x <= 1))); + BinProbs = ctx.Reader.ReadFloatArray(); + _host.CheckDecode(Utils.Size(BinProbs) > 0); + _host.CheckDecode(BinProbs.All(x => (0 <= x && x <= 1))); } - public static NaiveCalibrator Create(IHostEnvironment env, ModelLoadContext ctx) + private static NaiveCalibrator Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); @@ -1041,30 +1076,30 @@ private void SaveCore(ModelSaveContext ctx) ctx.SetVersionInfo(GetVersionInfo()); // *** Binary format *** - // int: sizeof(Float) - // Float: bin size - // Float: minimum value of first bin + // int: sizeof(float) + // float: bin size + // float: minimum value of first bin // int: number of bins - // Float[]: probability in each bin - ctx.Writer.Write(sizeof(Float)); - ctx.Writer.Write(_binSize); - ctx.Writer.Write(_min); - ctx.Writer.WriteSingleArray(_binProbs); + // float[]: probability in each bin + ctx.Writer.Write(sizeof(float)); + ctx.Writer.Write(BinSize); + ctx.Writer.Write(Min); + ctx.Writer.WriteSingleArray(BinProbs); } /// /// Given a classifier output, produce the probability /// - public Float PredictProbability(Float output) + public float PredictProbability(float output) { - if (Float.IsNaN(output)) + if (float.IsNaN(output)) return output; - int binIdx = GetBinIdx(output, _min, _binSize, _binProbs.Length); - return _binProbs[binIdx]; + int binIdx = GetBinIdx(output, Min, BinSize, BinProbs.Length); + return BinProbs[binIdx]; } // get the bin for a given output - internal static int GetBinIdx(Float output, Float min, Float binSize, int numBins) + internal static int GetBinIdx(float output, float min, float binSize, int numBins) { int binIdx = (int)((output - min) / binSize); if (binIdx >= numBins) @@ -1077,10 +1112,13 @@ internal static int GetBinIdx(Float output, Float min, Float binSize, int numBin /// Get the summary of current calibrator settings public string GetSummary() { - return string.Format("Naive Calibrator has {0} bins, starting at {1}, with bin size of {2}", _binProbs.Length, _min, _binSize); + return string.Format("Naive Calibrator has {0} bins, starting at {1}, with bin size of {2}", BinProbs.Length, Min, BinSize); } } + /// + /// Base class for calibrator trainers. + /// public abstract class CalibratorTrainerBase : ICalibratorTrainer { protected readonly IHost Host; @@ -1096,12 +1134,12 @@ protected CalibratorTrainerBase(IHostEnvironment env, string name) MaxNumSamples = DefaultMaxNumSamples; } - public bool NeedsTraining { get { return true; } } + bool ICalibratorTrainer.NeedsTraining => true; /// /// Training calibrators: provide the classifier output and the class label /// - public bool ProcessTrainingExample(Float output, bool labelIs1, Float weight) + bool ICalibratorTrainer.ProcessTrainingExample(float output, bool labelIs1, float weight) { if (Data == null) Data = new CalibrationDataStore(MaxNumSamples); @@ -1109,17 +1147,20 @@ public bool ProcessTrainingExample(Float output, bool labelIs1, Float weight) return true; } - public ICalibrator FinishTraining(IChannel ch) + ICalibrator ICalibratorTrainer.FinishTraining(IChannel ch) { ch.Check(Data != null, "Calibrator trained on zero instances."); - return CreateCalibrator(ch); + var calibrator = CreateCalibrator(ch); + Data = null; + return calibrator; } public abstract ICalibrator CreateCalibrator(IChannel ch); } [TlcModule.Component(Name = "PlattCalibrator", FriendlyName = "Platt Calibrator", Aliases = new[] { "Platt", "Sigmoid" }, Desc = "Platt calibration.")] - public sealed class PlattCalibratorTrainerFactory : ICalibratorTrainerFactory + [BestFriend] + internal sealed class PlattCalibratorTrainerFactory : ICalibratorTrainerFactory { public ICalibratorTrainer CreateComponent(IHostEnvironment env) { @@ -1129,9 +1170,6 @@ public ICalibratorTrainer CreateComponent(IHostEnvironment env) public sealed class PlattCalibratorTrainer : CalibratorTrainerBase { - private Double _paramA; - private Double _paramB; - internal const string UserName = "Sigmoid Calibration"; internal const string LoadName = "PlattCalibration"; internal const string Summary = "This model was introduced by Platt in the paper Probabilistic Outputs for Support Vector Machines " @@ -1140,13 +1178,13 @@ public sealed class PlattCalibratorTrainer : CalibratorTrainerBase public PlattCalibratorTrainer(IHostEnvironment env) : base(env, LoadName) { - } public override ICalibrator CreateCalibrator(IChannel ch) { - _paramA = 0; - _paramB = 0; + Double slope = 0; + Double offset = 0; + Double prior0 = 0; Double prior1 = 0; long n = 0; @@ -1160,12 +1198,12 @@ public override ICalibrator CreateCalibrator(IChannel ch) n++; } if (n == 0) - return new PlattCalibrator(Host, _paramA, _paramB); + return new PlattCalibrator(Host, slope, offset); - _paramA = 0; + slope = 0; // Initialize B to be the marginal probability of class // smoothed i.e. P(+ | x) = (N+ + 1) / (N + 2) - _paramB = Math.Log((prior0 + 1) / (prior1 + 1)); + offset = Math.Log((prior0 + 1) / (prior1 + 1)); // OK. We're going to maximize the likelihood of the output by // minimizing the cross-entropy of the output. Here's a @@ -1176,8 +1214,8 @@ public override ICalibrator CreateCalibrator(IChannel ch) Double lambda = 0.001; Double olderr = Double.MaxValue / 2; // array to store current estimate of probability of training points - Float[] pp = new Float[n]; - Float defValue = (Float)((prior1 + 1) / (prior0 + prior1 + 2)); + float[] pp = new float[n]; + float defValue = (float)((prior1 + 1) / (prior0 + prior1 + 2)); for (int i = 0; i < n; i++) pp[i] = defValue; @@ -1219,8 +1257,8 @@ public override ICalibrator CreateCalibrator(IChannel ch) break; } Double err = 0; - Double oldA = _paramA; - Double oldB = _paramB; + Double oldA = slope; + Double oldB = offset; // Loop until you get a increase in the goodness of fit for (; ; ) { @@ -1232,8 +1270,8 @@ public override ICalibrator CreateCalibrator(IChannel ch) continue; } // This is the Newton-Raphson step (with lambda as stabilizer) - _paramA = oldA + ((b + lambda) * d - c * e) / det; - _paramB = oldB + ((a + lambda) * e - c * d) / det; + slope = oldA + ((b + lambda) * d - c * e) / det; + offset = oldB + ((a + lambda) * e - c * d) / det; // Now, compute goodness of fit err = 0; @@ -1241,7 +1279,7 @@ public override ICalibrator CreateCalibrator(IChannel ch) foreach (var d_i in Data) { var y = d_i.Target ? d_i.Score : -d_i.Score; - var p = PlattCalibrator.PredictProbability(d_i.Score, _paramA, _paramB); + var p = PlattCalibrator.PredictProbability(d_i.Score, slope, offset); var t = d_i.Target ? hiTarget : loTarget; var weight = d_i.Weight; pp[i] = p; @@ -1285,7 +1323,7 @@ public override ICalibrator CreateCalibrator(IChannel ch) break; } - return new PlattCalibrator(Host, _paramA, _paramB); + return new PlattCalibrator(Host, slope, offset); } } @@ -1306,15 +1344,15 @@ public ICalibratorTrainer CreateComponent(IHostEnvironment env) } } - public const string UserName = "Fixed Sigmoid Calibration"; - public const string LoadName = "FixedPlattCalibration"; + internal const string UserName = "Fixed Sigmoid Calibration"; + internal const string LoadName = "FixedPlattCalibration"; internal const string Summary = "Sigmoid calibrator with configurable slope and offset."; private readonly IHost _host; private readonly Double _slope; private readonly Double _offset; - public FixedPlattCalibratorTrainer(IHostEnvironment env, Arguments args) + internal FixedPlattCalibratorTrainer(IHostEnvironment env, Arguments args) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(LoadName); @@ -1322,21 +1360,19 @@ public FixedPlattCalibratorTrainer(IHostEnvironment env, Arguments args) _offset = args.Offset; } - public bool NeedsTraining => false; + bool ICalibratorTrainer.NeedsTraining => false; - public bool ProcessTrainingExample(Float output, bool labelIs1, Float weight) - => false; + bool ICalibratorTrainer.ProcessTrainingExample(float output, bool labelIs1, float weight) => false; - public ICalibrator FinishTraining(IChannel ch) - { - return new PlattCalibrator(_host, _slope, _offset); - } + ICalibrator ICalibratorTrainer.FinishTraining(IChannel ch) => new PlattCalibrator(_host, _slope, _offset); } + /// The Platt calibrator calculates the probability following: + /// P(x) = 1 / (1 + exp(- * x + ) . public sealed class PlattCalibrator : ICalibrator, IParameterMixer, ICanSaveModel, ISingleCanSavePfa, ISingleCanSaveOnnx { - public const string LoaderSignature = "PlattCaliExec"; - public const string RegistrationName = "PlattCalibrator"; + internal const string LoaderSignature = "PlattCaliExec"; + internal const string RegistrationName = "PlattCalibrator"; private static VersionInfo GetVersionInfo() { return new VersionInfo( @@ -1350,17 +1386,20 @@ private static VersionInfo GetVersionInfo() private readonly IHost _host; - public Double ParamA { get; } - public Double ParamB { get; } + public Double Slope { get; } + public Double Offset { get; } bool ICanSavePfa.CanSavePfa => true; bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true; - public PlattCalibrator(IHostEnvironment env, Double paramA, Double paramB) + /// + /// Initializes a new instance of . + /// + public PlattCalibrator(IHostEnvironment env, Double slope, Double offset) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(RegistrationName); - ParamA = paramA; - ParamB = paramB; + Slope = slope; + Offset = offset; } private PlattCalibrator(IHostEnvironment env, ModelLoadContext ctx) @@ -1372,14 +1411,14 @@ private PlattCalibrator(IHostEnvironment env, ModelLoadContext ctx) // *** Binary format *** // Double: A // Double: B - ParamA = ctx.Reader.ReadDouble(); - _host.CheckDecode(FloatUtils.IsFinite(ParamA)); + Slope = ctx.Reader.ReadDouble(); + _host.CheckDecode(FloatUtils.IsFinite(Slope)); - ParamB = ctx.Reader.ReadDouble(); - _host.CheckDecode(FloatUtils.IsFinite(ParamB)); + Offset = ctx.Reader.ReadDouble(); + _host.CheckDecode(FloatUtils.IsFinite(Offset)); } - public static PlattCalibrator Create(IHostEnvironment env, ModelLoadContext ctx) + private static PlattCalibrator Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); @@ -1403,8 +1442,8 @@ private void SaveCore(ModelSaveContext ctx) // *** Binary format *** // Double: A // Double: B - ctx.Writer.Write(ParamA); - ctx.Writer.Write(ParamB); + ctx.Writer.Write(Slope); + ctx.Writer.Write(Offset); if (ctx.InRepository) { @@ -1412,22 +1451,22 @@ private void SaveCore(ModelSaveContext ctx) { writer.WriteLine("Platt calibrator"); writer.WriteLine("P(y=1|x) = 1/1+exp(A*x + B)"); - writer.WriteLine("A={0:R}", (object)ParamA); - writer.WriteLine("B={0:R}", ParamB); + writer.WriteLine("A={0:R}", (object)Slope); + writer.WriteLine("B={0:R}", Offset); })); } } - public Float PredictProbability(Float output) + public float PredictProbability(float output) { - if (Float.IsNaN(output)) + if (float.IsNaN(output)) return output; - return PredictProbability(output, ParamA, ParamB); + return PredictProbability(output, Slope, Offset); } - public static Float PredictProbability(Float output, Double a, Double b) + public static float PredictProbability(float output, Double a, Double b) { - return (Float)(1 / (1 + Math.Exp(a * output + b))); + return (float)(1 / (1 + Math.Exp(a * output + b))); } JToken ISingleCanSavePfa.SaveAsPfa(BoundPfaContext ctx, JToken input) @@ -1436,7 +1475,7 @@ JToken ISingleCanSavePfa.SaveAsPfa(BoundPfaContext ctx, JToken input) _host.CheckValue(input, nameof(input)); return PfaUtils.Call("m.link.logit", - PfaUtils.Call("+", -ParamB, PfaUtils.Call("*", -ParamA, input))); + PfaUtils.Call("+", -Offset, PfaUtils.Call("*", -Slope, input))); } bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] scoreProbablityColumnNames, string featureColumnName) @@ -1449,7 +1488,7 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] scoreProbablityColu string linearOutput = ctx.AddIntermediateVariable(null, "linearOutput", true); var node = ctx.CreateNode(opType, new[] { scoreProbablityColumnNames[0] }, new[] { linearOutput }, ctx.GetNodeName(opType), ""); - node.AddAttribute("alpha", ParamA * -1); + node.AddAttribute("alpha", Slope * -1); node.AddAttribute("beta", -0.0000001); opType = "Sigmoid"; @@ -1461,10 +1500,10 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] scoreProbablityColu public string GetSummary() { - return string.Format("Platt calibrator parameters: A={0}, B={1}", ParamA, ParamB); + return string.Format("Platt calibrator parameters: A={0}, B={1}", Slope, Offset); } - public IParameterMixer CombineParameters(IList calibrators) + IParameterMixer IParameterMixer.CombineParameters(IList calibrators) { Double a = 0; Double b = 0; @@ -1472,8 +1511,8 @@ public IParameterMixer CombineParameters(IList calibrators) { PlattCalibrator cal = calibrator as PlattCalibrator; - a += cal.ParamA; - b += cal.ParamB; + a += cal.Slope; + b += cal.Offset; } PlattCalibrator newCal = new PlattCalibrator(_host, a / calibrators.Count, b / calibrators.Count); @@ -1482,7 +1521,7 @@ public IParameterMixer CombineParameters(IList calibrators) } [TlcModule.Component(Name = "PavCalibrator", FriendlyName = "PAV Calibrator", Alias = "Pav")] - public sealed class PavCalibratorTrainerFactory : ICalibratorTrainerFactory + internal sealed class PavCalibratorTrainerFactory : ICalibratorTrainerFactory { public ICalibratorTrainer CreateComponent(IHostEnvironment env) { @@ -1495,12 +1534,12 @@ public class PavCalibratorTrainer : CalibratorTrainerBase // a piece of the piecwise function private readonly struct Piece { - public readonly Float MinX; // end of interval. - public readonly Float MaxX; // beginning of interval. - public readonly Float Value; // value of function in interval. - public readonly Float N; // number of points/sum of weights of interval. + public readonly float MinX; // end of interval. + public readonly float MaxX; // beginning of interval. + public readonly float Value; // value of function in interval. + public readonly float N; // number of points/sum of weights of interval. - public Piece(Float minX, Float maxX, Float value, Float n) + public Piece(float minX, float maxX, float value, float n) { Contracts.Assert(minX <= maxX); // REVIEW: Can this fail due to more innocent imprecision issues? @@ -1513,10 +1552,11 @@ public Piece(Float minX, Float maxX, Float value, Float n) } } - public const string UserName = "PAV Calibration"; - public const string LoadName = "PAVCalibration"; + internal const string UserName = "PAV Calibration"; + internal const string LoadName = "PAVCalibration"; internal const string Summary = "Piecewise linear calibrator."; + // REVIEW: Do we need a ctor that initializes min, max, value, n? public PavCalibratorTrainer(IHostEnvironment env) : base(env, LoadName) { @@ -1533,7 +1573,7 @@ public override ICalibrator CreateCalibrator(IChannel ch) Piece curr = new Piece(di.Score, di.Score, di.Target ? 1 : 0, di.Weight); for (; stack.Count > 0 && ((top.MaxX >= curr.MinX) || curr.Value <= top.Value);) { - Float newN = top.N + curr.N; + float newN = top.N + curr.N; curr = new Piece(top.MinX, curr.MaxX, (top.Value * top.N + curr.Value * curr.N) / newN, newN); stack.Pop(); if (stack.Count > 0) @@ -1545,9 +1585,9 @@ public override ICalibrator CreateCalibrator(IChannel ch) } ch.Info("PAV calibrator: piecewise function approximation has {0} components.", stack.Count); - Float[] mins = new Float[stack.Count]; - Float[] maxes = new Float[stack.Count]; - Float[] values = new Float[stack.Count]; + float[] mins = new float[stack.Count]; + float[] maxes = new float[stack.Count]; + float[] values = new float[stack.Count]; for (int i = stack.Count - 1; stack.Count > 0; --i) { @@ -1557,7 +1597,7 @@ public override ICalibrator CreateCalibrator(IChannel ch) values[i] = top.Value; } - return new PavCalibrator(Host, mins, maxes, values); + return new PavCalibrator(Host, mins.ToImmutableArray(), maxes.ToImmutableArray(), values.ToImmutableArray()); } } @@ -1570,8 +1610,8 @@ public override ICalibrator CreateCalibrator(IChannel ch) /// public sealed class PavCalibrator : ICalibrator, ICanSaveInBinaryFormat { - public const string LoaderSignature = "PAVCaliExec"; - public const string RegistrationName = "PAVCalibrator"; + internal const string LoaderSignature = "PAVCaliExec"; + internal const string RegistrationName = "PAVCalibrator"; private static VersionInfo GetVersionInfo() { @@ -1585,31 +1625,38 @@ private static VersionInfo GetVersionInfo() } // Epsilon for 0-comparisons - private const Float Epsilon = (Float)1e-15; - private const Float MinToReturn = Epsilon; // max predicted is 1 - min; - private const Float MaxToReturn = 1 - Epsilon; // max predicted is 1 - min; + private const float Epsilon = (float)1e-15; + private const float MinToReturn = Epsilon; // max predicted is 1 - min; + private const float MaxToReturn = 1 - Epsilon; // max predicted is 1 - min; private readonly IHost _host; - private readonly Float[] _mins; - private readonly Float[] _maxes; - private readonly Float[] _values; + public readonly ImmutableArray Mins; + public readonly ImmutableArray Maxes; + public readonly ImmutableArray Values; - internal PavCalibrator(IHostEnvironment env, Float[] mins, Float[] maxes, Float[] values) + /// + /// Initializes a new instance of . + /// + /// The to use. + /// The minimum values for each piece. + /// The maximum values for each piece. + /// The actual values for each piece. + public PavCalibrator(IHostEnvironment env, ImmutableArray mins, ImmutableArray maxes, ImmutableArray values) { Contracts.AssertValue(env); _host = env.Register(RegistrationName); - _host.AssertValue(mins); - _host.AssertValue(maxes); - _host.AssertValue(values); + _host.AssertNonEmpty(mins); + _host.AssertNonEmpty(maxes); + _host.AssertNonEmpty(values); _host.Assert(Utils.IsSorted(mins)); _host.Assert(Utils.IsSorted(maxes)); _host.Assert(Utils.IsSorted(values)); _host.Assert(values.Length == 0 || (0 <= values[0] && values[values.Length - 1] <= 1)); _host.Assert(mins.Zip(maxes, (min, max) => min <= max).All(x => x)); - _mins = mins; - _maxes = maxes; - _values = values; + Mins = mins; + Maxes = maxes; + Values = values; } private PavCalibrator(IHostEnvironment env, ModelLoadContext ctx) @@ -1619,40 +1666,44 @@ private PavCalibrator(IHostEnvironment env, ModelLoadContext ctx) _host.AssertValue(ctx); // *** Binary format *** - // int: sizeof(Float) + // int: sizeof(float) // int: number of pieces // for each piece: - // Float: MinX - // Float: MaxX - // Float: Value + // float: MinX + // float: MaxX + // float: Value int cbFloat = ctx.Reader.ReadInt32(); - _host.CheckDecode(cbFloat == sizeof(Float)); + _host.CheckDecode(cbFloat == sizeof(float)); int numPieces = ctx.Reader.ReadInt32(); _host.CheckDecode(numPieces >= 0); - _mins = new Float[numPieces]; - _maxes = new Float[numPieces]; - _values = new Float[numPieces]; - Float valuePrev = 0; - Float maxPrev = Float.NegativeInfinity; + var mins = new float[numPieces]; + var maxes = new float[numPieces]; + var values = new float[numPieces]; + float valuePrev = 0; + float maxPrev = float.NegativeInfinity; for (int i = 0; i < numPieces; ++i) { - Float minX = ctx.Reader.ReadFloat(); - Float maxX = ctx.Reader.ReadFloat(); - Float val = ctx.Reader.ReadFloat(); + float minX = ctx.Reader.ReadFloat(); + float maxX = ctx.Reader.ReadFloat(); + float val = ctx.Reader.ReadFloat(); _host.CheckDecode(minX <= maxX); _host.CheckDecode(minX > maxPrev); _host.CheckDecode(val > valuePrev || val == valuePrev && i == 0); valuePrev = val; maxPrev = maxX; - _mins[i] = minX; - _maxes[i] = maxX; - _values[i] = val; + mins[i] = minX; + maxes[i] = maxX; + values[i] = val; } + + Mins = mins.ToImmutableArray(); + Maxes = maxes.ToImmutableArray(); + Values = values.ToImmutableArray(); _host.CheckDecode(valuePrev <= 1); } - public static PavCalibrator Create(IHostEnvironment env, ModelLoadContext ctx) + private static PavCalibrator Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); @@ -1672,38 +1723,38 @@ private void SaveCore(ModelSaveContext ctx) ctx.SetVersionInfo(GetVersionInfo()); // *** Binary format *** - // int: sizeof(Float) + // int: sizeof(float) // int: number of pieces // for each piece: - // Float: MinX - // Float: MaxX - // Float: Value - ctx.Writer.Write(sizeof(Float)); - - _host.Assert(_mins.Length == _maxes.Length); - _host.Assert(_mins.Length == _values.Length); - ctx.Writer.Write(_mins.Length); - Float valuePrev = 0; - Float maxPrev = Float.NegativeInfinity; - for (int i = 0; i < _mins.Length; i++) + // float: MinX + // float: MaxX + // float: Value + ctx.Writer.Write(sizeof(float)); + + _host.Assert(Mins.Length == Maxes.Length); + _host.Assert(Mins.Length == Values.Length); + ctx.Writer.Write(Mins.Length); + float valuePrev = 0; + float maxPrev = float.NegativeInfinity; + for (int i = 0; i < Mins.Length; i++) { - _host.Assert(_mins[i] <= _maxes[i]); - _host.Assert(_mins[i] > maxPrev); - _host.Assert(_values[i] > valuePrev || _values[i] == valuePrev && i == 0); - valuePrev = _values[i]; - maxPrev = _maxes[i]; - ctx.Writer.Write(_mins[i]); - ctx.Writer.Write(_maxes[i]); - ctx.Writer.Write(_values[i]); + _host.Assert(Mins[i] <= Maxes[i]); + _host.Assert(Mins[i] > maxPrev); + _host.Assert(Values[i] > valuePrev || Values[i] == valuePrev && i == 0); + valuePrev = Values[i]; + maxPrev = Maxes[i]; + ctx.Writer.Write(Mins[i]); + ctx.Writer.Write(Maxes[i]); + ctx.Writer.Write(Values[i]); } _host.CheckDecode(valuePrev <= 1); } - public Float PredictProbability(Float output) + public float PredictProbability(float output) { - if (Float.IsNaN(output)) + if (float.IsNaN(output)) return output; - Float prob = FindValue(output); + float prob = FindValue(output); if (prob < MinToReturn) return MinToReturn; if (prob > MaxToReturn) @@ -1711,37 +1762,37 @@ public Float PredictProbability(Float output) return prob; } - private Float FindValue(Float score) + private float FindValue(float score) { - int p = _mins.Length; + int p = Mins.Length; if (p == 0) return 0; - if (score < _mins[0]) + if (score < Mins[0]) { - return _values[0]; + return Values[0]; // tail off to zero exponentially // return Math.Exp(-(piecewise[0].MinX-score)) * piecewise[0].Value; } - if (score > _maxes[p - 1]) + if (score > Maxes[p - 1]) { - return _values[p - 1]; + return Values[p - 1]; // tail off to one exponentially // return (1-Math.Exp(-(score - piecewise[P - 1].MaxX))) * (1 - piecewise[P - 1].Value) + piecewise[P - 1].Value; } - int pos = _maxes.FindIndexSorted(score); + int pos = Maxes.FindIndexSorted(score); _host.Assert(pos < p); // inside the piece, the value is constant - if (score >= _mins[pos]) - return _values[pos]; + if (score >= Mins[pos]) + return Values[pos]; // between pieces, interpolate - Float t = (score - _maxes[pos - 1]) / (_mins[pos] - _maxes[pos - 1]); - return _values[pos - 1] + t * (_values[pos] - _values[pos - 1]); + float t = (score - Maxes[pos - 1]) / (Mins[pos] - Maxes[pos - 1]); + return Values[pos - 1] + t * (Values[pos] - Values[pos - 1]); } public string GetSummary() { - return string.Format("PAV calibrator with {0} intervals", _mins.Length); + return string.Format("PAV calibrator with {0} intervals", Mins.Length); } } @@ -1752,11 +1803,11 @@ public readonly struct DataItem // The actual binary label of this example. public readonly bool Target; // The weight associated with this example. - public readonly Float Weight; + public readonly float Weight; // The output of the example. - public readonly Float Score; + public readonly float Score; - public DataItem(bool target, Float weight, Float score) + public DataItem(bool target, float weight, float score) { Target = target; Weight = weight; @@ -1811,10 +1862,10 @@ IEnumerator IEnumerable.GetEnumerator() return GetEnumerator(); } - public void AddToStore(Float score, bool isPositive, Float weight) + public void AddToStore(float score, bool isPositive, float weight) { // Can't calibrate NaN scores. - if (weight == 0 || Float.IsNaN(score)) + if (weight == 0 || float.IsNaN(score)) return; int index = _itemsSeen++; if (_itemsSeen <= _capacity) @@ -1829,13 +1880,13 @@ public void AddToStore(Float score, bool isPositive, Float weight) } } - public static class Calibrate + internal static class Calibrate { [TlcModule.EntryPointKind(typeof(CommonInputs.ICalibratorInput))] public abstract class CalibrateInputBase : TransformInputBase { [Argument(ArgumentType.Required, ShortName = "uncalibratedPredictorModel", HelpText = "The predictor to calibrate", SortOrder = 2)] - public IPredictorModel UncalibratedPredictorModel; + public PredictorModel UncalibratedPredictorModel; [Argument(ArgumentType.Required, ShortName = "maxRows", HelpText = "The maximum number of examples to train the calibrator on", SortOrder = 3)] [TlcModule.Range(Inf = 0, Max = int.MaxValue)] @@ -1908,8 +1959,8 @@ public static CommonOutputs.CalibratorOutput FixedPlatt(IHostEnvironment env, Fi /// The input object, containing the predictor, the data and an integer indicating the maximum number /// of examples to use for training the calibrator. /// The kind of calibrator to use. - /// A object, containing an . - public static TOut CalibratePredictor(IHost host, CalibrateInputBase input, + /// A object, containing an . + internal static TOut CalibratePredictor(IHost host, CalibrateInputBase input, ICalibratorTrainer calibratorTrainer) where TOut : CommonOutputs.TrainerOutput, new() { @@ -1931,10 +1982,10 @@ public static TOut CalibratePredictor(IHost host, CalibrateInputBase input else { calibratedPredictor = - CalibratorUtils.TrainCalibrator(host, ch, calibratorTrainer, input.MaxRows, predictor, data); + CalibratorUtils.GetCalibratedPredictor(host, ch, calibratorTrainer, predictor, data, input.MaxRows); } - return new TOut() { PredictorModel = new PredictorModel(host, data, input.Data, calibratedPredictor) }; + return new TOut() { PredictorModel = new PredictorModelImpl(host, data, input.Data, calibratedPredictor) }; } } } -} +} \ No newline at end of file diff --git a/src/Microsoft.ML.Data/Prediction/CalibratorCatalog.cs b/src/Microsoft.ML.Data/Prediction/CalibratorCatalog.cs new file mode 100644 index 0000000000..8728f5216b --- /dev/null +++ b/src/Microsoft.ML.Data/Prediction/CalibratorCatalog.cs @@ -0,0 +1,461 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML; +using Microsoft.ML.Calibrator; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Calibration; +using Microsoft.ML.Model; +using Microsoft.ML.Training; + +[assembly: LoadableClass(typeof(CalibratorTransformer), typeof(PlattCalibratorTransformer), null, + typeof(SignatureLoadModel), "", PlattCalibratorTransformer.LoadName)] + +[assembly: LoadableClass(typeof(CalibratorTransformer), typeof(NaiveCalibratorTransformer), null, + typeof(SignatureLoadModel), "", NaiveCalibratorTransformer.LoadName)] + +[assembly: LoadableClass(typeof(CalibratorTransformer), typeof(PavCalibratorTransformer), null, + typeof(SignatureLoadModel), "", PavCalibratorTransformer.LoadName)] + +namespace Microsoft.ML.Calibrator +{ + + /// + /// An interface for probability calibrators. + /// + public interface ICalibrator + { + /// Given a classifier output, produce the probability. + float PredictProbability(float output); + } + + /// + /// Base class for CalibratorEstimators. + /// + /// + /// CalibratorEstimators take an (the output of a ) + /// that contains a "Score" column, and converts the scores to probabilities(through binning, interpolation etc.), based on the type. + /// They are used in pipelines where the binary classifier produces non-calibrated scores. + /// + /// + /// + /// + /// + public abstract class CalibratorEstimatorBase : IEstimator> + where TCalibratorTrainer : ICalibratorTrainer + where TICalibrator : class, ICalibrator + { + protected readonly IHostEnvironment Host; + protected readonly TCalibratorTrainer CalibratorTrainer; + + protected readonly IPredictor Predictor; + protected readonly SchemaShape.Column ScoreColumn; + protected readonly SchemaShape.Column FeatureColumn; + protected readonly SchemaShape.Column LabelColumn; + protected readonly SchemaShape.Column WeightColumn; + protected readonly SchemaShape.Column PredictedLabel; + + protected CalibratorEstimatorBase(IHostEnvironment env, + TCalibratorTrainer calibratorTrainer, + IPredictor predictor = null, + string labelColumn = DefaultColumnNames.Label, + string featureColumn = DefaultColumnNames.Features, + string weightColumn = null) + { + Host = env; + Predictor = predictor; + CalibratorTrainer = calibratorTrainer; + + ScoreColumn = TrainerUtils.MakeR4ScalarColumn(DefaultColumnNames.Score); // Do we fantom this being named anything else (renaming column)? Complete metadata? + LabelColumn = TrainerUtils.MakeBoolScalarLabel(labelColumn); + FeatureColumn = TrainerUtils.MakeR4VecFeature(featureColumn); + PredictedLabel = new SchemaShape.Column(DefaultColumnNames.PredictedLabel, + SchemaShape.Column.VectorKind.Scalar, + BoolType.Instance, + false, + new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())); + + if (weightColumn != null) + WeightColumn = TrainerUtils.MakeR4ScalarWeightColumn(weightColumn); + } + + /// + /// Gets the output of the after fitting the calibrator. + /// Fitting the calibrator will add a column named "Probability" to the schema. If you already had such a column, a new one will be added. + /// + /// The input . + SchemaShape IEstimator>.GetOutputSchema(SchemaShape inputSchema) + { + Action checkColumnValid = (SchemaShape.Column column, string expected) => + { + if (column.IsValid) + { + if (!inputSchema.TryFindColumn(column.Name, out var outCol)) + throw Host.Except($"{expected} column '{column.Name}' is not found"); + if (!column.IsCompatibleWith(outCol)) + throw Host.Except($"{expected} column '{column.Name}' is not compatible"); + } + }; + + // check the input schema + checkColumnValid(ScoreColumn, DefaultColumnNames.Score); + checkColumnValid(WeightColumn, DefaultColumnNames.Weight); + checkColumnValid(LabelColumn, DefaultColumnNames.Label); + checkColumnValid(FeatureColumn, DefaultColumnNames.Features); + checkColumnValid(PredictedLabel, DefaultColumnNames.PredictedLabel); + + //create the new Probability column + var outColumns = inputSchema.ToDictionary(x => x.Name); + outColumns[DefaultColumnNames.Probability] = new SchemaShape.Column(DefaultColumnNames.Probability, + SchemaShape.Column.VectorKind.Scalar, + NumberType.R4, + false, + new SchemaShape(MetadataUtils.GetTrainerOutputMetadata(true))); + + return new SchemaShape(outColumns.Values); + } + + public CalibratorTransformer Fit(IDataView input) + { + TICalibrator calibrator = null; + + var roles = new List>(); + roles.Add(RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, DefaultColumnNames.Score)); + roles.Add(RoleMappedSchema.ColumnRole.Label.Bind(LabelColumn.Name)); + roles.Add(RoleMappedSchema.ColumnRole.Feature.Bind(FeatureColumn.Name)); + if (WeightColumn.IsValid) + roles.Add(RoleMappedSchema.ColumnRole.Weight.Bind(WeightColumn.Name)); + + var roleMappedData = new RoleMappedData(input, opt: false, roles.ToArray()); + + using (var ch = Host.Start("Creating calibrator.")) + calibrator = (TICalibrator)CalibratorUtils.TrainCalibrator(Host, ch, CalibratorTrainer, Predictor, roleMappedData); + + return Create(Host, calibrator); + } + + /// + /// Implemented by deriving classes that create a concrete calibrator. + /// + protected abstract CalibratorTransformer Create(IHostEnvironment env, TICalibrator calibrator); + } + + /// + /// CalibratorTransfomers, the artifact of calling Fit on a . + /// If you pass a scored data, to the Transform method, it will add the Probability column + /// to the dataset. The Probability column is the value of the Score normalized to be a valid probability. + /// The CalibratorTransformer is an instance of where score can be viewed as a feature + /// while probability is treated as the label. + /// + /// The used to transform the data. + public abstract class CalibratorTransformer : RowToRowTransformerBase, ISingleFeaturePredictionTransformer + where TICalibrator : class, ICalibrator + { + private TICalibrator _calibrator; + private readonly string _loaderSignature; + + internal CalibratorTransformer(IHostEnvironment env, TICalibrator calibrator, string loaderSignature) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(CalibratorTransformer))) + { + Host.CheckRef(calibrator, nameof(calibrator)); + + _loaderSignature = loaderSignature; + _calibrator = calibrator; + } + + // Factory method for SignatureLoadModel. + internal CalibratorTransformer(IHostEnvironment env, ModelLoadContext ctx, string loaderSignature) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(CalibratorTransformer))) + { + Contracts.AssertValue(ctx); + + _loaderSignature = loaderSignature; + ctx.CheckAtModel(GetVersionInfo()); + + // *** Binary format *** + // model: _calibrator + ctx.LoadModel(env, out _calibrator, @"Calibrator"); + } + + string ISingleFeaturePredictionTransformer.FeatureColumn => DefaultColumnNames.Score; + + ColumnType ISingleFeaturePredictionTransformer.FeatureColumnType => NumberType.Float; + + TICalibrator IPredictionTransformer.Model => _calibrator; + + bool ITransformer.IsRowToRowMapper => true; + + public override void Save(ModelSaveContext ctx) + { + Contracts.AssertValue(ctx); + ctx.CheckAtModel(); + ctx.SetVersionInfo(GetVersionInfo()); + + // *** Binary format *** + // model: _calibrator + ctx.SaveModel(_calibrator, @"Calibrator"); + } + + private protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, _calibrator, schema); + + protected VersionInfo GetVersionInfo() + { + return new VersionInfo( + modelSignature: "CALTRANS", + verWrittenCur: 0x00010001, // Initial + verReadableCur: 0x00010001, + verWeCanReadBack: 0x00010001, + loaderSignature: _loaderSignature, + loaderAssemblyName: typeof(CalibratorTransformer<>).Assembly.FullName); + } + + private sealed class Mapper : MapperBase + where TCalibrator : class, ICalibrator + { + private TCalibrator _calibrator; + private int _scoreColIndex; + private CalibratorTransformer _parent; + + internal Mapper(CalibratorTransformer parent, TCalibrator calibrator, Schema inputSchema) : + base(parent.Host, inputSchema) + { + _calibrator = calibrator; + _parent = parent; + + _scoreColIndex = inputSchema.GetColumnOrNull(DefaultColumnNames.Score)?.Index ?? -1; + + parent.Host.Check(_scoreColIndex > 0, "The data to calibrate contains no 'Score' column"); + } + + private protected override Func GetDependenciesCore(Func activeOutput) + => col => col == _scoreColIndex; + + public override void Save(ModelSaveContext ctx) => _parent.Save(ctx); + + protected override Schema.DetachedColumn[] GetOutputColumnsCore() + { + return new[] + { + new Schema.DetachedColumn(DefaultColumnNames.Probability, NumberType.Float, null) + }; + } + + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) + { + Host.AssertValue(input); + disposer = null; + + Host.Assert(input.IsColumnActive(_scoreColIndex)); + var getScore = input.GetGetter(_scoreColIndex); + + float score = default; + + ValueGetter probability = (ref float dst) => + { + getScore(ref score); + dst = _calibrator.PredictProbability(score); + }; + + return probability; + } + } + } + + /// + /// The PlattCalibratorEstimator. + /// + /// + /// For the usage pattern see the example in . + /// + public sealed class PlattCalibratorEstimator : CalibratorEstimatorBase + { + /// + /// Initializes a new instance of + /// + /// The environment to use. + /// The predictor used to train the data. + /// The label column name. + /// The feature column name. + /// The weight column name. + public PlattCalibratorEstimator(IHostEnvironment env, + IPredictor predictor, + string labelColumn = DefaultColumnNames.Label, + string featureColumn = DefaultColumnNames.Features, + string weightColumn = null) : base(env, new PlattCalibratorTrainer(env), predictor, labelColumn, featureColumn, weightColumn) + { + + } + + protected override CalibratorTransformer Create(IHostEnvironment env, PlattCalibrator calibrator) + => new PlattCalibratorTransformer(env, calibrator); + } + + /// + /// Obtains the probability values by fitting the sigmoid: f(x) = 1 / (1 + exp(-slope * x + offset). + /// + /// + /// For the usage pattern see the example in . + /// + public sealed class FixedPlattCalibratorEstimator : CalibratorEstimatorBase + { + /// + /// Initializes a new instance of + /// + /// The environment to use. + /// The predictor used to train the data. + /// The slope in the function of the exponent of the sigmoid. + /// The offset in the function of the exponent of the sigmoid. + /// The label column name. + /// The feature column name. + /// The weight column name. + public FixedPlattCalibratorEstimator(IHostEnvironment env, + IPredictor predictor, + double slope = 1, + double offset = 0, + string labelColumn = DefaultColumnNames.Label, + string featureColumn = DefaultColumnNames.Features, + string weightColumn = null) : base(env, new FixedPlattCalibratorTrainer(env, new FixedPlattCalibratorTrainer.Arguments() + { + Slope = slope, + Offset = offset + }), predictor, labelColumn, featureColumn, weightColumn) + { + + } + + protected override CalibratorTransformer Create(IHostEnvironment env, PlattCalibrator calibrator) + => new PlattCalibratorTransformer(env, calibrator); + } + + /// + /// The implementation obtained by training a or a . + /// + public sealed class PlattCalibratorTransformer : CalibratorTransformer + { + internal const string LoadName = "PlattCalibratTransf"; + + internal PlattCalibratorTransformer(IHostEnvironment env, PlattCalibrator calibrator) + : base(env, calibrator, LoadName) + { + + } + + // Factory method for SignatureLoadModel. + internal PlattCalibratorTransformer(IHostEnvironment env, ModelLoadContext ctx) + :base(env, ctx, LoadName) + { + + } + } + + /// + /// The naive binning-based calibratorEstimator. + /// + /// + /// It divides the range of the outputs into equally sized bins. In each bin, + /// the probability of belonging to class 1, is the number of class 1 instances in the bin, divided by the total number + /// of instances in the bin. + /// For the usage pattern see the example in . + /// + public sealed class NaiveCalibratorEstimator : CalibratorEstimatorBase + { + /// + /// Initializes a new instance of + /// + /// The environment to use. + /// The predictor used to train the data. + /// The label column name. + /// The feature column name. + /// The weight column name. + public NaiveCalibratorEstimator(IHostEnvironment env, + IPredictor predictor, + string labelColumn = DefaultColumnNames.Label, + string featureColumn = DefaultColumnNames.Features, + string weightColumn = null) : base(env, new NaiveCalibratorTrainer(env), predictor, labelColumn, featureColumn, weightColumn) + { + + } + + protected override CalibratorTransformer Create(IHostEnvironment env, NaiveCalibrator calibrator) + => new NaiveCalibratorTransformer(env, calibrator); + } + + /// + /// The implementation obtained by training a + /// + public sealed class NaiveCalibratorTransformer : CalibratorTransformer + { + internal const string LoadName = "NaiveCalibratTransf"; + + internal NaiveCalibratorTransformer(IHostEnvironment env, NaiveCalibrator calibrator) + : base(env, calibrator, LoadName) + { + + } + + // Factory method for SignatureLoadModel. + internal NaiveCalibratorTransformer(IHostEnvironment env, ModelLoadContext ctx) + : base(env, ctx, LoadName) + { + + } + } + + /// + /// The PavCalibratorEstimator. + /// + /// + /// For the usage pattern see the example in . + /// + public sealed class PavCalibratorEstimator : CalibratorEstimatorBase + { + /// + /// Initializes a new instance of + /// + /// The environment to use. + /// The predictor used to train the data. + /// The label column name. + /// The feature column name. + /// The weight column name. + public PavCalibratorEstimator(IHostEnvironment env, + IPredictor predictor, + string labelColumn = DefaultColumnNames.Label, + string featureColumn = DefaultColumnNames.Features, + string weightColumn = null) : base(env, new PavCalibratorTrainer(env), predictor, labelColumn, featureColumn, weightColumn) + { + + } + + protected override CalibratorTransformer Create(IHostEnvironment env, PavCalibrator calibrator) + => new PavCalibratorTransformer(env, calibrator); + + } + + /// + /// The implementation obtained by training a + /// + public sealed class PavCalibratorTransformer : CalibratorTransformer + { + internal const string LoadName = "PavCalibratTransf"; + + internal PavCalibratorTransformer(IHostEnvironment env, PavCalibrator calibrator) + : base(env, calibrator, LoadName) + { + + } + + // Factory method for SignatureLoadModel. + private PavCalibratorTransformer(IHostEnvironment env, ModelLoadContext ctx) + : base(env, ctx, LoadName) + { + + } + } +} diff --git a/src/Microsoft.ML.Data/Prediction/IPredictionTransformer.cs b/src/Microsoft.ML.Data/Prediction/IPredictionTransformer.cs index 06c1894f0a..9292edbe9b 100644 --- a/src/Microsoft.ML.Data/Prediction/IPredictionTransformer.cs +++ b/src/Microsoft.ML.Data/Prediction/IPredictionTransformer.cs @@ -2,20 +2,19 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System.Collections.Generic; +using Microsoft.ML.Calibrator; using Microsoft.ML.Core.Data; -using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Data; -namespace Microsoft.ML.Runtime +namespace Microsoft.ML { /// /// An interface for all the transformer that can transform data based on the field. /// The implemendations of this interface either have no feature column, or have more than one feature column, and cannot implement the /// , which most of the ML.Net tranformer implement. /// - /// The used for the data transformation. + /// The or used for the data transformation. public interface IPredictionTransformer : ITransformer - where TModel : IPredictor { TModel Model { get; } } @@ -25,9 +24,8 @@ public interface IPredictionTransformer : ITransformer /// and its type, . Implementations of this interface, have the ability /// to score the data of an input through the /// - /// The used for the data transformation. + /// The or used for the data transformation. public interface ISingleFeaturePredictionTransformer : IPredictionTransformer - where TModel : IPredictor { /// The name of the feature column. string FeatureColumn { get; } @@ -35,4 +33,4 @@ public interface ISingleFeaturePredictionTransformer : IPredictionTr /// Holds information about the type of the feature column. ColumnType FeatureColumnType { get; } } -} +} \ No newline at end of file diff --git a/src/Microsoft.ML.Api/PredictionEngine.cs b/src/Microsoft.ML.Data/Prediction/PredictionEngine.cs similarity index 90% rename from src/Microsoft.ML.Api/PredictionEngine.cs rename to src/Microsoft.ML.Data/Prediction/PredictionEngine.cs index 67734d0950..d4fede17f7 100644 --- a/src/Microsoft.ML.Api/PredictionEngine.cs +++ b/src/Microsoft.ML.Data/Prediction/PredictionEngine.cs @@ -2,15 +2,15 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Core.Data; -using Microsoft.ML.Data; -using Microsoft.ML.Runtime.Data; using System; using System.Collections.Generic; using System.IO; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data; -namespace Microsoft.ML.Runtime.Api +namespace Microsoft.ML { + // REVIEW: Temporarly moving here since it is used by the Legacy project. Remove when removing the legacy project. /// /// A class that runs the previously trained model (and the preceding transform pipeline) on the /// in-memory data in batch mode. @@ -19,7 +19,8 @@ namespace Microsoft.ML.Runtime.Api /// /// The user-defined type that holds the example. /// The user-defined type that holds the prediction. - public sealed class BatchPredictionEngine + [BestFriend] + internal sealed class BatchPredictionEngine where TSrc : class where TDst : class, new() { @@ -157,13 +158,15 @@ public override void Predict(TSrc example, ref TDst prediction) /// /// The user-defined type that holds the example. /// The user-defined type that holds the prediction. - public abstract class PredictionEngineBase + public abstract class PredictionEngineBase : IDisposable where TSrc : class where TDst : class, new() { private readonly DataViewConstructionUtils.InputRow _inputRow; private readonly IRowReadableAs _outputRow; private readonly Action _disposer; + private bool _disposed; + [BestFriend] private protected ITransformer Transformer { get; } @@ -193,12 +196,14 @@ private protected PredictionEngineBase(IHostEnvironment env, ITransformer transf PredictionEngineCore(env, _inputRow, makeMapper(_inputRow.Schema), ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition, out _disposer, out _outputRow); } - internal virtual void PredictionEngineCore(IHostEnvironment env, DataViewConstructionUtils.InputRow inputRow, IRowToRowMapper mapper, bool ignoreMissingColumns, + [BestFriend] + private protected virtual void PredictionEngineCore(IHostEnvironment env, DataViewConstructionUtils.InputRow inputRow, IRowToRowMapper mapper, bool ignoreMissingColumns, SchemaDefinition inputSchemaDefinition, SchemaDefinition outputSchemaDefinition, out Action disposer, out IRowReadableAs outputRow) { - var cursorable = TypedCursorable.Create(env, new EmptyDataView(env, mapper.Schema), ignoreMissingColumns, outputSchemaDefinition); - var outputRowLocal = mapper.GetRow(_inputRow, col => true, out disposer); + var cursorable = TypedCursorable.Create(env, new EmptyDataView(env, mapper.OutputSchema), ignoreMissingColumns, outputSchemaDefinition); + var outputRowLocal = mapper.GetRow(inputRow, col => true); outputRow = cursorable.GetRow(outputRowLocal); + disposer = inputRow.Dispose; } protected virtual Func TransformerChecker(IExceptionContext ectx, ITransformer transformer) @@ -208,9 +213,20 @@ protected virtual Func TransformerChecker(IExceptionCon return transformer.GetRowToRowMapper; } - ~PredictionEngineBase() + public void Dispose() + { + Disposing(true); + GC.SuppressFinalize(this); + } + + [BestFriend] + private protected void Disposing(bool disposing) { - _disposer?.Invoke(); + if (_disposed) + return; + if (disposing) + _disposer?.Invoke(); + _disposed = true; } /// diff --git a/src/Microsoft.ML.Data/Properties/AssemblyInfo.cs b/src/Microsoft.ML.Data/Properties/AssemblyInfo.cs index 849318c1e0..43587d4544 100644 --- a/src/Microsoft.ML.Data/Properties/AssemblyInfo.cs +++ b/src/Microsoft.ML.Data/Properties/AssemblyInfo.cs @@ -7,9 +7,13 @@ [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TestFramework" + PublicKey.TestValue)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Tests" + PublicKey.TestValue)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Core.Tests" + PublicKey.TestValue)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.InferenceTesting" + PublicKey.TestValue)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.OnnxTransformTest" + PublicKey.TestValue)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Predictor.Tests" + PublicKey.TestValue)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TimeSeries.Tests" + PublicKey.TestValue)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.EntryPoints" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Legacy" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Maml" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.ResultProcessor" + PublicKey.Value)] @@ -26,7 +30,7 @@ [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.PCA" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.PipelineInference" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Recommender" + PublicKey.Value)] -[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Runtime.ImageAnalytics" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.ImageAnalytics" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Scoring" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.StandardLearners" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Sweeper" + PublicKey.Value)] @@ -34,4 +38,6 @@ [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TimeSeries" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Transforms" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.StaticPipe" + PublicKey.Value)] + [assembly: WantsToBeBestFriends] diff --git a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs index 56a72d3f92..24b5433d9a 100644 --- a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs @@ -2,17 +2,15 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Float = System.Single; - using System; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model.Onnx; +using Microsoft.ML; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Model.Onnx; +using Microsoft.ML.Model.Pfa; using Newtonsoft.Json.Linq; -using Microsoft.ML.Runtime.Model.Pfa; -using System.Collections.Generic; +using Float = System.Single; [assembly: LoadableClass(typeof(BinaryClassifierScorer), typeof(BinaryClassifierScorer.Arguments), typeof(SignatureDataScorer), "Binary Classifier Scorer", "BinaryClassifierScorer", "BinaryClassifier", "Binary", @@ -21,7 +19,7 @@ [assembly: LoadableClass(typeof(BinaryClassifierScorer), null, typeof(SignatureLoadDataTransform), "Binary Classifier Scorer", BinaryClassifierScorer.LoaderSignature)] -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { public sealed class BinaryClassifierScorer : PredictedLabelScorerBase, ITransformCanSaveOnnx { @@ -65,7 +63,7 @@ private static ISchemaBoundMapper WrapIfNeeded(IHostEnvironment env, ISchemaBoun if (trainSchema?.Label == null) return mapper; // We don't even have a label identified in a training schema. - var keyType = trainSchema.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, trainSchema.Label.Index); + var keyType = trainSchema.Label.Value.Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues)?.Type; if (keyType == null || !CanWrap(mapper, keyType)) return mapper; @@ -94,11 +92,10 @@ private static bool CanWrap(ISchemaBoundMapper mapper, ColumnType labelNameType) if (rowMapper == null) return false; // We could cover this case, but it is of no practical worth as far as I see, so I decline to do so. - ISchema outSchema = mapper.Schema; int scoreIdx; - if (!outSchema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out scoreIdx)) + if (!mapper.OutputSchema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out scoreIdx)) return false; // The mapper doesn't even publish a score column to attach the metadata to. - if (outSchema.GetMetadataTypeOrNull(MetadataUtils.Kinds.TrainingLabelValues, scoreIdx) != null) + if (mapper.OutputSchema[scoreIdx].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.TrainingLabelValues)?.Type != null) return false; // The mapper publishes a score column, and already produces its own slot names. return labelNameType.IsVector && labelNameType.VectorSize == 2; @@ -110,24 +107,23 @@ private static ISchemaBoundMapper WrapCore(IHostEnvironment env, ISchemaBound env.AssertValue(mapper); env.AssertValue(trainSchema); env.Assert(mapper is ISchemaBoundRowMapper); + env.Assert(trainSchema.Label.HasValue); + var labelColumn = trainSchema.Label.Value; // Key values from the training schema label, will map to slot names of the score output. - var type = trainSchema.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, trainSchema.Label.Index); + var type = labelColumn.Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues)?.Type; env.AssertValue(type); env.Assert(type.IsVector); // Wrap the fetching of the metadata as a simple getter. - ValueGetter> getter = - (ref VBuffer value) => - { - trainSchema.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, - trainSchema.Label.Index, ref value); - }; + ValueGetter> getter = (ref VBuffer value) => + labelColumn.Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref value); - return MultiClassClassifierScorer.LabelNameBindableMapper.CreateBound(env, (ISchemaBoundRowMapper)mapper, type.AsVector, getter, MetadataUtils.Kinds.TrainingLabelValues, CanWrap); + return MultiClassClassifierScorer.LabelNameBindableMapper.CreateBound(env, (ISchemaBoundRowMapper)mapper, type as VectorType, getter, MetadataUtils.Kinds.TrainingLabelValues, CanWrap); } - public BinaryClassifierScorer(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema) + [BestFriend] + internal BinaryClassifierScorer(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema) : base(args, env, data, WrapIfNeeded(env, mapper, trainSchema), trainSchema, RegistrationName, MetadataUtils.Const.ScoreColumnKind.BinaryClassification, Contracts.CheckRef(args, nameof(args)).ThresholdColumn, OutputTypeMatches, GetPredColType) { @@ -170,7 +166,7 @@ public static BinaryClassifierScorer Create(IHostEnvironment env, ModelLoadConte return h.Apply("Loading Model", ch => new BinaryClassifierScorer(h, ctx, input)); } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.SetVersionInfo(GetVersionInfo()); @@ -221,10 +217,10 @@ public override IDataTransform ApplyToData(IHostEnvironment env, IDataView newSo return new BinaryClassifierScorer(env, this, newSource); } - protected override Delegate GetPredictedLabelGetter(IRow output, out Delegate scoreGetter) + protected override Delegate GetPredictedLabelGetter(Row output, out Delegate scoreGetter) { Host.AssertValue(output); - Host.Assert(output.Schema == Bindings.RowMapper.Schema); + Host.Assert(output.Schema == Bindings.RowMapper.OutputSchema); Host.Assert(output.IsColumnActive(Bindings.ScoreColumnIndex)); ValueGetter mapperScoreGetter = output.GetGetter(Bindings.ScoreColumnIndex); @@ -271,7 +267,7 @@ private void GetPredictedLabelCoreAsKey(Float score, ref uint value) value = (uint)(score > _threshold ? 2 : score <= _threshold ? 1 : 0); } - protected override JToken PredictedLabelPfa(string[] mapperOutputs) + private protected override JToken PredictedLabelPfa(string[] mapperOutputs) { Contracts.CheckParam(Utils.Size(mapperOutputs) >= 1, nameof(mapperOutputs)); diff --git a/src/Microsoft.ML.Data/Scorers/ClusteringScorer.cs b/src/Microsoft.ML.Data/Scorers/ClusteringScorer.cs index 40fd5b621f..8f950bcf69 100644 --- a/src/Microsoft.ML.Data/Scorers/ClusteringScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/ClusteringScorer.cs @@ -2,16 +2,15 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Float = System.Single; - using System; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Model.Pfa; -using Microsoft.ML.Runtime.Numeric; +using Microsoft.ML; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Model.Pfa; +using Microsoft.ML.Numeric; using Newtonsoft.Json.Linq; +using Float = System.Single; [assembly: LoadableClass(typeof(ClusteringScorer), typeof(ClusteringScorer.Arguments), typeof(SignatureDataScorer), "Clustering Scorer", "ClusteringScorer", MetadataUtils.Const.ScoreColumnKind.Clustering)] @@ -19,7 +18,7 @@ [assembly: LoadableClass(typeof(ClusteringScorer), null, typeof(SignatureLoadDataTransform), "Clustering Scorer", ClusteringScorer.LoaderSignature)] -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { public sealed class ClusteringScorer : PredictedLabelScorerBase { @@ -43,7 +42,8 @@ private static VersionInfo GetVersionInfo() private const string RegistrationName = "ClusteringScore"; - public ClusteringScorer(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema) + [BestFriend] + internal ClusteringScorer(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema) : base(args, env, data, mapper, trainSchema, RegistrationName, MetadataUtils.Const.ScoreColumnKind.Clustering, MetadataUtils.Const.ScoreValueKind.Score, OutputTypeMatches, GetPredColType) { @@ -72,7 +72,7 @@ public static ClusteringScorer Create(IHostEnvironment env, ModelLoadContext ctx return h.Apply("Loading Model", ch => new ClusteringScorer(h, ctx, input)); } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.SetVersionInfo(GetVersionInfo()); @@ -91,10 +91,10 @@ public override IDataTransform ApplyToData(IHostEnvironment env, IDataView newSo return new ClusteringScorer(env, this, newSource); } - protected override Delegate GetPredictedLabelGetter(IRow output, out Delegate scoreGetter) + protected override Delegate GetPredictedLabelGetter(Row output, out Delegate scoreGetter) { Contracts.AssertValue(output); - Contracts.Assert(output.Schema == Bindings.RowMapper.Schema); + Contracts.Assert(output.Schema == Bindings.RowMapper.OutputSchema); Contracts.Assert(output.IsColumnActive(Bindings.ScoreColumnIndex)); ValueGetter> mapperScoreGetter = output.GetGetter>(Bindings.ScoreColumnIndex); @@ -126,7 +126,7 @@ protected override Delegate GetPredictedLabelGetter(IRow output, out Delegate sc return predFn; } - protected override JToken PredictedLabelPfa(string[] mapperOutputs) + private protected override JToken PredictedLabelPfa(string[] mapperOutputs) { Contracts.Assert(Utils.Size(mapperOutputs) == 1); return PfaUtils.Call("a.argmax", mapperOutputs[0]); diff --git a/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculationTransform.cs b/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculation.cs similarity index 65% rename from src/Microsoft.ML.Data/Scorers/FeatureContributionCalculationTransform.cs rename to src/Microsoft.ML.Data/Scorers/FeatureContributionCalculation.cs index b112600f5a..cc9e4cce85 100644 --- a/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculationTransform.cs +++ b/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculation.cs @@ -6,48 +6,34 @@ using System.Collections.Generic; using System.Reflection; using System.Text; +using Microsoft.ML; +using Microsoft.ML.CommandLine; using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Internallearn; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Numeric; +using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Numeric; -[assembly: LoadableClass(typeof(IDataScorerTransform), typeof(FeatureContributionCalculationTransform), typeof(FeatureContributionCalculationTransform.Arguments), - typeof(SignatureDataScorer), "Feature Contribution Transform", "fct", "FeatureContributionCalculationTransform", MetadataUtils.Const.ScoreColumnKind.FeatureContribution)] +[assembly: LoadableClass(typeof(IDataScorerTransform), typeof(FeatureContributionScorer), typeof(FeatureContributionScorer.Arguments), + typeof(SignatureDataScorer), "Feature Contribution Scorer", "fcc", "wtf", "fct", "FeatureContributionCalculationScorer", MetadataUtils.Const.ScoreColumnKind.FeatureContribution)] -[assembly: LoadableClass(typeof(ISchemaBindableMapper), typeof(FeatureContributionCalculationTransform), typeof(FeatureContributionCalculationTransform.Arguments), - typeof(SignatureBindableMapper), "Feature Contribution Mapper", "fct", MetadataUtils.Const.ScoreColumnKind.FeatureContribution)] +[assembly: LoadableClass(typeof(ISchemaBindableMapper), typeof(FeatureContributionScorer), typeof(FeatureContributionScorer.Arguments), + typeof(SignatureBindableMapper), "Feature Contribution Mapper", "fcc", "wtf", "fct", MetadataUtils.Const.ScoreColumnKind.FeatureContribution)] -[assembly: LoadableClass(typeof(ISchemaBindableMapper), typeof(FeatureContributionCalculationTransform), null, typeof(SignatureLoadModel), - "Feature Contribution Mapper", FeatureContributionCalculationTransform.MapperLoaderSignature)] +[assembly: LoadableClass(typeof(ISchemaBindableMapper), typeof(FeatureContributionScorer), null, typeof(SignatureLoadModel), + "Feature Contribution Mapper", FeatureContributionScorer.MapperLoaderSignature)] -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// - /// Feature Contribution Calculation Transform. + /// Used only by the command line API for scoring and calculation of feature contribution. /// - /// - /// The Feature Contribution Calculation Transform scores the model on an input dataset and - /// computes model-specific contribution scores for each feature. See the sample below for - /// an example of how to compute feature importance using the Feature Contribution Calculation Transform. - /// - /// - /// - /// - /// - /// - public sealed class FeatureContributionCalculationTransform + internal sealed class FeatureContributionScorer { // Apparently, loader signature is limited in length to 24 characters. internal const string MapperLoaderSignature = "WTFBindable"; - private const int MaxTopBottom = 1000; - public sealed class Arguments : ScorerArgumentsBase + internal sealed class Arguments : ScorerArgumentsBase { [Argument(ArgumentType.AtMostOnce, HelpText = "Number of top contributions", SortOrder = 1)] public int Top = 10; @@ -64,15 +50,16 @@ public sealed class Arguments : ScorerArgumentsBase // REVIEW: the scorer currently ignores the 'suffix' argument from the base class. It should respect it. } - public static IDataScorerTransform Create(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema) + // Factory method for SignatureDataScorer. + private static IDataScorerTransform Create(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(data, nameof(data)); env.CheckValue(mapper, nameof(mapper)); - if (args.Top <= 0 || args.Top > MaxTopBottom) - throw env.Except($"Number of top contribution must be in range (0,{MaxTopBottom}]"); - if (args.Bottom <= 0 || args.Bottom > MaxTopBottom) - throw env.Except($"Number of bottom contribution must be in range (0,{MaxTopBottom}]"); + if (args.Top< 0) + throw env.Except($"Number of top contribution must be non negative"); + if (args.Bottom < 0) + throw env.Except($"Number of bottom contribution must be non negative"); var contributionMapper = mapper as RowMapper; env.CheckParam(mapper != null, nameof(mapper), "Unexpected mapper"); @@ -82,45 +69,24 @@ public static IDataScorerTransform Create(IHostEnvironment env, Arguments args, return scoredPipe; } - public static ISchemaBindableMapper Create(IHostEnvironment env, Arguments args, IPredictor predictor) + // Factory method for SignatureBindableMapper. + private static ISchemaBindableMapper Create(IHostEnvironment env, Arguments args, IPredictor predictor) { Contracts.CheckValue(env, nameof(env)); - env.CheckValue(args, nameof(args)); env.CheckValue(predictor, nameof(predictor)); - if (args.Top <= 0 || args.Top > MaxTopBottom) - throw env.Except($"Number of top contribution must be in range (0,{MaxTopBottom}]"); - if (args.Bottom <= 0 || args.Bottom > MaxTopBottom) - throw env.Except($"Number of bottom contribution must be in range (0,{MaxTopBottom}]"); - var pred = predictor as IFeatureContributionMapper; env.CheckParam(pred != null, nameof(predictor), "Predictor doesn't support getting feature contributions"); return new BindableMapper(env, pred, args.Top, args.Bottom, args.Normalize, args.Stringify); } - public static IDataScorerTransform Create(IHostEnvironment env, Arguments args, IDataView data, IPredictor predictor, string features = DefaultColumnNames.Features) - { - Contracts.CheckValue(env, nameof(env)); - env.CheckValue(args, nameof(args)); - env.CheckValue(predictor, nameof(predictor)); - if (args.Top <= 0 || args.Top > MaxTopBottom) - throw env.Except($"Number of top contribution must be in range (0,{MaxTopBottom}]"); - if (args.Bottom <= 0 || args.Bottom > MaxTopBottom) - throw env.Except($"Number of bottom contribution must be in range (0,{MaxTopBottom}]"); - - var roles = new List>(); - roles.Add(new KeyValuePair(RoleMappedSchema.ColumnRole.Feature, features)); - var schema = new RoleMappedSchema(data.Schema, roles); - - var mapper = Create(env, args, predictor); - var boundMapper = mapper.Bind(env, schema); - return Create(env, args, data, boundMapper, null); - } - - public static ISchemaBindableMapper Create(IHostEnvironment env, ModelLoadContext ctx) - { - return new BindableMapper(env, ctx); - } + // Factory method for SignatureLoadModel. + private static ISchemaBindableMapper Create(IHostEnvironment env, ModelLoadContext ctx) + => new BindableMapper(env, ctx); + /// + /// Holds the definition of the getters for the FeatureContribution column. It also contains the generic mapper that is used to score the Predictor. + /// This is only used by the command line API. + /// private sealed class BindableMapper : ISchemaBindableMapper, ICanSaveModel, IPredictor { private readonly int _topContributionsCount; @@ -140,7 +106,7 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, loaderSignature: MapperLoaderSignature, - loaderAssemblyName: typeof(FeatureContributionCalculationTransform).Assembly.FullName); + loaderAssemblyName: typeof(FeatureContributionScorer).Assembly.FullName); } public PredictionKind PredictionKind => Predictor.PredictionKind; @@ -150,10 +116,10 @@ public BindableMapper(IHostEnvironment env, IFeatureContributionMapper predictor Contracts.CheckValue(env, nameof(env)); _env = env; _env.CheckValue(predictor, nameof(predictor)); - if (topContributionsCount <= 0 || topContributionsCount > MaxTopBottom) - throw env.Except($"Number of top contribution must be in range (0,{MaxTopBottom}]"); - if (bottomContributionsCount <= 0 || bottomContributionsCount > MaxTopBottom) - throw env.Except($"Number of bottom contribution must be in range (0,{MaxTopBottom}]"); + if (topContributionsCount < 0) + throw env.Except($"Number of top contribution must be non negative"); + if (bottomContributionsCount < 0) + throw env.Except($"Number of bottom contribution must be non negative"); _topContributionsCount = topContributionsCount; _bottomContributionsCount = bottomContributionsCount; @@ -172,16 +138,18 @@ public BindableMapper(IHostEnvironment env, ModelLoadContext ctx) ctx.CheckAtModel(GetVersionInfo()); // *** Binary format *** + // IFeatureContributionMapper: Predictor // int: topContributionsCount // int: bottomContributionsCount // bool: normalize // bool: stringify + ctx.LoadModel(env, out Predictor, ModelFileUtils.DirPredictor); GenericMapper = ScoreUtils.GetSchemaBindableMapper(_env, Predictor, null); _topContributionsCount = ctx.Reader.ReadInt32(); - Contracts.CheckDecode(0 < _topContributionsCount && _topContributionsCount <= MaxTopBottom); + Contracts.CheckDecode(0 <= _topContributionsCount); _bottomContributionsCount = ctx.Reader.ReadInt32(); - Contracts.CheckDecode(0 < _bottomContributionsCount && _bottomContributionsCount <= MaxTopBottom); + Contracts.CheckDecode(0 <= _bottomContributionsCount); _normalize = ctx.Reader.ReadBoolByte(); Stringify = ctx.Reader.ReadBoolByte(); } @@ -200,38 +168,39 @@ public void Save(ModelSaveContext ctx) ctx.SetVersionInfo(GetVersionInfo()); // *** Binary format *** + // IFeatureContributionMapper: Predictor // int: topContributionsCount // int: bottomContributionsCount // bool: normalize // bool: stringify - ctx.SaveModel(Predictor, ModelFileUtils.DirPredictor); - Contracts.Assert(0 < _topContributionsCount && _topContributionsCount <= MaxTopBottom); + ctx.SaveModel(Predictor, ModelFileUtils.DirPredictor); + Contracts.Assert(0 <= _topContributionsCount); ctx.Writer.Write(_topContributionsCount); - Contracts.Assert(0 < _bottomContributionsCount && _bottomContributionsCount <= MaxTopBottom); + Contracts.Assert(0 <= _bottomContributionsCount); ctx.Writer.Write(_bottomContributionsCount); ctx.Writer.WriteBoolByte(_normalize); ctx.Writer.WriteBoolByte(Stringify); } - public Delegate GetTextContributionGetter(IRow input, int colSrc, VBuffer> slotNames) + public Delegate GetTextContributionGetter(Row input, int colSrc, VBuffer> slotNames) { Contracts.CheckValue(input, nameof(input)); - Contracts.Check(0 <= colSrc && colSrc < input.Schema.ColumnCount); - var typeSrc = input.Schema.GetColumnType(colSrc); + Contracts.Check(0 <= colSrc && colSrc < input.Schema.Count); + var typeSrc = input.Schema[colSrc].Type; - Func>, ValueGetter>> del = GetTextValueGetter; + Func>, ValueGetter>> del = GetTextValueGetter; var meth = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(typeSrc.RawType); return (Delegate)meth.Invoke(this, new object[] { input, colSrc, slotNames }); } - public Delegate GetContributionGetter(IRow input, int colSrc) + public Delegate GetContributionGetter(Row input, int colSrc) { Contracts.CheckValue(input, nameof(input)); - Contracts.Check(0 <= colSrc && colSrc < input.Schema.ColumnCount); + Contracts.Check(0 <= colSrc && colSrc < input.Schema.Count); - var typeSrc = input.Schema.GetColumnType(colSrc); - Func>> del = GetValueGetter; + var typeSrc = input.Schema[colSrc].Type; + Func>> del = GetValueGetter; // REVIEW: Assuming Feature contributions will be VBuffer. // For multiclass LR it needs to be(VBuffer[]. @@ -241,15 +210,14 @@ public Delegate GetContributionGetter(IRow input, int colSrc) private ReadOnlyMemory GetSlotName(int index, VBuffer> slotNames) { - var count = slotNames.GetValues().Length; - _env.Assert(count > index || count == 0 && slotNames.Length > index); + _env.Assert(0 <= index && index < slotNames.Length); var slotName = slotNames.GetItemOrDefault(index); return slotName.IsEmpty ? new ReadOnlyMemory($"f{index}".ToCharArray()) : slotName; } - private ValueGetter> GetTextValueGetter(IRow input, int colSrc, VBuffer> slotNames) + private ValueGetter> GetTextValueGetter(Row input, int colSrc, VBuffer> slotNames) { Contracts.AssertValue(input); Contracts.AssertValue(Predictor); @@ -264,16 +232,12 @@ private ValueGetter> GetTextValueGetter(IRow input, i { featureGetter(ref features); map(in features, ref contributions); - var indices = new Span(); - var values = new Span(); - if (contributions.IsDense) - Utils.GetIdentityPermutation(contributions.Length).AsSpan().CopyTo(indices); - else - contributions.GetIndices().CopyTo(indices); - contributions.GetValues().CopyTo(values); + var editor = VBufferEditor.CreateFromBuffer(ref contributions); + var indices = contributions.IsDense ? Utils.GetIdentityPermutation(contributions.Length) : editor.Indices; + var values = editor.Values; var count = values.Length; var sb = new StringBuilder(); - GenericSpanSortHelper.Sort(indices, values, 0, count); + GenericSpanSortHelper.Sort(values, indices, 0, count); for (var i = 0; i < count; i++) { var val = values[i]; @@ -292,14 +256,14 @@ private ValueGetter> GetTextValueGetter(IRow input, i }; } - private ValueGetter> GetValueGetter(IRow input, int colSrc) + private ValueGetter> GetValueGetter(Row input, int colSrc) { Contracts.AssertValue(input); Contracts.AssertValue(Predictor); var featureGetter = input.GetGetter(colSrc); - // REVIEW: Scorer can do call to Sparicification\Norm routine. + // REVIEW: Scorer can call Sparsification\Norm routine. var map = Predictor.GetFeatureContributionMapper>(_topContributionsCount, _bottomContributionsCount, _normalize); var features = default(TSrc); @@ -322,20 +286,24 @@ private static void CheckSchemaValid(IExceptionContext ectx, RoleMappedSchema sc } } + /// + /// Maps a schema from input columns to output columns. Keeps track of the input columns that are needed for the mapping. + /// private sealed class RowMapper : ISchemaBoundRowMapper { private readonly IHostEnvironment _env; private readonly ISchemaBoundRowMapper _genericRowMapper; private readonly BindableMapper _parent; - private readonly ISchema _outputSchema; - private readonly ISchema _outputGenericSchema; + private readonly Schema _outputSchema; + private readonly Schema _outputGenericSchema; private VBuffer> _slotNames; public RoleMappedSchema InputRoleMappedSchema { get; } public Schema InputSchema => InputRoleMappedSchema.Schema; + private Schema.Column FeatureColumn => InputRoleMappedSchema.Feature.Value; - public Schema Schema { get; } + public Schema OutputSchema { get; } public ISchemaBindableMapper Bindable => _parent; @@ -345,7 +313,7 @@ public RowMapper(IHostEnvironment env, BindableMapper parent, RoleMappedSchema s _env = env; _env.AssertValue(schema); _env.AssertValue(parent); - _env.AssertValue(schema.Feature); + _env.Assert(schema.Feature.HasValue); _parent = parent; InputRoleMappedSchema = schema; var genericMapper = parent.GenericMapper.Bind(_env, schema); @@ -353,23 +321,23 @@ public RowMapper(IHostEnvironment env, BindableMapper parent, RoleMappedSchema s if (parent.Stringify) { - _outputSchema = new SimpleSchema(_env, - new KeyValuePair(DefaultColumnNames.FeatureContributions, TextType.Instance)); - if (InputSchema.HasSlotNames(InputRoleMappedSchema.Feature.Index, InputRoleMappedSchema.Feature.Type.VectorSize)) - InputSchema.GetMetadata(MetadataUtils.Kinds.SlotNames, InputRoleMappedSchema.Feature.Index, - ref _slotNames); + var builder = new SchemaBuilder(); + builder.AddColumn(DefaultColumnNames.FeatureContributions, TextType.Instance, null); + _outputSchema = builder.GetSchema(); + if (FeatureColumn.HasSlotNames(FeatureColumn.Type.VectorSize)) + FeatureColumn.Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref _slotNames); else - _slotNames = VBufferUtils.CreateEmpty>(InputRoleMappedSchema.Feature.Type.VectorSize); + _slotNames = VBufferUtils.CreateEmpty>(FeatureColumn.Type.VectorSize); } else { - _outputSchema = new FeatureContributionSchema(_env, DefaultColumnNames.FeatureContributions, - new VectorType(NumberType.R4, schema.Feature.Type.AsVector), - InputSchema, InputRoleMappedSchema.Feature.Index); + _outputSchema = Schema.Create(new FeatureContributionSchema(_env, DefaultColumnNames.FeatureContributions, + new VectorType(NumberType.R4, FeatureColumn.Type as VectorType), + InputSchema, FeatureColumn.Index)); } - _outputGenericSchema = _genericRowMapper.Schema; - Schema = new CompositeSchema(new ISchema[] { _outputGenericSchema, _outputSchema, }).AsSchema; + _outputGenericSchema = _genericRowMapper.OutputSchema; + OutputSchema = new CompositeSchema(new Schema[] { _outputGenericSchema, _outputSchema, }).AsSchema; } /// @@ -377,36 +345,36 @@ public RowMapper(IHostEnvironment env, BindableMapper parent, RoleMappedSchema s /// public Func GetDependencies(Func predicate) { - for (int i = 0; i < Schema.ColumnCount; i++) + for (int i = 0; i < OutputSchema.Count; i++) { if (predicate(i)) - return col => col == InputRoleMappedSchema.Feature.Index; + return col => col == FeatureColumn.Index; } return col => false; } - public IRow GetOutputRow(IRow input, Func predicate, out Action disposer) + public Row GetRow(Row input, Func active) { Contracts.AssertValue(input); - Contracts.AssertValue(predicate); - var totalColumnsCount = 1 + _outputGenericSchema.ColumnCount; + Contracts.AssertValue(active); + var totalColumnsCount = 1 + _outputGenericSchema.Count; var getters = new Delegate[totalColumnsCount]; - if (predicate(totalColumnsCount - 1)) + if (active(totalColumnsCount - 1)) { getters[totalColumnsCount - 1] = _parent.Stringify - ? _parent.GetTextContributionGetter(input, InputRoleMappedSchema.Feature.Index, _slotNames) - : _parent.GetContributionGetter(input, InputRoleMappedSchema.Feature.Index); + ? _parent.GetTextContributionGetter(input, FeatureColumn.Index, _slotNames) + : _parent.GetContributionGetter(input, FeatureColumn.Index); } - var genericRow = _genericRowMapper.GetRow(input, GetGenericPredicate(predicate), out disposer); - for (var i = 0; i < _outputGenericSchema.ColumnCount; i++) + var genericRow = _genericRowMapper.GetRow(input, GetGenericPredicate(active)); + for (var i = 0; i < _outputGenericSchema.Count; i++) { if (genericRow.IsColumnActive(i)) getters[i] = RowCursorUtils.GetGetterAsDelegate(genericRow, i); } - return new SimpleRow(Schema, input, getters); + return new SimpleRow(OutputSchema, genericRow, getters); } public Func GetGenericPredicate(Func predicate) @@ -416,12 +384,7 @@ public Func GetGenericPredicate(Func predicate) public IEnumerable> GetInputColumnRoles() { - yield return RoleMappedSchema.ColumnRole.Feature.Bind(InputRoleMappedSchema.Feature.Name); - } - - public IRow GetRow(IRow input, Func active, out Action disposer) - { - return GetOutputRow(input, active, out disposer); + yield return RoleMappedSchema.ColumnRole.Feature.Bind(FeatureColumn.Name); } } @@ -446,8 +409,8 @@ public FeatureContributionSchema(IExceptionContext ectx, string columnName, Colu _ectx.CheckNonEmpty(columnName, nameof(columnName)); _parentSchema = parentSchema; _featureCol = featureCol; - _featureVectorSize = _parentSchema.GetColumnType(_featureCol).VectorSize; - _hasSlotNames = _parentSchema.HasSlotNames(_featureCol, _featureVectorSize); + _featureVectorSize = _parentSchema[_featureCol].Type.VectorSize; + _hasSlotNames = _parentSchema[_featureCol].HasSlotNames(_featureVectorSize); _names = new string[] { columnName }; _types = new ColumnType[] { columnType }; @@ -491,7 +454,7 @@ public void GetMetadata(string kind, int col, ref TValue value) { _ectx.CheckParam(col == 0, nameof(col)); if (kind == MetadataUtils.Kinds.SlotNames && _hasSlotNames) - _parentSchema.GetMetadata(kind, _featureCol, ref value); + _parentSchema[_featureCol].Metadata.GetValue(kind, ref value); else throw MetadataUtils.ExceptGetMetadata(); } diff --git a/src/Microsoft.ML.Data/Scorers/GenericScorer.cs b/src/Microsoft.ML.Data/Scorers/GenericScorer.cs index b1a1955452..cefa5af97e 100644 --- a/src/Microsoft.ML.Data/Scorers/GenericScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/GenericScorer.cs @@ -4,12 +4,11 @@ using System; using System.Collections.Generic; +using Microsoft.ML; using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Model.Onnx; -using Microsoft.ML.Runtime.Model.Pfa; +using Microsoft.ML.Model; +using Microsoft.ML.Model.Onnx; +using Microsoft.ML.Model.Pfa; [assembly: LoadableClass(typeof(GenericScorer), typeof(GenericScorer.Arguments), typeof(SignatureDataScorer), "Generic Scorer", GenericScorer.LoadName, "Generic")] @@ -17,7 +16,7 @@ [assembly: LoadableClass(typeof(GenericScorer), null, typeof(SignatureLoadDataTransform), "Generic Scorer", GenericScorer.LoaderSignature)] -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// This class is a scorer that passes through all the ISchemaBound columns without adding any "derived columns". @@ -138,9 +137,9 @@ private static VersionInfo GetVersionInfo() private const string RegistrationName = "GenericScore"; private readonly Bindings _bindings; - protected override BindingsBase GetBindings() => _bindings; + private protected override BindingsBase GetBindings() => _bindings; - public override Schema Schema { get; } + public override Schema OutputSchema { get; } bool ICanSavePfa.CanSavePfa => (Bindable as ICanSavePfa)?.CanSavePfa == true; @@ -149,7 +148,8 @@ private static VersionInfo GetVersionInfo() /// /// The entry point for creating a . /// - public GenericScorer(IHostEnvironment env, ScorerArgumentsBase args, IDataView data, + [BestFriend] + internal GenericScorer(IHostEnvironment env, ScorerArgumentsBase args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema) : base(env, data, RegistrationName, Contracts.CheckRef(mapper, nameof(mapper)).Bindable) { @@ -160,7 +160,7 @@ public GenericScorer(IHostEnvironment env, ScorerArgumentsBase args, IDataView d var rowMapper = mapper as ISchemaBoundRowMapper; Host.CheckParam(rowMapper != null, nameof(mapper), "mapper should implement ISchemaBoundRowMapper"); _bindings = Bindings.Create(data.Schema, rowMapper, args.Suffix); - Schema = Schema.Create(_bindings); + OutputSchema = _bindings.AsSchema; } /// @@ -170,7 +170,7 @@ private GenericScorer(IHostEnvironment env, GenericScorer transform, IDataView d : base(env, data, RegistrationName, transform.Bindable) { _bindings = transform._bindings.ApplyToSchema(env, data.Schema); - Schema = Schema.Create(_bindings); + OutputSchema = _bindings.AsSchema; } /// @@ -181,7 +181,7 @@ private GenericScorer(IHost host, ModelLoadContext ctx, IDataView input) { Contracts.AssertValue(ctx); _bindings = Bindings.Create(ctx, host, Bindable, input.Schema); - Schema = Schema.Create(_bindings); + OutputSchema = _bindings.AsSchema; } /// @@ -199,7 +199,7 @@ public static GenericScorer Create(IHostEnvironment env, ModelLoadContext ctx, I return h.Apply("Loading Model", ch => new GenericScorer(h, ctx, input)); } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.SetVersionInfo(GetVersionInfo()); @@ -261,12 +261,12 @@ public override IDataTransform ApplyToData(IHostEnvironment env, IDataView newSo return new GenericScorer(env, this, newSource); } - protected override Delegate[] GetGetters(IRow output, Func predicate) + protected override Delegate[] GetGetters(Row output, Func predicate) { Host.Assert(_bindings.DerivedColumnCount == 0); Host.AssertValue(output); Host.AssertValue(predicate); - Host.Assert(output.Schema == _bindings.RowMapper.Schema); + Host.Assert(output.Schema == _bindings.RowMapper.OutputSchema); return GetGettersFromRow(output, predicate); } diff --git a/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs b/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs index 66f30f565c..52e963b3a2 100644 --- a/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs @@ -2,19 +2,18 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Data.IO; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Model.Onnx; -using Microsoft.ML.Runtime.Model.Pfa; -using Microsoft.ML.Runtime.Numeric; -using Newtonsoft.Json.Linq; using System; using System.Collections.Generic; using System.Threading; +using Microsoft.ML; +using Microsoft.ML.Data; +using Microsoft.ML.Data.IO; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Model.Onnx; +using Microsoft.ML.Model.Pfa; +using Microsoft.ML.Numeric; +using Newtonsoft.Json.Linq; using Float = System.Single; [assembly: LoadableClass(typeof(MultiClassClassifierScorer), @@ -28,7 +27,7 @@ [assembly: LoadableClass(typeof(ISchemaBindableMapper), typeof(MultiClassClassifierScorer.LabelNameBindableMapper), null, typeof(SignatureLoadModel), "Multi-Class Label-Name Mapper", MultiClassClassifierScorer.LabelNameBindableMapper.LoaderSignature)] -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { public sealed class MultiClassClassifierScorer : PredictedLabelScorerBase { @@ -79,7 +78,6 @@ public sealed class LabelNameBindableMapper : ISchemaBindableMapper, ICanSaveMod public VectorType Type => _type; bool ICanSavePfa.CanSavePfa => (_bindable as ICanSavePfa)?.CanSavePfa == true; bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => (_bindable as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true; - public ISchemaBindableMapper InnerBindable => _bindable; private static VersionInfo GetVersionInfo() { @@ -130,19 +128,14 @@ private LabelNameBindableMapper(IHost host, ModelLoadContext ctx) ColumnType type; object value; _host.CheckDecode(saver.TryLoadTypeAndValue(ctx.Reader.BaseStream, out type, out value)); - _host.CheckDecode(type.IsVector); + _type = type as VectorType; + _host.CheckDecode(_type != null); _host.CheckDecode(value != null); - _type = type.AsVector; _getter = Utils.MarshalInvoke(DecodeInit, _type.ItemType.RawType, value); _metadataKind = ctx.Header.ModelVerReadable >= VersionAddedMetadataKind ? ctx.LoadNonEmptyString() : MetadataUtils.Kinds.SlotNames; } - public ISchemaBindableMapper Clone(ISchemaBindableMapper inner) - { - return new LabelNameBindableMapper(_host, inner, _type, _getter, _metadataKind, _canWrap); - } - private Delegate DecodeInit(object value) { _host.CheckDecode(value is VBuffer); @@ -151,7 +144,10 @@ private Delegate DecodeInit(object value) return buffGetter; } - public static ISchemaBindableMapper Create(IHostEnvironment env, ModelLoadContext ctx) + /// + /// Method corresponding to . + /// + private static ISchemaBindableMapper Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); var h = env.Register(LoaderSignature); @@ -214,7 +210,7 @@ bool IBindableCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, RoleMappedSchema schema, s return ((IBindableCanSaveOnnx)_bindable).SaveAsOnnx(ctx, schema, outputNames); } - public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema) + ISchemaBoundMapper ISchemaBindableMapper.Bind(IHostEnvironment env, RoleMappedSchema schema) { var innerBound = _bindable.Bind(env, schema); if (_canWrap != null && !_canWrap(innerBound, _type)) @@ -223,7 +219,7 @@ public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema) return Utils.MarshalInvoke(CreateBound, _type.ItemType.RawType, env, (ISchemaBoundRowMapper)innerBound, _type, _getter, _metadataKind, _canWrap); } - public static ISchemaBoundMapper CreateBound(IHostEnvironment env, ISchemaBoundRowMapper mapper, VectorType type, Delegate getter, + internal static ISchemaBoundMapper CreateBound(IHostEnvironment env, ISchemaBoundRowMapper mapper, VectorType type, Delegate getter, string metadataKind, Func canWrap) { Contracts.AssertValue(env); @@ -245,15 +241,13 @@ private sealed class Bound : ISchemaBoundRowMapper private readonly VectorType _labelNameType; private readonly string _metadataKind; private readonly ValueGetter> _labelNameGetter; - private readonly SchemaImpl _outSchema; // Lazily initialized by the property. private LabelNameBindableMapper _bindable; private readonly Func _canWrap; - public Schema Schema => _outSchema.AsSchema; - public RoleMappedSchema InputRoleMappedSchema => _mapper.InputRoleMappedSchema; public Schema InputSchema => _mapper.InputSchema; + public Schema OutputSchema { get; } public ISchemaBindableMapper Bindable { @@ -286,140 +280,77 @@ public Bound(IHostEnvironment env, ISchemaBoundRowMapper mapper, VectorType type _mapper = mapper; int scoreIdx; - bool result = mapper.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out scoreIdx); + bool result = mapper.OutputSchema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out scoreIdx); if (!result) throw env.ExceptParam(nameof(mapper), "Mapper did not have a '{0}' column", MetadataUtils.Const.ScoreValueKind.Score); _labelNameType = type; _labelNameGetter = getter; _metadataKind = metadataKind; - - _outSchema = new SchemaImpl(mapper.Schema, scoreIdx, _labelNameType, _labelNameGetter, _metadataKind); _canWrap = canWrap; - } - public Func GetDependencies(Func predicate) - { - return _mapper.GetDependencies(predicate); + OutputSchema = DecorateOutputSchema(mapper.OutputSchema, scoreIdx, _labelNameType, _labelNameGetter, _metadataKind); } - public IEnumerable> GetInputColumnRoles() - { - return _mapper.GetInputColumnRoles(); - } - - public IRow GetRow(IRow input, Func predicate, out Action disposer) - { - var innerRow = _mapper.GetRow(input, predicate, out disposer); - return new RowImpl(innerRow, Schema); - } - - private sealed class SchemaImpl : ISchema + /// + /// Append label names to score column as its metadata. + /// + private Schema DecorateOutputSchema(Schema partialSchema, int scoreColumnIndex, VectorType labelNameType, + ValueGetter> labelNameGetter, string labelNameKind) { - private readonly ISchema _parent; - private readonly int _scoreCol; - private readonly VectorType _labelNameType; - private readonly MetadataUtils.MetadataGetter> _labelNameGetter; - private readonly string _metadataKind; - - public Schema AsSchema { get; } - - public int ColumnCount { get { return _parent.ColumnCount; } } - - public SchemaImpl(ISchema parent, int col, VectorType type, ValueGetter> getter, string metadataKind) - { - Contracts.AssertValue(parent); - Contracts.Assert(0 <= col && col < parent.ColumnCount); - Contracts.AssertValue(type); - Contracts.AssertValue(getter); - Contracts.Assert(type.ItemType.RawType == typeof(T)); - Contracts.AssertNonEmpty(metadataKind); - Contracts.Assert(parent.GetMetadataTypeOrNull(metadataKind, col) == null); - - _parent = parent; - _scoreCol = col; - _labelNameType = type; - // We change to this metadata variant of the getter to enable the marshal call to work. - _labelNameGetter = (int c, ref VBuffer val) => getter(ref val); - _metadataKind = metadataKind; - - AsSchema = Schema.Create(this); - } - - public bool TryGetColumnIndex(string name, out int col) - { - return _parent.TryGetColumnIndex(name, out col); - } - - public string GetColumnName(int col) - { - return _parent.GetColumnName(col); - } - - public ColumnType GetColumnType(int col) + var builder = new SchemaBuilder(); + // Sequentially add columns so that the order of them is not changed comparing with the schema in the mapper + // that computes score column. + for (int i = 0; i < partialSchema.Count; ++i) { - return _parent.GetColumnType(col); + var meta = new MetadataBuilder(); + if (i == scoreColumnIndex) + { + // Add label names for score column. + meta.Add(partialSchema[i].Metadata, selector: s => s != labelNameKind); + meta.Add(labelNameKind, labelNameType, labelNameGetter); + } + else + { + // Copy all existing metadata because this transform only affects score column. + meta.Add(partialSchema[i].Metadata, selector: s => true); + } + // Instead of appending extra metadata to the existing score column, we create new one because + // metadata is read-only. + builder.AddColumn(partialSchema[i].Name, partialSchema[i].Type, meta.GetMetadata()); } + return builder.GetSchema(); + } - public IEnumerable> GetMetadataTypes(int col) - { - var result = _parent.GetMetadataTypes(col); - if (col == _scoreCol) - return result.Prepend(_labelNameType.GetPair(_metadataKind)); - return result; - } + public Func GetDependencies(Func predicate) => _mapper.GetDependencies(predicate); - public ColumnType GetMetadataTypeOrNull(string kind, int col) - { - if (col == _scoreCol && kind == _metadataKind) - return _labelNameType; - return _parent.GetMetadataTypeOrNull(kind, col); - } + public IEnumerable> GetInputColumnRoles() => _mapper.GetInputColumnRoles(); - public void GetMetadata(string kind, int col, ref TValue value) - { - if (col == _scoreCol && kind == _metadataKind) - { - _labelNameGetter.Marshal(0, ref value); - return; - } - _parent.GetMetadata(kind, col, ref value); - } + public Row GetRow(Row input, Func predicate) + { + var innerRow = _mapper.GetRow(input, predicate); + return new RowImpl(innerRow, OutputSchema); } - private sealed class RowImpl : IRow + private sealed class RowImpl : WrappingRow { - private readonly IRow _row; private readonly Schema _schema; - public long Batch { get { return _row.Batch; } } - public long Position { get { return _row.Position; } } // The schema is of course the only difference from _row. - public Schema Schema => _schema; + public override Schema Schema => _schema; - public RowImpl(IRow row, Schema schema) + public RowImpl(Row row, Schema schema) + : base(row) { Contracts.AssertValue(row); Contracts.AssertValue(schema); - _row = row; _schema = schema; } - public bool IsColumnActive(int col) - { - return _row.IsColumnActive(col); - } + public override bool IsColumnActive(int col) => Input.IsColumnActive(col); - public ValueGetter GetGetter(int col) - { - return _row.GetGetter(col); - } - - public ValueGetter GetIdGetter() - { - return _row.GetIdGetter(); - } + public override ValueGetter GetGetter(int col) => Input.GetGetter(col); } } } @@ -439,9 +370,9 @@ private static ISchemaBoundMapper WrapIfNeeded(IHostEnvironment env, ISchemaBoun // them as slot name metadata. But there are a number of conditions for this to actually // happen, so we test those here. If these are not - if (trainSchema == null || trainSchema.Label == null) + if (trainSchema?.Label == null) return mapper; // We don't even have a label identified in a training schema. - var keyType = trainSchema.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, trainSchema.Label.Index); + var keyType = trainSchema.Label.Value.Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues)?.Type; if (keyType == null || !CanWrap(mapper, keyType)) return mapper; @@ -461,7 +392,7 @@ private static ISchemaBoundMapper WrapIfNeeded(IHostEnvironment env, ISchemaBoun /// from the model of a bindable mapper) /// Whether we can call with /// this mapper and expect it to succeed - public static bool CanWrap(ISchemaBoundMapper mapper, ColumnType labelNameType) + private static bool CanWrap(ISchemaBoundMapper mapper, ColumnType labelNameType) { Contracts.AssertValue(mapper); Contracts.AssertValue(labelNameType); @@ -470,19 +401,20 @@ public static bool CanWrap(ISchemaBoundMapper mapper, ColumnType labelNameType) if (rowMapper == null) return false; // We could cover this case, but it is of no practical worth as far as I see, so I decline to do so. - ISchema outSchema = mapper.Schema; + var outSchema = mapper.OutputSchema; int scoreIdx; + var scoreCol = outSchema.GetColumnOrNull(MetadataUtils.Const.ScoreValueKind.Score); if (!outSchema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out scoreIdx)) return false; // The mapper doesn't even publish a score column to attach the metadata to. - if (outSchema.GetMetadataTypeOrNull(MetadataUtils.Kinds.SlotNames, scoreIdx) != null) + if (outSchema[scoreIdx].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.SlotNames)?.Type != null) return false; // The mapper publishes a score column, and already produces its own slot names. - var scoreType = outSchema.GetColumnType(scoreIdx); + var scoreType = outSchema[scoreIdx].Type; // Check that the type is vector, and is of compatible size with the score output. return labelNameType.IsVector && labelNameType.VectorSize == scoreType.VectorSize; } - public static ISchemaBoundMapper WrapCore(IHostEnvironment env, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema) + private static ISchemaBoundMapper WrapCore(IHostEnvironment env, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema) { Contracts.AssertValue(env); env.AssertValue(mapper); @@ -490,7 +422,7 @@ public static ISchemaBoundMapper WrapCore(IHostEnvironment env, ISchemaBoundM env.Assert(mapper is ISchemaBoundRowMapper); // Key values from the training schema label, will map to slot names of the score output. - var type = trainSchema.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, trainSchema.Label.Index); + var type = trainSchema.Label.Value.Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues)?.Type; env.AssertValue(type); env.Assert(type.IsVector); @@ -498,14 +430,14 @@ public static ISchemaBoundMapper WrapCore(IHostEnvironment env, ISchemaBoundM ValueGetter> getter = (ref VBuffer value) => { - trainSchema.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, - trainSchema.Label.Index, ref value); + trainSchema.Label.Value.Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref value); }; - return LabelNameBindableMapper.CreateBound(env, (ISchemaBoundRowMapper)mapper, type.AsVector, getter, MetadataUtils.Kinds.SlotNames, CanWrap); + return LabelNameBindableMapper.CreateBound(env, (ISchemaBoundRowMapper)mapper, type as VectorType, getter, MetadataUtils.Kinds.SlotNames, CanWrap); } - public MultiClassClassifierScorer(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema) + [BestFriend] + internal MultiClassClassifierScorer(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema) : base(args, env, data, WrapIfNeeded(env, mapper, trainSchema), trainSchema, RegistrationName, MetadataUtils.Const.ScoreColumnKind.MultiClassClassification, MetadataUtils.Const.ScoreValueKind.Score, OutputTypeMatches, GetPredColType) { @@ -523,7 +455,10 @@ private MultiClassClassifierScorer(IHost host, ModelLoadContext ctx, IDataView i // } - public static MultiClassClassifierScorer Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + /// + /// Corresponds to . + /// + private static MultiClassClassifierScorer Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) { Contracts.CheckValue(env, nameof(env)); var h = env.Register(RegistrationName); @@ -533,7 +468,7 @@ public static MultiClassClassifierScorer Create(IHostEnvironment env, ModelLoadC return h.Apply("Loading Model", ch => new MultiClassClassifierScorer(h, ctx, input)); } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.SetVersionInfo(GetVersionInfo()); @@ -552,10 +487,10 @@ public override IDataTransform ApplyToData(IHostEnvironment env, IDataView newSo return new MultiClassClassifierScorer(env, this, newSource); } - protected override Delegate GetPredictedLabelGetter(IRow output, out Delegate scoreGetter) + protected override Delegate GetPredictedLabelGetter(Row output, out Delegate scoreGetter) { Host.AssertValue(output); - Host.Assert(output.Schema == Bindings.RowMapper.Schema); + Host.Assert(output.Schema == Bindings.RowMapper.OutputSchema); Host.Assert(output.IsColumnActive(Bindings.ScoreColumnIndex)); ValueGetter> mapperScoreGetter = output.GetGetter>(Bindings.ScoreColumnIndex); @@ -587,20 +522,14 @@ protected override Delegate GetPredictedLabelGetter(IRow output, out Delegate sc return predFn; } - protected override JToken PredictedLabelPfa(string[] mapperOutputs) + private protected override JToken PredictedLabelPfa(string[] mapperOutputs) { Contracts.Assert(Utils.Size(mapperOutputs) == 1); return PfaUtils.Call("a.argmax", mapperOutputs[0]); } - private static ColumnType GetPredColType(ColumnType scoreType, ISchemaBoundRowMapper mapper) - { - return new KeyType(DataKind.U4, 0, scoreType.VectorSize); - } + private static ColumnType GetPredColType(ColumnType scoreType, ISchemaBoundRowMapper mapper) => new KeyType(DataKind.U4, 0, scoreType.VectorSize); - private static bool OutputTypeMatches(ColumnType scoreType) - { - return scoreType.IsKnownSizeVector && scoreType.ItemType == NumberType.Float; - } + private static bool OutputTypeMatches(ColumnType scoreType) => scoreType.IsKnownSizeVector && scoreType.ItemType == NumberType.Float; } } diff --git a/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs b/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs index 6345914b90..4cea6fe59f 100644 --- a/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs +++ b/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs @@ -2,18 +2,17 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Data; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Model.Onnx; -using Microsoft.ML.Runtime.Model.Pfa; -using Newtonsoft.Json.Linq; using System; using System.Collections.Generic; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Model.Onnx; +using Microsoft.ML.Model.Pfa; +using Newtonsoft.Json.Linq; using Float = System.Single; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// Class for scorers that compute on additional "PredictedLabel" column from the score column. @@ -30,7 +29,8 @@ public abstract class ThresholdArgumentsBase : ScorerArgumentsBase public string ThresholdColumn = MetadataUtils.Const.ScoreValueKind.Score; } - protected sealed class BindingsImpl : BindingsBase + [BestFriend] + private protected sealed class BindingsImpl : BindingsBase { // Column index of the score column in Mapper's schema. public readonly int ScoreColumnIndex; @@ -41,7 +41,7 @@ protected sealed class BindingsImpl : BindingsBase private readonly MetadataUtils.MetadataGetter> _getScoreColumnKind; private readonly MetadataUtils.MetadataGetter> _getScoreValueKind; - private readonly IRow _predColMetadata; + private readonly Schema.Metadata _predColMetadata; private BindingsImpl(Schema input, ISchemaBoundRowMapper mapper, string suffix, string scoreColumnKind, bool user, int scoreColIndex, ColumnType predColType) : base(input, mapper, suffix, user, DefaultColumnNames.PredictedLabel) @@ -59,42 +59,39 @@ private BindingsImpl(Schema input, ISchemaBoundRowMapper mapper, string suffix, // REVIEW: This logic is very specific to multiclass, which is deeply // regrettable, but the class structure as designed and the status of this schema // bearing object makes pushing the logic into the multiclass scorer almost impossible. - if (predColType.IsKey) + if (predColType is KeyType predColKeyType && predColKeyType.Count > 0) { - ColumnType scoreSlotsType = mapper.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.SlotNames, scoreColIndex); - if (scoreSlotsType != null && scoreSlotsType.IsKnownSizeVector && - scoreSlotsType.VectorSize == predColType.KeyCount) + var scoreColMetadata = mapper.OutputSchema[scoreColIndex].Metadata; + + var slotColumn = scoreColMetadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.SlotNames); + if (slotColumn?.Type is VectorType slotColVecType && slotColVecType.Size == predColKeyType.Count) { - Contracts.Assert(scoreSlotsType.VectorSize > 0); - IColumn col = Utils.MarshalInvoke(KeyValueMetadataFromMetadata, - scoreSlotsType.RawType, mapper.Schema, scoreColIndex, MetadataUtils.Kinds.SlotNames); - _predColMetadata = RowColumnUtils.GetRow(null, col); + Contracts.Assert(slotColVecType.Size > 0); + _predColMetadata = Utils.MarshalInvoke(KeyValueMetadataFromMetadata, slotColVecType.RawType, + scoreColMetadata, slotColumn.Value); } else { - scoreSlotsType = mapper.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.TrainingLabelValues, scoreColIndex); - if (scoreSlotsType != null && scoreSlotsType.IsKnownSizeVector && - scoreSlotsType.VectorSize == predColType.KeyCount) + var trainLabelColumn = scoreColMetadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.TrainingLabelValues); + if (trainLabelColumn?.Type is VectorType trainLabelColVecType && trainLabelColVecType.Size == predColKeyType.Count) { - Contracts.Assert(scoreSlotsType.VectorSize > 0); - IColumn col = Utils.MarshalInvoke(KeyValueMetadataFromMetadata, - scoreSlotsType.RawType, mapper.Schema, scoreColIndex, MetadataUtils.Kinds.TrainingLabelValues); - _predColMetadata = RowColumnUtils.GetRow(null, col); + Contracts.Assert(trainLabelColVecType.Size > 0); + _predColMetadata = Utils.MarshalInvoke(KeyValueMetadataFromMetadata, trainLabelColVecType.RawType, + scoreColMetadata, trainLabelColumn.Value); } } } } - private static IColumn KeyValueMetadataFromMetadata(ISchema schema, int col, string metadataName) + private static Schema.Metadata KeyValueMetadataFromMetadata(Schema.Metadata meta, Schema.Column metaCol) { - Contracts.AssertValue(schema); - Contracts.Assert(0 <= col && col < schema.ColumnCount); - var type = schema.GetMetadataTypeOrNull(metadataName, col); - Contracts.AssertValue(type); - Contracts.Assert(type.RawType == typeof(T)); - - ValueGetter getter = (ref T val) => schema.GetMetadata(metadataName, col, ref val); - return RowColumnUtils.GetColumn(MetadataUtils.Kinds.KeyValues, type, getter); + Contracts.AssertValue(meta); + Contracts.Assert(0 <= metaCol.Index && metaCol.Index < meta.Schema.Count); + Contracts.Assert(metaCol.Type.RawType == typeof(T)); + var getter = meta.GetGetter(metaCol.Index); + var builder = new MetadataBuilder(); + builder.Add(MetadataUtils.Kinds.KeyValues, metaCol.Type, meta.GetGetter(metaCol.Index)); + return builder.GetMetadata(); } public static BindingsImpl Create(Schema input, ISchemaBoundRowMapper mapper, string suffix, @@ -115,7 +112,7 @@ public BindingsImpl ApplyToSchema(Schema input, ISchemaBindableMapper bindable, env.AssertValue(input); env.AssertValue(bindable); - string scoreCol = RowMapper.Schema.GetColumnName(ScoreColumnIndex); + string scoreCol = RowMapper.OutputSchema[ScoreColumnIndex].Name; var schema = new RoleMappedSchema(input, RowMapper.GetInputColumnRoles()); // Checks compatibility of the predictor input types. @@ -123,7 +120,7 @@ public BindingsImpl ApplyToSchema(Schema input, ISchemaBindableMapper bindable, var rowMapper = mapper as ISchemaBoundRowMapper; env.CheckParam(rowMapper != null, nameof(bindable), "Mapper must implement ISchemaBoundRowMapper"); int mapperScoreColumn; - bool tmp = rowMapper.Schema.TryGetColumnIndex(scoreCol, out mapperScoreColumn); + bool tmp = rowMapper.OutputSchema.TryGetColumnIndex(scoreCol, out mapperScoreColumn); env.Check(tmp, "Mapper doesn't have expected score column"); return new BindingsImpl(input, rowMapper, Suffix, ScoreColumnKind, true, mapperScoreColumn, PredColType); @@ -153,9 +150,9 @@ public static BindingsImpl Create(ModelLoadContext ctx, Schema input, // Find the score column of the mapper. int scoreColIndex; - env.CheckDecode(mapper.Schema.TryGetColumnIndex(scoreCol, out scoreColIndex)); + env.CheckDecode(mapper.OutputSchema.TryGetColumnIndex(scoreCol, out scoreColIndex)); - var scoreType = mapper.Schema.GetColumnType(scoreColIndex); + var scoreType = mapper.OutputSchema[scoreColIndex].Type; env.CheckDecode(outputTypeMatches(scoreType)); var predColType = getPredColType(scoreType, rowMapper); @@ -172,7 +169,7 @@ public override void Save(ModelSaveContext ctx) // int: id of the column used for deriving the predicted label column SaveBase(ctx); ctx.SaveNonEmptyString(ScoreColumnKind); - ctx.SaveNonEmptyString(RowMapper.Schema.GetColumnName(ScoreColumnIndex)); + ctx.SaveNonEmptyString(RowMapper.OutputSchema[ScoreColumnIndex].Name); } protected override ColumnType GetColumnTypeCore(int iinfo) @@ -195,8 +192,8 @@ protected override IEnumerable> GetMetadataType if (_predColMetadata != null) { var sch = _predColMetadata.Schema; - for (int i = 0; i < sch.ColumnCount; ++i) - yield return new KeyValuePair(sch.GetColumnName(i), sch.GetColumnType(i)); + for (int i = 0; i < sch.Count; ++i) + yield return new KeyValuePair(sch[i].Name, sch[i].Type); } } foreach (var pair in base.GetMetadataTypesCore(iinfo)) @@ -218,7 +215,7 @@ protected override ColumnType GetMetadataTypeCore(string kind, int iinfo) { int mcol; if (_predColMetadata.Schema.TryGetColumnIndex(kind, out mcol)) - return _predColMetadata.Schema.GetColumnType(mcol); + return _predColMetadata.Schema[mcol].Type; } return base.GetMetadataTypeCore(kind, iinfo); } @@ -275,15 +272,18 @@ public override Func GetActiveMapperColumns(bool[] active) } } - protected readonly BindingsImpl Bindings; - protected override BindingsBase GetBindings() => Bindings; - public override Schema Schema { get; } + [BestFriend] + private protected readonly BindingsImpl Bindings; + [BestFriend] + private protected sealed override BindingsBase GetBindings() => Bindings; + public override Schema OutputSchema { get; } bool ICanSavePfa.CanSavePfa => (Bindable as ICanSavePfa)?.CanSavePfa == true; bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => (Bindable as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true; - protected PredictedLabelScorerBase(ScorerArgumentsBase args, IHostEnvironment env, IDataView data, + [BestFriend] + private protected PredictedLabelScorerBase(ScorerArgumentsBase args, IHostEnvironment env, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema, string registrationName, string scoreColKind, string scoreColName, Func outputTypeMatches, Func getPredColType) : base(env, data, registrationName, Contracts.CheckRef(mapper, nameof(mapper)).Bindable) @@ -298,15 +298,15 @@ protected PredictedLabelScorerBase(ScorerArgumentsBase args, IHostEnvironment en Host.CheckParam(rowMapper != null, nameof(mapper), "mapper should implement " + nameof(ISchemaBoundRowMapper)); int scoreColIndex; - if (!mapper.Schema.TryGetColumnIndex(scoreColName, out scoreColIndex)) + if (!mapper.OutputSchema.TryGetColumnIndex(scoreColName, out scoreColIndex)) throw Host.ExceptParam(nameof(scoreColName), "mapper does not contain a column '{0}'", scoreColName); - var scoreType = mapper.Schema.GetColumnType(scoreColIndex); + var scoreType = mapper.OutputSchema[scoreColIndex].Type; Host.Check(outputTypeMatches(scoreType), "Unexpected predictor output type"); var predColType = getPredColType(scoreType, rowMapper); Bindings = BindingsImpl.Create(data.Schema, rowMapper, args.Suffix, scoreColKind, scoreColIndex, predColType); - Schema = Schema.Create(Bindings); + OutputSchema = Bindings.AsSchema; } protected PredictedLabelScorerBase(IHostEnvironment env, PredictedLabelScorerBase transform, @@ -314,10 +314,11 @@ protected PredictedLabelScorerBase(IHostEnvironment env, PredictedLabelScorerBas : base(env, newSource, registrationName, transform.Bindable) { Bindings = transform.Bindings.ApplyToSchema(newSource.Schema, Bindable, env); - Schema = Schema.Create(Bindings); + OutputSchema = Bindings.AsSchema; } - protected PredictedLabelScorerBase(IHost host, ModelLoadContext ctx, IDataView input, + [BestFriend] + private protected PredictedLabelScorerBase(IHost host, ModelLoadContext ctx, IDataView input, Func outputTypeMatches, Func getPredColType) : base(host, ctx, input) { @@ -327,10 +328,10 @@ protected PredictedLabelScorerBase(IHost host, ModelLoadContext ctx, IDataView i Host.AssertValue(getPredColType); Bindings = BindingsImpl.Create(ctx, input.Schema, host, Bindable, outputTypeMatches, getPredColType); - Schema = Schema.Create(Bindings); + OutputSchema = Bindings.AsSchema; } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { Host.AssertValue(ctx); Bindings.Save(ctx); @@ -363,7 +364,8 @@ void ISaveAsPfa.SaveAsPfa(BoundPfaContext ctx) ctx.DeclareVar(derivedName, predictedLabelExpression); } - protected abstract JToken PredictedLabelPfa(string[] mapperOutputs); + [BestFriend] + private protected abstract JToken PredictedLabelPfa(string[] mapperOutputs); void ISaveAsOnnx.SaveAsOnnx(OnnxContext ctx) => SaveAsOnnxCore(ctx); @@ -403,13 +405,13 @@ protected override bool WantParallelCursors(Func predicate) return Bindings.AnyNewColumnsActive(predicate); } - protected override Delegate[] GetGetters(IRow output, Func predicate) + protected override Delegate[] GetGetters(Row output, Func predicate) { Host.Assert(Bindings.DerivedColumnCount == 1); Host.AssertValue(output); Host.AssertValue(predicate); - Host.Assert(output.Schema == Bindings.RowMapper.Schema); - Host.Assert(Bindings.InfoCount == output.Schema.ColumnCount + 1); + Host.Assert(output.Schema == Bindings.RowMapper.OutputSchema); + Host.Assert(Bindings.InfoCount == output.Schema.Count + 1); var getters = new Delegate[Bindings.InfoCount]; @@ -435,10 +437,10 @@ protected override Delegate[] GetGetters(IRow output, Func predicate) return getters; } - protected abstract Delegate GetPredictedLabelGetter(IRow output, out Delegate scoreGetter); + protected abstract Delegate GetPredictedLabelGetter(Row output, out Delegate scoreGetter); protected void EnsureCachedPosition(ref long cachedPosition, ref TScore score, - IRow boundRow, ValueGetter scoreGetter) + Row boundRow, ValueGetter scoreGetter) { if (cachedPosition != boundRow.Position) { diff --git a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs index c6f529429f..c0e7b9c0a9 100644 --- a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs +++ b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs @@ -3,11 +3,10 @@ // See the LICENSE file in the project root for more information. using System.IO; +using Microsoft.ML; using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Data.IO; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Data.IO; +using Microsoft.ML.Model; [assembly: LoadableClass(typeof(BinaryPredictionTransformer>), typeof(BinaryPredictionTransformer), null, typeof(SignatureLoadModel), "", BinaryPredictionTransformer.LoaderSignature)] @@ -27,7 +26,7 @@ [assembly: LoadableClass(typeof(ClusteringPredictionTransformer>>), typeof(ClusteringPredictionTransformer), null, typeof(SignatureLoadModel), "", ClusteringPredictionTransformer.LoaderSignature)] -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// @@ -47,7 +46,8 @@ public abstract class PredictionTransformerBase : IPredictionTr protected const string DirModel = "Model"; protected const string DirTransSchema = "TrainSchema"; protected readonly IHost Host; - protected ISchemaBindableMapper BindableMapper; + [BestFriend] + private protected ISchemaBindableMapper BindableMapper; protected Schema TrainSchema; public bool IsRowToRowMapper => true; @@ -178,7 +178,7 @@ public SingleFeaturePredictionTransformerBase(IHost host, TModel model, Schema t else if (!trainSchema.TryGetColumnIndex(featureColumn, out int col)) throw Host.ExceptSchemaMismatch(nameof(featureColumn), RoleMappedSchema.ColumnRole.Feature.Value, featureColumn); else - FeatureColumnType = trainSchema.GetColumnType(col); + FeatureColumnType = trainSchema[col].Type; BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model); } @@ -193,7 +193,7 @@ internal SingleFeaturePredictionTransformerBase(IHost host, ModelLoadContext ctx else if (!TrainSchema.TryGetColumnIndex(FeatureColumn, out int col)) throw Host.ExceptSchemaMismatch(nameof(FeatureColumn), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn); else - FeatureColumnType = TrainSchema.GetColumnType(col); + FeatureColumnType = TrainSchema[col].Type; BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, Model); } @@ -206,8 +206,8 @@ public override Schema GetOutputSchema(Schema inputSchema) { if (!inputSchema.TryGetColumnIndex(FeatureColumn, out int col)) throw Host.ExceptSchemaMismatch(nameof(inputSchema), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn, FeatureColumnType.ToString(), null); - if (!inputSchema.GetColumnType(col).Equals(FeatureColumnType)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn, FeatureColumnType.ToString(), inputSchema.GetColumnType(col).ToString()); + if (!inputSchema[col].Type.Equals(FeatureColumnType)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn, FeatureColumnType.ToString(), inputSchema[col].Type.ToString()); } return Transform(new EmptyDataView(Host, inputSchema)).Schema; @@ -245,7 +245,7 @@ public sealed class AnomalyPredictionTransformer : SingleFeaturePredicti public AnomalyPredictionTransformer(IHostEnvironment env, TModel model, Schema inputSchema, string featureColumn, float threshold = 0f, string thresholdColumn = DefaultColumnNames.Score) - : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(BinaryPredictionTransformer)), model, inputSchema, featureColumn) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(AnomalyPredictionTransformer)),model, inputSchema, featureColumn) { Host.CheckNonEmpty(thresholdColumn, nameof(thresholdColumn)); Threshold = threshold; diff --git a/src/Microsoft.ML.Data/Scorers/QuantileRegressionScorer.cs b/src/Microsoft.ML.Data/Scorers/QuantileRegressionScorer.cs index b46ba4042e..388cd236c0 100644 --- a/src/Microsoft.ML.Data/Scorers/QuantileRegressionScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/QuantileRegressionScorer.cs @@ -4,10 +4,10 @@ using System; using System.Linq; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Internallearn; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Internallearn; [assembly: LoadableClass(typeof(IDataScorerTransform), typeof(QuantileRegressionScorerTransform), typeof(QuantileRegressionScorerTransform.Arguments), typeof(SignatureDataScorer), "Quantile Regression Scorer", "QuantileRegressionScorer", MetadataUtils.Const.ScoreColumnKind.QuantileRegression)] @@ -15,9 +15,9 @@ [assembly: LoadableClass(typeof(ISchemaBindableMapper), typeof(QuantileRegressionScorerTransform), typeof(QuantileRegressionScorerTransform.Arguments), typeof(SignatureBindableMapper), "Quantile Regression Mapper", "QuantileRegressionScorer", MetadataUtils.Const.ScoreColumnKind.QuantileRegression)] -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { - public static class QuantileRegressionScorerTransform + internal static class QuantileRegressionScorerTransform { public sealed class Arguments : ScorerArgumentsBase { @@ -25,12 +25,18 @@ public sealed class Arguments : ScorerArgumentsBase public string Quantiles = "0,0.25,0.5,0.75,1"; } - public static IDataScorerTransform Create(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema) + /// + /// Constructor corresponding to . + /// + private static IDataScorerTransform Create(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema) { return new GenericScorer(env, args, data, mapper, trainSchema); } - public static ISchemaBindableMapper Create(IHostEnvironment env, Arguments args, IPredictor predictor) + /// + /// Constructor corresponding to . + /// + private static ISchemaBindableMapper Create(IHostEnvironment env, Arguments args, IPredictor predictor) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(args, nameof(args)); diff --git a/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs b/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs index e0d7d5e9ac..c3274c5748 100644 --- a/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs +++ b/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs @@ -6,11 +6,10 @@ using System.Collections.Generic; using System.Linq; using System.Reflection; -using Microsoft.ML.Data; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Model; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// Base class for scoring rows independently. This assumes that all columns produced by the @@ -19,7 +18,8 @@ namespace Microsoft.ML.Runtime.Data /// public abstract class RowToRowScorerBase : RowToRowMapperTransformBase, IDataScorerTransform { - public abstract class BindingsBase : ScorerBindingsBase + [BestFriend] + private protected abstract class BindingsBase : ScorerBindingsBase { public readonly ISchemaBoundRowMapper RowMapper; @@ -30,16 +30,19 @@ protected BindingsBase(Schema schema, ISchemaBoundRowMapper mapper, string suffi } } - protected readonly ISchemaBindableMapper Bindable; + [BestFriend] + private protected readonly ISchemaBindableMapper Bindable; - protected RowToRowScorerBase(IHostEnvironment env, IDataView input, string registrationName, ISchemaBindableMapper bindable) + [BestFriend] + private protected RowToRowScorerBase(IHostEnvironment env, IDataView input, string registrationName, ISchemaBindableMapper bindable) : base(env, registrationName, input) { Contracts.AssertValue(bindable); Bindable = bindable; } - protected RowToRowScorerBase(IHost host, ModelLoadContext ctx, IDataView input) + [BestFriend] + private protected RowToRowScorerBase(IHost host, ModelLoadContext ctx, IDataView input) : base(host, input) { ctx.LoadModel(host, out Bindable, "SchemaBindableMapper"); @@ -56,7 +59,8 @@ public sealed override void Save(ModelSaveContext ctx) /// /// The main save method handles saving the _bindable. This should do everything else. /// - protected abstract void SaveCore(ModelSaveContext ctx); + [BestFriend] + private protected abstract void SaveCore(ModelSaveContext ctx); /// /// For the ITransformTemplate implementation. @@ -66,7 +70,8 @@ public sealed override void Save(ModelSaveContext ctx) /// /// Derived classes provide the specific bindings object. /// - protected abstract BindingsBase GetBindings(); + [BestFriend] + private protected abstract BindingsBase GetBindings(); /// /// Produces the set of active columns for the scorer (as a bool[] of length bindings.ColumnCount), @@ -80,7 +85,7 @@ private static bool[] GetActive(BindingsBase bindings, Func predicate Contracts.Assert(active.Length == bindings.ColumnCount); var activeInput = bindings.GetActiveInput(predicate); - Contracts.Assert(activeInput.Length == bindings.Input.ColumnCount); + Contracts.Assert(activeInput.Length == bindings.Input.Count); // Get a predicate that determines which Mapper outputs are active. predicateMapper = bindings.GetActiveMapperColumns(active); @@ -114,7 +119,7 @@ private static bool[] GetActive(BindingsBase bindings, Func predicate /// protected abstract bool WantParallelCursors(Func predicate); - protected override IRowCursor GetRowCursorCore(Func predicate, IRandom rand = null) + protected override RowCursor GetRowCursorCore(Func predicate, Random rand = null) { Contracts.AssertValue(predicate); Contracts.AssertValueOrNull(rand); @@ -124,10 +129,10 @@ protected override IRowCursor GetRowCursorCore(Func predicate, IRando Func predicateMapper; var active = GetActive(bindings, predicate, out predicateInput, out predicateMapper); var input = Source.GetRowCursor(predicateInput, rand); - return new RowCursor(Host, this, input, active, predicateMapper); + return new Cursor(Host, this, input, active, predicateMapper); } - public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, IRandom rand = null) + public override RowCursor[] GetRowCursorSet(Func predicate, int n, Random rand = null) { Host.CheckValue(predicate, nameof(predicate)); Host.CheckValueOrNull(rand); @@ -136,27 +141,28 @@ public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolid Func predicateInput; Func predicateMapper; var active = GetActive(bindings, predicate, out predicateInput, out predicateMapper); - var inputs = Source.GetRowCursorSet(out consolidator, predicateInput, n, rand); + var inputs = Source.GetRowCursorSet(predicateInput, n, rand); Contracts.AssertNonEmpty(inputs); if (inputs.Length == 1 && n > 1 && WantParallelCursors(predicate) && (Source.GetRowCount() ?? int.MaxValue) > n) - inputs = DataViewUtils.CreateSplitCursors(out consolidator, Host, inputs[0], n); + inputs = DataViewUtils.CreateSplitCursors(Host, inputs[0], n); Contracts.AssertNonEmpty(inputs); - var cursors = new IRowCursor[inputs.Length]; + var cursors = new RowCursor[inputs.Length]; for (int i = 0; i < inputs.Length; i++) - cursors[i] = new RowCursor(Host, this, inputs[i], active, predicateMapper); + cursors[i] = new Cursor(Host, this, inputs[i], active, predicateMapper); return cursors; } - protected override Delegate[] CreateGetters(IRow input, Func active, out Action disp) + protected override Delegate[] CreateGetters(Row input, Func active, out Action disp) { var bindings = GetBindings(); Func predicateInput; Func predicateMapper; GetActive(bindings, active, out predicateInput, out predicateMapper); - var output = bindings.RowMapper.GetRow(input, predicateMapper, out disp); + var output = bindings.RowMapper.GetRow(input, predicateMapper); Func activeInfos = iinfo => active(bindings.MapIinfoToCol(iinfo)); + disp = output.Dispose; return GetGetters(output, activeInfos); } @@ -173,14 +179,14 @@ protected override Func GetDependenciesCore(Func predicate /// Create and fill an array of getters of size InfoCount. The indices of the non-null entries in the /// result should be exactly those for which predicate(iinfo) is true. /// - protected abstract Delegate[] GetGetters(IRow output, Func predicate); + protected abstract Delegate[] GetGetters(Row output, Func predicate); - protected static Delegate[] GetGettersFromRow(IRow row, Func predicate) + protected static Delegate[] GetGettersFromRow(Row row, Func predicate) { Contracts.AssertValue(row); Contracts.AssertValue(predicate); - var getters = new Delegate[row.Schema.ColumnCount]; + var getters = new Delegate[row.Schema.Count]; for (int col = 0; col < getters.Length; col++) { if (predicate(col)) @@ -189,22 +195,22 @@ protected static Delegate[] GetGettersFromRow(IRow row, Func predicat return getters; } - protected static Delegate GetGetterFromRow(IRow row, int col) + protected static Delegate GetGetterFromRow(Row row, int col) { Contracts.AssertValue(row); - Contracts.Assert(0 <= col && col < row.Schema.ColumnCount); + Contracts.Assert(0 <= col && col < row.Schema.Count); Contracts.Assert(row.IsColumnActive(col)); - var type = row.Schema.GetColumnType(col); - Func> del = GetGetterFromRow; + var type = row.Schema[col].Type; + Func> del = GetGetterFromRow; var meth = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(type.RawType); return (Delegate)meth.Invoke(null, new object[] { row, col }); } - protected static ValueGetter GetGetterFromRow(IRow output, int col) + protected static ValueGetter GetGetterFromRow(Row output, int col) { Contracts.AssertValue(output); - Contracts.Assert(0 <= col && col < output.Schema.ColumnCount); + Contracts.Assert(0 <= col && col < output.Schema.Count); Contracts.Assert(output.IsColumnActive(col)); return output.GetGetter(col); } @@ -215,16 +221,17 @@ protected override int MapColumnIndex(out bool isSrc, int col) return bindings.MapColumnIndex(out isSrc, col); } - private sealed class RowCursor : SynchronizedCursorBase, IRowCursor + private sealed class Cursor : SynchronizedCursorBase { private readonly BindingsBase _bindings; private readonly bool[] _active; private readonly Delegate[] _getters; - private readonly Action _disposer; + private readonly Row _output; + private bool _disposed; - public Schema Schema { get; } + public override Schema Schema { get; } - public RowCursor(IChannelProvider provider, RowToRowScorerBase parent, IRowCursor input, bool[] active, Func predicateMapper) + public Cursor(IChannelProvider provider, RowToRowScorerBase parent, RowCursor input, bool[] active, Func predicateMapper) : base(provider, input) { Ch.AssertValue(parent); @@ -232,36 +239,40 @@ public RowCursor(IChannelProvider provider, RowToRowScorerBase parent, IRowCurso Ch.AssertValue(predicateMapper); _bindings = parent.GetBindings(); - Schema = parent.Schema; + Schema = parent.OutputSchema; Ch.Assert(active.Length == _bindings.ColumnCount); _active = active; - var output = _bindings.RowMapper.GetRow(input, predicateMapper, out _disposer); + _output = _bindings.RowMapper.GetRow(input, predicateMapper); try { - Ch.Assert(output.Schema == _bindings.RowMapper.Schema); - _getters = parent.GetGetters(output, iinfo => active[_bindings.MapIinfoToCol(iinfo)]); + Ch.Assert(_output.Schema == _bindings.RowMapper.OutputSchema); + _getters = parent.GetGetters(_output, iinfo => active[_bindings.MapIinfoToCol(iinfo)]); } catch (Exception) { - _disposer?.Invoke(); + _output.Dispose(); throw; } } - public override void Dispose() + protected override void Dispose(bool disposing) { - _disposer?.Invoke(); - base.Dispose(); + if (_disposed) + return; + if (disposing) + _output.Dispose(); + _disposed = true; + base.Dispose(disposing); } - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { Ch.Check(0 <= col && col < _bindings.ColumnCount); return _active[col]; } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Ch.Check(IsColumnActive(col)); @@ -290,11 +301,12 @@ public abstract class ScorerArgumentsBase } /// - /// Base bindings for a scorer based on an ISchemaBoundMapper. This assumes that input schema columns + /// Base bindings for a scorer based on an . This assumes that input schema columns /// are echoed, followed by zero or more derived columns, followed by the mapper generated columns. /// The names of the derived columns and mapper generated columns have an optional suffix appended. /// - public abstract class ScorerBindingsBase : ColumnBindingsBase + [BestFriend] + internal abstract class ScorerBindingsBase : ColumnBindingsBase { /// /// The schema bound mapper. @@ -337,14 +349,14 @@ private static string[] GetOutputNames(ISchemaBoundMapper mapper, string suffix, Contracts.AssertValueOrNull(suffix); Contracts.AssertValue(namesDerived); - var schema = mapper.Schema; - int count = namesDerived.Length + schema.ColumnCount; + var schema = mapper.OutputSchema; + int count = namesDerived.Length + schema.Count; var res = new string[count]; int dst = 0; for (int i = 0; i < namesDerived.Length; i++) res[dst++] = namesDerived[i] + suffix; - for (int i = 0; i < schema.ColumnCount; i++) - res[dst++] = schema.GetColumnName(i) + suffix; + for (int i = 0; i < schema.Count; i++) + res[dst++] = schema[i].Name + suffix; Contracts.Assert(dst == count); return res; } @@ -400,7 +412,7 @@ protected void SaveBase(ModelSaveContext ctx) protected override ColumnType GetColumnTypeCore(int iinfo) { Contracts.Assert(DerivedColumnCount <= iinfo && iinfo < InfoCount); - return Mapper.Schema.GetColumnType(iinfo - DerivedColumnCount); + return Mapper.OutputSchema[iinfo - DerivedColumnCount].Type; } protected override IEnumerable> GetMetadataTypesCore(int iinfo) @@ -410,7 +422,7 @@ protected override IEnumerable> GetMetadataType yield return MetadataUtils.ScoreColumnSetIdType.GetPair(MetadataUtils.Kinds.ScoreColumnSetId); if (iinfo < DerivedColumnCount) yield break; - foreach (var pair in Mapper.Schema.GetMetadataTypes(iinfo - DerivedColumnCount)) + foreach (var pair in Mapper.OutputSchema[iinfo - DerivedColumnCount].Metadata.Schema.Select(c => new KeyValuePair(c.Name, c.Type))) yield return pair; } @@ -421,7 +433,7 @@ protected override ColumnType GetMetadataTypeCore(string kind, int iinfo) return MetadataUtils.ScoreColumnSetIdType; if (iinfo < DerivedColumnCount) return null; - return Mapper.Schema.GetMetadataTypeOrNull(kind, iinfo - DerivedColumnCount); + return Mapper.OutputSchema[iinfo - DerivedColumnCount].Metadata.Schema.GetColumnOrNull(kind)?.Type; } protected override void GetMetadataCore(string kind, int iinfo, ref TValue value) @@ -435,7 +447,7 @@ protected override void GetMetadataCore(string kind, int iinfo, ref TVal default: if (iinfo < DerivedColumnCount) throw MetadataUtils.ExceptGetMetadata(); - Mapper.Schema.GetMetadata(kind, iinfo - DerivedColumnCount, ref value); + Mapper.OutputSchema[iinfo - DerivedColumnCount].Metadata.GetValue(kind, ref value); break; } } @@ -452,8 +464,8 @@ public virtual Func GetActiveMapperColumns(bool[] active) return col => { - Contracts.Assert(0 <= col && col < Mapper.Schema.ColumnCount); - return 0 <= col && col < Mapper.Schema.ColumnCount && + Contracts.Assert(0 <= col && col < Mapper.OutputSchema.Count); + return 0 <= col && col < Mapper.OutputSchema.Count && active[MapIinfoToCol(col + DerivedColumnCount)]; }; } diff --git a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs index a2fdf96af4..2bb11d46ce 100644 --- a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs +++ b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs @@ -4,19 +4,18 @@ #pragma warning disable 420 // volatile with Interlocked.CompareExchange -using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Internallearn; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Model.Onnx; -using Microsoft.ML.Runtime.Model.Pfa; -using Newtonsoft.Json.Linq; using System; using System.Collections.Generic; using System.IO; using System.Reflection; +using Microsoft.ML; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Model.Onnx; +using Microsoft.ML.Model.Pfa; +using Newtonsoft.Json.Linq; using Float = System.Single; [assembly: LoadableClass(typeof(SchemaBindablePredictorWrapper), null, typeof(SignatureLoadModel), @@ -28,20 +27,20 @@ [assembly: LoadableClass(typeof(SchemaBindableBinaryPredictorWrapper), null, typeof(SignatureLoadModel), "Binary Classification Bindable Mapper", SchemaBindableBinaryPredictorWrapper.LoaderSignature)] -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { // REVIEW: Consider implementing ICanSaveAs(Code/Text/etc.) for these classes as well. /// /// This is a base class for wrapping s in an . /// - public abstract class SchemaBindablePredictorWrapperBase : ISchemaBindableMapper, ICanSaveModel, ICanSaveSummary, + internal abstract class SchemaBindablePredictorWrapperBase : ISchemaBindableMapper, ICanSaveModel, ICanSaveSummary, IBindableCanSavePfa, IBindableCanSaveOnnx { // The ctor guarantees that Predictor is non-null. It also ensures that either // ValueMapper or FloatPredictor is non-null (or both). With these guarantees, // the score value type (_scoreType) can be determined. protected readonly IPredictor Predictor; - protected readonly IValueMapper ValueMapper; + private protected readonly IValueMapper ValueMapper; protected readonly ColumnType ScoreType; bool ICanSavePfa.CanSavePfa => (ValueMapper as ICanSavePfa)?.CanSavePfa == true; @@ -115,17 +114,16 @@ bool IBindableCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, RoleMappedSchema schema, s [BestFriend] private protected virtual bool SaveAsOnnxCore(OnnxContext ctx, RoleMappedSchema schema, string[] outputNames) => false; - public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema) + ISchemaBoundMapper ISchemaBindableMapper.Bind(IHostEnvironment env, RoleMappedSchema schema) { Contracts.CheckValue(env, nameof(env)); using (var ch = env.Register("SchemaBindableWrapper").Start("Bind")) { ch.CheckValue(schema, nameof(schema)); - if (schema.Feature != null) + if (schema.Feature?.Type is ColumnType type) { // Ensure that the feature column type is compatible with the needed input type. - var type = schema.Feature.Type; var typeIn = ValueMapper != null ? ValueMapper.InputType : new VectorType(NumberType.Float); if (type != typeIn) { @@ -142,20 +140,21 @@ public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema) } } - protected abstract ISchemaBoundMapper BindCore(IChannel ch, RoleMappedSchema schema); + [BestFriend] + private protected abstract ISchemaBoundMapper BindCore(IChannel ch, RoleMappedSchema schema); - protected virtual Delegate GetPredictionGetter(IRow input, int colSrc) + protected virtual Delegate GetPredictionGetter(Row input, int colSrc) { Contracts.AssertValue(input); - Contracts.Assert(0 <= colSrc && colSrc < input.Schema.ColumnCount); + Contracts.Assert(0 <= colSrc && colSrc < input.Schema.Count); - var typeSrc = input.Schema.GetColumnType(colSrc); - Func> del = GetValueGetter; + var typeSrc = input.Schema[colSrc].Type; + Func> del = GetValueGetter; var meth = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(typeSrc.RawType, ScoreType.RawType); return (Delegate)meth.Invoke(this, new object[] { input, colSrc }); } - private ValueGetter GetValueGetter(IRow input, int colSrc) + private ValueGetter GetValueGetter(Row input, int colSrc) { Contracts.AssertValue(input); Contracts.Assert(ValueMapper != null); @@ -171,7 +170,7 @@ private ValueGetter GetValueGetter(IRow input, int colSrc) }; } - public void SaveSummary(TextWriter writer, RoleMappedSchema schema) + void ICanSaveSummary.SaveSummary(TextWriter writer, RoleMappedSchema schema) { var summarySaver = Predictor as ICanSaveSummary; if (summarySaver == null) @@ -191,48 +190,47 @@ protected sealed class SingleValueRowMapper : ISchemaBoundRowMapper private readonly SchemaBindablePredictorWrapperBase _parent; public RoleMappedSchema InputRoleMappedSchema { get; } - public Schema Schema { get; } + public Schema OutputSchema { get; } public ISchemaBindableMapper Bindable => _parent; public SingleValueRowMapper(RoleMappedSchema schema, SchemaBindablePredictorWrapperBase parent, Schema outputSchema) { Contracts.AssertValue(schema); Contracts.AssertValue(parent); - Contracts.AssertValue(schema.Feature); - Contracts.Assert(outputSchema.ColumnCount == 1); + Contracts.Assert(schema.Feature.HasValue); + Contracts.Assert(outputSchema.Count == 1); _parent = parent; InputRoleMappedSchema = schema; - Schema = outputSchema; + OutputSchema = outputSchema; } public Func GetDependencies(Func predicate) { - for (int i = 0; i < Schema.ColumnCount; i++) + for (int i = 0; i < OutputSchema.Count; i++) { if (predicate(i)) - return col => col == InputRoleMappedSchema.Feature.Index; + return col => col == InputRoleMappedSchema.Feature.Value.Index; } return col => false; } public IEnumerable> GetInputColumnRoles() { - yield return RoleMappedSchema.ColumnRole.Feature.Bind(InputRoleMappedSchema.Feature.Name); + yield return RoleMappedSchema.ColumnRole.Feature.Bind(InputRoleMappedSchema.Feature.Value.Name); } public Schema InputSchema => InputRoleMappedSchema.Schema; - public IRow GetRow(IRow input, Func predicate, out Action disposer) + public Row GetRow(Row input, Func predicate) { Contracts.AssertValue(input); Contracts.AssertValue(predicate); var getters = new Delegate[1]; if (predicate(0)) - getters[0] = _parent.GetPredictionGetter(input, InputRoleMappedSchema.Feature.Index); - disposer = null; - return new SimpleRow(Schema, input, getters); + getters[0] = _parent.GetPredictionGetter(input, InputRoleMappedSchema.Feature.Value.Index); + return new SimpleRow(OutputSchema, input, getters); } } } @@ -241,7 +239,7 @@ public IRow GetRow(IRow input, Func predicate, out Action disposer) /// This class is a wrapper for all s except for quantile regression predictors, /// and calibrated binary classification predictors. /// - public sealed class SchemaBindablePredictorWrapper : SchemaBindablePredictorWrapperBase + internal sealed class SchemaBindablePredictorWrapper : SchemaBindablePredictorWrapperBase { public const string LoaderSignature = "SchemaBindableWrapper"; private static VersionInfo GetVersionInfo() @@ -290,11 +288,11 @@ private protected override void SaveAsPfaCore(BoundPfaContext ctx, RoleMappedSch Contracts.CheckValue(ctx, nameof(ctx)); Contracts.CheckValue(schema, nameof(schema)); Contracts.Assert(ValueMapper is ISingleCanSavePfa); - Contracts.AssertValue(schema.Feature); + Contracts.Assert(schema.Feature.HasValue); Contracts.Assert(Utils.Size(outputNames) == 1); // Score. var mapper = (ISingleCanSavePfa)ValueMapper; // If the features column was not produced, we must hide the outputs. - var featureToken = ctx.TokenOrNullForName(schema.Feature.Name); + var featureToken = ctx.TokenOrNullForName(schema.Feature.Value.Name); if (featureToken == null) ctx.Hide(outputNames); var scoreToken = mapper.SaveAsPfa(ctx, featureToken); @@ -306,18 +304,17 @@ private protected override bool SaveAsOnnxCore(OnnxContext ctx, RoleMappedSchema Contracts.CheckValue(ctx, nameof(ctx)); Contracts.CheckValue(schema, nameof(schema)); Contracts.Assert(ValueMapper is ISingleCanSaveOnnx); - Contracts.AssertValue(schema.Feature); + Contracts.Assert(schema.Feature.HasValue); Contracts.Assert(Utils.Size(outputNames) <= 2); // PredictedLabel and/or Score. var mapper = (ISingleCanSaveOnnx)ValueMapper; - if (!ctx.ContainsColumn(schema.Feature.Name)) + string featName = schema.Feature.Value.Name; + if (!ctx.ContainsColumn(featName)) return false; - - Contracts.Assert(ctx.ContainsColumn(schema.Feature.Name)); - - return mapper.SaveAsOnnx(ctx, outputNames, ctx.GetVariableName(schema.Feature.Name)); + Contracts.Assert(ctx.ContainsColumn(featName)); + return mapper.SaveAsOnnx(ctx, outputNames, ctx.GetVariableName(featName)); } - protected override ISchemaBoundMapper BindCore(IChannel ch, RoleMappedSchema schema) + private protected override ISchemaBoundMapper BindCore(IChannel ch, RoleMappedSchema schema) { var outputSchema = Schema.Create(new ScoreMapperSchema(ScoreType, _scoreColumnKind)); return new SingleValueRowMapper(schema, this, outputSchema); @@ -353,7 +350,7 @@ private static string GetScoreColumnKind(IPredictor predictor) /// This is an wrapper for calibrated binary classification predictors. /// They need a separate wrapper because they return two values instead of one: the raw score and the probability. /// - public sealed class SchemaBindableBinaryPredictorWrapper : SchemaBindablePredictorWrapperBase + internal sealed class SchemaBindableBinaryPredictorWrapper : SchemaBindablePredictorWrapperBase { public const string LoaderSignature = "BinarySchemaBindable"; private static VersionInfo GetVersionInfo() @@ -401,11 +398,11 @@ private protected override void SaveAsPfaCore(BoundPfaContext ctx, RoleMappedSch Contracts.CheckValue(ctx, nameof(ctx)); Contracts.CheckValue(schema, nameof(schema)); Contracts.Assert(ValueMapper is IDistCanSavePfa); - Contracts.AssertValue(schema.Feature); + Contracts.Assert(schema.Feature.HasValue); Contracts.Assert(Utils.Size(outputNames) == 2); // Score and prob. var mapper = (IDistCanSavePfa)ValueMapper; // If the features column was not produced, we must hide the outputs. - string featureToken = ctx.TokenOrNullForName(schema.Feature.Name); + string featureToken = ctx.TokenOrNullForName(schema.Feature.Value.Name); if (featureToken == null) ctx.Hide(outputNames); @@ -423,15 +420,14 @@ private protected override bool SaveAsOnnxCore(OnnxContext ctx, RoleMappedSchema var mapper = ValueMapper as ISingleCanSaveOnnx; Contracts.CheckValue(mapper, nameof(mapper)); - Contracts.AssertValue(schema.Feature); + Contracts.Assert(schema.Feature.HasValue); Contracts.Assert(Utils.Size(outputNames) == 3); // Predicted Label, Score and Probablity. - if (!ctx.ContainsColumn(schema.Feature.Name)) + var featName = schema.Feature.Value.Name; + if (!ctx.ContainsColumn(featName)) return false; - - Contracts.Assert(ctx.ContainsColumn(schema.Feature.Name)); - - return mapper.SaveAsOnnx(ctx, outputNames, ctx.GetVariableName(schema.Feature.Name)); + Contracts.Assert(ctx.ContainsColumn(featName)); + return mapper.SaveAsOnnx(ctx, outputNames, ctx.GetVariableName(featName)); } private void CheckValid(out IValueMapperDist distMapper) @@ -451,7 +447,7 @@ private void CheckValid(out IValueMapperDist distMapper) "Invalid probability type for the IValueMapperDist"); } - protected override ISchemaBoundMapper BindCore(IChannel ch, RoleMappedSchema schema) + private protected override ISchemaBoundMapper BindCore(IChannel ch, RoleMappedSchema schema) { if (Predictor.PredictionKind != PredictionKind.BinaryClassification) ch.Warning("Scoring predictor of kind '{0}' as '{1}'.", Predictor.PredictionKind, PredictionKind.BinaryClassification); @@ -472,7 +468,8 @@ private sealed class CalibratedRowMapper : ISchemaBoundRowMapper public RoleMappedSchema InputRoleMappedSchema { get; } public Schema InputSchema => InputRoleMappedSchema.Schema; - public Schema Schema { get; } + public Schema OutputSchema { get; } + public ISchemaBindableMapper Bindable => _parent; public CalibratedRowMapper(RoleMappedSchema schema, SchemaBindableBinaryPredictorWrapper parent) @@ -480,15 +477,13 @@ public CalibratedRowMapper(RoleMappedSchema schema, SchemaBindableBinaryPredicto Contracts.AssertValue(parent); Contracts.Assert(parent._distMapper != null); Contracts.AssertValue(schema); - Contracts.AssertValueOrNull(schema.Feature); _parent = parent; InputRoleMappedSchema = schema; - Schema = Schema.Create(new BinaryClassifierSchema()); + OutputSchema = Schema.Create(new BinaryClassifierSchema()); - if (schema.Feature != null) + if (schema.Feature?.Type is ColumnType typeSrc) { - var typeSrc = InputRoleMappedSchema.Feature.Type; Contracts.Check(typeSrc.IsKnownSizeVector && typeSrc.ItemType == NumberType.Float, "Invalid feature column type"); } @@ -496,10 +491,10 @@ public CalibratedRowMapper(RoleMappedSchema schema, SchemaBindableBinaryPredicto public Func GetDependencies(Func predicate) { - for (int i = 0; i < Schema.ColumnCount; i++) + for (int i = 0; i < OutputSchema.Count; i++) { - if (predicate(i) && InputRoleMappedSchema.Feature != null) - return col => col == InputRoleMappedSchema.Feature.Index; + if (predicate(i) && InputRoleMappedSchema.Feature?.Index is int idx) + return col => col == idx; } return col => false; } @@ -509,7 +504,7 @@ public Func GetDependencies(Func predicate) yield return RoleMappedSchema.ColumnRole.Feature.Bind(InputRoleMappedSchema.Feature?.Name); } - private Delegate[] CreateGetters(IRow input, bool[] active) + private Delegate[] CreateGetters(Row input, bool[] active) { Contracts.Assert(Utils.Size(active) == 2); Contracts.Assert(_parent._distMapper != null); @@ -518,7 +513,7 @@ private Delegate[] CreateGetters(IRow input, bool[] active) if (active[0] || active[1]) { // Put all captured locals at this scope. - var featureGetter = InputRoleMappedSchema.Feature != null ? input.GetGetter>(InputRoleMappedSchema.Feature.Index) : null; + var featureGetter = InputRoleMappedSchema.Feature?.Index is int idx ? input.GetGetter>(idx) : null; Float prob = 0; Float score = 0; long cachedPosition = -1; @@ -552,7 +547,7 @@ private Delegate[] CreateGetters(IRow input, bool[] active) private static void EnsureCachedResultValueMapper(ValueMapper, Float, Float> mapper, ref long cachedPosition, ValueGetter> featureGetter, ref VBuffer features, - ref Float score, ref Float prob, IRow input) + ref Float score, ref Float prob, Row input) { Contracts.AssertValue(mapper); if (cachedPosition != input.Position) @@ -565,22 +560,22 @@ private static void EnsureCachedResultValueMapper(ValueMapper, Fl } } - public IRow GetRow(IRow input, Func predicate, out Action disposer) + public Row GetRow(Row input, Func predicate) { Contracts.AssertValue(input); - var active = Utils.BuildArray(Schema.ColumnCount, predicate); + var active = Utils.BuildArray(OutputSchema.Count, predicate); var getters = CreateGetters(input, active); - disposer = null; - return new SimpleRow(Schema, input, getters); + return new SimpleRow(OutputSchema, input, getters); } } } /// /// This is an wrapper for quantile regression predictors. They need a separate - /// wrapper because they need the quantiles to create the ISchemaBound. + /// wrapper because they need the quantiles to create the . /// - public sealed class SchemaBindableQuantileRegressionPredictor : SchemaBindablePredictorWrapperBase + [BestFriend] + internal sealed class SchemaBindableQuantileRegressionPredictor : SchemaBindablePredictorWrapperBase { public const string LoaderSignature = "QuantileSchemaBindable"; private static VersionInfo GetVersionInfo() @@ -651,17 +646,17 @@ public static SchemaBindableQuantileRegressionPredictor Create(IHostEnvironment return new SchemaBindableQuantileRegressionPredictor(env, ctx); } - protected override ISchemaBoundMapper BindCore(IChannel ch, RoleMappedSchema schema) + private protected override ISchemaBoundMapper BindCore(IChannel ch, RoleMappedSchema schema) { return new SingleValueRowMapper(schema, this, Schema.Create(new SchemaImpl(ScoreType, _quantiles))); } - protected override Delegate GetPredictionGetter(IRow input, int colSrc) + protected override Delegate GetPredictionGetter(Row input, int colSrc) { Contracts.AssertValue(input); - Contracts.Assert(0 <= colSrc && colSrc < input.Schema.ColumnCount); + Contracts.Assert(0 <= colSrc && colSrc < input.Schema.Count); - var typeSrc = input.Schema.GetColumnType(colSrc); + var typeSrc = input.Schema[colSrc].Type; Contracts.Assert(typeSrc.IsVector && typeSrc.ItemType == NumberType.Float); Contracts.Assert(ValueMapper == null || typeSrc.VectorSize == ValueMapper.InputType.VectorSize || ValueMapper.InputType.VectorSize == 0); diff --git a/src/Microsoft.ML.Data/Scorers/ScoreMapperSchema.cs b/src/Microsoft.ML.Data/Scorers/ScoreMapperSchema.cs index 68cbd34333..412b11a57e 100644 --- a/src/Microsoft.ML.Data/Scorers/ScoreMapperSchema.cs +++ b/src/Microsoft.ML.Data/Scorers/ScoreMapperSchema.cs @@ -2,11 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Float = System.Single; -using System.Collections.Generic; using System; +using System.Collections.Generic; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// A base class for schemas for ISchemaBoundMappers. Takes care of all the metadata that has to do with diff --git a/src/Microsoft.ML.Data/StaticPipe/DataLoadSaveOperationsExtensions.cs b/src/Microsoft.ML.Data/StaticPipe/DataLoadSaveOperationsExtensions.cs index 57adb1be4d..db30ecd454 100644 --- a/src/Microsoft.ML.Data/StaticPipe/DataLoadSaveOperationsExtensions.cs +++ b/src/Microsoft.ML.Data/StaticPipe/DataLoadSaveOperationsExtensions.cs @@ -2,17 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Data.IO; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.StaticPipe; using System; -using System.Collections.Generic; -using System.IO; -using System.Linq; -using System.Text; -using static Microsoft.ML.Runtime.Data.TextLoader; +using Microsoft.ML.Data; +using static Microsoft.ML.Data.TextLoader; namespace Microsoft.ML.StaticPipe { @@ -40,10 +32,10 @@ public static class DataLoadSaveOperationsExtensions /// Whether the input may include sparse representations. /// Remove trailing whitespace from lines. /// A configured statically-typed reader for text files. - public static DataReader TextReader<[IsShape] TShape>( + public static DataReader CreateTextReader<[IsShape] TShape>( this DataOperations catalog, Func func, IMultiStreamSource files = null, bool hasHeader = false, char separator = '\t', bool allowQuoting = true, bool allowSparse = true, bool trimWhitspace = false) - => TextLoader.CreateReader(catalog.Environment, func, files, hasHeader, separator, allowQuoting, allowSparse, trimWhitspace); + => CreateReader(catalog.Environment, func, files, hasHeader, separator, allowQuoting, allowSparse, trimWhitspace); } } diff --git a/src/Microsoft.ML.Data/StaticPipe/DataReader.cs b/src/Microsoft.ML.Data/StaticPipe/DataReader.cs index b17004c20d..f35e430eca 100644 --- a/src/Microsoft.ML.Data/StaticPipe/DataReader.cs +++ b/src/Microsoft.ML.Data/StaticPipe/DataReader.cs @@ -4,8 +4,6 @@ using Microsoft.ML.Core.Data; using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; using Microsoft.ML.StaticPipe.Runtime; namespace Microsoft.ML.StaticPipe diff --git a/src/Microsoft.ML.Data/StaticPipe/DataReaderEstimator.cs b/src/Microsoft.ML.Data/StaticPipe/DataReaderEstimator.cs index 540389414d..8639e3e785 100644 --- a/src/Microsoft.ML.Data/StaticPipe/DataReaderEstimator.cs +++ b/src/Microsoft.ML.Data/StaticPipe/DataReaderEstimator.cs @@ -4,8 +4,6 @@ using Microsoft.ML.Core.Data; using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; using Microsoft.ML.StaticPipe.Runtime; namespace Microsoft.ML.StaticPipe diff --git a/src/Microsoft.ML.Data/StaticPipe/DataView.cs b/src/Microsoft.ML.Data/StaticPipe/DataView.cs index 153623c5cc..e8dbaae416 100644 --- a/src/Microsoft.ML.Data/StaticPipe/DataView.cs +++ b/src/Microsoft.ML.Data/StaticPipe/DataView.cs @@ -2,12 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; +using System.Collections.Generic; +using System.Linq; using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; using Microsoft.ML.StaticPipe.Runtime; -using System.Collections.Generic; -using System; namespace Microsoft.ML.StaticPipe { @@ -23,6 +22,19 @@ internal DataView(IHostEnvironment env, IDataView view, StaticSchemaShape shape) AsDynamic = view; Shape.Check(Env, AsDynamic.Schema); } + + /// + /// This function return a whose columns are all cached in memory. + /// This returned is almost the same to the source . + /// The only difference are cache-related properties. + /// + public DataView Cache() + { + // Generate all column indexes in the source data. + var prefetched = Enumerable.Range(0, AsDynamic.Schema.Count).ToArray(); + // Create a cached version of the source data by caching all columns. + return new DataView(Env, new CacheDataView(Env, AsDynamic, prefetched), Shape); + } } public static class DataViewExtensions diff --git a/src/Microsoft.ML.Data/StaticPipe/Estimator.cs b/src/Microsoft.ML.Data/StaticPipe/Estimator.cs index 1575a0a2f1..4c8b81cff9 100644 --- a/src/Microsoft.ML.Data/StaticPipe/Estimator.cs +++ b/src/Microsoft.ML.Data/StaticPipe/Estimator.cs @@ -5,9 +5,6 @@ using System; using System.Collections.Generic; using Microsoft.ML.Core.Data; -using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; using Microsoft.ML.StaticPipe.Runtime; namespace Microsoft.ML.StaticPipe @@ -77,5 +74,14 @@ string NameMap(PipelineColumn col) return new Estimator(Env, est, _inShape, newOut); } } + + /// + /// Cache data produced in memory by this estimator. It may append an extra estimator to the this estimator + /// for caching. The newly added estimator would be returned. + /// + public Estimator AppendCacheCheckpoint() + { + return new Estimator(Env, AsDynamic.AppendCacheCheckpoint(Env), _inShape, Shape); + } } } diff --git a/src/Microsoft.ML.Data/StaticPipe/PipelineColumn.cs b/src/Microsoft.ML.Data/StaticPipe/PipelineColumn.cs index 8b68e788e8..66686b94f1 100644 --- a/src/Microsoft.ML.Data/StaticPipe/PipelineColumn.cs +++ b/src/Microsoft.ML.Data/StaticPipe/PipelineColumn.cs @@ -3,8 +3,7 @@ // See the LICENSE file in the project root for more information. using Microsoft.ML.Core.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Data; using Microsoft.ML.StaticPipe.Runtime; namespace Microsoft.ML.StaticPipe diff --git a/src/Microsoft.ML.Data/StaticPipe/Reconciler.cs b/src/Microsoft.ML.Data/StaticPipe/Reconciler.cs index f60c4d5327..c249def102 100644 --- a/src/Microsoft.ML.Data/StaticPipe/Reconciler.cs +++ b/src/Microsoft.ML.Data/StaticPipe/Reconciler.cs @@ -2,11 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Core.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Transforms; using System; using System.Collections.Generic; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Transforms; namespace Microsoft.ML.StaticPipe.Runtime { diff --git a/src/Microsoft.ML.Data/StaticPipe/SchemaAssertionContext.cs b/src/Microsoft.ML.Data/StaticPipe/SchemaAssertionContext.cs index 4375edc75a..eab9600244 100644 --- a/src/Microsoft.ML.Data/StaticPipe/SchemaAssertionContext.cs +++ b/src/Microsoft.ML.Data/StaticPipe/SchemaAssertionContext.cs @@ -2,7 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Data; namespace Microsoft.ML.StaticPipe.Runtime { diff --git a/src/Microsoft.ML.Data/StaticPipe/SchemaBearing.cs b/src/Microsoft.ML.Data/StaticPipe/SchemaBearing.cs index d4010dd28d..2c439e16a9 100644 --- a/src/Microsoft.ML.Data/StaticPipe/SchemaBearing.cs +++ b/src/Microsoft.ML.Data/StaticPipe/SchemaBearing.cs @@ -2,11 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Threading; using Microsoft.ML.Core.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Data; using Microsoft.ML.StaticPipe.Runtime; -using System.Threading; namespace Microsoft.ML.StaticPipe { diff --git a/src/Microsoft.ML.Data/StaticPipe/StaticPipeExtensions.cs b/src/Microsoft.ML.Data/StaticPipe/StaticPipeExtensions.cs index d0e57b93da..13b571825c 100644 --- a/src/Microsoft.ML.Data/StaticPipe/StaticPipeExtensions.cs +++ b/src/Microsoft.ML.Data/StaticPipe/StaticPipeExtensions.cs @@ -2,10 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime.Data; using System; -using Microsoft.ML.Runtime; using Microsoft.ML.Core.Data; +using Microsoft.ML.Data; using Microsoft.ML.StaticPipe.Runtime; namespace Microsoft.ML.StaticPipe diff --git a/src/Microsoft.ML.Data/StaticPipe/StaticPipeInternalUtils.cs b/src/Microsoft.ML.Data/StaticPipe/StaticPipeInternalUtils.cs index b5a4fbecf1..0eae14d2f6 100644 --- a/src/Microsoft.ML.Data/StaticPipe/StaticPipeInternalUtils.cs +++ b/src/Microsoft.ML.Data/StaticPipe/StaticPipeInternalUtils.cs @@ -8,10 +8,7 @@ using System.Reflection; using System.Runtime.CompilerServices; using Microsoft.ML.Core.Data; -using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Internal.Utilities; namespace Microsoft.ML.StaticPipe.Runtime { diff --git a/src/Microsoft.ML.Data/StaticPipe/StaticPipeUtils.cs b/src/Microsoft.ML.Data/StaticPipe/StaticPipeUtils.cs index fa5cb7d477..ec893a0589 100644 --- a/src/Microsoft.ML.Data/StaticPipe/StaticPipeUtils.cs +++ b/src/Microsoft.ML.Data/StaticPipe/StaticPipeUtils.cs @@ -6,10 +6,8 @@ using System.Collections.Generic; using System.Collections.Immutable; using System.Linq; +using Microsoft.ML; using Microsoft.ML.Core.Data; -using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; using Microsoft.ML.Transforms; namespace Microsoft.ML.StaticPipe.Runtime diff --git a/src/Microsoft.ML.Data/StaticPipe/StaticSchemaShape.cs b/src/Microsoft.ML.Data/StaticPipe/StaticSchemaShape.cs index 08d6936721..f162693e1d 100644 --- a/src/Microsoft.ML.Data/StaticPipe/StaticSchemaShape.cs +++ b/src/Microsoft.ML.Data/StaticPipe/StaticSchemaShape.cs @@ -6,8 +6,7 @@ using System.Collections.Generic; using System.Reflection; using Microsoft.ML.Core.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Data; namespace Microsoft.ML.StaticPipe.Runtime { @@ -51,7 +50,7 @@ public static StaticSchemaShape Make(ParameterInfo info) /// /// The context on which to throw exceptions /// The schema to check - public void Check(IExceptionContext ectx, ISchema schema) + public void Check(IExceptionContext ectx, Schema schema) { Contracts.AssertValue(ectx); ectx.AssertValue(schema); @@ -60,7 +59,7 @@ public void Check(IExceptionContext ectx, ISchema schema) { if (!schema.TryGetColumnIndex(pair.Key, out int colIdx)) throw ectx.ExceptParam(nameof(schema), $"Column named '{pair.Key}' was not found"); - var col = RowColumnUtils.GetColumn(schema, colIdx); + var col = schema[colIdx]; var type = GetTypeOrNull(col); if ((type != null && !pair.Value.IsAssignableFromStaticPipeline(type)) || (type == null && IsStandard(ectx, pair.Value))) { @@ -112,7 +111,7 @@ public void Check(IExceptionContext ectx, SchemaShape shape) private static Type GetTypeOrNull(SchemaShape.Column col) { - Contracts.AssertValue(col); + Contracts.Assert(col.IsValid); Type vecType = null; @@ -234,9 +233,8 @@ private static bool IsStandardCore(Type t) /// The column /// The .NET type for the static pipelines that should be used to reflect this type, given /// both the characteristics of the as well as one or two crucial pieces of metadata - private static Type GetTypeOrNull(IColumn col) + private static Type GetTypeOrNull(Schema.Column col) { - Contracts.AssertValue(col); var t = col.Type; Type vecType = null; @@ -251,7 +249,7 @@ private static Type GetTypeOrNull(IColumn col) var meta = col.Metadata; if (meta.Schema.TryGetColumnIndex(MetadataUtils.Kinds.IsNormalized, out int normcol)) { - var normtype = meta.Schema.GetColumnType(normcol); + var normtype = meta.Schema[normcol].Type; if (normtype == BoolType.Instance) { bool val = default; @@ -278,13 +276,14 @@ private static Type GetTypeOrNull(IColumn col) { // Check to see if we have key value metadata of the appropriate type, size, and whatnot. var meta = col.Metadata; - if (meta.Schema.TryGetColumnIndex(MetadataUtils.Kinds.KeyValues, out int kvcol)) + if (meta.Schema.TryGetColumnIndex(MetadataUtils.Kinds.KeyValues, out int kvcolIndex)) { - var kvType = meta.Schema.GetColumnType(kvcol); - if (kvType.VectorSize == kt.Count) + var kvcol = meta.Schema[kvcolIndex]; + var kvType = kvcol.Type; + if (kvType is VectorType kvVecType && kvVecType.Size == kt.Count) { Contracts.Assert(kt.Count > 0); - var subtype = GetTypeOrNull(RowColumnUtils.GetColumn(meta, kvcol)); + var subtype = GetTypeOrNull(kvcol); if (subtype != null && subtype.IsGenericType) { var sgtype = subtype.GetGenericTypeDefinition(); @@ -343,7 +342,7 @@ private static Type StaticKind(DataKind kind) case DataKind.U2: return typeof(ushort); case DataKind.U4: return typeof(uint); case DataKind.U8: return typeof(ulong); - case DataKind.U16: return typeof(UInt128); + case DataKind.U16: return typeof(RowId); case DataKind.R4: return typeof(float); case DataKind.R8: return typeof(double); diff --git a/src/Microsoft.ML.Data/StaticPipe/TrainerEstimatorReconciler.cs b/src/Microsoft.ML.Data/StaticPipe/TrainerEstimatorReconciler.cs index 49e559e78c..15753abedf 100644 --- a/src/Microsoft.ML.Data/StaticPipe/TrainerEstimatorReconciler.cs +++ b/src/Microsoft.ML.Data/StaticPipe/TrainerEstimatorReconciler.cs @@ -2,14 +2,12 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Collections.Generic; +using System.Linq; using Microsoft.ML.Core.Data; using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Transforms; -using System.Collections.Generic; -using System.Linq; namespace Microsoft.ML.StaticPipe.Runtime { diff --git a/src/Microsoft.ML.Data/StaticPipe/Transformer.cs b/src/Microsoft.ML.Data/StaticPipe/Transformer.cs index e3cbff72d2..472d4a60a7 100644 --- a/src/Microsoft.ML.Data/StaticPipe/Transformer.cs +++ b/src/Microsoft.ML.Data/StaticPipe/Transformer.cs @@ -4,7 +4,6 @@ using Microsoft.ML.Core.Data; using Microsoft.ML.Data; -using Microsoft.ML.Runtime; using Microsoft.ML.StaticPipe.Runtime; namespace Microsoft.ML.StaticPipe diff --git a/src/Microsoft.ML.Data/TrainContext.cs b/src/Microsoft.ML.Data/TrainContext.cs index 9418efae89..b0deb9b24b 100644 --- a/src/Microsoft.ML.Data/TrainContext.cs +++ b/src/Microsoft.ML.Data/TrainContext.cs @@ -2,14 +2,13 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Core.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Transforms; -using Microsoft.ML.Transforms.Conversions; using System; using System.Collections.Generic; using System.Linq; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data; +using Microsoft.ML.Transforms; +using Microsoft.ML.Transforms.Conversions; namespace Microsoft.ML { @@ -31,18 +30,20 @@ public abstract class TrainContextBase /// /// The dataset to split. /// The fraction of data to go into the test set. - /// Optional stratification column. - /// If two examples share the same value of the (if provided), - /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from - /// train to the test set. + /// Optional name of the column to use as a stratification column. If two examples share the same value of the + /// (if provided), they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from train to the test set. + /// If this optional parameter is not provided, a stratification columns will be generated, and its values will be random numbers . + /// Optional parameter used in combination with the . + /// If the is not provided, the random numbers generated to create it, will use this seed as value. + /// And if it is not provided, the default value will be used. /// A pair of datasets, for the train and test set. - public (IDataView trainSet, IDataView testSet) TrainTestSplit(IDataView data, double testFraction = 0.1, string stratificationColumn = null) + public (IDataView trainSet, IDataView testSet) TrainTestSplit(IDataView data, double testFraction = 0.1, string stratificationColumn = null, uint? seed = null) { Host.CheckValue(data, nameof(data)); Host.CheckParam(0 < testFraction && testFraction < 1, nameof(testFraction), "Must be between 0 and 1 exclusive"); Host.CheckValueOrNull(stratificationColumn); - EnsureStratificationColumn(ref data, ref stratificationColumn); + EnsureStratificationColumn(ref data, ref stratificationColumn, seed); var trainFilter = new RangeFilter(Host, new RangeFilter.Arguments() { @@ -67,14 +68,14 @@ public abstract class TrainContextBase /// Return each model and each scored test dataset. /// protected (IDataView scoredTestSet, ITransformer model)[] CrossValidateTrain(IDataView data, IEstimator estimator, - int numFolds, string stratificationColumn) + int numFolds, string stratificationColumn, uint? seed = null) { Host.CheckValue(data, nameof(data)); Host.CheckValue(estimator, nameof(estimator)); Host.CheckParam(numFolds > 1, nameof(numFolds), "Must be more than 1"); Host.CheckValueOrNull(stratificationColumn); - EnsureStratificationColumn(ref data, ref stratificationColumn); + EnsureStratificationColumn(ref data, ref stratificationColumn, seed); Func foldFunction = fold => @@ -121,7 +122,7 @@ protected TrainContextBase(IHostEnvironment env, string registrationName) /// for , hash it if needed, or introduce a new one /// if needed. /// - private void EnsureStratificationColumn(ref IDataView data, ref string stratificationColumn) + private void EnsureStratificationColumn(ref IDataView data, ref string stratificationColumn, uint? seed = null) { // We need to handle two cases: if the stratification column is provided, we use hashJoin to // build a single hash of it. If it is not, we generate a random number. @@ -129,14 +130,14 @@ private void EnsureStratificationColumn(ref IDataView data, ref string stratific if (stratificationColumn == null) { stratificationColumn = data.Schema.GetTempColumnName("StratificationColumn"); - data = new GenerateNumberTransform(Host, data, stratificationColumn); + data = new GenerateNumberTransform(Host, data, stratificationColumn, seed); } else { if (!data.Schema.TryGetColumnIndex(stratificationColumn, out int stratCol)) throw Host.ExceptSchemaMismatch(nameof(stratificationColumn), "stratification", stratificationColumn); - var type = data.Schema.GetColumnType(stratCol); + var type = data.Schema[stratCol].Type; if (!RangeFilter.IsValidRangeFilterColumnType(Host, type)) { // Hash the stratification column. @@ -207,7 +208,7 @@ internal BinaryClassificationTrainers(BinaryClassificationContext ctx) /// The name of the probability column in , the calibrated version of . /// The name of the predicted label column in . /// The evaluation results for these calibrated outputs. - public BinaryClassifierEvaluator.CalibratedResult Evaluate(IDataView data, string label = DefaultColumnNames.Label, string score = DefaultColumnNames.Score, + public CalibratedBinaryClassificationMetrics Evaluate(IDataView data, string label = DefaultColumnNames.Label, string score = DefaultColumnNames.Score, string probability = DefaultColumnNames.Probability, string predictedLabel = DefaultColumnNames.PredictedLabel) { Host.CheckValue(data, nameof(data)); @@ -228,7 +229,7 @@ public BinaryClassifierEvaluator.CalibratedResult Evaluate(IDataView data, strin /// The name of the score column in . /// The name of the predicted label column in . /// The evaluation results for these uncalibrated outputs. - public BinaryClassifierEvaluator.Result EvaluateNonCalibrated(IDataView data, string label = DefaultColumnNames.Label, string score = DefaultColumnNames.Score, + public BinaryClassificationMetrics EvaluateNonCalibrated(IDataView data, string label = DefaultColumnNames.Label, string score = DefaultColumnNames.Score, string predictedLabel = DefaultColumnNames.PredictedLabel) { Host.CheckValue(data, nameof(data)); @@ -249,16 +250,19 @@ public BinaryClassifierEvaluator.Result EvaluateNonCalibrated(IDataView data, st /// The estimator to fit. /// Number of cross-validation folds. /// The label column (for evaluation). - /// Optional stratification column. - /// If two examples share the same value of the (if provided), - /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from - /// train to the test set. + /// Optional name of the column to use as a stratification column. If two examples share the same value of the + /// (if provided), they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from train to the test set. + /// If this optional parameter is not provided, a stratification columns will be generated, and its values will be random numbers . + /// Optional parameter used in combination with the . + /// If the is not provided, the random numbers generated to create it, will use this seed as value. + /// And if it is not provided, the default value will be used. /// Per-fold results: metrics, models, scored datasets. - public (BinaryClassifierEvaluator.Result metrics, ITransformer model, IDataView scoredTestData)[] CrossValidateNonCalibrated( - IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null) + public (BinaryClassificationMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidateNonCalibrated( + IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, + string stratificationColumn = null, uint? seed = null) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn); + var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed); return result.Select(x => (EvaluateNonCalibrated(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray(); } @@ -275,12 +279,14 @@ public BinaryClassifierEvaluator.Result EvaluateNonCalibrated(IDataView data, st /// If two examples share the same value of the (if provided), /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from /// train to the test set. + /// If not present in dataset we will generate random filled column based on provided . /// Per-fold results: metrics, models, scored datasets. - public (BinaryClassifierEvaluator.CalibratedResult metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate( - IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null) + public (CalibratedBinaryClassificationMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate( + IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, + string stratificationColumn = null, uint? seed = null) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn); + var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed); return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray(); } } @@ -318,19 +324,19 @@ internal ClusteringTrainers(ClusteringContext ctx) /// The scored data. /// The name of the score column in . /// The name of the optional label column in . - /// If present, the metric will be computed. + /// If present, the metric will be computed. /// The name of the optional features column in . - /// If present, the metric will be computed. + /// If present, the metric will be computed. /// The evaluation result. - public ClusteringEvaluator.Result Evaluate(IDataView data, + public ClusteringMetrics Evaluate(IDataView data, string label = null, string score = DefaultColumnNames.Score, - string features = null ) + string features = null) { Host.CheckValue(data, nameof(data)); Host.CheckNonEmpty(score, nameof(score)); - if(features != null) + if (features != null) Host.CheckNonEmpty(features, nameof(features), "The features column name should be non-empty if you want to calculate the Dbi metric."); if (label != null) @@ -350,15 +356,18 @@ public ClusteringEvaluator.Result Evaluate(IDataView data, /// Number of cross-validation folds. /// Optional label column for evaluation (clustering tasks may not always have a label). /// Optional features column for evaluation (needed for calculating Dbi metric) - /// Optional stratification column. - /// If two examples share the same value of the (if provided), - /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from - /// train to the test set. + /// Optional name of the column to use as a stratification column. If two examples share the same value of the + /// (if provided), they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from train to the test set. + /// If this optional parameter is not provided, a stratification columns will be generated, and its values will be random numbers . + /// Optional parameter used in combination with the . + /// If the is not provided, the random numbers generated to create it, will use this seed as value. + /// And if it is not provided, the default value will be used. /// Per-fold results: metrics, models, scored datasets. - public (ClusteringEvaluator.Result metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate( - IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = null, string featuresColumn = null, string stratificationColumn = null) + public (ClusteringMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate( + IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = null, string featuresColumn = null, + string stratificationColumn = null, uint? seed = null) { - var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn); + var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed); return result.Select(x => (Evaluate(x.scoredTestSet, label: labelColumn, features: featuresColumn), x.model, x.scoredTestSet)).ToArray(); } } @@ -394,11 +403,11 @@ internal MulticlassClassificationTrainers(MulticlassClassificationContext ctx) /// The name of the label column in . /// The name of the score column in . /// The name of the predicted label column in . - /// If given a positive value, the will be filled with + /// If given a positive value, the will be filled with /// the top-K accuracy, that is, the accuracy assuming we consider an example with the correct class within /// the top-K values as being stored "correctly." /// The evaluation results for these calibrated outputs. - public MultiClassClassifierEvaluator.Result Evaluate(IDataView data, string label = DefaultColumnNames.Label, string score = DefaultColumnNames.Score, + public MultiClassClassifierMetrics Evaluate(IDataView data, string label = DefaultColumnNames.Label, string score = DefaultColumnNames.Score, string predictedLabel = DefaultColumnNames.PredictedLabel, int topK = 0) { Host.CheckValue(data, nameof(data)); @@ -422,16 +431,19 @@ public MultiClassClassifierEvaluator.Result Evaluate(IDataView data, string labe /// The estimator to fit. /// Number of cross-validation folds. /// The label column (for evaluation). - /// Optional stratification column. - /// If two examples share the same value of the (if provided), - /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from - /// train to the test set. + /// Optional name of the column to use as a stratification column. If two examples share the same value of the + /// (if provided), they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from train to the test set. + /// If this optional parameter is not provided, a stratification columns will be generated, and its values will be random numbers . + /// Optional parameter used in combination with the . + /// If the is not provided, the random numbers generated to create it, will use this seed as value. + /// And if it is not provided, the default value will be used. /// Per-fold results: metrics, models, scored datasets. - public (MultiClassClassifierEvaluator.Result metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate( - IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null) + public (MultiClassClassifierMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate( + IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, + string stratificationColumn = null, uint? seed = null) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn); + var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed); return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray(); } } @@ -467,7 +479,7 @@ internal RegressionTrainers(RegressionContext ctx) /// The name of the label column in . /// The name of the score column in . /// The evaluation results for these calibrated outputs. - public RegressionEvaluator.Result Evaluate(IDataView data, string label = DefaultColumnNames.Label, string score = DefaultColumnNames.Score) + public RegressionMetrics Evaluate(IDataView data, string label = DefaultColumnNames.Label, string score = DefaultColumnNames.Score) { Host.CheckValue(data, nameof(data)); Host.CheckNonEmpty(label, nameof(label)); @@ -486,16 +498,19 @@ public RegressionEvaluator.Result Evaluate(IDataView data, string label = Defaul /// The estimator to fit. /// Number of cross-validation folds. /// The label column (for evaluation). - /// Optional stratification column. - /// If two examples share the same value of the (if provided), - /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from - /// train to the test set. + /// Optional name of the column to use as a stratification column. If two examples share the same value of the + /// (if provided), they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from train to the test set. + /// If this optional parameter is not provided, a stratification columns will be generated, and its values will be random numbers . + /// Optional parameter used in combination with the . + /// If the is not provided, the random numbers generated to create it, will use this seed as value. + /// And if it is not provided, the default value will be used. /// Per-fold results: metrics, models, scored datasets. - public (RegressionEvaluator.Result metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate( - IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null) + public (RegressionMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate( + IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, + string stratificationColumn = null, uint? seed = null) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn); + var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed); return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray(); } } @@ -532,7 +547,7 @@ internal RankingTrainers(RankingContext ctx) /// The name of the groupId column in . /// The name of the score column in . /// The evaluation results for these calibrated outputs. - public RankerEvaluator.Result Evaluate(IDataView data, string label, string groupId, string score = DefaultColumnNames.Score) + public RankerMetrics Evaluate(IDataView data, string label, string groupId, string score = DefaultColumnNames.Score) { Host.CheckValue(data, nameof(data)); Host.CheckNonEmpty(label, nameof(label)); diff --git a/src/Microsoft.ML.Data/Training/EarlyStoppingCriteria.cs b/src/Microsoft.ML.Data/Training/EarlyStoppingCriteria.cs index 1da5a5562a..26a58b5b1a 100644 --- a/src/Microsoft.ML.Data/Training/EarlyStoppingCriteria.cs +++ b/src/Microsoft.ML.Data/Training/EarlyStoppingCriteria.cs @@ -2,14 +2,12 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Float = System.Single; - -using System; using System.Collections.Generic; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Internal.Internallearn; -using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Internallearn; +using Float = System.Single; [assembly: LoadableClass(typeof(TolerantEarlyStoppingCriterion), typeof(TolerantEarlyStoppingCriterion.Arguments), typeof(SignatureEarlyStoppingCriterion), "Tolerant (TR)", "tr")] [assembly: LoadableClass(typeof(GLEarlyStoppingCriterion), typeof(GLEarlyStoppingCriterion.Arguments), typeof(SignatureEarlyStoppingCriterion), "Loss of Generality (GL)", "gl")] @@ -23,7 +21,7 @@ [assembly: EntryPointModule(typeof(PQEarlyStoppingCriterion))] [assembly: EntryPointModule(typeof(UPEarlyStoppingCriterion))] -namespace Microsoft.ML.Runtime.Internal.Internallearn +namespace Microsoft.ML.Internal.Internallearn { public delegate void SignatureEarlyStoppingCriterion(bool lowerIsBetter); diff --git a/src/Microsoft.ML.Data/Training/ITrainerEstimator.cs b/src/Microsoft.ML.Data/Training/ITrainerEstimator.cs index 2c9942e8d0..0384203a9e 100644 --- a/src/Microsoft.ML.Data/Training/ITrainerEstimator.cs +++ b/src/Microsoft.ML.Data/Training/ITrainerEstimator.cs @@ -4,11 +4,11 @@ using Microsoft.ML.Core.Data; -namespace Microsoft.ML.Runtime.Training +namespace Microsoft.ML.Training { - public interface ITrainerEstimator: IEstimator - where TTransformer: ISingleFeaturePredictionTransformer - where TPredictor: IPredictor + public interface ITrainerEstimator : IEstimator + where TTransformer : ISingleFeaturePredictionTransformer + where TPredictor : IPredictor { TrainerInfo Info { get; } diff --git a/src/Microsoft.ML.Data/Training/TrainerBase.cs b/src/Microsoft.ML.Data/Training/TrainerBase.cs index d82dfb7ba6..1a6a145514 100644 --- a/src/Microsoft.ML.Data/Training/TrainerBase.cs +++ b/src/Microsoft.ML.Data/Training/TrainerBase.cs @@ -2,7 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -namespace Microsoft.ML.Runtime.Training +namespace Microsoft.ML.Training { public abstract class TrainerBase : ITrainer where TPredictor : IPredictor diff --git a/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs b/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs index 1e49c32ed3..23651b0510 100644 --- a/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs +++ b/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs @@ -2,13 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System.Collections.Generic; using System.Linq; using Microsoft.ML.Core.Data; using Microsoft.ML.Data; -using Microsoft.ML.Runtime.Data; -namespace Microsoft.ML.Runtime.Training +namespace Microsoft.ML.Training { /// /// This represents a basic class for 'simple trainer'. @@ -55,13 +53,11 @@ public abstract class TrainerEstimatorBase : ITrainerEstim private protected TrainerEstimatorBase(IHost host, SchemaShape.Column feature, SchemaShape.Column label, - SchemaShape.Column weight = null) + SchemaShape.Column weight = default) { Contracts.CheckValue(host, nameof(host)); Host = host; - Host.CheckValue(feature, nameof(feature)); - Host.CheckValueOrNull(label); - Host.CheckValueOrNull(weight); + Host.CheckParam(feature.IsValid, nameof(feature), "not initialized properly"); FeatureColumn = feature; LabelColumn = label; @@ -76,7 +72,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) CheckInputSchema(inputSchema); - var outColumns = inputSchema.Columns.ToDictionary(x => x.Name); + var outColumns = inputSchema.ToDictionary(x => x.Name); foreach (var col in GetOutputColumnsCore(inputSchema)) outColumns[col.Name] = col; @@ -102,7 +98,7 @@ private void CheckInputSchema(SchemaShape inputSchema) if (!FeatureColumn.IsCompatibleWith(featureCol)) throw Host.Except($"Feature column '{FeatureColumn.Name}' is not compatible"); - if (WeightColumn != null) + if (WeightColumn.IsValid) { if (!inputSchema.TryFindColumn(WeightColumn.Name, out var weightCol)) throw Host.Except($"Weight column '{WeightColumn.Name}' is not found"); @@ -112,7 +108,7 @@ private void CheckInputSchema(SchemaShape inputSchema) // Special treatment for label column: we allow different types of labels, so the trainers // may define their own requirements on the label column. - if (LabelColumn != null) + if (LabelColumn.IsValid) { if (!inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol)) throw Host.Except($"Label column '{LabelColumn.Name}' is not found"); @@ -122,8 +118,8 @@ private void CheckInputSchema(SchemaShape inputSchema) protected virtual void CheckLabelCompatible(SchemaShape.Column labelCol) { - Contracts.CheckValue(labelCol, nameof(labelCol)); - Contracts.AssertValue(LabelColumn); + Contracts.CheckParam(labelCol.IsValid, nameof(labelCol), "not initialized properly"); + Host.Assert(LabelColumn.IsValid); if (!LabelColumn.IsCompatibleWith(labelCol)) throw Host.Except($"Label column '{LabelColumn.Name}' is not compatible"); @@ -132,21 +128,10 @@ protected virtual void CheckLabelCompatible(SchemaShape.Column labelCol) protected TTransformer TrainTransformer(IDataView trainSet, IDataView validationSet = null, IPredictor initPredictor = null) { - var cachedTrain = Info.WantCaching ? new CacheDataView(Host, trainSet, prefetch: null) : trainSet; + var trainRoleMapped = MakeRoles(trainSet); + var validRoleMapped = validationSet == null ? null : MakeRoles(validationSet); - var trainRoles = MakeRoles(cachedTrain); - - RoleMappedData validRoles; - - if (validationSet == null) - validRoles = null; - else - { - var cachedValid = Info.WantCaching ? new CacheDataView(Host, validationSet, prefetch: null) : validationSet; - validRoles = MakeRoles(cachedValid); - } - - var pred = TrainModelCore(new TrainContext(trainRoles, validRoles, null, initPredictor)); + var pred = TrainModelCore(new TrainContext(trainRoleMapped, validRoleMapped, null, initPredictor)); return MakeTransformer(pred, trainSet.Schema); } @@ -155,8 +140,8 @@ protected TTransformer TrainTransformer(IDataView trainSet, protected abstract TTransformer MakeTransformer(TModel model, Schema trainSchema); - protected virtual RoleMappedData MakeRoles(IDataView data) => - new RoleMappedData(data, label: LabelColumn?.Name, feature: FeatureColumn.Name, weight: WeightColumn?.Name); + private protected virtual RoleMappedData MakeRoles(IDataView data) => + new RoleMappedData(data, label: LabelColumn.Name, feature: FeatureColumn.Name, weight: WeightColumn.Name); IPredictor ITrainer.Train(TrainContext context) => ((ITrainer)this).Train(context); } @@ -178,16 +163,15 @@ public abstract class TrainerEstimatorBaseWithGroupId : Tr public TrainerEstimatorBaseWithGroupId(IHost host, SchemaShape.Column feature, SchemaShape.Column label, - SchemaShape.Column weight = null, - SchemaShape.Column groupId = null) + SchemaShape.Column weight = default, + SchemaShape.Column groupId = default) :base(host, feature, label, weight) { - Host.CheckValueOrNull(groupId); GroupIdColumn = groupId; } - protected override RoleMappedData MakeRoles(IDataView data) => - new RoleMappedData(data, label: LabelColumn?.Name, feature: FeatureColumn.Name, group: GroupIdColumn?.Name, weight: WeightColumn?.Name); + private protected override RoleMappedData MakeRoles(IDataView data) => + new RoleMappedData(data, label: LabelColumn.Name, feature: FeatureColumn.Name, group: GroupIdColumn.Name, weight: WeightColumn.Name); } } diff --git a/src/Microsoft.ML.Data/Training/TrainerUtils.cs b/src/Microsoft.ML.Data/Training/TrainerUtils.cs index f6947cd2f9..26a431c88d 100644 --- a/src/Microsoft.ML.Data/Training/TrainerUtils.cs +++ b/src/Microsoft.ML.Data/Training/TrainerUtils.cs @@ -2,14 +2,13 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Core.Data; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Utilities; using System; using System.Collections.Generic; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.Training +namespace Microsoft.ML.Training { /// /// Options for creating a row cursor from a RoleMappedData with specified standard columns active. @@ -38,7 +37,8 @@ public enum CursOpt : uint AllFeatures = Features | AllowBadFeatures, } - public static class TrainerUtils + [BestFriend] + internal static class TrainerUtils { /// /// Check for a standard (known-length vector of float) feature column. @@ -47,11 +47,11 @@ public static void CheckFeatureFloatVector(this RoleMappedData data) { Contracts.CheckValue(data, nameof(data)); - var col = data.Schema.Feature; - if (col == null) + if (!data.Schema.Feature.HasValue) throw Contracts.ExceptParam(nameof(data), "Training data must specify a feature column."); - Contracts.Assert(!data.Schema.Schema.IsHidden(col.Index)); - if (!col.Type.IsKnownSizeVector || col.Type.ItemType != NumberType.Float) + var col = data.Schema.Feature.Value; + Contracts.Assert(!col.IsHidden); + if (!(col.Type is VectorType vecType && vecType.Size > 0 && vecType.ItemType == NumberType.Float)) throw Contracts.ExceptParam(nameof(data), "Training feature column '{0}' must be a known-size vector of R4, but has type: {1}.", col.Name, col.Type); } @@ -64,11 +64,12 @@ public static void CheckFeatureFloatVector(this RoleMappedData data, out int len // If the above function is generalized, this needs to be as well. Contracts.AssertValue(data); - Contracts.Assert(data.Schema.Feature != null); - Contracts.Assert(!data.Schema.Schema.IsHidden(data.Schema.Feature.Index)); - Contracts.Assert(data.Schema.Feature.Type.IsKnownSizeVector); - Contracts.Assert(data.Schema.Feature.Type.ItemType == NumberType.Float); - length = data.Schema.Feature.Type.VectorSize; + Contracts.Assert(data.Schema.Feature.HasValue); + var col = data.Schema.Feature.Value; + Contracts.Assert(!col.IsHidden); + Contracts.Assert(col.Type.IsKnownSizeVector); + Contracts.Assert(col.Type.ItemType == NumberType.Float); + length = col.Type.VectorSize; } /// @@ -78,11 +79,11 @@ public static void CheckBinaryLabel(this RoleMappedData data) { Contracts.CheckValue(data, nameof(data)); - var col = data.Schema.Label; - if (col == null) + if (!data.Schema.Label.HasValue) throw Contracts.ExceptParam(nameof(data), "Training data must specify a label column."); - Contracts.Assert(!data.Schema.Schema.IsHidden(col.Index)); - if (!col.Type.IsBool && col.Type != NumberType.R4 && col.Type != NumberType.R8 && col.Type.KeyCount != 2) + var col = data.Schema.Label.Value; + Contracts.Assert(!col.IsHidden); + if (col.Type != BoolType.Instance && col.Type != NumberType.R4 && col.Type != NumberType.R8 && !(col.Type is KeyType keyType && keyType.Count == 2)) { if (col.Type.IsKey) { @@ -112,10 +113,10 @@ public static void CheckRegressionLabel(this RoleMappedData data) { Contracts.CheckValue(data, nameof(data)); - var col = data.Schema.Label; - if (col == null) + if (!data.Schema.Label.HasValue) throw Contracts.ExceptParam(nameof(data), "Training data must specify a label column."); - Contracts.Assert(!data.Schema.Schema.IsHidden(col.Index)); + var col = data.Schema.Label.Value; + Contracts.Assert(!data.Schema.Schema[col.Index].IsHidden); if (col.Type != NumberType.R4 && col.Type != NumberType.R8) { throw Contracts.ExceptParam(nameof(data), @@ -132,13 +133,13 @@ public static void CheckMultiClassLabel(this RoleMappedData data, out int count) { Contracts.CheckValue(data, nameof(data)); - var col = data.Schema.Label; - if (col == null) + if (!data.Schema.Label.HasValue) throw Contracts.ExceptParam(nameof(data), "Training data must specify a label column."); - Contracts.Assert(!data.Schema.Schema.IsHidden(col.Index)); - if (col.Type.KeyCount > 0) + var col = data.Schema.Label.Value; + Contracts.Assert(!col.IsHidden); + if (col.Type is KeyType keyType && keyType.Count > 0) { - count = col.Type.KeyCount; + count = keyType.Count; return; } @@ -178,10 +179,10 @@ public static void CheckMultiOutputRegressionLabel(this RoleMappedData data) { Contracts.CheckValue(data, nameof(data)); - var col = data.Schema.Label; - if (col == null) + if (!data.Schema.Label.HasValue) throw Contracts.ExceptParam(nameof(data), "Training data must specify a label column."); - Contracts.Assert(!data.Schema.Schema.IsHidden(col.Index)); + var col = data.Schema.Label.Value; + Contracts.Assert(!col.IsHidden); if (!col.Type.IsKnownSizeVector || col.Type.ItemType != NumberType.Float) throw Contracts.ExceptParam(nameof(data), "Training label column '{0}' must be a known-size vector of R4, but has type: {1}.", col.Name, col.Type); } @@ -190,10 +191,10 @@ public static void CheckOptFloatWeight(this RoleMappedData data) { Contracts.CheckValue(data, nameof(data)); - var col = data.Schema.Weight; - if (col == null) + if (!data.Schema.Weight.HasValue) return; - Contracts.Assert(!data.Schema.Schema.IsHidden(col.Index)); + var col = data.Schema.Weight.Value; + Contracts.Assert(!col.IsHidden); if (col.Type != NumberType.R4 && col.Type != NumberType.R8) throw Contracts.ExceptParam(nameof(data), "Training weight column '{0}' must be of floating point numeric type, but has type: {1}.", col.Name, col.Type); } @@ -202,11 +203,11 @@ public static void CheckOptGroup(this RoleMappedData data) { Contracts.CheckValue(data, nameof(data)); - var col = data.Schema.Group; - if (col == null) + if (!data.Schema.Group.HasValue) return; - Contracts.Assert(!data.Schema.Schema.IsHidden(col.Index)); - if (col.Type.IsKey) + var col = data.Schema.Group.Value; + Contracts.Assert(!col.IsHidden); + if (col.Type is KeyType) return; throw Contracts.ExceptParam(nameof(data), "Training group column '{0}' type is invalid: {1}. Must be Key type.", col.Name, col.Type); } @@ -237,41 +238,41 @@ private static Func CreatePredicate(RoleMappedData data, CursOpt opt, /// Create a row cursor for the RoleMappedData with the indicated standard columns active. /// This does not verify that the columns exist, but merely activates the ones that do exist. /// - public static IRowCursor CreateRowCursor(this RoleMappedData data, CursOpt opt, IRandom rand, IEnumerable extraCols = null) + public static RowCursor CreateRowCursor(this RoleMappedData data, CursOpt opt, Random rand, IEnumerable extraCols = null) => data.Data.GetRowCursor(CreatePredicate(data, opt, extraCols), rand); /// - /// Create a row cursor set for the RoleMappedData with the indicated standard columns active. + /// Create a row cursor set for the with the indicated standard columns active. /// This does not verify that the columns exist, but merely activates the ones that do exist. /// - public static IRowCursor[] CreateRowCursorSet(this RoleMappedData data, out IRowCursorConsolidator consolidator, - CursOpt opt, int n, IRandom rand, IEnumerable extraCols = null) - => data.Data.GetRowCursorSet(out consolidator, CreatePredicate(data, opt, extraCols), n, rand); + public static RowCursor[] CreateRowCursorSet(this RoleMappedData data, + CursOpt opt, int n, Random rand, IEnumerable extraCols = null) + => data.Data.GetRowCursorSet(CreatePredicate(data, opt, extraCols), n, rand); - private static void AddOpt(HashSet cols, ColumnInfo info) + private static void AddOpt(HashSet cols, Schema.Column? info) { Contracts.AssertValue(cols); - if (info != null) - cols.Add(info.Index); + if (info.HasValue) + cols.Add(info.Value.Index); } /// /// Get the getter for the feature column, assuming it is a vector of float. /// - public static ValueGetter> GetFeatureFloatVectorGetter(this IRow row, RoleMappedSchema schema) + public static ValueGetter> GetFeatureFloatVectorGetter(this Row row, RoleMappedSchema schema) { Contracts.CheckValue(row, nameof(row)); Contracts.CheckValue(schema, nameof(schema)); Contracts.CheckParam(schema.Schema == row.Schema, nameof(schema), "schemas don't match!"); - Contracts.CheckParam(schema.Feature != null, nameof(schema), "Missing feature column"); + Contracts.CheckParam(schema.Feature.HasValue, nameof(schema), "Missing feature column"); - return row.GetGetter>(schema.Feature.Index); + return row.GetGetter>(schema.Feature.Value.Index); } /// /// Get the getter for the feature column, assuming it is a vector of float. /// - public static ValueGetter> GetFeatureFloatVectorGetter(this IRow row, RoleMappedData data) + public static ValueGetter> GetFeatureFloatVectorGetter(this Row row, RoleMappedData data) { Contracts.CheckValue(data, nameof(data)); return GetFeatureFloatVectorGetter(row, data.Schema); @@ -281,21 +282,21 @@ public static ValueGetter> GetFeatureFloatVectorGetter(this IRow /// Get a getter for the label as a float. This assumes that the label column type /// has already been validated as appropriate for the kind of training being done. /// - public static ValueGetter GetLabelFloatGetter(this IRow row, RoleMappedSchema schema) + public static ValueGetter GetLabelFloatGetter(this Row row, RoleMappedSchema schema) { Contracts.CheckValue(row, nameof(row)); Contracts.CheckValue(schema, nameof(schema)); Contracts.CheckParam(schema.Schema == row.Schema, nameof(schema), "schemas don't match!"); - Contracts.CheckParam(schema.Label != null, nameof(schema), "Missing label column"); + Contracts.CheckParam(schema.Label.HasValue, nameof(schema), "Missing label column"); - return RowCursorUtils.GetLabelGetter(row, schema.Label.Index); + return RowCursorUtils.GetLabelGetter(row, schema.Label.Value.Index); } /// /// Get a getter for the label as a float. This assumes that the label column type /// has already been validated as appropriate for the kind of training being done. /// - public static ValueGetter GetLabelFloatGetter(this IRow row, RoleMappedData data) + public static ValueGetter GetLabelFloatGetter(this Row row, RoleMappedData data) { Contracts.CheckValue(data, nameof(data)); return GetLabelFloatGetter(row, data.Schema); @@ -304,20 +305,19 @@ public static ValueGetter GetLabelFloatGetter(this IRow row, RoleMappedDa /// /// Get the getter for the weight column, or null if there is no weight column. /// - public static ValueGetter GetOptWeightFloatGetter(this IRow row, RoleMappedSchema schema) + public static ValueGetter GetOptWeightFloatGetter(this Row row, RoleMappedSchema schema) { Contracts.CheckValue(row, nameof(row)); Contracts.CheckValue(schema, nameof(schema)); Contracts.Check(schema.Schema == row.Schema, "schemas don't match!"); - Contracts.CheckValueOrNull(schema.Weight); var col = schema.Weight; - if (col == null) + if (!col.HasValue) return null; - return RowCursorUtils.GetGetterAs(NumberType.Float, row, col.Index); + return RowCursorUtils.GetGetterAs(NumberType.Float, row, col.Value.Index); } - public static ValueGetter GetOptWeightFloatGetter(this IRow row, RoleMappedData data) + public static ValueGetter GetOptWeightFloatGetter(this Row row, RoleMappedData data) { Contracts.CheckValue(data, nameof(data)); return GetOptWeightFloatGetter(row, data.Schema); @@ -326,20 +326,19 @@ public static ValueGetter GetOptWeightFloatGetter(this IRow row, RoleMapp /// /// Get the getter for the group column, or null if there is no group column. /// - public static ValueGetter GetOptGroupGetter(this IRow row, RoleMappedSchema schema) + public static ValueGetter GetOptGroupGetter(this Row row, RoleMappedSchema schema) { Contracts.CheckValue(row, nameof(row)); Contracts.CheckValue(schema, nameof(schema)); Contracts.Check(schema.Schema == row.Schema, "schemas don't match!"); - Contracts.CheckValueOrNull(schema.Group); var col = schema.Group; - if (col == null) + if (!col.HasValue) return null; - return RowCursorUtils.GetGetterAs(NumberType.U8, row, col.Index); + return RowCursorUtils.GetGetterAs(NumberType.U8, row, col.Value.Index); } - public static ValueGetter GetOptGroupGetter(this IRow row, RoleMappedData data) + public static ValueGetter GetOptGroupGetter(this Row row, RoleMappedData data) { Contracts.CheckValue(data, nameof(data)); return GetOptGroupGetter(row, data.Schema); @@ -353,11 +352,11 @@ public static SchemaShape.Column MakeBoolScalarLabel(string labelColumn) => new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false); /// - /// The for the label column for regression tasks. + /// The for the float type columns. /// - /// name of the weight column - public static SchemaShape.Column MakeR4ScalarLabel(string labelColumn) - => new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); + /// name of the column + public static SchemaShape.Column MakeR4ScalarColumn(string columnName) + => new SchemaShape.Column(columnName, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); /// /// The for the label column for regression tasks. @@ -366,7 +365,7 @@ public static SchemaShape.Column MakeR4ScalarLabel(string labelColumn) public static SchemaShape.Column MakeU4ScalarColumn(string columnName) { if (columnName == null) - return null; + return default; return new SchemaShape.Column(columnName, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true); } @@ -386,14 +385,14 @@ public static SchemaShape.Column MakeR4VecFeature(string featureColumn) public static SchemaShape.Column MakeR4ScalarWeightColumn(string weightColumn, bool isExplicit = true) { if (weightColumn == null || !isExplicit) - return null; + return default; return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); } } /// /// This is the base class for a data cursor. Data cursors are specially typed - /// "convenience" cursor-like objects, less general than a but + /// "convenience" cursor-like objects, less general than a but /// more convenient for common access patterns that occur in machine learning. For /// example, the common idiom of iterating over features/labels/weights while skipping /// "bad" features, labels, and weights. There will be two typical access patterns for @@ -402,25 +401,23 @@ public static SchemaShape.Column MakeR4ScalarWeightColumn(string weightColumn, b /// repeated accesses, is to use a cursor factory (usually a nested class of the cursor /// class). This keeps track of what filtering options were actually useful. /// - public abstract class TrainingCursorBase : IDisposable + [BestFriend] + internal abstract class TrainingCursorBase : IDisposable { - public IRow Row { get { return _cursor; } } + public Row Row => _cursor; - private readonly IRowCursor _cursor; + private readonly RowCursor _cursor; private readonly Action _signal; - private long _skipCount; - private long _keptCount; - - public long SkippedRowCount { get { return _skipCount; } } - public long KeptRowCount { get { return _keptCount; } } + public long SkippedRowCount { get; private set; } + public long KeptRowCount { get; private set; } /// /// The base constructor class for the factory-based cursor creation. /// /// /// This method is called - protected TrainingCursorBase(IRowCursor input, Action signal) + protected TrainingCursorBase(RowCursor input, Action signal) { Contracts.AssertValue(input); Contracts.AssertValueOrNull(signal); @@ -428,7 +425,7 @@ protected TrainingCursorBase(IRowCursor input, Action signal) _signal = signal; } - protected static IRowCursor CreateCursor(RoleMappedData data, CursOpt opt, IRandom rand, params int[] extraCols) + protected static RowCursor CreateCursor(RoleMappedData data, CursOpt opt, Random rand, params int[] extraCols) { Contracts.AssertValue(data); Contracts.AssertValueOrNull(rand); @@ -471,10 +468,10 @@ public bool MoveNext() } if (Accept()) { - _keptCount++; + KeptRowCount++; return true; } - _skipCount++; + SkippedRowCount++; } } @@ -531,13 +528,13 @@ private void SignalCore(CursOpt opt) } /// - /// The typed analog to . + /// The typed analog to . /// /// Non-null if we are requesting a shuffled cursor. /// The extra columns to activate on the row cursor /// in addition to those required by the factory's options. /// The wrapping typed cursor. - public TCurs Create(IRandom rand = null, params int[] extraCols) + public TCurs Create(Random rand = null, params int[] extraCols) { CursOpt opt; lock (_lock) @@ -558,17 +555,15 @@ public TCurs Create(IRandom rand = null, params int[] extraCols) /// in addition to those required by the factory's options. /// The cursor set. Note that this needn't necessarily be of size /// . - public TCurs[] CreateSet(int n, IRandom rand = null, params int[] extraCols) + public TCurs[] CreateSet(int n, Random rand = null, params int[] extraCols) { CursOpt opt; lock (_lock) opt = _opts; - // The intended use of this sort of thing is for cases where we have no interest in - // doing consolidation at all, that is, the consuming endpoint using these typed - // cursors wants to consume them as a set. - IRowCursorConsolidator consolidator; - var inputs = _data.CreateRowCursorSet(out consolidator, opt, n, rand, extraCols); + // Users of this method will tend to consume the cursors in the set in separate + // threads, and so gain benefit from the parallel transformation of the data. + var inputs = _data.CreateRowCursorSet(opt, n, rand, extraCols); Contracts.Assert(Utils.Size(inputs) > 0); Action signal; @@ -596,7 +591,7 @@ public TCurs[] CreateSet(int n, IRandom rand = null, params int[] extraCols) /// , whose return value is used to call /// this action. /// - protected abstract TCurs CreateCursorCore(IRowCursor input, RoleMappedData data, CursOpt opt, Action signal); + protected abstract TCurs CreateCursorCore(RowCursor input, RoleMappedData data, CursOpt opt, Action signal); /// /// Accumulates signals from cursors, anding them together. Once it has @@ -634,31 +629,30 @@ public void Signal(CursOpt opt) } /// - /// This supports Weight (float), Group (ulong), and Id (UInt128) columns. + /// This supports Weight (float), Group (ulong), and Id (RowId) columns. /// - public class StandardScalarCursor : TrainingCursorBase + [BestFriend] + internal class StandardScalarCursor : TrainingCursorBase { private readonly ValueGetter _getWeight; private readonly ValueGetter _getGroup; - private readonly ValueGetter _getId; + private readonly ValueGetter _getId; private readonly bool _keepBadWeight; private readonly bool _keepBadGroup; - private long _badWeightCount; - private long _badGroupCount; - public long BadWeightCount { get { return _badWeightCount; } } - public long BadGroupCount { get { return _badGroupCount; } } + public long BadWeightCount { get; private set; } + public long BadGroupCount { get; private set; } public float Weight; public ulong Group; - public UInt128 Id; + public RowId Id; - public StandardScalarCursor(RoleMappedData data, CursOpt opt, IRandom rand = null, params int[] extraCols) + public StandardScalarCursor(RoleMappedData data, CursOpt opt, Random rand = null, params int[] extraCols) : this(CreateCursor(data, opt, rand, extraCols), data, opt) { } - protected StandardScalarCursor(IRowCursor input, RoleMappedData data, CursOpt opt, Action signal = null) + protected StandardScalarCursor(RowCursor input, RoleMappedData data, CursOpt opt, Action signal = null) : base(input, signal) { Contracts.AssertValue(data); @@ -698,7 +692,7 @@ public override bool Accept() _getWeight(ref Weight); if (!_keepBadWeight && !(0 < Weight && Weight < float.PositiveInfinity)) { - _badWeightCount++; + BadWeightCount++; return false; } } @@ -707,7 +701,7 @@ public override bool Accept() _getGroup(ref Group); if (!_keepBadGroup && Group == 0) { - _badGroupCount++; + BadGroupCount++; return false; } } @@ -723,7 +717,7 @@ public Factory(RoleMappedData data, CursOpt opt) { } - protected override StandardScalarCursor CreateCursorCore(IRowCursor input, RoleMappedData data, CursOpt opt, Action signal) + protected override StandardScalarCursor CreateCursorCore(RowCursor input, RoleMappedData data, CursOpt opt, Action signal) => new StandardScalarCursor(input, data, opt, signal); } } @@ -732,23 +726,23 @@ protected override StandardScalarCursor CreateCursorCore(IRowCursor input, RoleM /// This derives from and adds the feature column /// as a . /// - public class FeatureFloatVectorCursor : StandardScalarCursor + [BestFriend] + internal class FeatureFloatVectorCursor : StandardScalarCursor { private readonly ValueGetter> _get; private readonly bool _keepBad; - private long _badCount; - public long BadFeaturesRowCount { get { return _badCount; } } + public long BadFeaturesRowCount { get; private set; } public VBuffer Features; public FeatureFloatVectorCursor(RoleMappedData data, CursOpt opt = CursOpt.Features, - IRandom rand = null, params int[] extraCols) + Random rand = null, params int[] extraCols) : this(CreateCursor(data, opt, rand, extraCols), data, opt) { } - protected FeatureFloatVectorCursor(IRowCursor input, RoleMappedData data, CursOpt opt, Action signal = null) + protected FeatureFloatVectorCursor(RowCursor input, RoleMappedData data, CursOpt opt, Action signal = null) : base(input, data, opt, signal) { if ((opt & CursOpt.Features) != 0 && data.Schema.Feature != null) @@ -775,7 +769,7 @@ public override bool Accept() _get(ref Features); if (!_keepBad && !FloatUtils.IsFinite(Features.GetValues())) { - _badCount++; + BadFeaturesRowCount++; return false; } } @@ -789,7 +783,7 @@ public Factory(RoleMappedData data, CursOpt opt = CursOpt.Features) { } - protected override FeatureFloatVectorCursor CreateCursorCore(IRowCursor input, RoleMappedData data, CursOpt opt, Action signal) + protected override FeatureFloatVectorCursor CreateCursorCore(RowCursor input, RoleMappedData data, CursOpt opt, Action signal) { return new FeatureFloatVectorCursor(input, data, opt, signal); } @@ -799,24 +793,23 @@ protected override FeatureFloatVectorCursor CreateCursorCore(IRowCursor input, R /// /// This derives from the FeatureFloatVectorCursor and adds the Label (float) column. /// - public class FloatLabelCursor : FeatureFloatVectorCursor + [BestFriend] + internal class FloatLabelCursor : FeatureFloatVectorCursor { private readonly ValueGetter _get; private readonly bool _keepBad; - private long _badCount; - - public long BadLabelCount { get { return _badCount; } } + public long BadLabelCount { get; private set; } public float Label; public FloatLabelCursor(RoleMappedData data, CursOpt opt = CursOpt.Label, - IRandom rand = null, params int[] extraCols) + Random rand = null, params int[] extraCols) : this(CreateCursor(data, opt, rand, extraCols), data, opt) { } - protected FloatLabelCursor(IRowCursor input, RoleMappedData data, CursOpt opt, Action signal = null) + protected FloatLabelCursor(RowCursor input, RoleMappedData data, CursOpt opt, Action signal = null) : base(input, data, opt, signal) { if ((opt & CursOpt.Label) != 0 && data.Schema.Label != null) @@ -842,7 +835,7 @@ public override bool Accept() _get(ref Label); if (!_keepBad && !FloatUtils.IsFinite(Label)) { - _badCount++; + BadLabelCount++; return false; } } @@ -856,7 +849,7 @@ public Factory(RoleMappedData data, CursOpt opt = CursOpt.Label) { } - protected override FloatLabelCursor CreateCursorCore(IRowCursor input, RoleMappedData data, CursOpt opt, Action signal) + protected override FloatLabelCursor CreateCursorCore(RowCursor input, RoleMappedData data, CursOpt opt, Action signal) { return new FloatLabelCursor(input, data, opt, signal); } @@ -867,25 +860,25 @@ protected override FloatLabelCursor CreateCursorCore(IRowCursor input, RoleMappe /// This derives from the FeatureFloatVectorCursor and adds the Label (int) column, /// enforcing multi-class semantics. /// - public class MultiClassLabelCursor : FeatureFloatVectorCursor + [BestFriend] + internal class MultiClassLabelCursor : FeatureFloatVectorCursor { private readonly int _classCount; private readonly ValueGetter _get; private readonly bool _keepBad; - private long _badCount; - public long BadLabelCount { get { return _badCount; } } + public long BadLabelCount { get; private set; } private float _raw; public int Label; public MultiClassLabelCursor(int classCount, RoleMappedData data, CursOpt opt = CursOpt.Label, - IRandom rand = null, params int[] extraCols) + Random rand = null, params int[] extraCols) : this(classCount, CreateCursor(data, opt, rand, extraCols), data, opt) { } - protected MultiClassLabelCursor(int classCount, IRowCursor input, RoleMappedData data, CursOpt opt, Action signal = null) + protected MultiClassLabelCursor(int classCount, RowCursor input, RoleMappedData data, CursOpt opt, Action signal = null) : base(input, data, opt, signal) { Contracts.Assert(classCount >= 0); @@ -915,7 +908,7 @@ public override bool Accept() Label = (int)_raw; if (!_keepBad && !(Label == _raw && (0 <= _raw && (_raw < _classCount || _classCount == 0)))) { - _badCount++; + BadLabelCount++; return false; } } @@ -934,7 +927,7 @@ public Factory(int classCount, RoleMappedData data, CursOpt opt = CursOpt.Label) _classCount = classCount; } - protected override MultiClassLabelCursor CreateCursorCore(IRowCursor input, RoleMappedData data, CursOpt opt, Action signal) + protected override MultiClassLabelCursor CreateCursorCore(RowCursor input, RoleMappedData data, CursOpt opt, Action signal) { return new MultiClassLabelCursor(_classCount, input, data, opt, signal); } diff --git a/src/Microsoft.ML.Data/Training/TrainingStaticExtensions.cs b/src/Microsoft.ML.Data/Training/TrainingStaticExtensions.cs index 8cfda0485e..b950f87180 100644 --- a/src/Microsoft.ML.Data/Training/TrainingStaticExtensions.cs +++ b/src/Microsoft.ML.Data/Training/TrainingStaticExtensions.cs @@ -2,15 +2,12 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; +using System.Linq; using Microsoft.ML.Core.Data; +using Microsoft.ML.Data; using Microsoft.ML.StaticPipe; using Microsoft.ML.StaticPipe.Runtime; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; namespace Microsoft.ML { @@ -28,13 +25,15 @@ public static class TrainingStaticExtensions /// The training context. /// The dataset to split. /// The fraction of data to go into the test set. - /// Optional selector for the stratification column. - /// If two examples share the same value of the (if provided), - /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from - /// train to the test set. + /// Optional selector for the column to use as a stratification column. If two examples share the same value of the + /// (if provided), they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from train to the test set. + /// If this optional parameter is not provided, a stratification columns will be generated, and its values will be random numbers . + /// Optional parameter used in combination with the . + /// If the is not provided, the random numbers generated to create it, will use this seed as value. + /// And if it is not provided, the default value will be used. /// A pair of datasets, for the train and test set. public static (DataView trainSet, DataView testSet) TrainTestSplit(this TrainContextBase context, - DataView data, double testFraction = 0.1, Func stratificationColumn = null) + DataView data, double testFraction = 0.1, Func stratificationColumn = null, uint? seed = null) { var env = StaticPipeUtils.GetEnvironment(data); Contracts.AssertValue(env); @@ -51,7 +50,7 @@ public static (DataView trainSet, DataView testSet) TrainTestSplit(this stratName = indexer.Get(column); } - var (trainData, testData) = context.TrainTestSplit(data.AsDynamic, testFraction, stratName); + var (trainData, testData) = context.TrainTestSplit(data.AsDynamic, testFraction, stratName, seed); return (new DataView(env, trainData, data.Shape), new DataView(env, testData, data.Shape)); } @@ -68,18 +67,20 @@ public static (DataView trainSet, DataView testSet) TrainTestSplit(this /// The estimator to fit. /// Number of cross-validation folds. /// The label column (for evaluation). - /// Optional stratification column. - /// If two examples share the same value of the (if provided), - /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from - /// train to the test set. + /// Optional selector for the column to use as a stratification column. If two examples share the same value of the + /// (if provided), they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from train to the test set. + /// If this optional parameter is not provided, a stratification columns will be generated, and its values will be random numbers . + /// Optional parameter used in combination with the . + /// If the is not provided, the random numbers generated to create it, will use this seed as value. + /// And if it is not provided, the default value will be used. /// Per-fold results: metrics, models, scored datasets. - public static (RegressionEvaluator.Result metrics, Transformer model, DataView scoredTestData)[] CrossValidate( + public static (RegressionMetrics metrics, Transformer model, DataView scoredTestData)[] CrossValidate( this RegressionContext context, DataView data, Estimator estimator, Func> label, int numFolds = 5, - Func stratificationColumn = null) + Func stratificationColumn = null, uint? seed = null) where TTransformer : class, ITransformer { var env = StaticPipeUtils.GetEnvironment(data); @@ -102,7 +103,7 @@ public static (RegressionEvaluator.Result metrics, Transformer ( x.metrics, @@ -124,18 +125,20 @@ public static (RegressionEvaluator.Result metrics, TransformerThe estimator to fit. /// Number of cross-validation folds. /// The label column (for evaluation). - /// Optional stratification column. - /// If two examples share the same value of the (if provided), - /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from - /// train to the test set. + /// Optional selector for the column to use as a stratification column. If two examples share the same value of the + /// (if provided), they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from train to the test set. + /// If this optional parameter is not provided, a stratification columns will be generated, and its values will be random numbers . + /// Optional parameter used in combination with the . + /// If the is not provided, the random numbers generated to create it, will use this seed as value. + /// And if it is not provided, the default value will be used. /// Per-fold results: metrics, models, scored datasets. - public static (MultiClassClassifierEvaluator.Result metrics, Transformer model, DataView scoredTestData)[] CrossValidate( + public static (MultiClassClassifierMetrics metrics, Transformer model, DataView scoredTestData)[] CrossValidate( this MulticlassClassificationContext context, DataView data, Estimator estimator, Func> label, int numFolds = 5, - Func stratificationColumn = null) + Func stratificationColumn = null, uint? seed = null) where TTransformer : class, ITransformer { var env = StaticPipeUtils.GetEnvironment(data); @@ -158,7 +161,7 @@ public static (MultiClassClassifierEvaluator.Result metrics, Transformer ( x.metrics, @@ -180,18 +183,20 @@ public static (MultiClassClassifierEvaluator.Result metrics, TransformerThe estimator to fit. /// Number of cross-validation folds. /// The label column (for evaluation). - /// Optional stratification column. - /// If two examples share the same value of the (if provided), - /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from - /// train to the test set. + /// Optional selector for the column to use as a stratification column. If two examples share the same value of the + /// (if provided), they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from train to the test set. + /// If this optional parameter is not provided, a stratification columns will be generated, and its values will be random numbers . + /// Optional parameter used in combination with the . + /// If the is not provided, the random numbers generated to create it, will use this seed as value. + /// And if it is not provided, the default value will be used. /// Per-fold results: metrics, models, scored datasets. - public static (BinaryClassifierEvaluator.Result metrics, Transformer model, DataView scoredTestData)[] CrossValidateNonCalibrated( + public static (BinaryClassificationMetrics metrics, Transformer model, DataView scoredTestData)[] CrossValidateNonCalibrated( this BinaryClassificationContext context, DataView data, Estimator estimator, Func> label, int numFolds = 5, - Func stratificationColumn = null) + Func stratificationColumn = null, uint? seed = null) where TTransformer : class, ITransformer { var env = StaticPipeUtils.GetEnvironment(data); @@ -214,7 +219,7 @@ public static (BinaryClassifierEvaluator.Result metrics, Transformer ( x.metrics, @@ -236,18 +241,20 @@ public static (BinaryClassifierEvaluator.Result metrics, TransformerThe estimator to fit. /// Number of cross-validation folds. /// The label column (for evaluation). - /// Optional stratification column. - /// If two examples share the same value of the (if provided), - /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from - /// train to the test set. + /// Optional selector for the column to use as a stratification column. If two examples share the same value of the + /// (if provided), they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from train to the test set. + /// If this optional parameter is not provided, a stratification columns will be generated, and its values will be random numbers . + /// Optional parameter used in combination with the . + /// If the is not provided, the random numbers generated to create it, will use this seed as value. + /// And if it is not provided, the default value will be used. /// Per-fold results: metrics, models, scored datasets. - public static (BinaryClassifierEvaluator.CalibratedResult metrics, Transformer model, DataView scoredTestData)[] CrossValidate( + public static (CalibratedBinaryClassificationMetrics metrics, Transformer model, DataView scoredTestData)[] CrossValidate( this BinaryClassificationContext context, DataView data, Estimator estimator, Func> label, int numFolds = 5, - Func stratificationColumn = null) + Func stratificationColumn = null, uint? seed = null) where TTransformer : class, ITransformer { var env = StaticPipeUtils.GetEnvironment(data); @@ -270,7 +277,7 @@ public static (BinaryClassifierEvaluator.CalibratedResult metrics, Transformer ( x.metrics, diff --git a/src/Microsoft.ML.Data/Transforms/BindingsWrappedRowCursor.cs b/src/Microsoft.ML.Data/Transforms/BindingsWrappedRowCursor.cs index fecac3d005..d489acfe9a 100644 --- a/src/Microsoft.ML.Data/Transforms/BindingsWrappedRowCursor.cs +++ b/src/Microsoft.ML.Data/Transforms/BindingsWrappedRowCursor.cs @@ -2,9 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Data; - -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// A class for mapping an input to an output cursor assuming no output columns @@ -13,11 +11,11 @@ namespace Microsoft.ML.Runtime.Data /// inconvenient or inefficient to handle the "no output selected" case in their /// own implementation. /// - internal sealed class BindingsWrappedRowCursor : SynchronizedCursorBase, IRowCursor + internal sealed class BindingsWrappedRowCursor : SynchronizedCursorBase { private readonly ColumnBindingsBase _bindings; - public Schema Schema => _bindings.AsSchema; + public override Schema Schema => _bindings.AsSchema; /// /// Creates a wrapped version of the cursor @@ -25,7 +23,7 @@ internal sealed class BindingsWrappedRowCursor : SynchronizedCursorBaseChannel provider /// The input cursor /// The bindings object, - public BindingsWrappedRowCursor(IChannelProvider provider, IRowCursor input, ColumnBindingsBase bindings) + public BindingsWrappedRowCursor(IChannelProvider provider, RowCursor input, ColumnBindingsBase bindings) : base(provider, input) { Ch.CheckValue(input, nameof(input)); @@ -34,7 +32,7 @@ public BindingsWrappedRowCursor(IChannelProvider provider, IRowCursor input, Col _bindings = bindings; } - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { Ch.Check(0 <= col & col < _bindings.ColumnCount, "col"); bool isSrc; @@ -42,7 +40,7 @@ public bool IsColumnActive(int col) return isSrc && Input.IsColumnActive(col); } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Ch.Check(IsColumnActive(col), "col"); bool isSrc; diff --git a/src/Microsoft.ML.Data/Transforms/CatalogUtils.cs b/src/Microsoft.ML.Data/Transforms/CatalogUtils.cs index 3f7c3356c3..f7aed4a8de 100644 --- a/src/Microsoft.ML.Data/Transforms/CatalogUtils.cs +++ b/src/Microsoft.ML.Data/Transforms/CatalogUtils.cs @@ -1,12 +1,8 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Collections.Generic; -using System.Text; - -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// Set of extension methods to extract from various catalog classes. @@ -16,9 +12,9 @@ public static class CatalogUtils public static IHostEnvironment GetEnvironment(this TransformsCatalog catalog) => Contracts.CheckRef(catalog, nameof(catalog)).Environment; public static IHostEnvironment GetEnvironment(this TransformsCatalog.SubCatalogBase subCatalog) => Contracts.CheckRef(subCatalog, nameof(subCatalog)).Environment; public static IHostEnvironment GetEnvironment(this ModelOperationsCatalog catalog) => Contracts.CheckRef(catalog, nameof(catalog)).Environment; + public static IHostEnvironment GetEnvironment(this ModelOperationsCatalog.SubCatalogBase subCatalog) => Contracts.CheckRef(subCatalog, nameof(subCatalog)).Environment; public static IHostEnvironment GetEnvironment(this DataOperations catalog) => Contracts.CheckRef(catalog, nameof(catalog)).Environment; public static IHostEnvironment GetEnvironment(TrainContextBase.ContextInstantiatorBase obj) => Contracts.CheckRef(obj, nameof(obj)).Owner.Environment; public static IHostEnvironment GetEnvironment(TrainContextBase ctx) => Contracts.CheckRef(ctx, nameof(ctx)).Environment; - } } diff --git a/src/Microsoft.ML.Data/Transforms/ColumnBindingsBase.cs b/src/Microsoft.ML.Data/Transforms/ColumnBindingsBase.cs index c76ebbc57a..eb6eae0972 100644 --- a/src/Microsoft.ML.Data/Transforms/ColumnBindingsBase.cs +++ b/src/Microsoft.ML.Data/Transforms/ColumnBindingsBase.cs @@ -2,18 +2,16 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Data; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Transforms.Conversions; using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Transforms.Conversions; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { public abstract class SourceNameColumnBase { @@ -264,9 +262,10 @@ public static T Create(string name, params string[] source) /// the input column, and is placed immediately after the input column. Otherwise, the added column is placed /// at the end. By default, newly added columns have no metadata (but this can be overriden). /// - public abstract class ColumnBindingsBase : ISchema + [BestFriend] + internal abstract class ColumnBindingsBase { - public readonly ISchema Input; + public readonly Schema Input; // Mapping from name to index into Infos for the columns that we "generate". // Some of these might "hide" input columns. @@ -287,22 +286,29 @@ public abstract class ColumnBindingsBase : ISchema public Schema AsSchema => _convertedSchema.Value; - /// - /// Constructor that takes an input schema and adds no new columns. - /// This is utilized by lambda transforms if the output happens to have no columns. - /// - /// - protected ColumnBindingsBase(ISchema input) + private static Schema CreateSchema(ColumnBindingsBase inputBindings) { - Contracts.CheckValue(input, nameof(input)); + Contracts.CheckValue(inputBindings, nameof(inputBindings)); - Input = input; + var builder = new SchemaBuilder(); + for (int i = 0; i < inputBindings.ColumnCount; i++) + { + var meta = new MetadataBuilder(); + foreach (var kvp in inputBindings.GetMetadataTypes(i)) + { + var getter = Utils.MarshalInvoke(GetMetadataGetterDelegate, kvp.Value.RawType, inputBindings, i, kvp.Key); + meta.Add(kvp.Key, kvp.Value, getter); + } + builder.AddColumn(inputBindings.GetColumnName(i), inputBindings.GetColumnType(i), meta.GetMetadata()); + } - _names = new string[0]; - _nameToInfoIndex = new Dictionary(); - Contracts.Assert(_nameToInfoIndex.Count == _names.Length); - ComputeColumnMapping(Input, _names, out _colMap, out _mapIinfoToCol); - _convertedSchema = new Lazy(() => Schema.Create(this), LazyThreadSafetyMode.PublicationOnly); + return builder.GetSchema(); + } + + private static Delegate GetMetadataGetterDelegate(ColumnBindingsBase bindings, int col, string kind) + { + ValueGetter getter = (ref TValue value) => bindings.GetMetadata(kind, col, ref value); + return getter; } /// @@ -311,7 +317,7 @@ protected ColumnBindingsBase(ISchema input) /// in schemaInput. For error reporting, this assumes that the names come from a user-supplied /// parameter named "column". This takes ownership of the params array of names. /// - protected ColumnBindingsBase(ISchema input, bool user, params string[] names) + protected ColumnBindingsBase(Schema input, bool user, params string[] names) { Contracts.CheckValue(input, nameof(input)); Contracts.CheckNonEmpty(names, nameof(names)); @@ -351,15 +357,15 @@ protected ColumnBindingsBase(ISchema input, bool user, params string[] names) Contracts.Assert(_nameToInfoIndex.Count == names.Length); ComputeColumnMapping(Input, names, out _colMap, out _mapIinfoToCol); - _convertedSchema = new Lazy(() => Schema.Create(this), LazyThreadSafetyMode.PublicationOnly); + _convertedSchema = new Lazy(() => CreateSchema(this), LazyThreadSafetyMode.PublicationOnly); } - private static void ComputeColumnMapping(ISchema input, string[] names, out int[] colMap, out int[] mapIinfoToCol) + private static void ComputeColumnMapping(Schema input, string[] names, out int[] colMap, out int[] mapIinfoToCol) { // To compute the column mapping information, first populate: // * _colMap[src] with the ~ of the iinfo that hides src (zero for none). // * _mapIinfoToCol[iinfo] with the ~ of the source column that iinfo hides (zero for none). - colMap = new int[input.ColumnCount + names.Length]; + colMap = new int[input.Count + names.Length]; mapIinfoToCol = new int[names.Length]; for (int iinfo = 0; iinfo < names.Length; iinfo++) { @@ -367,8 +373,8 @@ private static void ComputeColumnMapping(ISchema input, string[] names, out int[ int colHidden; if (input.TryGetColumnIndex(name, out colHidden)) { - Contracts.Check(0 <= colHidden && colHidden < input.ColumnCount); - var str = input.GetColumnName(colHidden); + Contracts.Check(0 <= colHidden && colHidden < input.Count); + var str = input[colHidden].Name; Contracts.Check(str == name); Contracts.Check(colMap[colHidden] == 0); mapIinfoToCol[iinfo] = ~colHidden; @@ -387,7 +393,7 @@ private static void ComputeColumnMapping(ISchema input, string[] names, out int[ mapIinfoToCol[iinfo] = colDst; } } - for (int colSrc = input.ColumnCount; --colSrc >= 0;) + for (int colSrc = input.Count; --colSrc >= 0;) { Contracts.Assert(colMap[colSrc] <= 0); if (colMap[colSrc] < 0) @@ -436,7 +442,7 @@ public int MapColumnIndex(out bool isSrcColumn, int col) } else { - Contracts.Assert(index < Input.ColumnCount); + Contracts.Assert(index < Input.Count); isSrcColumn = true; } return index; @@ -476,7 +482,7 @@ public bool TryGetColumnIndex(string name, out int col) int src; if (Input.TryGetColumnIndex(name, out src)) { - Contracts.Assert(0 <= src && src < Input.ColumnCount); + Contracts.Assert(0 <= src && src < Input.Count); int res = src; for (; ; res++) { @@ -501,7 +507,7 @@ public string GetColumnName(int col) bool isSrc; int index = MapColumnIndex(out isSrc, col); if (isSrc) - return Input.GetColumnName(index); + return Input[index].Name; return GetColumnNameCore(index); } @@ -512,7 +518,7 @@ public ColumnType GetColumnType(int col) bool isSrc; int index = MapColumnIndex(out isSrc, col); if (isSrc) - return Input.GetColumnType(index); + return Input[index].Type; return GetColumnTypeCore(index); } @@ -523,7 +529,7 @@ public IEnumerable> GetMetadataTypes(int col) bool isSrc; int index = MapColumnIndex(out isSrc, col); if (isSrc) - return Input.GetMetadataTypes(index); + return Input[index].Metadata.Schema.Select(c => new KeyValuePair(c.Name, c.Type)); Contracts.Assert(0 <= index && index < InfoCount); return GetMetadataTypesCore(index); } @@ -536,7 +542,7 @@ public ColumnType GetMetadataTypeOrNull(string kind, int col) bool isSrc; int index = MapColumnIndex(out isSrc, col); if (isSrc) - return Input.GetMetadataTypeOrNull(kind, index); + return Input[index].Metadata.Schema.GetColumnOrNull(kind)?.Type; Contracts.Assert(0 <= index && index < InfoCount); return GetMetadataTypeCore(kind, index); } @@ -549,7 +555,7 @@ public void GetMetadata(string kind, int col, ref TValue value) bool isSrc; int index = MapColumnIndex(out isSrc, col); if (isSrc) - Input.GetMetadata(kind, index, ref value); + Input[index].Metadata.GetValue(kind, ref value); else { Contracts.Assert(0 <= index && index < InfoCount); @@ -610,11 +616,11 @@ public bool[] GetActiveInput(Func predicate) { Contracts.AssertValue(predicate); - var active = new bool[Input.ColumnCount]; + var active = new bool[Input.Count]; for (int dst = 0; dst < _colMap.Length; dst++) { int src = _colMap[dst]; - Contracts.Assert(-InfoCount <= src && src < Input.ColumnCount); + Contracts.Assert(-InfoCount <= src && src < Input.Count); if (src >= 0 && predicate(dst)) active[src] = true; } @@ -679,7 +685,7 @@ public ColumnBindings(Schema input, Schema.DetachedColumn[] addedColumns) // Construct the indices. var indices = new List(); var namesUsed = new HashSet(); - for (int i = 0; i < input.ColumnCount; i++) + for (int i = 0; i < input.Count; i++) { namesUsed.Add(input[i].Name); indices.Add(i); @@ -707,7 +713,7 @@ public ColumnBindings(Schema input, Schema.DetachedColumn[] addedColumns) } } } - Contracts.Assert(indices.Count == addedColumns.Length + input.ColumnCount); + Contracts.Assert(indices.Count == addedColumns.Length + input.Count); // Create the output schema. var schemaColumns = indices.Select(idx => idx >= 0 ? new Schema.DetachedColumn(input[idx]) : addedColumns[~idx]); @@ -749,7 +755,7 @@ public int MapColumnIndex(out bool isSrcColumn, int col) } else { - Contracts.Assert(index < InputSchema.ColumnCount); + Contracts.Assert(index < InputSchema.Count); isSrcColumn = true; } return index; @@ -764,11 +770,11 @@ public bool[] GetActiveInput(Func predicate) { Contracts.AssertValue(predicate); - var active = new bool[InputSchema.ColumnCount]; + var active = new bool[InputSchema.Count]; for (int dst = 0; dst < _colMap.Length; dst++) { int src = _colMap[dst]; - Contracts.Assert(-AddedColumnIndices.Count <= src && src < InputSchema.ColumnCount); + Contracts.Assert(-AddedColumnIndices.Count <= src && src < InputSchema.Count); if (src >= 0 && predicate(dst)) active[src] = true; } @@ -776,239 +782,6 @@ public bool[] GetActiveInput(Func predicate) } } - /// - /// Base type for bindings with multiple new columns, each mapping from multiple source columns. - /// The column strings are parsed as D:S where D is the name of the new column and S is a comma separated - /// list of source columns. A column string S with no colon is interpreted as S:S, so the destination - /// column has the same name as the source column. This form requires S to not contain commas. - /// Note that this base type requires columns of type typeMulti to have known size. - /// - public abstract class ManyToOneColumnBindingsBase : ColumnBindingsBase - { - public sealed class ColInfo - { - // Total size of all source columns, zero for variable (non-vector columns contribute 1). - public readonly int SrcSize; - // Indices of the source columns - public readonly int[] SrcIndices; - // Types of the source columns. - public readonly ColumnType[] SrcTypes; - - public ColInfo(int srcSize, int[] srcIndices, ColumnType[] srcTypes) - { - Contracts.Assert(srcSize >= 0); // Zero means variable. - Contracts.AssertNonEmpty(srcIndices); - Contracts.Assert(Utils.Size(srcTypes) == srcIndices.Length); - - SrcSize = srcSize; - SrcIndices = srcIndices; - SrcTypes = srcTypes; - } - } - - public readonly ColInfo[] Infos; - - protected ManyToOneColumnBindingsBase(ManyToOneColumn[] column, ISchema input, Func testTypes) - : base(input, true, GetNamesAndSanitize(column)) - { - Contracts.AssertNonEmpty(column); - Contracts.Assert(column.Length == InfoCount); - - // In lieu of actual protections, I have the following silly asserts, so we can have some - // warning if we decide to rename this argument, and so know to change the below hard-coded - // standard column name. - const string standardColumnArgName = "Column"; - Contracts.Assert(nameof(ValueToKeyMappingTransformer.Arguments.Column) == standardColumnArgName); - Contracts.Assert(nameof(ColumnConcatenatingTransformer.Arguments.Column) == standardColumnArgName); - - Infos = new ColInfo[InfoCount]; - for (int i = 0; i < Infos.Length; i++) - { - var item = column[i]; - Contracts.AssertNonEmpty(item.Name); - Contracts.AssertNonEmpty(item.Source); - - var src = item.Source; - var srcIndices = new int[src.Length]; - var srcTypes = new ColumnType[src.Length]; - int? srcSize = 0; - for (int j = 0; j < src.Length; j++) - { - Contracts.CheckUserArg(!string.IsNullOrWhiteSpace(src[j]), nameof(ManyToOneColumn.Source)); -#pragma warning disable MSML_ContractsNameUsesNameof // Unfortunately, there is no base class for the columns bindings. - if (!input.TryGetColumnIndex(src[j], out srcIndices[j])) - throw Contracts.ExceptUserArg(standardColumnArgName, "Source column '{0}' not found", src[j]); -#pragma warning restore MSML_ContractsNameUsesNameof - srcTypes[j] = input.GetColumnType(srcIndices[j]); - var size = srcTypes[j].ValueCount; - srcSize = size == 0 ? null : checked(srcSize + size); - } - - if (testTypes != null) - { - string reason = testTypes(srcTypes); - if (reason != null) - { -#pragma warning disable MSML_ContractsNameUsesNameof // Unfortunately, there is no base class for the columns bindings. - throw Contracts.ExceptUserArg(standardColumnArgName, "Column '{0}' has invalid source types: {1}. Source types: '{2}'.", - item.Name, reason, string.Join(", ", srcTypes.Select(type => type.ToString()))); -#pragma warning restore MSML_ContractsNameUsesNameof - } - } - Infos[i] = new ColInfo(srcSize.GetValueOrDefault(), srcIndices, srcTypes); - } - } - - /// - /// Gather the names from the objects. Also, cleanse the column - /// objects (propagate values that are shared). - /// - private static string[] GetNamesAndSanitize(ManyToOneColumn[] column) - { - Contracts.CheckUserArg(Utils.Size(column) > 0, nameof(column)); - - var names = new string[column.Length]; - for (int i = 0; i < column.Length; i++) - { - var item = column[i]; - if (string.IsNullOrWhiteSpace(item.Name)) - { - Contracts.CheckUserArg(Utils.Size(item.Source) > 0, nameof(item.Name), "Must specify name"); - Contracts.CheckUserArg(item.Source.Length == 1, nameof(item.Name), "New name is required when multiple source columns are specified"); - item.Name = item.Source[0]; - } - else if (Utils.Size(item.Source) == 0) - item.Source = new string[] { item.Name }; - names[i] = item.Name; - } - - return names; - } - - // Read everything into a new Contents object and pass it to the constructor below. - protected ManyToOneColumnBindingsBase(ModelLoadContext ctx, ISchema input, Func testTypes) - : this(new Contents(ctx, input, testTypes)) - { - } - - private ManyToOneColumnBindingsBase(Contents contents) - : base(contents.Input, false, contents.Names) - { - Contracts.Assert(InfoCount == Utils.Size(contents.Infos)); - Infos = contents.Infos; - } - - /// - /// This class is used to deserialize. We read everything into an instance of this - /// and pass that to another constructor. - /// - private sealed class Contents - { - public ISchema Input; - public ColInfo[] Infos; - public string[] Names; - - public Contents(ModelLoadContext ctx, ISchema input, Func testTypes) - { - Contracts.CheckValue(ctx, nameof(ctx)); - Contracts.CheckValue(input, nameof(input)); - Contracts.CheckValueOrNull(testTypes); - - Input = input; - - // *** Binary format *** - // int: number of added columns - // for each added column - // int: id of output column name - // int: number of input column names - // int[]: ids of input column names - int cinfo = ctx.Reader.ReadInt32(); - Contracts.CheckDecode(cinfo > 0); - - Infos = new ColInfo[cinfo]; - Names = new string[cinfo]; - for (int i = 0; i < cinfo; i++) - { - Names[i] = ctx.LoadNonEmptyString(); - - int csrc = ctx.Reader.ReadInt32(); - Contracts.CheckDecode(csrc > 0); - int[] indices = new int[csrc]; - var srcTypes = new ColumnType[csrc]; - int? srcSize = 0; - for (int j = 0; j < csrc; j++) - { - string src = ctx.LoadNonEmptyString(); - if (!input.TryGetColumnIndex(src, out indices[j])) - throw Contracts.Except("Source column '{0}' is required but not found", src); - srcTypes[j] = input.GetColumnType(indices[j]); - var size = srcTypes[j].ValueCount; - srcSize = size == 0 ? null : checked(srcSize + size); - } - - if (testTypes != null) - { - string reason = testTypes(srcTypes); - if (reason != null) - { - throw Contracts.Except("Source columns '{0}' have invalid types: {1}. Source types: '{2}'.", - string.Join(", ", indices.Select(k => input.GetColumnName(k))), - reason, - string.Join(", ", srcTypes.Select(type => type.ToString()))); - } - } - - Infos[i] = new ColInfo(srcSize.GetValueOrDefault(), indices, srcTypes); - } - } - } - - public virtual void Save(ModelSaveContext ctx) - { - Contracts.AssertValue(ctx); - - // *** Binary format *** - // int: number of added columns - // for each added column - // int: id of output column name - // int: number of input column names - // int[]: ids of input column names - ctx.Writer.Write(Infos.Length); - for (int i = 0; i < Infos.Length; i++) - { - var info = Infos[i]; - ctx.SaveNonEmptyString(GetColumnNameCore(i)); - ctx.Writer.Write(info.SrcIndices.Length); - foreach (int src in info.SrcIndices) - ctx.SaveNonEmptyString(Input.GetColumnName(src)); - } - } - - public Func GetDependencies(Func predicate) - { - Contracts.AssertValue(predicate); - - var active = new bool[Input.ColumnCount]; - for (int col = 0; col < ColumnCount; col++) - { - if (!predicate(col)) - continue; - - bool isSrc; - int index = MapColumnIndex(out isSrc, col); - if (isSrc) - active[index] = true; - else - { - foreach (var i in Infos[index].SrcIndices) - active[i] = true; - } - } - - return col => 0 <= col && col < active.Length && active[col]; - } - } - /// /// Parsing utilities for converting between transform column argument objects and /// command line representations. diff --git a/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingEstimator.cs b/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingEstimator.cs index dc8ad9b389..7f01259379 100644 --- a/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingEstimator.cs +++ b/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingEstimator.cs @@ -2,16 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Core.Data; -using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.StaticPipe; -using Microsoft.ML.StaticPipe.Runtime; using System; using System.Collections.Generic; using System.Linq; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data; namespace Microsoft.ML.Transforms { @@ -49,7 +44,7 @@ public ITransformer Fit(IDataView input) private bool HasCategoricals(SchemaShape.Column col) { - _host.AssertValue(col); + _host.Assert(col.IsValid); if (!col.Metadata.TryFindColumn(MetadataUtils.Kinds.CategoricalSlotRanges, out var mcol)) return false; // The indices must be ints and of a definite size vector type. (Definite becuase @@ -116,218 +111,9 @@ private SchemaShape.Column CheckInputsAndMakeColumn( public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); result[_name] = CheckInputsAndMakeColumn(inputSchema, _name, _source); return new SchemaShape(result.Values); } } - - /// - /// The extension methods and implementation support for concatenating columns together. - /// - public static class ConcatStaticExtensions - { - /// - /// Given a scalar vector, produce a vector of length one. - /// - /// The value type. - /// The scalar column. - /// The vector column, whose single item has the same value as the input. - public static Vector AsVector(this Scalar me) - => new Impl(Join(me, (PipelineColumn[])null)); - - /// - /// Given a bunch of normalized vectors, concatenate them together into a normalized vector. - /// - /// The value type. - /// The first input column. - /// Subsequent input columns. - /// The result of concatenating all input columns together. - public static NormVector ConcatWith(this NormVector me, params NormVector[] others) - => new ImplNorm(Join(me, others)); - - /// - /// Given a set of columns, concatenate them together into a vector valued column of the same type. - /// - /// The value type. - /// The first input column. - /// Subsequent input columns. - /// The result of concatenating all input columns together. - public static Vector ConcatWith(this Scalar me, params ScalarOrVector[] others) - => new Impl(Join(me, others)); - - /// - /// Given a set of columns, concatenate them together into a vector valued column of the same type. - /// - /// The value type. - /// The first input column. - /// Subsequent input columns. - /// The result of concatenating all input columns together. - public static Vector ConcatWith(this Vector me, params ScalarOrVector[] others) - => new Impl(Join(me, others)); - - /// - /// Given a set of columns including at least one variable sized vector column, concatenate them - /// together into a vector valued column of the same type. - /// - /// The value type. - /// The first input column. - /// Subsequent input columns. - /// The result of concatenating all input columns together. - public static VarVector ConcatWith(this Scalar me, params ScalarOrVectorOrVarVector[] others) - => new ImplVar(Join(me, others)); - - /// - /// Given a set of columns including at least one variable sized vector column, concatenate them - /// together into a vector valued column of the same type. - /// - /// The value type. - /// The first input column. - /// Subsequent input columns. - /// The result of concatenating all input columns together. - public static VarVector ConcatWith(this Vector me, params ScalarOrVectorOrVarVector[] others) - => new ImplVar(Join(me, others)); - - /// - /// Given a set of columns including at least one variable sized vector column, concatenate them - /// together into a vector valued column of the same type. - /// - /// The value type. - /// The first input column. - /// Subsequent input columns. - /// The result of concatenating all input columns together. - public static VarVector ConcatWith(this VarVector me, params ScalarOrVectorOrVarVector[] others) - => new ImplVar(Join(me, others)); - - private interface IContainsColumn - { - PipelineColumn WrappedColumn { get; } - } - - /// - /// A wrapping object for the implicit conversions in - /// and other related methods. - /// - /// The value type. - public sealed class ScalarOrVector : ScalarOrVectorOrVarVector - { - private ScalarOrVector(PipelineColumn col) : base(col) { } - public static implicit operator ScalarOrVector(Scalar c) => new ScalarOrVector(c); - public static implicit operator ScalarOrVector(Vector c) => new ScalarOrVector(c); - public static implicit operator ScalarOrVector(NormVector c) => new ScalarOrVector(c); - } - - /// - /// A wrapping object for the implicit conversions in - /// and other related methods. - /// - /// The value type. - public class ScalarOrVectorOrVarVector : IContainsColumn - { - public PipelineColumn WrappedColumn { get; } - - private protected ScalarOrVectorOrVarVector(PipelineColumn col) - { - Contracts.CheckValue(col, nameof(col)); - WrappedColumn = col; - } - - public static implicit operator ScalarOrVectorOrVarVector(VarVector c) - => new ScalarOrVectorOrVarVector(c); - } - - #region Implementation support - private sealed class Rec : EstimatorReconciler - { - /// - /// For the moment the concat estimator can only do one at a time, so I want to apply these operations - /// one at a time, which means a separate reconciler. Otherwise there may be problems with name overwriting. - /// If that is ever adjusted, then we can make a slightly more efficient reconciler, though this is probably - /// not that important of a consideration from a runtime perspective. - /// - public static Rec Inst => new Rec(); - - private Rec() { } - - public override IEstimator Reconcile(IHostEnvironment env, - PipelineColumn[] toOutput, - IReadOnlyDictionary inputNames, - IReadOnlyDictionary outputNames, - IReadOnlyCollection usedNames) - { - // For the moment, the concat estimator can only do one concatenation at a time. - // So we will chain the estimators. - Contracts.AssertNonEmpty(toOutput); - IEstimator est = null; - for (int i = 0; i < toOutput.Length; ++i) - { - var ccol = (IConcatCol)toOutput[i]; - string[] inputs = ccol.Sources.Select(s => inputNames[s]).ToArray(); - var localEst = new ColumnConcatenatingEstimator (env, outputNames[toOutput[i]], inputs); - if (i == 0) - est = localEst; - else - est = est.Append(localEst); - } - return est; - } - } - - private static PipelineColumn[] Join(PipelineColumn col, IContainsColumn[] cols) - { - if (Utils.Size(cols) == 0) - return new[] { col }; - var retVal = new PipelineColumn[cols.Length + 1]; - retVal[0] = col; - for (int i = 0; i < cols.Length; ++i) - retVal[i + 1] = cols[i].WrappedColumn; - return retVal; - } - - private static PipelineColumn[] Join(PipelineColumn col, PipelineColumn[] cols) - { - if (Utils.Size(cols) == 0) - return new[] { col }; - var retVal = new PipelineColumn[cols.Length + 1]; - retVal[0] = col; - Array.Copy(cols, 0, retVal, 1, cols.Length); - return retVal; - } - - private interface IConcatCol - { - PipelineColumn[] Sources { get; } - } - - private sealed class Impl : Vector, IConcatCol - { - public PipelineColumn[] Sources { get; } - public Impl(PipelineColumn[] cols) - : base(Rec.Inst, cols) - { - Sources = cols; - } - } - - private sealed class ImplVar : VarVector, IConcatCol - { - public PipelineColumn[] Sources { get; } - public ImplVar(PipelineColumn[] cols) - : base(Rec.Inst, cols) - { - Sources = cols; - } - } - - private sealed class ImplNorm : NormVector, IConcatCol - { - public PipelineColumn[] Sources { get; } - public ImplNorm(PipelineColumn[] cols) - : base(Rec.Inst, cols) - { - Sources = cols; - } - } - #endregion - } } diff --git a/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs b/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs index c33c2e55ff..0c245ceefc 100644 --- a/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs @@ -2,21 +2,19 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Core.Data; -using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Model.Onnx; -using Microsoft.ML.Runtime.Model.Pfa; -using Newtonsoft.Json.Linq; using System; using System.Collections.Generic; using System.Linq; using System.Text; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Model.Onnx; +using Microsoft.ML.Model.Pfa; +using Newtonsoft.Json.Linq; [assembly: LoadableClass(ColumnConcatenatingTransformer.Summary, typeof(IDataTransform), typeof(ColumnConcatenatingTransformer), typeof(ColumnConcatenatingTransformer.TaggedArguments), typeof(SignatureDataTransform), ColumnConcatenatingTransformer.UserName, ColumnConcatenatingTransformer.LoadName, "ConcatTransform", DocName = "transform/ConcatTransform.md")] @@ -30,7 +28,7 @@ [assembly: LoadableClass(typeof(IRowMapper), typeof(ColumnConcatenatingTransformer), null, typeof(SignatureLoadRowMapper), ColumnConcatenatingTransformer.UserName, ColumnConcatenatingTransformer.LoaderSignature)] -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { using PfaType = PfaUtils.Type; @@ -393,19 +391,19 @@ public static IDataTransform Create(IHostEnvironment env, TaggedArguments args, return transformer.MakeDataTransform(input); } - protected override IRowMapper MakeRowMapper(Schema inputSchema) => new Mapper(this, inputSchema); + private protected override IRowMapper MakeRowMapper(Schema inputSchema) => new Mapper(this, inputSchema); /// /// Factory method for SignatureLoadDataTransform. /// - public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) => new ColumnConcatenatingTransformer(env, ctx).MakeDataTransform(input); /// /// Factory method for SignatureLoadRowMapper. /// - public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) - => new ColumnConcatenatingTransformer(env, ctx).MakeRowMapper(Schema.Create(inputSchema)); + private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema) + => new ColumnConcatenatingTransformer(env, ctx).MakeRowMapper(inputSchema); private sealed class Mapper : MapperBase, ISaveAsOnnx, ISaveAsPfa { @@ -457,7 +455,7 @@ private BoundColumn MakeColumn(Schema inputSchema, int iinfo) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", srcName); sources[i] = srcCol; - var curType = inputSchema.GetColumnType(srcCol); + var curType = inputSchema[srcCol].Type; if (itemType == null) { itemType = curType.ItemType; @@ -474,7 +472,7 @@ private BoundColumn MakeColumn(Schema inputSchema, int iinfo) else throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", srcName, itemType.ToString(), curType.ToString()); - if (isNormalized && !inputSchema.IsNormalized(srcCol)) + if (isNormalized && !inputSchema[srcCol].IsNormalized()) isNormalized = false; if (MetadataUtils.TryGetCategoricalFeatureIndices(inputSchema, srcCol, out int[] typeCat)) @@ -484,7 +482,7 @@ private BoundColumn MakeColumn(Schema inputSchema, int iinfo) hasCategoricals = true; } - if (!hasSlotNames && !curType.IsVector || inputSchema.HasSlotNames(srcCol, curType.VectorSize)) + if (!hasSlotNames && !curType.IsVector || inputSchema[srcCol].HasSlotNames(curType.VectorSize)) hasSlotNames = true; } @@ -496,7 +494,7 @@ private BoundColumn MakeColumn(Schema inputSchema, int iinfo) hasSlotNames = false; } - return new BoundColumn(InputSchema, _parent._columns[iinfo], sources, new VectorType(itemType.AsPrimitive, totalSize), + return new BoundColumn(InputSchema, _parent._columns[iinfo], sources, new VectorType((PrimitiveType)itemType, totalSize), isNormalized, hasSlotNames, hasCategoricals, totalSize, catCount); } @@ -650,7 +648,7 @@ private void GetSlotNames(ref VBuffer> dst) bldr.GetResult(ref dst); } - public Delegate MakeGetter(IRow input) + public Delegate MakeGetter(Row input) { if (_isIdentity) return Utils.MarshalInvoke(MakeIdentityGetter, OutputType.RawType, input); @@ -658,13 +656,13 @@ public Delegate MakeGetter(IRow input) return Utils.MarshalInvoke(MakeGetter, OutputType.ItemType.RawType, input); } - private Delegate MakeIdentityGetter(IRow input) + private Delegate MakeIdentityGetter(Row input) { Contracts.Assert(SrcIndices.Length == 1); return input.GetGetter(SrcIndices[0]); } - private Delegate MakeGetter(IRow input) + private Delegate MakeGetter(Row input) { var srcGetterOnes = new ValueGetter[SrcIndices.Length]; var srcGetterVecs = new ValueGetter>[SrcIndices.Length]; @@ -691,7 +689,7 @@ private Delegate MakeGetter(IRow input) if (type.VectorSize != 0 && type.VectorSize != tmpBufs[i].Length) { throw Contracts.Except("Column '{0}': expected {1} slots, but got {2}", - input.Schema.GetColumnName(SrcIndices[i]), type.VectorSize, tmpBufs[i].Length) + input.Schema[SrcIndices[i]].Name, type.VectorSize, tmpBufs[i].Length) .MarkSensitive(MessageSensitivity.Schema); } dstLength = checked(dstLength + tmpBufs[i].Length); @@ -829,9 +827,9 @@ public KeyValuePair SavePfaInfo(BoundPfaContext ctx) } } - public override Func GetDependencies(Func activeOutput) + private protected override Func GetDependenciesCore(Func activeOutput) { - var active = new bool[InputSchema.ColumnCount]; + var active = new bool[InputSchema.Count]; for (int i = 0; i < _columns.Length; i++) { if (activeOutput(i)) @@ -847,7 +845,7 @@ public override Func GetDependencies(Func activeOutput) public override void Save(ModelSaveContext ctx) => _parent.Save(ctx); - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { disposer = null; return _columns[iinfo].MakeGetter(input); diff --git a/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs b/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs index 6ab8fe9b7b..2997a9a1c4 100644 --- a/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs +++ b/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs @@ -6,15 +6,14 @@ using System.Collections.Generic; using System.Linq; using System.Text; +using Microsoft.ML; +using Microsoft.ML.CommandLine; using Microsoft.ML.Core.Data; using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Model.Onnx; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Model.Onnx; using Microsoft.ML.Transforms; [assembly: LoadableClass(ColumnCopyingTransformer.Summary, typeof(IDataTransform), typeof(ColumnCopyingTransformer), @@ -49,7 +48,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); - var resultDic = inputSchema.Columns.ToDictionary(x => x.Name); + var resultDic = inputSchema.ToDictionary(x => x.Name); foreach (var (Source, Name) in Transformer.Columns) { if (!inputSchema.TryFindColumn(Source, out var originalColumn)) @@ -149,8 +148,8 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, => Create(env, ctx).MakeDataTransform(input); // Factory method for SignatureLoadRowMapper. - private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) - => Create(env, ctx).MakeRowMapper(Schema.Create(inputSchema)); + private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema) + => Create(env, ctx).MakeRowMapper(inputSchema); public override void Save(ModelSaveContext ctx) { @@ -158,7 +157,7 @@ public override void Save(ModelSaveContext ctx) SaveColumns(ctx); } - protected override IRowMapper MakeRowMapper(Schema inputSchema) + private protected override IRowMapper MakeRowMapper(Schema inputSchema) => new Mapper(this, inputSchema, ColumnPairs); private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx @@ -175,17 +174,17 @@ internal Mapper(ColumnCopyingTransformer parent, Schema inputSchema, (string Sou _columns = columns; } - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { Host.AssertValue(input); Host.Assert(0 <= iinfo && iinfo < _columns.Length); disposer = null; - Delegate MakeGetter(IRow row, int index) + Delegate MakeGetter(Row row, int index) => input.GetGetter(index); input.Schema.TryGetColumnIndex(_columns[iinfo].Source, out int colIndex); - var type = input.Schema.GetColumnType(colIndex); + var type = input.Schema[colIndex].Type; return Utils.MarshalInvoke(MakeGetter, type.RawType, input, colIndex); } @@ -208,7 +207,7 @@ public void SaveAsOnnx(OnnxContext ctx) { var srcVariableName = ctx.GetVariableName(column.Source); _schema.TryGetColumnIndex(column.Source, out int colIndex); - var dstVariableName = ctx.AddIntermediateVariable(_schema.GetColumnType(colIndex), column.Name); + var dstVariableName = ctx.AddIntermediateVariable(_schema[colIndex].Type, column.Name); var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType)); node.AddAttribute("type", LoaderSignature); } diff --git a/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs b/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs index 204421d1ad..ce7fb554ee 100644 --- a/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs +++ b/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs @@ -2,18 +2,17 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Core.Data; -using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Transforms; using System; using System.Collections.Generic; using System.Linq; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Transforms; [assembly: LoadableClass(ColumnSelectingTransformer.Summary, typeof(IDataTransform), typeof(ColumnSelectingTransformer), typeof(ColumnSelectingTransformer.Arguments), typeof(SignatureDataTransform), @@ -98,14 +97,14 @@ public static ColumnSelectingEstimator DropColumns(IHostEnvironment env, params public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); - if (!Transformer.IgnoreMissing && !ColumnSelectingTransformer.IsSchemaValid(inputSchema.Columns.Select(x => x.Name), + if (!Transformer.IgnoreMissing && !ColumnSelectingTransformer.IsSchemaValid(inputSchema.Select(x => x.Name), Transformer.SelectColumns, out IEnumerable invalidColumns)) { throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", string.Join(",", invalidColumns)); } - var columns = inputSchema.Columns.Where(c => _selectPredicate(c.Name)); + var columns = inputSchema.Where(c => _selectPredicate(c.Name)); return new SchemaShape(columns); } } @@ -442,19 +441,19 @@ public static bool IsSchemaValid(IEnumerable inputColumns, public Schema GetOutputSchema(Schema inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); - if (!IgnoreMissing && !IsSchemaValid(inputSchema.GetColumns().Select(x => x.column.Name), + if (!IgnoreMissing && !IsSchemaValid(inputSchema.Select(x => x.Name), SelectColumns, out IEnumerable invalidColumns)) { throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", string.Join(",", invalidColumns)); } - return new Mapper(this, inputSchema).Schema; + return new Mapper(this, inputSchema).OutputSchema; } public IRowToRowMapper GetRowToRowMapper(Schema inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); - if (!IgnoreMissing && !IsSchemaValid(inputSchema.GetColumns().Select(x => x.column.Name), + if (!IgnoreMissing && !IsSchemaValid(inputSchema.Select(x => x.Name), SelectColumns, out IEnumerable invalidColumns)) { throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", string.Join(",", invalidColumns)); @@ -468,7 +467,7 @@ public IRowToRowMapper GetRowToRowMapper(Schema inputSchema) public IDataView Transform(IDataView input) { _host.CheckValue(input, nameof(input)); - if (!IgnoreMissing && !IsSchemaValid(input.Schema.GetColumns().Select(x => x.column.Name), + if (!IgnoreMissing && !IsSchemaValid(input.Schema.Select(x => x.Name), SelectColumns, out IEnumerable invalidColumns)) { throw _host.ExceptSchemaMismatch(nameof(input), "input", string.Join(",", invalidColumns)); @@ -483,9 +482,9 @@ private sealed class Mapper private readonly Schema _inputSchema; private readonly int[] _outputToInputMap; - public ISchema InputSchema => _inputSchema; + public Schema InputSchema => _inputSchema; - public Schema Schema { get; } + public Schema OutputSchema { get; } public Mapper(ColumnSelectingTransformer transform, Schema inputSchema) { @@ -496,7 +495,7 @@ public Mapper(ColumnSelectingTransformer transform, Schema inputSchema) transform.KeepColumns, transform.KeepHidden, _inputSchema); - Schema = GenerateOutputSchema(_outputToInputMap, _inputSchema); + OutputSchema = GenerateOutputSchema(_outputToInputMap, _inputSchema); } public int GetInputIndex(int outputIndex) @@ -511,7 +510,7 @@ private static int[] BuildOutputToInputMap(IEnumerable selectedColumns, Schema inputSchema) { var outputToInputMapping = new List(); - var columnCount = inputSchema.ColumnCount; + var columnCount = inputSchema.Count; if (keepColumns) { @@ -523,9 +522,9 @@ private static int[] BuildOutputToInputMap(IEnumerable selectedColumns, // column name-> list of column indices. This dictionary is used for // building the final mapping. var columnDict = new Dictionary>(); - for (int colIdx = 0; colIdx < inputSchema.ColumnCount; ++colIdx) + for (int colIdx = 0; colIdx < inputSchema.Count; ++colIdx) { - if (!keepHidden && inputSchema.IsHidden(colIdx)) + if (!keepHidden && inputSchema[colIdx].IsHidden) continue; var columnName = inputSchema[colIdx].Name; @@ -559,7 +558,7 @@ private static int[] BuildOutputToInputMap(IEnumerable selectedColumns, // given an input of ABC and dropping column B will result in AC. // In drop mode, we drop all columns with the specified names and keep all the rest, // ignoring the keepHidden argument. - for(int colIdx = 0; colIdx < inputSchema.ColumnCount; colIdx++) + for(int colIdx = 0; colIdx < inputSchema.Count; colIdx++) { if (selectedColumns.Contains(inputSchema[colIdx].Name)) continue; @@ -579,32 +578,24 @@ private static Schema GenerateOutputSchema(IEnumerable map, } } - private sealed class Row : IRow + private sealed class RowImpl : WrappingRow { private readonly Mapper _mapper; - private readonly IRow _input; - public Row(IRow input, Mapper mapper) + public RowImpl(Row input, Mapper mapper) + : base(input) { _mapper = mapper; - _input = input; } - public long Position => _input.Position; + public override Schema Schema => _mapper.OutputSchema; - public long Batch => _input.Batch; - - Schema ISchematized.Schema => _mapper.Schema; - - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { int index = _mapper.GetInputIndex(col); - return _input.GetGetter(index); + return Input.GetGetter(index); } - public ValueGetter GetIdGetter() - => _input.GetIdGetter(); - - public bool IsColumnActive(int col) => true; + public override bool IsColumnActive(int col) => true; } private sealed class SelectColumnsDataTransform : IDataTransform, IRowToRowMapper, ITransformTemplate @@ -625,13 +616,15 @@ public SelectColumnsDataTransform(IHostEnvironment env, ColumnSelectingTransform public IDataView Source { get; } - Schema IRowToRowMapper.InputSchema => Source.Schema; + public Schema InputSchema => Source.Schema; - Schema ISchematized.Schema => _mapper.Schema; + Schema IDataView.Schema => OutputSchema; + + public Schema OutputSchema => _mapper.OutputSchema; public long? GetRowCount() => Source.GetRowCount(); - public IRowCursor GetRowCursor(Func needCol, IRandom rand = null) + public RowCursor GetRowCursor(Func needCol, Random rand = null) { _host.AssertValue(needCol, nameof(needCol)); _host.AssertValueOrNull(rand); @@ -641,29 +634,27 @@ public IRowCursor GetRowCursor(Func needCol, IRandom rand = null) var inputRowCursor = Source.GetRowCursor(inputPred, rand); // Build the active state for the output - var active = Utils.BuildArray(_mapper.Schema.ColumnCount, needCol); - return new RowCursor(_host, _mapper, inputRowCursor, active); + var active = Utils.BuildArray(_mapper.OutputSchema.Count, needCol); + return new Cursor(_host, _mapper, inputRowCursor, active); } - public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func needCol, int n, IRandom rand = null) + public RowCursor[] GetRowCursorSet(Func needCol, int n, Random rand = null) { _host.CheckValue(needCol, nameof(needCol)); _host.CheckValueOrNull(rand); // Build out the active state for the input var inputPred = GetDependencies(needCol); - var inputs = Source.GetRowCursorSet(out consolidator, inputPred, n, rand); + var inputs = Source.GetRowCursorSet(inputPred, n, rand); // Build out the acitve state for the output - var active = Utils.BuildArray(_mapper.Schema.ColumnCount, needCol); + var active = Utils.BuildArray(_mapper.OutputSchema.Count, needCol); _host.AssertNonEmpty(inputs); // No need to split if this is given 1 input cursor. - var cursors = new IRowCursor[inputs.Length]; + var cursors = new RowCursor[inputs.Length]; for (int i = 0; i < inputs.Length; i++) - { - cursors[i] = new RowCursor(_host, _mapper, inputs[i], active); - } + cursors[i] = new Cursor(_host, _mapper, inputs[i], active); return cursors; } @@ -671,8 +662,8 @@ public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Fun public Func GetDependencies(Func activeOutput) { - var active = new bool[_mapper.InputSchema.ColumnCount]; - var columnCount = _mapper.Schema.ColumnCount; + var active = new bool[_mapper.InputSchema.Count]; + var columnCount = _mapper.OutputSchema.Count; for (int colIdx = 0; colIdx < columnCount; ++colIdx) { if (activeOutput(colIdx)) @@ -682,22 +673,21 @@ public Func GetDependencies(Func activeOutput) return col => active[col]; } - public IRow GetRow(IRow input, Func active, out Action disposer) + public Row GetRow(Row input, Func active) { - disposer = null; - return new Row(input, _mapper); + return new RowImpl(input, _mapper); } public IDataTransform ApplyToData(IHostEnvironment env, IDataView newSource) => new SelectColumnsDataTransform(env, _transform, new Mapper(_transform, newSource.Schema), newSource); } - private sealed class RowCursor : SynchronizedCursorBase, IRowCursor + private sealed class Cursor : SynchronizedCursorBase { private readonly Mapper _mapper; - private readonly IRowCursor _inputCursor; + private readonly RowCursor _inputCursor; private readonly bool[] _active; - public RowCursor(IChannelProvider provider, Mapper mapper, IRowCursor input, bool[] active) + public Cursor(IChannelProvider provider, Mapper mapper, RowCursor input, bool[] active) : base(provider, input) { _mapper = mapper; @@ -705,15 +695,15 @@ public RowCursor(IChannelProvider provider, Mapper mapper, IRowCursor input, boo _active = active; } - Schema ISchematized.Schema => _mapper.Schema; + public override Schema Schema => _mapper.OutputSchema; - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { int index = _mapper.GetInputIndex(col); return _inputCursor.GetGetter(index); } - public bool IsColumnActive(int col) => _active[col]; + public override bool IsColumnActive(int col) => _active[col]; } } } diff --git a/src/Microsoft.ML.Data/Transforms/ConversionsExtensionsCatalog.cs b/src/Microsoft.ML.Data/Transforms/ConversionsExtensionsCatalog.cs index f6b45af576..d2481fb4be 100644 --- a/src/Microsoft.ML.Data/Transforms/ConversionsExtensionsCatalog.cs +++ b/src/Microsoft.ML.Data/Transforms/ConversionsExtensionsCatalog.cs @@ -1,9 +1,9 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; +using System.Collections.Generic; +using Microsoft.ML.Data; using Microsoft.ML.Transforms.Conversions; namespace Microsoft.ML @@ -23,7 +23,10 @@ public static class ConversionsExtensionsCatalog /// Name of the input column. /// Name of the column to be transformed. If this is null '' will be used. /// Number of bits to hash into. Must be between 1 and 31, inclusive. - /// Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit. + /// During hashing we constuct mappings between original values and the produced hash values. + /// Text representation of original values are stored in the slot names of the metadata for the new column.Hashing, as such, can map many initial values to one. + /// specifies the upper bound of the number of distinct input values mapping to a hash that should be retained. + /// 0 does not retain any input values. -1 retains all input values mapping to each hash. public static HashingEstimator Hash(this TransformsCatalog.ConversionTransforms catalog, string inputColumn, string outputColumn = null, int hashBits = HashDefaults.HashBits, int invertHash = HashDefaults.InvertHash) => new HashingEstimator(CatalogUtils.GetEnvironment(catalog), inputColumn, outputColumn, hashBits, invertHash); @@ -99,8 +102,8 @@ public static KeyToVectorMappingEstimator MapKeyToVector(this TransformsCatalog. /// Name of the column to be transformed. /// Name of the output column. If this is null '' will be used. /// Maximum number of keys to keep per column when auto-training. - /// How items should be ordered when vectorized. By default, they will be in the order encountered. - /// If by value items are sorted according to their default comparison, for example, text sorting will be case sensitive (for example, 'A' then 'Z' then 'a'). + /// How items should be ordered when vectorized. If choosen they will be in the order encountered. + /// If , items are sorted according to their default comparison, for example, text sorting will be case sensitive (for example, 'A' then 'Z' then 'a'). public static ValueToKeyMappingEstimator MapValueToKey(this TransformsCatalog.ConversionTransforms catalog, string inputColumn, string outputColumn = null, @@ -122,5 +125,23 @@ public static ValueToKeyMappingEstimator MapValueToKey(this TransformsCatalog.Co string termsColumn = null, IComponentFactory loaderFactory = null) => new ValueToKeyMappingEstimator(CatalogUtils.GetEnvironment(catalog), columns, file, termsColumn, loaderFactory); + + /// + /// Maps specified keys to specified values + /// + /// The key type. + /// The value type. + /// The categorical transform's catalog + /// The list of keys to use for the mapping. The mapping is 1-1 with values. This list must be the same length as values and + /// cannot contain duplicate keys. + /// The list of values to pair with the keys for the mapping. This list must be equal to the same length as keys. + /// The columns to apply this transform on. + /// + public static ValueMappingEstimator ValueMap( + this TransformsCatalog.ConversionTransforms catalog, + IEnumerable keys, + IEnumerable values, + params (string source, string name)[] columns) + => new ValueMappingEstimator(CatalogUtils.GetEnvironment(catalog), keys, values, columns); } } diff --git a/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs b/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs index d1f6c70bee..f14144a45e 100644 --- a/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs @@ -2,19 +2,18 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Internallearn; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Transforms.FeatureSelection; using System; using System.Collections.Generic; using System.Linq; using System.Reflection; using System.Text; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Transforms.FeatureSelection; [assembly: LoadableClass(SlotsDroppingTransformer.Summary, typeof(IDataTransform), typeof(SlotsDroppingTransformer), typeof(SlotsDroppingTransformer.Arguments), typeof(SignatureDataTransform), SlotsDroppingTransformer.FriendlyName, SlotsDroppingTransformer.LoaderSignature, "DropSlots")] @@ -200,7 +199,7 @@ public sealed class ColumnInfo /// Describes how the transformer handles one input-output column pair. /// /// Name of the input column. - /// Name of the column resulting from the transformation of . Null means is replaced. + /// Name of the column resulting from the transformation of . Null means is replaced. /// Ranges of indices in the input column to be dropped. Setting max in to null sets max to int.MaxValue. public ColumnInfo(string input, string output = null, params (int min, int? max)[] slots) { @@ -252,7 +251,7 @@ private static VersionInfo GetVersionInfo() /// /// The environment to use. /// Name of the input column. - /// Name of the column resulting from the transformation of . Null means is replaced. + /// Name of the column resulting from the transformation of . Null means is replaced. /// Specifies the lower bound of the range of slots to be dropped. The lower bound is inclusive. /// Specifies the upper bound of the range of slots to be dropped. The upper bound is exclusive. public SlotsDroppingTransformer(IHostEnvironment env, string input, string output = null, int min = default, int? max = null) @@ -316,8 +315,8 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, => Create(env, ctx).MakeDataTransform(input); // Factory method for SignatureLoadRowMapper. - private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) - => Create(env, ctx).MakeRowMapper(Schema.Create(inputSchema)); + private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema) + => Create(env, ctx).MakeRowMapper(inputSchema); public override void Save(ModelSaveContext ctx) { @@ -434,7 +433,7 @@ private static bool AreRangesValid(int[][] slotsMin, int[][] slotsMax) return true; } - protected override IRowMapper MakeRowMapper(Schema schema) + private protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); private sealed class Mapper : OneToOneMapperBase @@ -463,7 +462,7 @@ public Mapper(SlotsDroppingTransformer parent, Schema inputSchema) { if (!InputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out _cols[i])) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].input); - _srcTypes[i] = inputSchema.GetColumnType(_cols[i]); + _srcTypes[i] = inputSchema[_cols[i]].Type; if (!IsValidColumnType(_srcTypes[i].ItemType)) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].input); _slotDropper[i] = new SlotDropper(_srcTypes[i].ValueCount, _parent.SlotsMin[i], _parent.SlotsMax[i]); @@ -517,8 +516,8 @@ private void ComputeType(Schema input, int iinfo, SlotDropper slotDropper, { Host.Assert(typeSrc.IsKnownSizeVector); var dstLength = slotDropper.DstLength; - var hasSlotNames = input.HasSlotNames(_cols[iinfo], _srcTypes[iinfo].VectorSize); - type = new VectorType(typeSrc.ItemType.AsPrimitive, Math.Max(dstLength, 1)); + var hasSlotNames = input[_cols[iinfo]].HasSlotNames(_srcTypes[iinfo].VectorSize); + type = new VectorType((PrimitiveType)typeSrc.ItemType, Math.Max(dstLength, 1)); suppressed = dstLength == 0; } } @@ -528,7 +527,7 @@ private void GetSlotNames(int iinfo, ref VBuffer> dst) Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); var names = default(VBuffer>); - InputSchema.GetMetadata(MetadataUtils.Kinds.SlotNames, _cols[iinfo], ref names); + InputSchema[_cols[iinfo]].GetSlotNames(ref names); _slotDropper[iinfo].DropSlots(ref names, ref dst); } @@ -692,7 +691,7 @@ private void CombineRanges( newRangeMax = maxRange2; } - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { Host.AssertValue(input); Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); @@ -711,7 +710,7 @@ protected override Delegate MakeGetter(IRow input, int iinfo, Func ac return MakeVecGetter(input, iinfo); } - private Delegate MakeOneTrivialGetter(IRow input, int iinfo) + private Delegate MakeOneTrivialGetter(Row input, int iinfo) { Host.AssertValue(input); Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); @@ -734,7 +733,7 @@ private void OneTrivialGetter(ref TDst value) value = default(TDst); } - private Delegate MakeVecTrivialGetter(IRow input, int iinfo) + private Delegate MakeVecTrivialGetter(Row input, int iinfo) { Host.AssertValue(input); Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); @@ -757,19 +756,19 @@ private void VecTrivialGetter(ref VBuffer value) VBufferUtils.Resize(ref value, 1, 0); } - private Delegate MakeVecGetter(IRow input, int iinfo) + private Delegate MakeVecGetter(Row input, int iinfo) { Host.AssertValue(input); Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); Host.Assert(_srcTypes[iinfo].IsVector); Host.Assert(!_suppressed[iinfo]); - Func>> del = MakeVecGetter; + Func>> del = MakeVecGetter; var methodInfo = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(_srcTypes[iinfo].ItemType.RawType); return (Delegate)methodInfo.Invoke(this, new object[] { input, iinfo }); } - private ValueGetter> MakeVecGetter(IRow input, int iinfo) + private ValueGetter> MakeVecGetter(Row input, int iinfo) { var srcGetter = GetSrcGetter>(input, iinfo); var typeDst = _dstTypes[iinfo]; @@ -786,7 +785,7 @@ private ValueGetter> MakeVecGetter(IRow input, int iinfo) }; } - private ValueGetter GetSrcGetter(IRow input, int iinfo) + private ValueGetter GetSrcGetter(Row input, int iinfo) { Host.AssertValue(input); Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); @@ -795,12 +794,12 @@ private ValueGetter GetSrcGetter(IRow input, int iinfo) return input.GetGetter(src); } - private Delegate GetSrcGetter(ColumnType typeDst, IRow row, int iinfo) + private Delegate GetSrcGetter(ColumnType typeDst, Row row, int iinfo) { Host.CheckValue(typeDst, nameof(typeDst)); Host.CheckValue(row, nameof(row)); - Func> del = GetSrcGetter; + Func> del = GetSrcGetter; var methodInfo = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(typeDst.RawType); return (Delegate)methodInfo.Invoke(this, new object[] { row, iinfo }); } @@ -821,8 +820,8 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() if (_srcTypes[iinfo].IsVector && _srcTypes[iinfo].IsKnownSizeVector) { var dstLength = _slotDropper[iinfo].DstLength; - var hasSlotNames = InputSchema.HasSlotNames(_cols[iinfo], _srcTypes[iinfo].VectorSize); - var type = new VectorType(_srcTypes[iinfo].ItemType.AsPrimitive, Math.Max(dstLength, 1)); + var hasSlotNames = InputSchema[_cols[iinfo]].HasSlotNames(_srcTypes[iinfo].VectorSize); + var type = new VectorType((PrimitiveType)_srcTypes[iinfo].ItemType, Math.Max(dstLength, 1)); if (hasSlotNames && dstLength > 0) { diff --git a/src/Microsoft.ML.Data/Transforms/ExplainabilityCatalog.cs b/src/Microsoft.ML.Data/Transforms/ExplainabilityCatalog.cs new file mode 100644 index 0000000000..9baf2714ae --- /dev/null +++ b/src/Microsoft.ML.Data/Transforms/ExplainabilityCatalog.cs @@ -0,0 +1,34 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Internallearn; + +namespace Microsoft.ML +{ + using FeatureContributionDefaults = FeatureContributionCalculatingEstimator.Defaults; + + public static class ExplainabilityCatalog + { + /// + /// Feature Contribution Calculation computes model-specific contribution scores for each feature. + /// Note that this functionality is not supported by all the models. See for a list of the suported models. + /// + /// The model explainability operations catalog. + /// Trained model parameters that support Feature Contribution Calculation and which will be used for scoring. + /// The name of the feature column that will be used as input. + /// The number of features with highest positive contributions for each data sample that will be retained in the FeatureContribution column. + /// Note that if there are fewer features with positive contributions than , the rest will be returned as zeros. + /// The number of features with least negative contributions for each data sample that will be retained in the FeatureContribution column. + /// Note that if there are fewer features with negative contributions than , the rest will be returned as zeros. + /// Whether the feature contributions should be normalized to the [-1, 1] interval. + public static FeatureContributionCalculatingEstimator FeatureContributionCalculation(this ModelOperationsCatalog.ExplainabilityTransforms catalog, + ICalculateFeatureContribution modelParameters, + string featureColumn = DefaultColumnNames.Features, + int top = FeatureContributionDefaults.Top, + int bottom = FeatureContributionDefaults.Bottom, + bool normalize = FeatureContributionDefaults.Normalize) + => new FeatureContributionCalculatingEstimator(CatalogUtils.GetEnvironment(catalog), modelParameters, featureColumn, top, bottom, normalize); + } +} diff --git a/src/Microsoft.ML.Data/Transforms/ExtensionsCatalog.cs b/src/Microsoft.ML.Data/Transforms/ExtensionsCatalog.cs index defe52ee65..18393ba934 100644 --- a/src/Microsoft.ML.Data/Transforms/ExtensionsCatalog.cs +++ b/src/Microsoft.ML.Data/Transforms/ExtensionsCatalog.cs @@ -2,8 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Data; using Microsoft.ML.Transforms; namespace Microsoft.ML diff --git a/src/Microsoft.ML.Data/Transforms/FeatureContributionCalculationTransform.cs b/src/Microsoft.ML.Data/Transforms/FeatureContributionCalculationTransform.cs new file mode 100644 index 0000000000..7b2bbd87ab --- /dev/null +++ b/src/Microsoft.ML.Data/Transforms/FeatureContributionCalculationTransform.cs @@ -0,0 +1,335 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; + +[assembly: LoadableClass(FeatureContributionCalculatingTransformer.Summary, typeof(FeatureContributionCalculatingTransformer), null, typeof(SignatureLoadModel), + FeatureContributionCalculatingTransformer.FriendlyName, FeatureContributionCalculatingTransformer.LoaderSignature)] + +[assembly: LoadableClass(typeof(IRowMapper), typeof(FeatureContributionCalculatingTransformer), null, typeof(SignatureLoadRowMapper), + FeatureContributionCalculatingTransformer.FriendlyName, FeatureContributionCalculatingTransformer.LoaderSignature)] + +[assembly: LoadableClass(typeof(void), typeof(FeatureContributionEntryPoint), null, typeof(SignatureEntryPointModule), FeatureContributionCalculatingTransformer.LoaderSignature)] + +namespace Microsoft.ML.Data +{ + /// + /// The FeatureContributionCalculationTransformer computes model-specific contribution scores for each feature. + /// See the list of currently supported models below. + /// + /// + /// Feature Contribution Calculation is currently supported for the following models: + /// Regression: + /// OrdinaryLeastSquares, StochasticDualCoordinateAscent (SDCA), OnlineGradientDescent, PoissonRegression, + /// GeneralizedAdditiveModels (GAM), LightGbm, FastTree, FastForest, FastTreeTweedie + /// Binary Classification: + /// AveragedPerceptron, LinearSupportVectorMachines, LogisticRegression, StochasticDualCoordinateAscent (SDCA), + /// StochasticGradientDescent (SGD), SymbolicStochasticGradientDescent, GeneralizedAdditiveModels (GAM), + /// FastForest, FastTree, LightGbm + /// Ranking: + /// FastTree, LightGbm + /// + /// See the sample below for an example of how to compute feature importance using the FeatureContributionCalculatingTransformer. + /// + /// + /// + /// + /// + /// + public sealed class FeatureContributionCalculatingTransformer : OneToOneTransformerBase + { + public sealed class Arguments : TransformInputBase + { + [Argument(ArgumentType.Required, HelpText = "The predictor model to apply to data", SortOrder = 1)] + public PredictorModel PredictorModel; + + [Argument(ArgumentType.AtMostOnce, HelpText = "Name of feature column", SortOrder = 2)] + public string FeatureColumn = DefaultColumnNames.Features; + + [Argument(ArgumentType.AtMostOnce, HelpText = "Number of top contributions", SortOrder = 3)] + public int Top = FeatureContributionCalculatingEstimator.Defaults.Top; + + [Argument(ArgumentType.AtMostOnce, HelpText = "Number of bottom contributions", SortOrder = 4)] + public int Bottom = FeatureContributionCalculatingEstimator.Defaults.Bottom; + + [Argument(ArgumentType.AtMostOnce, HelpText = "Whether or not output of Features contribution should be normalized", ShortName = "norm", SortOrder = 5)] + public bool Normalize = FeatureContributionCalculatingEstimator.Defaults.Normalize; + } + + // Apparently, loader signature is limited in length to 24 characters. + internal const string Summary = "For each data point, calculates the contribution of individual features to the model prediction."; + internal const string FriendlyName = "Feature Contribution Calculation"; + internal const string LoaderSignature = "FeatureContribution"; + + public readonly int Top; + public readonly int Bottom; + public readonly bool Normalize; + + private readonly IFeatureContributionMapper _predictor; + + private static VersionInfo GetVersionInfo() + { + return new VersionInfo( + modelSignature: "FCC TRAN", + verWrittenCur: 0x00010001, // Initial + verReadableCur: 0x00010001, + verWeCanReadBack: 0x00010001, + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(FeatureContributionCalculatingTransformer).Assembly.FullName); + } + + /// + /// Feature Contribution Calculation computes model-specific contribution scores for each feature. + /// Note that this functionality is not supported by all the models. See for a list of the suported models. + /// + /// The environment to use. + /// Trained model parameters that support Feature Contribution Calculation and which will be used for scoring. + /// The name of the feature column that will be used as input. + /// The number of features with highest positive contributions for each data sample that will be retained in the FeatureContribution column. + /// Note that if there are fewer features with positive contributions than , the rest will be returned as zeros. + /// The number of features with least negative contributions for each data sample that will be retained in the FeatureContribution column. + /// Note that if there are fewer features with negative contributions than , the rest will be returned as zeros. + /// Whether the feature contributions should be normalized to the [-1, 1] interval. + public FeatureContributionCalculatingTransformer(IHostEnvironment env, ICalculateFeatureContribution modelParameters, + string featureColumn = DefaultColumnNames.Features, + int top = FeatureContributionCalculatingEstimator.Defaults.Top, + int bottom = FeatureContributionCalculatingEstimator.Defaults.Bottom, + bool normalize = FeatureContributionCalculatingEstimator.Defaults.Normalize) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(FeatureContributionCalculatingTransformer)), new[] { (input: featureColumn, output: DefaultColumnNames.FeatureContributions) }) + { + Host.CheckValue(modelParameters, nameof(modelParameters)); + Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); + if (top < 0) + throw Host.Except($"Number of top contribution must be non negative"); + if (bottom < 0) + throw Host.Except($"Number of bottom contribution must be non negative"); + + // If a predictor implements ICalculateFeatureContribution, it also implements the internal interface IFeatureContributionMapper. + // This is how we keep the implementation of feature contribution calculation internal. + _predictor = modelParameters as IFeatureContributionMapper; + Host.AssertValue(_predictor); + + Top = top; + Bottom = bottom; + Normalize = normalize; + } + + private FeatureContributionCalculatingTransformer(IHostEnvironment env, ModelLoadContext ctx) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(FeatureContributionCalculatingTransformer)), ctx) + { + Host.AssertValue(ctx); + + // *** Binary format *** + // base + // IFeatureContributionMapper: predictor + // int: top + // int: bottom + // bool: normalize + + ctx.LoadModel(env, out _predictor, ModelFileUtils.DirPredictor); + Top = ctx.Reader.ReadInt32(); + Contracts.CheckDecode(0 <= Top); + Bottom = ctx.Reader.ReadInt32(); + Contracts.CheckDecode(0 <= Bottom); + Normalize = ctx.Reader.ReadBoolByte(); + } + + public override void Save(ModelSaveContext ctx) + { + Host.CheckValue(ctx, nameof(ctx)); + ctx.SetVersionInfo(GetVersionInfo()); + + // *** Binary format *** + // base + // IFeatureContributionMapper: predictor + // int: top + // int: bottom + // bool: normalize + + SaveColumns(ctx); + ctx.SaveModel(_predictor, ModelFileUtils.DirPredictor); + Contracts.Assert(0 <= Top); + ctx.Writer.Write(Top); + Contracts.Assert(0 <= Bottom); + ctx.Writer.Write(Bottom); + ctx.Writer.WriteBoolByte(Normalize); + } + + // Factory method for SignatureLoadModel. + private static FeatureContributionCalculatingTransformer Create(IHostEnvironment env, ModelLoadContext ctx) + { + Contracts.CheckValue(env, nameof(env)); + ctx.CheckAtModel(GetVersionInfo()); + return new FeatureContributionCalculatingTransformer(env, ctx); + } + + // Factory method for SignatureLoadRowMapper. + private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema) + => Create(env, ctx).MakeRowMapper(inputSchema); + + private protected override IRowMapper MakeRowMapper(Schema schema) + => new Mapper(this, schema); + + private class Mapper : OneToOneMapperBase + { + private readonly FeatureContributionCalculatingTransformer _parent; + private readonly VBuffer> _slotNames; + private readonly int _featureColumnIndex; + private readonly ColumnType _featureColumnType; + + public Mapper(FeatureContributionCalculatingTransformer parent, Schema schema) + : base(parent.Host, parent, schema) + { + _parent = parent; + + // Check that the featureColumn is present and has the expected type. + if (!schema.TryGetColumnIndex(_parent.ColumnPairs[0].input, out _featureColumnIndex)) + throw Host.ExceptSchemaMismatch(nameof(schema), "input", _parent.ColumnPairs[0].input); + _featureColumnType = schema[_featureColumnIndex].Type; + if (_featureColumnType.ItemType != NumberType.R4 || !_featureColumnType.IsVector) + throw Host.ExceptSchemaMismatch(nameof(schema), "feature column", _parent.ColumnPairs[0].input, "Expected type is vector of float.", _featureColumnType.ItemType.ToString()); + + if (InputSchema[_featureColumnIndex].HasSlotNames(_featureColumnType.VectorSize)) + InputSchema[_featureColumnIndex].Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref _slotNames); + else + _slotNames = VBufferUtils.CreateEmpty>(_featureColumnType.VectorSize); + } + + // The FeatureContributionCalculatingTransformer produces two sets of columns: the columns obtained from scoring and the FeatureContribution column. + // If the argument stringify is true, the type of the FeatureContribution column is string, otherwise it is a vector of float. + protected override Schema.DetachedColumn[] GetOutputColumnsCore() + { + // Add FeatureContributions column. + var builder = new MetadataBuilder(); + builder.Add(InputSchema[_featureColumnIndex].Metadata, x => x == MetadataUtils.Kinds.SlotNames); + return new[] { new Schema.DetachedColumn(DefaultColumnNames.FeatureContributions, new VectorType(NumberType.R4, _featureColumnType.ValueCount), builder.GetMetadata()) }; + } + + protected override Delegate MakeGetter(Row input, int iinfo, Func active, out Action disposer) + { + disposer = null; + Contracts.CheckValue(input, nameof(input)); + + // REVIEW: Assuming Feature contributions will be VBuffer. + // For multiclass LR it needs to be VBuffer[]. + return Utils.MarshalInvoke(GetValueGetter, _featureColumnType.RawType, input, ColMapNewToOld[iinfo]); + } + + private Delegate GetValueGetter(Row input, int colSrc) + { + Contracts.AssertValue(input); + Contracts.AssertValue(_parent._predictor); + + var featureGetter = input.GetGetter(colSrc); + + var map = _parent._predictor.GetFeatureContributionMapper>(_parent.Top, _parent.Bottom, _parent.Normalize); + var features = default(TSrc); + + return (ValueGetter>)((ref VBuffer dst) => + { + featureGetter(ref features); + map(in features, ref dst); + }); + } + } + } + + /// + /// Estimator producing a FeatureContributionCalculatingTransformer which scores the model on an input dataset and + /// computes model-specific contribution scores for each feature. + /// + public sealed class FeatureContributionCalculatingEstimator : TrivialEstimator + { + private readonly string _featureColumn; + private readonly ICalculateFeatureContribution _predictor; + + public static class Defaults + { + public const int Top = 10; + public const int Bottom = 10; + public const bool Normalize = true; + } + + /// + /// Feature Contribution Calculation computes model-specific contribution scores for each feature. + /// Note that this functionality is not supported by all the models. See for a list of the suported models. + /// + /// The environment to use. + /// Trained model parameters that support Feature Contribution Calculation and which will be used for scoring. + /// The name of the feature column that will be used as input. + /// The number of features with highest positive contributions for each data sample that will be retained in the FeatureContribution column. + /// Note that if there are fewer features with positive contributions than , the rest will be returned as zeros. + /// The number of features with least negative contributions for each data sample that will be retained in the FeatureContribution column. + /// Note that if there are fewer features with negative contributions than , the rest will be returned as zeros. + /// Whether the feature contributions should be normalized to the [-1, 1] interval. + public FeatureContributionCalculatingEstimator(IHostEnvironment env, ICalculateFeatureContribution modelParameters, + string featureColumn = DefaultColumnNames.Features, + int top = Defaults.Top, + int bottom = Defaults.Bottom, + bool normalize = Defaults.Normalize) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(FeatureContributionCalculatingTransformer)), + new FeatureContributionCalculatingTransformer(env, modelParameters, featureColumn, top, bottom, normalize)) + { + _featureColumn = featureColumn; + _predictor = modelParameters; + } + + public override SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + // Check that the featureColumn is present. + Host.CheckValue(inputSchema, nameof(inputSchema)); + if (!inputSchema.TryFindColumn(_featureColumn, out var col)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _featureColumn); + // Check that the feature column is of the correct type: a vector of float. + if (col.ItemType != NumberType.R4 || col.Kind != SchemaShape.Column.VectorKind.Vector) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "feature column", _featureColumn, "Expected type is vector of float.", col.GetTypeString()); + + // Build output schemaShape. + var result = inputSchema.ToDictionary(x => x.Name); + + // Add FeatureContributions column. + var featContributionMetadata = new List(); + if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.SlotNames, out var slotMeta)) + featContributionMetadata.Add(slotMeta); + result[DefaultColumnNames.FeatureContributions] = new SchemaShape.Column( + DefaultColumnNames.FeatureContributions, col.Kind, col.ItemType, false, new SchemaShape(featContributionMetadata)); + + return new SchemaShape(result.Values); + } + } + + internal static class FeatureContributionEntryPoint + { + [TlcModule.EntryPoint(Name = "Transforms.FeatureContributionCalculationTransformer", + Desc = FeatureContributionCalculatingTransformer.Summary, + UserName = FeatureContributionCalculatingTransformer.FriendlyName)] + public static CommonOutputs.TransformOutput FeatureContributionCalculation(IHostEnvironment env, FeatureContributionCalculatingTransformer.Arguments args) + { + Contracts.CheckValue(env, nameof(env)); + var host = env.Register(nameof(FeatureContributionCalculatingTransformer)); + host.CheckValue(args, nameof(args)); + EntryPointUtils.CheckInputArgs(host, args); + host.CheckValue(args.PredictorModel, nameof(args.PredictorModel)); + + var predictor = args.PredictorModel.Predictor as ICalculateFeatureContribution; + if (predictor == null) + throw host.ExceptUserArg(nameof(predictor), "The provided model parameters do not support feature contribution calculation."); + var outData = new FeatureContributionCalculatingTransformer(host, predictor, args.FeatureColumn, args.Top, args.Bottom, args.Normalize).Transform(args.Data); + + return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, outData, args.Data), OutputData = outData}; + } + } +} diff --git a/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs b/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs index 85d91e4301..e114d5e0bc 100644 --- a/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs @@ -2,16 +2,15 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Transforms; using System; using System.Collections.Generic; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Transforms; using Float = System.Single; [assembly: LoadableClass(GenerateNumberTransform.Summary, typeof(GenerateNumberTransform), typeof(GenerateNumberTransform.Arguments), typeof(SignatureDataTransform), @@ -102,7 +101,7 @@ private sealed class Bindings : ColumnBindingsBase public readonly TauswortheHybrid.State[] States; private Bindings(bool[] useCounter, TauswortheHybrid.State[] states, - ISchema input, bool user, string[] names) + Schema input, bool user, string[] names) : base(input, user, names) { Contracts.Assert(Utils.Size(useCounter) == InfoCount); @@ -111,7 +110,7 @@ private Bindings(bool[] useCounter, TauswortheHybrid.State[] states, States = states; } - public static Bindings Create(Arguments args, ISchema input) + public static Bindings Create(Arguments args, Schema input) { var names = new string[args.Column.Length]; var useCounter = new bool[args.Column.Length]; @@ -128,7 +127,7 @@ public static Bindings Create(Arguments args, ISchema input) return new Bindings(useCounter, states, input, true, names); } - public static Bindings Create(ModelLoadContext ctx, ISchema input) + public static Bindings Create(ModelLoadContext ctx, Schema input) { Contracts.AssertValue(ctx); Contracts.AssertValue(input); @@ -232,7 +231,7 @@ public Func GetDependencies(Func predicate) Contracts.AssertValue(predicate); var active = GetActiveInput(predicate); - Contracts.Assert(active.Length == Input.ColumnCount); + Contracts.Assert(active.Length == Input.Count); return col => 0 <= col && col < active.Length && active[col]; } } @@ -264,9 +263,10 @@ private static VersionInfo GetVersionInfo() /// Host Environment. /// Input . This is the output from previous transform or loader. /// Name of the output column. + /// Seed to start random number generator. /// Use an auto-incremented integer starting at zero instead of a random number. - public GenerateNumberTransform(IHostEnvironment env, IDataView input, string name, bool useCounter = Defaults.UseCounter) - : this(env, new Arguments() { Column = new[] { new Column() { Name = name } }, UseCounter = useCounter }, input) + public GenerateNumberTransform(IHostEnvironment env, IDataView input, string name, uint? seed = null, bool useCounter = Defaults.UseCounter) + : this(env, new Arguments() { Column = new[] { new Column() { Name = name } }, Seed = seed ?? Defaults.Seed, UseCounter = useCounter }, input) { } @@ -318,7 +318,7 @@ public override void Save(ModelSaveContext ctx) _bindings.Save(ctx); } - public override Schema Schema => _bindings.AsSchema; + public override Schema OutputSchema => _bindings.AsSchema; public override bool CanShuffle { get { return false; } } @@ -332,7 +332,7 @@ public override void Save(ModelSaveContext ctx) return null; } - protected override IRowCursor GetRowCursorCore(Func predicate, IRandom rand = null) + protected override RowCursor GetRowCursorCore(Func predicate, Random rand = null) { Host.AssertValue(predicate, "predicate"); Host.AssertValueOrNull(rand); @@ -340,29 +340,28 @@ protected override IRowCursor GetRowCursorCore(Func predicate, IRando var inputPred = _bindings.GetDependencies(predicate); var active = _bindings.GetActive(predicate); var input = Source.GetRowCursor(inputPred); - return new RowCursor(Host, _bindings, input, active); + return new Cursor(Host, _bindings, input, active); } - public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, - Func predicate, int n, IRandom rand = null) + public override RowCursor[] GetRowCursorSet(Func predicate, int n, Random rand = null) { Host.CheckValue(predicate, nameof(predicate)); Host.CheckValueOrNull(rand); var inputPred = _bindings.GetDependencies(predicate); var active = _bindings.GetActive(predicate); - IRowCursor input; + RowCursor input; if (n > 1 && ShouldUseParallelCursors(predicate) != false) { - var inputs = Source.GetRowCursorSet(out consolidator, inputPred, n); + var inputs = Source.GetRowCursorSet(inputPred, n); Host.AssertNonEmpty(inputs); if (inputs.Length != 1) { - var cursors = new IRowCursor[inputs.Length]; + var cursors = new RowCursor[inputs.Length]; for (int i = 0; i < inputs.Length; i++) - cursors[i] = new RowCursor(Host, _bindings, inputs[i], active); + cursors[i] = new Cursor(Host, _bindings, inputs[i], active); return cursors; } input = inputs[0]; @@ -370,11 +369,10 @@ public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolid else input = Source.GetRowCursor(inputPred); - consolidator = null; - return new IRowCursor[] { new RowCursor(Host, _bindings, input, active) }; + return new RowCursor[] { new Cursor(Host, _bindings, input, active) }; } - private sealed class RowCursor : SynchronizedCursorBase, IRowCursor + private sealed class Cursor : SynchronizedCursorBase { private readonly Bindings _bindings; private readonly bool[] _active; @@ -383,7 +381,7 @@ private sealed class RowCursor : SynchronizedCursorBase, IRowCursor private readonly TauswortheHybrid[] _rngs; private readonly long[] _lastCounters; - public RowCursor(IChannelProvider provider, Bindings bindings, IRowCursor input, bool[] active) + public Cursor(IChannelProvider provider, Bindings bindings, RowCursor input, bool[] active) : base(provider, input) { Ch.CheckValue(bindings, nameof(bindings)); @@ -408,15 +406,15 @@ public RowCursor(IChannelProvider provider, Bindings bindings, IRowCursor input, } } - public Schema Schema => _bindings.AsSchema; + public override Schema Schema => _bindings.AsSchema; - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { Ch.Check(0 <= col && col < _bindings.ColumnCount); return _active == null || _active[col]; } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Ch.Check(IsColumnActive(col)); @@ -473,7 +471,7 @@ public static CommonOutputs.TransformOutput Generate(IHostEnvironment env, Gener var xf = new GenerateNumberTransform(h, input, input.Data); return new CommonOutputs.TransformOutput() { - Model = new TransformModel(h, xf, input.Data), + Model = new TransformModelImpl(h, xf, input.Data), OutputData = xf }; } diff --git a/src/Microsoft.ML.Data/Transforms/Hashing.cs b/src/Microsoft.ML.Data/Transforms/Hashing.cs index 60e799ddb6..9a0fbff6b1 100644 --- a/src/Microsoft.ML.Data/Transforms/Hashing.cs +++ b/src/Microsoft.ML.Data/Transforms/Hashing.cs @@ -2,19 +2,18 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Core.Data; -using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Transforms.Conversions; using System; using System.Collections.Generic; using System.Linq; using System.Runtime.CompilerServices; using System.Text; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Transforms.Conversions; [assembly: LoadableClass(HashingTransformer.Summary, typeof(IDataTransform), typeof(HashingTransformer), typeof(HashingTransformer.Arguments), typeof(SignatureDataTransform), "Hash Transform", "HashTransform", "Hash", DocName = "transform/HashTransform.md")] @@ -132,7 +131,10 @@ public sealed class ColumnInfo /// Number of bits to hash into. Must be between 1 and 31, inclusive. /// Hashing seed. /// Whether the position of each term should be included in the hash. - /// Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit. + /// During hashing we constuct mappings between original values and the produced hash values. + /// Text representation of original values are stored in the slot names of the metadata for the new column.Hashing, as such, can map many initial values to one. + /// specifies the upper bound of the number of distinct input values mapping to a hash that should be retained. + /// 0 does not retain any input values. -1 retains all input values mapping to each hash. public ColumnInfo(string input, string output, int hashBits = HashingEstimator.Defaults.HashBits, uint seed = HashingEstimator.Defaults.Seed, @@ -203,9 +205,9 @@ private static VersionInfo GetVersionInfo() private readonly VBuffer>[] _keyValues; private readonly ColumnType[] _kvTypes; - protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol) + protected override void CheckInputColumn(Schema inputSchema, int col, int srcCol) { - var type = inputSchema.GetColumnType(srcCol); + var type = inputSchema[srcCol].Type; if (!HashingEstimator.IsColumnTypeValid(type)) throw Host.ExceptParam(nameof(inputSchema), HashingEstimator.ExpectedColumnType); } @@ -216,12 +218,12 @@ private static (string input, string output)[] GetColumnPairs(ColumnInfo[] colum return columns.Select(x => (x.Input, x.Output)).ToArray(); } - private ColumnType GetOutputType(ISchema inputSchema, ColumnInfo column) + private ColumnType GetOutputType(Schema inputSchema, ColumnInfo column) { var keyCount = column.HashBits < 31 ? 1 << column.HashBits : 0; inputSchema.TryGetColumnIndex(column.Input, out int srcCol); var itemType = new KeyType(DataKind.U4, 0, keyCount, keyCount > 0); - var srcType = inputSchema.GetColumnType(srcCol); + var srcType = inputSchema[srcCol].Type; if (!srcType.IsVector) return itemType; else @@ -273,7 +275,7 @@ internal HashingTransformer(IHostEnvironment env, IDataView input, ColumnInfo[] } if (Utils.Size(sourceColumnsForInvertHash) > 0) { - using (IRowCursor srcCursor = input.GetRowCursor(sourceColumnsForInvertHash.Contains)) + using (RowCursor srcCursor = input.GetRowCursor(sourceColumnsForInvertHash.Contains)) { using (var ch = Host.Start("Invert hash building")) { @@ -307,19 +309,19 @@ internal HashingTransformer(IHostEnvironment env, IDataView input, ColumnInfo[] } } - private Delegate GetGetterCore(IRow input, int iinfo, out Action disposer) + private Delegate GetGetterCore(Row input, int iinfo, out Action disposer) { Host.AssertValue(input); Host.Assert(0 <= iinfo && iinfo < _columns.Length); disposer = null; input.Schema.TryGetColumnIndex(_columns[iinfo].Input, out int srcCol); - var srcType = input.Schema.GetColumnType(srcCol); + var srcType = input.Schema[srcCol].Type; if (!srcType.IsVector) return ComposeGetterOne(input, iinfo, srcCol, srcType); return ComposeGetterVec(input, iinfo, srcCol, srcType); } - protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); + private protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); // Factory method for SignatureLoadModel. private static HashingTransformer Create(IHostEnvironment env, ModelLoadContext ctx) @@ -367,8 +369,8 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, => Create(env, ctx).MakeDataTransform(input); // Factory method for SignatureLoadRowMapper. - private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) - => Create(env, ctx).MakeRowMapper(Schema.Create(inputSchema)); + private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema) + => Create(env, ctx).MakeRowMapper(inputSchema); // Factory method for SignatureDataTransform. public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) @@ -394,7 +396,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV } #region Getters - private ValueGetter ComposeGetterOne(IRow input, int iinfo, int srcCol, ColumnType srcType) + private ValueGetter ComposeGetterOne(Row input, int iinfo, int srcCol, ColumnType srcType) { Host.Assert(HashingEstimator.IsColumnTypeValid(srcType)); @@ -431,7 +433,7 @@ private ValueGetter ComposeGetterOne(IRow input, int iinfo, int srcCol, Co case DataKind.U8: return MakeScalarHashGetter(input, srcCol, seed, mask); case DataKind.U16: - return MakeScalarHashGetter(input, srcCol, seed, mask); + return MakeScalarHashGetter(input, srcCol, seed, mask); case DataKind.I1: return MakeScalarHashGetter(input, srcCol, seed, mask); case DataKind.I2: @@ -452,7 +454,7 @@ private ValueGetter ComposeGetterOne(IRow input, int iinfo, int srcCol, Co } } - private ValueGetter> ComposeGetterVec(IRow input, int iinfo, int srcCol, ColumnType srcType) + private ValueGetter> ComposeGetterVec(Row input, int iinfo, int srcCol, ColumnType srcType) { Host.Assert(srcType.IsVector); Host.Assert(HashingEstimator.IsColumnTypeValid(srcType.ItemType)); @@ -484,7 +486,7 @@ private ValueGetter> ComposeGetterVec(IRow input, int iinfo, int s case DataKind.U8: return ComposeGetterVecCore(input, iinfo, srcCol, srcType); case DataKind.U16: - return ComposeGetterVecCore(input, iinfo, srcCol, srcType); + return ComposeGetterVecCore(input, iinfo, srcCol, srcType); case DataKind.I1: return ComposeGetterVecCore(input, iinfo, srcCol, srcType); case DataKind.I2: @@ -505,7 +507,7 @@ private ValueGetter> ComposeGetterVec(IRow input, int iinfo, int s } } - private ValueGetter> ComposeGetterVecCore(IRow input, int iinfo, int srcCol, ColumnType srcType) + private ValueGetter> ComposeGetterVecCore(Row input, int iinfo, int srcCol, ColumnType srcType) where THash : struct, IHasher { Host.Assert(srcType.IsVector); @@ -648,19 +650,19 @@ public uint HashCore(uint seed, uint mask, in ulong value) } } - private readonly struct HashU16: IHasher + private readonly struct HashU16: IHasher { [MethodImpl(MethodImplOptions.AggressiveInlining)] - public uint HashCore(uint seed, uint mask, in UInt128 value) + public uint HashCore(uint seed, uint mask, in RowId value) { - var hash = Hashing.MurmurRound(seed, Utils.GetLo(value.Lo)); - var hi = Utils.GetHi(value.Lo); + var hash = Hashing.MurmurRound(seed, Utils.GetLo(value.Low)); + var hi = Utils.GetHi(value.Low); if (hi != 0) hash = Hashing.MurmurRound(hash, hi); - if (value.Hi != 0) + if (value.High != 0) { - hash = Hashing.MurmurRound(hash, Utils.GetLo(value.Hi)); - hi = Utils.GetHi(value.Hi); + hash = Hashing.MurmurRound(hash, Utils.GetLo(value.High)); + hi = Utils.GetHi(value.High); if (hi != 0) hash = Hashing.MurmurRound(hash, hi); } @@ -709,13 +711,13 @@ public uint HashCore(uint seed, uint mask, in long value) } } - private static ValueGetter MakeScalarHashGetter(IRow input, int srcCol, uint seed, uint mask) + private static ValueGetter MakeScalarHashGetter(Row input, int srcCol, uint seed, uint mask) where THash : struct, IHasher { Contracts.Assert(Utils.IsPowerOfTwo(mask + 1)); Contracts.AssertValue(input); - Contracts.Assert(0 <= srcCol && srcCol < input.Schema.ColumnCount); - Contracts.Assert(input.Schema.GetColumnType(srcCol).RawType == typeof(T)); + Contracts.Assert(0 <= srcCol && srcCol < input.Schema.Count); + Contracts.Assert(input.Schema[srcCol].Type.RawType == typeof(T)); var srcGetter = input.GetGetter(srcCol); T src = default; @@ -911,27 +913,27 @@ private void AddMetaKeyValues(int i, MetadataBuilder builder) { _parent._keyValues[i].CopyTo(ref dst); }; - builder.AddKeyValues(_parent._kvTypes[i].VectorSize, _parent._kvTypes[i].ItemType.AsPrimitive, getter); + builder.AddKeyValues(_parent._kvTypes[i].VectorSize, (PrimitiveType)_parent._kvTypes[i].ItemType, getter); } - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) => _parent.GetGetterCore(input, iinfo, out disposer); + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) => _parent.GetGetterCore(input, iinfo, out disposer); } private abstract class InvertHashHelper { - protected readonly IRow Row; + protected readonly Row Row; private readonly bool _includeSlot; private readonly ColumnInfo _ex; private readonly ColumnType _srcType; private readonly int _srcCol; - private InvertHashHelper(IRow row, ColumnInfo ex) + private InvertHashHelper(Row row, ColumnInfo ex) { Contracts.AssertValue(row); Row = row; row.Schema.TryGetColumnIndex(ex.Input, out int srcCol); _srcCol = srcCol; - _srcType = row.Schema.GetColumnType(srcCol); + _srcType = row.Schema[srcCol].Type; _ex = ex; // If this is a vector and ordered, then we must include the slot as part of the representation. _includeSlot = _srcType.IsVector && _ex.Ordered; @@ -946,13 +948,13 @@ private InvertHashHelper(IRow row, ColumnInfo ex) /// The extra column info /// The number of input hashed valuPres to accumulate per output hash value /// A hash getter, built on top of . - public static InvertHashHelper Create(IRow row, ColumnInfo ex, int invertHashMaxCount, Delegate dstGetter) + public static InvertHashHelper Create(Row row, ColumnInfo ex, int invertHashMaxCount, Delegate dstGetter) { row.Schema.TryGetColumnIndex(ex.Input, out int srcCol); - ColumnType typeSrc = row.Schema.GetColumnType(srcCol); + ColumnType typeSrc = row.Schema[srcCol].Type; Type t = typeSrc.IsVector ? (ex.Ordered ? typeof(ImplVecOrdered<>) : typeof(ImplVec<>)) : typeof(ImplOne<>); t = t.MakeGenericType(typeSrc.ItemType.RawType); - var consTypes = new Type[] { typeof(IRow), typeof(ColumnInfo), typeof(int), typeof(Delegate) }; + var consTypes = new Type[] { typeof(Row), typeof(ColumnInfo), typeof(int), typeof(Delegate) }; var constructorInfo = t.GetConstructor(consTypes); return (InvertHashHelper)constructorInfo.Invoke(new object[] { row, ex, invertHashMaxCount, dstGetter }); } @@ -1029,7 +1031,7 @@ private abstract class Impl : InvertHashHelper { protected readonly InvertHashCollector Collector; - protected Impl(IRow row, ColumnInfo ex, int invertHashMaxCount) + protected Impl(Row row, ColumnInfo ex, int invertHashMaxCount) : base(row, ex) { Contracts.AssertValue(row); @@ -1062,7 +1064,7 @@ private sealed class ImplOne : Impl private T _value; private uint _hash; - public ImplOne(IRow row, ColumnInfo ex, int invertHashMaxCount, Delegate dstGetter) + public ImplOne(Row row, ColumnInfo ex, int invertHashMaxCount, Delegate dstGetter) : base(row, ex, invertHashMaxCount) { _srcGetter = Row.GetGetter(_srcCol); @@ -1096,7 +1098,7 @@ private sealed class ImplVec : Impl private VBuffer _value; private VBuffer _hash; - public ImplVec(IRow row, ColumnInfo ex, int invertHashMaxCount, Delegate dstGetter) + public ImplVec(Row row, ColumnInfo ex, int invertHashMaxCount, Delegate dstGetter) : base(row, ex, invertHashMaxCount) { _srcGetter = Row.GetGetter>(_srcCol); @@ -1130,7 +1132,7 @@ private sealed class ImplVecOrdered : Impl> private VBuffer _value; private VBuffer _hash; - public ImplVecOrdered(IRow row, ColumnInfo ex, int invertHashMaxCount, Delegate dstGetter) + public ImplVecOrdered(Row row, ColumnInfo ex, int invertHashMaxCount, Delegate dstGetter) : base(row, ex, invertHashMaxCount) { _srcGetter = Row.GetGetter>(_srcCol); @@ -1211,7 +1213,10 @@ internal static bool IsColumnTypeValid(ColumnType type) /// Name of the column to be transformed. /// Name of the output column. If this is null '' will be used. /// Number of bits to hash into. Must be between 1 and 31, inclusive. - /// Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit. + /// During hashing we constuct mappings between original values and the produced hash values. + /// Text representation of original values are stored in the slot names of the metadata for the new column.Hashing, as such, can map many initial values to one. + /// specifies the upper bound of the number of distinct input values mapping to a hash that should be retained. + /// 0 does not retain any input values. -1 retains all input values mapping to each hash. public HashingEstimator(IHostEnvironment env, string inputColumn, string outputColumn = null, int hashBits = Defaults.HashBits, int invertHash = Defaults.InvertHash) : this(env, new HashingTransformer.ColumnInfo(inputColumn, outputColumn ?? inputColumn, hashBits: hashBits, invertHash: invertHash)) @@ -1235,7 +1240,7 @@ public HashingEstimator(IHostEnvironment env, params HashingTransformer.ColumnIn public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in _columns) { if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) diff --git a/src/Microsoft.ML.Data/Transforms/InvertHashUtils.cs b/src/Microsoft.ML.Data/Transforms/InvertHashUtils.cs index 8d041dc6ce..355b3e0333 100644 --- a/src/Microsoft.ML.Data/Transforms/InvertHashUtils.cs +++ b/src/Microsoft.ML.Data/Transforms/InvertHashUtils.cs @@ -6,14 +6,14 @@ using System.Collections.Generic; using System.IO; using System.Text; -using Microsoft.ML.Data; -using Microsoft.ML.Runtime.Data.IO; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Data.IO; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { - public static class InvertHashUtils + [BestFriend] + internal static class InvertHashUtils { /// /// Clears a destination StringBuilder. If it is currently null, allocates it. @@ -35,23 +35,23 @@ private static void ClearDst(ref StringBuilder dst) public static ValueMapper GetSimpleMapper(Schema schema, int col) { Contracts.AssertValue(schema); - Contracts.Assert(0 <= col && col < schema.ColumnCount); - var type = schema.GetColumnType(col).ItemType; + Contracts.Assert(0 <= col && col < schema.Count); + var type = schema[col].Type.ItemType; Contracts.Assert(type.RawType == typeof(T)); var conv = Conversion.Conversions.Instance; // First: if not key, then get the standard string converison. - if (!type.IsKey) + if (!(type is KeyType keyType)) return conv.GetStringConversion(type); bool identity; // Second choice: if key, utilize the KeyValues metadata for that key, if it has one and is text. - if (schema.HasKeyValues(col, type.KeyCount)) + if (schema[col].HasKeyValues(keyType.KeyCount)) { // REVIEW: Non-textual KeyValues are certainly possible. Should we handle them? // Get the key names. VBuffer> keyValues = default; - schema.GetMetadata(MetadataUtils.Kinds.KeyValues, col, ref keyValues); + schema[col].Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref keyValues); ReadOnlyMemory value = default; // REVIEW: We could optimize for identity, but it's probably not worthwhile. @@ -70,7 +70,7 @@ public static ValueMapper GetSimpleMapper(Schema schema, in } // Third choice: just use the key value itself, subject to offsetting by the min. - return conv.GetKeyStringConversion(type.AsKey); + return conv.GetKeyStringConversion(keyType); } public static ValueMapper, StringBuilder> GetPairMapper(ValueMapper submap) @@ -101,7 +101,8 @@ public static void AppendToEnd(StringBuilder src, StringBuilder dst, ref char[] } } - public sealed class InvertHashCollector + [BestFriend] + internal sealed class InvertHashCollector { /// /// This is a small struct that is meant to compare akin to the value, diff --git a/src/Microsoft.ML.Data/Transforms/KeyToValue.cs b/src/Microsoft.ML.Data/Transforms/KeyToValue.cs index fd6af73412..0cf18b2772 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToValue.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToValue.cs @@ -2,24 +2,21 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Core.Data; -using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Model.Pfa; -using Microsoft.ML.StaticPipe; -using Microsoft.ML.StaticPipe.Runtime; -using Microsoft.ML.Transforms.Conversions; -using Newtonsoft.Json.Linq; using System; using System.Collections.Generic; using System.Linq; using System.Reflection; using System.Text; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Model.Pfa; +using Microsoft.ML.Transforms.Conversions; +using Newtonsoft.Json.Linq; [assembly: LoadableClass(typeof(IDataTransform), typeof(KeyToValueMappingTransformer), typeof(KeyToValueMappingTransformer.Arguments), typeof(SignatureDataTransform), KeyToValueMappingTransformer.UserName, KeyToValueMappingTransformer.LoaderSignature, "KeyToValue", "KeyToVal", "Unterm")] @@ -138,8 +135,8 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, /// /// Factory method for SignatureLoadRowMapper. /// - private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) - => Create(env, ctx).MakeRowMapper(Schema.Create(inputSchema)); + private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema) + => Create(env, ctx).MakeRowMapper(inputSchema); public override void Save(ModelSaveContext ctx) { @@ -153,7 +150,7 @@ public override void Save(ModelSaveContext ctx) SaveColumns(ctx); } - protected override IRowMapper MakeRowMapper(Schema inputSchema) => new Mapper(this, inputSchema); + private protected override IRowMapper MakeRowMapper(Schema inputSchema) => new Mapper(this, inputSchema); private sealed class Mapper : OneToOneMapperBase, ISaveAsPfa { @@ -211,7 +208,7 @@ public void SaveAsPfa(BoundPfaContext ctx) ctx.DeclareVar(toDeclare.ToArray()); } - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { Host.AssertValue(input); Host.Assert(0 <= iinfo && iinfo < _types.Length); @@ -220,7 +217,7 @@ protected override Delegate MakeGetter(IRow input, int iinfo, Func ac } // Computes the types of the columns and constructs the kvMaps. - private void ComputeKvMaps(ISchema schema, out ColumnType[] types, out KeyToValueMap[] kvMaps) + private void ComputeKvMaps(Schema schema, out ColumnType[] types, out KeyToValueMap[] kvMaps) { types = new ColumnType[_parent.ColumnPairs.Length]; kvMaps = new KeyToValueMap[_parent.ColumnPairs.Length]; @@ -228,14 +225,14 @@ private void ComputeKvMaps(ISchema schema, out ColumnType[] types, out KeyToValu { // Construct kvMaps. Contracts.Assert(types[iinfo] == null); - var typeSrc = schema.GetColumnType(ColMapNewToOld[iinfo]); - var typeVals = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, ColMapNewToOld[iinfo]); + var typeSrc = schema[ColMapNewToOld[iinfo]].Type; + var typeVals = schema[ColMapNewToOld[iinfo]].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues)?.Type; Host.Check(typeVals != null, "Metadata KeyValues does not exist"); Host.Check(typeVals.VectorSize == typeSrc.ItemType.KeyCount, "KeyValues metadata size does not match column type key count"); - if (!typeSrc.IsVector) + if (!(typeSrc is VectorType vectorType)) types[iinfo] = typeVals.ItemType; else - types[iinfo] = new VectorType(typeVals.ItemType.AsPrimitive, typeSrc.AsVector); + types[iinfo] = new VectorType((PrimitiveType)typeVals.ItemType, vectorType); // MarshalInvoke with two generic params. Func func = GetKeyMetadata; @@ -258,7 +255,7 @@ private KeyToValueMap GetKeyMetadata(int iinfo, ColumnType typeKey Host.Check(keyMetadata.Length == typeKey.ItemType.KeyCount); VBufferUtils.Densify(ref keyMetadata); - return new KeyToValueMap(this, typeKey.ItemType.AsKey, typeVal.ItemType.AsPrimitive, keyMetadata, iinfo); + return new KeyToValueMap(this, (KeyType)typeKey.ItemType, (PrimitiveType)typeVal.ItemType, keyMetadata, iinfo); } /// /// A map is an object capable of creating the association from an input type, to an output @@ -294,7 +291,7 @@ protected KeyToValueMap(Mapper mapper, PrimitiveType typeVal, int iinfo) InfoIndex = iinfo; } - public abstract Delegate GetMappingGetter(IRow input); + public abstract Delegate GetMappingGetter(Row input); public abstract JToken SavePfa(BoundPfaContext ctx, JToken srcToken); } @@ -318,16 +315,16 @@ public KeyToValueMap(Mapper parent, KeyType typeKey, PrimitiveType typeVal, VBuf _values = values; // REVIEW: May want to include more specific information about what the specific value is for the default. - _na = Runtime.Data.Conversion.Conversions.Instance.GetNAOrDefault(TypeOutput.ItemType, out _naMapsToDefault); + _na = Data.Conversion.Conversions.Instance.GetNAOrDefault(TypeOutput.ItemType, out _naMapsToDefault); if (_naMapsToDefault) { // Only initialize _isDefault if _defaultIsNA is true as this is the only case in which it is used. - _isDefault = Runtime.Data.Conversion.Conversions.Instance.GetIsDefaultPredicate(TypeOutput.ItemType); + _isDefault = Data.Conversion.Conversions.Instance.GetIsDefaultPredicate(TypeOutput.ItemType); } bool identity; - _convertToUInt = Runtime.Data.Conversion.Conversions.Instance.GetStandardConversion(typeKey, NumberType.U4, out identity); + _convertToUInt = Data.Conversion.Conversions.Instance.GetStandardConversion(typeKey, NumberType.U4, out identity); } private void MapKey(in TKey src, ref TValue dst) @@ -346,7 +343,7 @@ private void MapKey(in TKey src, ReadOnlySpan values, ref TValue dst) dst = _na; } - public override Delegate GetMappingGetter(IRow input) + public override Delegate GetMappingGetter(Row input) { // When constructing the getter, there are a few cases we have to consider: // If scalar then it's just a straightforward mapping. @@ -512,7 +509,7 @@ public KeyToValueMappingEstimator(IHostEnvironment env, params (string input, st public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in Transformer.Columns) { if (!inputSchema.TryFindColumn(colInfo.input, out var col)) @@ -533,117 +530,4 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) return new SchemaShape(result.Values); } } - - /// - /// Extension methods for the static-pipeline over objects. - /// - public static class KeyToValueStaticExtensions - { - private interface IColInput - { - PipelineColumn Input { get; } - } - - private sealed class OutKeyColumn : Key, IColInput - { - public PipelineColumn Input { get; } - - public OutKeyColumn(Key> input) - : base(Reconciler.Inst, input) - { - Input = input; - } - } - - private sealed class OutScalarColumn : Scalar, IColInput - { - public PipelineColumn Input { get; } - - public OutScalarColumn(Key input) - : base(Reconciler.Inst, input) - { - Input = input; - } - } - - private sealed class OutVectorColumn : Vector, IColInput - { - public PipelineColumn Input { get; } - - public OutVectorColumn(Vector> input) - : base(Reconciler.Inst, input) - { - Input = input; - } - } - - private sealed class OutVarVectorColumn : VarVector, IColInput - { - public PipelineColumn Input { get; } - - public OutVarVectorColumn(VarVector> input) - : base(Reconciler.Inst, input) - { - Input = input; - } - } - - private sealed class Reconciler : EstimatorReconciler - { - public static Reconciler Inst = new Reconciler(); - - private Reconciler() { } - - public override IEstimator Reconcile(IHostEnvironment env, - PipelineColumn[] toOutput, - IReadOnlyDictionary inputNames, - IReadOnlyDictionary outputNames, - IReadOnlyCollection usedNames) - { - var cols = new (string input, string output)[toOutput.Length]; - for (int i = 0; i < toOutput.Length; ++i) - { - var outCol = (IColInput)toOutput[i]; - cols[i] = (inputNames[outCol.Input], outputNames[toOutput[i]]); - } - return new KeyToValueMappingEstimator(env, cols); - } - } - - /// - /// Convert a key column to a column containing the corresponding value. - /// - public static Key ToValue(this Key> input) - { - Contracts.CheckValue(input, nameof(input)); - return new OutKeyColumn(input); - } - - /// - /// Convert a key column to a column containing the corresponding value. - /// - public static Scalar ToValue(this Key input) - { - Contracts.CheckValue(input, nameof(input)); - return new OutScalarColumn(input); - } - - /// - /// Convert a key column to a column containing the corresponding value. - /// - public static Vector ToValue(this Vector> input) - { - Contracts.CheckValue(input, nameof(input)); - return new OutVectorColumn(input); - } - - /// - /// Convert a key column to a column containing the corresponding value. - /// - public static VarVector ToValue(this VarVector> input) - { - Contracts.CheckValue(input, nameof(input)); - return new OutVarVectorColumn(input); - } - } } diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVector.cs b/src/Microsoft.ML.Data/Transforms/KeyToVector.cs index 68bb844a3f..c49848465e 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToVector.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToVector.cs @@ -2,23 +2,20 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Core.Data; -using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Model.Onnx; -using Microsoft.ML.Runtime.Model.Pfa; -using Microsoft.ML.StaticPipe; -using Microsoft.ML.StaticPipe.Runtime; -using Microsoft.ML.Transforms.Conversions; -using Newtonsoft.Json.Linq; using System; using System.Collections.Generic; using System.Linq; using System.Text; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Model.Onnx; +using Microsoft.ML.Model.Pfa; +using Microsoft.ML.Transforms.Conversions; +using Newtonsoft.Json.Linq; [assembly: LoadableClass(KeyToVectorMappingTransformer.Summary, typeof(IDataTransform), typeof(KeyToVectorMappingTransformer), typeof(KeyToVectorMappingTransformer.Arguments), typeof(SignatureDataTransform), "Key To Vector Transform", KeyToVectorMappingTransformer.UserName, "KeyToVector", "ToVector", DocName = "transform/KeyToVectorTransform.md")] @@ -88,16 +85,26 @@ public sealed class Arguments public bool Bag = KeyToVectorMappingEstimator.Defaults.Bag; } - public class ColumnInfo + /// + /// Describes how the transformer handles one column pair. + /// + public sealed class ColumnInfo { public readonly string Input; public readonly string Output; public readonly bool Bag; - public ColumnInfo(string input, string output, bool bag = KeyToVectorMappingEstimator.Defaults.Bag) + /// + /// Describes how the transformer handles one column pair. + /// + /// Name of input column. + /// Name of the column resulting from the transformation of . Null means is replaced. + /// Whether to combine multiple indicator vectors into a single bag vector instead of concatenating them. This is only relevant when the input column is a vector. + public ColumnInfo(string input, string output = null, bool bag = KeyToVectorMappingEstimator.Defaults.Bag) { + Contracts.CheckNonWhiteSpace(input, nameof(input)); Input = input; - Output = output; + Output = output ?? input; Bag = bag; } } @@ -120,9 +127,9 @@ private string TestIsKey(ColumnType type) return "key type of known cardinality"; } - protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol) + protected override void CheckInputColumn(Schema inputSchema, int col, int srcCol) { - var type = inputSchema.GetColumnType(srcCol); + var type = inputSchema[srcCol].Type; string reason = TestIsKey(type); if (reason != null) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input, reason, type.ToString()); @@ -226,10 +233,10 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, => Create(env, ctx).MakeDataTransform(input); // Factory method for SignatureLoadRowMapper. - private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) - => Create(env, ctx).MakeRowMapper(Schema.Create(inputSchema)); + private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema) + => Create(env, ctx).MakeRowMapper(inputSchema); - protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); + private protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx, ISaveAsPfa { @@ -266,7 +273,7 @@ public Mapper(KeyToVectorMappingTransformer parent, Schema inputSchema) } } - private ColInfo[] CreateInfos(ISchema inputSchema) + private ColInfo[] CreateInfos(Schema inputSchema) { Host.AssertValue(inputSchema); var infos = new ColInfo[_parent.ColumnPairs.Length]; @@ -274,7 +281,7 @@ private ColInfo[] CreateInfos(ISchema inputSchema) { if (!inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out int colSrc)) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].input); - var type = inputSchema.GetColumnType(colSrc); + var type = inputSchema[colSrc].Type; _parent.CheckInputColumn(inputSchema, i, colSrc); infos[i] = new ColInfo(_parent.ColumnPairs[i].output, _parent.ColumnPairs[i].input, type); } @@ -438,7 +445,7 @@ private void GetCategoricalSlotRanges(int iinfo, ref VBuffer dst) dst = new VBuffer(ranges.Length, ranges); } - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { Host.AssertValue(input); Host.Assert(0 <= iinfo && iinfo < _infos.Length); @@ -456,7 +463,7 @@ protected override Delegate MakeGetter(IRow input, int iinfo, Func ac /// This is for the singleton case. This should be equivalent to both Bag and Ord over /// a vector of size one. /// - private ValueGetter> MakeGetterOne(IRow input, int iinfo) + private ValueGetter> MakeGetterOne(Row input, int iinfo) { Host.AssertValue(input); Host.Assert(_infos[iinfo].TypeSrc.IsKey); @@ -478,7 +485,7 @@ private ValueGetter> MakeGetterOne(IRow input, int iinfo) return; } - var editor = VBufferEditor.Create(ref dst, size, 1); + var editor = VBufferEditor.Create(ref dst, size, 1, requireIndicesOnDense: true); editor.Values[0] = 1; editor.Indices[0] = (int)src - 1; @@ -489,7 +496,7 @@ private ValueGetter> MakeGetterOne(IRow input, int iinfo) /// /// This is for the bagging case - vector input and outputs should be added. /// - private ValueGetter> MakeGetterBag(IRow input, int iinfo) + private ValueGetter> MakeGetterBag(Row input, int iinfo) { Host.AssertValue(input); Host.Assert(_infos[iinfo].TypeSrc.IsVector); @@ -533,7 +540,7 @@ private ValueGetter> MakeGetterBag(IRow input, int iinfo) /// /// This is for the indicator (non-bagging) case - vector input and outputs should be concatenated. /// - private ValueGetter> MakeGetterInd(IRow input, int iinfo) + private ValueGetter> MakeGetterInd(Row input, int iinfo) { Host.AssertValue(input); Host.Assert(_infos[iinfo].TypeSrc.IsVector); @@ -746,7 +753,7 @@ private KeyToVectorMappingEstimator(IHostEnvironment env, KeyToVectorMappingTran public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in Transformer.Columns) { if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) @@ -769,256 +776,4 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) return new SchemaShape(result.Values); } } - - /// - /// Extension methods for the static-pipeline over objects. - /// - public static class KeyToVectorExtensions - { - private interface IColInput - { - PipelineColumn Input { get; } - bool Bag { get; } - } - - private sealed class OutVectorColumn : Vector, IColInput - { - public PipelineColumn Input { get; } - public bool Bag { get; } - - public OutVectorColumn(Key input) - : base(Reconciler.Inst, input) - { - Input = input; - Bag = false; - } - - public OutVectorColumn(Vector> input, bool bag) - : base(Reconciler.Inst, input) - { - Input = input; - Bag = bag; - } - - public OutVectorColumn(VarVector> input) - : base(Reconciler.Inst, input) - { - Input = input; - Bag = true; - } - } - - private sealed class OutVarVectorColumn : VarVector, IColInput - { - public PipelineColumn Input { get; } - public bool Bag { get; } - - public OutVarVectorColumn(VarVector> input) - : base(Reconciler.Inst, input) - { - Input = input; - Bag = false; - } - } - - private sealed class OutVectorColumn : Vector, IColInput - { - public PipelineColumn Input { get; } - public bool Bag { get; } - - public OutVectorColumn(Key input) - : base(Reconciler.Inst, input) - { - Input = input; - Bag = false; - } - - public OutVectorColumn(Vector> input, bool bag) - : base(Reconciler.Inst, input) - { - Input = input; - Bag = bag; - } - - public OutVectorColumn(VarVector> input) - : base(Reconciler.Inst, input) - { - Input = input; - Bag = true; - } - } - - private sealed class OutVarVectorColumn : VarVector, IColInput - { - public PipelineColumn Input { get; } - public bool Bag { get; } - - public OutVarVectorColumn(VarVector> input) - : base(Reconciler.Inst, input) - { - Input = input; - Bag = false; - } - } - - private sealed class Reconciler : EstimatorReconciler - { - public static Reconciler Inst = new Reconciler(); - - private Reconciler() { } - - public override IEstimator Reconcile(IHostEnvironment env, - PipelineColumn[] toOutput, - IReadOnlyDictionary inputNames, - IReadOnlyDictionary outputNames, - IReadOnlyCollection usedNames) - { - var infos = new KeyToVectorMappingTransformer.ColumnInfo[toOutput.Length]; - for (int i = 0; i < toOutput.Length; ++i) - { - var col = (IColInput)toOutput[i]; - infos[i] = new KeyToVectorMappingTransformer.ColumnInfo(inputNames[col.Input], outputNames[toOutput[i]], col.Bag); - } - return new KeyToVectorMappingEstimator(env, infos); - } - } - - /// - /// Takes a column of key type of known cardinality and produces an indicator vector of floats. - /// Each key value of the input is used to create an indicator vector: the indicator vector is the length of the key cardinality, - /// where all values are 0, except for the entry corresponding to the value of the key, which is 1. - /// If the key value is missing, then all values are 0. Naturally this tends to generate very sparse vectors. - /// - public static Vector ToVector(this Key input) - { - Contracts.CheckValue(input, nameof(input)); - return new OutVectorColumn(input); - } - - /// - /// Takes a column of key type of known cardinality and produces an indicator vector of floats. - /// Each key value of the input is used to create an indicator vector: the indicator vector is the length of the key cardinality, - /// where all values are 0, except for the entry corresponding to the value of the key, which is 1. - /// If the key value is missing, then all values are 0. Naturally this tends to generate very sparse vectors. - /// - public static Vector ToVector(this Vector> input) - { - Contracts.CheckValue(input, nameof(input)); - return new OutVectorColumn(input, false); - } - - /// - /// Takes a column of key type of known cardinality and produces an indicator vector of floats. - /// Each key value of the input is used to create an indicator vector: the indicator vector is the length of the key cardinality, - /// where all values are 0, except for the entry corresponding to the value of the key, which is 1. - /// If the key value is missing, then all values are 0. Naturally this tends to generate very sparse vectors. - /// In this case then the indicator vectors for all values in the column will be simply added together, - /// to produce the final vector with type equal to the key cardinality; so, in all cases, whether vector or scalar, - /// the output column will be a vector type of length equal to that cardinality. - /// - public static VarVector ToVector(this VarVector> input) - { - Contracts.CheckValue(input, nameof(input)); - return new OutVarVectorColumn(input); - } - - /// - /// Takes a column of key type of known cardinality and produces an indicator vector of floats. - /// Each key value of the input is used to create an indicator vector: the indicator vector is the length of the key cardinality, - /// where all values are 0, except for the entry corresponding to the value of the key, which is 1. - /// If the key value is missing, then all values are 0. Naturally this tends to generate very sparse vectors. - /// In this case then the indicator vectors for all values in the column will be simply added together, - /// to produce the final vector with type equal to the key cardinality; so, in all cases, whether vector or scalar, - /// the output column will be a vector type of length equal to that cardinality. - /// - public static Vector ToBaggedVector(this Vector> input) - { - Contracts.CheckValue(input, nameof(input)); - return new OutVectorColumn(input, true); - } - - /// - /// Takes a column of key type of known cardinality and produces an indicator vector of floats. - /// Each key value of the input is used to create an indicator vector: the indicator vector is the length of the key cardinality, - /// where all values are 0, except for the entry corresponding to the value of the key, which is 1. - /// If the key value is missing, then all values are 0. Naturally this tends to generate very sparse vectors. - /// In this case then the indicator vectors for all values in the column will be simply added together, - /// to produce the final vector with type equal to the key cardinality; so, in all cases, whether vector or scalar, - /// the output column will be a vector type of length equal to that cardinality. - /// - public static Vector ToBaggedVector(this VarVector> input) - { - Contracts.CheckValue(input, nameof(input)); - return new OutVectorColumn(input); - } - - /// - /// Takes a column of key type of known cardinality and produces an indicator vector of floats. - /// Each key value of the input is used to create an indicator vector: the indicator vector is the length of the key cardinality, - /// where all values are 0, except for the entry corresponding to the value of the key, which is 1. - /// If the key value is missing, then all values are 0. Naturally this tends to generate very sparse vectors. - /// - public static Vector ToVector(this Key input) - { - Contracts.CheckValue(input, nameof(input)); - return new OutVectorColumn(input); - } - - /// - /// Takes a column of key type of known cardinality and produces an indicator vector of floats. - /// Each key value of the input is used to create an indicator vector: the indicator vector is the length of the key cardinality, - /// where all values are 0, except for the entry corresponding to the value of the key, which is 1. - /// If the key value is missing, then all values are 0. Naturally this tends to generate very sparse vectors. - /// - public static Vector ToVector(this Vector> input) - { - Contracts.CheckValue(input, nameof(input)); - return new OutVectorColumn(input, false); - } - - /// - /// Takes a column of key type of known cardinality and produces an indicator vector of floats. - /// Each key value of the input is used to create an indicator vector: the indicator vector is the length of the key cardinality, - /// where all values are 0, except for the entry corresponding to the value of the key, which is 1. - /// If the key value is missing, then all values are 0. Naturally this tends to generate very sparse vectors. - /// In this case then the indicator vectors for all values in the column will be simply added together, - /// to produce the final vector with type equal to the key cardinality; so, in all cases, whether vector or scalar, - /// the output column will be a vector type of length equal to that cardinality. - /// - public static VarVector ToVector(this VarVector> input) - { - Contracts.CheckValue(input, nameof(input)); - return new OutVarVectorColumn(input); - } - - /// - /// Takes a column of key type of known cardinality and produces an indicator vector of floats. - /// Each key value of the input is used to create an indicator vector: the indicator vector is the length of the key cardinality, - /// where all values are 0, except for the entry corresponding to the value of the key, which is 1. - /// If the key value is missing, then all values are 0. Naturally this tends to generate very sparse vectors. - /// In this case then the indicator vectors for all values in the column will be simply added together, - /// to produce the final vector with type equal to the key cardinality; so, in all cases, whether vector or scalar, - /// the output column will be a vector type of length equal to that cardinality. - /// - public static Vector ToBaggedVector(this Vector> input) - { - Contracts.CheckValue(input, nameof(input)); - return new OutVectorColumn(input, true); - } - - /// - /// Takes a column of key type of known cardinality and produces an indicator vector of floats. - /// Each key value of the input is used to create an indicator vector: the indicator vector is the length of the key cardinality, - /// where all values are 0, except for the entry corresponding to the value of the key, which is 1. - /// If the key value is missing, then all values are 0. Naturally this tends to generate very sparse vectors. - /// In this case then the indicator vectors for all values in the column will be simply added together, - /// to produce the final vector with type equal to the key cardinality; so, in all cases, whether vector or scalar, - /// the output column will be a vector type of length equal to that cardinality. - /// - public static Vector ToBaggedVector(this VarVector> input) - { - Contracts.CheckValue(input, nameof(input)); - return new OutVectorColumn(input); - } - } } diff --git a/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs b/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs index f63b97333b..2e27b89496 100644 --- a/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs @@ -2,15 +2,15 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Transforms; using System; using System.Text; using System.Threading; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Transforms; using Float = System.Single; [assembly: LoadableClass(LabelConvertTransform.Summary, typeof(LabelConvertTransform), typeof(LabelConvertTransform.Arguments), typeof(SignatureDataTransform), @@ -165,7 +165,7 @@ private bool PassThrough(string kind, int iinfo) return kind != MetadataUtils.Kinds.KeyValues; } - protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer) + protected override Delegate GetGetterCore(IChannel ch, Row input, int iinfo, out Action disposer) { Contracts.AssertValueOrNull(ch); Contracts.AssertValue(input); @@ -173,7 +173,7 @@ protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, ou disposer = null; int col = Infos[iinfo].Source; - var typeSrc = input.Schema.GetColumnType(col); + var typeSrc = input.Schema[col].Type; Contracts.Assert(RowCursorUtils.TestGetLabelGetter(typeSrc) == null); return RowCursorUtils.GetLabelGetter(input, col); } @@ -190,34 +190,34 @@ protected override VectorType GetSlotTypeCore(int iinfo) return _slotType; } - protected override ISlotCursor GetSlotCursorCore(int iinfo) + protected override SlotCursor GetSlotCursorCore(int iinfo) { Host.Assert(0 <= iinfo && iinfo < Infos.Length); Host.AssertValue(Infos[iinfo].SlotTypeSrc); - ISlotCursor cursor = InputTranspose.GetSlotCursor(Infos[iinfo].Source); - return new SlotCursor(Host, cursor, GetSlotTypeCore(iinfo)); + var cursor = InputTranspose.GetSlotCursor(Infos[iinfo].Source); + return new SlotCursorImpl(Host, cursor, GetSlotTypeCore(iinfo)); } - private sealed class SlotCursor : SynchronizedCursorBase, ISlotCursor + private sealed class SlotCursorImpl : SlotCursor.SynchronizedSlotCursor { private readonly Delegate _getter; private readonly VectorType _type; - public SlotCursor(IChannelProvider provider, ISlotCursor cursor, VectorType typeDst) + public SlotCursorImpl(IChannelProvider provider, SlotCursor cursor, VectorType typeDst) : base(provider, cursor) { Ch.AssertValue(typeDst); - _getter = RowCursorUtils.GetLabelGetter(Input); + _getter = RowCursorUtils.GetLabelGetter(cursor); _type = typeDst; } - public VectorType GetSlotType() + public override VectorType GetSlotType() { return _type; } - public ValueGetter> GetGetter() + public override ValueGetter> GetGetter() { ValueGetter> getter = _getter as ValueGetter>; if (getter == null) diff --git a/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs b/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs index 3734245258..905c123375 100644 --- a/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs @@ -2,15 +2,15 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Transforms; using System; using System.Text; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Transforms; [assembly: LoadableClass(typeof(LabelIndicatorTransform), typeof(LabelIndicatorTransform.Arguments), typeof(SignatureDataTransform), LabelIndicatorTransform.UserName, LabelIndicatorTransform.LoadName, "LabelIndicator")] @@ -163,7 +163,7 @@ protected override ColumnType GetColumnTypeCore(int iinfo) return BoolType.Instance; } - protected override Delegate GetGetterCore(IChannel ch, IRow input, + protected override Delegate GetGetterCore(IChannel ch, Row input, int iinfo, out Action disposer) { Host.AssertValue(ch); @@ -175,7 +175,7 @@ protected override Delegate GetGetterCore(IChannel ch, IRow input, return GetGetter(ch, input, iinfo); } - private ValueGetter GetGetter(IChannel ch, IRow input, int iinfo) + private ValueGetter GetGetter(IChannel ch, Row input, int iinfo) { Host.AssertValue(ch); ch.AssertValue(input); @@ -234,7 +234,7 @@ public static CommonOutputs.TransformOutput LabelIndicator(IHostEnvironment env, EntryPointUtils.CheckInputArgs(host, input); var xf = Create(host, input, input.Data); - return new CommonOutputs.TransformOutput { Model = new TransformModel(env, xf, input.Data), OutputData = xf }; + return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, xf, input.Data), OutputData = xf }; } } } diff --git a/src/Microsoft.ML.Data/Transforms/MetadataDispatcher.cs b/src/Microsoft.ML.Data/Transforms/MetadataDispatcher.cs index bff3a20f9b..38ee0f82e0 100644 --- a/src/Microsoft.ML.Data/Transforms/MetadataDispatcher.cs +++ b/src/Microsoft.ML.Data/Transforms/MetadataDispatcher.cs @@ -5,12 +5,9 @@ using System; using System.Collections.Generic; using System.Linq; -using System.Text; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// Base class for handling the schema metadata API. @@ -25,7 +22,7 @@ public abstract class MetadataDispatcherBase protected sealed class ColInfo { // The source schema to pass through metadata from. May be null, indicating none. - public readonly ISchema SchemaSrc; + public readonly Schema SchemaSrc; // The source column index to pass through metadata from. public readonly int IndexSrc; // The metadata kind predicate indicating the kinds of metadata to pass through @@ -46,7 +43,7 @@ public IEnumerable Getters } } - public ColInfo(ISchema schemaSrc, int indexSrc, Func filterSrc, + public ColInfo(Schema schemaSrc, int indexSrc, Func filterSrc, IEnumerable getters = null) { SchemaSrc = schemaSrc; @@ -157,11 +154,11 @@ protected MetadataDispatcherBase(int colCount) /// the same ColInfo, if desired. Simply call RegisterColumn multiple times, passing /// the same ColInfo but different index values. This can only be called before Seal is called. /// - protected ColInfo CreateInfo(ISchema schemaSrc = null, int indexSrc = -1, + protected ColInfo CreateInfo(Schema schemaSrc = null, int indexSrc = -1, Func filterSrc = null) { Contracts.Check(!_sealed, "MetadataDispatcher sealed"); - Contracts.Check(schemaSrc == null || (0 <= indexSrc && indexSrc < schemaSrc.ColumnCount), "indexSrc out of range"); + Contracts.Check(schemaSrc == null || (0 <= indexSrc && indexSrc < schemaSrc.Count), "indexSrc out of range"); Contracts.Check(filterSrc == null || schemaSrc != null, "filterSrc should be null if schemaSrc is null"); return new ColInfo(schemaSrc, indexSrc, filterSrc); } @@ -235,7 +232,7 @@ private IEnumerable> GetTypesCore(int index, Co yield break; // Pass through from base, with filtering. - foreach (var kvp in info.SchemaSrc.GetMetadataTypes(info.IndexSrc)) + foreach (var kvp in info.SchemaSrc[info.IndexSrc].Metadata.Schema.Select(c => new KeyValuePair(c.Name, c.Type))) { if (kinds != null && kinds.Contains(kvp.Key)) continue; @@ -267,7 +264,7 @@ public ColumnType GetMetadataTypeOrNull(string kind, int index) return null; if (info.FilterSrc != null && !info.FilterSrc(kind, index)) return null; - return info.SchemaSrc.GetMetadataTypeOrNull(kind, info.IndexSrc); + return info.SchemaSrc[info.IndexSrc].Metadata.Schema.GetColumnOrNull(kind)?.Type; } /// @@ -297,7 +294,7 @@ public void GetMetadata(IExceptionContext ectx, string kind, int index, if (info.SchemaSrc == null || info.FilterSrc != null && !info.FilterSrc(kind, index)) throw ectx.ExceptGetMetadata(); - info.SchemaSrc.GetMetadata(kind, info.IndexSrc, ref value); + info.SchemaSrc[info.IndexSrc].Metadata.GetValue(kind, ref value); } } @@ -326,7 +323,7 @@ public Builder BuildMetadata(int index) /// Start building metadata for a column that passes through all metadata from /// a source column. /// - public Builder BuildMetadata(int index, ISchema schemaSrc, int indexSrc) + public Builder BuildMetadata(int index, Schema schemaSrc, int indexSrc) { Contracts.CheckValue(schemaSrc, nameof(schemaSrc)); return new Builder(this, index, schemaSrc, indexSrc); @@ -337,7 +334,7 @@ public Builder BuildMetadata(int index, ISchema schemaSrc, int indexSrc) /// a source column. The kinds that are passed through are those for which /// returns true. /// - public Builder BuildMetadata(int index, ISchema schemaSrc, int indexSrc, Func filterSrc) + public Builder BuildMetadata(int index, Schema schemaSrc, int indexSrc, Func filterSrc) { Contracts.CheckValue(schemaSrc, nameof(schemaSrc)); return new Builder(this, index, schemaSrc, indexSrc, filterSrc); @@ -347,7 +344,7 @@ public Builder BuildMetadata(int index, ISchema schemaSrc, int indexSrc, Func - public Builder BuildMetadata(int index, ISchema schemaSrc, int indexSrc, string kindSrc) + public Builder BuildMetadata(int index, Schema schemaSrc, int indexSrc, string kindSrc) { Contracts.CheckValue(schemaSrc, nameof(schemaSrc)); Contracts.CheckNonWhiteSpace(kindSrc, nameof(kindSrc)); @@ -358,7 +355,7 @@ public Builder BuildMetadata(int index, ISchema schemaSrc, int indexSrc, string /// Start building metadata for a column that passes through metadata of the given kinds from /// a source column. /// - public Builder BuildMetadata(int index, ISchema schemaSrc, int indexSrc, params string[] kindsSrc) + public Builder BuildMetadata(int index, Schema schemaSrc, int indexSrc, params string[] kindsSrc) { Contracts.CheckValue(schemaSrc, nameof(schemaSrc)); Contracts.CheckParam(Utils.Size(kindsSrc) >= 2, nameof(kindsSrc)); @@ -388,7 +385,7 @@ public sealed class Builder : IDisposable /// allow restricting to an outer class. /// internal Builder(MetadataDispatcher md, int index, - ISchema schemaSrc = null, int indexSrc = -1, Func filterSrc = null) + Schema schemaSrc = null, int indexSrc = -1, Func filterSrc = null) { Contracts.CheckValue(md, nameof(md)); Contracts.CheckParam(0 <= index && index < md.ColCount, nameof(index)); diff --git a/src/Microsoft.ML.Data/Transforms/NAFilter.cs b/src/Microsoft.ML.Data/Transforms/NAFilter.cs index 280546f785..1442de304d 100644 --- a/src/Microsoft.ML.Data/Transforms/NAFilter.cs +++ b/src/Microsoft.ML.Data/Transforms/NAFilter.cs @@ -3,16 +3,16 @@ // See the LICENSE file in the project root for more information. // REVIEW: As soon as we stop writing sizeof(Float), or when we retire the double builds, we can remove this. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Transforms; using System; using System.Collections.Generic; using System.Reflection; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Transforms; using Float = System.Single; [assembly: LoadableClass(NAFilter.Summary, typeof(NAFilter), typeof(NAFilter.Arguments), typeof(SignatureDataTransform), @@ -111,7 +111,7 @@ public NAFilter(IHostEnvironment env, Arguments args, IDataView input) if (_srcIndexToInfoIndex.ContainsKey(index)) throw Host.ExceptUserArg(nameof(args.Column), "Source column '{0}' specified multiple times", src); - var type = schema.GetColumnType(index); + var type = schema[index].Type; if (!TestType(type)) throw Host.ExceptUserArg(nameof(args.Column), $"Column '{src}' has type {type} which does not support missing values, so we cannot filter on them", src); @@ -147,7 +147,7 @@ public NAFilter(IHost host, ModelLoadContext ctx, IDataView input) if (_srcIndexToInfoIndex.ContainsKey(index)) throw Host.Except("Source column '{0}' specified multiple times", src); - var type = schema.GetColumnType(index); + var type = schema[index].Type; if (!TestType(type)) throw Host.Except($"Column '{src}' has type {type} which does not support missing values, so we cannot filter on them", src); @@ -180,7 +180,7 @@ public override void Save(ModelSaveContext ctx) Host.Assert(_infos.Length > 0); ctx.Writer.Write(_infos.Length); foreach (var info in _infos) - ctx.SaveNonEmptyString(Source.Schema.GetColumnName(info.Index)); + ctx.SaveNonEmptyString(Source.Schema[info.Index].Name); } private static bool TestType(ColumnType type) @@ -204,7 +204,7 @@ private static bool TestType(ColumnType type) return null; } - protected override IRowCursor GetRowCursorCore(Func predicate, IRandom rand = null) + protected override RowCursor GetRowCursorCore(Func predicate, Random rand = null) { Host.AssertValue(predicate, "predicate"); Host.AssertValueOrNull(rand); @@ -212,32 +212,31 @@ protected override IRowCursor GetRowCursorCore(Func predicate, IRando bool[] active; Func inputPred = GetActive(predicate, out active); var input = Source.GetRowCursor(inputPred, rand); - return new RowCursor(this, input, active); + return new Cursor(this, input, active); } - public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, - Func predicate, int n, IRandom rand = null) + public override RowCursor[] GetRowCursorSet(Func predicate, int n, Random rand = null) { Host.CheckValue(predicate, nameof(predicate)); Host.CheckValueOrNull(rand); bool[] active; Func inputPred = GetActive(predicate, out active); - var inputs = Source.GetRowCursorSet(out consolidator, inputPred, n, rand); + var inputs = Source.GetRowCursorSet(inputPred, n, rand); Host.AssertNonEmpty(inputs); // No need to split if this is given 1 input cursor. - var cursors = new IRowCursor[inputs.Length]; + var cursors = new RowCursor[inputs.Length]; for (int i = 0; i < inputs.Length; i++) - cursors[i] = new RowCursor(this, inputs[i], active); + cursors[i] = new Cursor(this, inputs[i], active); return cursors; } private Func GetActive(Func predicate, out bool[] active) { Host.AssertValue(predicate); - active = new bool[Source.Schema.ColumnCount]; - bool[] activeInput = new bool[Source.Schema.ColumnCount]; + active = new bool[Source.Schema.Count]; + bool[] activeInput = new bool[Source.Schema.Count]; for (int i = 0; i < active.Length; i++) activeInput[i] = active[i] = predicate(i); for (int i = 0; i < _infos.Length; i++) @@ -245,13 +244,13 @@ private Func GetActive(Func predicate, out bool[] active) return col => activeInput[col]; } - private sealed class RowCursor : LinkedRowFilterCursorBase + private sealed class Cursor : LinkedRowFilterCursorBase { private abstract class Value { - protected readonly RowCursor Cursor; + protected readonly Cursor Cursor; - protected Value(RowCursor cursor) + protected Value(Cursor cursor) { Contracts.AssertValue(cursor); Cursor = cursor; @@ -261,7 +260,7 @@ protected Value(RowCursor cursor) public abstract Delegate GetGetter(); - public static Value Create(RowCursor cursor, ColInfo info) + public static Value Create(Cursor cursor, ColInfo info) { Contracts.AssertValue(cursor); Contracts.AssertValue(info); @@ -269,18 +268,18 @@ public static Value Create(RowCursor cursor, ColInfo info) MethodInfo meth; if (info.Type is VectorType vecType) { - Func> d = CreateVec; + Func> d = CreateVec; meth = d.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(vecType.ItemType.RawType); } else { - Func> d = CreateOne; + Func> d = CreateOne; meth = d.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(info.Type.RawType); } return (Value)meth.Invoke(null, new object[] { cursor, info }); } - private static ValueOne CreateOne(RowCursor cursor, ColInfo info) + private static ValueOne CreateOne(Cursor cursor, ColInfo info) { Contracts.AssertValue(cursor); Contracts.AssertValue(info); @@ -288,11 +287,11 @@ private static ValueOne CreateOne(RowCursor cursor, ColInfo info) Contracts.Assert(info.Type.RawType == typeof(T)); var getSrc = cursor.Input.GetGetter(info.Index); - var hasBad = Runtime.Data.Conversion.Conversions.Instance.GetIsNAPredicate(info.Type); + var hasBad = Data.Conversion.Conversions.Instance.GetIsNAPredicate(info.Type); return new ValueOne(cursor, getSrc, hasBad); } - private static ValueVec CreateVec(RowCursor cursor, ColInfo info) + private static ValueVec CreateVec(Cursor cursor, ColInfo info) { Contracts.AssertValue(cursor); Contracts.AssertValue(info); @@ -300,7 +299,7 @@ private static ValueVec CreateVec(RowCursor cursor, ColInfo info) Contracts.Assert(info.Type.RawType == typeof(VBuffer)); var getSrc = cursor.Input.GetGetter>(info.Index); - var hasBad = Runtime.Data.Conversion.Conversions.Instance.GetHasMissingPredicate((VectorType)info.Type); + var hasBad = Data.Conversion.Conversions.Instance.GetHasMissingPredicate((VectorType)info.Type); return new ValueVec(cursor, getSrc, hasBad); } @@ -310,7 +309,7 @@ private abstract class TypedValue : Value private readonly InPredicate _hasBad; public T Src; - protected TypedValue(RowCursor cursor, ValueGetter getSrc, InPredicate hasBad) + protected TypedValue(Cursor cursor, ValueGetter getSrc, InPredicate hasBad) : base(cursor) { Contracts.AssertValue(getSrc); @@ -330,7 +329,7 @@ private sealed class ValueOne : TypedValue { private readonly ValueGetter _getter; - public ValueOne(RowCursor cursor, ValueGetter getSrc, InPredicate hasBad) + public ValueOne(Cursor cursor, ValueGetter getSrc, InPredicate hasBad) : base(cursor, getSrc, hasBad) { _getter = GetValue; @@ -352,7 +351,7 @@ private sealed class ValueVec : TypedValue> { private readonly ValueGetter> _getter; - public ValueVec(RowCursor cursor, ValueGetter> getSrc, InPredicate> hasBad) + public ValueVec(Cursor cursor, ValueGetter> getSrc, InPredicate> hasBad) : base(cursor, getSrc, hasBad) { _getter = GetValue; @@ -374,8 +373,8 @@ public override Delegate GetGetter() private readonly NAFilter _parent; private readonly Value[] _values; - public RowCursor(NAFilter parent, IRowCursor input, bool[] active) - : base(parent.Host, input, parent.Schema, active) + public Cursor(NAFilter parent, RowCursor input, bool[] active) + : base(parent.Host, input, parent.OutputSchema, active) { _parent = parent; _values = new Value[_parent._infos.Length]; diff --git a/src/Microsoft.ML.Data/Transforms/NopTransform.cs b/src/Microsoft.ML.Data/Transforms/NopTransform.cs index 2384c4abb4..128df8b36a 100644 --- a/src/Microsoft.ML.Data/Transforms/NopTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/NopTransform.cs @@ -3,18 +3,17 @@ // See the LICENSE file in the project root for more information. using System; +using Microsoft.ML; using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Model; [assembly: LoadableClass(NopTransform.Summary, typeof(NopTransform), null, typeof(SignatureLoadDataTransform), "", NopTransform.LoaderSignature)] [assembly: LoadableClass(typeof(void), typeof(NopTransform), null, typeof(SignatureEntryPointModule), "NopTransform")] -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// A transform that does nothing. @@ -102,21 +101,30 @@ public bool CanShuffle get { return Source.CanShuffle; } } - public Schema Schema => Source.Schema; + /// + /// Explicit implementation prevents Schema from being accessed from derived classes. + /// It's our first step to separate data produced by transform from transform. + /// + Schema IDataView.Schema => OutputSchema; + + /// + /// Shape information of the produced output. Note that the input and the output of this transform (and their types) are identical. + /// + public Schema OutputSchema => Source.Schema; public long? GetRowCount() { return Source.GetRowCount(); } - public IRowCursor GetRowCursor(Func predicate, IRandom rand = null) + public RowCursor GetRowCursor(Func predicate, Random rand = null) { return Source.GetRowCursor(predicate, rand); } - public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, IRandom rand = null) + public RowCursor[] GetRowCursorSet(Func predicate, int n, Random rand = null) { - return Source.GetRowCursorSet(out consolidator, predicate, n, rand); + return Source.GetRowCursorSet(predicate, n, rand); } public Func GetDependencies(Func predicate) @@ -124,13 +132,11 @@ public Func GetDependencies(Func predicate) return predicate; } - public IRow GetRow(IRow input, Func active, out Action disposer) + public Row GetRow(Row input, Func active) { Contracts.CheckValue(input, nameof(input)); Contracts.CheckValue(active, nameof(active)); Contracts.CheckParam(input.Schema == Source.Schema, nameof(input), "Schema of input row must be the same as the schema the mapper is bound to"); - - disposer = null; return input; } @@ -147,7 +153,7 @@ public static CommonOutputs.TransformOutput Nop(IHostEnvironment env, NopInput i EntryPointUtils.CheckInputArgs(host, input); var xf = CreateIfNeeded(host, input.Data); - return new CommonOutputs.TransformOutput { Model = new TransformModel(env, xf, input.Data), OutputData = xf }; + return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, xf, input.Data), OutputData = xf }; } } } \ No newline at end of file diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs index 7c3bdd363c..1d0ee991d1 100644 --- a/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs +++ b/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs @@ -2,21 +2,21 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Internallearn; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Model.Onnx; -using Microsoft.ML.Runtime.Model.Pfa; -using Microsoft.ML.Transforms.Normalizers; -using Newtonsoft.Json.Linq; using System; using System.Collections.Generic; using System.Collections.Immutable; using System.Linq; using System.Text; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Model; +using Microsoft.ML.Model.Onnx; +using Microsoft.ML.Model.Pfa; +using Microsoft.ML.Transforms.Normalizers; +using Newtonsoft.Json.Linq; [assembly: LoadableClass(NormalizeTransform.MinMaxNormalizerSummary, typeof(IDataTransform), typeof(NormalizeTransform), typeof(NormalizeTransform.MinMaxArguments), typeof(SignatureDataTransform), NormalizeTransform.MinMaxNormalizerUserName, "MinMaxNormalizer", NormalizeTransform.MinMaxNormalizerShortName)] @@ -361,7 +361,7 @@ private AffineColumnFunction(IHost host) public bool CanSaveOnnx(OnnxContext ctx) => true; public abstract bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int featureCount); - public abstract Delegate GetGetter(IRow input, int icol); + public abstract Delegate GetGetter(Row input, int icol); public abstract void AttachMetadata(MetadataDispatcher.Builder bldr, ColumnType typeSrc); @@ -480,7 +480,7 @@ private CdfColumnFunction(IHost host) public bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int featureCount) => throw Host.ExceptNotSupp(); - public abstract Delegate GetGetter(IRow input, int icol); + public abstract Delegate GetGetter(Row input, int icol); public abstract void AttachMetadata(MetadataDispatcher.Builder bldr, ColumnType typeSrc); public abstract NormalizingTransformer.NormalizerModelParametersBase GetNormalizerModelParams(); @@ -609,7 +609,7 @@ protected BinColumnFunction(IHost host) public bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int featureCount) => throw Host.ExceptNotSupp(); - public abstract Delegate GetGetter(IRow input, int icol); + public abstract Delegate GetGetter(Row input, int icol); public void AttachMetadata(MetadataDispatcher.Builder bldr, ColumnType typeSrc) { @@ -732,7 +732,7 @@ private abstract class SupervisedBinFunctionBuilderBase : IColumnFunctionBuilder protected readonly int LabelCardinality; private readonly ValueGetter _labelGetterSrc; - protected SupervisedBinFunctionBuilderBase(IHost host, long lim, int labelColId, IRow dataRow) + protected SupervisedBinFunctionBuilderBase(IHost host, long lim, int labelColId, Row dataRow) { Contracts.CheckValue(host, nameof(host)); Host = host; @@ -742,10 +742,10 @@ protected SupervisedBinFunctionBuilderBase(IHost host, long lim, int labelColId, _labelGetterSrc = GetLabelGetter(dataRow, labelColId, out LabelCardinality); } - private ValueGetter GetLabelGetter(IRow row, int col, out int labelCardinality) + private ValueGetter GetLabelGetter(Row row, int col, out int labelCardinality) { // The label column type is checked as part of args validation. - var type = row.Schema.GetColumnType(col); + var type = row.Schema[col].Type; Host.Assert(type.IsKey || type.IsNumber); if (type.IsKey) @@ -816,7 +816,7 @@ private abstract class OneColumnSupervisedBinFunctionBuilderBase : Super protected readonly List ColValues; protected OneColumnSupervisedBinFunctionBuilderBase(IHost host, long lim, int valueColId, int labelColId, - IRow dataRow) + Row dataRow) : base(host, lim, labelColId, dataRow) { _colGetterSrc = dataRow.GetGetter(valueColId); @@ -844,11 +844,11 @@ private abstract class VecColumnSupervisedBinFunctionBuilderBase : Super protected readonly List[] ColValues; protected readonly int ColumnSlotCount; - protected VecColumnSupervisedBinFunctionBuilderBase(IHost host, long lim, int valueColId, int labelColId, IRow dataRow) + protected VecColumnSupervisedBinFunctionBuilderBase(IHost host, long lim, int valueColId, int labelColId, Row dataRow) : base(host, lim, labelColId, dataRow) { _colValueGetter = dataRow.GetGetter>(valueColId); - var valueColType = dataRow.Schema.GetColumnType(valueColId); + var valueColType = dataRow.Schema[valueColId].Type; Host.Assert(valueColType.IsKnownSizeVector); ColumnSlotCount = valueColType.ValueCount; @@ -899,7 +899,7 @@ protected override bool AcceptColumnValue() internal static partial class MinMaxUtils { public static IColumnFunctionBuilder CreateBuilder(MinMaxArguments args, IHost host, - int icol, int srcIndex, ColumnType srcType, IRowCursor cursor) + int icol, int srcIndex, ColumnType srcType, RowCursor cursor) { Contracts.AssertValue(host); host.AssertValue(args); @@ -912,7 +912,7 @@ public static IColumnFunctionBuilder CreateBuilder(MinMaxArguments args, IHost h } public static IColumnFunctionBuilder CreateBuilder(NormalizingEstimator.MinMaxColumn column, IHost host, - int srcIndex, ColumnType srcType, IRowCursor cursor) + int srcIndex, ColumnType srcType, RowCursor cursor) { if (srcType.IsNumber) { @@ -935,7 +935,7 @@ public static IColumnFunctionBuilder CreateBuilder(NormalizingEstimator.MinMaxCo internal static partial class MeanVarUtils { public static IColumnFunctionBuilder CreateBuilder(MeanVarArguments args, IHost host, - int icol, int srcIndex, ColumnType srcType, IRowCursor cursor) + int icol, int srcIndex, ColumnType srcType, RowCursor cursor) { Contracts.AssertValue(host); host.AssertValue(args); @@ -949,7 +949,7 @@ public static IColumnFunctionBuilder CreateBuilder(MeanVarArguments args, IHost } public static IColumnFunctionBuilder CreateBuilder(NormalizingEstimator.MeanVarColumn column, IHost host, - int srcIndex, ColumnType srcType, IRowCursor cursor) + int srcIndex, ColumnType srcType, RowCursor cursor) { Contracts.AssertValue(host); @@ -975,7 +975,7 @@ public static IColumnFunctionBuilder CreateBuilder(NormalizingEstimator.MeanVarC internal static partial class LogMeanVarUtils { public static IColumnFunctionBuilder CreateBuilder(LogMeanVarArguments args, IHost host, - int icol, int srcIndex, ColumnType srcType, IRowCursor cursor) + int icol, int srcIndex, ColumnType srcType, RowCursor cursor) { Contracts.AssertValue(host); host.AssertValue(args); @@ -988,7 +988,7 @@ public static IColumnFunctionBuilder CreateBuilder(LogMeanVarArguments args, IHo } public static IColumnFunctionBuilder CreateBuilder(NormalizingEstimator.LogMeanVarColumn column, IHost host, - int srcIndex, ColumnType srcType, IRowCursor cursor) + int srcIndex, ColumnType srcType, RowCursor cursor) { Contracts.AssertValue(host); host.AssertValue(column); @@ -1014,7 +1014,7 @@ public static IColumnFunctionBuilder CreateBuilder(NormalizingEstimator.LogMeanV internal static partial class BinUtils { public static IColumnFunctionBuilder CreateBuilder(BinArguments args, IHost host, - int icol, int srcIndex, ColumnType srcType, IRowCursor cursor) + int icol, int srcIndex, ColumnType srcType, RowCursor cursor) { Contracts.AssertValue(host); host.AssertValue(args); @@ -1028,7 +1028,7 @@ public static IColumnFunctionBuilder CreateBuilder(BinArguments args, IHost host } public static IColumnFunctionBuilder CreateBuilder(NormalizingEstimator.BinningColumn column, IHost host, - int srcIndex, ColumnType srcType, IRowCursor cursor) + int srcIndex, ColumnType srcType, RowCursor cursor) { Contracts.AssertValue(host); @@ -1053,7 +1053,7 @@ public static IColumnFunctionBuilder CreateBuilder(NormalizingEstimator.BinningC internal static class SupervisedBinUtils { public static IColumnFunctionBuilder CreateBuilder(SupervisedBinArguments args, IHost host, - int icol, int srcIndex, ColumnType srcType, IRowCursor cursor) + int icol, int srcIndex, ColumnType srcType, RowCursor cursor) { Contracts.AssertValue(host); host.AssertValue(args); @@ -1061,31 +1061,57 @@ public static IColumnFunctionBuilder CreateBuilder(SupervisedBinArguments args, // checking for label column host.CheckUserArg(!string.IsNullOrWhiteSpace(args.LabelColumn), nameof(args.LabelColumn), "Must specify the label column name"); int labelColumnId = GetLabelColumnId(host, cursor.Schema, args.LabelColumn); - var labelColumnType = cursor.Schema.GetColumnType(labelColumnId); + var labelColumnType = cursor.Schema[labelColumnId].Type; if (labelColumnType.IsKey) host.CheckUserArg(labelColumnType.KeyCount > 0, nameof(args.LabelColumn), "Label column must have a known cardinality"); else host.CheckUserArg(labelColumnType.IsNumber, nameof(args.LabelColumn), "Label column must be a number or a key type"); + return CreateBuilder( + new NormalizingEstimator.SupervisedBinningColumn( + args.Column[icol].Source ?? args.Column[icol].Name, + args.Column[icol].Name, + args.LabelColumn ?? DefaultColumnNames.Label, + args.Column[icol].MaxTrainingExamples ?? args.MaxTrainingExamples, + args.Column[icol].FixZero ?? args.FixZero, + args.Column[icol].NumBins ?? args.NumBins, + args.MinBinSize), + host, labelColumnId, srcIndex, srcType, cursor); + } + + public static IColumnFunctionBuilder CreateBuilder(NormalizingEstimator.SupervisedBinningColumn column, IHost host, + string labelColumn, int srcIndex, ColumnType srcType, RowCursor cursor) + { + int labelColumnId = GetLabelColumnId(host, cursor.Schema, labelColumn); + return CreateBuilder(column, host, labelColumnId, srcIndex, srcType, cursor); + } + + private static IColumnFunctionBuilder CreateBuilder(NormalizingEstimator.SupervisedBinningColumn column, IHost host, + int labelColumnId, int srcIndex, ColumnType srcType, RowCursor cursor) + { + Contracts.AssertValue(host); + if (srcType.IsNumber) { if (srcType == NumberType.R4) - return Sng.SupervisedBinOneColumnFunctionBuilder.Create(args, host, icol, srcIndex, labelColumnId, cursor); + return Sng.SupervisedBinOneColumnFunctionBuilder.Create(column, host, srcIndex, labelColumnId, cursor); if (srcType == NumberType.R8) - return Dbl.SupervisedBinOneColumnFunctionBuilder.Create(args, host, icol, srcIndex, labelColumnId, cursor); + return Dbl.SupervisedBinOneColumnFunctionBuilder.Create(column, host, srcIndex, labelColumnId, cursor); } if (srcType.IsVector && srcType.ItemType.IsNumber) { if (srcType.ItemType == NumberType.R4) - return Sng.SupervisedBinVecColumnFunctionBuilder.Create(args, host, icol, srcIndex, labelColumnId, cursor); + return Sng.SupervisedBinVecColumnFunctionBuilder.Create(column, host, srcIndex, labelColumnId, cursor); if (srcType.ItemType == NumberType.R8) - return Dbl.SupervisedBinVecColumnFunctionBuilder.Create(args, host, icol, srcIndex, labelColumnId, cursor); + return Dbl.SupervisedBinVecColumnFunctionBuilder.Create(column, host, srcIndex, labelColumnId, cursor); } - throw host.ExceptUserArg(nameof(args.Column), "Wrong column type for column {0}. Expected: R4, R8, Vec or Vec. Got: {1}.", args.Column[icol].Source, srcType.ToString()); + throw host.ExceptParam(nameof(column), "Wrong column type for column {0}. Expected: R4, R8, Vec or Vec. Got: {1}.", + column.Input, + srcType.ToString()); } - public static int GetLabelColumnId(IExceptionContext host, ISchema schema, string labelColumnName) + public static int GetLabelColumnId(IExceptionContext host, Schema schema, string labelColumnName) { Contracts.AssertValue(host); host.AssertValue(schema); diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs index 5792486012..3c804e754a 100644 --- a/src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs +++ b/src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs @@ -2,18 +2,17 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Model.Onnx; -using Microsoft.ML.Runtime.Model.Pfa; -using Newtonsoft.Json.Linq; using System; using System.Collections.Generic; using System.Collections.Immutable; using System.Linq; using System.Runtime.CompilerServices; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Model.Onnx; +using Microsoft.ML.Model.Pfa; +using Newtonsoft.Json.Linq; namespace Microsoft.ML.Transforms.Normalizers { @@ -583,7 +582,7 @@ public override bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int fe return true; } - public override Delegate GetGetter(IRow input, int icol) + public override Delegate GetGetter(Row input, int icol) { var getSrc = input.GetGetter(icol); ValueGetter del = @@ -659,7 +658,7 @@ public override bool OnnxInfo(OnnxContext ctx, OnnxNode node, int featureCount) return true; } - public override Delegate GetGetter(IRow input, int icol) + public override Delegate GetGetter(Row input, int icol) { var getSrc = input.GetGetter>(icol); var bldr = new BufferBuilder(R8Adder.Instance); @@ -901,7 +900,7 @@ public override void Save(ModelSaveContext ctx) CdfNormSerializationUtils.SaveModel(ctx, UseLog, new[] { Mean }, new[] { Stddev }); } - public override Delegate GetGetter(IRow input, int icol) + public override Delegate GetGetter(Row input, int icol) { if (Stddev <= TFloat.Epsilon) { @@ -956,7 +955,7 @@ public override void Save(ModelSaveContext ctx) CdfNormSerializationUtils.SaveModel(ctx, UseLog, Mean, Stddev); } - public override Delegate GetGetter(IRow input, int icol) + public override Delegate GetGetter(Row input, int icol) { var getSrc = input.GetGetter>(icol); var bldr = new BufferBuilder(R8Adder.Instance); @@ -1085,7 +1084,7 @@ public override void Save(ModelSaveContext ctx) c => BinNormSerializationUtils.SaveModel(c, new[] { _binUpperBounds }, saveText: true)); } - public override Delegate GetGetter(IRow input, int icol) + public override Delegate GetGetter(Row input, int icol) { var getSrc = input.GetGetter(icol); ValueGetter del = @@ -1170,7 +1169,7 @@ public override void Save(ModelSaveContext ctx) ctx.SaveSubModel("BinNormalizer", c => BinNormSerializationUtils.SaveModel(c, _binUpperBounds, saveText: true)); } - public override Delegate GetGetter(IRow input, int icol) + public override Delegate GetGetter(Row input, int icol) { var getSrc = input.GetGetter>(icol); var bldr = new BufferBuilder(R8Adder.Instance); @@ -1842,7 +1841,7 @@ public sealed class SupervisedBinOneColumnFunctionBuilder : OneColumnSupervisedB private readonly int _numBins; private readonly int _minBinSize; - private SupervisedBinOneColumnFunctionBuilder(IHost host, long lim, bool fix, int numBins, int minBinSize, int valueColumnId, int labelColumnId, IRow dataRow) + private SupervisedBinOneColumnFunctionBuilder(IHost host, long lim, bool fix, int numBins, int minBinSize, int valueColumnId, int labelColumnId, Row dataRow) : base(host, lim, valueColumnId, labelColumnId, dataRow) { _fix = fix; @@ -1862,15 +1861,15 @@ public override IColumnFunction CreateColumnFunction() return BinColumnFunction.Create(Host, binUpperBounds, _fix); } - public static IColumnFunctionBuilder Create(SupervisedBinArguments args, IHost host, int argsColumnIndex, int valueColumnId, int labelColumnId, IRow dataRow) + public static IColumnFunctionBuilder Create(NormalizingEstimator.SupervisedBinningColumn column, IHost host, int valueColumnId, int labelColumnId, Row dataRow) { - var lim = args.Column[argsColumnIndex].MaxTrainingExamples ?? args.MaxTrainingExamples; - host.CheckUserArg(lim > 1, nameof(args.MaxTrainingExamples), "Must be greater than 1"); - bool fix = args.Column[argsColumnIndex].FixZero ?? args.FixZero; - var numBins = args.Column[argsColumnIndex].NumBins ?? args.NumBins; - host.CheckUserArg(numBins > 1, nameof(args.NumBins), "Must be greater than 1"); - host.CheckUserArg(args.MinBinSize > 0, nameof(args.MinBinSize), "Must be positive"); - return new SupervisedBinOneColumnFunctionBuilder(host, lim, fix, numBins, args.MinBinSize, valueColumnId, labelColumnId, dataRow); + var lim = column.MaxTrainingExamples; + host.CheckUserArg(lim > 1, nameof(column.MaxTrainingExamples), "Must be greater than 1"); + bool fix = column.FixZero; + var numBins = column.NumBins; + host.CheckUserArg(numBins > 1, nameof(column.NumBins), "Must be greater than 1"); + host.CheckUserArg(column.MinBinSize > 0, nameof(column.MinBinSize), "Must be positive"); + return new SupervisedBinOneColumnFunctionBuilder(host, lim, fix, numBins, column.MinBinSize, valueColumnId, labelColumnId, dataRow); } } @@ -1880,7 +1879,7 @@ public sealed class SupervisedBinVecColumnFunctionBuilder : VecColumnSupervisedB private readonly int _numBins; private readonly int _minBinSize; - private SupervisedBinVecColumnFunctionBuilder(IHost host, long lim, bool fix, int numBins, int minBinSize, int valueColumnId, int labelColumnId, IRow dataRow) + private SupervisedBinVecColumnFunctionBuilder(IHost host, long lim, bool fix, int numBins, int minBinSize, int valueColumnId, int labelColumnId, Row dataRow) : base(host, lim, valueColumnId, labelColumnId, dataRow) { _fix = fix; @@ -1902,15 +1901,15 @@ public override IColumnFunction CreateColumnFunction() return BinColumnFunction.Create(Host, binUpperBounds, _fix); } - public static IColumnFunctionBuilder Create(SupervisedBinArguments args, IHost host, int argsColumnIndex, int valueColumnId, int labelColumnId, IRow dataRow) + public static IColumnFunctionBuilder Create(NormalizingEstimator.SupervisedBinningColumn column, IHost host, int valueColumnId, int labelColumnId, Row dataRow) { - var lim = args.Column[argsColumnIndex].MaxTrainingExamples ?? args.MaxTrainingExamples; - host.CheckUserArg(lim > 1, nameof(args.MaxTrainingExamples), "Must be greater than 1"); - bool fix = args.Column[argsColumnIndex].FixZero ?? args.FixZero; - var numBins = args.Column[argsColumnIndex].NumBins ?? args.NumBins; - host.CheckUserArg(numBins > 1, nameof(args.NumBins), "Must be greater than 1"); - host.CheckUserArg(args.MinBinSize > 0, nameof(args.MinBinSize), "Must be positive"); - return new SupervisedBinVecColumnFunctionBuilder(host, lim, fix, numBins, args.MinBinSize, valueColumnId, labelColumnId, dataRow); + var lim = column.MaxTrainingExamples; + host.CheckUserArg(lim > 1, nameof(column.MaxTrainingExamples), "Must be greater than 1"); + bool fix = column.FixZero; + var numBins = column.NumBins; + host.CheckUserArg(numBins > 1, nameof(column.NumBins), "Must be greater than 1"); + host.CheckUserArg(column.MinBinSize > 0, nameof(column.MinBinSize), "Must be positive"); + return new SupervisedBinVecColumnFunctionBuilder(host, lim, fix, numBins, column.MinBinSize, valueColumnId, labelColumnId, dataRow); } } } diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs index 54e6693f76..cb84845988 100644 --- a/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs +++ b/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs @@ -2,18 +2,17 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Model.Onnx; -using Microsoft.ML.Runtime.Model.Pfa; -using Newtonsoft.Json.Linq; using System; using System.Collections.Generic; using System.Collections.Immutable; using System.Linq; using System.Runtime.CompilerServices; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Model.Onnx; +using Microsoft.ML.Model.Pfa; +using Newtonsoft.Json.Linq; namespace Microsoft.ML.Transforms.Normalizers { @@ -316,7 +315,7 @@ public static void LoadModel(ModelLoadContext ctx, int cv, out bool useLog, out /// It tracks min, max, number of non-sparse values (vCount) and number of ProcessValue() calls (trainCount). /// NaNs are ignored when updating min and max. /// - public sealed class MinMaxSngAggregator : IColumnAggregator> + internal sealed class MinMaxSngAggregator : IColumnAggregator> { private readonly TFloat[] _min; private readonly TFloat[] _max; @@ -585,7 +584,7 @@ public override bool OnnxInfo(OnnxContext ctx, OnnxNode node, int featureCount) return true; } - public override Delegate GetGetter(IRow input, int icol) + public override Delegate GetGetter(Row input, int icol) { var getSrc = input.GetGetter(icol); ValueGetter del = @@ -660,7 +659,7 @@ public override bool OnnxInfo(OnnxContext ctx, OnnxNode node, int featureCount) return true; } - public override Delegate GetGetter(IRow input, int icol) + public override Delegate GetGetter(Row input, int icol) { var getSrc = input.GetGetter>(icol); var bldr = new BufferBuilder(R4Adder.Instance); @@ -905,7 +904,7 @@ public override void Save(ModelSaveContext ctx) CdfNormSerializationUtils.SaveModel(ctx, UseLog, new[] { Mean }, new[] { Stddev }); } - public override Delegate GetGetter(IRow input, int icol) + public override Delegate GetGetter(Row input, int icol) { if (Stddev <= TFloat.Epsilon) { @@ -960,7 +959,7 @@ public override void Save(ModelSaveContext ctx) CdfNormSerializationUtils.SaveModel(ctx, UseLog, Mean, Stddev); } - public override Delegate GetGetter(IRow input, int icol) + public override Delegate GetGetter(Row input, int icol) { var getSrc = input.GetGetter>(icol); var bldr = new BufferBuilder(R4Adder.Instance); @@ -1090,7 +1089,7 @@ public override void Save(ModelSaveContext ctx) c => BinNormSerializationUtils.SaveModel(c, new[] { _binUpperBounds }, saveText: true)); } - public override Delegate GetGetter(IRow input, int icol) + public override Delegate GetGetter(Row input, int icol) { var getSrc = input.GetGetter(icol); ValueGetter del = @@ -1175,7 +1174,7 @@ public override void Save(ModelSaveContext ctx) ctx.SaveSubModel("BinNormalizer", c => BinNormSerializationUtils.SaveModel(c, _binUpperBounds, saveText: true)); } - public override Delegate GetGetter(IRow input, int icol) + public override Delegate GetGetter(Row input, int icol) { var getSrc = input.GetGetter>(icol); var bldr = new BufferBuilder(R4Adder.Instance); @@ -1850,7 +1849,7 @@ public sealed class SupervisedBinOneColumnFunctionBuilder : OneColumnSupervisedB private readonly int _numBins; private readonly int _minBinSize; - private SupervisedBinOneColumnFunctionBuilder(IHost host, long lim, bool fix, int numBins, int minBinSize, int valueColumnId, int labelColumnId, IRow dataRow) + private SupervisedBinOneColumnFunctionBuilder(IHost host, long lim, bool fix, int numBins, int minBinSize, int valueColumnId, int labelColumnId, Row dataRow) : base(host, lim, valueColumnId, labelColumnId, dataRow) { _fix = fix; @@ -1870,15 +1869,15 @@ public override IColumnFunction CreateColumnFunction() return BinColumnFunction.Create(Host, binUpperBounds, _fix); } - public static IColumnFunctionBuilder Create(SupervisedBinArguments args, IHost host, int argsColumnIndex, int valueColumnId, int labelColumnId, IRow dataRow) + public static IColumnFunctionBuilder Create(NormalizingEstimator.SupervisedBinningColumn column, IHost host, int valueColumnId, int labelColumnId, Row dataRow) { - var lim = args.Column[argsColumnIndex].MaxTrainingExamples ?? args.MaxTrainingExamples; - host.CheckUserArg(lim > 1, nameof(args.MaxTrainingExamples), "Must be greater than 1"); - bool fix = args.Column[argsColumnIndex].FixZero ?? args.FixZero; - var numBins = args.Column[argsColumnIndex].NumBins ?? args.NumBins; - host.CheckUserArg(numBins > 1, nameof(args.NumBins), "Must be greater than 1"); - host.CheckUserArg(args.MinBinSize > 0, nameof(args.MinBinSize), "Must be positive"); - return new SupervisedBinOneColumnFunctionBuilder(host, lim, fix, numBins, args.MinBinSize, valueColumnId, labelColumnId, dataRow); + var lim = column.MaxTrainingExamples; + host.CheckUserArg(lim > 1, nameof(column.MaxTrainingExamples), "Must be greater than 1"); + bool fix = column.FixZero; + var numBins = column.NumBins; + host.CheckUserArg(numBins > 1, nameof(column.NumBins), "Must be greater than 1"); + host.CheckUserArg(column.MinBinSize > 0, nameof(column.MinBinSize), "Must be positive"); + return new SupervisedBinOneColumnFunctionBuilder(host, lim, fix, numBins, column.MinBinSize, valueColumnId, labelColumnId, dataRow); } } @@ -1888,7 +1887,7 @@ public sealed class SupervisedBinVecColumnFunctionBuilder : VecColumnSupervisedB private readonly int _numBins; private readonly int _minBinSize; - private SupervisedBinVecColumnFunctionBuilder(IHost host, long lim, bool fix, int numBins, int minBinSize, int valueColumnId, int labelColumnId, IRow dataRow) + private SupervisedBinVecColumnFunctionBuilder(IHost host, long lim, bool fix, int numBins, int minBinSize, int valueColumnId, int labelColumnId, Row dataRow) : base(host, lim, valueColumnId, labelColumnId, dataRow) { _fix = fix; @@ -1910,15 +1909,15 @@ public override IColumnFunction CreateColumnFunction() return BinColumnFunction.Create(Host, binUpperBounds, _fix); } - public static IColumnFunctionBuilder Create(SupervisedBinArguments args, IHost host, int argsColumnIndex, int valueColumnId, int labelColumnId, IRow dataRow) + public static IColumnFunctionBuilder Create(NormalizingEstimator.SupervisedBinningColumn column, IHost host, int valueColumnId, int labelColumnId, Row dataRow) { - var lim = args.Column[argsColumnIndex].MaxTrainingExamples ?? args.MaxTrainingExamples; - host.CheckUserArg(lim > 1, nameof(args.MaxTrainingExamples), "Must be greater than 1"); - bool fix = args.Column[argsColumnIndex].FixZero ?? args.FixZero; - var numBins = args.Column[argsColumnIndex].NumBins ?? args.NumBins; - host.CheckUserArg(numBins > 1, nameof(args.NumBins), "Must be greater than 1"); - host.CheckUserArg(args.MinBinSize > 0, nameof(args.MinBinSize), "Must be positive"); - return new SupervisedBinVecColumnFunctionBuilder(host, lim, fix, numBins, args.MinBinSize, valueColumnId, labelColumnId, dataRow); + var lim = column.MaxTrainingExamples; + host.CheckUserArg(lim > 1, nameof(column.MaxTrainingExamples), "Must be greater than 1"); + bool fix = column.FixZero; + var numBins = column.NumBins; + host.CheckUserArg(numBins > 1, nameof(column.NumBins), "Must be greater than 1"); + host.CheckUserArg(column.MinBinSize > 0, nameof(column.MinBinSize), "Must be positive"); + return new SupervisedBinVecColumnFunctionBuilder(host, lim, fix, numBins, column.MinBinSize, valueColumnId, labelColumnId, dataRow); } } } diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeUtils.cs b/src/Microsoft.ML.Data/Transforms/NormalizeUtils.cs index cc3f4e352f..1dc57a833e 100644 --- a/src/Microsoft.ML.Data/Transforms/NormalizeUtils.cs +++ b/src/Microsoft.ML.Data/Transforms/NormalizeUtils.cs @@ -2,21 +2,20 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; +using System.Collections.Generic; +using Microsoft.ML; using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Model.Onnx; -using Microsoft.ML.Runtime.Model.Pfa; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Model; +using Microsoft.ML.Model.Onnx; +using Microsoft.ML.Model.Pfa; using Microsoft.ML.Transforms.Normalizers; using Newtonsoft.Json.Linq; -using System; -using System.Collections.Generic; [assembly: LoadableClass(typeof(void), typeof(Normalize), null, typeof(SignatureEntryPointModule), "Normalize")] -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// Signature for a repository based loader of a IColumnFunction @@ -40,7 +39,8 @@ internal interface IColumnFunctionBuilder /// /// Interface to define an aggregate function over values /// - public interface IColumnAggregator + [BestFriend] + internal interface IColumnAggregator { /// /// Updates the aggregate function with a value @@ -53,9 +53,10 @@ public interface IColumnAggregator void Finish(); } + [BestFriend] internal interface IColumnFunction : ICanSaveModel { - Delegate GetGetter(IRow input, int icol); + Delegate GetGetter(Row input, int icol); void AttachMetadata(MetadataDispatcher.Builder bldr, ColumnType typeSrc); @@ -68,27 +69,6 @@ internal interface IColumnFunction : ICanSaveModel NormalizingTransformer.NormalizerModelParametersBase GetNormalizerModelParams(); } - public static class NormalizeUtils - { - /// - /// Returns whether the feature column in the schema is indicated to be normalized. If the features column is not - /// specified on the schema, then this will return null. - /// - /// The role-mapped schema to query - /// Returns null if does not have - /// defined, and otherwise returns a Boolean value as returned from - /// on that feature column - /// - public static bool? FeaturesAreNormalized(this RoleMappedSchema schema) - { - // REVIEW: The role mapped data has the ability to have multiple columns fill the role of features, which is - // useful in some trainers that are nonetheless parameteric and can therefore benefit from normalization. - Contracts.CheckValue(schema, nameof(schema)); - var featInfo = schema.Feature; - return featInfo == null ? default(bool?) : schema.Schema.IsNormalized(featInfo.Index); - } - } - /// /// This contains entry-point definitions related to . /// @@ -103,7 +83,7 @@ public static CommonOutputs.TransformOutput MinMax(IHostEnvironment env, Normali EntryPointUtils.CheckInputArgs(host, input); var xf = NormalizeTransform.Create(host, input, input.Data); - return new CommonOutputs.TransformOutput { Model = new TransformModel(env, xf, input.Data), OutputData = xf }; + return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, xf, input.Data), OutputData = xf }; } [TlcModule.EntryPoint(Name = "Transforms.MeanVarianceNormalizer", Desc = NormalizeTransform.MeanVarNormalizerSummary, UserName = NormalizeTransform.MeanVarNormalizerUserName, ShortName = NormalizeTransform.MeanVarNormalizerShortName)] @@ -115,7 +95,7 @@ public static CommonOutputs.TransformOutput MeanVar(IHostEnvironment env, Normal EntryPointUtils.CheckInputArgs(host, input); var xf = NormalizeTransform.Create(host, input, input.Data); - return new CommonOutputs.TransformOutput { Model = new TransformModel(env, xf, input.Data), OutputData = xf }; + return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, xf, input.Data), OutputData = xf }; } [TlcModule.EntryPoint(Name = "Transforms.LogMeanVarianceNormalizer", Desc = NormalizeTransform.LogMeanVarNormalizerSummary, UserName = NormalizeTransform.LogMeanVarNormalizerUserName, ShortName = NormalizeTransform.LogMeanVarNormalizerShortName)] @@ -127,7 +107,7 @@ public static CommonOutputs.TransformOutput LogMeanVar(IHostEnvironment env, Nor EntryPointUtils.CheckInputArgs(host, input); var xf = NormalizeTransform.Create(host, input, input.Data); - return new CommonOutputs.TransformOutput { Model = new TransformModel(env, xf, input.Data), OutputData = xf }; + return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, xf, input.Data), OutputData = xf }; } [TlcModule.EntryPoint(Name = "Transforms.BinNormalizer", Desc = NormalizeTransform.BinNormalizerSummary, UserName = NormalizeTransform.BinNormalizerUserName, ShortName = NormalizeTransform.BinNormalizerShortName)] @@ -139,7 +119,7 @@ public static CommonOutputs.TransformOutput Bin(IHostEnvironment env, NormalizeT EntryPointUtils.CheckInputArgs(host, input); var xf = NormalizeTransform.Create(host, input, input.Data); - return new CommonOutputs.TransformOutput { Model = new TransformModel(env, xf, input.Data), OutputData = xf }; + return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, xf, input.Data), OutputData = xf }; } [TlcModule.EntryPoint(Name = "Transforms.ConditionalNormalizer", Desc = "Normalize the columns only if needed", UserName = "Normalize If Needed")] @@ -154,7 +134,7 @@ public static CommonOutputs.TransformOutput Bin(IHostEnvironment env, NormalizeT { if (!schema.TryGetColumnIndex(column.Source, out int col)) throw env.ExceptUserArg(nameof(input.Column), $"Column '{column.Source}' does not exist."); - if (!schema.IsNormalized(col)) + if (!schema[col].IsNormalized()) columnsToNormalize.Add(column); } diff --git a/src/Microsoft.ML.Data/Transforms/Normalizer.cs b/src/Microsoft.ML.Data/Transforms/Normalizer.cs index 1434b680a7..2dae36c95a 100644 --- a/src/Microsoft.ML.Data/Transforms/Normalizer.cs +++ b/src/Microsoft.ML.Data/Transforms/Normalizer.cs @@ -2,20 +2,20 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Core.Data; -using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Model.Onnx; -using Microsoft.ML.Runtime.Model.Pfa; -using Microsoft.ML.Transforms.Normalizers; -using Newtonsoft.Json.Linq; using System; using System.Collections; using System.Collections.Generic; using System.Collections.Immutable; using System.Linq; +using Microsoft.ML; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data; +using Microsoft.ML.Model; +using Microsoft.ML.Model.Onnx; +using Microsoft.ML.Model.Pfa; +using Microsoft.ML.Transforms.Normalizers; +using Newtonsoft.Json.Linq; +using static Microsoft.ML.Transforms.Normalizers.NormalizeTransform; [assembly: LoadableClass(typeof(NormalizingTransformer), null, typeof(SignatureLoadModel), "", NormalizingTransformer.LoaderSignature)] @@ -23,10 +23,14 @@ [assembly: LoadableClass(typeof(IRowMapper), typeof(NormalizingTransformer), null, typeof(SignatureLoadRowMapper), "", NormalizingTransformer.LoaderSignature)] +[assembly: LoadableClass(typeof(IDataTransform), typeof(NormalizingTransformer), null, typeof(SignatureLoadDataTransform), + "", NormalizingTransformer.LoaderSignature, "NormalizeTransform")] + namespace Microsoft.ML.Transforms.Normalizers { public sealed class NormalizingEstimator : IEstimator { + [BestFriend] internal static class Defaults { public const bool FixZero = true; @@ -54,7 +58,11 @@ public enum NormalizerMode /// /// Bucketize and then rescale to between -1 and 1. /// - Binning = 3 + Binning = 3, + /// + /// Bucketize and then rescale to between -1 and 1. Calculates bins based on correlation with the Label column. + /// + SupervisedBinning = 4 } public abstract class ColumnBase @@ -74,7 +82,7 @@ private protected ColumnBase(string input, string output, long maxTrainingExampl MaxTrainingExamples = maxTrainingExamples; } - internal abstract IColumnFunctionBuilder MakeBuilder(IHost host, int srcIndex, ColumnType srcType, IRowCursor cursor); + internal abstract IColumnFunctionBuilder MakeBuilder(IHost host, int srcIndex, ColumnType srcType, RowCursor cursor); internal static ColumnBase Create(string input, string output, NormalizerMode mode) { @@ -88,6 +96,8 @@ internal static ColumnBase Create(string input, string output, NormalizerMode mo return new LogMeanVarColumn(input, output); case NormalizerMode.Binning: return new BinningColumn(input, output); + case NormalizerMode.SupervisedBinning: + return new SupervisedBinningColumn(input, output); default: throw Contracts.ExceptParam(nameof(mode), "Unknown normalizer mode"); } @@ -112,7 +122,7 @@ public MinMaxColumn(string input, string output = null, long maxTrainingExamples { } - internal override IColumnFunctionBuilder MakeBuilder(IHost host, int srcIndex, ColumnType srcType, IRowCursor cursor) + internal override IColumnFunctionBuilder MakeBuilder(IHost host, int srcIndex, ColumnType srcType, RowCursor cursor) => NormalizeTransform.MinMaxUtils.CreateBuilder(this, host, srcIndex, srcType, cursor); } @@ -127,7 +137,7 @@ public MeanVarColumn(string input, string output = null, UseCdf = useCdf; } - internal override IColumnFunctionBuilder MakeBuilder(IHost host, int srcIndex, ColumnType srcType, IRowCursor cursor) + internal override IColumnFunctionBuilder MakeBuilder(IHost host, int srcIndex, ColumnType srcType, RowCursor cursor) => NormalizeTransform.MeanVarUtils.CreateBuilder(this, host, srcIndex, srcType, cursor); } @@ -142,7 +152,7 @@ public LogMeanVarColumn(string input, string output = null, UseCdf = useCdf; } - internal override IColumnFunctionBuilder MakeBuilder(IHost host, int srcIndex, ColumnType srcType, IRowCursor cursor) + internal override IColumnFunctionBuilder MakeBuilder(IHost host, int srcIndex, ColumnType srcType, RowCursor cursor) => NormalizeTransform.LogMeanVarUtils.CreateBuilder(this, host, srcIndex, srcType, cursor); } @@ -157,10 +167,33 @@ public BinningColumn(string input, string output = null, NumBins = numBins; } - internal override IColumnFunctionBuilder MakeBuilder(IHost host, int srcIndex, ColumnType srcType, IRowCursor cursor) + internal override IColumnFunctionBuilder MakeBuilder(IHost host, int srcIndex, ColumnType srcType, RowCursor cursor) => NormalizeTransform.BinUtils.CreateBuilder(this, host, srcIndex, srcType, cursor); } + public sealed class SupervisedBinningColumn : FixZeroColumnBase + { + public readonly int NumBins; + public readonly string LabelColumn; + public readonly int MinBinSize; + + public SupervisedBinningColumn(string input, string output = null, + string labelColumn = DefaultColumnNames.Label, + long maxTrainingExamples = Defaults.MaxTrainingExamples, + bool fixZero = true, + int numBins = Defaults.NumBins, + int minBinSize = Defaults.MinBinSize) + : base(input, output ?? input, maxTrainingExamples, fixZero) + { + NumBins = numBins; + LabelColumn = labelColumn; + MinBinSize = minBinSize; + } + + internal override IColumnFunctionBuilder MakeBuilder(IHost host, int srcIndex, ColumnType srcType, RowCursor cursor) + => NormalizeTransform.SupervisedBinUtils.CreateBuilder(this, host, LabelColumn, srcIndex, srcType, cursor); + } + private readonly IHost _host; private readonly ColumnBase[] _columns; @@ -213,7 +246,7 @@ public NormalizingTransformer Fit(IDataView input) public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in _columns) { @@ -241,6 +274,23 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) public sealed partial class NormalizingTransformer : OneToOneTransformerBase { public const string LoaderSignature = "Normalizer"; + + internal const string LoaderSignatureOld = "NormalizeFunction"; + + private static VersionInfo GetOldVersionInfo() + { + return new VersionInfo( + modelSignature: "NORMFUNC", + // verWrittenCur: 0x00010001, // Initial + // verWrittenCur: 0x00010002, // Changed to OneToOneColumn + verWrittenCur: 0x00010003, // Support generic column functions + verReadableCur: 0x00010003, + verWeCanReadBack: 0x00010003, + loaderSignature: LoaderSignature, + loaderSignatureAlt: LoaderSignatureOld, + loaderAssemblyName: typeof(NormalizingTransformer).Assembly.FullName); + } + private static VersionInfo GetVersionInfo() { return new VersionInfo( @@ -322,6 +372,7 @@ public ColumnFunctionAccessor(ImmutableArray infos) } /// An accessor of the column functions within . + [BestFriend] internal readonly IReadOnlyList ColumnFunctions; public readonly ImmutableArray Columns; @@ -338,7 +389,7 @@ public static NormalizingTransformer Train(IHostEnvironment env, IDataView data, env.CheckValue(data, nameof(data)); env.CheckValue(columns, nameof(columns)); - bool[] activeInput = new bool[data.Schema.ColumnCount]; + bool[] activeInput = new bool[data.Schema.Count]; var srcCols = new int[columns.Length]; var srcTypes = new ColumnType[columns.Length]; @@ -348,8 +399,15 @@ public static NormalizingTransformer Train(IHostEnvironment env, IDataView data, bool success = data.Schema.TryGetColumnIndex(info.Input, out srcCols[i]); if (!success) throw env.ExceptSchemaMismatch(nameof(data), "input", info.Input); - srcTypes[i] = data.Schema.GetColumnType(srcCols[i]); + srcTypes[i] = data.Schema[srcCols[i]].Type; activeInput[srcCols[i]] = true; + + var supervisedBinColumn = info as NormalizingEstimator.SupervisedBinningColumn; + if(supervisedBinColumn != null) + { + var labelColumnId = SupervisedBinUtils.GetLabelColumnId(env, data.Schema, supervisedBinColumn.LabelColumn); + activeInput[labelColumnId] = true; + } } var functionBuilders = new IColumnFunctionBuilder[columns.Length]; @@ -416,6 +474,7 @@ private NormalizingTransformer(IHost host, ModelLoadContext ctx) // for each added column: // - source type // - separate model for column function + var cols = new ColumnInfo[ColumnPairs.Length]; ColumnFunctions = new ColumnFunctionAccessor(Columns); for (int iinfo = 0; iinfo < ColumnPairs.Length; iinfo++) @@ -429,6 +488,27 @@ private NormalizingTransformer(IHost host, ModelLoadContext ctx) Columns = ImmutableArray.Create(cols); } + // This constructor for models in old format. + private NormalizingTransformer(IHost host, ModelLoadContext ctx, IDataView input) + : base(host, ctx) + { + // *** Binary format *** + // + // for each added column: + // - separate model for column function + var cols = new ColumnInfo[ColumnPairs.Length]; + ColumnFunctions = new ColumnFunctionAccessor(Columns); + for (int iinfo = 0; iinfo < ColumnPairs.Length; iinfo++) + { + var dir = string.Format("Normalizer_{0:000}", iinfo); + var typeSrc = input.Schema[ColumnPairs[iinfo].input].Type; + ctx.LoadModel(Host, out var function, dir, Host, typeSrc); + cols[iinfo] = new ColumnInfo(ColumnPairs[iinfo].input, ColumnPairs[iinfo].output, typeSrc, function); + } + + Columns = ImmutableArray.Create(cols); + } + public static NormalizingTransformer Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); @@ -437,9 +517,21 @@ public static NormalizingTransformer Create(IHostEnvironment env, ModelLoadConte return new NormalizingTransformer(env.Register(nameof(NormalizingTransformer)), ctx); } + // Factory method for SignatureLoadDataTransform. + private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(GetOldVersionInfo()); + int cbFloat = ctx.Reader.ReadInt32(); + env.CheckDecode(cbFloat == sizeof(float)); + var transformer = new NormalizingTransformer(env.Register(nameof(NormalizingTransformer)), ctx, input); + return transformer.MakeDataTransform(input); + } + // Factory method for SignatureLoadRowMapper. - private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) - => Create(env, ctx).MakeRowMapper(Schema.Create(inputSchema)); + private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema) + => Create(env, ctx).MakeRowMapper(inputSchema); public override void Save(ModelSaveContext ctx) { @@ -463,11 +555,11 @@ public override void Save(ModelSaveContext ctx) } } - protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol) + protected override void CheckInputColumn(Schema inputSchema, int col, int srcCol) { const string expectedType = "scalar or known-size vector of R4"; - var colType = inputSchema.GetColumnType(srcCol); + var colType = inputSchema[srcCol].Type; if (colType.IsVector && !colType.IsKnownSizeVector) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input, expectedType, "variable-size vector"); if (!colType.ItemType.Equals(NumberType.R4) && !colType.ItemType.Equals(NumberType.R8)) @@ -478,7 +570,7 @@ protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCo public new IDataTransform MakeDataTransform(IDataView input) => base.MakeDataTransform(input); - protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); + private protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx, ISaveAsPfa { @@ -516,7 +608,7 @@ private void IsNormalizedGetter(ref bool dst) dst = true; } - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { disposer = null; return _parent.Columns[iinfo].ColumnFunction.GetGetter(input, ColMapNewToOld[iinfo]); @@ -725,4 +817,4 @@ internal BinNormalizerModelParameters(ImmutableArray upperBounds, TData d } } } -} \ No newline at end of file +} diff --git a/src/Microsoft.ML.Data/Transforms/NormalizerCatalog.cs b/src/Microsoft.ML.Data/Transforms/NormalizerCatalog.cs index df9c435def..a89a8a0ee2 100644 --- a/src/Microsoft.ML.Data/Transforms/NormalizerCatalog.cs +++ b/src/Microsoft.ML.Data/Transforms/NormalizerCatalog.cs @@ -2,8 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Data; using Microsoft.ML.Transforms.Normalizers; namespace Microsoft.ML diff --git a/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs b/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs index 0a74c8a39c..991ee26284 100644 --- a/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs +++ b/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs @@ -2,12 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Data; -using Microsoft.ML.Runtime.Model; using System; using System.Collections.Generic; +using Microsoft.ML.Model; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// Base class for transformer which operates on pairs input and output columns. @@ -68,7 +67,7 @@ protected void SaveColumns(ModelSaveContext ctx) } } - private void CheckInput(ISchema inputSchema, int col, out int srcCol) + private void CheckInput(Schema inputSchema, int col, out int srcCol) { Contracts.AssertValue(inputSchema); Contracts.Assert(0 <= col && col < ColumnPairs.Length); @@ -78,7 +77,7 @@ private void CheckInput(ISchema inputSchema, int col, out int srcCol) CheckInputColumn(inputSchema, col, srcCol); } - protected virtual void CheckInputColumn(ISchema inputSchema, int col, int srcCol) + protected virtual void CheckInputColumn(Schema inputSchema, int col, int srcCol) { // By default, there are no extra checks. } @@ -101,9 +100,9 @@ protected OneToOneMapperBase(IHost host, OneToOneTransformerBase parent, Schema } } - public override Func GetDependencies(Func activeOutput) + private protected override Func GetDependenciesCore(Func activeOutput) { - var active = new bool[InputSchema.ColumnCount]; + var active = new bool[InputSchema.Count]; foreach (var pair in ColMapNewToOld) if (activeOutput(pair.Key)) active[pair.Value] = true; diff --git a/src/Microsoft.ML.Data/Transforms/PerGroupTransformBase.cs b/src/Microsoft.ML.Data/Transforms/PerGroupTransformBase.cs index 7fddf36dbd..5d6b71246a 100644 --- a/src/Microsoft.ML.Data/Transforms/PerGroupTransformBase.cs +++ b/src/Microsoft.ML.Data/Transforms/PerGroupTransformBase.cs @@ -3,10 +3,9 @@ // See the LICENSE file in the project root for more information. using System; -using Microsoft.ML.Data; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Model; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// This is a base implementation for a transform that in order to compute its output columns, needs to look @@ -25,13 +24,13 @@ public abstract class PerGroupTransformBase : IDataTrans /// Deriving classes only need to implement . /// If any of the output columns have metadata, then the metadata methods should be overridden. /// - protected abstract class BindingsBase : ColumnBindingsBase + private protected abstract class BindingsBase : ColumnBindingsBase { public readonly int LabelIndex; public readonly int ScoreIndex; public readonly int GroupIndex; - protected BindingsBase(IExceptionContext ectx, ISchema input, string labelCol, string scoreCol, string groupCol, bool user, params string[] names) + protected BindingsBase(IExceptionContext ectx, Schema input, string labelCol, string scoreCol, string groupCol, bool user, params string[] names) : base(input, user, names) { ectx.AssertNonWhiteSpace(labelCol); @@ -63,7 +62,7 @@ public Func GetDependencies(Func predicate) { Contracts.AssertValue(predicate); - var active = new bool[Input.ColumnCount]; + var active = new bool[Input.Count]; for (int col = 0; col < ColumnCount; col++) { if (!predicate(col)) @@ -91,7 +90,9 @@ public Func GetDependencies(Func predicate) protected readonly string ScoreCol; protected readonly string GroupCol; - public Schema Schema => GetBindings().AsSchema; + Schema IDataView.Schema => OutputSchema; + + public Schema OutputSchema => GetBindings().AsSchema; public IDataView Source { get; } @@ -143,22 +144,21 @@ public virtual void Save(ModelSaveContext ctx) ctx.SaveNonEmptyString(GroupCol); } - protected abstract BindingsBase GetBindings(); + private protected abstract BindingsBase GetBindings(); public long? GetRowCount() { return Source.GetRowCount(); } - public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, IRandom rand = null) + public RowCursor[] GetRowCursorSet(Func predicate, int n, Random rand = null) { Host.CheckValue(predicate, nameof(predicate)); Host.CheckValueOrNull(rand); - consolidator = null; - return new IRowCursor[] { GetRowCursor(predicate, rand) }; + return new RowCursor[] { GetRowCursor(predicate, rand) }; } - public IRowCursor GetRowCursor(Func predicate, IRandom rand = null) + public RowCursor GetRowCursor(Func predicate, Random rand = null) { Host.CheckValue(predicate, nameof(predicate)); Host.CheckValueOrNull(rand); @@ -176,14 +176,14 @@ public IRowCursor GetRowCursor(Func predicate, IRandom rand = null) return GetRowCursorCore(predicate); } - private IRowCursor GetRowCursorCore(Func predicate) + private RowCursor GetRowCursorCore(Func predicate) { var bindings = GetBindings(); var active = bindings.GetActive(predicate); Contracts.Assert(active.Length == bindings.ColumnCount); var predInput = bindings.GetDependencies(predicate); - return new RowCursor(this, Source.GetRowCursor(predInput, null), Source.GetRowCursor(predInput, null), active); + return new Cursor(this, Source.GetRowCursor(predInput, null), Source.GetRowCursor(predInput, null), active); } /// @@ -197,17 +197,17 @@ private IRowCursor GetRowCursorCore(Func predicate) /// /// Get the getter for the first input column. /// - protected abstract ValueGetter GetLabelGetter(IRow row); + protected abstract ValueGetter GetLabelGetter(Row row); /// /// Get the getter for the second input column. /// - protected abstract ValueGetter GetScoreGetter(IRow row); + protected abstract ValueGetter GetScoreGetter(Row row); /// /// Return a new state object. /// - protected abstract TState InitializeState(IRow input); + protected abstract TState InitializeState(Row input); /// /// Update the state object with one example. @@ -220,11 +220,11 @@ private IRowCursor GetRowCursorCore(Func predicate) /// protected abstract void UpdateState(TState state); - private sealed class RowCursor : RootCursorBase, IRowCursor + private sealed class Cursor : RootCursorBase { private readonly PerGroupTransformBase _parent; - private readonly IRowCursor _groupCursor; - private readonly IRowCursor _input; + private readonly RowCursor _groupCursor; + private readonly RowCursor _input; private readonly bool[] _active; private readonly Delegate[] _getters; @@ -235,11 +235,11 @@ private sealed class RowCursor : RootCursorBase, IRowCursor private readonly ValueGetter _labelGetter; private readonly ValueGetter _scoreGetter; - public Schema Schema => _parent.Schema; + public override Schema Schema => _parent.OutputSchema; - public override long Batch { get { return 0; } } + public override long Batch => 0; - public RowCursor(PerGroupTransformBase parent, IRowCursor input, IRowCursor groupCursor, bool[] active) + public Cursor(PerGroupTransformBase parent, RowCursor input, RowCursor groupCursor, bool[] active) : base(parent.Host) { Ch.AssertValue(parent); @@ -265,13 +265,13 @@ public RowCursor(PerGroupTransformBase parent, IRowCurso _scoreGetter = _parent.GetScoreGetter(_groupCursor); } - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { Ch.Check(0 <= col && col < _parent.GetBindings().ColumnCount); return _active[col]; } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Contracts.CheckParam(IsColumnActive(col), nameof(col), "requested column is not active"); @@ -292,13 +292,13 @@ public ValueGetter GetGetter(int col) return fn; } - public override ValueGetter GetIdGetter() + public override ValueGetter GetIdGetter() { return - (ref UInt128 val) => + (ref RowId val) => { Ch.Check(IsGood, "Cannot call ID getter in current state"); - val = new UInt128((ulong)Position, 0); + val = new RowId((ulong)Position, 0); }; } @@ -322,8 +322,8 @@ protected override bool MoveNextCore() // Read the whole group from the auxiliary cursor. while (_groupCursor.State != CursorState.Done && !_newGroupInGroupCursorDel()) { - TLabel label = default(TLabel); - TScore score = default(TScore); + TLabel label = default; + TScore score = default; _labelGetter(ref label); _scoreGetter(ref score); _parent.ProcessExample(_state, label, score); diff --git a/src/Microsoft.ML.Data/Transforms/RangeFilter.cs b/src/Microsoft.ML.Data/Transforms/RangeFilter.cs index 59542b13aa..04d0a75c47 100644 --- a/src/Microsoft.ML.Data/Transforms/RangeFilter.cs +++ b/src/Microsoft.ML.Data/Transforms/RangeFilter.cs @@ -2,15 +2,15 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Transforms; using System; using System.Reflection; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Transforms; using Float = System.Single; [assembly: LoadableClass(RangeFilter.Summary, typeof(RangeFilter), typeof(RangeFilter.Arguments), typeof(SignatureDataTransform), @@ -102,7 +102,7 @@ public RangeFilter(IHostEnvironment env, Arguments args, IDataView input) using (var ch = Host.Start("Checking parameters")) { - _type = schema.GetColumnType(_index); + _type = schema[_index].Type; if (!IsValidRangeFilterColumnType(ch, _type)) throw ch.ExceptUserArg(nameof(args.Column), "Column '{0}' does not have compatible type", args.Column); if (_type.IsKey) @@ -150,7 +150,7 @@ private RangeFilter(IHost host, ModelLoadContext ctx, IDataView input) if (!schema.TryGetColumnIndex(column, out _index)) throw Host.Except("column", "Source column '{0}' not found", column); - _type = schema.GetColumnType(_index); + _type = schema[_index].Type; if (_type != NumberType.R4 && _type != NumberType.R8 && _type.KeyCount == 0) throw Host.Except("column", "Column '{0}' does not have compatible type", column); @@ -188,7 +188,7 @@ public override void Save(ModelSaveContext ctx) // byte: includeMin // byte: includeMax ctx.Writer.Write(sizeof(Float)); - ctx.SaveNonEmptyString(Source.Schema.GetColumnName(_index)); + ctx.SaveNonEmptyString(Source.Schema[_index].Name); Host.Assert(_min < _max); ctx.Writer.Write(_min); ctx.Writer.Write(_max); @@ -204,7 +204,7 @@ public override void Save(ModelSaveContext ctx) return null; } - protected override IRowCursor GetRowCursorCore(Func predicate, IRandom rand = null) + protected override RowCursor GetRowCursorCore(Func predicate, Random rand = null) { Host.AssertValue(predicate, "predicate"); Host.AssertValueOrNull(rand); @@ -215,25 +215,24 @@ protected override IRowCursor GetRowCursorCore(Func predicate, IRando return CreateCursorCore(input, active); } - public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, - Func predicate, int n, IRandom rand = null) + public override RowCursor[] GetRowCursorSet(Func predicate, int n, Random rand = null) { Host.CheckValue(predicate, nameof(predicate)); Host.CheckValueOrNull(rand); bool[] active; Func inputPred = GetActive(predicate, out active); - var inputs = Source.GetRowCursorSet(out consolidator, inputPred, n, rand); + var inputs = Source.GetRowCursorSet(inputPred, n, rand); Host.AssertNonEmpty(inputs); // No need to split if this is given 1 input cursor. - var cursors = new IRowCursor[inputs.Length]; + var cursors = new RowCursor[inputs.Length]; for (int i = 0; i < inputs.Length; i++) cursors[i] = CreateCursorCore(inputs[i], active); return cursors; } - private IRowCursor CreateCursorCore(IRowCursor input, bool[] active) + private RowCursor CreateCursorCore(RowCursor input, bool[] active) { if (_type == NumberType.R4) return new SingleRowCursor(this, input, active); @@ -246,8 +245,8 @@ private IRowCursor CreateCursorCore(IRowCursor input, bool[] active) private Func GetActive(Func predicate, out bool[] active) { Host.AssertValue(predicate); - active = new bool[Source.Schema.ColumnCount]; - bool[] activeInput = new bool[Source.Schema.ColumnCount]; + active = new bool[Source.Schema.Count]; + bool[] activeInput = new bool[Source.Schema.Count]; for (int i = 0; i < active.Length; i++) activeInput[i] = active[i] = predicate(i); activeInput[_index] = true; @@ -268,8 +267,8 @@ private abstract class RowCursorBase : LinkedRowFilterCursorBase private readonly Double _min; private readonly Double _max; - protected RowCursorBase(RangeFilter parent, IRowCursor input, bool[] active) - : base(parent.Host, input, parent.Schema, active) + protected RowCursorBase(RangeFilter parent, RowCursor input, bool[] active) + : base(parent.Host, input, parent.OutputSchema, active) { Parent = parent; _min = Parent._min; @@ -307,7 +306,7 @@ protected RowCursorBase(RangeFilter parent, IRowCursor input, bool[] active) public override ValueGetter GetGetter(int col) { - Ch.Check(0 <= col && col < Schema.ColumnCount); + Ch.Check(0 <= col && col < Schema.Count); Ch.Check(IsColumnActive(col)); if (col != Parent._index) @@ -319,15 +318,15 @@ public override ValueGetter GetGetter(int col) return fn; } - public static IRowCursor CreateKeyRowCursor(RangeFilter filter, IRowCursor input, bool[] active) + public static RowCursor CreateKeyRowCursor(RangeFilter filter, RowCursor input, bool[] active) { Contracts.Assert(filter._type.IsKey); - Func del = CreateKeyRowCursor; + Func del = CreateKeyRowCursor; var methodInfo = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(filter._type.RawType); - return (IRowCursor)methodInfo.Invoke(null, new object[] { filter, input, active }); + return (RowCursor)methodInfo.Invoke(null, new object[] { filter, input, active }); } - private static IRowCursor CreateKeyRowCursor(RangeFilter filter, IRowCursor input, bool[] active) + private static RowCursor CreateKeyRowCursor(RangeFilter filter, RowCursor input, bool[] active) { Contracts.Assert(filter._type.IsKey); return new KeyRowCursor(filter, input, active); @@ -340,7 +339,7 @@ private sealed class SingleRowCursor : RowCursorBase private readonly ValueGetter _getter; private Single _value; - public SingleRowCursor(RangeFilter parent, IRowCursor input, bool[] active) + public SingleRowCursor(RangeFilter parent, RowCursor input, bool[] active) : base(parent, input, active) { Ch.Assert(Parent._type == NumberType.R4); @@ -373,7 +372,7 @@ private sealed class DoubleRowCursor : RowCursorBase private readonly ValueGetter _getter; private Double _value; - public DoubleRowCursor(RangeFilter parent, IRowCursor input, bool[] active) + public DoubleRowCursor(RangeFilter parent, RowCursor input, bool[] active) : base(parent, input, active) { Ch.Assert(Parent._type == NumberType.R8); @@ -408,7 +407,7 @@ private sealed class KeyRowCursor : RowCursorBase private readonly ValueMapper _conv; private readonly int _count; - public KeyRowCursor(RangeFilter parent, IRowCursor input, bool[] active) + public KeyRowCursor(RangeFilter parent, RowCursor input, bool[] active) : base(parent, input, active) { Ch.Assert(Parent._type.KeyCount > 0); @@ -421,7 +420,7 @@ public KeyRowCursor(RangeFilter parent, IRowCursor input, bool[] active) dst = _value; }; bool identity; - _conv = Runtime.Data.Conversion.Conversions.Instance.GetStandardConversion(Parent._type, NumberType.U8, out identity); + _conv = Data.Conversion.Conversions.Instance.GetStandardConversion(Parent._type, NumberType.U8, out identity); } protected override Delegate GetGetter() diff --git a/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs b/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs index 4a74df460e..028a94cb87 100644 --- a/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs @@ -4,18 +4,17 @@ #pragma warning disable 420 // volatile with Interlocked.CompareExchange -using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Transforms; using System; using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; using System.Threading.Tasks.Dataflow; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Transforms; [assembly: LoadableClass(RowShufflingTransformer.Summary, typeof(RowShufflingTransformer), typeof(RowShufflingTransformer.Arguments), typeof(SignatureDataTransform), "Shuffle Transform", "ShuffleTransform", "Shuffle", "shuf")] @@ -191,9 +190,9 @@ private static IDataView SelectCachableColumns(IDataView data, IHostEnvironment { List columnsToDrop = null; var schema = data.Schema; - for (int c = 0; c < schema.ColumnCount; ++c) + for (int c = 0; c < schema.Count; ++c) { - var type = schema.GetColumnType(c); + var type = schema[c].Type; if (!type.IsCachable()) Utils.Add(ref columnsToDrop, c); } @@ -209,11 +208,11 @@ private static IDataView SelectCachableColumns(IDataView data, IHostEnvironment /// /// Utility to check whether all types in an input schema are shufflable. /// - internal static bool CanShuffleAll(ISchema schema) + internal static bool CanShuffleAll(Schema schema) { - for (int c = 0; c < schema.ColumnCount; ++c) + for (int c = 0; c < schema.Count; ++c) { - var type = schema.GetColumnType(c); + var type = schema[c].Type; if (!type.IsCachable()) return false; } @@ -223,7 +222,7 @@ internal static bool CanShuffleAll(ISchema schema) /// /// Utility to take a cursor, and get a shuffled version of this cursor. /// - public static IRowCursor GetShuffledCursor(IChannelProvider provider, int poolRows, IRowCursor cursor, IRandom rand) + public static RowCursor GetShuffledCursor(IChannelProvider provider, int poolRows, RowCursor cursor, Random rand) { Contracts.CheckValue(provider, nameof(provider)); @@ -236,12 +235,12 @@ public static IRowCursor GetShuffledCursor(IChannelProvider provider, int poolRo if (poolRows == 1) return cursor; - return new RowCursor(provider, poolRows, cursor, rand); + return new Cursor(provider, poolRows, cursor, rand); } public override bool CanShuffle { get { return true; } } - public override Schema Schema { get { return _subsetInput.Schema; } } + public override Schema OutputSchema { get { return _subsetInput.Schema; } } protected override bool? ShouldUseParallelCursors(Func predicate) { @@ -249,7 +248,7 @@ public static IRowCursor GetShuffledCursor(IChannelProvider provider, int poolRo return false; } - protected override IRowCursor GetRowCursorCore(Func predicate, IRandom rand = null) + protected override RowCursor GetRowCursorCore(Func predicate, Random rand = null) { Host.AssertValue(predicate, "predicate"); Host.AssertValueOrNull(rand); @@ -257,7 +256,7 @@ protected override IRowCursor GetRowCursorCore(Func predicate, IRando // REVIEW: This is slightly interesting. Our mechanism for inducing // randomness in the source cursor is this Random object, but this can change // from release to release. The correct solution, it seems, is to instead have - // randomness injected into cursor creation by using IRandom (or something akin + // randomness injected into cursor creation by using Random (or something akin // to it), vs. just a straight system Random. // The desired functionality is to support some permutations of whether we allow @@ -275,10 +274,10 @@ protected override IRowCursor GetRowCursorCore(Func predicate, IRando bool shouldShuffleMe = _forceShuffle || rand != null; bool shouldShuffleSource = _forceShuffleSource || (!_poolOnly && rand != null); - IRandom myRandom = rand ?? (shouldShuffleMe || shouldShuffleSource ? RandomUtils.Create(_forceShuffleSeed) : null); + Random myRandom = rand ?? (shouldShuffleMe || shouldShuffleSource ? RandomUtils.Create(_forceShuffleSeed) : null); if (shouldShuffleMe) rand = myRandom; - IRandom sourceRand = shouldShuffleSource ? RandomUtils.Create(myRandom) : null; + Random sourceRand = shouldShuffleSource ? RandomUtils.Create(myRandom) : null; var input = _subsetInput.GetRowCursor(predicate, sourceRand); // If rand is null (so we're not doing pool shuffling) or number of pool rows is 1 @@ -286,16 +285,14 @@ protected override IRowCursor GetRowCursorCore(Func predicate, IRando // source cursor. if (rand == null || _poolRows == 1) return input; - return new RowCursor(Host, _poolRows, input, rand); + return new Cursor(Host, _poolRows, input, rand); } - public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, - Func predicate, int n, IRandom rand = null) + public override RowCursor[] GetRowCursorSet(Func predicate, int n, Random rand = null) { Host.CheckValue(predicate, nameof(predicate)); Host.CheckValueOrNull(rand); - consolidator = null; - return new IRowCursor[] { GetRowCursorCore(predicate, rand) }; + return new RowCursor[] { GetRowCursorCore(predicate, rand) }; } /// @@ -344,7 +341,7 @@ public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolid /// The result is something functionally equivalent to but but considerably faster than the /// simple implementation described in the first paragraph. /// - private sealed class RowCursor : RootCursorBase, IRowCursor + private sealed class Cursor : RootCursorBase { /// /// Pipes, in addition to column values, will also communicate extra information @@ -465,8 +462,8 @@ public void Fetch(int idx, ref T value) private const int _bufferDepth = 3; private readonly int _poolRows; - private readonly IRowCursor _input; - private readonly IRandom _rand; + private readonly RowCursor _input; + private readonly Random _rand; // This acts as mapping from the "circular" index to the actual index within the pipe. private readonly int[] _pipeIndices; @@ -476,7 +473,7 @@ public void Fetch(int idx, ref T value) // Each delegate here corresponds to a pipe holding column data. private readonly Delegate[] _getters; // This delegate corresponds to the pipe holding ID data. - private readonly ValueGetter _idGetter; + private readonly ValueGetter _idGetter; // The current position of the output cursor in circular "space". private int _circularIndex; @@ -495,16 +492,14 @@ public void Fetch(int idx, ref T value) private Exception _producerTaskException; private readonly int[] _colToActivesIndex; + private bool _disposed; - public Schema Schema { get { return _input.Schema; } } + public override Schema Schema => _input.Schema; - public override long Batch - { - // REVIEW: Implement cursor set support. - get { return 0; } - } + // REVIEW: Implement cursor set support. + public override long Batch => 0; - public RowCursor(IChannelProvider provider, int poolRows, IRowCursor input, IRandom rand) + public Cursor(IChannelProvider provider, int poolRows, RowCursor input, Random rand) : base(provider) { Ch.AssertValue(input); @@ -520,7 +515,7 @@ public RowCursor(IChannelProvider provider, int poolRows, IRowCursor input, IRan _pipeIndices = Utils.GetIdentityPermutation(_poolRows - 1 + _bufferDepth * _blockSize); - int colLim = Schema.ColumnCount; + int colLim = Schema.Count; int numActive = 0; _colToActivesIndex = new int[colLim]; for (int c = 0; c < colLim; ++c) @@ -533,11 +528,11 @@ public RowCursor(IChannelProvider provider, int poolRows, IRowCursor input, IRan if (ia < 0) continue; _pipes[ia] = ShufflePipe.Create(_pipeIndices.Length, - input.Schema.GetColumnType(c), RowCursorUtils.GetGetterAsDelegate(input, c)); + input.Schema[c].Type, RowCursorUtils.GetGetterAsDelegate(input, c)); _getters[ia] = CreateGetterDelegate(c); } var idPipe = _pipes[numActive + (int)ExtraIndex.Id] = ShufflePipe.Create(_pipeIndices.Length, NumberType.UG, input.GetIdGetter()); - _idGetter = CreateGetterDelegate(idPipe); + _idGetter = CreateGetterDelegate(idPipe); // Initially, after the preamble to MoveNextCore, we want: // liveCount=0, deadCount=0, circularIndex=0. So we set these // funky values accordingly. @@ -557,14 +552,17 @@ public RowCursor(IChannelProvider provider, int poolRows, IRowCursor input, IRan _producerTask = LoopProducerWorker(); } - public override void Dispose() + protected override void Dispose(bool disposing) { - if (_producerTask.Status == TaskStatus.Running) + if (_disposed) + return; + if (disposing && _producerTask.Status == TaskStatus.Running) { _toProduce.Post(0); _producerTask.Wait(); } - base.Dispose(); + _disposed = true; + base.Dispose(disposing); } public static void PostAssert(ITargetBlock target, T item) @@ -573,7 +571,7 @@ public static void PostAssert(ITargetBlock target, T item) Contracts.Assert(retval); } - public override ValueGetter GetIdGetter() + public override ValueGetter GetIdGetter() { return _idGetter; } @@ -669,7 +667,7 @@ protected override bool MoveNextCore() return true; } - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { Ch.CheckParam(0 <= col && col < _colToActivesIndex.Length, nameof(col)); Ch.Assert((_colToActivesIndex[col] >= 0) == _input.IsColumnActive(col)); @@ -681,14 +679,14 @@ private Delegate CreateGetterDelegate(int col) Ch.Assert(0 <= col && col < _colToActivesIndex.Length); Ch.Assert(_colToActivesIndex[col] >= 0); Func createDel = CreateGetterDelegate; - return Utils.MarshalInvoke(createDel, Schema.GetColumnType(col).RawType, col); + return Utils.MarshalInvoke(createDel, Schema[col].Type.RawType, col); } private Delegate CreateGetterDelegate(int col) { Ch.Assert(0 <= col && col < _colToActivesIndex.Length); Ch.Assert(_colToActivesIndex[col] >= 0); - Ch.Assert(Schema.GetColumnType(col).RawType == typeof(TValue)); + Ch.Assert(Schema[col].Type.RawType == typeof(TValue)); return CreateGetterDelegate(_pipes[_colToActivesIndex[col]]); } @@ -706,7 +704,7 @@ private ValueGetter CreateGetterDelegate(ShufflePipe pipe) return getter; } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Ch.CheckParam(0 <= col && col < _colToActivesIndex.Length, nameof(col)); Ch.CheckParam(_colToActivesIndex[col] >= 0, nameof(col), "requested column not active"); diff --git a/src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs b/src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs index cce93e439b..e98b9aa231 100644 --- a/src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs +++ b/src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs @@ -2,13 +2,12 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Core.Data; -using Microsoft.ML.Data; -using Microsoft.ML.Runtime.Model; using System; using System.Linq; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Model; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// Base class for transformer which produce new columns, but doesn't affect existing ones. @@ -33,7 +32,8 @@ public IRowToRowMapper GetRowToRowMapper(Schema inputSchema) return new RowToRowMapperTransform(Host, new EmptyDataView(Host, inputSchema), MakeRowMapper(inputSchema), MakeRowMapper); } - protected abstract IRowMapper MakeRowMapper(Schema schema); + [BestFriend] + private protected abstract IRowMapper MakeRowMapper(Schema schema); public Schema GetOutputSchema(Schema inputSchema) { @@ -67,9 +67,9 @@ protected MapperBase(IHost host, Schema inputSchema) protected abstract Schema.DetachedColumn[] GetOutputColumnsCore(); - public Schema.DetachedColumn[] GetOutputColumns() => _outputColumns.Value; + Schema.DetachedColumn[] IRowMapper.GetOutputColumns() => _outputColumns.Value; - public Delegate[] CreateGetters(IRow input, Func activeOutput, out Action disposer) + Delegate[] IRowMapper.CreateGetters(Row input, Func activeOutput, out Action disposer) { // REVIEW: it used to be that the mapper's input schema in the constructor was required to be reference-equal to the schema // of the input row. @@ -98,9 +98,13 @@ public Delegate[] CreateGetters(IRow input, Func activeOutput, out Ac return result; } - protected abstract Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer); + protected abstract Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer); + + Func IRowMapper.GetDependencies(Func activeOutput) + => GetDependenciesCore(activeOutput); - public abstract Func GetDependencies(Func activeOutput); + [BestFriend] + private protected abstract Func GetDependenciesCore(Func activeOutput); public abstract void Save(ModelSaveContext ctx); } diff --git a/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs b/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs index bf3331d0a5..f4faea48eb 100644 --- a/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs +++ b/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs @@ -2,15 +2,14 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; +using Microsoft.ML; +using Microsoft.ML.CommandLine; using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; using Microsoft.ML.Transforms; -using System; [assembly: LoadableClass(SkipTakeFilter.SkipTakeFilterSummary, typeof(SkipTakeFilter), typeof(SkipTakeFilter.Arguments), typeof(SignatureDataTransform), SkipTakeFilter.SkipTakeFilterUserName, "SkipTakeFilter", SkipTakeFilter.SkipTakeFilterShortName)] @@ -188,26 +187,24 @@ public override void Save(ModelSaveContext ctx) return false; } - protected override IRowCursor GetRowCursorCore(Func predicate, IRandom rand = null) + protected override RowCursor GetRowCursorCore(Func predicate, Random rand = null) { Host.AssertValue(predicate); Host.AssertValueOrNull(rand); var input = Source.GetRowCursor(predicate); - var activeColumns = Utils.BuildArray(Schema.ColumnCount, predicate); - return new RowCursor(Host, input, Schema, activeColumns, _skip, _take); + var activeColumns = Utils.BuildArray(OutputSchema.Count, predicate); + return new Cursor(Host, input, OutputSchema, activeColumns, _skip, _take); } - public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, - Func predicate, int n, IRandom rand = null) + public override RowCursor[] GetRowCursorSet(Func predicate, int n, Random rand = null) { Host.CheckValue(predicate, nameof(predicate)); Host.CheckValueOrNull(rand); - consolidator = null; - return new IRowCursor[] { GetRowCursorCore(predicate) }; + return new RowCursor[] { GetRowCursorCore(predicate) }; } - private sealed class RowCursor : LinkedRowRootCursorBase + private sealed class Cursor : LinkedRowRootCursorBase { private readonly long _skip; private readonly long _take; @@ -219,7 +216,7 @@ public override long Batch { get { return 0; } } - public RowCursor(IChannelProvider provider, IRowCursor input, Schema schema, bool[] active, long skip, long take) + public Cursor(IChannelProvider provider, RowCursor input, Schema schema, bool[] active, long skip, long take) : base(provider, input, schema, active) { Ch.Assert(skip >= 0); @@ -229,7 +226,7 @@ public RowCursor(IChannelProvider provider, IRowCursor input, Schema schema, boo _take = take; } - public override ValueGetter GetIdGetter() + public override ValueGetter GetIdGetter() { return Input.GetIdGetter(); } diff --git a/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransformer.cs b/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransformer.cs index e8ec5f4e53..b1fa2491f9 100644 --- a/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransformer.cs @@ -2,15 +2,14 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Calibration; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Transforms; -using System; using System.Collections.Generic; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Calibration; +using Microsoft.ML.Model; +using Microsoft.ML.Transforms; [assembly: LoadableClass(ScoringTransformer.Summary, typeof(IDataTransform), typeof(ScoringTransformer), typeof(ScoringTransformer.Arguments), typeof(SignatureDataTransform), "Score Predictor", "Score")] @@ -20,7 +19,8 @@ namespace Microsoft.ML.Transforms { - public static class ScoringTransformer + [BestFriend] + internal static class ScoringTransformer { public sealed class Arguments : TransformInputBase { diff --git a/src/Microsoft.ML.Data/Transforms/TransformBase.cs b/src/Microsoft.ML.Data/Transforms/TransformBase.cs index 83dce385ad..141c6aceee 100644 --- a/src/Microsoft.ML.Data/Transforms/TransformBase.cs +++ b/src/Microsoft.ML.Data/Transforms/TransformBase.cs @@ -6,14 +6,13 @@ using System.Collections.Generic; using System.Linq; using System.Reflection; -using Microsoft.ML.Data; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Model.Onnx; -using Microsoft.ML.Runtime.Model.Pfa; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Model.Onnx; +using Microsoft.ML.Model.Pfa; using Newtonsoft.Json.Linq; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// Base class for transforms. @@ -49,9 +48,18 @@ protected TransformBase(IHost host, IDataView input) public virtual bool CanShuffle { get { return Source.CanShuffle; } } - public abstract Schema Schema { get; } + /// + /// The field is the type information of the produced IDataView of this transformer. + /// + /// Explicit interface implementation hides in all derived classes. The reason + /// is that a transformer should know the type it will produce but shouldn't contain the type of the data it produces. + /// Thus, this field will be eventually removed while legacy code can still access for now. + /// + Schema IDataView.Schema => OutputSchema; - public IRowCursor GetRowCursor(Func predicate, IRandom rand = null) + public abstract Schema OutputSchema { get; } + + public RowCursor GetRowCursor(Func predicate, Random rand = null) { Host.CheckValue(predicate, nameof(predicate)); Host.CheckValueOrNull(rand); @@ -63,7 +71,7 @@ public IRowCursor GetRowCursor(Func predicate, IRandom rand = null) // When the input wants to be split, this puts the consolidation after this transform // instead of before. This is likely to produce better performance, for example, when // this is RangeFilter. - IRowCursor curs; + RowCursor curs; if (useParallel != false && DataViewUtils.TryCreateConsolidatingCursor(out curs, this, predicate, Host, rng)) { @@ -84,10 +92,9 @@ public IRowCursor GetRowCursor(Func predicate, IRandom rand = null) /// /// Create a single (non-parallel) row cursor. /// - protected abstract IRowCursor GetRowCursorCore(Func predicate, IRandom rand = null); + protected abstract RowCursor GetRowCursorCore(Func predicate, Random rand = null); - public abstract IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, - Func predicate, int n, IRandom rand = null); + public abstract RowCursor[] GetRowCursorSet(Func predicate, int n, Random rand = null); } /// @@ -127,7 +134,7 @@ private protected FilterBase(IHost host, IDataView input) public override long? GetRowCount() => null; - public sealed override Schema Schema => Source.Schema; + public override Schema OutputSchema => Source.Schema; bool ICanSavePfa.CanSavePfa => true; @@ -157,56 +164,56 @@ public Func GetDependencies(Func predicate) protected abstract Func GetDependenciesCore(Func predicate); - Schema IRowToRowMapper.InputSchema => Source.Schema; + public Schema InputSchema => Source.Schema; - public IRow GetRow(IRow input, Func active, out Action disposer) + public Row GetRow(Row input, Func active) { Host.CheckValue(input, nameof(input)); Host.CheckValue(active, nameof(active)); Host.Check(input.Schema == Source.Schema, "Schema of input row must be the same as the schema the mapper is bound to"); - disposer = null; using (var ch = Host.Start("GetEntireRow")) { - Action disp; - var getters = CreateGetters(input, active, out disp); - disposer += disp; - return new Row(input, this, Schema, getters); + var getters = CreateGetters(input, active, out Action disp); + return new RowImpl(input, this, OutputSchema, getters, disp); } } - protected abstract Delegate[] CreateGetters(IRow input, Func active, out Action disp); + protected abstract Delegate[] CreateGetters(Row input, Func active, out Action disp); protected abstract int MapColumnIndex(out bool isSrc, int col); - private sealed class Row : IRow + private sealed class RowImpl : WrappingRow { private readonly Schema _schema; - private readonly IRow _input; private readonly Delegate[] _getters; + private readonly Action _disposer; private readonly RowToRowMapperTransformBase _parent; - public long Batch { get { return _input.Batch; } } + public override Schema Schema => _schema; - public long Position { get { return _input.Position; } } - - public Schema Schema { get { return _schema; } } - - public Row(IRow input, RowToRowMapperTransformBase parent, Schema schema, Delegate[] getters) + public RowImpl(Row input, RowToRowMapperTransformBase parent, Schema schema, Delegate[] getters, Action disposer) + : base(input) { - _input = input; _parent = parent; _schema = schema; _getters = getters; + _disposer = disposer; + } + + protected override void DisposeCore(bool disposing) + { + if (disposing) + _disposer?.Invoke(); } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { bool isSrc; int index = _parent.MapColumnIndex(out isSrc, col); if (isSrc) - return _input.GetGetter(index); + return Input.GetGetter(index); Contracts.Assert(_getters[index] != null); var fn = _getters[index] as ValueGetter; @@ -215,17 +222,12 @@ public ValueGetter GetGetter(int col) return fn; } - public ValueGetter GetIdGetter() - { - return _input.GetIdGetter(); - } - - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { bool isSrc; int index = _parent.MapColumnIndex(out isSrc, col); if (isSrc) - return _input.IsColumnActive((index)); + return Input.IsColumnActive((index)); return _getters[index] != null; } } @@ -283,7 +285,7 @@ private sealed class Bindings : ColumnBindingsBase, ITransposeSchema private const string InvalidTypeErrorFormat = "Source column '{0}' has invalid type ('{1}'): {2}."; private Bindings(OneToOneTransformBase parent, ColInfo[] infos, - ISchema input, bool user, string[] names) + Schema input, bool user, string[] names) : base(input, user, names) { Contracts.AssertValue(parent); @@ -296,7 +298,7 @@ private Bindings(OneToOneTransformBase parent, ColInfo[] infos, Infos = infos; } - public static Bindings Create(OneToOneTransformBase parent, OneToOneColumn[] column, ISchema input, + public static Bindings Create(OneToOneTransformBase parent, OneToOneColumn[] column, Schema input, ITransposeSchema transInput, Func testType) { Contracts.AssertValue(parent); @@ -318,7 +320,7 @@ public static Bindings Create(OneToOneTransformBase parent, OneToOneColumn[] col if (!input.TryGetColumnIndex(item.Source, out colSrc)) throw host.ExceptUserArg(nameof(OneToOneColumn.Source), "Source column '{0}' not found", item.Source); - var type = input.GetColumnType(colSrc); + var type = input[colSrc].Type; if (testType != null) { string reason = testType(type); @@ -333,7 +335,7 @@ public static Bindings Create(OneToOneTransformBase parent, OneToOneColumn[] col return new Bindings(parent, infos, input, true, names); } - public static Bindings Create(OneToOneTransformBase parent, ModelLoadContext ctx, ISchema input, + public static Bindings Create(OneToOneTransformBase parent, ModelLoadContext ctx, Schema input, ITransposeSchema transInput, Func testType) { Contracts.AssertValue(parent); @@ -367,7 +369,7 @@ public static Bindings Create(OneToOneTransformBase parent, ModelLoadContext ctx int colSrc; if (!input.TryGetColumnIndex(src, out colSrc)) throw host.Except("Source column '{0}' is required but not found", src); - var type = input.GetColumnType(colSrc); + var type = input[colSrc].Type; if (testType != null) { string reason = testType(type); @@ -394,7 +396,7 @@ public void Save(ModelSaveContext ctx) foreach (var info in Infos) { ctx.SaveNonEmptyString(info.Name); - ctx.SaveNonEmptyString(Input.GetColumnName(info.Source)); + ctx.SaveNonEmptyString(Input[info.Source].Name); } } @@ -402,7 +404,7 @@ public Func GetDependencies(Func predicate) { Contracts.AssertValue(predicate); - var active = new bool[Input.ColumnCount]; + var active = new bool[Input.Count]; for (int col = 0; col < ColumnCount; col++) { if (!predicate(col)) @@ -467,9 +469,9 @@ private sealed class ColumnTmp : OneToOneColumn // The ColInfos are exposed to sub-classes. They should be considered readonly. protected readonly ColInfo[] Infos; // The _input as a transposed data view, non-null iff _input is a transposed data view. - protected readonly ITransposeDataView InputTranspose; + private protected readonly ITransposeDataView InputTranspose; // The InputTranspose transpose schema, null iff InputTranspose is null. - protected ITransposeSchema InputTransposeSchema => InputTranspose?.TransposeSchema; + private protected ITransposeSchema InputTransposeSchema => InputTranspose?.TransposeSchema; bool ICanSavePfa.CanSavePfa => CanSavePfaCore; @@ -601,7 +603,7 @@ void ISaveAsOnnx.SaveAsOnnx(OnnxContext ctx) } if (!SaveAsOnnxCore(ctx, iinfo, info, ctx.GetVariableName(sourceColumnName), - ctx.AddIntermediateVariable(Schema[_bindings.MapIinfoToCol(iinfo)].Type, info.Name))) + ctx.AddIntermediateVariable(OutputSchema[_bindings.MapIinfoToCol(iinfo)].Type, info.Name))) { ctx.RemoveColumn(info.Name, true); } @@ -635,9 +637,9 @@ private protected virtual JToken SaveAsPfaCore(BoundPfaContext ctx, int iinfo, C private protected virtual bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName) => false; - public sealed override Schema Schema => _bindings.AsSchema; + public sealed override Schema OutputSchema => _bindings.AsSchema; - public ITransposeSchema TransposeSchema => _bindings; + ITransposeSchema ITransposeDataView.TransposeSchema => _bindings; /// /// Return the (destination) column index for the indicated added column. @@ -674,9 +676,9 @@ protected virtual void ActivateSourceColumns(int iinfo, bool[] active) /// otherwise it should be set to a delegate to be invoked by the cursor's Dispose method. It's best /// for this action to be idempotent - calling it multiple times should be equivalent to calling it once. /// - protected abstract Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer); + protected abstract Delegate GetGetterCore(IChannel ch, Row input, int iinfo, out Action disposer); - protected ValueGetter GetSrcGetter(IRow input, int iinfo) + protected ValueGetter GetSrcGetter(Row input, int iinfo) { Host.AssertValue(input); Host.Assert(0 <= iinfo && iinfo < Infos.Length); @@ -685,12 +687,12 @@ protected ValueGetter GetSrcGetter(IRow input, int iinfo) return input.GetGetter(src); } - protected Delegate GetSrcGetter(ColumnType typeDst, IRow row, int iinfo) + protected Delegate GetSrcGetter(ColumnType typeDst, Row row, int iinfo) { Host.CheckValue(typeDst, nameof(typeDst)); Host.CheckValue(row, nameof(row)); - Func> del = GetSrcGetter; + Func> del = GetSrcGetter; var methodInfo = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(typeDst.RawType); return (Delegate)methodInfo.Invoke(this, new object[] { row, iinfo }); } @@ -718,7 +720,7 @@ protected virtual bool WantParallelCursors(Func predicate) return _bindings.AnyNewColumnsActive(predicate); } - protected override IRowCursor GetRowCursorCore(Func predicate, IRandom rand = null) + protected override RowCursor GetRowCursorCore(Func predicate, Random rand = null) { Host.AssertValue(predicate, "predicate"); Host.AssertValueOrNull(rand); @@ -726,27 +728,26 @@ protected override IRowCursor GetRowCursorCore(Func predicate, IRando var inputPred = _bindings.GetDependencies(predicate); var active = _bindings.GetActive(predicate); var input = Source.GetRowCursor(inputPred, rand); - return new RowCursor(Host, this, input, active); + return new Cursor(Host, this, input, active); } - public sealed override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, - Func predicate, int n, IRandom rand = null) + public sealed override RowCursor[] GetRowCursorSet(Func predicate, int n, Random rand = null) { Host.CheckValue(predicate, nameof(predicate)); Host.CheckValueOrNull(rand); var inputPred = _bindings.GetDependencies(predicate); var active = _bindings.GetActive(predicate); - var inputs = Source.GetRowCursorSet(out consolidator, inputPred, n, rand); + var inputs = Source.GetRowCursorSet(inputPred, n, rand); Host.AssertNonEmpty(inputs); if (inputs.Length == 1 && n > 1 && WantParallelCursors(predicate)) - inputs = DataViewUtils.CreateSplitCursors(out consolidator, Host, inputs[0], n); + inputs = DataViewUtils.CreateSplitCursors(Host, inputs[0], n); Host.AssertNonEmpty(inputs); - var cursors = new IRowCursor[inputs.Length]; + var cursors = new RowCursor[inputs.Length]; for (int i = 0; i < inputs.Length; i++) - cursors[i] = new RowCursor(Host, this, inputs[i], active); + cursors[i] = new Cursor(Host, this, inputs[i], active); return cursors; } @@ -758,10 +759,10 @@ protected Exception ExceptGetSlotCursor(int col) { Host.Assert(0 <= col && col < _bindings.ColumnCount); return Host.ExceptParam(nameof(col), "Bad call to GetSlotCursor on untransposable column '{0}'", - Schema[col].Name); + OutputSchema[col].Name); } - public ISlotCursor GetSlotCursor(int col) + public SlotCursor GetSlotCursor(int col) { Host.CheckParam(0 <= col && col < _bindings.ColumnCount, nameof(col)); @@ -786,7 +787,7 @@ public ISlotCursor GetSlotCursor(int col) /// null for all new columns, and so reaching this is only possible if there is a /// bug. /// - protected virtual ISlotCursor GetSlotCursorCore(int iinfo) + protected virtual SlotCursor GetSlotCursorCore(int iinfo) { Host.Assert(false); throw Host.ExceptNotImpl("Data view indicated it could transpose a column, but apparently it could not"); @@ -802,7 +803,7 @@ protected override Func GetDependenciesCore(Func predicate return _bindings.GetDependencies(predicate); } - protected override Delegate[] CreateGetters(IRow input, Func active, out Action disposer) + protected override Delegate[] CreateGetters(Row input, Func active, out Action disposer) { Func activeInfos = iinfo => @@ -827,15 +828,16 @@ protected override Delegate[] CreateGetters(IRow input, Func active, } } - private sealed class RowCursor : SynchronizedCursorBase, IRowCursor + private sealed class Cursor : SynchronizedCursorBase { private readonly Bindings _bindings; private readonly bool[] _active; private readonly Delegate[] _getters; - private readonly Action[] _disposers; + private readonly Action _disposer; + private bool _disposed; - public RowCursor(IChannelProvider provider, OneToOneTransformBase parent, IRowCursor input, bool[] active) + public Cursor(IChannelProvider provider, OneToOneTransformBase parent, RowCursor input, bool[] active) : base(provider, input) { Ch.AssertValue(parent); @@ -845,35 +847,34 @@ public RowCursor(IChannelProvider provider, OneToOneTransformBase parent, IRowCu _active = active; _getters = new Delegate[parent.Infos.Length]; - // Build the delegates. - List disposers = null; + // Build the disposing delegate. + Action masterDisposer = null; for (int iinfo = 0; iinfo < _getters.Length; iinfo++) { if (!IsColumnActive(parent._bindings.MapIinfoToCol(iinfo))) continue; - Action disposer; - _getters[iinfo] = parent.GetGetterCore(Ch, Input, iinfo, out disposer); + _getters[iinfo] = parent.GetGetterCore(Ch, Input, iinfo, out Action disposer); if (disposer != null) - Utils.Add(ref disposers, disposer); + masterDisposer += disposer; } - - if (Utils.Size(disposers) > 0) - _disposers = disposers.ToArray(); + _disposer = masterDisposer; } - public override void Dispose() + protected override void Dispose(bool disposing) { - if (_disposers != null) + if (_disposed) + return; + if (disposing) { - foreach (var act in _disposers) - act(); + _disposer?.Invoke(); } - base.Dispose(); + _disposed = true; + base.Dispose(disposing); } - public Schema Schema => _bindings.AsSchema; + public override Schema Schema => _bindings.AsSchema; - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Ch.Check(IsColumnActive(col)); @@ -889,7 +890,7 @@ public ValueGetter GetGetter(int col) return fn; } - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { Ch.Check(0 <= col && col < _bindings.ColumnCount); return _active == null || _active[col]; diff --git a/src/Microsoft.ML.Data/Transforms/TransformsCatalog.cs b/src/Microsoft.ML.Data/Transforms/TransformsCatalog.cs index 42f2031807..38033db8cf 100644 --- a/src/Microsoft.ML.Data/Transforms/TransformsCatalog.cs +++ b/src/Microsoft.ML.Data/Transforms/TransformsCatalog.cs @@ -2,7 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -namespace Microsoft.ML.Runtime +namespace Microsoft.ML { /// /// Similar to training context, a transform context is an object serving as a 'catalog' of available transforms. diff --git a/src/Microsoft.ML.Data/Transforms/TypeConverting.cs b/src/Microsoft.ML.Data/Transforms/TypeConverting.cs index 331e5fff71..59ef43e684 100644 --- a/src/Microsoft.ML.Data/Transforms/TypeConverting.cs +++ b/src/Microsoft.ML.Data/Transforms/TypeConverting.cs @@ -4,22 +4,19 @@ #pragma warning disable 420 // volatile with Interlocked.CompareExchange -using Microsoft.ML.Core.Data; -using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Model.Onnx; -using Microsoft.ML.StaticPipe; -using Microsoft.ML.StaticPipe.Runtime; -using Microsoft.ML.Transforms.Conversions; using System; using System.Collections.Generic; using System.Linq; using System.Text; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Model.Onnx; +using Microsoft.ML.Transforms.Conversions; [assembly: LoadableClass(TypeConvertingTransformer.Summary, typeof(IDataTransform), typeof(TypeConvertingTransformer), typeof(TypeConvertingTransformer.Arguments), typeof(SignatureDataTransform), TypeConvertingTransformer.UserName, TypeConvertingTransformer.ShortName, "ConvertTransform", DocName = "transform/ConvertTransform.md")] @@ -50,7 +47,7 @@ public static CommonOutputs.TransformOutput Convert(IHostEnvironment env, TypeCo return new CommonOutputs.TransformOutput() { - Model = new TransformModel(h, view, input.Data), + Model = new TransformModelImpl(h, view, input.Data), OutputData = view }; } @@ -357,10 +354,10 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, => Create(env, ctx).MakeDataTransform(input); // Factory method for SignatureLoadRowMapper. - private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) - => Create(env, ctx).MakeRowMapper(Schema.Create(inputSchema)); + private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema) + => Create(env, ctx).MakeRowMapper(inputSchema); - protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); + private protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); internal static bool GetNewType(IExceptionContext ectx, ColumnType srcType, DataKind kind, KeyRange range, out PrimitiveType itemType) { @@ -370,7 +367,7 @@ internal static bool GetNewType(IExceptionContext ectx, ColumnType srcType, Data if (!srcType.ItemType.IsKey && !srcType.ItemType.IsText) return false; } - else if (!srcType.ItemType.IsKey) + else if (!(srcType.ItemType is KeyType key)) itemType = PrimitiveType.FromKind(kind); else if (!KeyType.IsValidDataKind(kind)) { @@ -379,7 +376,6 @@ internal static bool GetNewType(IExceptionContext ectx, ColumnType srcType, Data } else { - var key = srcType.ItemType.AsKey; ectx.Assert(KeyType.IsValidDataKind(key.RawKind)); int count = key.Count; // Technically, it's an error for the counts not to match, but we'll let the Conversions @@ -432,12 +428,12 @@ private static bool CanConvertToType(IExceptionContext ectx, ColumnType srcType, // Ensure that the conversion is legal. We don't actually cache the delegate here. It will get // re-fetched by the utils code when needed. - if (!Runtime.Data.Conversion.Conversions.Instance.TryGetStandardConversion(srcType.ItemType, itemType, out Delegate del, out bool identity)) + if (!Data.Conversion.Conversions.Instance.TryGetStandardConversion(srcType.ItemType, itemType, out Delegate del, out bool identity)) return false; typeDst = itemType; - if (srcType.IsVector) - typeDst = new VectorType(itemType, srcType.AsVector); + if (srcType is VectorType vectorType) + typeDst = new VectorType(itemType, vectorType); return true; } @@ -468,14 +464,14 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() return result; } - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { Contracts.AssertValue(input); Contracts.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); disposer = null; - if (!_types[iinfo].IsVector) + if (!(_types[iinfo] is VectorType vectorType)) return RowCursorUtils.GetGetterAs(_types[iinfo], input, _srcCols[iinfo]); - return RowCursorUtils.GetVecGetterAs(_types[iinfo].AsVector.ItemType, input, _srcCols[iinfo]); + return RowCursorUtils.GetVecGetterAs(vectorType.ItemType, input, _srcCols[iinfo]); } public void SaveAsOnnx(OnnxContext ctx) @@ -507,7 +503,7 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, node.AddAttribute("to", (byte)_parent._columns[iinfo].OutputKind); if (_parent._columns[iinfo].OutputKeyRange != null) { - var key = _types[iinfo].ItemType.AsKey; + var key = (KeyType)_types[iinfo].ItemType; node.AddAttribute("min", key.Min); node.AddAttribute("max", key.Count); node.AddAttribute("contiguous", key.Contiguous); @@ -552,14 +548,14 @@ public TypeConvertingEstimator(IHostEnvironment env, params TypeConvertingTransf public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in Transformer.Columns) { if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); if (!TypeConvertingTransformer.GetNewType(Host, col.ItemType, colInfo.OutputKind, colInfo.OutputKeyRange, out PrimitiveType newType)) throw Host.ExceptParam(nameof(inputSchema), $"Can't convert {colInfo.Input} into {newType.ToString()}"); - if (!Runtime.Data.Conversion.Conversions.Instance.TryGetStandardConversion(col.ItemType, newType, out Delegate del, out bool identity)) + if (!Data.Conversion.Conversions.Instance.TryGetStandardConversion(col.ItemType, newType, out Delegate del, out bool identity)) throw Host.ExceptParam(nameof(inputSchema), $"Don't know how to convert {colInfo.Input} into {newType.ToString()}"); var metadata = new List(); if (col.ItemType.IsBool && newType.ItemType.IsNumber) @@ -578,64 +574,4 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) return new SchemaShape(result.Values); } } - - public static partial class ConvertStaticExtensions - { - - private interface IConvertCol - { - PipelineColumn Input { get; } - DataKind Kind { get; } - } - - private sealed class ImplScalar : Scalar, IConvertCol - { - public PipelineColumn Input { get; } - public DataKind Kind { get; } - public ImplScalar(PipelineColumn input, DataKind kind) : base(Rec.Inst, input) - { - Input = input; - Kind = kind; - } - } - - private sealed class ImplVector : Vector, IConvertCol - { - public PipelineColumn Input { get; } - public DataKind Kind { get; } - public ImplVector(PipelineColumn input, DataKind kind) : base(Rec.Inst, input) - { - Input = input; - Kind = kind; - } - } - - private sealed class ImplVarVector : VarVector, IConvertCol - { - public PipelineColumn Input { get; } - public DataKind Kind { get; } - public ImplVarVector(PipelineColumn input, DataKind kind) : base(Rec.Inst, input) - { - Input = input; - Kind = kind; - } - } - - private sealed class Rec : EstimatorReconciler - { - public static readonly Rec Inst = new Rec(); - - public override IEstimator Reconcile(IHostEnvironment env, PipelineColumn[] toOutput, - IReadOnlyDictionary inputNames, IReadOnlyDictionary outputNames, IReadOnlyCollection usedNames) - { - var infos = new TypeConvertingTransformer.ColumnInfo[toOutput.Length]; - for (int i = 0; i < toOutput.Length; ++i) - { - var tcol = (IConvertCol)toOutput[i]; - infos[i] = new TypeConvertingTransformer.ColumnInfo(inputNames[tcol.Input], outputNames[toOutput[i]], tcol.Kind); - } - return new TypeConvertingEstimator(env, infos); - } - } - } } diff --git a/src/Microsoft.ML.Data/Transforms/ValueMappingTransformer.cs b/src/Microsoft.ML.Data/Transforms/ValueMappingTransformer.cs new file mode 100644 index 0000000000..7ab229169b --- /dev/null +++ b/src/Microsoft.ML.Data/Transforms/ValueMappingTransformer.cs @@ -0,0 +1,974 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data; +using Microsoft.ML.Data.IO; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Transforms.Conversions; + +[assembly: LoadableClass(ValueMappingTransformer.Summary, typeof(IDataTransform), typeof(ValueMappingTransformer), + typeof(ValueMappingTransformer.Arguments), typeof(SignatureDataTransform), + ValueMappingTransformer.UserName, "ValueMapping", "ValueMappingTransformer", ValueMappingTransformer.ShortName, + "TermLookup", "Lookup", "LookupTransform", DocName = "transform/ValueMappingTransformer.md")] + +[assembly: LoadableClass(ValueMappingTransformer.Summary, typeof(IDataTransform), typeof(ValueMappingTransformer), null, typeof(SignatureLoadDataTransform), + "Value Mapping Transform", ValueMappingTransformer.LoaderSignature, ValueMappingTransformer.TermLookupLoaderSignature)] + +[assembly: LoadableClass(ValueMappingTransformer.Summary, typeof(ValueMappingTransformer), null, typeof(SignatureLoadModel), + "Value Mapping Transform", ValueMappingTransformer.LoaderSignature)] + +[assembly: LoadableClass(typeof(IRowMapper), typeof(ValueMappingTransformer), null, typeof(SignatureLoadRowMapper), + ValueMappingTransformer.UserName, ValueMappingTransformer.LoaderSignature)] + +namespace Microsoft.ML.Transforms.Conversions +{ + /// + /// The ValueMappingEstimator is a 1-1 mapping from a key to value. The key type and value type are specified + /// through TKey and TValue. Arrays are supported for vector types which can be used as either a key or a value + /// or both. The mapping is specified, not trained by providiing a list of keys and a list of values. + /// + /// Specifies the key type. + /// Specifies the value type. + public sealed class ValueMappingEstimator : TrivialEstimator> + { + private (string input, string output)[] _columns; + + /// + /// Constructs the ValueMappingEstimator, key type -> value type mapping + /// + /// The environment to use. + /// The list of keys of TKey. + /// The list of values of TValue. + /// The list of columns to apply. + public ValueMappingEstimator(IHostEnvironment env, IEnumerable keys, IEnumerable values, params (string input, string output)[] columns) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ValueMappingEstimator)), + new ValueMappingTransformer(env, keys, values, false, columns)) + { + _columns = columns; + } + + /// + /// Constructs the ValueMappingEstimator, key type -> value type mapping + /// + /// The environment to use. + /// The list of keys of TKey. + /// The list of values of TValue. + /// Specifies to treat the values as a . + /// The list of columns to apply. + public ValueMappingEstimator(IHostEnvironment env, IEnumerable keys, IEnumerable values, bool treatValuesAsKeyType, params (string input, string output)[] columns) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ValueMappingEstimator)), + new ValueMappingTransformer(env, keys, values, treatValuesAsKeyType, columns)) + { + _columns = columns; + } + + /// + /// Constructs the ValueMappingEstimator, key type -> value array type mapping + /// + /// The environment to use. + /// The list of keys of TKey. + /// The list of values of TValue[]. + /// The list of columns to apply. + public ValueMappingEstimator(IHostEnvironment env, IEnumerable keys, IEnumerable values, params (string input, string output)[] columns) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ValueMappingEstimator)), + new ValueMappingTransformer(env, keys, values, columns)) + { + _columns = columns; + } + + /// + /// Retrieves the output schema given the input schema + /// + /// Input schema + /// Returns the generated output schema + public override SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + Host.CheckValue(inputSchema, nameof(inputSchema)); + + var resultDic = inputSchema.ToDictionary(x => x.Name); + var vectorKind = Transformer.ValueColumnType.IsVector ? SchemaShape.Column.VectorKind.Vector : SchemaShape.Column.VectorKind.Scalar; + var isKey = Transformer.ValueColumnType.IsKey; + var columnType = (isKey) ? PrimitiveType.FromKind(DataKind.U4) : + Transformer.ValueColumnType; + foreach (var (Input, Output) in _columns) + { + if (!inputSchema.TryFindColumn(Input, out var originalColumn)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", Input); + + // Get the type from TOutputType + var col = new SchemaShape.Column(Output, vectorKind, columnType, isKey, originalColumn.Metadata); + resultDic[Output] = col; + } + return new SchemaShape(resultDic.Values); + } + } + + /// + /// The DataViewHelper provides a set of static functions to create a DataView given a list of keys and values. + /// + internal class DataViewHelper + { + /// + /// Helper function to retrieve the Primitie type given a Type + /// + internal static PrimitiveType GetPrimitiveType(Type rawType, out bool isVectorType) + { + Type type = rawType; + isVectorType = false; + if (type.IsArray) + { + type = rawType.GetElementType(); + isVectorType = true; + } + + if (!type.TryGetDataKind(out DataKind kind)) + throw new InvalidOperationException($"Unsupported type {type} used in mapping."); + + return PrimitiveType.FromKind(kind); + } + + /// + /// Helper function for a reverse lookup given value. This is used for generating the metadata of the value column. + /// + + private static ValueGetter>> GetKeyValueGetter(TKey[] keys) + { + return + (ref VBuffer> dst) => + { + var editor = VBufferEditor.Create(ref dst, keys.Length); + for (int i = 0; i < keys.Length; i++) + editor.Values[i] = keys[i].ToString().AsMemory(); + dst = editor.Commit(); + }; + } + + /// + /// Helper function to create an IDataView given a list of key and vector-based values + /// + internal static IDataView CreateDataView(IHostEnvironment env, + IEnumerable keys, + IEnumerable values, + string keyColumnName, + string valueColumnName) + { + var keyType = GetPrimitiveType(typeof(TKey), out bool isKeyVectorType); + var valueType = GetPrimitiveType(typeof(TValue), out bool isValueVectorType); + var dataViewBuilder = new ArrayDataViewBuilder(env); + dataViewBuilder.AddColumn(keyColumnName, keyType, keys.ToArray()); + dataViewBuilder.AddColumn(valueColumnName, valueType, values.ToArray()); + return dataViewBuilder.GetDataView(); + } + + /// + /// Helper function that builds the IDataView given a list of keys and non-vector values + /// + internal static IDataView CreateDataView(IHostEnvironment env, + IEnumerable keys, + IEnumerable values, + string keyColumnName, + string valueColumnName, + bool treatValuesAsKeyTypes) + { + var keyType = GetPrimitiveType(typeof(TKey), out bool isKeyVectorType); + var valueType = GetPrimitiveType(typeof(TValue), out bool isValueVectorType); + + var dataViewBuilder = new ArrayDataViewBuilder(env); + dataViewBuilder.AddColumn(keyColumnName, keyType, keys.ToArray()); + if (treatValuesAsKeyTypes) + { + // When treating the values as KeyTypes, generate the unique + // set of values. This is used for generating the metadata of + // the column. + HashSet valueSet = new HashSet(); + HashSet keySet = new HashSet(); + for (int i = 0; i < values.Count(); ++i) + { + var v = values.ElementAt(i); + if (valueSet.Contains(v)) + continue; + valueSet.Add(v); + + var k = keys.ElementAt(i); + keySet.Add(k); + } + var metaKeys = keySet.ToArray(); + + // Key Values are treated in one of two ways: + // If the values are of type uint or ulong, these values are used directly as the keys types and no new keys are created. + // If the values are not of uint or ulong, then key values are generated as uints starting from 1, since 0 is missing key. + if (valueType.RawKind == DataKind.U4) + { + uint[] indices = values.Select((x) => Convert.ToUInt32(x)).ToArray(); + dataViewBuilder.AddColumn(valueColumnName, GetKeyValueGetter(metaKeys), 0, metaKeys.Length, indices); + } + else if (valueType.RawKind == DataKind.U8) + { + ulong[] indices = values.Select((x) => Convert.ToUInt64(x)).ToArray(); + dataViewBuilder.AddColumn(valueColumnName, GetKeyValueGetter(metaKeys), 0, metaKeys.Length, indices); + } + else + { + // When generating the indices, treat each value as being unique, i.e. two values that are the same will + // be assigned the same index. The dictionary is used to maintain uniqueness, indices will contain + // the full list of indices (equal to the same length of values). + Dictionary keyTypeValueMapping = new Dictionary(); + uint[] indices = new uint[values.Count()]; + // Start the index at 1 + uint index = 1; + for (int i = 0; i < values.Count(); ++i) + { + TValue value = values.ElementAt(i); + if (!keyTypeValueMapping.ContainsKey(value)) + { + keyTypeValueMapping.Add(value, index); + index++; + } + + var keyValue = keyTypeValueMapping[value]; + indices[i] = keyValue; + } + + dataViewBuilder.AddColumn(valueColumnName, GetKeyValueGetter(metaKeys), 0, metaKeys.Count(), indices); + } + } + else + dataViewBuilder.AddColumn(valueColumnName, valueType, values.ToArray()); + + return dataViewBuilder.GetDataView(); + } + } + + /// + /// The ValueMappingTransformer is a 1-1 mapping from a key to value. The key type and value type are specified + /// through TKey and TValue. Arrays are supported for vector types which can be used as either a key or a value + /// or both. The mapping is specified, not trained by providiing a list of keys and a list of values. + /// + /// Specifies the key type + /// Specifies the value type + public sealed class ValueMappingTransformer : ValueMappingTransformer + { + /// + /// Constructs a ValueMappingTransformer with a key type to value type. + /// + /// The environment to use. + /// The list of keys that are TKey. + /// The list of values that are TValue. + /// Specifies to treat the values as a . + /// The specified columns to apply + public ValueMappingTransformer(IHostEnvironment env, IEnumerable keys, IEnumerable values, bool treatValuesAsKeyTypes, (string input, string output)[] columns) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ValueMappingTransformer)), + ConvertToDataView(env, keys, values, treatValuesAsKeyTypes), KeyColumnName, ValueColumnName, columns) + { } + + /// + /// Constructs a ValueMappingTransformer with a key type to value array type. + /// + /// The environment to use. + /// The list of keys that are TKey. + /// The list of values that are TValue[]. + /// The specified columns to apply. + public ValueMappingTransformer(IHostEnvironment env, IEnumerable keys, IEnumerable values, (string input, string output)[] columns) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ValueMappingTransformer)), + ConvertToDataView(env, keys, values), KeyColumnName, ValueColumnName, columns) + { } + + private static IDataView ConvertToDataView(IHostEnvironment env, IEnumerable keys, IEnumerable values, bool treatValuesAsKeyValue) + => DataViewHelper.CreateDataView(env, + keys, + values, + ValueMappingTransformer.KeyColumnName, + ValueMappingTransformer.ValueColumnName, + treatValuesAsKeyValue); + + // Handler for vector value types + private static IDataView ConvertToDataView(IHostEnvironment env, IEnumerable keys, IEnumerable values) + => DataViewHelper.CreateDataView(env, keys, values, ValueMappingTransformer.KeyColumnName, ValueMappingTransformer.ValueColumnName); + } + + public class ValueMappingTransformer : OneToOneTransformerBase + { + internal const string Summary = "Maps text values columns to new columns using a map dataset."; + internal const string LoaderSignature = "ValueMappingTransformer"; + internal const string UserName = "Value Mapping Transform"; + internal const string ShortName = "ValueMap"; + + internal const string TermLookupLoaderSignature = "TermLookupTransform"; + + // Stream names for the binary idv streams. + private const string DefaultMapName = "DefaultMap.idv"; + protected static string KeyColumnName = "Key"; + protected static string ValueColumnName = "Value"; + private ValueMap _valueMap; + private Schema.Metadata _valueMetadata; + private byte[] _dataView; + + public ColumnType ValueColumnType => _valueMap.ValueType; + public Schema.Metadata ValueColumnMetadata => _valueMetadata; + + private static VersionInfo GetVersionInfo() + { + return new VersionInfo( + modelSignature: "VALUMAPG", + verWrittenCur: 0x00010001, // Initial. + verReadableCur: 0x00010001, + verWeCanReadBack: 0x00010001, + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(ValueMappingTransformer).Assembly.FullName); + } + + private static VersionInfo GetTermLookupVersionInfo() + { + return new VersionInfo( + modelSignature: "TXTLOOKT", + // verWrittenCur: 0x00010001, // Initial. + verWrittenCur: 0x00010002, // Dropped sizeof(Float). + verReadableCur: 0x00010002, + verWeCanReadBack: 0x00010002, + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(ValueMappingTransformer).Assembly.FullName); + } + + public sealed class Column : OneToOneColumn + { + public static Column Parse(string str) + { + var res = new Column(); + if (res.TryParse(str)) + return res; + return null; + } + + public bool TryUnparse(StringBuilder sb) + { + Contracts.AssertValue(sb); + return TryUnparseCore(sb); + } + } + + public sealed class Arguments + { + [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col", SortOrder = 1)] + public Column[] Column; + + [Argument(ArgumentType.AtMostOnce, IsInputFileName = true, HelpText = "The data file containing the terms", ShortName = "data", SortOrder = 2)] + public string DataFile; + + [Argument(ArgumentType.AtMostOnce, HelpText = "The name of the column containing the keys", ShortName = "keyCol, term, TermColumn")] + public string KeyColumn; + + [Argument(ArgumentType.AtMostOnce, HelpText = "The name of the column containing the values", ShortName = "valueCol, value")] + public string ValueColumn; + + [Argument(ArgumentType.Multiple, HelpText = "The data loader", NullName = "", SignatureType = typeof(SignatureDataLoader))] + public IComponentFactory Loader; + + [Argument(ArgumentType.AtMostOnce, + HelpText = "Specifies whether the values are key values or numeric, only valid when loader is not specified and the type of data is not an idv.", + ShortName = "key")] + public bool ValuesAsKeyType = true; + } + + protected ValueMappingTransformer(IHostEnvironment env, IDataView lookupMap, + string keyColumn, string valueColumn, (string input, string output)[] columns) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ValueMappingTransformer)), columns) + { + Host.CheckNonEmpty(keyColumn, nameof(keyColumn), "A key column must be specified when passing in an IDataView for the value mapping"); + Host.CheckNonEmpty(valueColumn, nameof(valueColumn), "A value column must be specified when passing in an IDataView for the value mapping"); + _valueMap = CreateValueMapFromDataView(lookupMap, keyColumn, valueColumn); + int valueColumnIdx = 0; + Host.Assert(lookupMap.Schema.TryGetColumnIndex(valueColumn, out valueColumnIdx)); + _valueMetadata = lookupMap.Schema[valueColumnIdx].Metadata; + + // Create the byte array of the original IDataView, this is used for saving out the data. + _dataView = GetBytesFromDataView(Host, lookupMap, keyColumn, valueColumn); + } + + private ValueMap CreateValueMapFromDataView(IDataView dataView, string keyColumn, string valueColumn) + { + // Confirm that the key and value columns exist in the dataView + Host.Check(dataView.Schema.TryGetColumnIndex(keyColumn, out int keyIdx), "Key column " + keyColumn + " does not exist in the given dataview"); + Host.Check(dataView.Schema.TryGetColumnIndex(valueColumn, out int valueIdx), "Value column " + valueColumn + " does not exist in the given dataview"); + var keyType = dataView.Schema[keyIdx].Type; + var valueType = dataView.Schema[valueIdx].Type; + var valueMap = ValueMap.Create(keyType, valueType, _valueMetadata); + using (var cursor = dataView.GetRowCursor(c => c == keyIdx || c == valueIdx)) + valueMap.Train(Host, cursor); + return valueMap; + } + + private static TextLoader.Column GenerateValueColumn(IHostEnvironment env, + IDataView loader, + string valueColumnName, + int keyIdx, + int valueIdx, + string fileName) + { + // Scan the source to determine the min max of the column + ulong keyMin = ulong.MaxValue; + ulong keyMax = ulong.MinValue; + + // scan the input to create convert the values as key types + using (var cursor = loader.GetRowCursor(c => true)) + { + using (var ch = env.Start($"Processing key values from file {fileName}")) + { + var getKey = cursor.GetGetter>(keyIdx); + var getValue = cursor.GetGetter>(valueIdx); + int countNonKeys = 0; + + ReadOnlyMemory key = default; + ReadOnlyMemory value = default; + while (cursor.MoveNext()) + { + getKey(ref key); + getValue(ref value); + + ulong res; + // Try to parse the text as a key value between 1 and ulong.MaxValue. If this succeeds and res>0, + // we update max and min accordingly. If res==0 it means the value is missing, in which case we ignore it for + // computing max and min. + if (Microsoft.ML.Data.Conversion.Conversions.Instance.TryParseKey(in value, 1, ulong.MaxValue, out res)) + { + if (res < keyMin && res != 0) + keyMin = res; + if (res > keyMax) + keyMax = res; + } + // If parsing as key did not succeed, the value can still be 0, so we try parsing it as a ulong. If it succeeds, + // then the value is 0, and we update min accordingly. + else if (Microsoft.ML.Data.Conversion.Conversions.Instance.TryParse(in value, out res)) + { + keyMin = 0; + } + //If parsing as a ulong fails, we increment the counter for the non-key values. + else + { + if (countNonKeys < 5) + ch.Warning($"Key '{key}' in mapping file is mapped to non key value '{value}'"); + countNonKeys++; + } + } + + if (countNonKeys > 0) + ch.Warning($"Found {countNonKeys} non key values in the file '{fileName}'"); + if (keyMin > keyMax) + { + keyMin = 0; + keyMax = uint.MaxValue - 1; + ch.Warning($"Did not find any valid key values in the file '{fileName}'"); + } + else + ch.Info($"Found key values in the range {keyMin} to {keyMax} in the file '{fileName}'"); + } + } + + TextLoader.Column valueColumn = new TextLoader.Column(valueColumnName, DataKind.U4, 1); + if (keyMax - keyMin < (ulong)int.MaxValue) + { + valueColumn.KeyRange = new KeyRange(keyMin, keyMax); + } + else if (keyMax - keyMin < (ulong)uint.MaxValue) + { + valueColumn.KeyRange = new KeyRange(keyMin); + } + else + { + valueColumn.Type = DataKind.U8; + valueColumn.KeyRange = new KeyRange(keyMin); + } + + return valueColumn; + } + + private static ValueMappingTransformer CreateTransformInvoke(IHostEnvironment env, + IDataView idv, + string keyColumnName, + string valueColumnName, + bool treatValuesAsKeyTypes, + (string input, string output)[] columns) + { + // Read in the data + // scan the input to create convert the values as key types + List keys = new List(); + List values = new List(); + + idv.Schema.TryGetColumnIndex(keyColumnName, out int keyIdx); + idv.Schema.TryGetColumnIndex(valueColumnName, out int valueIdx); + using (var cursor = idv.GetRowCursor(c => true)) + { + using (var ch = env.Start("Processing key values")) + { + TKey key = default; + TValue value = default; + var getKey = cursor.GetGetter(keyIdx); + var getValue = cursor.GetGetter(valueIdx); + while (cursor.MoveNext()) + { + try + { + getKey(ref key); + } + catch (InvalidOperationException) + { + ch.Warning("Invalid key parsed, row will be skipped."); + continue; + } + + try + { + getValue(ref value); + } + catch (InvalidOperationException) + { + ch.Warning("Invalid value parsed for key {key}, row will be skipped."); + continue; + } + + keys.Add(key); + values.Add(value); + } + } + } + + return new ValueMappingTransformer(env, keys, values, treatValuesAsKeyTypes, columns); + } + + private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(args, nameof(args)); + env.Assert(!string.IsNullOrWhiteSpace(args.DataFile)); + env.CheckValueOrNull(args.KeyColumn); + env.CheckValueOrNull(args.ValueColumn); + + var keyColumnName = (string.IsNullOrEmpty(args.KeyColumn)) ? KeyColumnName : args.KeyColumn; + var valueColumnName = (string.IsNullOrEmpty(args.ValueColumn)) ? ValueColumnName : args.ValueColumn; + + IMultiStreamSource fileSource = new MultiFileSource(args.DataFile); + IDataView loader; + if (args.Loader != null) + { + loader = args.Loader.CreateComponent(env, fileSource); + } + else + { + var extension = Path.GetExtension(args.DataFile); + if (extension.Equals(".idv", StringComparison.OrdinalIgnoreCase)) + loader = new BinaryLoader(env, new BinaryLoader.Arguments(), fileSource); + else if (extension.Equals(".tdv")) + loader = new TransposeLoader(env, new TransposeLoader.Arguments(), fileSource); + else + { + // The user has not specified how to load this file. This will attempt to load the + // data file as two text columns. If the user has also specified ValuesAsKeyTypes, + // this will default to the key column as a text column and the value column as a uint column + + // Set the keyColumnName and valueColumnName to the default values. + keyColumnName = KeyColumnName; + valueColumnName = ValueColumnName; + TextLoader.Column keyColumn = default; + TextLoader.Column valueColumn = default; + + // Default to a text loader. KeyType and ValueType are assumed to be string + // types unless ValueAsKeyType is specified. + if (args.ValuesAsKeyType) + { + keyColumn = new TextLoader.Column(keyColumnName, DataKind.TXT, 0); + valueColumn = new TextLoader.Column(valueColumnName, DataKind.TXT, 1); + var txtArgs = new TextLoader.Arguments() + { + Column = new TextLoader.Column[] + { + keyColumn, + valueColumn + } + }; + + try + { + var textLoader = TextLoader.ReadFile(env, txtArgs, fileSource); + valueColumn = GenerateValueColumn(env, textLoader, valueColumnName, 0, 1, args.DataFile); + } + catch (Exception ex) + { + throw env.Except(ex, "Failed to parse the lookup file '{args.DataFile}' in ValueMappingTransformerer"); + } + } + else + { + keyColumn = new TextLoader.Column(keyColumnName, DataKind.TXT, 0); + valueColumn = new TextLoader.Column(valueColumnName, DataKind.R4, 1); + } + + loader = TextLoader.Create( + env, + new TextLoader.Arguments() + { + Column = new TextLoader.Column[] + { + keyColumn, + valueColumn + } + }, + fileSource); + } + } + + env.AssertValue(loader); + env.Assert(loader.Schema.TryGetColumnIndex(keyColumnName, out int keyColumnIndex)); + env.Assert(loader.Schema.TryGetColumnIndex(valueColumnName, out int valueColumnIndex)); + + ValueMappingTransformer transformer = null; + (string Source, string Name)[] columns = args.Column.Select(x => (x.Source, x.Name)).ToArray(); + transformer = new ValueMappingTransformer(env, loader, keyColumnName, valueColumnName, columns); + return transformer.MakeDataTransform(input); + } + + /// + /// Helper function to determine the model version that is being loaded. + /// + private static bool CheckModelVersion(ModelLoadContext ctx, VersionInfo versionInfo) + { + try + { + ctx.CheckVersionInfo(versionInfo); + return true; + } + catch (Exception) + { + //consume + return false; + } + } + + protected static ValueMappingTransformer Create(IHostEnvironment env, ModelLoadContext ctx) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(ctx, nameof(ctx)); + + // Checks for both the TermLookup for backwards compatibility + var termLookupModel = CheckModelVersion(ctx, GetTermLookupVersionInfo()); + env.Check(termLookupModel || CheckModelVersion(ctx, GetVersionInfo())); + + // *** Binary format *** + // int: number of added columns + // for each added column + // string: output column name + // string: input column name + // Binary stream of mapping + + var length = ctx.Reader.ReadInt32(); + var columns = new (string Source, string Name)[length]; + for (int i = 0; i < length; i++) + { + columns[i].Name = ctx.LoadNonEmptyString(); + columns[i].Source = ctx.LoadNonEmptyString(); + } + + byte[] rgb = null; + Action fn = r => rgb = ReadAllBytes(env, r); + + if (!ctx.TryLoadBinaryStream(DefaultMapName, fn)) + throw env.ExceptDecode(); + + var binaryLoader = GetLoader(env, rgb); + var keyColumnName = (termLookupModel) ? "Term" : KeyColumnName; + return new ValueMappingTransformer(env, binaryLoader, keyColumnName, ValueColumnName, columns); + } + + private static byte[] ReadAllBytes(IExceptionContext ectx, BinaryReader rdr) + { + Contracts.AssertValue(ectx); + ectx.AssertValue(rdr); + ectx.Assert(rdr.BaseStream.CanSeek); + + long size = rdr.BaseStream.Length; + ectx.CheckDecode(size <= int.MaxValue); + + var rgb = new byte[(int)size]; + int cb = rdr.Read(rgb, 0, rgb.Length); + ectx.CheckDecode(cb == rgb.Length); + + return rgb; + } + + protected static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + => Create(env, ctx).MakeDataTransform(input); + + private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema) + => Create(env, ctx).MakeRowMapper(inputSchema); + + protected static PrimitiveType GetPrimitiveType(Type rawType, out bool isVectorType) + { + Type type = rawType; + isVectorType = false; + if (type.IsArray) + { + type = rawType.GetElementType(); + isVectorType = true; + } + + if (!type.TryGetDataKind(out DataKind kind)) + throw Contracts.Except($"Unsupported type {type} used in mapping."); + + return PrimitiveType.FromKind(kind); + } + + public override void Save(ModelSaveContext ctx) + { + Host.CheckValue(ctx, nameof(ctx)); + ctx.SetVersionInfo(GetVersionInfo()); + SaveColumns(ctx); + + // Save out the byte stream of the IDataView of the data source + ctx.SaveBinaryStream(DefaultMapName, w => w.Write(_dataView)); + } + + /// + /// Base class that contains the mapping of keys to values. + /// + private abstract class ValueMap + { + public readonly ColumnType KeyType; + public readonly ColumnType ValueType; + + public ValueMap(ColumnType keyType, ColumnType valueType) + { + KeyType = keyType; + ValueType = valueType; + } + + public static ValueMap Create(ColumnType keyType, ColumnType valueType, Schema.Metadata valueMetadata) + { + Func del = CreateValueMapInvoke; + var meth = del.Method.GetGenericMethodDefinition().MakeGenericMethod(keyType.RawType, valueType.RawType); + return (ValueMap)meth.Invoke(null, new object[] { keyType, valueType, valueMetadata }); + } + + private static ValueMap CreateValueMapInvoke(ColumnType keyType, + ColumnType valueType, + Schema.Metadata valueMetadata) + => new ValueMap(keyType, valueType, valueMetadata); + + public abstract void Train(IHostEnvironment env, RowCursor cursor); + + public abstract Delegate GetGetter(Row input, int index); + + public abstract IDataView GetDataView(IHostEnvironment env); + } + + /// + /// Implementation mapping class that maps a key of TKey to a specified value of TValue. + /// + private class ValueMap : ValueMap + { + private Dictionary _mapping; + private TValue _missingValue; + private Schema.Metadata _valueMetadata; + + private Dictionary CreateDictionary() + { + if (typeof(TKey) == typeof(ReadOnlyMemory)) + return new Dictionary, TValue>(new ReadOnlyMemoryUtils.ReadonlyMemoryCharComparer()) as Dictionary; + return new Dictionary(); + } + + public ValueMap(ColumnType keyType, ColumnType valueType, Schema.Metadata valueMetadata) + : base(keyType, valueType) + { + _mapping = CreateDictionary(); + _valueMetadata = valueMetadata; + } + + /// + /// Generates the mapping based on the IDataView + /// + public override void Train(IHostEnvironment env, RowCursor cursor) + { + // Validate that the conversion is supported for non-vector types + bool identity; + ValueMapper, TValue> conv; + + // For keys that are not in the mapping, the missingValue will be returned. + _missingValue = default; + if (!ValueType.IsVector) + { + // For handling missing values, this follows how a missing value is handled when loading from a text source. + // First check if there is a String->ValueType conversion method. If so, call the conversion method with an + // empty string, the returned value will be the new missing value. + // NOTE this will return NA for R4 and R8 types. + if (Microsoft.ML.Data.Conversion.Conversions.Instance.TryGetStandardConversion, TValue>( + TextType.Instance, + ValueType, + out conv, + out identity)) + { + TValue value = default; + conv(string.Empty.AsMemory(), ref value); + _missingValue = value; + } + } + + var keyGetter = cursor.GetGetter(0); + var valueGetter = cursor.GetGetter(1); + while (cursor.MoveNext()) + { + TKey key = default; + TValue value = default; + keyGetter(ref key); + valueGetter(ref value); + if (_mapping.ContainsKey(key)) + throw env.Except($"Duplicate keys in data '{key}'"); + + _mapping.Add(key, value); + } + } + + public override Delegate GetGetter(Row input, int index) + { + var src = default(TKey); + ValueGetter getSrc = input.GetGetter(index); + ValueGetter retVal = + (ref TValue dst) => + { + getSrc(ref src); + if (_mapping.ContainsKey(src)) + { + if (ValueType.IsVector) + dst = Utils.MarshalInvoke(GetVector, ValueType.ItemType.RawType, _mapping[src]); + else + dst = Utils.MarshalInvoke(GetValue, ValueType.RawType, _mapping[src]); + } + else + dst = _missingValue; + }; + return retVal; + } + + public override IDataView GetDataView(IHostEnvironment env) + => DataViewHelper.CreateDataView(env, + _mapping.Keys, + _mapping.Values, + ValueMappingTransformer.KeyColumnName, + ValueMappingTransformer.ValueColumnName, + ValueType.IsKey); + + private static TValue GetVector(TValue value) + { + if (value is VBuffer valueRef) + { + VBuffer dest = default; + valueRef.CopyTo(ref dest); + if (dest is TValue destRef) + return destRef; + } + + return default; + } + + private static TValue GetValue(TValue value) => value; + } + + /// + /// Retrieves the byte array given a dataview and columns + /// + private static byte[] GetBytesFromDataView(IHost host, IDataView lookup, string keyColumn, string valueColumn) + { + Contracts.AssertValue(host); + host.AssertValue(lookup); + host.AssertNonEmpty(keyColumn); + host.AssertNonEmpty(valueColumn); + + var schema = lookup.Schema; + + if (!schema.GetColumnOrNull(keyColumn).HasValue) + throw host.ExceptUserArg(nameof(Arguments.KeyColumn), $"Key column not found: '{keyColumn}'"); + if (!schema.GetColumnOrNull(valueColumn).HasValue) + throw host.ExceptUserArg(nameof(Arguments.ValueColumn), $"Value column not found: '{valueColumn}'"); + + var cols = new List<(string Source, string Name)>() + { + (keyColumn, KeyColumnName), + (valueColumn, ValueColumnName) + }; + + var view = new ColumnCopyingTransformer(host, cols.ToArray()).Transform(lookup); + view = ColumnSelectingTransformer.CreateKeep(host, view, cols.Select(x => x.Name).ToArray()); + + var saver = new BinarySaver(host, new BinarySaver.Arguments()); + using (var strm = new MemoryStream()) + { + saver.SaveData(strm, view, 0, 1); + return strm.ToArray(); + } + } + + private static BinaryLoader GetLoader(IHostEnvironment env, byte[] bytes) + { + env.AssertValue(env); + env.AssertValue(bytes); + + var strm = new MemoryStream(bytes, writable: false); + return new BinaryLoader(env, new BinaryLoader.Arguments(), strm); + } + + private protected override IRowMapper MakeRowMapper(Schema schema) + { + return new Mapper(this, schema, _valueMap, _valueMetadata, ColumnPairs); + } + + private sealed class Mapper : OneToOneMapperBase + { + private readonly Schema _inputSchema; + private readonly ValueMap _valueMap; + private readonly Schema.Metadata _valueMetadata; + private readonly (string Source, string Name)[] _columns; + private readonly ValueMappingTransformer _parent; + + internal Mapper(ValueMappingTransformer transform, + Schema inputSchema, + ValueMap valueMap, + Schema.Metadata valueMetadata, + (string input, string output)[] columns) + : base(transform.Host.Register(nameof(Mapper)), transform, inputSchema) + { + _inputSchema = inputSchema; + _valueMetadata = valueMetadata; + _valueMap = valueMap; + _columns = columns; + _parent = transform; + } + + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) + { + Host.AssertValue(input); + Host.Assert(0 <= iinfo && iinfo < _columns.Length); + disposer = null; + + return _valueMap.GetGetter(input, ColMapNewToOld[iinfo]); + } + + protected override Schema.DetachedColumn[] GetOutputColumnsCore() + { + var result = new Schema.DetachedColumn[_columns.Length]; + for (int i = 0; i < _columns.Length; i++) + { + var srcCol = _inputSchema[_columns[i].Source]; + result[i] = new Schema.DetachedColumn(_columns[i].Name, _valueMap.ValueType, _valueMetadata); + } + return result; + } + } + } +} diff --git a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingEstimator.cs b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingEstimator.cs index 4d01a62541..a66c731820 100644 --- a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingEstimator.cs +++ b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingEstimator.cs @@ -2,15 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Linq; using Microsoft.ML.Core.Data; using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.StaticPipe; -using Microsoft.ML.StaticPipe.Runtime; -using System; -using System.Collections.Generic; -using System.Linq; namespace Microsoft.ML.Transforms.Conversions { @@ -36,8 +30,8 @@ public static class Defaults /// Name of the column to be transformed. /// Name of the output column. If this is null '' will be used. /// Maximum number of keys to keep per column when auto-training. - /// How items should be ordered when vectorized. By default, they will be in the order encountered. - /// If by value items are sorted according to their default comparison, for example, text sorting will be case sensitive (for example, 'A' then 'Z' then 'a'). + /// How items should be ordered when vectorized. If choosen they will be in the order encountered. + /// If , items are sorted according to their default comparison, for example, text sorting will be case sensitive (for example, 'A' then 'Z' then 'a'). public ValueToKeyMappingEstimator(IHostEnvironment env, string inputColumn, string outputColumn = null, int maxNumTerms = Defaults.MaxNumTerms, ValueToKeyMappingTransformer.SortOrder sort = Defaults.Sort) : this(env, new [] { new ValueToKeyMappingTransformer.ColumnInfo(inputColumn, outputColumn ?? inputColumn, maxNumTerms, sort) }) { @@ -60,7 +54,7 @@ public ValueToKeyMappingEstimator(IHostEnvironment env, ValueToKeyMappingTransfo public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in _columns) { if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) @@ -77,7 +71,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) kv = new SchemaShape.Column(MetadataUtils.Kinds.KeyValues, SchemaShape.Column.VectorKind.Vector, colInfo.TextKeyValues ? TextType.Instance : col.ItemType, col.IsKey); } - Contracts.AssertValue(kv); + Contracts.Assert(kv.IsValid); if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.SlotNames, out var slotMeta)) metadata = new SchemaShape(new[] { slotMeta, kv }); @@ -118,108 +112,9 @@ public sealed class ToKeyFitResult // At the moment this is empty. Once PR #863 clears, we can change this class to hold the output // key-values metadata. - public ToKeyFitResult(ValueToKeyMappingTransformer.TermMap map) - { - } - } - - public static partial class TermStaticExtensions - { - // I am not certain I see a good way to cover the distinct types beyond complete enumeration. - // Raw generics would allow illegal possible inputs, for example, Scalar. So, this is a partial - // class, and all the public facing extension methods for each possible type are in a T4 generated result. - - private const KeyValueOrder DefSort = (KeyValueOrder)ValueToKeyMappingEstimator.Defaults.Sort; - private const int DefMax = ValueToKeyMappingEstimator.Defaults.MaxNumTerms; - - private readonly struct Config - { - public readonly KeyValueOrder Order; - public readonly int Max; - public readonly Action OnFit; - - public Config(KeyValueOrder order, int max, Action onFit) - { - Order = order; - Max = max; - OnFit = onFit; - } - } - - private static Action Wrap(ToKeyFitResult.OnFit onFit) - { - if (onFit == null) - return null; - // The type T asociated with the delegate will be the actual value type once #863 goes in. - // However, until such time as #863 goes in, it would be too awkward to attempt to extract the metadata. - // For now construct the useless object then pass it into the delegate. - return map => onFit(new ToKeyFitResult(map)); - } - - private interface ITermCol - { - PipelineColumn Input { get; } - Config Config { get; } - } - - private sealed class ImplScalar : Key, ITermCol + [BestFriend] + internal ToKeyFitResult(ValueToKeyMappingTransformer.TermMap map) { - public PipelineColumn Input { get; } - public Config Config { get; } - public ImplScalar(PipelineColumn input, Config config) : base(Rec.Inst, input) - { - Input = input; - Config = config; - } - } - - private sealed class ImplVector : Vector>, ITermCol - { - public PipelineColumn Input { get; } - public Config Config { get; } - public ImplVector(PipelineColumn input, Config config) : base(Rec.Inst, input) - { - Input = input; - Config = config; - } - } - - private sealed class ImplVarVector : VarVector>, ITermCol - { - public PipelineColumn Input { get; } - public Config Config { get; } - public ImplVarVector(PipelineColumn input, Config config) : base(Rec.Inst, input) - { - Input = input; - Config = config; - } - } - - private sealed class Rec : EstimatorReconciler - { - public static readonly Rec Inst = new Rec(); - - public override IEstimator Reconcile(IHostEnvironment env, PipelineColumn[] toOutput, - IReadOnlyDictionary inputNames, IReadOnlyDictionary outputNames, IReadOnlyCollection usedNames) - { - var infos = new ValueToKeyMappingTransformer.ColumnInfo[toOutput.Length]; - Action onFit = null; - for (int i = 0; i < toOutput.Length; ++i) - { - var tcol = (ITermCol)toOutput[i]; - infos[i] = new ValueToKeyMappingTransformer.ColumnInfo(inputNames[tcol.Input], outputNames[toOutput[i]], - tcol.Config.Max, (ValueToKeyMappingTransformer.SortOrder)tcol.Config.Order); - if (tcol.Config.OnFit != null) - { - int ii = i; // Necessary because if we capture i that will change to toOutput.Length on call. - onFit += tt => tcol.Config.OnFit(tt.GetTermMap(ii)); - } - } - var est = new ValueToKeyMappingEstimator(env, infos); - if (onFit == null) - return est; - return est.WithOnFitDelegate(onFit); - } } } } diff --git a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs index 15b10c3e74..e35571c320 100644 --- a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs @@ -2,24 +2,23 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Data.IO; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Model.Onnx; -using Microsoft.ML.Runtime.Model.Pfa; -using Microsoft.ML.Transforms.Conversions; -using Newtonsoft.Json.Linq; using System; using System.Collections.Generic; using System.IO; using System.Linq; using System.Text; using System.Threading; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.Data.IO; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Model.Onnx; +using Microsoft.ML.Model.Pfa; +using Microsoft.ML.Transforms.Conversions; +using Newtonsoft.Json.Linq; [assembly: LoadableClass(ValueToKeyMappingTransformer.Summary, typeof(IDataTransform), typeof(ValueToKeyMappingTransformer), typeof(ValueToKeyMappingTransformer.Arguments), typeof(SignatureDataTransform), @@ -158,18 +157,11 @@ public ColInfo(string name, string source, ColumnType type) } } + /// + /// Describes how the transformer handles one column pair. + /// public class ColumnInfo { - public ColumnInfo(string input, string output, int maxNumTerms = ValueToKeyMappingEstimator.Defaults.MaxNumTerms, SortOrder sort = ValueToKeyMappingEstimator.Defaults.Sort, string[] term = null, bool textKeyValues = false) - { - Input = input; - Output = output; - Sort = sort; - MaxNumTerms = maxNumTerms; - Term = term; - TextKeyValues = textKeyValues; - } - public readonly string Input; public readonly string Output; public readonly SortOrder Sort; @@ -178,13 +170,41 @@ public ColumnInfo(string input, string output, int maxNumTerms = ValueToKeyMappi public readonly bool TextKeyValues; protected internal string Terms { get; set; } - } - - public const string Summary = "Converts input values (words, numbers, etc.) to index in a dictionary."; - public const string UserName = "Term Transform"; - public const string LoaderSignature = "TermTransform"; - public const string FriendlyName = "To Key"; + /// + /// Describes how the transformer handles one column pair. + /// + /// Name of input column. + /// Name of the column resulting from the transformation of . Null means is replaced. + /// Maximum number of terms to keep per column when auto-training. + /// How items should be ordered when vectorized. If choosen they will be in the order encountered. + /// If , items are sorted according to their default comparison, for example, text sorting will be case sensitive (for example, 'A' then 'Z' then 'a'). + /// List of terms. + /// Whether key value metadata should be text, regardless of the actual input type. + public ColumnInfo(string input, string output = null, + int maxNumTerms = ValueToKeyMappingEstimator.Defaults.MaxNumTerms, + SortOrder sort = ValueToKeyMappingEstimator.Defaults.Sort, + string[] term = null, + bool textKeyValues = false + ) + { + Contracts.CheckNonWhiteSpace(input, nameof(input)); + Input = input; + Output = output ?? input; + Sort = sort; + MaxNumTerms = maxNumTerms; + Term = term; + TextKeyValues = textKeyValues; + } + } + [BestFriend] + internal const string Summary = "Converts input values (words, numbers, etc.) to index in a dictionary."; + [BestFriend] + internal const string UserName = "Term Transform"; + [BestFriend] + internal const string LoaderSignature = "TermTransform"; + [BestFriend] + internal const string FriendlyName = "To Key"; private static VersionInfo GetVersionInfo() { @@ -202,7 +222,7 @@ private static VersionInfo GetVersionInfo() private const uint VerNonTextTypesSupported = 0x00010003; private const uint VerManagerNonTextTypesSupported = 0x00010002; - public const string TermManagerLoaderSignature = "TermManager"; + internal const string TermManagerLoaderSignature = "TermManager"; private static volatile MemoryStreamPool _codecFactoryPool; private volatile CodecFactory _codecFactory; @@ -248,7 +268,7 @@ internal string TestIsKnownDataKind(ColumnType type) return "standard type or a vector of standard type"; } - private ColInfo[] CreateInfos(ISchema inputSchema) + private ColInfo[] CreateInfos(Schema inputSchema) { Host.AssertValue(inputSchema); var infos = new ColInfo[ColumnPairs.Length]; @@ -256,7 +276,7 @@ private ColInfo[] CreateInfos(ISchema inputSchema) { if (!inputSchema.TryGetColumnIndex(ColumnPairs[i].input, out int colSrc)) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[i].input); - var type = inputSchema.GetColumnType(colSrc); + var type = inputSchema[colSrc].Type; string reason = TestIsKnownDataKind(type); if (reason != null) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[i].input, reason, type.ToString()); @@ -289,8 +309,9 @@ internal ValueToKeyMappingTransformer(IHostEnvironment env, IDataView input, } } + [BestFriend] // Factory method for SignatureDataTransform. - public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(args, nameof(args)); @@ -388,49 +409,8 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, => Create(env, ctx).MakeDataTransform(input); // Factory method for SignatureLoadRowMapper. - private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) - => Create(env, ctx).MakeRowMapper(Schema.Create(inputSchema)); - - /// - /// Initializes a new instance of . - /// - /// Host Environment. - /// Input . This is the output from previous transform or loader. - /// Name of the output column. - /// Name of the column to be transformed. If this is null '' will be used. - /// Maximum number of terms to keep per column when auto-training. - /// How items should be ordered when vectorized. By default, they will be in the order encountered. - /// If by value items are sorted according to their default comparison, for example, text sorting will be case sensitive (for example, 'A' then 'Z' then 'a'). - public static IDataView Create(IHostEnvironment env, - IDataView input, string name, string source = null, - int maxNumTerms = ValueToKeyMappingEstimator.Defaults.MaxNumTerms, SortOrder sort = ValueToKeyMappingEstimator.Defaults.Sort) => - new ValueToKeyMappingTransformer(env, input, new[] { new ColumnInfo(source ?? name, name, maxNumTerms, sort) }).MakeDataTransform(input); - - public static IDataTransform Create(IHostEnvironment env, ArgumentsBase args, ColumnBase[] column, IDataView input) - { - return Create(env, new Arguments() - { - Column = column.Select(x => new Column() - { - MaxNumTerms = x.MaxNumTerms, - Name = x.Name, - Sort = x.Sort, - Source = x.Source, - Term = x.Term, - Terms = x.Terms, - TextKeyValues = x.TextKeyValues - }).ToArray(), - Data = args.Data, - DataFile = args.DataFile, - Loader = args.Loader, - MaxNumTerms = args.MaxNumTerms, - Sort = args.Sort, - Term = args.Term, - Terms = args.Terms, - TermsColumn = args.TermsColumn, - TextKeyValues = args.TextKeyValues - }, input); - } + private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema) + => Create(env, ctx).MakeRowMapper(inputSchema); /// /// Utility method to create the file-based . @@ -483,13 +463,10 @@ private static TermMap CreateFileTermMap(IHostEnvironment env, IChannel ch, stri "{0} should not be specified when default loader is TextLoader. Ignoring {0}={1}", nameof(Arguments.TermsColumn), src); } - termData = TextLoader.ReadFile(env, - new TextLoader.Arguments() - { - Separator = "tab", - Column = new[] { new TextLoader.Column("Term", DataKind.TX, 0) } - }, - fileSource); + termData = new TextLoader(env, + columns: new[] { new TextLoader.Column("Term", DataKind.TX, 0) }, + dataSample: fileSource) + .Read(fileSource); src = "Term"; autoConvert = true; } @@ -499,7 +476,7 @@ private static TermMap CreateFileTermMap(IHostEnvironment env, IChannel ch, stri int colSrc; if (!termData.Schema.TryGetColumnIndex(src, out colSrc)) throw ch.ExceptUserArg(nameof(termsColumn), "Unknown column '{0}'", src); - var typeSrc = termData.Schema.GetColumnType(colSrc); + var typeSrc = termData.Schema[colSrc].Type; if (!autoConvert && !typeSrc.Equals(bldr.ItemType)) throw ch.ExceptUserArg(nameof(termsColumn), "Must be of type '{0}' but was '{1}'", bldr.ItemType, typeSrc); @@ -707,13 +684,14 @@ public override void Save(ModelSaveContext ctx) }); } - public TermMap GetTermMap(int iinfo) + [BestFriend] + internal TermMap GetTermMap(int iinfo) { Contracts.Assert(0 <= iinfo && iinfo < _unboundMaps.Length); return _unboundMaps[iinfo]; } - protected override IRowMapper MakeRowMapper(Schema schema) + private protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx, ISaveAsPfa @@ -721,7 +699,6 @@ private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx, ISaveAsPfa private readonly ColumnType[] _types; private readonly ValueToKeyMappingTransformer _parent; private readonly ColInfo[] _infos; - private readonly BoundTermMap[] _termMap; public bool CanSaveOnnx(OnnxContext ctx) => true; @@ -739,8 +716,8 @@ public Mapper(ValueToKeyMappingTransformer parent, Schema inputSchema) var type = _infos[i].TypeSrc; KeyType keyType = _parent._unboundMaps[i].OutputType; ColumnType colType; - if (type.IsVector) - colType = new VectorType(keyType, type.AsVector); + if (type is VectorType vectorType) + colType = new VectorType(keyType, vectorType); else colType = keyType; _types[i] = colType; @@ -768,7 +745,7 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() return result; } - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { Contracts.AssertValue(input); Contracts.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); @@ -777,7 +754,7 @@ protected override Delegate MakeGetter(IRow input, int iinfo, Func ac return Utils.MarshalInvoke(MakeGetter, type.RawType, input, iinfo); } - private Delegate MakeGetter(IRow row, int src) => _termMap[src].GetMappingGetter(row); + private Delegate MakeGetter(Row row, int src) => _termMap[src].GetMappingGetter(row); private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName) { diff --git a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformerImpl.cs b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformerImpl.cs index c136bda9c6..624e0e1c89 100644 --- a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformerImpl.cs +++ b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformerImpl.cs @@ -2,15 +2,13 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Data.IO; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; using System; using System.IO; using System.Text; +using Microsoft.ML.Data; +using Microsoft.ML.Data.IO; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; namespace Microsoft.ML.Transforms.Conversions { @@ -49,7 +47,7 @@ public static Builder Create(ColumnType type, SortOrder sortOrder) Contracts.Assert(sortOrder == SortOrder.Occurrence || sortOrder == SortOrder.Value); bool sorted = sortOrder == SortOrder.Value; - PrimitiveType itemType = type.ItemType.AsPrimitive; + PrimitiveType itemType = type.ItemType as PrimitiveType; Contracts.AssertValue(itemType); if (itemType.IsText) return new TextImpl(sorted); @@ -66,7 +64,7 @@ private static Builder CreateCore(PrimitiveType type, bool sorted) // of building our term dictionary. For the other types (practically, only the UX types), // we should ignore nothing. InPredicate mapsToMissing; - if (!Runtime.Data.Conversion.Conversions.Instance.TryGetIsNAPredicate(type, out mapsToMissing)) + if (!Data.Conversion.Conversions.Instance.TryGetIsNAPredicate(type, out mapsToMissing)) mapsToMissing = (in T val) => false; return new Impl(type, mapsToMissing, sorted); } @@ -206,7 +204,7 @@ protected Builder(PrimitiveType type) public override void ParseAddTermArg(ref ReadOnlyMemory terms, IChannel ch) { T val; - var tryParse = Runtime.Data.Conversion.Conversions.Instance.GetTryParseConversion(ItemType); + var tryParse = Data.Conversion.Conversions.Instance.GetTryParseConversion(ItemType); for (bool more = true; more;) { ReadOnlyMemory term; @@ -232,7 +230,7 @@ public override void ParseAddTermArg(ref ReadOnlyMemory terms, IChannel ch public override void ParseAddTermArg(string[] terms, IChannel ch) { T val; - var tryParse = Runtime.Data.Conversion.Conversions.Instance.GetTryParseConversion(ItemType); + var tryParse = Data.Conversion.Conversions.Instance.GetTryParseConversion(ItemType); foreach (var sterm in terms) { ReadOnlyMemory term = sterm.AsMemory(); @@ -280,15 +278,15 @@ private Trainer(Builder bldr, int max) /// the input type to the desired type /// The builder we add items to /// An associated training pipe - public static Trainer Create(IRow row, int col, bool autoConvert, int count, Builder bldr) + public static Trainer Create(Row row, int col, bool autoConvert, int count, Builder bldr) { Contracts.AssertValue(row); var schema = row.Schema; - Contracts.Assert(0 <= col && col < schema.ColumnCount); + Contracts.Assert(0 <= col && col < schema.Count); Contracts.Assert(count > 0); Contracts.AssertValue(bldr); - var type = schema.GetColumnType(col); + var type = schema[col].Type; Contracts.Assert(autoConvert || bldr.ItemType == type.ItemType); // Auto conversion should only be possible when the type is text. Contracts.Assert(type.IsText || !autoConvert); @@ -297,7 +295,7 @@ public static Trainer Create(IRow row, int col, bool autoConvert, int count, Bui return Utils.MarshalInvoke(CreateOne, bldr.ItemType.RawType, row, col, autoConvert, count, bldr); } - private static Trainer CreateOne(IRow row, int col, bool autoConvert, int count, Builder bldr) + private static Trainer CreateOne(Row row, int col, bool autoConvert, int count, Builder bldr) { Contracts.AssertValue(row); Contracts.AssertValue(bldr); @@ -313,7 +311,7 @@ private static Trainer CreateOne(IRow row, int col, bool autoConvert, int cou return new ImplOne(inputGetter, count, bldrT); } - private static Trainer CreateVec(IRow row, int col, int count, Builder bldr) + private static Trainer CreateVec(Row row, int col, int count, Builder bldr) { Contracts.AssertValue(row); Contracts.AssertValue(bldr); @@ -474,7 +472,8 @@ private static BoundTermMap Bind(IHostEnvironment env, Schema schema, TermMap un /// These are the immutable and serializable analogs to the used in /// training. /// - public abstract class TermMap + [BestFriend] + internal abstract class TermMap { /// /// The item type of the input type, that is, either the input type or, @@ -568,7 +567,7 @@ private static TermMap LoadCodecCore(ModelLoadContext ctx, IExceptionContext } } - return new HashArrayImpl(codec.Type.AsPrimitive, values); + return new HashArrayImpl((PrimitiveType)codec.Type, values); } internal abstract void WriteTextTerms(TextWriter writer); @@ -746,7 +745,7 @@ internal override void WriteTextTerms(TextWriter writer) { writer.WriteLine("# Number of terms of type '{0}' = {1}", ItemType, Count); StringBuilder sb = null; - var stringMapper = Runtime.Data.Conversion.Conversions.Instance.GetStringConversion(ItemType); + var stringMapper = Data.Conversion.Conversions.Instance.GetStringConversion(ItemType); for (int i = 0; i < _values.Count; ++i) { T val = _values.GetItem(i); @@ -757,7 +756,7 @@ internal override void WriteTextTerms(TextWriter writer) } } - public abstract class TermMap : TermMap + internal abstract class TermMap : TermMap { protected TermMap(PrimitiveType type, int count) : base(type, count) @@ -849,7 +848,7 @@ public static BoundTermMap CreateCore(IHostEnvironment env, Schema schema, Te return new Impl(env, schema, mapT, infos, textMetadata, iinfo); } - public abstract Delegate GetMappingGetter(IRow row); + public abstract Delegate GetMappingGetter(Row row); /// /// Allows us to optionally register metadata. It is also perfectly legal for @@ -890,7 +889,7 @@ private static uint MapDefault(ValueMapper map) return dst; } - public override Delegate GetMappingGetter(IRow input) + public override Delegate GetMappingGetter(Row input) { // When constructing the getter, there are a few cases we have to consider: // If scalar then it's just a straightforward mapping. @@ -1044,7 +1043,7 @@ public override void AddMetadata(MetadataBuilder builder) return; if (IsTextMetadata && !TypedMap.ItemType.IsText) { - var conv = Runtime.Data.Conversion.Conversions.Instance; + var conv = Data.Conversion.Conversions.Instance; var stringMapper = conv.GetStringConversion(TypedMap.ItemType); ValueGetter>> getter = @@ -1087,7 +1086,7 @@ public override void AddMetadata(MetadataBuilder builder) return; _schema.TryGetColumnIndex(_infos[_iinfo].Source, out int srcCol); - ColumnType srcMetaType = _schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, srcCol); + ColumnType srcMetaType = _schema[srcCol].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues)?.Type; if (srcMetaType == null || srcMetaType.VectorSize != TypedMap.ItemType.KeyCount || TypedMap.ItemType.KeyCount == 0 || !Utils.MarshalInvoke(AddMetadataCore, srcMetaType.ItemType.RawType, srcMetaType.ItemType, builder)) { @@ -1101,10 +1100,10 @@ private bool AddMetadataCore(ColumnType srcMetaType, MetadataBuilder buil _host.AssertValue(srcMetaType); _host.Assert(srcMetaType.RawType == typeof(TMeta)); _host.AssertValue(builder); - var srcType = TypedMap.ItemType.AsKey; + var srcType = TypedMap.ItemType as KeyType; _host.AssertValue(srcType); var dstType = new KeyType(DataKind.U4, srcType.Min, srcType.Count); - var convInst = Runtime.Data.Conversion.Conversions.Instance; + var convInst = Data.Conversion.Conversions.Instance; ValueMapper conv; bool identity; // If we can't convert this type to U4, don't try to pass along the metadata. @@ -1116,7 +1115,7 @@ private bool AddMetadataCore(ColumnType srcMetaType, MetadataBuilder buil (ref VBuffer dst) => { VBuffer srcMeta = default(VBuffer); - _schema.GetMetadata(MetadataUtils.Kinds.KeyValues, srcCol, ref srcMeta); + _schema[srcCol].Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref srcMeta); _host.Assert(srcMeta.Length == srcType.Count); VBuffer keyVals = default(VBuffer); @@ -1156,7 +1155,7 @@ private bool AddMetadataCore(ColumnType srcMetaType, MetadataBuilder buil getter(ref dst); _host.Assert(dst.Length == TypedMap.OutputType.KeyCount); }; - builder.AddKeyValues(TypedMap.OutputType.KeyCount, srcMetaType.ItemType.AsPrimitive, mgetter); + builder.AddKeyValues(TypedMap.OutputType.KeyCount, (PrimitiveType)srcMetaType.ItemType, mgetter); } return true; } @@ -1167,9 +1166,9 @@ public override void WriteTextTerms(TextWriter writer) return; _schema.TryGetColumnIndex(_infos[_iinfo].Source, out int srcCol); - ColumnType srcMetaType = _schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, srcCol); + ColumnType srcMetaType = _schema[srcCol].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues)?.Type; if (srcMetaType == null || srcMetaType.VectorSize != TypedMap.ItemType.KeyCount || - TypedMap.ItemType.KeyCount == 0 || !Utils.MarshalInvoke(WriteTextTermsCore, srcMetaType.ItemType.RawType, srcMetaType.AsVector.ItemType, writer)) + TypedMap.ItemType.KeyCount == 0 || !Utils.MarshalInvoke(WriteTextTermsCore, srcMetaType.ItemType.RawType, ((VectorType)srcMetaType).ItemType, writer)) { // No valid input key-value metadata. Back off to the base implementation. base.WriteTextTerms(writer); @@ -1180,10 +1179,10 @@ private bool WriteTextTermsCore(PrimitiveType srcMetaType, TextWriter wri { _host.AssertValue(srcMetaType); _host.Assert(srcMetaType.RawType == typeof(TMeta)); - var srcType = TypedMap.ItemType.AsKey; + var srcType = TypedMap.ItemType as KeyType; _host.AssertValue(srcType); var dstType = new KeyType(DataKind.U4, srcType.Min, srcType.Count); - var convInst = Runtime.Data.Conversion.Conversions.Instance; + var convInst = Data.Conversion.Conversions.Instance; ValueMapper conv; bool identity; // If we can't convert this type to U4, don't try. @@ -1192,7 +1191,7 @@ private bool WriteTextTermsCore(PrimitiveType srcMetaType, TextWriter wri _schema.TryGetColumnIndex(_infos[_iinfo].Source, out int srcCol); VBuffer srcMeta = default(VBuffer); - _schema.GetMetadata(MetadataUtils.Kinds.KeyValues, srcCol, ref srcMeta); + _schema[srcCol].Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref srcMeta); if (srcMeta.Length != srcType.Count) return false; diff --git a/src/Microsoft.ML.Data/Utilities/ApplyTransformUtils.cs b/src/Microsoft.ML.Data/Utilities/ApplyTransformUtils.cs index 7d0695083b..c8510c8413 100644 --- a/src/Microsoft.ML.Data/Utilities/ApplyTransformUtils.cs +++ b/src/Microsoft.ML.Data/Utilities/ApplyTransformUtils.cs @@ -4,9 +4,9 @@ using System.Collections.Generic; using System.IO; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Model; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// Utilities to rebind data transforms diff --git a/src/Microsoft.ML.Data/Utilities/ColumnCursor.cs b/src/Microsoft.ML.Data/Utilities/ColumnCursor.cs index 75d24b3340..5928e8ece7 100644 --- a/src/Microsoft.ML.Data/Utilities/ColumnCursor.cs +++ b/src/Microsoft.ML.Data/Utilities/ColumnCursor.cs @@ -2,8 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; using System; using System.Collections.Generic; @@ -39,7 +37,7 @@ public static IEnumerable GetColumn(this IDataView data, IHostEnvironment // - If this is the same type, we can map directly. // - Otherwise, we need a conversion delegate. - var colType = data.Schema.GetColumnType(col); + var colType = data.Schema[col].Type; if (colType.RawType == typeof(T)) { // Direct mapping is possible. @@ -82,7 +80,7 @@ public static IEnumerable GetColumn(this IDataView data, IHostEnvironment private static IEnumerable GetColumnDirect(IDataView data, int col) { Contracts.AssertValue(data); - Contracts.Assert(0 <= col && col < data.Schema.ColumnCount); + Contracts.Assert(0 <= col && col < data.Schema.Count); using (var cursor = data.GetRowCursor(col.Equals)) { @@ -99,7 +97,7 @@ private static IEnumerable GetColumnDirect(IDataView data, int col) private static IEnumerable GetColumnConvert(IDataView data, int col, Func convert) { Contracts.AssertValue(data); - Contracts.Assert(0 <= col && col < data.Schema.ColumnCount); + Contracts.Assert(0 <= col && col < data.Schema.Count); using (var cursor = data.GetRowCursor(col.Equals)) { @@ -116,7 +114,7 @@ private static IEnumerable GetColumnConvert(IDataView data, i private static IEnumerable GetColumnArrayDirect(IDataView data, int col) { Contracts.AssertValue(data); - Contracts.Assert(0 <= col && col < data.Schema.ColumnCount); + Contracts.Assert(0 <= col && col < data.Schema.Count); using (var cursor = data.GetRowCursor(col.Equals)) { @@ -137,7 +135,7 @@ private static IEnumerable GetColumnArrayDirect(IDataView data, int col) private static IEnumerable GetColumnArrayConvert(IDataView data, int col, Func convert) { Contracts.AssertValue(data); - Contracts.Assert(0 <= col && col < data.Schema.ColumnCount); + Contracts.Assert(0 <= col && col < data.Schema.Count); using (var cursor = data.GetRowCursor(col.Equals)) { diff --git a/src/Microsoft.ML.Api/ComponentCreation.cs b/src/Microsoft.ML.Data/Utilities/ComponentCreation.cs similarity index 93% rename from src/Microsoft.ML.Api/ComponentCreation.cs rename to src/Microsoft.ML.Data/Utilities/ComponentCreation.cs index e83cfe10b6..8af4fde39f 100644 --- a/src/Microsoft.ML.Api/ComponentCreation.cs +++ b/src/Microsoft.ML.Data/Utilities/ComponentCreation.cs @@ -5,12 +5,11 @@ using System; using System.Collections.Generic; using System.IO; +using Microsoft.ML.CommandLine; using Microsoft.ML.Core.Data; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Model; -namespace Microsoft.ML.Runtime.Api +namespace Microsoft.ML.Data { /// /// This class defines extension methods for an to facilitate creating @@ -44,7 +43,8 @@ public static IDataView Zip(this IHostEnvironment env, IEnumerable so /// The name of the weight column. Can be null. /// Additional column mapping to be passed to the trainer or scorer (specific to the prediction type). Can be null or empty. /// The constructed examples. - public static RoleMappedData CreateExamples(this IHostEnvironment env, IDataView data, string features, string label = null, + [BestFriend] + internal static RoleMappedData CreateExamples(this IHostEnvironment env, IDataView data, string features, string label = null, string group = null, string weight = null, IEnumerable> custom = null) { Contracts.CheckValue(env, nameof(env)); @@ -118,7 +118,7 @@ public static IDataView CreateStreamingDataView(this IHostEnvironment env, /// Whether to ignore missing columns in the data view. /// The optional input schema. If null, the schema is inferred from the type. /// The optional output schema. If null, the schema is inferred from the type. - public static BatchPredictionEngine CreateBatchPredictionEngine(this IHostEnvironment env, Stream modelStream, + internal static BatchPredictionEngine CreateBatchPredictionEngine(this IHostEnvironment env, Stream modelStream, bool ignoreMissingColumns = false, SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null) where TSrc : class where TDst : class, new() @@ -138,7 +138,7 @@ public static BatchPredictionEngine CreateBatchPredictionEngineWhether to ignore missing columns in the data view. /// The optional input schema. If null, the schema is inferred from the type. /// The optional output schema. If null, the schema is inferred from the type. - public static BatchPredictionEngine CreateBatchPredictionEngine(this IHostEnvironment env, IDataView dataPipe, + internal static BatchPredictionEngine CreateBatchPredictionEngine(this IHostEnvironment env, IDataView dataPipe, bool ignoreMissingColumns = false, SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null) where TSrc : class where TDst : class, new() @@ -263,8 +263,8 @@ public static IDataTransform CreateTransform(this IHostEnvironment env, string s /// additional information, for example, label names. If this is null, no information will be /// extracted. /// The scored data. - public static IDataScorerTransform CreateScorer(this IHostEnvironment env, string settings, - RoleMappedData data, Predictor predictor, RoleMappedSchema trainSchema = null) + internal static IDataScorerTransform CreateScorer(this IHostEnvironment env, string settings, + RoleMappedData data, IPredictor predictor, RoleMappedSchema trainSchema = null) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(data, nameof(data)); @@ -279,7 +279,7 @@ public static IDataScorerTransform CreateScorer(this IHostEnvironment env, strin signatureType, settings); - var bindable = ScoreUtils.GetSchemaBindableMapper(env, predictor.Pred, scorerFactorySettings: scorerFactorySettings); + var bindable = ScoreUtils.GetSchemaBindableMapper(env, predictor, scorerFactorySettings: scorerFactorySettings); var mapper = bindable.Bind(env, data.Schema); return CreateCore(env, factoryType, signatureType, settings, data.Data, mapper, trainSchema); } @@ -294,18 +294,20 @@ public static IDataScorerTransform CreateScorer(this IHostEnvironment env, strin /// additional information, for example, label names. If this is null, no information will be /// extracted. /// The scored data. - public static IDataScorerTransform CreateDefaultScorer(this IHostEnvironment env, RoleMappedData data, - Predictor predictor, RoleMappedSchema trainSchema = null) + [BestFriend] + internal static IDataScorerTransform CreateDefaultScorer(this IHostEnvironment env, RoleMappedData data, + IPredictor predictor, RoleMappedSchema trainSchema = null) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(data, nameof(data)); env.CheckValue(predictor, nameof(predictor)); env.CheckValueOrNull(trainSchema); - return ScoreUtils.GetScorer(predictor.Pred, data, env, trainSchema); + return ScoreUtils.GetScorer(predictor, data, env, trainSchema); } - public static IEvaluator CreateEvaluator(this IHostEnvironment env, string settings) + [BestFriend] + internal static IEvaluator CreateEvaluator(this IHostEnvironment env, string settings) { Contracts.CheckValue(env, nameof(env)); env.CheckNonWhiteSpace(settings, nameof(settings)); @@ -317,11 +319,10 @@ public static IEvaluator CreateEvaluator(this IHostEnvironment env, string setti /// /// The host environment to use. /// The model stream. - public static Predictor LoadPredictorOrNull(this IHostEnvironment env, Stream modelStream) + public static IPredictor LoadPredictorOrNull(this IHostEnvironment env, Stream modelStream) { Contracts.CheckValue(modelStream, nameof(modelStream)); - var p = ModelFileUtils.LoadPredictorOrNull(env, modelStream); - return p == null ? null : new Predictor(p); + return ModelFileUtils.LoadPredictorOrNull(env, modelStream); } internal static ITrainer CreateTrainer(this IHostEnvironment env, TArgs arguments, out string loadName) diff --git a/src/Microsoft.ML.Data/Utilities/LocalEnvironment.cs b/src/Microsoft.ML.Data/Utilities/LocalEnvironment.cs index f53f6f3558..a9e3a4a80e 100644 --- a/src/Microsoft.ML.Data/Utilities/LocalEnvironment.cs +++ b/src/Microsoft.ML.Data/Utilities/LocalEnvironment.cs @@ -5,7 +5,7 @@ using System; using System.ComponentModel.Composition.Hosting; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { using Stopwatch = System.Diagnostics.Stopwatch; @@ -71,7 +71,7 @@ public void RemoveListener(Action listener) protected override IFileHandle CreateTempFileCore(IHostEnvironment env, string suffix = null, string prefix = null) => base.CreateTempFileCore(env, suffix, "Local_" + prefix); - protected override IHost RegisterCore(HostEnvironmentBase source, string shortName, string parentFullName, IRandom rand, bool verbose, int? conc) + protected override IHost RegisterCore(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, bool verbose, int? conc) { Contracts.AssertValue(rand); Contracts.AssertValueOrNull(parentFullName); @@ -105,7 +105,7 @@ public override CompositionContainer GetCompositionContainer() private sealed class Host : HostBase { - public Host(HostEnvironmentBase source, string shortName, string parentFullName, IRandom rand, bool verbose, int? conc) + public Host(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, bool verbose, int? conc) : base(source, shortName, parentFullName, rand, verbose, conc) { IsCancelled = source.IsCancelled; @@ -127,7 +127,7 @@ protected override IPipe CreatePipe(ChannelProviderBase pare return new Pipe(parent, name, GetDispatchDelegate()); } - protected override IHost RegisterCore(HostEnvironmentBase source, string shortName, string parentFullName, IRandom rand, bool verbose, int? conc) + protected override IHost RegisterCore(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, bool verbose, int? conc) { return new Host(source, shortName, parentFullName, rand, verbose, conc); } diff --git a/src/Microsoft.ML.Data/Utilities/ModelFileUtils.cs b/src/Microsoft.ML.Data/Utilities/ModelFileUtils.cs index 97948b764b..6fcbbaaaf5 100644 --- a/src/Microsoft.ML.Data/Utilities/ModelFileUtils.cs +++ b/src/Microsoft.ML.Data/Utilities/ModelFileUtils.cs @@ -6,13 +6,12 @@ using System.Collections.Generic; using System.IO; using System.Linq; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Data.IO; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Internal.Internallearn; +using Microsoft.ML.Data; +using Microsoft.ML.Data.IO; +using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.Model +namespace Microsoft.ML.Model { using ColumnRole = RoleMappedSchema.ColumnRole; using Conditional = System.Diagnostics.ConditionalAttribute; @@ -20,7 +19,8 @@ namespace Microsoft.ML.Runtime.Model /// /// This class provides utilities for loading components from the model file generated by MAML commands. /// - public static class ModelFileUtils + [BestFriend] + internal static class ModelFileUtils { public const string DirPredictor = "Predictor"; public const string DirDataLoaderModel = "DataLoaderModel"; @@ -247,7 +247,7 @@ internal static void SaveRoleMappings(IHostEnvironment env, IChannel ch, RoleMap // ever go ahead and do something so stupid as that. var saver = new TextSaver(env, new TextSaver.Arguments() { Dense = true, Silent = true }); var view = builder.GetDataView(); - saver.SaveData(entry.Stream, view, Utils.GetIdentityPermutation(view.Schema.ColumnCount)); + saver.SaveData(entry.Stream, view, Utils.GetIdentityPermutation(view.Schema.Count)); } } @@ -283,8 +283,8 @@ public static IEnumerable> LoadRoleMappingsOrNu { // REVIEW: Should really validate the schema here, and consider // ignoring this stream if it isn't as expected. - var loader = TextLoader.ReadFile(env, new TextLoader.Arguments(), - new RepositoryStreamWrapper(rep, DirTrainingInfo, RoleMappingFile)); + var repoStreamWrapper = new RepositoryStreamWrapper(rep, DirTrainingInfo, RoleMappingFile); + var loader = new TextLoader(env, dataSample: repoStreamWrapper).Read(repoStreamWrapper); using (var cursor = loader.GetRowCursor(c => true)) { diff --git a/src/Microsoft.ML.Data/Utilities/PartitionedPathUtils.cs b/src/Microsoft.ML.Data/Utilities/PartitionedPathUtils.cs index b13a0d5cee..88717c91c7 100644 --- a/src/Microsoft.ML.Data/Utilities/PartitionedPathUtils.cs +++ b/src/Microsoft.ML.Data/Utilities/PartitionedPathUtils.cs @@ -6,7 +6,7 @@ using System.Collections.Generic; using System.IO; -namespace Microsoft.ML.Runtime.Data.Utilities +namespace Microsoft.ML.Data.Utilities { internal static class PartitionedPathUtils { diff --git a/src/Microsoft.ML.Data/Utilities/SlotDropper.cs b/src/Microsoft.ML.Data/Utilities/SlotDropper.cs index 188f8e72b5..402efda828 100644 --- a/src/Microsoft.ML.Data/Utilities/SlotDropper.cs +++ b/src/Microsoft.ML.Data/Utilities/SlotDropper.cs @@ -3,11 +3,10 @@ // See the LICENSE file in the project root for more information. using System; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.Internal.Internallearn +namespace Microsoft.ML.Internal.Internallearn { /// /// Drops slots from a fixed or variable sized column based on slot ranges. @@ -29,7 +28,7 @@ public sealed class SlotDropper /// Constructs slot dropper. It expects the slot ranges to be in sorted order and not overlap. /// /// 0 indicates variable sized vector. - /// Lower limit of ranges to be dropped. + /// Low limit of ranges to be dropped. /// Upper limit of ranges to be dropped. public SlotDropper(int srcLength, int[] slotsMin, int[] slotsMax) { diff --git a/src/Microsoft.ML.Data/Utilities/StreamUtils.cs b/src/Microsoft.ML.Data/Utilities/StreamUtils.cs index 5d52b45b3b..f157d09ea8 100644 --- a/src/Microsoft.ML.Data/Utilities/StreamUtils.cs +++ b/src/Microsoft.ML.Data/Utilities/StreamUtils.cs @@ -6,7 +6,7 @@ using System.IO; using System.Linq; -namespace Microsoft.ML.Runtime.Internal.Utilities +namespace Microsoft.ML.Internal.Utilities { // REVIEW: Implement properly on CoreCLR. [BestFriend] @@ -15,7 +15,7 @@ internal static class StreamUtils public static Stream OpenInStream(string fileName) { #if !CORECLR - return Microsoft.ML.Runtime.Internal.IO.ZStreamIn.Open(fileName); + return Microsoft.ML.Internal.IO.ZStreamIn.Open(fileName); #else return new FileStream(fileName, FileMode.Open, FileAccess.Read, FileShare.Read); #endif diff --git a/src/Microsoft.ML.Data/Utilities/TimerScope.cs b/src/Microsoft.ML.Data/Utilities/TimerScope.cs index c2a590ffe8..663ba669d5 100644 --- a/src/Microsoft.ML.Data/Utilities/TimerScope.cs +++ b/src/Microsoft.ML.Data/Utilities/TimerScope.cs @@ -4,7 +4,7 @@ using System; -namespace Microsoft.ML.Runtime.Internal.Utilities +namespace Microsoft.ML.Internal.Utilities { using Stopwatch = System.Diagnostics.Stopwatch; diff --git a/src/Microsoft.ML.Data/Utilities/TypeParsingUtils.cs b/src/Microsoft.ML.Data/Utilities/TypeParsingUtils.cs index a24d7d3883..da90e13e35 100644 --- a/src/Microsoft.ML.Data/Utilities/TypeParsingUtils.cs +++ b/src/Microsoft.ML.Data/Utilities/TypeParsingUtils.cs @@ -4,10 +4,9 @@ using System; using System.Text; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; +using Microsoft.ML.CommandLine; -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// Utilities to parse command-line representations of types. diff --git a/src/Microsoft.ML.Data/Utilities/TypeUtils.cs b/src/Microsoft.ML.Data/Utilities/TypeUtils.cs index 8459a7d4cd..eb91b97100 100644 --- a/src/Microsoft.ML.Data/Utilities/TypeUtils.cs +++ b/src/Microsoft.ML.Data/Utilities/TypeUtils.cs @@ -4,13 +4,11 @@ using System; using System.Linq; -using System.Reflection; using System.Text; using System.Text.RegularExpressions; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.Internal.Internallearn +namespace Microsoft.ML.Internal.Internallearn { public static class TypeUtils { diff --git a/src/Microsoft.ML.Api/ApiUtils.cs b/src/Microsoft.ML.Data/Utils/ApiUtils.cs similarity index 98% rename from src/Microsoft.ML.Api/ApiUtils.cs rename to src/Microsoft.ML.Data/Utils/ApiUtils.cs index 760ed1e768..d1b14f2b94 100644 --- a/src/Microsoft.ML.Api/ApiUtils.cs +++ b/src/Microsoft.ML.Data/Utils/ApiUtils.cs @@ -5,9 +5,9 @@ using System; using System.Reflection; using System.Reflection.Emit; -using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Data; -namespace Microsoft.ML.Runtime.Api +namespace Microsoft.ML { internal delegate void Peek(TRow row, long position, ref TValue value); @@ -18,11 +18,11 @@ internal static class ApiUtils private static OpCode GetAssignmentOpCode(Type t) { // REVIEW: This should be a Dictionary based solution. - // DvTypes, strings, arrays, all nullable types, VBuffers and UInt128. + // DvTypes, strings, arrays, all nullable types, VBuffers and RowId. if (t == typeof(ReadOnlyMemory) || t == typeof(string) || t.IsArray || (t.IsGenericType && t.GetGenericTypeDefinition() == typeof(VBuffer<>)) || (t.IsGenericType && t.GetGenericTypeDefinition() == typeof(Nullable<>)) || - t == typeof(DateTime) || t == typeof(DateTimeOffset) || t == typeof(TimeSpan) || t == typeof(UInt128)) + t == typeof(DateTime) || t == typeof(DateTimeOffset) || t == typeof(TimeSpan) || t == typeof(RowId)) { return OpCodes.Stobj; } diff --git a/src/Microsoft.ML.Data/Utils/LossFunctions.cs b/src/Microsoft.ML.Data/Utils/LossFunctions.cs index 1659b39387..ed3ad17765 100644 --- a/src/Microsoft.ML.Data/Utils/LossFunctions.cs +++ b/src/Microsoft.ML.Data/Utils/LossFunctions.cs @@ -2,13 +2,12 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Float = System.Single; - using System; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Utilities; +using Float = System.Single; [assembly: LoadableClass(LogLoss.Summary, typeof(LogLoss), null, typeof(SignatureClassificationLoss), "Log Loss", "LogLoss", "Logistic", "CrossEntropy")] @@ -39,7 +38,7 @@ [assembly: EntryPointModule(typeof(SquaredLossFactory))] [assembly: EntryPointModule(typeof(TweedieLoss.Arguments))] -namespace Microsoft.ML.Runtime +namespace Microsoft.ML { /// /// The loss function may know the close-form solution to the optimal dual update diff --git a/src/Microsoft.ML.Data/Utils/SequencePool.cs b/src/Microsoft.ML.Data/Utils/SequencePool.cs index 6b9ebe01e3..989bbd195f 100644 --- a/src/Microsoft.ML.Data/Utils/SequencePool.cs +++ b/src/Microsoft.ML.Data/Utils/SequencePool.cs @@ -5,7 +5,7 @@ using System; using System.IO; -namespace Microsoft.ML.Runtime.Internal.Utilities +namespace Microsoft.ML.Internal.Utilities { using Conditional = System.Diagnostics.ConditionalAttribute; diff --git a/src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/DnnAnalyzer.cs b/src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/DnnAnalyzer.cs index 5179c60e56..67f261a393 100644 --- a/src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/DnnAnalyzer.cs +++ b/src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/DnnAnalyzer.cs @@ -2,8 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Transforms.TensorFlow; using System; +using Microsoft.ML.Transforms.TensorFlow; namespace Microsoft.ML.DnnAnalyzer { diff --git a/src/Microsoft.ML.DnnImageFeaturizer.AlexNet/AlexNetExtension.cs b/src/Microsoft.ML.DnnImageFeaturizer.AlexNet/AlexNetExtension.cs index 186d083679..e59a1da6e6 100644 --- a/src/Microsoft.ML.DnnImageFeaturizer.AlexNet/AlexNetExtension.cs +++ b/src/Microsoft.ML.DnnImageFeaturizer.AlexNet/AlexNetExtension.cs @@ -2,10 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; using System.IO; using System.Reflection; +using Microsoft.ML.Data; namespace Microsoft.ML.Transforms { diff --git a/src/Microsoft.ML.DnnImageFeaturizer.ResNet101/ResNet101Extension.cs b/src/Microsoft.ML.DnnImageFeaturizer.ResNet101/ResNet101Extension.cs index 0e5298211e..c5978fe198 100644 --- a/src/Microsoft.ML.DnnImageFeaturizer.ResNet101/ResNet101Extension.cs +++ b/src/Microsoft.ML.DnnImageFeaturizer.ResNet101/ResNet101Extension.cs @@ -2,10 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; using System.IO; using System.Reflection; +using Microsoft.ML.Data; namespace Microsoft.ML.Transforms { diff --git a/src/Microsoft.ML.DnnImageFeaturizer.ResNet18/ResNet18Extension.cs b/src/Microsoft.ML.DnnImageFeaturizer.ResNet18/ResNet18Extension.cs index 8a0d0504b7..01b8436950 100644 --- a/src/Microsoft.ML.DnnImageFeaturizer.ResNet18/ResNet18Extension.cs +++ b/src/Microsoft.ML.DnnImageFeaturizer.ResNet18/ResNet18Extension.cs @@ -2,10 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; using System.IO; using System.Reflection; +using Microsoft.ML.Data; namespace Microsoft.ML.Transforms { diff --git a/src/Microsoft.ML.DnnImageFeaturizer.ResNet50/ResNet50Extension.cs b/src/Microsoft.ML.DnnImageFeaturizer.ResNet50/ResNet50Extension.cs index b46d34f708..c57e7f2560 100644 --- a/src/Microsoft.ML.DnnImageFeaturizer.ResNet50/ResNet50Extension.cs +++ b/src/Microsoft.ML.DnnImageFeaturizer.ResNet50/ResNet50Extension.cs @@ -2,10 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; using System.IO; using System.Reflection; +using Microsoft.ML.Data; namespace Microsoft.ML.Transforms { diff --git a/src/Microsoft.ML.Ensemble/Batch.cs b/src/Microsoft.ML.Ensemble/Batch.cs index e9c8fcf179..caaf6bf4f3 100644 --- a/src/Microsoft.ML.Ensemble/Batch.cs +++ b/src/Microsoft.ML.Ensemble/Batch.cs @@ -2,11 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Data; -namespace Microsoft.ML.Runtime.Ensemble +namespace Microsoft.ML.Ensemble { - public sealed class Batch + internal sealed class Batch { public readonly RoleMappedData TrainInstances; public readonly RoleMappedData TestInstances; diff --git a/src/Microsoft.ML.Ensemble/EnsembleUtils.cs b/src/Microsoft.ML.Ensemble/EnsembleUtils.cs index 66a6ff165e..0cb43ef7a4 100644 --- a/src/Microsoft.ML.Ensemble/EnsembleUtils.cs +++ b/src/Microsoft.ML.Ensemble/EnsembleUtils.cs @@ -4,10 +4,10 @@ using System; using System.Collections; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.Ensemble +namespace Microsoft.ML.Ensemble { internal static class EnsembleUtils { @@ -18,17 +18,18 @@ public static RoleMappedData SelectFeatures(IHost host, RoleMappedData data, Bit { Contracts.AssertValue(host); Contracts.AssertValue(data); - Contracts.Assert(data.Schema.Feature != null); + Contracts.Assert(data.Schema.Feature.HasValue); Contracts.AssertValue(features); + var featCol = data.Schema.Feature.Value; - var type = data.Schema.Feature.Type; + var type = featCol.Type; Contracts.Assert(features.Length == type.VectorSize); int card = Utils.GetCardinality(features); if (card == type.VectorSize) return data; // REVIEW: This doesn't preserve metadata on the features column. Should it? - var name = data.Schema.Feature.Name; + var name = featCol.Name; var view = LambdaColumnMapper.Create( host, "FeatureSelector", data.Data, name, name, type, type, (in VBuffer src, ref VBuffer dst) => SelectFeatures(in src, features, card, ref dst)); diff --git a/src/Microsoft.ML.Ensemble/EntryPoints/CreateEnsemble.cs b/src/Microsoft.ML.Ensemble/EntryPoints/CreateEnsemble.cs index 7011f7976c..2416701f3c 100644 --- a/src/Microsoft.ML.Ensemble/EntryPoints/CreateEnsemble.cs +++ b/src/Microsoft.ML.Ensemble/EntryPoints/CreateEnsemble.cs @@ -3,22 +3,21 @@ // See the LICENSE file in the project root for more information. using System; -using System.Collections.Generic; using System.IO; using System.IO.Compression; using System.Linq; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.Ensemble; using Microsoft.ML.Ensemble.EntryPoints; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Ensemble; -using Microsoft.ML.Runtime.Ensemble.OutputCombiners; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Ensemble.OutputCombiners; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Utilities; [assembly: LoadableClass(typeof(void), typeof(EnsembleCreator), null, typeof(SignatureEntryPointModule), "CreateEnsemble")] -namespace Microsoft.ML.Runtime.EntryPoints +namespace Microsoft.ML.EntryPoints { /// /// A component to combine given models into an ensemble model. @@ -47,13 +46,13 @@ public enum ScoreCombiner public abstract class PipelineInputBase { [Argument(ArgumentType.Required, ShortName = "models", HelpText = "The models to combine into an ensemble", SortOrder = 1)] - public IPredictorModel[] Models; + public PredictorModel[] Models; } public abstract class InputBase { [Argument(ArgumentType.Required, ShortName = "models", HelpText = "The models to combine into an ensemble", SortOrder = 1)] - public IPredictorModel[] Models; + public PredictorModel[] Models; [Argument(ArgumentType.AtMostOnce, ShortName = "validate", HelpText = "Whether to validate that all the pipelines are identical", SortOrder = 5)] public bool ValidatePipelines = true; @@ -95,7 +94,7 @@ private static void GetPipeline(IHostEnvironment env, InputBase input, out IData env.AssertValue(input); env.AssertNonEmpty(input.Models); - ISchema inputSchema = null; + Schema inputSchema = null; startingData = null; transformedData = null; byte[][] transformedDataSerialized = null; @@ -159,7 +158,7 @@ public static CommonOutputs.BinaryClassificationOutput CreateBinaryEnsemble(IHos var trainer = new EnsembleTrainer(host, args); var ensemble = trainer.CombineModels(input.Models.Select(pm => pm.Predictor as IPredictorProducing)); - var predictorModel = new PredictorModel(host, transformedData, startingData, ensemble); + var predictorModel = new PredictorModelImpl(host, transformedData, startingData, ensemble); var output = new CommonOutputs.BinaryClassificationOutput { PredictorModel = predictorModel }; return output; @@ -191,7 +190,7 @@ public static CommonOutputs.RegressionOutput CreateRegressionEnsemble(IHostEnvir var trainer = new RegressionEnsembleTrainer(host, args); var ensemble = trainer.CombineModels(input.Models.Select(pm => pm.Predictor as IPredictorProducing)); - var predictorModel = new PredictorModel(host, transformedData, startingData, ensemble); + var predictorModel = new PredictorModelImpl(host, transformedData, startingData, ensemble); var output = new CommonOutputs.RegressionOutput { PredictorModel = predictorModel }; return output; @@ -299,7 +298,7 @@ public static CommonOutputs.AnomalyDetectionOutput CreateAnomalyPipelineEnsemble return CreatePipelineEnsemble(host, input.Models, ensemble); } - private static TOut CreatePipelineEnsemble(IHostEnvironment env, IPredictorModel[] predictors, SchemaBindablePipelineEnsembleBase ensemble) + private static TOut CreatePipelineEnsemble(IHostEnvironment env, PredictorModel[] predictors, SchemaBindablePipelineEnsembleBase ensemble) where TOut : CommonOutputs.TrainerOutput, new() { var inputSchema = predictors[0].TransformModel.InputSchema; @@ -307,7 +306,7 @@ private static TOut CreatePipelineEnsemble(IHostEnvironment env, IPredicto // The role mappings are specific to the individual predictors. var rmd = new RoleMappedData(dv); - var predictorModel = new PredictorModel(env, rmd, dv, ensemble); + var predictorModel = new PredictorModelImpl(env, rmd, dv, ensemble); var output = new TOut { PredictorModel = predictorModel }; return output; @@ -322,7 +321,7 @@ private static TOut CreatePipelineEnsemble(IHostEnvironment env, IPredicto /// This method is used for comparing pipelines. Its outputs can be passed to /// to check if this pipeline is identical to another pipeline. /// - public static void SerializeRoleMappedData(IHostEnvironment env, IChannel ch, RoleMappedData data, + private static void SerializeRoleMappedData(IHostEnvironment env, IChannel ch, RoleMappedData data, out byte[][] dataSerialized, out string[] dataZipEntryNames) { Contracts.CheckValue(env, nameof(env)); @@ -356,7 +355,7 @@ public static void SerializeRoleMappedData(IHostEnvironment env, IChannel ch, Ro /// and . /// This method throws if for any of the entries the name/byte sequence are not identical. /// - public static void CheckSamePipeline(IHostEnvironment env, IChannel ch, + private static void CheckSamePipeline(IHostEnvironment env, IChannel ch, RoleMappedData dataToCompare, byte[][] dataSerialized, string[] dataZipEntryNames) { Contracts.CheckValue(env, nameof(env)); diff --git a/src/Microsoft.ML.Ensemble/EntryPoints/DiversityMeasure.cs b/src/Microsoft.ML.Ensemble/EntryPoints/DiversityMeasure.cs index b13cff3b35..f295697f6e 100644 --- a/src/Microsoft.ML.Ensemble/EntryPoints/DiversityMeasure.cs +++ b/src/Microsoft.ML.Ensemble/EntryPoints/DiversityMeasure.cs @@ -2,13 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; using Microsoft.ML.Ensemble.EntryPoints; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Ensemble.Selector; -using Microsoft.ML.Runtime.Ensemble.Selector.DiversityMeasure; -using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML.Ensemble.Selector; +using Microsoft.ML.Ensemble.Selector.DiversityMeasure; +using Microsoft.ML.EntryPoints; [assembly: EntryPointModule(typeof(DisagreementDiversityFactory))] [assembly: EntryPointModule(typeof(RegressionDisagreementDiversityFactory))] diff --git a/src/Microsoft.ML.Ensemble/EntryPoints/Ensemble.cs b/src/Microsoft.ML.Ensemble/EntryPoints/Ensemble.cs index e5e9b79afc..6cf8b418bc 100644 --- a/src/Microsoft.ML.Ensemble/EntryPoints/Ensemble.cs +++ b/src/Microsoft.ML.Ensemble/EntryPoints/Ensemble.cs @@ -2,10 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.ML; using Microsoft.ML.Ensemble.EntryPoints; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Ensemble; -using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML.EntryPoints; [assembly: LoadableClass(typeof(void), typeof(Ensemble), null, typeof(SignatureEntryPointModule), "TrainEnsemble")] diff --git a/src/Microsoft.ML.Ensemble/EntryPoints/FeatureSelector.cs b/src/Microsoft.ML.Ensemble/EntryPoints/FeatureSelector.cs index 65ca5e9d06..c2d7862a31 100644 --- a/src/Microsoft.ML.Ensemble/EntryPoints/FeatureSelector.cs +++ b/src/Microsoft.ML.Ensemble/EntryPoints/FeatureSelector.cs @@ -3,10 +3,9 @@ // See the LICENSE file in the project root for more information. using Microsoft.ML.Ensemble.EntryPoints; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Ensemble.Selector; -using Microsoft.ML.Runtime.Ensemble.Selector.FeatureSelector; -using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML.Ensemble.Selector; +using Microsoft.ML.Ensemble.Selector.FeatureSelector; +using Microsoft.ML.EntryPoints; [assembly: EntryPointModule(typeof(AllFeatureSelectorFactory))] [assembly: EntryPointModule(typeof(RandomFeatureSelector))] diff --git a/src/Microsoft.ML.Ensemble/EntryPoints/OutputCombiner.cs b/src/Microsoft.ML.Ensemble/EntryPoints/OutputCombiner.cs index 537b35f47b..7583d72cfb 100644 --- a/src/Microsoft.ML.Ensemble/EntryPoints/OutputCombiner.cs +++ b/src/Microsoft.ML.Ensemble/EntryPoints/OutputCombiner.cs @@ -3,9 +3,8 @@ // See the LICENSE file in the project root for more information. using Microsoft.ML.Ensemble.EntryPoints; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Ensemble.OutputCombiners; -using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML.Ensemble.OutputCombiners; +using Microsoft.ML.EntryPoints; [assembly: EntryPointModule(typeof(AverageFactory))] [assembly: EntryPointModule(typeof(MedianFactory))] diff --git a/src/Microsoft.ML.Ensemble/EntryPoints/PipelineEnsemble.cs b/src/Microsoft.ML.Ensemble/EntryPoints/PipelineEnsemble.cs index bcfaaefb89..66482b44a9 100644 --- a/src/Microsoft.ML.Ensemble/EntryPoints/PipelineEnsemble.cs +++ b/src/Microsoft.ML.Ensemble/EntryPoints/PipelineEnsemble.cs @@ -2,14 +2,14 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Ensemble.EntryPoints; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Calibration; +using Microsoft.ML.Data; +using Microsoft.ML.Ensemble.EntryPoints; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Calibration; [assembly: EntryPointModule(typeof(PipelineEnsemble))] -namespace Microsoft.ML.Runtime.Ensemble.EntryPoints +namespace Microsoft.ML.Ensemble.EntryPoints { public static class PipelineEnsemble { diff --git a/src/Microsoft.ML.Ensemble/EntryPoints/SubModelSelector.cs b/src/Microsoft.ML.Ensemble/EntryPoints/SubModelSelector.cs index 57001190ac..10b1551c62 100644 --- a/src/Microsoft.ML.Ensemble/EntryPoints/SubModelSelector.cs +++ b/src/Microsoft.ML.Ensemble/EntryPoints/SubModelSelector.cs @@ -2,13 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; using Microsoft.ML.Ensemble.EntryPoints; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Ensemble.Selector; -using Microsoft.ML.Runtime.Ensemble.Selector.SubModelSelector; -using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML.Ensemble.Selector; +using Microsoft.ML.Ensemble.Selector.SubModelSelector; +using Microsoft.ML.EntryPoints; [assembly: EntryPointModule(typeof(AllSelectorFactory))] [assembly: EntryPointModule(typeof(AllSelectorMultiClassFactory))] diff --git a/src/Microsoft.ML.Ensemble/FeatureSubsetModel.cs b/src/Microsoft.ML.Ensemble/FeatureSubsetModel.cs index 4518666d34..fadd096810 100644 --- a/src/Microsoft.ML.Ensemble/FeatureSubsetModel.cs +++ b/src/Microsoft.ML.Ensemble/FeatureSubsetModel.cs @@ -4,9 +4,9 @@ using System.Collections; using System.Collections.Generic; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.Ensemble +namespace Microsoft.ML.Ensemble { public sealed class FeatureSubsetModel where TPredictor : IPredictor { diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/Average.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/Average.cs index de1e5ef505..05b85b37ec 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/Average.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/Average.cs @@ -3,14 +3,14 @@ // See the LICENSE file in the project root for more information. using System; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Ensemble.OutputCombiners; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML; +using Microsoft.ML.Ensemble.OutputCombiners; +using Microsoft.ML.Model; [assembly: LoadableClass(typeof(Average), null, typeof(SignatureCombiner), Average.UserName)] [assembly: LoadableClass(typeof(Average), null, typeof(SignatureLoadModel), Average.UserName, Average.LoaderSignature)] -namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners +namespace Microsoft.ML.Ensemble.OutputCombiners { public sealed class Average : BaseAverager, ICanSaveModel, IRegressionOutputCombiner { diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseAverager.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseAverager.cs index 824300e594..1ffbe13ec1 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseAverager.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseAverager.cs @@ -3,9 +3,9 @@ // See the LICENSE file in the project root for more information. using System; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Model; -namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners +namespace Microsoft.ML.Ensemble.OutputCombiners { public abstract class BaseAverager : IBinaryOutputCombiner { diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiAverager.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiAverager.cs index e7a50c11c3..73e6f4ea7e 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiAverager.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiAverager.cs @@ -3,12 +3,12 @@ // See the LICENSE file in the project root for more information. using System; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Numeric; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Numeric; -namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners +namespace Microsoft.ML.Ensemble.OutputCombiners { public abstract class BaseMultiAverager : BaseMultiCombiner { diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiCombiner.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiCombiner.cs index 350833aebb..5e8296862b 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiCombiner.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiCombiner.cs @@ -3,13 +3,13 @@ // See the LICENSE file in the project root for more information. using System; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Numeric; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Numeric; -namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners +namespace Microsoft.ML.Ensemble.OutputCombiners { public abstract class BaseMultiCombiner : IMultiClassOutputCombiner { diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseScalarStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseScalarStacking.cs index 257499cd13..ba75e05080 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseScalarStacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseScalarStacking.cs @@ -3,11 +3,11 @@ // See the LICENSE file in the project root for more information. using System; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; -namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners +namespace Microsoft.ML.Ensemble.OutputCombiners { internal abstract class BaseScalarStacking : BaseStacking { diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs index d62650b7c2..c791307bd4 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs @@ -5,17 +5,15 @@ using System; using System.Collections.Generic; using System.Threading.Tasks; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Internallearn; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Training; - -namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Training; + +namespace Microsoft.ML.Ensemble.OutputCombiners { - using ColumnRole = RoleMappedSchema.ColumnRole; internal abstract class BaseStacking : IStackingTrainer { public abstract class ArgumentsBase diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/IOutputCombiner.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/IOutputCombiner.cs index 512974b717..53aec1f8d8 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/IOutputCombiner.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/IOutputCombiner.cs @@ -4,10 +4,10 @@ using System; using System.Collections.Generic; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; -namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners +namespace Microsoft.ML.Ensemble.OutputCombiners { /// /// Signature for combiners. @@ -28,7 +28,7 @@ public interface IOutputCombiner : IOutputCombiner Combiner GetCombiner(); } - public interface IStackingTrainer + internal interface IStackingTrainer { void Train(List>> models, RoleMappedData data, IHostEnvironment env); Single ValidationDatasetProportion { get; } diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/Median.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/Median.cs index 95bc0cc991..3b94196eb9 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/Median.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/Median.cs @@ -3,15 +3,15 @@ // See the LICENSE file in the project root for more information. using System; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Ensemble.OutputCombiners; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML; +using Microsoft.ML.Ensemble.OutputCombiners; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; [assembly: LoadableClass(typeof(Median), null, typeof(SignatureCombiner), Median.UserName, Median.LoadName)] [assembly: LoadableClass(typeof(Median), null, typeof(SignatureLoadModel), Median.UserName, Median.LoaderSignature)] -namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners +namespace Microsoft.ML.Ensemble.OutputCombiners { /// /// Generic interface for combining outputs of multiple models diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiAverage.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiAverage.cs index fef6fa087e..1ba5cdf028 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiAverage.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiAverage.cs @@ -3,18 +3,18 @@ // See the LICENSE file in the project root for more information. using System; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Ensemble.OutputCombiners; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML; +using Microsoft.ML.Data; +using Microsoft.ML.Ensemble.OutputCombiners; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Model; [assembly: LoadableClass(typeof(MultiAverage), typeof(MultiAverage.Arguments), typeof(SignatureCombiner), Average.UserName, MultiAverage.LoadName)] [assembly: LoadableClass(typeof(MultiAverage), null, typeof(SignatureLoadModel), Average.UserName, MultiAverage.LoadName, MultiAverage.LoaderSignature)] -namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners +namespace Microsoft.ML.Ensemble.OutputCombiners { public sealed class MultiAverage : BaseMultiAverager, ICanSaveModel { diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiMedian.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiMedian.cs index 3b11146203..477a658355 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiMedian.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiMedian.cs @@ -3,18 +3,18 @@ // See the LICENSE file in the project root for more information. using System; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Ensemble.OutputCombiners; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML; +using Microsoft.ML.Data; +using Microsoft.ML.Ensemble.OutputCombiners; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; [assembly: LoadableClass(typeof(MultiMedian), typeof(MultiMedian.Arguments), typeof(SignatureCombiner), Median.UserName, MultiMedian.LoadName)] [assembly: LoadableClass(typeof(MultiMedian), null, typeof(SignatureLoadModel), Median.UserName, MultiMedian.LoaderSignature)] -namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners +namespace Microsoft.ML.Ensemble.OutputCombiners { /// /// Generic interface for combining outputs of multiple models diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs index f9e3b246f7..3ec8ea85c8 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs @@ -3,14 +3,13 @@ // See the LICENSE file in the project root for more information. using System; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Ensemble.OutputCombiners; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Internallearn; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.Ensemble.OutputCombiners; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Model; [assembly: LoadableClass(typeof(MultiStacking), typeof(MultiStacking.Arguments), typeof(SignatureCombiner), Stacking.UserName, MultiStacking.LoadName)] @@ -18,7 +17,7 @@ [assembly: LoadableClass(typeof(MultiStacking), null, typeof(SignatureLoadModel), Stacking.UserName, MultiStacking.LoaderSignature)] -namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners +namespace Microsoft.ML.Ensemble.OutputCombiners { using TVectorPredictor = IPredictorProducing>; internal sealed class MultiStacking : BaseStacking>, ICanSaveModel, IMultiClassOutputCombiner diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiVoting.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiVoting.cs index ffa1b9c647..a8ba932926 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiVoting.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiVoting.cs @@ -3,17 +3,17 @@ // See the LICENSE file in the project root for more information. using System; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Ensemble.OutputCombiners; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Numeric; +using Microsoft.ML; +using Microsoft.ML.Data; +using Microsoft.ML.Ensemble.OutputCombiners; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Numeric; [assembly: LoadableClass(typeof(MultiVoting), null, typeof(SignatureCombiner), Voting.UserName, MultiVoting.LoadName)] [assembly: LoadableClass(typeof(MultiVoting), null, typeof(SignatureLoadModel), Voting.UserName, MultiVoting.LoaderSignature)] -namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners +namespace Microsoft.ML.Ensemble.OutputCombiners { // REVIEW: Why is MultiVoting based on BaseMultiCombiner? Normalizing the model outputs // is senseless, so the base adds no real functionality. diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiWeightedAverage.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiWeightedAverage.cs index a2f52b1451..deef23abcd 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiWeightedAverage.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiWeightedAverage.cs @@ -3,13 +3,13 @@ // See the LICENSE file in the project root for more information. using System; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Ensemble.OutputCombiners; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Internallearn; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.Ensemble.OutputCombiners; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Model; [assembly: LoadableClass(typeof(MultiWeightedAverage), typeof(MultiWeightedAverage.Arguments), typeof(SignatureCombiner), MultiWeightedAverage.UserName, MultiWeightedAverage.LoadName)] @@ -17,7 +17,7 @@ [assembly: LoadableClass(typeof(MultiWeightedAverage), null, typeof(SignatureLoadModel), MultiWeightedAverage.UserName, MultiWeightedAverage.LoaderSignature)] -namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners +namespace Microsoft.ML.Ensemble.OutputCombiners { /// /// Generic interface for combining outputs of multiple models diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs index 8c984613db..ab57cfe9aa 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs @@ -2,12 +2,12 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. using System; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Ensemble.OutputCombiners; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Internallearn; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Ensemble.OutputCombiners; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Model; [assembly: LoadableClass(typeof(RegressionStacking), typeof(RegressionStacking.Arguments), typeof(SignatureCombiner), Stacking.UserName, RegressionStacking.LoadName)] @@ -15,7 +15,7 @@ [assembly: LoadableClass(typeof(RegressionStacking), null, typeof(SignatureLoadModel), Stacking.UserName, RegressionStacking.LoaderSignature)] -namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners +namespace Microsoft.ML.Ensemble.OutputCombiners { using TScalarPredictor = IPredictorProducing; diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs index f44f987b05..891963dbea 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs @@ -3,17 +3,17 @@ // See the LICENSE file in the project root for more information. using System; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Ensemble.OutputCombiners; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Internallearn; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Ensemble.OutputCombiners; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Model; [assembly: LoadableClass(typeof(Stacking), typeof(Stacking.Arguments), typeof(SignatureCombiner), Stacking.UserName, Stacking.LoadName)] [assembly: LoadableClass(typeof(Stacking), null, typeof(SignatureLoadModel), Stacking.UserName, Stacking.LoaderSignature)] -namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners +namespace Microsoft.ML.Ensemble.OutputCombiners { using TScalarPredictor = IPredictorProducing; internal sealed class Stacking : BaseScalarStacking, IBinaryOutputCombiner, ICanSaveModel diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/Voting.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/Voting.cs index d352439d55..b18de67b18 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/Voting.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/Voting.cs @@ -3,15 +3,15 @@ // See the LICENSE file in the project root for more information. using System; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Ensemble.OutputCombiners; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML; +using Microsoft.ML.Ensemble.OutputCombiners; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; [assembly: LoadableClass(typeof(Voting), null, typeof(SignatureCombiner), Voting.UserName, Voting.LoadName)] [assembly: LoadableClass(typeof(Voting), null, typeof(SignatureLoadModel), Voting.UserName, Voting.LoaderSignature)] -namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners +namespace Microsoft.ML.Ensemble.OutputCombiners { public sealed class Voting : IBinaryOutputCombiner, ICanSaveModel { diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/WeightedAverage.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/WeightedAverage.cs index e11f0c9bb9..ec1d2dcd11 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/WeightedAverage.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/WeightedAverage.cs @@ -3,13 +3,13 @@ // See the LICENSE file in the project root for more information. using System; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Ensemble.OutputCombiners; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Internallearn; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.Ensemble.OutputCombiners; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Model; [assembly: LoadableClass(typeof(WeightedAverage), typeof(WeightedAverage.Arguments), typeof(SignatureCombiner), WeightedAverage.UserName, WeightedAverage.LoadName)] @@ -17,7 +17,7 @@ [assembly: LoadableClass(typeof(WeightedAverage), null, typeof(SignatureLoadModel), WeightedAverage.UserName, WeightedAverage.LoaderSignature)] -namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners +namespace Microsoft.ML.Ensemble.OutputCombiners { public sealed class WeightedAverage : BaseAverager, IWeightedAverager, ICanSaveModel { diff --git a/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs b/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs index 385d8dc7b5..0c2655fff1 100644 --- a/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs +++ b/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs @@ -7,21 +7,20 @@ using System.IO; using System.Linq; using System.Text; +using Microsoft.ML; using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Ensemble; -using Microsoft.ML.Runtime.Ensemble.OutputCombiners; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Calibration; -using Microsoft.ML.Runtime.Internal.Internallearn; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Ensemble; +using Microsoft.ML.Ensemble.OutputCombiners; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Calibration; +using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; [assembly: LoadableClass(typeof(SchemaBindablePipelineEnsembleBase), null, typeof(SignatureLoadModel), SchemaBindablePipelineEnsembleBase.UserName, SchemaBindablePipelineEnsembleBase.LoaderSignature)] -namespace Microsoft.ML.Runtime.Ensemble +namespace Microsoft.ML.Ensemble { /// /// This class represents an ensemble predictor, where each predictor has its own featurization pipeline. It is @@ -43,13 +42,13 @@ private abstract class BoundBase : ISchemaBoundRowMapper public ISchemaBindableMapper Bindable => Parent; public RoleMappedSchema InputRoleMappedSchema { get; } public Schema InputSchema => InputRoleMappedSchema.Schema; - public Schema Schema { get; } + public Schema OutputSchema { get; } public BoundBase(SchemaBindablePipelineEnsembleBase parent, RoleMappedSchema schema) { Parent = parent; InputRoleMappedSchema = schema; - Schema = Schema.Create(new ScoreMapperSchema(Parent.ScoreType, Parent._scoreColumnKind)); + OutputSchema = Schema.Create(new ScoreMapperSchema(Parent.ScoreType, Parent._scoreColumnKind)); _inputColIndices = new HashSet(); for (int i = 0; i < Parent._inputCols.Length; i++) { @@ -75,12 +74,12 @@ public BoundBase(SchemaBindablePipelineEnsembleBase parent, RoleMappedSchema sch throw Parent.Host.Except("Predictor {0} is not a row to row mapper", i); // Make sure there is a score column, and remember its index. - if (!Mappers[i].Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out ScoreCols[i])) + if (!Mappers[i].OutputSchema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out ScoreCols[i])) throw Parent.Host.Except("Predictor {0} does not contain a score column", i); // Get the pipeline. var dv = new EmptyDataView(Parent.Host, schema.Schema); - var tm = new TransformModel(Parent.Host, dv, dv); + var tm = new TransformModelImpl(Parent.Host, dv, dv); var pipeline = Parent.PredictorModels[i].TransformModel.Apply(Parent.Host, tm); BoundPipelines[i] = pipeline.AsRowToRowMapper(Parent.Host); if (BoundPipelines[i] == null) @@ -90,7 +89,7 @@ public BoundBase(SchemaBindablePipelineEnsembleBase parent, RoleMappedSchema sch public Func GetDependencies(Func predicate) { - for (int i = 0; i < Schema.ColumnCount; i++) + for (int i = 0; i < OutputSchema.Count; i++) { if (predicate(i)) return col => _inputColIndices.Contains(col); @@ -103,12 +102,13 @@ public Func GetDependencies(Func predicate) yield break; } - public IRow GetRow(IRow input, Func predicate, out Action disposer) + public Row GetRow(Row input, Func predicate) { - return new SimpleRow(Schema, input, new[] { CreateScoreGetter(input, predicate, out disposer) }); + var scoreGetter = CreateScoreGetter(input, predicate, out Action disposer); + return new SimpleRow(OutputSchema, input, new[] { scoreGetter }, disposer); } - public abstract Delegate CreateScoreGetter(IRow input, Func mapperPredicate, out Action disposer); + internal abstract Delegate CreateScoreGetter(Row input, Func mapperPredicate, out Action disposer); } // A generic base class for pipeline ensembles. This class contains the combiner. @@ -124,7 +124,7 @@ public Bound(SchemaBindablePipelineEnsemble parent, RoleMappedSchema schema) _combiner = parent.Combiner; } - public override Delegate CreateScoreGetter(IRow input, Func mapperPredicate, out Action disposer) + internal override Delegate CreateScoreGetter(Row input, Func mapperPredicate, out Action disposer) { disposer = null; @@ -137,13 +137,12 @@ public override Delegate CreateScoreGetter(IRow input, Func mapperPre // First get the output row from the pipelines. The input predicate of the predictor // is the output predicate of the pipeline. var inputPredicate = Mappers[i].GetDependencies(mapperPredicate); - var pipelineRow = BoundPipelines[i].GetRow(input, inputPredicate, out Action disp); - disposer += disp; + var pipelineRow = BoundPipelines[i].GetRow(input, inputPredicate); // Next we get the output row from the predictors. We activate the score column as output predicate. - var predictorRow = Mappers[i].GetRow(pipelineRow, col => col == ScoreCols[i], out disp); - disposer += disp; + var predictorRow = Mappers[i].GetRow(pipelineRow, col => col == ScoreCols[i]); getters[i] = predictorRow.GetGetter(ScoreCols[i]); + disposer += predictorRow.Dispose; } var comb = _combiner.GetCombiner(); @@ -158,36 +157,41 @@ public override Delegate CreateScoreGetter(IRow input, Func mapperPre return scoreGetter; } - public ValueGetter GetLabelGetter(IRow input, int i, out Action disposer) + public ValueGetter GetLabelGetter(Row input, int i, out Action disposer) { Parent.Host.Assert(0 <= i && i < Mappers.Length); - Parent.Host.Check(Mappers[i].InputRoleMappedSchema.Label != null, "Mapper was not trained using a label column"); + Parent.Host.Check(Mappers[i].InputRoleMappedSchema.Label.HasValue, "Mapper was not trained using a label column"); + var labelCol = Mappers[i].InputRoleMappedSchema.Label.Value; // The label should be in the output row of the i'th pipeline - var pipelineRow = BoundPipelines[i].GetRow(input, col => col == Mappers[i].InputRoleMappedSchema.Label.Index, out disposer); - return RowCursorUtils.GetLabelGetter(pipelineRow, Mappers[i].InputRoleMappedSchema.Label.Index); + var pipelineRow = BoundPipelines[i].GetRow(input, col => col == labelCol.Index); + disposer = pipelineRow.Dispose; + return RowCursorUtils.GetLabelGetter(pipelineRow, labelCol.Index); } - public ValueGetter GetWeightGetter(IRow input, int i, out Action disposer) + public ValueGetter GetWeightGetter(Row input, int i, out Action disposer) { Parent.Host.Assert(0 <= i && i < Mappers.Length); - if (Mappers[i].InputRoleMappedSchema.Weight == null) + if (!Mappers[i].InputRoleMappedSchema.Weight.HasValue) { - ValueGetter weight = (ref Single dst) => dst = 1; + ValueGetter weight = (ref float dst) => dst = 1; disposer = null; return weight; } + var weightCol = Mappers[i].InputRoleMappedSchema.Weight.Value; // The weight should be in the output row of the i'th pipeline if it exists. - var inputPredicate = Mappers[i].GetDependencies(col => col == Mappers[i].InputRoleMappedSchema.Weight.Index); - var pipelineRow = BoundPipelines[i].GetRow(input, inputPredicate, out disposer); - return pipelineRow.GetGetter(Mappers[i].InputRoleMappedSchema.Weight.Index); + var inputPredicate = Mappers[i].GetDependencies(col => col == weightCol.Index); + var pipelineRow = BoundPipelines[i].GetRow(input, inputPredicate); + disposer = pipelineRow.Dispose; + return pipelineRow.GetGetter(weightCol.Index); + } } protected readonly IOutputCombiner Combiner; - protected SchemaBindablePipelineEnsemble(IHostEnvironment env, IPredictorModel[] predictors, + protected SchemaBindablePipelineEnsemble(IHostEnvironment env, PredictorModel[] predictors, IOutputCombiner combiner, string registrationName, string scoreColumnKind) : base(env, predictors, registrationName, scoreColumnKind) { @@ -215,7 +219,7 @@ protected override void SaveCore(ModelSaveContext ctx) ctx.SaveModel(Combiner, "Combiner"); } - public override ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema) + private protected override ISchemaBoundMapper BindCore(IHostEnvironment env, RoleMappedSchema schema) { return new Bound(this, schema); } @@ -238,7 +242,7 @@ public override PredictionKind PredictionKind } } - public ImplOne(IHostEnvironment env, IPredictorModel[] predictors, IRegressionOutputCombiner combiner, string scoreColumnKind) + public ImplOne(IHostEnvironment env, PredictorModel[] predictors, IRegressionOutputCombiner combiner, string scoreColumnKind) : base(env, predictors, combiner, LoaderSignature, scoreColumnKind) { } @@ -266,7 +270,7 @@ public override PredictionKind PredictionKind private readonly VectorType _scoreType; - public ImplVec(IHostEnvironment env, IPredictorModel[] predictors, IMultiClassOutputCombiner combiner) + public ImplVec(IHostEnvironment env, PredictorModel[] predictors, IMultiClassOutputCombiner combiner) : base(env, predictors, combiner, LoaderSignature, MetadataUtils.Const.ScoreColumnKind.MultiClassClassification) { int classCount = CheckLabelColumn(Host, predictors, false); @@ -288,7 +292,7 @@ private sealed class ImplOneWithCalibrator : SchemaBindablePipelineEnsemble(); - for (int j = 0; j < inputSchema.ColumnCount; j++) + for (int j = 0; j < inputSchema.Count; j++) { - if (inputSchema.IsHidden(j)) + if (inputSchema[j].IsHidden) continue; - inputCols.Add(inputSchema.GetColumnName(j)); + inputCols.Add(inputSchema[j].Name); } _inputCols = inputCols.ToArray(); } else { int nonHiddenCols = 0; - for (int j = 0; j < inputSchema.ColumnCount; j++) + for (int j = 0; j < inputSchema.Count; j++) { - if (inputSchema.IsHidden(j)) + if (inputSchema[j].IsHidden) continue; - var name = inputSchema.GetColumnName(j); + var name = inputSchema[j].Name; if (!inputCols.Contains(name)) throw Host.Except("Inconsistent schemas: Some schemas do not contain the column '{0}'", name); nonHiddenCols++; @@ -456,7 +460,7 @@ protected SchemaBindablePipelineEnsembleBase(IHostEnvironment env, ModelLoadCont var length = ctx.Reader.ReadInt32(); Host.CheckDecode(length > 0); - PredictorModels = new IPredictorModel[length]; + PredictorModels = new PredictorModel[length]; for (int i = 0; i < PredictorModels.Length; i++) { string dir = @@ -464,7 +468,7 @@ protected SchemaBindablePipelineEnsembleBase(IHostEnvironment env, ModelLoadCont ? "PredictorModels" : Path.Combine(ctx.Directory, "PredictorModels"); using (var ent = ctx.Repository.OpenEntry(dir, $"PredictorModel_{i:000}")) - PredictorModels[i] = new PredictorModel(Host, ent.Stream); + PredictorModels[i] = new PredictorModelImpl(Host, ent.Stream); } length = ctx.Reader.ReadInt32(); @@ -509,7 +513,7 @@ public void Save(ModelSaveContext ctx) protected abstract void SaveCore(ModelSaveContext ctx); - public static SchemaBindablePipelineEnsembleBase Create(IHostEnvironment env, IPredictorModel[] predictors, IOutputCombiner combiner, string scoreColumnKind) + public static SchemaBindablePipelineEnsembleBase Create(IHostEnvironment env, PredictorModel[] predictors, IOutputCombiner combiner, string scoreColumnKind) { switch (scoreColumnKind) { @@ -555,9 +559,11 @@ public static SchemaBindablePipelineEnsembleBase Create(IHostEnvironment env, Mo } } - public abstract ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema); + ISchemaBoundMapper ISchemaBindableMapper.Bind(IHostEnvironment env, RoleMappedSchema schema) => BindCore(env, schema); + + private protected abstract ISchemaBoundMapper BindCore(IHostEnvironment env, RoleMappedSchema schema); - public void SaveSummary(TextWriter writer, RoleMappedSchema schema) + void ICanSaveSummary.SaveSummary(TextWriter writer, RoleMappedSchema schema) { for (int i = 0; i < PredictorModels.Length; i++) { @@ -577,7 +583,7 @@ public void SaveSummary(TextWriter writer, RoleMappedSchema schema) } // Checks that the predictors have matching label columns, and returns the number of classes in all predictors. - protected static int CheckLabelColumn(IHostEnvironment env, IPredictorModel[] models, bool isBinary) + protected static int CheckLabelColumn(IHostEnvironment env, PredictorModel[] models, bool isBinary) { Contracts.CheckValue(env, nameof(env)); env.CheckNonEmpty(models, nameof(models)); @@ -585,28 +591,28 @@ protected static int CheckLabelColumn(IHostEnvironment env, IPredictorModel[] mo var model = models[0]; var edv = new EmptyDataView(env, model.TransformModel.InputSchema); model.PrepareData(env, edv, out RoleMappedData rmd, out IPredictor pred); - var labelInfo = rmd.Schema.Label; - if (labelInfo == null) + if (!rmd.Schema.Label.HasValue) throw env.Except("Training schema for model 0 does not have a label column"); + var labelCol = rmd.Schema.Label.Value; - var labelType = rmd.Schema.Schema.GetColumnType(rmd.Schema.Label.Index); + var labelType = labelCol.Type; if (!labelType.IsKey) return CheckNonKeyLabelColumnCore(env, pred, models, isBinary, labelType); if (isBinary && labelType.KeyCount != 2) throw env.Except("Label is not binary"); var schema = rmd.Schema.Schema; - var mdType = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, labelInfo.Index); + var mdType = labelCol.Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues)?.Type; if (mdType == null || !mdType.IsKnownSizeVector) throw env.Except("Label column of type key must have a vector of key values metadata"); - return Utils.MarshalInvoke(CheckKeyLabelColumnCore, mdType.ItemType.RawType, env, models, labelType.AsKey, schema, labelInfo.Index, mdType); + return Utils.MarshalInvoke(CheckKeyLabelColumnCore, mdType.ItemType.RawType, env, models, (KeyType)labelType, schema, labelCol.Index, mdType); } // When the label column is not a key, we check that the number of classes is the same for all the predictors, by checking the // OutputType property of the IValueMapper. // If any of the predictors do not implement IValueMapper we throw an exception. Returns the class count. - private static int CheckNonKeyLabelColumnCore(IHostEnvironment env, IPredictor pred, IPredictorModel[] models, bool isBinary, ColumnType labelType) + private static int CheckNonKeyLabelColumnCore(IHostEnvironment env, IPredictor pred, PredictorModel[] models, bool isBinary, ColumnType labelType) { env.Assert(!labelType.IsKey); env.AssertNonEmpty(models); @@ -633,13 +639,13 @@ private static int CheckNonKeyLabelColumnCore(IHostEnvironment env, IPredictor p // Checks that all the label columns of the model have the same key type as their label column - including the same // cardinality and the same key values, and returns the cardinality of the label column key. - private static int CheckKeyLabelColumnCore(IHostEnvironment env, IPredictorModel[] models, KeyType labelType, ISchema schema, int labelIndex, ColumnType keyValuesType) + private static int CheckKeyLabelColumnCore(IHostEnvironment env, PredictorModel[] models, KeyType labelType, Schema schema, int labelIndex, ColumnType keyValuesType) where T : IEquatable { env.Assert(keyValuesType.ItemType.RawType == typeof(T)); env.AssertNonEmpty(models); var labelNames = default(VBuffer); - schema.GetMetadata(MetadataUtils.Kinds.KeyValues, labelIndex, ref labelNames); + schema[labelIndex].Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref labelNames); var classCount = labelNames.Length; var curLabelNames = default(VBuffer); @@ -648,18 +654,19 @@ private static int CheckKeyLabelColumnCore(IHostEnvironment env, IPredictorMo var model = models[i]; var edv = new EmptyDataView(env, model.TransformModel.InputSchema); model.PrepareData(env, edv, out RoleMappedData rmd, out IPredictor pred); - var labelInfo = rmd.Schema.Label; - if (labelInfo == null) + var labelInfo = rmd.Schema.Label.HasValue; + if (!rmd.Schema.Label.HasValue) throw env.Except("Training schema for model {0} does not have a label column", i); + var labelCol = rmd.Schema.Label.Value; - var curLabelType = rmd.Schema.Schema.GetColumnType(rmd.Schema.Label.Index); - if (!labelType.Equals(curLabelType.AsKey)) + var curLabelType = labelCol.Type as KeyType; + if (!labelType.Equals(curLabelType)) throw env.Except("Label column of model {0} has different type than model 0", i); - var mdType = rmd.Schema.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, labelInfo.Index); + var mdType = labelCol.Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues)?.Type; if (!mdType.Equals(keyValuesType)) throw env.Except("Label column of model {0} has different key value type than model 0", i); - rmd.Schema.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, labelInfo.Index, ref curLabelNames); + labelCol.Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref curLabelNames); if (!AreEqual(in labelNames, in curLabelNames)) throw env.Except("Label of model {0} has different values than model 0", i); } @@ -688,7 +695,7 @@ private static bool AreEqual(in VBuffer v1, in VBuffer v2) /// - If neither of those interfaces are implemented then the value is a string containing the name of the type of model. /// /// - public IList> GetSummaryInKeyValuePairs(RoleMappedSchema schema) + IList> ICanGetSummaryInKeyValuePairs.GetSummaryInKeyValuePairs(RoleMappedSchema schema) { Host.CheckValueOrNull(schema); diff --git a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/BaseDisagreementDiversityMeasure.cs b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/BaseDisagreementDiversityMeasure.cs index 52d5dd9a58..f6ac5da19e 100644 --- a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/BaseDisagreementDiversityMeasure.cs +++ b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/BaseDisagreementDiversityMeasure.cs @@ -6,7 +6,7 @@ using System.Collections.Concurrent; using System.Collections.Generic; -namespace Microsoft.ML.Runtime.Ensemble.Selector.DiversityMeasure +namespace Microsoft.ML.Ensemble.Selector.DiversityMeasure { public abstract class BaseDisagreementDiversityMeasure : IDiversityMeasure { diff --git a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/DisagreementDiversityMeasure.cs b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/DisagreementDiversityMeasure.cs index d36f00ea9d..c5af6268c3 100644 --- a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/DisagreementDiversityMeasure.cs +++ b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/DisagreementDiversityMeasure.cs @@ -3,14 +3,14 @@ // See the LICENSE file in the project root for more information. using System; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Ensemble.Selector; -using Microsoft.ML.Runtime.Ensemble.Selector.DiversityMeasure; +using Microsoft.ML; +using Microsoft.ML.Ensemble.Selector; +using Microsoft.ML.Ensemble.Selector.DiversityMeasure; [assembly: LoadableClass(typeof(DisagreementDiversityMeasure), null, typeof(SignatureEnsembleDiversityMeasure), DisagreementDiversityMeasure.UserName, DisagreementDiversityMeasure.LoadName)] -namespace Microsoft.ML.Runtime.Ensemble.Selector.DiversityMeasure +namespace Microsoft.ML.Ensemble.Selector.DiversityMeasure { public class DisagreementDiversityMeasure : BaseDisagreementDiversityMeasure, IBinaryDiversityMeasure { diff --git a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/ModelDiversityMetric.cs b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/ModelDiversityMetric.cs index 1ee03a9489..621671c2f0 100644 --- a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/ModelDiversityMetric.cs +++ b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/ModelDiversityMetric.cs @@ -4,7 +4,7 @@ using System; -namespace Microsoft.ML.Runtime.Ensemble.Selector.DiversityMeasure +namespace Microsoft.ML.Ensemble.Selector.DiversityMeasure { public class ModelDiversityMetric { diff --git a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/MultiDisagreementDiversityMeasure.cs b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/MultiDisagreementDiversityMeasure.cs index 424e2a328d..1645ae9a8a 100644 --- a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/MultiDisagreementDiversityMeasure.cs +++ b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/MultiDisagreementDiversityMeasure.cs @@ -3,16 +3,16 @@ // See the LICENSE file in the project root for more information. using System; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Ensemble.Selector; -using Microsoft.ML.Runtime.Ensemble.Selector.DiversityMeasure; -using Microsoft.ML.Runtime.Numeric; +using Microsoft.ML; +using Microsoft.ML.Data; +using Microsoft.ML.Ensemble.Selector; +using Microsoft.ML.Ensemble.Selector.DiversityMeasure; +using Microsoft.ML.Numeric; [assembly: LoadableClass(typeof(MultiDisagreementDiversityMeasure), null, typeof(SignatureEnsembleDiversityMeasure), DisagreementDiversityMeasure.UserName, MultiDisagreementDiversityMeasure.LoadName)] -namespace Microsoft.ML.Runtime.Ensemble.Selector.DiversityMeasure +namespace Microsoft.ML.Ensemble.Selector.DiversityMeasure { public class MultiDisagreementDiversityMeasure : BaseDisagreementDiversityMeasure>, IMulticlassDiversityMeasure { diff --git a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/RegressionDisagreementDiversityMeasure.cs b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/RegressionDisagreementDiversityMeasure.cs index 68f882d19c..67b9eb0d21 100644 --- a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/RegressionDisagreementDiversityMeasure.cs +++ b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/RegressionDisagreementDiversityMeasure.cs @@ -3,14 +3,14 @@ // See the LICENSE file in the project root for more information. using System; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Ensemble.Selector; -using Microsoft.ML.Runtime.Ensemble.Selector.DiversityMeasure; +using Microsoft.ML; +using Microsoft.ML.Ensemble.Selector; +using Microsoft.ML.Ensemble.Selector.DiversityMeasure; [assembly: LoadableClass(typeof(RegressionDisagreementDiversityMeasure), null, typeof(SignatureEnsembleDiversityMeasure), DisagreementDiversityMeasure.UserName, RegressionDisagreementDiversityMeasure.LoadName)] -namespace Microsoft.ML.Runtime.Ensemble.Selector.DiversityMeasure +namespace Microsoft.ML.Ensemble.Selector.DiversityMeasure { public class RegressionDisagreementDiversityMeasure : BaseDisagreementDiversityMeasure, IRegressionDiversityMeasure { diff --git a/src/Microsoft.ML.Ensemble/Selector/FeatureSelector/AllFeatureSelector.cs b/src/Microsoft.ML.Ensemble/Selector/FeatureSelector/AllFeatureSelector.cs index 84a70a6ee4..f2ffadf877 100644 --- a/src/Microsoft.ML.Ensemble/Selector/FeatureSelector/AllFeatureSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/FeatureSelector/AllFeatureSelector.cs @@ -2,17 +2,18 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Ensemble.Selector; -using Microsoft.ML.Runtime.Ensemble.Selector.FeatureSelector; +using System; +using Microsoft.ML; +using Microsoft.ML.Data; +using Microsoft.ML.Ensemble.Selector; +using Microsoft.ML.Ensemble.Selector.FeatureSelector; [assembly: LoadableClass(typeof(AllFeatureSelector), null, typeof(SignatureEnsembleFeatureSelector), AllFeatureSelector.UserName, AllFeatureSelector.LoadName)] -namespace Microsoft.ML.Runtime.Ensemble.Selector.FeatureSelector +namespace Microsoft.ML.Ensemble.Selector.FeatureSelector { - public sealed class AllFeatureSelector : IFeatureSelector + internal sealed class AllFeatureSelector : IFeatureSelector { public const string UserName = "All Feature Selector"; public const string LoadName = "AllFeatureSelector"; @@ -21,7 +22,7 @@ public AllFeatureSelector(IHostEnvironment env) { } - public Subset SelectFeatures(RoleMappedData data, IRandom rand) + public Subset SelectFeatures(RoleMappedData data, Random rand) { return new Subset(data); } diff --git a/src/Microsoft.ML.Ensemble/Selector/FeatureSelector/RandomFeatureSelector.cs b/src/Microsoft.ML.Ensemble/Selector/FeatureSelector/RandomFeatureSelector.cs index c0c9b8968f..e28911c981 100644 --- a/src/Microsoft.ML.Ensemble/Selector/FeatureSelector/RandomFeatureSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/FeatureSelector/RandomFeatureSelector.cs @@ -4,20 +4,20 @@ using System; using System.Collections; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Ensemble.Selector; -using Microsoft.ML.Runtime.Ensemble.Selector.FeatureSelector; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Training; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.Ensemble.Selector; +using Microsoft.ML.Ensemble.Selector.FeatureSelector; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Training; [assembly: LoadableClass(typeof(RandomFeatureSelector), typeof(RandomFeatureSelector.Arguments), typeof(SignatureEnsembleFeatureSelector), RandomFeatureSelector.UserName, RandomFeatureSelector.LoadName)] -namespace Microsoft.ML.Runtime.Ensemble.Selector.FeatureSelector +namespace Microsoft.ML.Ensemble.Selector.FeatureSelector { - public class RandomFeatureSelector : IFeatureSelector + internal class RandomFeatureSelector : IFeatureSelector { public const string UserName = "Random Feature Selector"; public const string LoadName = "RandomFeatureSelector"; @@ -45,12 +45,12 @@ public RandomFeatureSelector(IHostEnvironment env, Arguments args) "The feature proportion for RandomFeatureSelector should be greater than 0 and lesser than 1"); } - public Subset SelectFeatures(RoleMappedData data, IRandom rand) + public Subset SelectFeatures(RoleMappedData data, Random rand) { _host.CheckValue(data, nameof(data)); data.CheckFeatureFloatVector(); - var type = data.Schema.Feature.Type; + var type = data.Schema.Feature.Value.Type; int len = type.VectorSize; var features = new BitArray(len); for (int j = 0; j < len; j++) diff --git a/src/Microsoft.ML.Ensemble/Selector/IDiversityMeasure.cs b/src/Microsoft.ML.Ensemble/Selector/IDiversityMeasure.cs index 01589c4714..4209761c94 100644 --- a/src/Microsoft.ML.Ensemble/Selector/IDiversityMeasure.cs +++ b/src/Microsoft.ML.Ensemble/Selector/IDiversityMeasure.cs @@ -5,11 +5,11 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Ensemble.Selector.DiversityMeasure; -using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML.Data; +using Microsoft.ML.Ensemble.Selector.DiversityMeasure; +using Microsoft.ML.EntryPoints; -namespace Microsoft.ML.Runtime.Ensemble.Selector +namespace Microsoft.ML.Ensemble.Selector { public interface IDiversityMeasure { diff --git a/src/Microsoft.ML.Ensemble/Selector/IFeatureSelector.cs b/src/Microsoft.ML.Ensemble/Selector/IFeatureSelector.cs index 99e90c5c01..e4eb986294 100644 --- a/src/Microsoft.ML.Ensemble/Selector/IFeatureSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/IFeatureSelector.cs @@ -2,20 +2,21 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; +using System; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; -namespace Microsoft.ML.Runtime.Ensemble.Selector +namespace Microsoft.ML.Ensemble.Selector { - public interface IFeatureSelector + internal interface IFeatureSelector { - Subset SelectFeatures(RoleMappedData data, IRandom rand); + Subset SelectFeatures(RoleMappedData data, Random rand); } public delegate void SignatureEnsembleFeatureSelector(); [TlcModule.ComponentKind("EnsembleFeatureSelector")] - public interface ISupportFeatureSelectorFactory : IComponentFactory + internal interface ISupportFeatureSelectorFactory : IComponentFactory { } } diff --git a/src/Microsoft.ML.Ensemble/Selector/ISubModelSelector.cs b/src/Microsoft.ML.Ensemble/Selector/ISubModelSelector.cs index 96e9f5b886..64ecf45bce 100644 --- a/src/Microsoft.ML.Ensemble/Selector/ISubModelSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/ISubModelSelector.cs @@ -2,14 +2,14 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; using System; using System.Collections.Generic; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; -namespace Microsoft.ML.Runtime.Ensemble.Selector +namespace Microsoft.ML.Ensemble.Selector { - public interface ISubModelSelector + internal interface ISubModelSelector { IList>> Prune(IList>> models); @@ -19,33 +19,33 @@ void CalculateMetrics(FeatureSubsetModel> model, IS Single ValidationDatasetProportion { get; } } - public interface IRegressionSubModelSelector : ISubModelSelector + internal interface IRegressionSubModelSelector : ISubModelSelector { } - public interface IBinarySubModelSelector : ISubModelSelector + internal interface IBinarySubModelSelector : ISubModelSelector { } - public interface IMulticlassSubModelSelector : ISubModelSelector> + internal interface IMulticlassSubModelSelector : ISubModelSelector> { } - public delegate void SignatureEnsembleSubModelSelector(); + internal delegate void SignatureEnsembleSubModelSelector(); [TlcModule.ComponentKind("EnsembleMulticlassSubModelSelector")] - public interface ISupportMulticlassSubModelSelectorFactory : IComponentFactory + internal interface ISupportMulticlassSubModelSelectorFactory : IComponentFactory { } [TlcModule.ComponentKind("EnsembleBinarySubModelSelector")] - public interface ISupportBinarySubModelSelectorFactory: IComponentFactory + internal interface ISupportBinarySubModelSelectorFactory: IComponentFactory { } [TlcModule.ComponentKind("EnsembleRegressionSubModelSelector")] - public interface ISupportRegressionSubModelSelectorFactory : IComponentFactory + internal interface ISupportRegressionSubModelSelectorFactory : IComponentFactory { } diff --git a/src/Microsoft.ML.Ensemble/Selector/ISubsetSelector.cs b/src/Microsoft.ML.Ensemble/Selector/ISubsetSelector.cs index 2a0f088219..6ba7002508 100644 --- a/src/Microsoft.ML.Ensemble/Selector/ISubsetSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/ISubsetSelector.cs @@ -4,23 +4,23 @@ using System; using System.Collections.Generic; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; -namespace Microsoft.ML.Runtime.Ensemble.Selector +namespace Microsoft.ML.Ensemble.Selector { - public interface ISubsetSelector + internal interface ISubsetSelector { void Initialize(RoleMappedData data, int size, int batchSize, Single validationDatasetProportion); - IEnumerable GetBatches(IRandom rand); - IEnumerable GetSubsets(Batch batch, IRandom rand); + IEnumerable GetBatches(Random rand); + IEnumerable GetSubsets(Batch batch, Random rand); RoleMappedData GetTestData(Subset subset, Batch batch); } public delegate void SignatureEnsembleDataSelector(); [TlcModule.ComponentKind("EnsembleSubsetSelector")] - public interface ISupportSubsetSelectorFactory : IComponentFactory + internal interface ISupportSubsetSelectorFactory : IComponentFactory { } } diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/AllSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/AllSelector.cs index 4196ab3558..f88df3bfee 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/AllSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/AllSelector.cs @@ -3,15 +3,15 @@ // See the LICENSE file in the project root for more information. using System; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Ensemble.Selector; -using Microsoft.ML.Runtime.Ensemble.Selector.SubModelSelector; +using Microsoft.ML; +using Microsoft.ML.Ensemble.Selector; +using Microsoft.ML.Ensemble.Selector.SubModelSelector; [assembly: LoadableClass(typeof(AllSelector), null, typeof(SignatureEnsembleSubModelSelector), AllSelector.UserName, AllSelector.LoadName)] -namespace Microsoft.ML.Runtime.Ensemble.Selector.SubModelSelector +namespace Microsoft.ML.Ensemble.Selector.SubModelSelector { - public class AllSelector : BaseSubModelSelector, IBinarySubModelSelector, IRegressionSubModelSelector + internal sealed class AllSelector : BaseSubModelSelector, IBinarySubModelSelector, IRegressionSubModelSelector { public const string UserName = "All Selector"; public const string LoadName = "AllSelector"; diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/AllSelectorMultiClass.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/AllSelectorMultiClass.cs index 6c82fc25f5..2158b05733 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/AllSelectorMultiClass.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/AllSelectorMultiClass.cs @@ -3,17 +3,17 @@ // See the LICENSE file in the project root for more information. using System; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Ensemble.Selector; -using Microsoft.ML.Runtime.Ensemble.Selector.SubModelSelector; +using Microsoft.ML; +using Microsoft.ML.Data; +using Microsoft.ML.Ensemble.Selector; +using Microsoft.ML.Ensemble.Selector.SubModelSelector; [assembly: LoadableClass(typeof(AllSelectorMultiClass), null, typeof(SignatureEnsembleSubModelSelector), AllSelectorMultiClass.UserName, AllSelectorMultiClass.LoadName)] -namespace Microsoft.ML.Runtime.Ensemble.Selector.SubModelSelector +namespace Microsoft.ML.Ensemble.Selector.SubModelSelector { - public class AllSelectorMultiClass : BaseSubModelSelector>, IMulticlassSubModelSelector + internal sealed class AllSelectorMultiClass : BaseSubModelSelector>, IMulticlassSubModelSelector { public const string UserName = "All Selector"; public const string LoadName = "AllSelectorMultiClass"; diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseBestPerformanceSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseBestPerformanceSelector.cs index 1e51044fb9..cc6292e9b2 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseBestPerformanceSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseBestPerformanceSelector.cs @@ -6,11 +6,11 @@ using System.Collections.Generic; using System.Linq; using System.Reflection; -using Microsoft.ML.Runtime.CommandLine; +using Microsoft.ML.CommandLine; -namespace Microsoft.ML.Runtime.Ensemble.Selector.SubModelSelector +namespace Microsoft.ML.Ensemble.Selector.SubModelSelector { - public abstract class BaseBestPerformanceSelector : SubModelDataSelector + internal abstract class BaseBestPerformanceSelector : SubModelDataSelector { protected abstract string MetricName { get; } diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseDiverseSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseDiverseSelector.cs index ac2e85dda6..98d5c47c02 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseDiverseSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseDiverseSelector.cs @@ -5,15 +5,14 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Ensemble.Selector.DiversityMeasure; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Training; +using Microsoft.ML.Data; +using Microsoft.ML.Ensemble.Selector.DiversityMeasure; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Training; -namespace Microsoft.ML.Runtime.Ensemble.Selector.SubModelSelector +namespace Microsoft.ML.Ensemble.Selector.SubModelSelector { - public abstract class BaseDiverseSelector : SubModelDataSelector + internal abstract class BaseDiverseSelector : SubModelDataSelector where TDiversityMetric : class, IDiversityMeasure { public abstract class DiverseSelectorArguments : ArgumentsBase diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs index 87ab03bb7c..e4df5351da 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs @@ -6,12 +6,10 @@ using System.Collections.Generic; using System.Linq; using Microsoft.ML.Data; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -namespace Microsoft.ML.Runtime.Ensemble.Selector.SubModelSelector +namespace Microsoft.ML.Ensemble.Selector.SubModelSelector { - public abstract class BaseSubModelSelector : ISubModelSelector + internal abstract class BaseSubModelSelector : ISubModelSelector { protected readonly IHost Host; @@ -106,27 +104,27 @@ public virtual void CalculateMetrics(FeatureSubsetModel t == NumberType.Float); - if (probInfo != null) - yield return RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Probability, probInfo.Name); + var probCol = EvaluateUtils.GetOptAuxScoreColumn(Host, scoredSchema, null, nameof(BinaryClassifierMamlEvaluator.Arguments.ProbabilityColumn), + scoreCol.Index, MetadataUtils.Const.ScoreValueKind.Probability, NumberType.Float.Equals); + if (probCol.HasValue) + yield return RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Probability, probCol.Value.Name); yield break; case PredictionKind.Regression: - yield return RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Label, testSchema.Label.Name); - scoreInfo = EvaluateUtils.GetScoreColumnInfo(Host, scoredSchema, null, nameof(RegressionMamlEvaluator.Arguments.ScoreColumn), + yield return RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Label, testSchema.Label.Value.Name); + scoreCol = EvaluateUtils.GetScoreColumn(Host, scoredSchema, null, nameof(RegressionMamlEvaluator.Arguments.ScoreColumn), MetadataUtils.Const.ScoreColumnKind.Regression); - yield return RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, scoreInfo.Name); + yield return RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, scoreCol.Name); yield break; case PredictionKind.MultiClassClassification: - yield return RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Label, testSchema.Label.Name); - scoreInfo = EvaluateUtils.GetScoreColumnInfo(Host, scoredSchema, null, nameof(MultiClassMamlEvaluator.Arguments.ScoreColumn), + yield return RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Label, testSchema.Label.Value.Name); + scoreCol = EvaluateUtils.GetScoreColumn(Host, scoredSchema, null, nameof(MultiClassMamlEvaluator.Arguments.ScoreColumn), MetadataUtils.Const.ScoreColumnKind.MultiClassClassification); - yield return RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, scoreInfo.Name); + yield return RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, scoreCol.Name); yield break; default: throw Host.Except("Unrecognized prediction kind '{0}'", PredictionKind); diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorBinary.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorBinary.cs index 75d25df9d6..cb56769b79 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorBinary.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorBinary.cs @@ -5,23 +5,23 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; +using Microsoft.ML; +using Microsoft.ML.CommandLine; using Microsoft.ML.Ensemble.EntryPoints; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Ensemble.Selector; -using Microsoft.ML.Runtime.Ensemble.Selector.DiversityMeasure; -using Microsoft.ML.Runtime.Ensemble.Selector.SubModelSelector; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Internallearn; +using Microsoft.ML.Ensemble.Selector; +using Microsoft.ML.Ensemble.Selector.DiversityMeasure; +using Microsoft.ML.Ensemble.Selector.SubModelSelector; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Internallearn; [assembly: LoadableClass(typeof(BestDiverseSelectorBinary), typeof(BestDiverseSelectorBinary.Arguments), typeof(SignatureEnsembleSubModelSelector), BestDiverseSelectorBinary.UserName, BestDiverseSelectorBinary.LoadName)] -namespace Microsoft.ML.Runtime.Ensemble.Selector.SubModelSelector +namespace Microsoft.ML.Ensemble.Selector.SubModelSelector { using TScalarPredictor = IPredictorProducing; - public sealed class BestDiverseSelectorBinary : BaseDiverseSelector, IBinarySubModelSelector + internal sealed class BestDiverseSelectorBinary : BaseDiverseSelector, IBinarySubModelSelector { public const string UserName = "Best Diverse Selector"; public const string LoadName = "BestDiverseSelector"; diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorMultiClass.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorMultiClass.cs index 8856982f00..d5e6b9ae91 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorMultiClass.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorMultiClass.cs @@ -5,24 +5,24 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; using Microsoft.ML.Ensemble.EntryPoints; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Ensemble.Selector; -using Microsoft.ML.Runtime.Ensemble.Selector.DiversityMeasure; -using Microsoft.ML.Runtime.Ensemble.Selector.SubModelSelector; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Internallearn; +using Microsoft.ML.Ensemble.Selector; +using Microsoft.ML.Ensemble.Selector.DiversityMeasure; +using Microsoft.ML.Ensemble.Selector.SubModelSelector; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Internallearn; [assembly: LoadableClass(typeof(BestDiverseSelectorMultiClass), typeof(BestDiverseSelectorMultiClass.Arguments), typeof(SignatureEnsembleSubModelSelector), BestDiverseSelectorMultiClass.UserName, BestDiverseSelectorMultiClass.LoadName)] -namespace Microsoft.ML.Runtime.Ensemble.Selector.SubModelSelector +namespace Microsoft.ML.Ensemble.Selector.SubModelSelector { using TVectorPredictor = IPredictorProducing>; - public sealed class BestDiverseSelectorMultiClass : BaseDiverseSelector, IDiversityMeasure>>, IMulticlassSubModelSelector + internal sealed class BestDiverseSelectorMultiClass : BaseDiverseSelector, IDiversityMeasure>>, IMulticlassSubModelSelector { public const string UserName = "Best Diverse Selector"; public const string LoadName = "BestDiverseSelectorMultiClass"; diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorRegression.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorRegression.cs index cbd2d47330..a276995a9a 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorRegression.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorRegression.cs @@ -5,23 +5,23 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; +using Microsoft.ML; +using Microsoft.ML.CommandLine; using Microsoft.ML.Ensemble.EntryPoints; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Ensemble.Selector; -using Microsoft.ML.Runtime.Ensemble.Selector.DiversityMeasure; -using Microsoft.ML.Runtime.Ensemble.Selector.SubModelSelector; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Internallearn; +using Microsoft.ML.Ensemble.Selector; +using Microsoft.ML.Ensemble.Selector.DiversityMeasure; +using Microsoft.ML.Ensemble.Selector.SubModelSelector; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Internallearn; [assembly: LoadableClass(typeof(BestDiverseSelectorRegression), typeof(BestDiverseSelectorRegression.Arguments), typeof(SignatureEnsembleSubModelSelector), BestDiverseSelectorRegression.UserName, BestDiverseSelectorRegression.LoadName)] -namespace Microsoft.ML.Runtime.Ensemble.Selector.SubModelSelector +namespace Microsoft.ML.Ensemble.Selector.SubModelSelector { using TScalarPredictor = IPredictorProducing; - public sealed class BestDiverseSelectorRegression : BaseDiverseSelector, IRegressionSubModelSelector + internal sealed class BestDiverseSelectorRegression : BaseDiverseSelector, IRegressionSubModelSelector { public const string UserName = "Best Diverse Selector"; public const string LoadName = "BestDiverseSelectorRegression"; diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestPerformanceRegressionSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestPerformanceRegressionSelector.cs index 46f13e9cd1..be7b7e8f35 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestPerformanceRegressionSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestPerformanceRegressionSelector.cs @@ -3,20 +3,20 @@ // See the LICENSE file in the project root for more information. using System; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Ensemble.Selector; -using Microsoft.ML.Runtime.Ensemble.Selector.SubModelSelector; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Internallearn; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.Ensemble.Selector; +using Microsoft.ML.Ensemble.Selector.SubModelSelector; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Internallearn; [assembly: LoadableClass(typeof(BestPerformanceRegressionSelector), typeof(BestPerformanceRegressionSelector.Arguments), typeof(SignatureEnsembleSubModelSelector), BestPerformanceRegressionSelector.UserName, BestPerformanceRegressionSelector.LoadName)] -namespace Microsoft.ML.Runtime.Ensemble.Selector.SubModelSelector +namespace Microsoft.ML.Ensemble.Selector.SubModelSelector { - public sealed class BestPerformanceRegressionSelector : BaseBestPerformanceSelector, IRegressionSubModelSelector + internal sealed class BestPerformanceRegressionSelector : BaseBestPerformanceSelector, IRegressionSubModelSelector { [TlcModule.Component(Name = LoadName, FriendlyName = UserName)] public sealed class Arguments : ArgumentsBase, ISupportRegressionSubModelSelectorFactory diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestPerformanceSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestPerformanceSelector.cs index 76742ad0ec..9b24798276 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestPerformanceSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestPerformanceSelector.cs @@ -3,20 +3,20 @@ // See the LICENSE file in the project root for more information. using System; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Ensemble.Selector; -using Microsoft.ML.Runtime.Ensemble.Selector.SubModelSelector; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Internallearn; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.Ensemble.Selector; +using Microsoft.ML.Ensemble.Selector.SubModelSelector; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Internallearn; [assembly: LoadableClass(typeof(BestPerformanceSelector), typeof(BestPerformanceSelector.Arguments), typeof(SignatureEnsembleSubModelSelector), BestPerformanceSelector.UserName, BestPerformanceSelector.LoadName)] -namespace Microsoft.ML.Runtime.Ensemble.Selector.SubModelSelector +namespace Microsoft.ML.Ensemble.Selector.SubModelSelector { - public sealed class BestPerformanceSelector : BaseBestPerformanceSelector, IBinarySubModelSelector + internal sealed class BestPerformanceSelector : BaseBestPerformanceSelector, IBinarySubModelSelector { [TlcModule.Component(Name = LoadName, FriendlyName = UserName)] public sealed class Arguments : ArgumentsBase, ISupportBinarySubModelSelectorFactory diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestPerformanceSelectorMultiClass.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestPerformanceSelectorMultiClass.cs index 0a9b9ac497..36f9635c87 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestPerformanceSelectorMultiClass.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestPerformanceSelectorMultiClass.cs @@ -3,20 +3,20 @@ // See the LICENSE file in the project root for more information. using System; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Ensemble.Selector; -using Microsoft.ML.Runtime.Ensemble.Selector.SubModelSelector; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Internallearn; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.Ensemble.Selector; +using Microsoft.ML.Ensemble.Selector.SubModelSelector; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Internallearn; [assembly: LoadableClass(typeof(BestPerformanceSelectorMultiClass), typeof(BestPerformanceSelectorMultiClass.Arguments), typeof(SignatureEnsembleSubModelSelector), BestPerformanceSelectorMultiClass.UserName, BestPerformanceSelectorMultiClass.LoadName)] -namespace Microsoft.ML.Runtime.Ensemble.Selector.SubModelSelector +namespace Microsoft.ML.Ensemble.Selector.SubModelSelector { - public class BestPerformanceSelectorMultiClass : BaseBestPerformanceSelector>, IMulticlassSubModelSelector + internal sealed class BestPerformanceSelectorMultiClass : BaseBestPerformanceSelector>, IMulticlassSubModelSelector { [TlcModule.Component(Name = LoadName, FriendlyName = UserName)] public sealed class Arguments : ArgumentsBase, ISupportMulticlassSubModelSelectorFactory diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/SubModelDataSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/SubModelDataSelector.cs index 5953b30d97..0ba6497f25 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/SubModelDataSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/SubModelDataSelector.cs @@ -3,12 +3,12 @@ // See the LICENSE file in the project root for more information. using System; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Internal.Internallearn; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Internal.Internallearn; -namespace Microsoft.ML.Runtime.Ensemble.Selector.SubModelSelector +namespace Microsoft.ML.Ensemble.Selector.SubModelSelector { - public abstract class SubModelDataSelector : BaseSubModelSelector + internal abstract class SubModelDataSelector : BaseSubModelSelector { public abstract class ArgumentsBase { diff --git a/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/AllInstanceSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/AllInstanceSelector.cs index 98e6eeb8eb..43203e9ca9 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/AllInstanceSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/AllInstanceSelector.cs @@ -2,20 +2,21 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; using System.Collections.Generic; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Ensemble.Selector; -using Microsoft.ML.Runtime.Ensemble.Selector.SubsetSelector; -using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML; +using Microsoft.ML.Ensemble.Selector; +using Microsoft.ML.Ensemble.Selector.SubsetSelector; +using Microsoft.ML.EntryPoints; [assembly: LoadableClass(typeof(AllInstanceSelector), typeof(AllInstanceSelector.Arguments), typeof(SignatureEnsembleDataSelector), AllInstanceSelector.UserName, AllInstanceSelector.LoadName)] [assembly: EntryPointModule(typeof(AllInstanceSelector))] -namespace Microsoft.ML.Runtime.Ensemble.Selector.SubsetSelector +namespace Microsoft.ML.Ensemble.Selector.SubsetSelector { - public sealed class AllInstanceSelector : BaseSubsetSelector + internal sealed class AllInstanceSelector : BaseSubsetSelector { public const string UserName = "All Instance Selector"; public const string LoadName = "AllInstanceSelector"; @@ -31,7 +32,7 @@ public AllInstanceSelector(IHostEnvironment env, Arguments args) { } - public override IEnumerable GetSubsets(Batch batch, IRandom rand) + public override IEnumerable GetSubsets(Batch batch, Random rand) { for (int i = 0; i < Size; i++) yield return FeatureSelector.SelectFeatures(batch.TrainInstances, rand); diff --git a/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/BaseSubsetSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/BaseSubsetSelector.cs index 78564618db..8e163bee6b 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/BaseSubsetSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/BaseSubsetSelector.cs @@ -2,16 +2,16 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Ensemble.EntryPoints; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Transforms; using System; using System.Collections.Generic; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.Ensemble.EntryPoints; +using Microsoft.ML.Transforms; -namespace Microsoft.ML.Runtime.Ensemble.Selector.SubsetSelector +namespace Microsoft.ML.Ensemble.Selector.SubsetSelector { - public abstract class BaseSubsetSelector : ISubsetSelector + internal abstract class BaseSubsetSelector : ISubsetSelector where TArgs : BaseSubsetSelector.ArgumentsBase { public abstract class ArgumentsBase @@ -52,9 +52,9 @@ public void Initialize(RoleMappedData data, int size, int batchSize, Single vali ValidationDatasetProportion = validationDatasetProportion; } - public abstract IEnumerable GetSubsets(Batch batch, IRandom rand); + public abstract IEnumerable GetSubsets(Batch batch, Random rand); - public IEnumerable GetBatches(IRandom rand) + public IEnumerable GetBatches(Random rand) { Host.Assert(Data != null, "Must call Initialize first!"); Host.AssertValue(rand); diff --git a/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/BootstrapSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/BootstrapSelector.cs index fd045f71a8..149fb91159 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/BootstrapSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/BootstrapSelector.cs @@ -2,12 +2,13 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; using System.Collections.Generic; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Ensemble.Selector; -using Microsoft.ML.Runtime.Ensemble.Selector.SubsetSelector; -using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML; +using Microsoft.ML.Data; +using Microsoft.ML.Ensemble.Selector; +using Microsoft.ML.Ensemble.Selector.SubsetSelector; +using Microsoft.ML.EntryPoints; using Microsoft.ML.Transforms; [assembly: LoadableClass(typeof(BootstrapSelector), typeof(BootstrapSelector.Arguments), @@ -15,9 +16,9 @@ [assembly: EntryPointModule(typeof(BootstrapSelector))] -namespace Microsoft.ML.Runtime.Ensemble.Selector.SubsetSelector +namespace Microsoft.ML.Ensemble.Selector.SubsetSelector { - public sealed class BootstrapSelector : BaseSubsetSelector + internal sealed class BootstrapSelector : BaseSubsetSelector { public const string UserName = "Bootstrap Selector"; public const string LoadName = "BootstrapSelector"; @@ -41,7 +42,7 @@ public BootstrapSelector(IHostEnvironment env, Arguments args) { } - public override IEnumerable GetSubsets(Batch batch, IRandom rand) + public override IEnumerable GetSubsets(Batch batch, Random rand) { for (int i = 0; i < Size; i++) { diff --git a/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/RandomPartitionSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/RandomPartitionSelector.cs index 0768b2a79e..af7b93a0b5 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/RandomPartitionSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/RandomPartitionSelector.cs @@ -2,23 +2,23 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Ensemble.Selector; -using Microsoft.ML.Runtime.Ensemble.Selector.SubsetSelector; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Transforms; using System; using System.Collections.Generic; +using Microsoft.ML; +using Microsoft.ML.Data; +using Microsoft.ML.Ensemble.Selector; +using Microsoft.ML.Ensemble.Selector.SubsetSelector; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Transforms; [assembly: LoadableClass(typeof(RandomPartitionSelector), typeof(RandomPartitionSelector.Arguments), typeof(SignatureEnsembleDataSelector), RandomPartitionSelector.UserName, RandomPartitionSelector.LoadName)] [assembly: EntryPointModule(typeof(RandomPartitionSelector))] -namespace Microsoft.ML.Runtime.Ensemble.Selector.SubsetSelector +namespace Microsoft.ML.Ensemble.Selector.SubsetSelector { - public sealed class RandomPartitionSelector : BaseSubsetSelector + internal sealed class RandomPartitionSelector : BaseSubsetSelector { public const string UserName = "Random Partition Selector"; public const string LoadName = "RandomPartitionSelector"; @@ -34,7 +34,7 @@ public RandomPartitionSelector(IHostEnvironment env, Arguments args) { } - public override IEnumerable GetSubsets(Batch batch, IRandom rand) + public override IEnumerable GetSubsets(Batch batch, Random rand) { string name = Data.Data.Schema.GetTempColumnName(); var args = new GenerateNumberTransform.Arguments(); diff --git a/src/Microsoft.ML.Ensemble/Subset.cs b/src/Microsoft.ML.Ensemble/Subset.cs index 77183de6f3..743be33df6 100644 --- a/src/Microsoft.ML.Ensemble/Subset.cs +++ b/src/Microsoft.ML.Ensemble/Subset.cs @@ -3,11 +3,11 @@ // See the LICENSE file in the project root for more information. using System.Collections; -using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Data; -namespace Microsoft.ML.Runtime.Ensemble +namespace Microsoft.ML.Ensemble { - public sealed class Subset + internal sealed class Subset { public readonly RoleMappedData Data; public readonly BitArray SelectedFeatures; diff --git a/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs index 27fdf19d62..bd74ae8b27 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs @@ -5,14 +5,13 @@ using System; using System.Collections.Generic; using System.Linq; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Ensemble; using Microsoft.ML.Ensemble.EntryPoints; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Ensemble; -using Microsoft.ML.Runtime.Ensemble.OutputCombiners; -using Microsoft.ML.Runtime.Ensemble.Selector; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Internallearn; +using Microsoft.ML.Ensemble.OutputCombiners; +using Microsoft.ML.Ensemble.Selector; +using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Trainers.Online; [assembly: LoadableClass(EnsembleTrainer.Summary, typeof(EnsembleTrainer), typeof(EnsembleTrainer.Arguments), @@ -22,7 +21,7 @@ [assembly: LoadableClass(typeof(EnsembleTrainer), typeof(EnsembleTrainer.Arguments), typeof(SignatureModelCombiner), "Binary Classification Ensemble Model Combiner", EnsembleTrainer.LoadNameValue, "pe", "ParallelEnsemble")] -namespace Microsoft.ML.Runtime.Ensemble +namespace Microsoft.ML.Ensemble { using TDistPredictor = IDistPredictorProducing; using TScalarPredictor = IPredictorProducing; @@ -59,7 +58,7 @@ public Arguments() BasePredictors = new[] { ComponentFactoryUtils.CreateFromFunction( - env => new LinearSvm(env)) + env => new LinearSvmTrainer(env)) }; } } @@ -85,8 +84,8 @@ private EnsembleTrainer(IHostEnvironment env, Arguments args, PredictionKind pre private protected override TScalarPredictor CreatePredictor(List> models) { if (models.All(m => m.Predictor is TDistPredictor)) - return new EnsembleDistributionPredictor(Host, PredictionKind, CreateModels(models), Combiner); - return new EnsemblePredictor(Host, PredictionKind, CreateModels(models), Combiner); + return new EnsembleDistributionModelParameters(Host, PredictionKind, CreateModels(models), Combiner); + return new EnsembleModelParameters(Host, PredictionKind, CreateModels(models), Combiner); } public IPredictor CombineModels(IEnumerable models) @@ -98,12 +97,12 @@ public IPredictor CombineModels(IEnumerable models) if (p is TDistPredictor) { Host.CheckParam(models.All(m => m is TDistPredictor), nameof(models)); - return new EnsembleDistributionPredictor(Host, p.PredictionKind, + return new EnsembleDistributionModelParameters(Host, p.PredictionKind, models.Select(k => new FeatureSubsetModel((TDistPredictor)k)).ToArray(), combiner); } Host.CheckParam(models.All(m => m is TScalarPredictor), nameof(models)); - return new EnsemblePredictor(Host, p.PredictionKind, + return new EnsembleModelParameters(Host, p.PredictionKind, models.Select(k => new FeatureSubsetModel((TScalarPredictor)k)).ToArray(), combiner); } } diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionPredictor.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionModelParameters.cs similarity index 75% rename from src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionPredictor.cs rename to src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionModelParameters.cs index 5a92512e6d..f19e96fb47 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionPredictor.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionModelParameters.cs @@ -6,27 +6,27 @@ using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Ensemble; -using Microsoft.ML.Runtime.Ensemble.OutputCombiners; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML; +using Microsoft.ML.Data; +using Microsoft.ML.Ensemble; +using Microsoft.ML.Ensemble.OutputCombiners; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; // These are for deserialization from a model repository. -[assembly: LoadableClass(typeof(EnsembleDistributionPredictor), null, typeof(SignatureLoadModel), - EnsembleDistributionPredictor.UserName, EnsembleDistributionPredictor.LoaderSignature)] +[assembly: LoadableClass(typeof(EnsembleDistributionModelParameters), null, typeof(SignatureLoadModel), + EnsembleDistributionModelParameters.UserName, EnsembleDistributionModelParameters.LoaderSignature)] -namespace Microsoft.ML.Runtime.Ensemble +namespace Microsoft.ML.Ensemble { using TDistPredictor = IDistPredictorProducing; - public sealed class EnsembleDistributionPredictor : EnsemblePredictorBase, + public sealed class EnsembleDistributionModelParameters : EnsembleModelParametersBase, TDistPredictor, IValueMapperDist { - public const string UserName = "Ensemble Distribution Executor"; - public const string LoaderSignature = "EnsemDbExec"; - public const string RegistrationName = "EnsembleDistributionPredictor"; + internal const string UserName = "Ensemble Distribution Executor"; + internal const string LoaderSignature = "EnsemDbExec"; + internal const string RegistrationName = "EnsembleDistributionPredictor"; private static VersionInfo GetVersionInfo() { @@ -38,35 +38,45 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010003, verWeCanReadBack: 0x00010002, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(EnsembleDistributionPredictor).Assembly.FullName); + loaderAssemblyName: typeof(EnsembleDistributionModelParameters).Assembly.FullName); } private readonly Single[] _averagedWeights; private readonly Median _probabilityCombiner; private readonly IValueMapperDist[] _mappers; - public ColumnType InputType { get; } - public ColumnType OutputType => NumberType.Float; - public ColumnType DistType => NumberType.Float; + private readonly ColumnType _inputType; + + ColumnType IValueMapper.InputType => _inputType; + ColumnType IValueMapper.OutputType => NumberType.Float; + ColumnType IValueMapperDist.DistType => NumberType.Float; public override PredictionKind PredictionKind { get; } - internal EnsembleDistributionPredictor(IHostEnvironment env, PredictionKind kind, + /// + /// Instantiate new ensemble model from existing sub-models. + /// + /// The host environment. + /// The prediction kind + /// Array of sub-models that you want to ensemble together. + /// The combiner class to use to ensemble the models. + /// The weights assigned to each model to be ensembled. + public EnsembleDistributionModelParameters(IHostEnvironment env, PredictionKind kind, FeatureSubsetModel[] models, IOutputCombiner combiner, Single[] weights = null) : base(env, RegistrationName, models, combiner, weights) { PredictionKind = kind; _probabilityCombiner = new Median(env); - InputType = InitializeMappers(out _mappers); + _inputType = InitializeMappers(out _mappers); ComputeAveragedWeights(out _averagedWeights); } - private EnsembleDistributionPredictor(IHostEnvironment env, ModelLoadContext ctx) + private EnsembleDistributionModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, RegistrationName, ctx) { PredictionKind = (PredictionKind)ctx.Reader.ReadInt32(); _probabilityCombiner = new Median(env); - InputType = InitializeMappers(out _mappers); + _inputType = InitializeMappers(out _mappers); ComputeAveragedWeights(out _averagedWeights); } @@ -101,15 +111,15 @@ private bool IsValid(IValueMapperDist mapper) && mapper.DistType == NumberType.Float; } - public static EnsembleDistributionPredictor Create(IHostEnvironment env, ModelLoadContext ctx) + private static EnsembleDistributionModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - return new EnsembleDistributionPredictor(env, ctx); + return new EnsembleDistributionModelParameters(env, ctx); } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { base.SaveCore(ctx); ctx.SetVersionInfo(GetVersionInfo()); @@ -119,7 +129,7 @@ protected override void SaveCore(ModelSaveContext ctx) ctx.Writer.Write((int)PredictionKind); } - public ValueMapper GetMapper() + ValueMapper IValueMapper.GetMapper() { Host.Check(typeof(TIn) == typeof(VBuffer)); Host.Check(typeof(TOut) == typeof(Single)); @@ -132,8 +142,8 @@ public ValueMapper GetMapper() ValueMapper, Single> del = (in VBuffer src, ref Single dst) => { - if (InputType.VectorSize > 0) - Host.Check(src.Length == InputType.VectorSize); + if (_inputType.VectorSize > 0) + Host.Check(src.Length == _inputType.VectorSize); var tmp = src; Parallel.For(0, maps.Length, i => @@ -155,7 +165,7 @@ public ValueMapper GetMapper() return (ValueMapper)(Delegate)del; } - public ValueMapper GetMapper() + ValueMapper IValueMapperDist.GetMapper() { Host.Check(typeof(TIn) == typeof(VBuffer)); Host.Check(typeof(TOut) == typeof(Single)); @@ -170,8 +180,8 @@ public ValueMapper GetMapper() ValueMapper, Single, Single> del = (in VBuffer src, ref Single score, ref Single prob) => { - if (InputType.VectorSize > 0) - Host.Check(src.Length == InputType.VectorSize); + if (_inputType.VectorSize > 0) + Host.Check(src.Length == _inputType.VectorSize); var tmp = src; Parallel.For(0, maps.Length, i => diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictor.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParameters.cs similarity index 64% rename from src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictor.cs rename to src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParameters.cs index b69b76f40e..cd7e82b52f 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictor.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParameters.cs @@ -4,27 +4,30 @@ using System; using System.Threading.Tasks; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Ensemble; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Ensemble.OutputCombiners; -using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML; +using Microsoft.ML.Data; +using Microsoft.ML.Ensemble; +using Microsoft.ML.Ensemble.OutputCombiners; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Model; -[assembly: LoadableClass(typeof(EnsemblePredictor), null, typeof(SignatureLoadModel), EnsemblePredictor.UserName, - EnsemblePredictor.LoaderSignature)] +[assembly: LoadableClass(typeof(EnsembleModelParameters), null, typeof(SignatureLoadModel), EnsembleModelParameters.UserName, + EnsembleModelParameters.LoaderSignature)] -[assembly: EntryPointModule(typeof(EnsemblePredictor))] +[assembly: EntryPointModule(typeof(EnsembleModelParameters))] -namespace Microsoft.ML.Runtime.Ensemble +namespace Microsoft.ML.Ensemble { using TScalarPredictor = IPredictorProducing; - public sealed class EnsemblePredictor : EnsemblePredictorBase, IValueMapper + /// + /// A class for artifacts of ensembled models. + /// + public sealed class EnsembleModelParameters : EnsembleModelParametersBase, IValueMapper { - public const string UserName = "Ensemble Executor"; - public const string LoaderSignature = "EnsembleFloatExec"; - public const string RegistrationName = "EnsemblePredictor"; + internal const string UserName = "Ensemble Executor"; + internal const string LoaderSignature = "EnsembleFloatExec"; + internal const string RegistrationName = "EnsemblePredictor"; private static VersionInfo GetVersionInfo() { @@ -36,28 +39,37 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010003, verWeCanReadBack: 0x00010002, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(EnsemblePredictor).Assembly.FullName); + loaderAssemblyName: typeof(EnsembleModelParameters).Assembly.FullName); } private readonly IValueMapper[] _mappers; - public ColumnType InputType { get; } - public ColumnType OutputType => NumberType.Float; + private readonly ColumnType _inputType; + ColumnType IValueMapper.InputType => _inputType; + ColumnType IValueMapper.OutputType => NumberType.Float; public override PredictionKind PredictionKind { get; } - internal EnsemblePredictor(IHostEnvironment env, PredictionKind kind, + /// + /// Instantiate new ensemble model from existing sub-models. + /// + /// The host environment. + /// The prediction kind + /// Array of sub-models that you want to ensemble together. + /// The combiner class to use to ensemble the models. + /// The weights assigned to each model to be ensembled. + public EnsembleModelParameters(IHostEnvironment env, PredictionKind kind, FeatureSubsetModel[] models, IOutputCombiner combiner, Single[] weights = null) : base(env, LoaderSignature, models, combiner, weights) { PredictionKind = kind; - InputType = InitializeMappers(out _mappers); + _inputType = InitializeMappers(out _mappers); } - private EnsemblePredictor(IHostEnvironment env, ModelLoadContext ctx) + private EnsembleModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, RegistrationName, ctx) { PredictionKind = (PredictionKind)ctx.Reader.ReadInt32(); - InputType = InitializeMappers(out _mappers); + _inputType = InitializeMappers(out _mappers); } private ColumnType InitializeMappers(out IValueMapper[] mappers) @@ -91,15 +103,15 @@ private bool IsValid(IValueMapper mapper) && mapper.OutputType == NumberType.Float; } - public static EnsemblePredictor Create(IHostEnvironment env, ModelLoadContext ctx) + private static EnsembleModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - return new EnsemblePredictor(env, ctx); + return new EnsembleModelParameters(env, ctx); } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { base.SaveCore(ctx); ctx.SetVersionInfo(GetVersionInfo()); @@ -109,7 +121,7 @@ protected override void SaveCore(ModelSaveContext ctx) ctx.Writer.Write((int)PredictionKind); } - public ValueMapper GetMapper() + ValueMapper IValueMapper.GetMapper() { Host.Check(typeof(TIn) == typeof(VBuffer)); Host.Check(typeof(TOut) == typeof(Single)); @@ -124,8 +136,8 @@ public ValueMapper GetMapper() ValueMapper, Single> del = (in VBuffer src, ref Single dst) => { - if (InputType.VectorSize > 0) - Host.Check(src.Length == InputType.VectorSize); + if (_inputType.VectorSize > 0) + Host.Check(src.Length == _inputType.VectorSize); var tmp = src; Parallel.For(0, maps.Length, i => diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictorBase.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParametersBase.cs similarity index 88% rename from src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictorBase.cs rename to src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParametersBase.cs index 032312acde..f4a1d6d72d 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictorBase.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParametersBase.cs @@ -5,16 +5,16 @@ using System; using System.Collections.Generic; using System.IO; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Ensemble.OutputCombiners; -using Microsoft.ML.Runtime.Internal.Internallearn; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Data; +using Microsoft.ML.Ensemble.OutputCombiners; +using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; -namespace Microsoft.ML.Runtime.Ensemble +namespace Microsoft.ML.Ensemble { - public abstract class EnsemblePredictorBase : PredictorBase, - IPredictorProducing, ICanSaveInTextFormat, ICanSaveModel, ICanSaveSummary + public abstract class EnsembleModelParametersBase : ModelParametersBase, + IPredictorProducing, ICanSaveInTextFormat, ICanSaveSummary where TPredictor : class, IPredictorProducing { private const string SubPredictorFmt = "SubPredictor_{0:000}"; @@ -25,7 +25,7 @@ public abstract class EnsemblePredictorBase : PredictorBase private const uint VerOld = 0x00010002; - protected EnsemblePredictorBase(IHostEnvironment env, string name, FeatureSubsetModel[] models, + internal EnsembleModelParametersBase(IHostEnvironment env, string name, FeatureSubsetModel[] models, IOutputCombiner combiner, Single[] weights) : base(env, name) { @@ -38,7 +38,7 @@ protected EnsemblePredictorBase(IHostEnvironment env, string name, FeatureSubset Weights = weights; } - protected EnsemblePredictorBase(IHostEnvironment env, string name, ModelLoadContext ctx) + protected EnsembleModelParametersBase(IHostEnvironment env, string name, ModelLoadContext ctx) : base(env, name, ctx) { // *** Binary format *** @@ -86,7 +86,7 @@ protected EnsemblePredictorBase(IHostEnvironment env, string name, ModelLoadCont ctx.LoadModel, SignatureLoadModel>(Host, out Combiner, @"Combiner"); } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { base.SaveCore(ctx); @@ -128,7 +128,7 @@ protected override void SaveCore(ModelSaveContext ctx) /// /// Output the INI model to a given writer /// - public void SaveAsText(TextWriter writer, RoleMappedSchema schema) + void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema) { using (var ch = Host.Start("SaveAsText")) { @@ -144,7 +144,7 @@ public void SaveAsText(TextWriter writer, RoleMappedSchema schema) /// /// Saves the model summary /// - public void SaveSummary(TextWriter writer, RoleMappedSchema schema) + void ICanSaveSummary.SaveSummary(TextWriter writer, RoleMappedSchema schema) { for (int i = 0; i < Models.Length; i++) { diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs index a8fc896c5b..305e05c788 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs @@ -6,27 +6,28 @@ using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Ensemble.OutputCombiners; -using Microsoft.ML.Runtime.Ensemble.Selector; -using Microsoft.ML.Runtime.Ensemble.Selector.SubsetSelector; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Internallearn; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Training; - -namespace Microsoft.ML.Runtime.Ensemble +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.Ensemble.OutputCombiners; +using Microsoft.ML.Ensemble.Selector; +using Microsoft.ML.Ensemble.Selector.SubsetSelector; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Training; + +namespace Microsoft.ML.Ensemble { using Stopwatch = System.Diagnostics.Stopwatch; - public abstract class EnsembleTrainerBase : TrainerBase + internal abstract class EnsembleTrainerBase : TrainerBase where TPredictor : class, IPredictorProducing where TSelector : class, ISubModelSelector where TCombiner : class, IOutputCombiner { public abstract class ArgumentsBase : LearnerInputBaseWithLabel { +#pragma warning disable CS0649 // These are set via reflection. [Argument(ArgumentType.AtMostOnce, HelpText = "Number of models per batch. If not specified, will default to 50 if there is only one base predictor, " + "or the number of base predictors otherwise.", ShortName = "nm", SortOrder = 3)] @@ -54,6 +55,7 @@ public abstract class ArgumentsBase : LearnerInputBaseWithLabel public bool ShowMetrics; internal abstract IComponentFactory>>[] GetPredictorFactories(); +#pragma warning restore CS0649 } private const int DefaultNumModels = 50; diff --git a/src/Microsoft.ML.Ensemble/Trainer/IModelCombiner.cs b/src/Microsoft.ML.Ensemble/Trainer/IModelCombiner.cs index 196df4dc54..85f55d8111 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/IModelCombiner.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/IModelCombiner.cs @@ -3,9 +3,8 @@ // See the LICENSE file in the project root for more information. using System.Collections.Generic; -using Microsoft.ML.Runtime; -namespace Microsoft.ML.Runtime.Ensemble +namespace Microsoft.ML.Ensemble { public delegate void SignatureModelCombiner(PredictionKind kind); diff --git a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/EnsembleMultiClassPredictor.cs b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/EnsembleMultiClassModelParameters.cs similarity index 72% rename from src/Microsoft.ML.Ensemble/Trainer/Multiclass/EnsembleMultiClassPredictor.cs rename to src/Microsoft.ML.Ensemble/Trainer/Multiclass/EnsembleMultiClassModelParameters.cs index 6f931f3e9c..ecf6ed0a57 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/EnsembleMultiClassPredictor.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/EnsembleMultiClassModelParameters.cs @@ -4,24 +4,24 @@ using System; using System.Threading.Tasks; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Ensemble; -using Microsoft.ML.Runtime.Ensemble.OutputCombiners; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML; +using Microsoft.ML.Data; +using Microsoft.ML.Ensemble; +using Microsoft.ML.Ensemble.OutputCombiners; +using Microsoft.ML.Model; -[assembly: LoadableClass(typeof(EnsembleMultiClassPredictor), null, typeof(SignatureLoadModel), - EnsembleMultiClassPredictor.UserName, EnsembleMultiClassPredictor.LoaderSignature)] +[assembly: LoadableClass(typeof(EnsembleMultiClassModelParameters), null, typeof(SignatureLoadModel), + EnsembleMultiClassModelParameters.UserName, EnsembleMultiClassModelParameters.LoaderSignature)] -namespace Microsoft.ML.Runtime.Ensemble +namespace Microsoft.ML.Ensemble { using TVectorPredictor = IPredictorProducing>; - public sealed class EnsembleMultiClassPredictor : EnsemblePredictorBase>, IValueMapper + public sealed class EnsembleMultiClassModelParameters : EnsembleModelParametersBase>, IValueMapper { - public const string UserName = "Ensemble Multiclass Executor"; - public const string LoaderSignature = "EnsemMcExec"; - public const string RegistrationName = "EnsembleMultiClassPredictor"; + internal const string UserName = "Ensemble Multiclass Executor"; + internal const string LoaderSignature = "EnsemMcExec"; + internal const string RegistrationName = "EnsembleMultiClassPredictor"; private static VersionInfo GetVersionInfo() { @@ -33,24 +33,31 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010003, verWeCanReadBack: 0x00010002, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(EnsembleMultiClassPredictor).Assembly.FullName); + loaderAssemblyName: typeof(EnsembleMultiClassModelParameters).Assembly.FullName); } private readonly ColumnType _inputType; private readonly ColumnType _outputType; private readonly IValueMapper[] _mappers; - public ColumnType InputType { get { return _inputType; } } - public ColumnType OutputType { get { return _outputType; } } - - internal EnsembleMultiClassPredictor(IHostEnvironment env, FeatureSubsetModel[] models, + ColumnType IValueMapper.InputType => _inputType; + ColumnType IValueMapper.OutputType => _outputType; + + /// + /// Instantiate new ensemble model from existing sub-models. + /// + /// The host environment. + /// Array of sub-models that you want to ensemble together. + /// The combiner class to use to ensemble the models. + /// The weights assigned to each model to be ensembled. + public EnsembleMultiClassModelParameters(IHostEnvironment env, FeatureSubsetModel[] models, IMultiClassOutputCombiner combiner, Single[] weights = null) : base(env, RegistrationName, models, combiner, weights) { InitializeMappers(out _mappers, out _inputType, out _outputType); } - private EnsembleMultiClassPredictor(IHostEnvironment env, ModelLoadContext ctx) + private EnsembleMultiClassModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, RegistrationName, ctx) { InitializeMappers(out _mappers, out _inputType, out _outputType); @@ -87,23 +94,23 @@ private void InitializeMappers(out IValueMapper[] mappers, out ColumnType inputT inputType = new VectorType(NumberType.Float); } - public static EnsembleMultiClassPredictor Create(IHostEnvironment env, ModelLoadContext ctx) + private static EnsembleMultiClassModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - return new EnsembleMultiClassPredictor(env, ctx); + return new EnsembleMultiClassModelParameters(env, ctx); } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { base.SaveCore(ctx); ctx.SetVersionInfo(GetVersionInfo()); } - public override PredictionKind PredictionKind { get { return PredictionKind.MultiClassClassification; } } + public override PredictionKind PredictionKind => PredictionKind.MultiClassClassification; - public ValueMapper GetMapper() + ValueMapper IValueMapper.GetMapper() { Host.Check(typeof(TIn) == typeof(VBuffer)); Host.Check(typeof(TOut) == typeof(VBuffer)); diff --git a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs index 1961e6a785..7d13483f4a 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs @@ -5,15 +5,15 @@ using System; using System.Collections.Generic; using System.Linq; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.Ensemble; using Microsoft.ML.Ensemble.EntryPoints; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Ensemble; -using Microsoft.ML.Runtime.Ensemble.OutputCombiners; -using Microsoft.ML.Runtime.Ensemble.Selector; -using Microsoft.ML.Runtime.Internal.Internallearn; -using Microsoft.ML.Runtime.Learners; +using Microsoft.ML.Ensemble.OutputCombiners; +using Microsoft.ML.Ensemble.Selector; +using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Learners; [assembly: LoadableClass(MulticlassDataPartitionEnsembleTrainer.Summary, typeof(MulticlassDataPartitionEnsembleTrainer), typeof(MulticlassDataPartitionEnsembleTrainer.Arguments), @@ -24,14 +24,14 @@ [assembly: LoadableClass(typeof(MulticlassDataPartitionEnsembleTrainer), typeof(MulticlassDataPartitionEnsembleTrainer.Arguments), typeof(SignatureModelCombiner), "Multiclass Classification Ensemble Model Combiner", MulticlassDataPartitionEnsembleTrainer.LoadNameValue)] -namespace Microsoft.ML.Runtime.Ensemble +namespace Microsoft.ML.Ensemble { using TVectorPredictor = IPredictorProducing>; /// /// A generic ensemble classifier for multi-class classification /// internal sealed class MulticlassDataPartitionEnsembleTrainer : - EnsembleTrainerBase, EnsembleMultiClassPredictor, + EnsembleTrainerBase, EnsembleMultiClassModelParameters, IMulticlassSubModelSelector, IMultiClassOutputCombiner>, IModelCombiner { @@ -83,9 +83,9 @@ private MulticlassDataPartitionEnsembleTrainer(IHostEnvironment env, Arguments a public override PredictionKind PredictionKind => PredictionKind.MultiClassClassification; - private protected override EnsembleMultiClassPredictor CreatePredictor(List> models) + private protected override EnsembleMultiClassModelParameters CreatePredictor(List> models) { - return new EnsembleMultiClassPredictor(Host, CreateModels(models), Combiner as IMultiClassOutputCombiner); + return new EnsembleMultiClassModelParameters(Host, CreateModels(models), Combiner as IMultiClassOutputCombiner); } public IPredictor CombineModels(IEnumerable models) @@ -94,7 +94,7 @@ public IPredictor CombineModels(IEnumerable models) Host.CheckParam(models.All(m => m is TVectorPredictor), nameof(models)); var combiner = _outputCombiner.CreateComponent(Host); - var predictor = new EnsembleMultiClassPredictor(Host, + var predictor = new EnsembleMultiClassModelParameters(Host, models.Select(k => new FeatureSubsetModel((TVectorPredictor)k)).ToArray(), combiner); return predictor; diff --git a/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs index 09d394d596..4545ce3c34 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs @@ -5,14 +5,14 @@ using System; using System.Collections.Generic; using System.Linq; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.Ensemble; using Microsoft.ML.Ensemble.EntryPoints; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Ensemble; -using Microsoft.ML.Runtime.Ensemble.OutputCombiners; -using Microsoft.ML.Runtime.Ensemble.Selector; -using Microsoft.ML.Runtime.Internal.Internallearn; +using Microsoft.ML.Ensemble.OutputCombiners; +using Microsoft.ML.Ensemble.Selector; +using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Trainers.Online; [assembly: LoadableClass(typeof(RegressionEnsembleTrainer), typeof(RegressionEnsembleTrainer.Arguments), @@ -23,7 +23,7 @@ [assembly: LoadableClass(typeof(RegressionEnsembleTrainer), typeof(RegressionEnsembleTrainer.Arguments), typeof(SignatureModelCombiner), "Regression Ensemble Model Combiner", RegressionEnsembleTrainer.LoadNameValue)] -namespace Microsoft.ML.Runtime.Ensemble +namespace Microsoft.ML.Ensemble { using TScalarPredictor = IPredictorProducing; internal sealed class RegressionEnsembleTrainer : EnsembleTrainerBase> models) { - return new EnsemblePredictor(Host, PredictionKind, CreateModels(models), Combiner); + return new EnsembleModelParameters(Host, PredictionKind, CreateModels(models), Combiner); } public IPredictor CombineModels(IEnumerable models) @@ -90,7 +90,7 @@ public IPredictor CombineModels(IEnumerable models) var combiner = _outputCombiner.CreateComponent(Host); var p = models.First(); - var predictor = new EnsemblePredictor(Host, p.PredictionKind, + var predictor = new EnsembleModelParameters(Host, p.PredictionKind, models.Select(k => new FeatureSubsetModel((TScalarPredictor)k)).ToArray(), combiner); return predictor; diff --git a/src/Microsoft.ML.Legacy/Runtime/EntryPoints/CVSplit.cs b/src/Microsoft.ML.EntryPoints/CVSplit.cs similarity index 94% rename from src/Microsoft.ML.Legacy/Runtime/EntryPoints/CVSplit.cs rename to src/Microsoft.ML.EntryPoints/CVSplit.cs index 0b7f0fe94a..9b29584370 100644 --- a/src/Microsoft.ML.Legacy/Runtime/EntryPoints/CVSplit.cs +++ b/src/Microsoft.ML.EntryPoints/CVSplit.cs @@ -2,15 +2,15 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; using Microsoft.ML.Transforms; [assembly: LoadableClass(typeof(void), typeof(CVSplit), null, typeof(SignatureEntryPointModule), "CVSplit")] -namespace Microsoft.ML.Runtime.EntryPoints +namespace Microsoft.ML.EntryPoints { /// /// The module that splits the input dataset into the specified number of cross-validation folds, and outputs the 'training' diff --git a/src/Microsoft.ML.Legacy/Runtime/EntryPoints/CrossValidationMacro.cs b/src/Microsoft.ML.EntryPoints/CrossValidationMacro.cs similarity index 65% rename from src/Microsoft.ML.Legacy/Runtime/EntryPoints/CrossValidationMacro.cs rename to src/Microsoft.ML.EntryPoints/CrossValidationMacro.cs index 9522b8aaa5..ce64661b57 100644 --- a/src/Microsoft.ML.Legacy/Runtime/EntryPoints/CrossValidationMacro.cs +++ b/src/Microsoft.ML.EntryPoints/CrossValidationMacro.cs @@ -5,11 +5,11 @@ using System; using System.Collections.Generic; using System.Linq; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Utilities; using Newtonsoft.Json.Linq; [assembly: LoadableClass(typeof(void), typeof(CrossValidationMacro), null, typeof(SignatureEntryPointModule), "CrossValidationMacro")] @@ -17,7 +17,7 @@ // The warning #612 is disabled because the following code uses a lot of things in Legacy.Models and Legacy.Transforms while Legacy is marked as obsolete. // Because that dependency will be removed form ML.NET, one needs to rewrite all places where legacy APIs are used. #pragma warning disable 612 -namespace Microsoft.ML.Runtime.EntryPoints +namespace Microsoft.ML.EntryPoints { /// @@ -34,10 +34,7 @@ public sealed class SubGraphInput public sealed class SubGraphOutput { [Argument(ArgumentType.AtMostOnce, HelpText = "The predictor model", SortOrder = 1)] - public Var PredictorModel; - - [Argument(ArgumentType.AtMostOnce, HelpText = "The transform model", SortOrder = 2)] - public Var TransformModel; + public Var PredictorModel; } public sealed class Arguments @@ -51,7 +48,7 @@ public sealed class Arguments [TlcModule.OptionalInput] [Argument(ArgumentType.AtMostOnce, HelpText = "The transform model from the pipeline before this command. " + "It gets included in the Output.PredictorModel.", SortOrder = 2)] - public ITransformModel TransformModel; + public TransformModel TransformModel; // This is the subgraph that describes how to train a model for each fold. It should // accept one IDataView input and output one IPredictorModel output (see Inputs and Outputs). @@ -101,11 +98,7 @@ public sealed class Output { [TlcModule.Output(Desc = "The final model including the trained predictor model and the model from the transforms, " + "provided as the Input.TransformModel.", SortOrder = 1)] - public IPredictorModel[] PredictorModel; - - [TlcModule.Output(Desc = "The final model including the trained predictor model and the model from the transforms, " + - "provided as the Input.TransformModel.", SortOrder = 2)] - public ITransformModel[] TransformModel; + public PredictorModel[] PredictorModel; [TlcModule.Output(Desc = "Warning dataset", SortOrder = 3)] public IDataView Warnings; @@ -182,17 +175,26 @@ public static CommonOutputs.MacroOutput CrossValidate( transformModelVarName = node.GetInputVariable(nameof(input.TransformModel)); // Split the input data into folds. - var exp = new Experiment(env); - var cvSplit = new Legacy.Models.CrossValidatorDatasetSplitter(); - cvSplit.Data.VarName = node.GetInputVariable("Data").ToJson(); - cvSplit.NumFolds = input.NumFolds; - cvSplit.StratificationColumn = input.StratificationColumn; - var cvSplitOutput = exp.Add(cvSplit); - subGraphNodes.AddRange(EntryPointNode.ValidateNodes(env, node.Context, exp.GetNodes())); - - var predModelVars = new Var[input.NumFolds]; - var transformModelVars = new Var[input.NumFolds]; - var inputTransformModelVars = new Var[input.NumFolds]; + var splitArgs = new CVSplit.Input(); + splitArgs.NumFolds = input.NumFolds; + splitArgs.StratificationColumn = input.StratificationColumn; + var inputBindingMap = new Dictionary>(); + var inputMap = new Dictionary(); + var inputData = node.GetInputVariable(nameof(splitArgs.Data)); + ParameterBinding paramBinding = new SimpleParameterBinding(nameof(splitArgs.Data)); + inputBindingMap.Add(nameof(splitArgs.Data), new List() { paramBinding }); + inputMap.Add(paramBinding, inputData); + var outputMap = new Dictionary(); + var splitOutputTrainData = new ArrayVar(); + var splitOutputTestData = new ArrayVar(); + outputMap.Add(nameof(CVSplit.Output.TrainData), splitOutputTrainData.VarName); + outputMap.Add(nameof(CVSplit.Output.TestData), splitOutputTestData.VarName); + var splitNode = EntryPointNode.Create(env, "Models.CrossValidatorDatasetSplitter", splitArgs, + node.Context, inputBindingMap, inputMap, outputMap); + subGraphNodes.Add(splitNode); + + var predModelVars = new Var[input.NumFolds]; + var inputTransformModelVars = new Var[input.NumFolds]; var warningsVars = new Var[input.NumFolds]; var overallMetricsVars = new Var[input.NumFolds]; var instanceMetricsVars = new Var[input.NumFolds]; @@ -221,88 +223,56 @@ public static CommonOutputs.MacroOutput CrossValidate( }; if (transformModelVarName != null) - args.TransformModel = new Var { VarName = transformModelVarName.VariableName }; + args.TransformModel = new Var { VarName = transformModelVarName.VariableName }; args.Inputs.Data = new Var { VarName = mapping[input.Inputs.Data.VarName] }; - - if (input.Outputs.PredictorModel != null && mapping.ContainsKey(input.Outputs.PredictorModel.VarName)) + args.Outputs.PredictorModel = new Var { - args.Outputs.PredictorModel = new Var - { - VarName = mapping[input.Outputs.PredictorModel.VarName] - }; - } - else - args.Outputs.PredictorModel = null; - - if (input.Outputs.TransformModel != null && mapping.ContainsKey(input.Outputs.TransformModel.VarName)) - { - args.Outputs.TransformModel = new Var - { - VarName = mapping[input.Outputs.TransformModel.VarName] - }; - } - else - args.Outputs.TransformModel = null; + VarName = mapping[input.Outputs.PredictorModel.VarName] + }; // Set train/test trainer kind to match. args.Kind = input.Kind; // Set the input bindings for the TrainTest entry point. - var inputBindingMap = new Dictionary>(); - var inputMap = new Dictionary(); + inputBindingMap = new Dictionary>(); + inputMap = new Dictionary(); var trainingData = new SimpleParameterBinding(nameof(args.TrainingData)); inputBindingMap.Add(nameof(args.TrainingData), new List { trainingData }); - inputMap.Add(trainingData, new ArrayIndexVariableBinding(cvSplitOutput.TrainData.VarName, k)); + inputMap.Add(trainingData, new ArrayIndexVariableBinding(splitOutputTrainData.VarName, k)); var testingData = new SimpleParameterBinding(nameof(args.TestingData)); inputBindingMap.Add(nameof(args.TestingData), new List { testingData }); - inputMap.Add(testingData, new ArrayIndexVariableBinding(cvSplitOutput.TestData.VarName, k)); - var outputMap = new Dictionary(); - var transformModelVar = new Var(); - var predModelVar = new Var(); - if (input.Outputs.PredictorModel == null) + inputMap.Add(testingData, new ArrayIndexVariableBinding(splitOutputTestData.VarName, k)); + outputMap = new Dictionary(); + var transformModelVar = new Var(); + var predModelVar = new Var(); + outputMap.Add(nameof(TrainTestMacro.Output.PredictorModel), predModelVar.VarName); + predModelVars[k] = predModelVar; + if (transformModelVarName != null && transformModelVarName.VariableName != null) { - outputMap.Add(nameof(TrainTestMacro.Output.TransformModel), transformModelVar.VarName); - transformModelVars[k] = transformModelVar; - Legacy.Transforms.ModelCombiner.Output modelCombineOutput = null; - if (transformModelVarName != null && transformModelVarName.VariableName != null) - { - var modelCombine = new Legacy.Transforms.ModelCombiner - { - Models = new ArrayVar( - new Var[] { - new Var { VarName = transformModelVarName.VariableName }, - transformModelVar } - ) - }; - - exp.Reset(); - modelCombineOutput = exp.Add(modelCombine); - subGraphNodes.AddRange(EntryPointNode.ValidateNodes(env, node.Context, exp.GetNodes())); - transformModelVars[k] = modelCombineOutput.OutputModel; - } - } - else - { - outputMap.Add(nameof(TrainTestMacro.Output.PredictorModel), predModelVar.VarName); - predModelVars[k] = predModelVar; - Legacy.Transforms.TwoHeterogeneousModelCombiner.Output modelCombineOutput = null; - if (transformModelVarName != null && transformModelVarName.VariableName != null) - { - var modelCombine = new Legacy.Transforms.TwoHeterogeneousModelCombiner - { - TransformModel = { VarName = transformModelVarName.VariableName }, - PredictorModel = predModelVar - }; - - exp.Reset(); - modelCombineOutput = exp.Add(modelCombine); - subGraphNodes.AddRange(EntryPointNode.ValidateNodes(env, node.Context, exp.GetNodes())); - predModelVars[k] = modelCombineOutput.PredictorModel; - } + var combineModelsArgs = new ModelOperations.SimplePredictorModelInput(); + inputBindingMap = new Dictionary>(); + inputMap = new Dictionary(); + + var inputTransformModel = new SimpleVariableBinding(transformModelVarName.VariableName); + var inputPredictorModel = new SimpleVariableBinding(predModelVar.VarName); + paramBinding = new SimpleParameterBinding(nameof(combineModelsArgs.TransformModel)); + inputBindingMap.Add(nameof(combineModelsArgs.TransformModel), new List() { paramBinding }); + inputMap.Add(paramBinding, inputTransformModel); + paramBinding = new SimpleParameterBinding(nameof(combineModelsArgs.PredictorModel)); + inputBindingMap.Add(nameof(combineModelsArgs.PredictorModel), new List() { paramBinding }); + inputMap.Add(paramBinding, inputPredictorModel); + outputMap = new Dictionary(); + + var combineNodeOutputPredictorModel = new Var(); + predModelVars[k] = combineNodeOutputPredictorModel; + outputMap.Add(nameof(ModelOperations.PredictorModelOutput.PredictorModel), combineNodeOutputPredictorModel.VarName); + EntryPointNode combineNode = EntryPointNode.Create(env, "Transforms.TwoHeterogeneousModelCombiner", combineModelsArgs, + node.Context, inputBindingMap, inputMap, outputMap); + subGraphNodes.Add(combineNode); } var warningVar = new Var(); @@ -321,66 +291,22 @@ public static CommonOutputs.MacroOutput CrossValidate( subGraphNodes.Add(EntryPointNode.Create(env, trainTestEvaluatorMacroEntryPoint, args, node.Context, inputBindingMap, inputMap, outputMap)); } - exp.Reset(); - - // Convert predictors from all folds into an array of predictors. - - if (input.Outputs.PredictorModel == null) - { - var outModels = new Legacy.Data.TransformModelArrayConverter - { - TransformModel = new ArrayVar(transformModelVars) - }; - var outModelsOutput = new Legacy.Data.TransformModelArrayConverter.Output(); - outModelsOutput.OutputModel.VarName = node.GetOutputVariableName(nameof(Output.TransformModel)); - exp.Add(outModels, outModelsOutput); - } - else - { - var outModels = new Legacy.Data.PredictorModelArrayConverter - { - Model = new ArrayVar(predModelVars) - }; - var outModelsOutput = new Legacy.Data.PredictorModelArrayConverter.Output(); - outModelsOutput.OutputModel.VarName = node.GetOutputVariableName(nameof(Output.PredictorModel)); - exp.Add(outModels, outModelsOutput); - } - - // Convert warnings data views from all folds into an array of data views. - var warnings = new Legacy.Data.IDataViewArrayConverter - { - Data = new ArrayVar(warningsVars) - }; - var warningsOutput = new Legacy.Data.IDataViewArrayConverter.Output(); - exp.Add(warnings, warningsOutput); - - // Convert overall metrics data views from all folds into an array of data views. - var overallMetrics = new Legacy.Data.IDataViewArrayConverter - { - Data = new ArrayVar(overallMetricsVars) - }; - var overallMetricsOutput = new Legacy.Data.IDataViewArrayConverter.Output(); - exp.Add(overallMetrics, overallMetricsOutput); - - // Convert per instance data views from all folds into an array of data views. - var instanceMetrics = new Legacy.Data.IDataViewArrayConverter - { - Data = new ArrayVar(instanceMetricsVars) - }; - var instanceMetricsOutput = new Legacy.Data.IDataViewArrayConverter.Output(); - exp.Add(instanceMetrics, instanceMetricsOutput); - - Legacy.Data.IDataViewArrayConverter.Output confusionMatricesOutput = null; + // Convert the predictor models to an array of predictor models. + MacroUtils.ConvertIPredictorModelsToArray(env, node.Context, subGraphNodes, predModelVars, node.GetOutputVariableName(nameof(Output.PredictorModel))); + + // Convert the warnings, overall, per instance and confusion matrix data views into an array. + var warningsArrayVar = new ArrayVar(); + var overallArrayVar = new ArrayVar(); + var instanceArrayVar = new ArrayVar(); + ArrayVar confusionMatrixArrayVar = null; + MacroUtils.ConvertIdataViewsToArray(env, node.Context, subGraphNodes, warningsVars, warningsArrayVar.VarName); + MacroUtils.ConvertIdataViewsToArray(env, node.Context, subGraphNodes, overallMetricsVars, overallArrayVar.VarName); + MacroUtils.ConvertIdataViewsToArray(env, node.Context, subGraphNodes, instanceMetricsVars, instanceArrayVar.VarName); if (input.Kind == MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer || input.Kind == MacroUtils.TrainerKinds.SignatureMultiClassClassifierTrainer) { - // Convert confusion matrix data views from all folds into an array of data views. - var confusionMatrices = new Legacy.Data.IDataViewArrayConverter - { - Data = new ArrayVar(confusionMatrixVars) - }; - confusionMatricesOutput = new Legacy.Data.IDataViewArrayConverter.Output(); - exp.Add(confusionMatrices, confusionMatricesOutput); + confusionMatrixArrayVar = new ArrayVar(); + MacroUtils.ConvertIdataViewsToArray(env, node.Context, subGraphNodes, confusionMatrixVars, confusionMatrixArrayVar.VarName); } var combineArgs = new CombineMetricsInput(); @@ -396,18 +322,18 @@ public static CommonOutputs.MacroOutput CrossValidate( var warningsArray = new SimpleParameterBinding(nameof(combineArgs.Warnings)); combineInputBindingMap.Add(nameof(combineArgs.Warnings), new List { warningsArray }); - combineInputMap.Add(warningsArray, new SimpleVariableBinding(warningsOutput.OutputData.VarName)); + combineInputMap.Add(warningsArray, new SimpleVariableBinding(warningsArrayVar.VarName)); var overallArray = new SimpleParameterBinding(nameof(combineArgs.OverallMetrics)); combineInputBindingMap.Add(nameof(combineArgs.OverallMetrics), new List { overallArray }); - combineInputMap.Add(overallArray, new SimpleVariableBinding(overallMetricsOutput.OutputData.VarName)); + combineInputMap.Add(overallArray, new SimpleVariableBinding(overallArrayVar.VarName)); var combinePerInstArray = new SimpleParameterBinding(nameof(combineArgs.PerInstanceMetrics)); combineInputBindingMap.Add(nameof(combineArgs.PerInstanceMetrics), new List { combinePerInstArray }); - combineInputMap.Add(combinePerInstArray, new SimpleVariableBinding(instanceMetricsOutput.OutputData.VarName)); - if (confusionMatricesOutput != null) + combineInputMap.Add(combinePerInstArray, new SimpleVariableBinding(instanceArrayVar.VarName)); + if (confusionMatrixArrayVar != null) { var combineConfArray = new SimpleParameterBinding(nameof(combineArgs.ConfusionMatrix)); combineInputBindingMap.Add(nameof(combineArgs.ConfusionMatrix), new List { combineConfArray }); - combineInputMap.Add(combineConfArray, new SimpleVariableBinding(confusionMatricesOutput.OutputData.VarName)); + combineInputMap.Add(combineConfArray, new SimpleVariableBinding(confusionMatrixArrayVar.VarName)); } var combineOutputMap = new Dictionary(); @@ -420,14 +346,15 @@ public static CommonOutputs.MacroOutput CrossValidate( var combineInstanceMetric = new Var(); combineInstanceMetric.VarName = node.GetOutputVariableName(nameof(Output.PerInstanceMetrics)); combineOutputMap.Add(nameof(Output.PerInstanceMetrics), combineInstanceMetric.VarName); - if (confusionMatricesOutput != null) + if (confusionMatrixArrayVar != null) { var combineConfusionMatrix = new Var(); combineConfusionMatrix.VarName = node.GetOutputVariableName(nameof(Output.ConfusionMatrix)); combineOutputMap.Add(nameof(TrainTestMacro.Output.ConfusionMatrix), combineConfusionMatrix.VarName); } - subGraphNodes.AddRange(EntryPointNode.ValidateNodes(env, node.Context, exp.GetNodes())); - subGraphNodes.Add(EntryPointNode.Create(env, "Models.CrossValidationResultsCombiner", combineArgs, node.Context, combineInputBindingMap, combineInputMap, combineOutputMap)); + var combineMetricsNode = EntryPointNode.Create(env, "Models.CrossValidationResultsCombiner", + combineArgs, node.Context, combineInputBindingMap, combineInputMap, combineOutputMap); + subGraphNodes.Add(combineMetricsNode); return new CommonOutputs.MacroOutput() { Nodes = subGraphNodes }; } @@ -468,10 +395,10 @@ public static CombinedOutput CombineMetrics(IHostEnvironment env, CombineMetrics { var idv = input.ConfusionMatrix[i]; // Find the old Count column and drop it. - for (int col = 0; col < idv.Schema.ColumnCount; col++) + for (int col = 0; col < idv.Schema.Count; col++) { - if (idv.Schema.IsHidden(col) && - idv.Schema.GetColumnName(col).Equals(MetricKinds.ColumnNames.Count)) + if (idv.Schema[col].IsHidden && + idv.Schema[col].Name.Equals(MetricKinds.ColumnNames.Count)) { input.ConfusionMatrix[i] = new ChooseColumnsByIndexTransform(env, new ChooseColumnsByIndexTransform.Arguments() { Drop = true, Index = new[] { col } }, idv); @@ -498,22 +425,22 @@ private static IMamlEvaluator GetEvaluator(IHostEnvironment env, MacroUtils.Trai { switch (kind) { - case MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer: - return new BinaryClassifierMamlEvaluator(env, new BinaryClassifierMamlEvaluator.Arguments()); - case MacroUtils.TrainerKinds.SignatureMultiClassClassifierTrainer: - return new MultiClassMamlEvaluator(env, new MultiClassMamlEvaluator.Arguments()); - case MacroUtils.TrainerKinds.SignatureRegressorTrainer: - return new RegressionMamlEvaluator(env, new RegressionMamlEvaluator.Arguments()); - case MacroUtils.TrainerKinds.SignatureRankerTrainer: - return new RankerMamlEvaluator(env, new RankerMamlEvaluator.Arguments()); - case MacroUtils.TrainerKinds.SignatureAnomalyDetectorTrainer: - return new AnomalyDetectionMamlEvaluator(env, new AnomalyDetectionMamlEvaluator.Arguments()); - case MacroUtils.TrainerKinds.SignatureClusteringTrainer: - return new ClusteringMamlEvaluator(env, new ClusteringMamlEvaluator.Arguments()); - case MacroUtils.TrainerKinds.SignatureMultiOutputRegressorTrainer: - return new MultiOutputRegressionMamlEvaluator(env, new MultiOutputRegressionMamlEvaluator.Arguments()); - default: - throw env.ExceptParam(nameof(kind), $"Trainer kind {kind} does not have an evaluator"); + case MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer: + return new BinaryClassifierMamlEvaluator(env, new BinaryClassifierMamlEvaluator.Arguments()); + case MacroUtils.TrainerKinds.SignatureMultiClassClassifierTrainer: + return new MultiClassMamlEvaluator(env, new MultiClassMamlEvaluator.Arguments()); + case MacroUtils.TrainerKinds.SignatureRegressorTrainer: + return new RegressionMamlEvaluator(env, new RegressionMamlEvaluator.Arguments()); + case MacroUtils.TrainerKinds.SignatureRankerTrainer: + return new RankerMamlEvaluator(env, new RankerMamlEvaluator.Arguments()); + case MacroUtils.TrainerKinds.SignatureAnomalyDetectorTrainer: + return new AnomalyDetectionMamlEvaluator(env, new AnomalyDetectionMamlEvaluator.Arguments()); + case MacroUtils.TrainerKinds.SignatureClusteringTrainer: + return new ClusteringMamlEvaluator(env, new ClusteringMamlEvaluator.Arguments()); + case MacroUtils.TrainerKinds.SignatureMultiOutputRegressorTrainer: + return new MultiOutputRegressionMamlEvaluator(env, new MultiOutputRegressionMamlEvaluator.Arguments()); + default: + throw env.ExceptParam(nameof(kind), $"Trainer kind {kind} does not have an evaluator"); } } } diff --git a/src/Microsoft.ML.Legacy/Runtime/EntryPoints/DataViewReference.cs b/src/Microsoft.ML.EntryPoints/DataViewReference.cs similarity index 87% rename from src/Microsoft.ML.Legacy/Runtime/EntryPoints/DataViewReference.cs rename to src/Microsoft.ML.EntryPoints/DataViewReference.cs index 19d51028ea..f799ed6cdf 100644 --- a/src/Microsoft.ML.Legacy/Runtime/EntryPoints/DataViewReference.cs +++ b/src/Microsoft.ML.EntryPoints/DataViewReference.cs @@ -2,14 +2,14 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; [assembly: LoadableClass(typeof(void), typeof(DataViewReference), null, typeof(SignatureEntryPointModule), "DataViewReference")] -namespace Microsoft.ML.Runtime.EntryPoints +namespace Microsoft.ML.EntryPoints { public class DataViewReference { diff --git a/src/Microsoft.ML.Legacy/Runtime/EntryPoints/FeatureCombiner.cs b/src/Microsoft.ML.EntryPoints/FeatureCombiner.cs similarity index 92% rename from src/Microsoft.ML.Legacy/Runtime/EntryPoints/FeatureCombiner.cs rename to src/Microsoft.ML.EntryPoints/FeatureCombiner.cs index 6735cb43e6..cb04432ec1 100644 --- a/src/Microsoft.ML.Legacy/Runtime/EntryPoints/FeatureCombiner.cs +++ b/src/Microsoft.ML.EntryPoints/FeatureCombiner.cs @@ -2,21 +2,20 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Transforms.Categorical; -using Microsoft.ML.Transforms.Conversions; using System; using System.Collections.Generic; using System.Linq; using System.Text; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Transforms.Conversions; [assembly: LoadableClass(typeof(void), typeof(FeatureCombiner), null, typeof(SignatureEntryPointModule), "FeatureCombiner")] -namespace Microsoft.ML.Runtime.EntryPoints +namespace Microsoft.ML.EntryPoints { public static class FeatureCombiner { @@ -80,7 +79,7 @@ public static CommonOutputs.TransformOutput PrepareFeatures(IHostEnvironment env new[] { new ColumnConcatenatingTransformer.TaggedColumn() { Name = nameFeat, Source = concatNames.ToArray() } } }, viewTrain); - return new CommonOutputs.TransformOutput { Model = new TransformModel(env, viewTrain, input.Data), OutputData = viewTrain }; + return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, viewTrain, input.Data), OutputData = viewTrain }; } } @@ -120,11 +119,11 @@ private static string GetTerms(IDataView data, string colName) var schema = data.Schema; if (!schema.TryGetColumnIndex(colName, out col)) return null; - var type = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, col); + var type = schema[col].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues)?.Type; if (type == null || !type.IsKnownSizeVector || !type.ItemType.IsText) return null; var metadata = default(VBuffer>); - schema.GetMetadata(MetadataUtils.Kinds.KeyValues, col, ref metadata); + schema[col].Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref metadata); if (!metadata.IsDense) return null; var sb = new StringBuilder(); @@ -149,7 +148,7 @@ private static IDataView ApplyConvert(List return viewTrain; } - private static List ConvertFeatures(ColumnInfo[] feats, HashSet featNames, List> concatNames, IChannel ch, + private static List ConvertFeatures(IEnumerable feats, HashSet featNames, List> concatNames, IChannel ch, out List cvt, out int errCount) { Contracts.AssertValue(feats); @@ -235,11 +234,11 @@ public static CommonOutputs.TransformOutput PrepareClassificationLabel(IHostEnvi int labelCol; if (!input.Data.Schema.TryGetColumnIndex(input.LabelColumn, out labelCol)) throw host.Except($"Column '{input.LabelColumn}' not found."); - var labelType = input.Data.Schema.GetColumnType(labelCol); + var labelType = input.Data.Schema[labelCol].Type; if (labelType.IsKey || labelType.IsBool) { var nop = NopTransform.CreateIfNeeded(env, input.Data); - return new CommonOutputs.TransformOutput { Model = new TransformModel(env, nop, input.Data), OutputData = nop }; + return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, nop, input.Data), OutputData = nop }; } var args = new ValueToKeyMappingTransformer.Arguments() @@ -256,7 +255,7 @@ public static CommonOutputs.TransformOutput PrepareClassificationLabel(IHostEnvi } }; var xf = ValueToKeyMappingTransformer.Create(host, args, input.Data); - return new CommonOutputs.TransformOutput { Model = new TransformModel(env, xf, input.Data), OutputData = xf }; + return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, xf, input.Data), OutputData = xf }; } [TlcModule.EntryPoint(Name = "Transforms.PredictedLabelColumnOriginalValueConverter", Desc = "Transforms a predicted label column to its original values, unless it is of type bool.", UserName = "Convert Predicted Label")] @@ -270,15 +269,15 @@ public static CommonOutputs.TransformOutput ConvertPredictedLabel(IHostEnvironme int predictedLabelCol; if (!input.Data.Schema.TryGetColumnIndex(input.PredictedLabelColumn, out predictedLabelCol)) throw host.Except($"Column '{input.PredictedLabelColumn}' not found."); - var predictedLabelType = input.Data.Schema.GetColumnType(predictedLabelCol); + var predictedLabelType = input.Data.Schema[predictedLabelCol].Type; if (predictedLabelType.IsNumber || predictedLabelType.IsBool) { var nop = NopTransform.CreateIfNeeded(env, input.Data); - return new CommonOutputs.TransformOutput { Model = new TransformModel(env, nop, input.Data), OutputData = nop }; + return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, nop, input.Data), OutputData = nop }; } var xf = new KeyToValueMappingTransformer(host, input.PredictedLabelColumn).Transform(input.Data); - return new CommonOutputs.TransformOutput { Model = new TransformModel(env, xf, input.Data), OutputData = xf }; + return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, xf, input.Data), OutputData = xf }; } [TlcModule.EntryPoint(Name = "Transforms.LabelToFloatConverter", Desc = "Transforms the label to float to make it suitable for regression.", UserName = "Prepare Regression Label")] @@ -292,11 +291,11 @@ public static CommonOutputs.TransformOutput PrepareRegressionLabel(IHostEnvironm int labelCol; if (!input.Data.Schema.TryGetColumnIndex(input.LabelColumn, out labelCol)) throw host.Except($"Column '{input.LabelColumn}' not found."); - var labelType = input.Data.Schema.GetColumnType(labelCol); + var labelType = input.Data.Schema[labelCol].Type; if (labelType == NumberType.R4 || !labelType.IsNumber) { var nop = NopTransform.CreateIfNeeded(env, input.Data); - return new CommonOutputs.TransformOutput { Model = new TransformModel(env, nop, input.Data), OutputData = nop }; + return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, nop, input.Data), OutputData = nop }; } var args = new TypeConvertingTransformer.Arguments() @@ -312,7 +311,7 @@ public static CommonOutputs.TransformOutput PrepareRegressionLabel(IHostEnvironm } }; var xf = new TypeConvertingTransformer(host, new TypeConvertingTransformer.ColumnInfo(input.LabelColumn, input.LabelColumn, DataKind.R4)).Transform(input.Data); - return new CommonOutputs.TransformOutput { Model = new TransformModel(env, xf, input.Data), OutputData = xf }; + return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, xf, input.Data), OutputData = xf }; } } } diff --git a/src/Microsoft.ML.EntryPoints/ImportTextData.cs b/src/Microsoft.ML.EntryPoints/ImportTextData.cs new file mode 100644 index 0000000000..69e7c6c073 --- /dev/null +++ b/src/Microsoft.ML.EntryPoints/ImportTextData.cs @@ -0,0 +1,49 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; + +[assembly: LoadableClass(typeof(void), typeof(ImportTextData), null, typeof(SignatureEntryPointModule), "ImportTextData")] + +// The warning #612 is disabled because the following code uses legacy TextLoader. +// Because that dependency will be removed form ML.NET, one needs to rewrite all places where legacy APIs are used. +#pragma warning disable 612 +namespace Microsoft.ML.EntryPoints +{ + /// + /// A component for importing text files as . + /// + public static class ImportTextData + { + public sealed class Input + { + [Argument(ArgumentType.Required, ShortName = "data", HelpText = "Location of the input file", SortOrder = 1)] + public IFileHandle InputFile; + + [Argument(ArgumentType.AtMostOnce, ShortName = "schema", HelpText = "Custom schema to use for parsing", SortOrder = 2)] + public string CustomSchema = null; + } + + public sealed class Output + { + [TlcModule.Output(Desc = "The resulting data view", SortOrder = 1)] + public IDataView Data; + } + + [TlcModule.EntryPoint(Name = "Data.CustomTextLoader", Desc = "Import a dataset from a text file")] + public static Output ImportText(IHostEnvironment env, Input input) + { + Contracts.CheckValue(env, nameof(env)); + var host = env.Register("ImportTextData"); + env.CheckValue(input, nameof(input)); + EntryPointUtils.CheckInputArgs(host, input); + var loader = host.CreateLoader(string.Format("Text{{{0}}}", input.CustomSchema), new FileHandleSource(input.InputFile)); + return new Output { Data = loader }; + } + } +} +#pragma warning restore 612 diff --git a/src/Microsoft.ML.Legacy/Runtime/EntryPoints/JsonUtils/ExecuteGraphCommand.cs b/src/Microsoft.ML.EntryPoints/JsonUtils/ExecuteGraphCommand.cs similarity index 80% rename from src/Microsoft.ML.Legacy/Runtime/EntryPoints/JsonUtils/ExecuteGraphCommand.cs rename to src/Microsoft.ML.EntryPoints/JsonUtils/ExecuteGraphCommand.cs index d6e95961ec..a9e6c6f274 100644 --- a/src/Microsoft.ML.Legacy/Runtime/EntryPoints/JsonUtils/ExecuteGraphCommand.cs +++ b/src/Microsoft.ML.EntryPoints/JsonUtils/ExecuteGraphCommand.cs @@ -2,24 +2,22 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Collections.Generic; using System.IO; using System.Linq; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Command; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Data.IO; -using Microsoft.ML.Runtime.EntryPoints.JsonUtils; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML; +using Microsoft.ML.Command; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.Data.IO; +using Microsoft.ML.EntryPoints.JsonUtils; +using Microsoft.ML.Internal.Utilities; using Newtonsoft.Json; using Newtonsoft.Json.Linq; [assembly: LoadableClass(typeof(ExecuteGraphCommand), typeof(ExecuteGraphCommand.Arguments), typeof(SignatureCommand), "", "ExecGraph")] -namespace Microsoft.ML.Runtime.EntryPoints.JsonUtils +namespace Microsoft.ML.EntryPoints.JsonUtils { internal sealed class ExecuteGraphCommand : ICommand { @@ -115,15 +113,15 @@ public void SetInputFromPath(GraphRunner runner, string varName, string path, Tl runner.SetInput(varName, loader); break; case TlcModule.DataKind.PredictorModel: - PredictorModel pm; + PredictorModelImpl pm; using (var fs = File.OpenRead(path)) - pm = new PredictorModel(_host, fs); + pm = new PredictorModelImpl(_host, fs); runner.SetInput(varName, pm); break; case TlcModule.DataKind.TransformModel: - TransformModel tm; + TransformModelImpl tm; using (var fs = File.OpenRead(path)) - tm = new TransformModel(_host, fs); + tm = new TransformModelImpl(_host, fs); runner.SetInput(varName, tm); break; default: @@ -145,14 +143,22 @@ public void GetOutputToPath(GraphRunner runner, string varName, string path, Tlc throw _host.ExceptNotSupp("File handle outputs not yet supported."); case TlcModule.DataKind.DataView: var idv = runner.GetOutput(varName); - SaveDataView(idv, path, extension); + if (idv != null) + SaveDataView(idv, path, extension); + else + using (var ch = _host.Start("Get outputs from executed graph")) + { + string msg = string.Format("Ignoring empty graph output (output name: {0}, type: {1}, expected output's file: {2})", + varName, nameof(idv), path + extension); + ch.Warning(msg); + } break; case TlcModule.DataKind.PredictorModel: - var pm = runner.GetOutput(varName); + var pm = runner.GetOutput(varName); SavePredictorModel(pm, path); break; case TlcModule.DataKind.TransformModel: - var tm = runner.GetOutput(varName); + var tm = runner.GetOutput(varName); using (var handle = _host.CreateOutputFile(path)) using (var fs = handle.CreateWriteStream()) tm.Save(_host, fs); @@ -160,17 +166,17 @@ public void GetOutputToPath(GraphRunner runner, string varName, string path, Tlc case TlcModule.DataKind.Array: string partialPath = path.Substring(0, path.Length - extension.Length); - var ipmArray = runner.GetOutputOrDefault(varName); + var ipmArray = runner.GetOutputOrDefault(varName); if (ipmArray != null && !ipmArray.GetType().IsValueType) { - SaveArrayToFile(ipmArray.ToList(), TlcModule.DataKind.PredictorModel, partialPath, extension); + SaveArrayToFile(ipmArray, partialPath, extension); break; } var idvArray = runner.GetOutputOrDefault(varName); if (idvArray != null && !idvArray.GetType().IsValueType) { - SaveArrayToFile(idvArray.ToList(), TlcModule.DataKind.DataView, partialPath, extension); + SaveArrayToFile(idvArray, partialPath, extension); break; } goto default; @@ -180,28 +186,28 @@ public void GetOutputToPath(GraphRunner runner, string varName, string path, Tlc } - private void SaveArrayToFile(List array, TlcModule.DataKind kind, string partialPath, string extension) - where T : class + private void SaveArrayToFile(PredictorModel[] array, string partialPath, string extension) { - for (int i = 0; i < array.Count; i++) + for (int i = 0; i < array.Length; i++) { string path = $"{partialPath}_{i}{extension}"; - switch (kind) - { - case TlcModule.DataKind.DataView: - SaveDataView((IDataView)array[i], path, extension); - break; - case TlcModule.DataKind.PredictorModel: - SavePredictorModel((IPredictorModel)array[i], path); - break; - } + SavePredictorModel(array[i], path); + } + } + + private void SaveArrayToFile(IDataView[] array, string partialPath, string extension) + { + for (int i = 0; i < array.Length; i++) + { + string path = $"{partialPath}_{i}{extension}"; + SaveDataView(array[i], path, extension); } } /// /// Saves the PredictorModel to the given path /// - private void SavePredictorModel(IPredictorModel pm, string path) + private void SavePredictorModel(PredictorModel pm, string path) { Contracts.CheckValue(pm, nameof(pm)); @@ -234,8 +240,8 @@ private void SaveDataView(IDataView idv, string path, string extension) using (var handle = _host.CreateOutputFile(path)) using (var fs = handle.CreateWriteStream()) { - saver.SaveData(fs, idv, Utils.GetIdentityPermutation(idv.Schema.ColumnCount) - .Where(x => saver.IsColumnSavable(idv.Schema.GetColumnType(x))).ToArray()); + saver.SaveData(fs, idv, Utils.GetIdentityPermutation(idv.Schema.Count) + .Where(x => saver.IsColumnSavable(idv.Schema[x].Type)).ToArray()); } } } diff --git a/src/Microsoft.ML.Legacy/Runtime/EntryPoints/JsonUtils/GraphRunner.cs b/src/Microsoft.ML.EntryPoints/JsonUtils/GraphRunner.cs similarity index 99% rename from src/Microsoft.ML.Legacy/Runtime/EntryPoints/JsonUtils/GraphRunner.cs rename to src/Microsoft.ML.EntryPoints/JsonUtils/GraphRunner.cs index 9cdee0946b..8cff9c0021 100644 --- a/src/Microsoft.ML.Legacy/Runtime/EntryPoints/JsonUtils/GraphRunner.cs +++ b/src/Microsoft.ML.EntryPoints/JsonUtils/GraphRunner.cs @@ -5,7 +5,7 @@ using System.Linq; using Newtonsoft.Json.Linq; -namespace Microsoft.ML.Runtime.EntryPoints.JsonUtils +namespace Microsoft.ML.EntryPoints.JsonUtils { /// /// This class runs a graph of entry points with the specified inputs, and produces the specified outputs. diff --git a/src/Microsoft.ML.Legacy/Runtime/EntryPoints/JsonUtils/JsonManifestUtils.cs b/src/Microsoft.ML.EntryPoints/JsonUtils/JsonManifestUtils.cs similarity index 96% rename from src/Microsoft.ML.Legacy/Runtime/EntryPoints/JsonUtils/JsonManifestUtils.cs rename to src/Microsoft.ML.EntryPoints/JsonUtils/JsonManifestUtils.cs index 2a5c20e3c9..4e6341dfd1 100644 --- a/src/Microsoft.ML.Legacy/Runtime/EntryPoints/JsonUtils/JsonManifestUtils.cs +++ b/src/Microsoft.ML.EntryPoints/JsonUtils/JsonManifestUtils.cs @@ -1,4 +1,4 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. @@ -6,12 +6,11 @@ using System.Collections.Generic; using System.Linq; using System.Reflection; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Internal.Tools; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Internal.Utilities; using Newtonsoft.Json.Linq; -namespace Microsoft.ML.Runtime.EntryPoints.JsonUtils +namespace Microsoft.ML.EntryPoints.JsonUtils { /// /// Utilities to generate JSON manifests for entry points and other components. @@ -68,7 +67,7 @@ public static JObject BuildAllManifests(IExceptionContext ectx, ComponentCatalog { var jField = new JObject(); jField[FieldNames.Name] = fieldInfo.Name; - var type = CSharpGeneratorUtils.ExtractOptionalOrNullableType(fieldInfo.PropertyType); + var type = ExtractOptionalOrNullableType(fieldInfo.PropertyType); // Dive inside Var. if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Var<>)) type = type.GetGenericArguments()[0]; @@ -87,6 +86,14 @@ public static JObject BuildAllManifests(IExceptionContext ectx, ComponentCatalog return jResult; } + private static Type ExtractOptionalOrNullableType(Type type) + { + if (type.IsGenericType && (type.GetGenericTypeDefinition() == typeof(Optional<>) || type.GetGenericTypeDefinition() == typeof(Nullable<>))) + type = type.GetGenericArguments()[0]; + + return type; + } + private static JObject BuildComponentManifest(IExceptionContext ectx, ComponentCatalog.ComponentInfo componentInfo, ComponentCatalog catalog) { Contracts.AssertValueOrNull(ectx); @@ -103,7 +110,7 @@ private static JObject BuildComponentManifest(IExceptionContext ectx, ComponentC return result; } - public static JObject BuildEntryPointManifest(IExceptionContext ectx, ComponentCatalog.EntryPointInfo entryPointInfo, ComponentCatalog catalog) + private static JObject BuildEntryPointManifest(IExceptionContext ectx, ComponentCatalog.EntryPointInfo entryPointInfo, ComponentCatalog catalog) { Contracts.CheckValueOrNull(ectx); ectx.CheckValue(entryPointInfo, nameof(entryPointInfo)); @@ -298,12 +305,12 @@ private static JToken BuildTypeToken(IExceptionContext ectx, FieldInfo fieldInfo type == typeof(CommonOutputs.IEvaluatorOutput)) { var jo = new JObject(); - var typeString = $"{type}".Replace("Microsoft.ML.Runtime.EntryPoints.", ""); + var typeString = $"{type}".Replace("Microsoft.ML.EntryPoints.", ""); jo[FieldNames.Kind] = "EntryPoint"; jo[FieldNames.ItemType] = typeString; return jo; } - type = CSharpGeneratorUtils.ExtractOptionalOrNullableType(type); + type = ExtractOptionalOrNullableType(type); // Dive inside Var. if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Var<>)) @@ -366,12 +373,6 @@ private static JToken BuildTypeToken(IExceptionContext ectx, FieldInfo fieldInfo jo[FieldNames.Kind] = typeEnum.ToString(); jo[FieldNames.ComponentKind] = kind; return jo; - case TlcModule.DataKind.State: - jo = new JObject(); - var typeString = $"{type}".Replace("Microsoft.ML.Runtime.Interfaces.", ""); - jo[FieldNames.Kind] = "C# Object"; - jo[FieldNames.ItemType] = typeString; - return jo; default: ectx.Assert(false); throw ectx.ExceptNotSupp(); diff --git a/src/Microsoft.ML.EntryPoints/MacroUtils.cs b/src/Microsoft.ML.EntryPoints/MacroUtils.cs new file mode 100644 index 0000000000..466bd96d41 --- /dev/null +++ b/src/Microsoft.ML.EntryPoints/MacroUtils.cs @@ -0,0 +1,168 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; + +[assembly: EntryPointModule(typeof(MacroUtils))] + +// The warning #612 is disabled because the following code uses a lot of things in Legacy.Models while Legacy.Model is marked as obsolete. +// Because that dependency will be removed form ML.NET, one needs to rewrite all places where legacy APIs are used. +#pragma warning disable 612 +namespace Microsoft.ML.EntryPoints +{ + public static class MacroUtils + { + /// + /// Lists the types of trainer signatures. Used by entry points and autoML system + /// to know what types of evaluators to use for the train test / pipeline sweeper. + /// + public enum TrainerKinds + { + SignatureBinaryClassifierTrainer, + SignatureMultiClassClassifierTrainer, + SignatureRankerTrainer, + SignatureRegressorTrainer, + SignatureMultiOutputRegressorTrainer, + SignatureAnomalyDetectorTrainer, + SignatureClusteringTrainer, + } + + public sealed class EvaluatorSettings + { + public string LabelColumn { get; set; } + public string NameColumn { get; set; } + public string WeightColumn { get; set; } + public string GroupColumn { get; set; } + public string FeatureColumn { get; set; } + + public EvaluatorSettings() + { + LabelColumn = DefaultColumnNames.Label; + } + } + + public static EvaluateInputBase GetEvaluatorArgs(TrainerKinds kind, out string entryPointName, EvaluatorSettings settings = null) + { + switch (kind) + { + case TrainerKinds.SignatureBinaryClassifierTrainer: + entryPointName = "Models.BinaryClassificationEvaluator"; + return new BinaryClassifierMamlEvaluator.Arguments() { LabelColumn = settings.LabelColumn, WeightColumn = settings.WeightColumn, NameColumn = settings.NameColumn }; + case TrainerKinds.SignatureMultiClassClassifierTrainer: + entryPointName = "Models.ClassificationEvaluator"; + return new MultiClassMamlEvaluator.Arguments() { LabelColumn = settings.LabelColumn, WeightColumn = settings.WeightColumn, NameColumn = settings.NameColumn }; + case TrainerKinds.SignatureRankerTrainer: + entryPointName = "Models.RankerEvaluator"; + return new RankerMamlEvaluator.Arguments() { LabelColumn = settings.LabelColumn, WeightColumn = settings.WeightColumn, NameColumn = settings.NameColumn, GroupIdColumn = settings.GroupColumn }; + case TrainerKinds.SignatureRegressorTrainer: + entryPointName = "Models.RegressionEvaluator"; + return new RegressionMamlEvaluator.Arguments() { LabelColumn = settings.LabelColumn, WeightColumn = settings.WeightColumn, NameColumn = settings.NameColumn }; + case TrainerKinds.SignatureMultiOutputRegressorTrainer: + entryPointName = "Models.MultiOutputRegressionEvaluator"; + return new MultiOutputRegressionMamlEvaluator.Arguments() { LabelColumn = settings.LabelColumn, WeightColumn = settings.WeightColumn, NameColumn = settings.NameColumn }; + case TrainerKinds.SignatureAnomalyDetectorTrainer: + entryPointName = "Models.AnomalyDetectionEvaluator"; + return new AnomalyDetectionMamlEvaluator.Arguments() { LabelColumn = settings.LabelColumn, WeightColumn = settings.WeightColumn, NameColumn = settings.NameColumn }; + case TrainerKinds.SignatureClusteringTrainer: + entryPointName = "Models.ClusterEvaluator"; + return new ClusteringMamlEvaluator.Arguments() { LabelColumn = settings.LabelColumn, WeightColumn = settings.WeightColumn, NameColumn = settings.NameColumn }; + default: + throw Contracts.Except("Trainer kind not supported"); + } + } + + public sealed class ArrayIPredictorModelInput + { + [Argument(ArgumentType.Required, HelpText = "The models", SortOrder = 1)] + public PredictorModel[] Model; + } + + public sealed class ArrayIPredictorModelOutput + { + [TlcModule.Output(Desc = "The model array", SortOrder = 1)] + public PredictorModel[] OutputModel; + } + + [TlcModule.EntryPoint(Desc = "Create an array variable of " + nameof(PredictorModel), Name = "Data.PredictorModelArrayConverter")] + public static ArrayIPredictorModelOutput MakeArray(IHostEnvironment env, ArrayIPredictorModelInput input) + { + var result = new ArrayIPredictorModelOutput + { + OutputModel = input.Model + }; + return result; + } + + public sealed class ArrayIDataViewInput + { + [Argument(ArgumentType.Required, HelpText = "The data sets", SortOrder = 1)] + public IDataView[] Data; + } + + public sealed class ArrayIDataViewOutput + { + [TlcModule.Output(Desc = "The data set array", SortOrder = 1)] + public IDataView[] OutputData; + } + + [TlcModule.EntryPoint(Desc = "Create an array variable of IDataView", Name = "Data.IDataViewArrayConverter")] + public static ArrayIDataViewOutput MakeArray(IHostEnvironment env, ArrayIDataViewInput input) + { + var result = new ArrayIDataViewOutput + { + OutputData = input.Data + }; + return result; + } + + internal static void ConvertIPredictorModelsToArray(IHostEnvironment env, RunContext context, List subGraphNodes, + Var[] predModelVars, string outputVarName) + { + var predictorArrayConverterArgs = new ArrayIPredictorModelInput(); + var inputBindingMap = new Dictionary>(); + var inputMap = new Dictionary(); + + var argName = nameof(predictorArrayConverterArgs.Model); + inputBindingMap.Add(argName, new List()); + for (int i = 0; i < predModelVars.Length; i++) + { + var paramBinding = new ArrayIndexParameterBinding(argName, i); + inputBindingMap[argName].Add(paramBinding); + inputMap[paramBinding] = new SimpleVariableBinding(predModelVars[i].VarName); + } + var outputMap = new Dictionary(); + var output = new ArrayVar(); + outputMap.Add(nameof(MacroUtils.ArrayIPredictorModelOutput.OutputModel), outputVarName); + var arrayConvertNode = EntryPointNode.Create(env, "Data.PredictorModelArrayConverter", predictorArrayConverterArgs, + context, inputBindingMap, inputMap, outputMap); + subGraphNodes.Add(arrayConvertNode); + } + + internal static void ConvertIdataViewsToArray(IHostEnvironment env, RunContext context, List subGraphNodes, + Var[] vars, string outputVarName) + { + var dataviewArrayConverterArgs = new ArrayIDataViewInput(); + var inputBindingMap = new Dictionary>(); + var inputMap = new Dictionary(); + + var argName = nameof(dataviewArrayConverterArgs.Data); + inputBindingMap.Add(argName, new List()); + for (int i = 0; i < vars.Length; i++) + { + var paramBinding = new ArrayIndexParameterBinding(argName, i); + inputBindingMap[argName].Add(paramBinding); + inputMap[paramBinding] = new SimpleVariableBinding(vars[i].VarName); + } + var outputMap = new Dictionary(); + outputMap.Add(nameof(ArrayIDataViewOutput.OutputData), outputVarName); + var arrayConvertNode = EntryPointNode.Create(env, "Data.IDataViewArrayConverter", dataviewArrayConverterArgs, + context, inputBindingMap, inputMap, outputMap); + subGraphNodes.Add(arrayConvertNode); + } + } +} +#pragma warning restore 612 diff --git a/src/Microsoft.ML.EntryPoints/Microsoft.ML.EntryPoints.csproj b/src/Microsoft.ML.EntryPoints/Microsoft.ML.EntryPoints.csproj new file mode 100644 index 0000000000..52f8cf594e --- /dev/null +++ b/src/Microsoft.ML.EntryPoints/Microsoft.ML.EntryPoints.csproj @@ -0,0 +1,15 @@ + + + + netstandard2.0 + Microsoft.ML.EntryPoints + + + + + + + + + + diff --git a/src/Microsoft.ML.Legacy/Runtime/EntryPoints/ModelOperations.cs b/src/Microsoft.ML.EntryPoints/ModelOperations.cs similarity index 87% rename from src/Microsoft.ML.Legacy/Runtime/EntryPoints/ModelOperations.cs rename to src/Microsoft.ML.EntryPoints/ModelOperations.cs index f92cca99fc..3fca2855e0 100644 --- a/src/Microsoft.ML.Legacy/Runtime/EntryPoints/ModelOperations.cs +++ b/src/Microsoft.ML.EntryPoints/ModelOperations.cs @@ -2,60 +2,59 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Learners; -using Microsoft.ML.Trainers; using System.Linq; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Trainers; [assembly: LoadableClass(typeof(void), typeof(ModelOperations), null, typeof(SignatureEntryPointModule), "ModelOperations")] -namespace Microsoft.ML.Runtime.EntryPoints +namespace Microsoft.ML.EntryPoints { public static class ModelOperations { public sealed class CombineTransformModelsInput { [Argument(ArgumentType.Multiple, HelpText = "Input models", SortOrder = 1)] - public ITransformModel[] Models; + public TransformModel[] Models; } public sealed class CombineTransformModelsOutput { [TlcModule.Output(Desc = "Combined model", SortOrder = 1)] - public ITransformModel OutputModel; + public TransformModel OutputModel; } public sealed class PredictorModelInput { [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "Transform model", SortOrder = 1)] - public ITransformModel[] TransformModels; + public TransformModel[] TransformModels; [Argument(ArgumentType.Required, HelpText = "Predictor model", SortOrder = 2)] - public IPredictorModel PredictorModel; + public PredictorModel PredictorModel; } public sealed class SimplePredictorModelInput { [Argument(ArgumentType.Required, HelpText = "Transform model", SortOrder = 1)] - public ITransformModel TransformModel; + public TransformModel TransformModel; [Argument(ArgumentType.Required, HelpText = "Predictor model", SortOrder = 2)] - public IPredictorModel PredictorModel; + public PredictorModel PredictorModel; } public sealed class PredictorModelOutput { [TlcModule.Output(Desc = "Predictor model", SortOrder = 1)] - public IPredictorModel PredictorModel; + public PredictorModel PredictorModel; } public sealed class CombineOvaPredictorModelsInput : LearnerInputBaseWithWeight { [Argument(ArgumentType.Multiple, HelpText = "Input models", SortOrder = 1)] - public IPredictorModel[] ModelArray; + public PredictorModel[] ModelArray; [Argument(ArgumentType.AtMostOnce, HelpText = "Use probabilities from learners instead of raw values.", SortOrder = 2)] public bool UseProbabilities = true; @@ -64,13 +63,13 @@ public sealed class CombineOvaPredictorModelsInput : LearnerInputBaseWithWeight public sealed class CombinePredictorModelsInput { [Argument(ArgumentType.Multiple, HelpText = "Input models", SortOrder = 1)] - public IPredictorModel[] Models; + public PredictorModel[] Models; } public sealed class ApplyTransformModelInput : TransformInputBase { [Argument(ArgumentType.Required, HelpText = "Transform model", SortOrder = 2)] - public ITransformModel TransformModel; + public TransformModel TransformModel; } public sealed class ApplyTransformModelOutput @@ -88,7 +87,7 @@ public static CombineTransformModelsOutput CombineTransformModels(IHostEnvironme EntryPointUtils.CheckInputArgs(host, input); host.CheckNonEmpty(input.Models, nameof(input.Models)); - ITransformModel model = input.Models[input.Models.Length - 1]; + TransformModel model = input.Models[input.Models.Length - 1]; for (int i = input.Models.Length - 2; i >= 0; i--) model = model.Apply(env, input.Models[i]); @@ -104,7 +103,7 @@ public static PredictorModelOutput CombineModels(IHostEnvironment env, Predictor EntryPointUtils.CheckInputArgs(host, input); host.CheckNonEmpty(input.TransformModels, nameof(input.TransformModels)); - ITransformModel model = input.TransformModels[input.TransformModels.Length - 1]; + TransformModel model = input.TransformModels[input.TransformModels.Length - 1]; for (int i = input.TransformModels.Length - 2; i >= 0; i--) model = model.Apply(env, input.TransformModels[i]); return new PredictorModelOutput() { PredictorModel = input.PredictorModel.Apply(env, model) }; @@ -153,8 +152,8 @@ public static PredictorModelOutput CombineOvaModels(IHostEnvironment env, Combin return new PredictorModelOutput { - PredictorModel = new PredictorModel(env, data, input.TrainingData, - OvaPredictor.Create(host, input.UseProbabilities, + PredictorModel = new PredictorModelImpl(env, data, input.TrainingData, + OvaModelParameters.Create(host, input.UseProbabilities, input.ModelArray.Select(p => p.Predictor as IPredictorProducing).ToArray())) }; } diff --git a/src/Microsoft.ML.Legacy/Runtime/EntryPoints/OneVersusAllMacro.cs b/src/Microsoft.ML.EntryPoints/OneVersusAllMacro.cs similarity index 61% rename from src/Microsoft.ML.Legacy/Runtime/EntryPoints/OneVersusAllMacro.cs rename to src/Microsoft.ML.EntryPoints/OneVersusAllMacro.cs index f1d60d87eb..857ec4afe7 100644 --- a/src/Microsoft.ML.Legacy/Runtime/EntryPoints/OneVersusAllMacro.cs +++ b/src/Microsoft.ML.EntryPoints/OneVersusAllMacro.cs @@ -4,11 +4,12 @@ using System; using System.Collections.Generic; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Training; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Training; +using Microsoft.ML.Transforms; using Newtonsoft.Json.Linq; [assembly: LoadableClass(typeof(void), typeof(OneVersusAllMacro), null, typeof(SignatureEntryPointModule), "OneVersusAllMacro")] @@ -16,7 +17,7 @@ // The warning #612 is disabled because the following code uses Legacy.Models and Legacy.Transforms while Legacy is marked as obsolete. // Because that dependency will be removed form ML.NET, one needs to rewrite all places where legacy APIs are used. #pragma warning disable 612 -namespace Microsoft.ML.Runtime.EntryPoints +namespace Microsoft.ML.EntryPoints { /// /// This macro entrypoint implements OVA. @@ -26,7 +27,7 @@ public static class OneVersusAllMacro public sealed class SubGraphOutput { [Argument(ArgumentType.Required, HelpText = "The predictor model for the subgraph exemplar.", SortOrder = 1)] - public Var Model; + public Var Model; } public sealed class Arguments : LearnerInputBaseWithWeight @@ -46,32 +47,30 @@ public sealed class Arguments : LearnerInputBaseWithWeight public sealed class Output { [TlcModule.Output(Desc = "The trained multiclass model", SortOrder = 1)] - public IPredictorModel PredictorModel; + public PredictorModel PredictorModel; } - private static Tuple, Var> ProcessClass(IHostEnvironment env, int k, string label, Arguments input, EntryPointNode node) + private static Var ProcessClass(IHostEnvironment env, List macroNodes, int k, string label, Arguments input, EntryPointNode node) { - var macroNodes = new List(); + Contracts.AssertValue(macroNodes); // Convert label into T,F based on k. - var remapper = new Legacy.Transforms.LabelIndicator - { - ClassIndex = k, - Column = new[] - { - new Legacy.Transforms.LabelIndicatorTransformColumn - { - ClassIndex = k, - Name = label, - Source = label - } - }, - Data = { VarName = node.GetInputVariable(nameof(input.TrainingData)).ToJson() } - }; - var exp = new Experiment(env); - var remapperOutNode = exp.Add(remapper); - var subNodes = EntryPointNode.ValidateNodes(env, node.Context, exp.GetNodes()); - macroNodes.AddRange(subNodes); + var labelIndicatorArgs = new LabelIndicatorTransform.Arguments(); + labelIndicatorArgs.ClassIndex = k; + labelIndicatorArgs.Column = new[] { new LabelIndicatorTransform.Column() { Name = label, Source = label } }; + + var inputBindingMap = new Dictionary>(); + var inputMap = new Dictionary(); + var paramBinding = new SimpleParameterBinding(nameof(labelIndicatorArgs.Data)); + inputBindingMap.Add(nameof(labelIndicatorArgs.Data), new List() { paramBinding }); + inputMap.Add(paramBinding, node.GetInputVariable(nameof(input.TrainingData))); + + var outputMap = new Dictionary(); + var remappedLabelVar = new Var(); + outputMap.Add(nameof(CommonOutputs.TransformOutput.OutputData), remappedLabelVar.VarName); + var labelIndicatorNode = EntryPointNode.Create(env, "Transforms.LabelIndicator", labelIndicatorArgs, node.Context, + inputBindingMap, inputMap, outputMap); + macroNodes.Add(labelIndicatorNode); // Parse the nodes in input.Nodes into a temporary run context. var subGraphRunContext = new RunContext(env); @@ -80,7 +79,7 @@ private static Tuple, Var> ProcessClass(IH // Rename all the variables such that they don't conflict with the ones in the outer run context. var mapping = new Dictionary(); bool foundOutput = false; - Var predModelVar = null; + Var predModelVar = null; foreach (var entryPointNode in subGraphNodes) { // Rename variables in input/output maps, and in subgraph context. @@ -91,13 +90,13 @@ private static Tuple, Var> ProcessClass(IH // Grab a hold of output model from this subgraph. if (entryPointNode.GetOutputVariableName("PredictorModel") is string mvn) { - predModelVar = new Var { VarName = mvn }; + predModelVar = new Var { VarName = mvn }; foundOutput = true; } // Connect label remapper output to wherever training data was expected within the input graph. if (entryPointNode.GetInputVariable(nameof(input.TrainingData)) is VariableBinding vb) - vb.Rename(remapperOutNode.OutputData.VarName); + vb.Rename(remappedLabelVar.VarName); // Change node to use the main context. entryPointNode.SetContext(node.Context); @@ -112,8 +111,7 @@ private static Tuple, Var> ProcessClass(IH // Add training subgraph to our context. macroNodes.AddRange(subGraphNodes); - - return new Tuple, Var>(macroNodes, predModelVar); + return predModelVar; } private static int GetNumberOfClasses(IHostEnvironment env, Arguments input, out string label) @@ -122,7 +120,7 @@ private static int GetNumberOfClasses(IHostEnvironment env, Arguments input, out using (var ch = host.Start("OVA Macro GetNumberOfClasses")) { // RoleMappedData creation - ISchema schema = input.TrainingData.Schema; + var schema = input.TrainingData.Schema; label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.LabelColumn), input.LabelColumn, DefaultColumnNames.Label); @@ -151,48 +149,44 @@ public static CommonOutputs.MacroOutput OneVersusAll( env.Assert(input.Nodes.Count > 0); var numClasses = GetNumberOfClasses(env, input, out var label); - var predModelVars = new Var[numClasses]; + var predModelVars = new Var[numClasses]; // This will be the final resulting list of nodes that is returned from the macro. var macroNodes = new List(); // Instantiate the subgraph for each label value. for (int k = 0; k < numClasses; k++) - { - var result = ProcessClass(env, k, label, input, node); - predModelVars[k] = result.Item2; - macroNodes.AddRange(result.Item1); - } + predModelVars[k] = ProcessClass(env, macroNodes, k, label, input, node); + + // Convert the predictor models to an array of predictor models. + var modelsArray = new Var(); + MacroUtils.ConvertIPredictorModelsToArray(env, node.Context, macroNodes, predModelVars, modelsArray.VarName); // Use OVA model combiner to combine these models into one. // Takes in array of models that are binary predictor models and // produces single multiclass predictor model. - var macroExperiment = new Experiment(env); - var combinerNode = new Legacy.Models.OvaModelCombiner - { - ModelArray = new ArrayVar(predModelVars), - TrainingData = new Var { VarName = node.GetInputVariable(nameof(input.TrainingData)).VariableName }, - Caching = (Legacy.Models.CachingOptions)input.Caching, - FeatureColumn = input.FeatureColumn, - NormalizeFeatures = (Legacy.Models.NormalizeOption)input.NormalizeFeatures, - LabelColumn = input.LabelColumn, - UseProbabilities = input.UseProbabilities - }; - - // Get output model variable. - if (!node.OutputMap.TryGetValue(nameof(Output.PredictorModel), out var outVariableName)) - throw new Exception("Cannot find OVA model output."); - - // Map macro's output back to OVA combiner (so OVA combiner will set the value on our output variable). - var combinerOutput = new Legacy.Models.OvaModelCombiner.Output { PredictorModel = new Var { VarName = outVariableName } }; - - // Add to experiment (must be done AFTER we assign variable name to output). - macroExperiment.Add(combinerNode, combinerOutput); - - // Add nodes to main experiment. - var nodes = macroExperiment.GetNodes(); - var expNodes = EntryPointNode.ValidateNodes(env, node.Context, nodes); - macroNodes.AddRange(expNodes); + var combineArgs = new ModelOperations.CombineOvaPredictorModelsInput(); + combineArgs.Caching = input.Caching; + combineArgs.FeatureColumn = input.FeatureColumn; + combineArgs.LabelColumn = input.LabelColumn; + combineArgs.NormalizeFeatures = input.NormalizeFeatures; + combineArgs.UseProbabilities = input.UseProbabilities; + + var inputBindingMap = new Dictionary>(); + var inputMap = new Dictionary(); + var combineNodeModelArrayInput = new SimpleVariableBinding(modelsArray.VarName); + var paramBinding = new SimpleParameterBinding(nameof(combineArgs.ModelArray)); + inputBindingMap.Add(nameof(combineArgs.ModelArray), new List() { paramBinding }); + inputMap.Add(paramBinding, combineNodeModelArrayInput); + paramBinding = new SimpleParameterBinding(nameof(combineArgs.TrainingData)); + inputBindingMap.Add(nameof(combineArgs.TrainingData), new List() { paramBinding }); + inputMap.Add(paramBinding, node.GetInputVariable(nameof(input.TrainingData))); + + var outputMap = new Dictionary(); + outputMap.Add(nameof(Output.PredictorModel), node.GetOutputVariableName(nameof(Output.PredictorModel))); + var combineModelsNode = EntryPointNode.Create(env, "Models.OvaModelCombiner", + combineArgs, node.Context, inputBindingMap, inputMap, outputMap); + macroNodes.Add(combineModelsNode); return new CommonOutputs.MacroOutput() { Nodes = macroNodes }; } diff --git a/src/Microsoft.ML.EntryPoints/Properties/AssemblyInfo.cs b/src/Microsoft.ML.EntryPoints/Properties/AssemblyInfo.cs new file mode 100644 index 0000000000..0169e5d896 --- /dev/null +++ b/src/Microsoft.ML.EntryPoints/Properties/AssemblyInfo.cs @@ -0,0 +1,9 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Runtime.CompilerServices; +using Microsoft.ML; + +[assembly: InternalsVisibleTo("Microsoft.ML.Tests" + PublicKey.TestValue)] +[assembly: InternalsVisibleTo("Microsoft.ML.Core.Tests" + PublicKey.TestValue)] \ No newline at end of file diff --git a/src/Microsoft.ML.Legacy/Runtime/EntryPoints/TrainTestMacro.cs b/src/Microsoft.ML.EntryPoints/TrainTestMacro.cs similarity index 51% rename from src/Microsoft.ML.Legacy/Runtime/EntryPoints/TrainTestMacro.cs rename to src/Microsoft.ML.EntryPoints/TrainTestMacro.cs index c61ea1d268..260ba7b756 100644 --- a/src/Microsoft.ML.Legacy/Runtime/EntryPoints/TrainTestMacro.cs +++ b/src/Microsoft.ML.EntryPoints/TrainTestMacro.cs @@ -2,19 +2,20 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Newtonsoft.Json.Linq; using System; +using System.Collections.Generic; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; +using Newtonsoft.Json.Linq; [assembly: LoadableClass(typeof(void), typeof(TrainTestMacro), null, typeof(SignatureEntryPointModule), "TrainTestMacro")] // The warning #612 is disabled because the following code uses a lot of things in Legacy.Models and Legacy.Transforms while Legacy is marked as obsolete. // Because that dependency will be removed form ML.NET, one needs to rewrite all places where legacy APIs are used. #pragma warning disable 612 -namespace Microsoft.ML.Runtime.EntryPoints +namespace Microsoft.ML.EntryPoints { public static class TrainTestMacro { @@ -27,10 +28,7 @@ public sealed class SubGraphInput public sealed class SubGraphOutput { [Argument(ArgumentType.AtMostOnce, HelpText = "The predictor model", SortOrder = 1)] - public Var PredictorModel; - - [Argument(ArgumentType.AtMostOnce, HelpText = "Transform model", SortOrder = 2)] - public Var TransformModel; + public Var PredictorModel; } public sealed class Arguments @@ -45,7 +43,7 @@ public sealed class Arguments [TlcModule.OptionalInput] [Argument(ArgumentType.AtMostOnce, HelpText = "The aggregated transform model from the pipeline before this command, to apply to the test data, and also include in the final model, together with the predictor model.", SortOrder = 3)] - public Var TransformModel = null; + public Var TransformModel = null; [Argument(ArgumentType.Required, HelpText = "The training subgraph", SortOrder = 4)] public JArray Nodes; @@ -82,11 +80,7 @@ public sealed class Output { [TlcModule.Output(Desc = "The final model including the trained predictor model and the model from the transforms, " + "provided as the Input.TransformModel.", SortOrder = 1)] - public IPredictorModel PredictorModel; - - [TlcModule.Output(Desc = "The final model including the trained predictor model and the model from the transforms, " + - "provided as the Input.TransformModel.", SortOrder = 2)] - public ITransformModel TransformModel; + public PredictorModel PredictorModel; [TlcModule.Output(Desc = "Warning dataset", SortOrder = 3)] public IDataView Warnings; @@ -143,15 +137,14 @@ public static CommonOutputs.MacroOutput TrainTest( subGraphRunContext.RemoveVariable(dataVariable); // Change the subgraph to use the model variable as output. - varName = input.Outputs.PredictorModel == null ? input.Outputs.TransformModel.VarName : input.Outputs.PredictorModel.VarName; + varName = input.Outputs.PredictorModel.VarName; if (!subGraphRunContext.TryGetVariable(varName, out dataVariable)) throw env.Except($"Invalid variable name '{varName}'."); - string outputVarName = input.Outputs.PredictorModel == null ? node.GetOutputVariableName(nameof(Output.TransformModel)) : - node.GetOutputVariableName(nameof(Output.PredictorModel)); + string predictorModelVarName = node.GetOutputVariableName(nameof(Output.PredictorModel)); foreach (var subGraphNode in subGraphNodes) - subGraphNode.RenameOutputVariable(dataVariable.Name, outputVarName); + subGraphNode.RenameOutputVariable(dataVariable.Name, predictorModelVarName); subGraphRunContext.RemoveVariable(dataVariable); // Move the variables from the subcontext to the main context. @@ -163,67 +156,62 @@ public static CommonOutputs.MacroOutput TrainTest( // Testing using test data set var testingVar = node.GetInputVariable(nameof(input.TestingData)); - var exp = new Experiment(env); + //var exp = new Experiment(env); - Legacy.Transforms.DatasetScorer.Output scoreNodeOutput = null; - Legacy.Models.DatasetTransformer.Output datasetTransformNodeOutput = null; - if (input.Outputs.PredictorModel == null) - { - //combine the predictor model with any potential transfrom model passed from the outer graph - if (transformModelVarName != null && transformModelVarName.VariableName != null) - { - var modelCombine = new ML.Legacy.Transforms.ModelCombiner - { - Models = new ArrayVar( - new Var[] { - new Var { VarName = transformModelVarName.VariableName }, - new Var { VarName = outputVarName} } - ) - }; - - var modelCombineOutput = exp.Add(modelCombine); - outputVarName = modelCombineOutput.OutputModel.VarName; - } - - var datasetTransformerNode = new Legacy.Models.DatasetTransformer - { - Data = { VarName = testingVar.ToJson() }, - TransformModel = { VarName = outputVarName } - }; - - datasetTransformNodeOutput = exp.Add(datasetTransformerNode); - } - else + Dictionary> inputBindingMap; + Dictionary inputMap; + ParameterBinding paramBinding; + Dictionary outputMap; + + //combine the predictor model with any potential transfrom model passed from the outer graph + if (transformModelVarName != null && transformModelVarName.VariableName != null) { - //combine the predictor model with any potential transfrom model passed from the outer graph - if (transformModelVarName != null && transformModelVarName.VariableName != null) - { - var modelCombine = new Legacy.Transforms.TwoHeterogeneousModelCombiner - { - TransformModel = { VarName = transformModelVarName.VariableName }, - PredictorModel = { VarName = outputVarName } - }; - - var modelCombineOutput = exp.Add(modelCombine); - outputVarName = modelCombineOutput.PredictorModel.VarName; - } - - // Add the scoring node for testing. - var scoreNode = new Legacy.Transforms.DatasetScorer - { - Data = { VarName = testingVar.ToJson() }, - PredictorModel = { VarName = outputVarName } - }; - - scoreNodeOutput = exp.Add(scoreNode); + var combineArgs = new ModelOperations.SimplePredictorModelInput(); + inputBindingMap = new Dictionary>(); + inputMap = new Dictionary(); + + var inputTransformModel = new SimpleVariableBinding(transformModelVarName.VariableName); + var inputPredictorModel = new SimpleVariableBinding(predictorModelVarName); + paramBinding = new SimpleParameterBinding(nameof(combineArgs.TransformModel)); + inputBindingMap.Add(nameof(combineArgs.TransformModel), new List() { paramBinding }); + inputMap.Add(paramBinding, inputTransformModel); + paramBinding = new SimpleParameterBinding(nameof(combineArgs.PredictorModel)); + inputBindingMap.Add(nameof(combineArgs.PredictorModel), new List() { paramBinding }); + inputMap.Add(paramBinding, inputPredictorModel); + outputMap = new Dictionary(); + + var combineNodeOutputPredictorModel = new Var(); + predictorModelVarName = combineNodeOutputPredictorModel.VarName; + outputMap.Add(nameof(ModelOperations.PredictorModelOutput.PredictorModel), combineNodeOutputPredictorModel.VarName); + EntryPointNode combineNode = EntryPointNode.Create(env, "Transforms.TwoHeterogeneousModelCombiner", combineArgs, + node.Context, inputBindingMap, inputMap, outputMap); + subGraphNodes.Add(combineNode); } - subGraphNodes.AddRange(EntryPointNode.ValidateNodes(env, node.Context, exp.GetNodes())); - - // Do not double-add previous nodes. - exp.Reset(); - - // REVIEW: add similar support for NameColumn and FeatureColumn. + // Add the scoring node for testing. + var args = new ScoreModel.Input(); + inputBindingMap = new Dictionary>(); + inputMap = new Dictionary(); + paramBinding = new SimpleParameterBinding(nameof(args.Data)); + inputBindingMap.Add(nameof(args.Data), new List() { paramBinding }); + inputMap.Add(paramBinding, testingVar); + var scoreNodeInputPredictorModel = new SimpleVariableBinding(predictorModelVarName); + paramBinding = new SimpleParameterBinding(nameof(args.PredictorModel)); + inputBindingMap.Add(nameof(args.PredictorModel), new List() { paramBinding }); + inputMap.Add(paramBinding, scoreNodeInputPredictorModel); + + var scoreNodeOutputScoredData = new Var(); + var scoreNodeOutputScoringTransform = new Var(); + outputMap = new Dictionary(); + outputMap.Add(nameof(ScoreModel.Output.ScoredData), scoreNodeOutputScoredData.VarName); + outputMap.Add(nameof(ScoreModel.Output.ScoringTransform), scoreNodeOutputScoringTransform.VarName); + + EntryPointNode scoreNode = EntryPointNode.Create(env, "Transforms.DatasetScorer", args, + node.Context, inputBindingMap, inputMap, outputMap); + subGraphNodes.Add(scoreNode); + var evalDataVarName = scoreNodeOutputScoredData.VarName; + + // REVIEW: add similar support for FeatureColumn. var settings = new MacroUtils.EvaluatorSettings { LabelColumn = input.LabelColumn, @@ -232,80 +220,73 @@ public static CommonOutputs.MacroOutput TrainTest( NameColumn = input.NameColumn.IsExplicit ? input.NameColumn.Value : null }; - string outVariableName; - if (input.IncludeTrainingMetrics) { - Legacy.Transforms.DatasetScorer.Output scoreNodeTrainingOutput = null; - Legacy.Models.DatasetTransformer.Output datasetTransformNodeTrainingOutput = null; - if (input.Outputs.PredictorModel == null) - { - var datasetTransformerNode = new Legacy.Models.DatasetTransformer - { - Data = { VarName = testingVar.ToJson() }, - TransformModel = { VarName = outputVarName } - }; - - datasetTransformNodeTrainingOutput = exp.Add(datasetTransformerNode); - } - else - { - // Add the scoring node for training. - var scoreNodeTraining = new Legacy.Transforms.DatasetScorer - { - Data = { VarName = trainingVar.ToJson() }, - PredictorModel = { VarName = outputVarName } - }; - scoreNodeTrainingOutput = exp.Add(scoreNodeTraining); - } - - subGraphNodes.AddRange(EntryPointNode.ValidateNodes(env, node.Context, exp.GetNodes())); - - // Do not double-add previous nodes. - exp.Reset(); + string evalTrainingDataVarName; + args = new ScoreModel.Input(); + inputBindingMap = new Dictionary>(); + inputMap = new Dictionary(); + paramBinding = new SimpleParameterBinding(nameof(args.Data)); + inputBindingMap.Add(nameof(args.Data), new List() { paramBinding }); + inputMap.Add(paramBinding, trainingVar); + scoreNodeInputPredictorModel = new SimpleVariableBinding(predictorModelVarName); + paramBinding = new SimpleParameterBinding(nameof(args.PredictorModel)); + inputBindingMap.Add(nameof(args.PredictorModel), new List() { paramBinding }); + inputMap.Add(paramBinding, scoreNodeInputPredictorModel); + + scoreNodeOutputScoredData = new Var(); + scoreNodeOutputScoringTransform = new Var(); + outputMap = new Dictionary(); + outputMap.Add(nameof(ScoreModel.Output.ScoredData), scoreNodeOutputScoredData.VarName); + outputMap.Add(nameof(ScoreModel.Output.ScoringTransform), scoreNodeOutputScoringTransform.VarName); + + scoreNode = EntryPointNode.Create(env, "Transforms.DatasetScorer", args, + node.Context, inputBindingMap, inputMap, outputMap); + subGraphNodes.Add(scoreNode); + evalTrainingDataVarName = scoreNodeOutputScoredData.VarName; // Add the evaluator node for training. - var evalInputOutputTraining = MacroUtils.GetEvaluatorInputOutput(input.Kind, settings); - var evalNodeTraining = evalInputOutputTraining.Item1; - var evalOutputTraining = evalInputOutputTraining.Item2; - evalNodeTraining.Data.VarName = input.Outputs.PredictorModel == null ? datasetTransformNodeTrainingOutput.OutputData.VarName : - scoreNodeTrainingOutput.ScoredData.VarName; - - if (node.OutputMap.TryGetValue(nameof(Output.TrainingWarnings), out outVariableName)) - evalOutputTraining.Warnings.VarName = outVariableName; - if (node.OutputMap.TryGetValue(nameof(Output.TrainingOverallMetrics), out outVariableName)) - evalOutputTraining.OverallMetrics.VarName = outVariableName; - if (node.OutputMap.TryGetValue(nameof(Output.TrainingPerInstanceMetrics), out outVariableName)) - evalOutputTraining.PerInstanceMetrics.VarName = outVariableName; - if (node.OutputMap.TryGetValue(nameof(Output.TrainingConfusionMatrix), out outVariableName) - && evalOutputTraining is CommonOutputs.IClassificationEvaluatorOutput eoTraining) - eoTraining.ConfusionMatrix.VarName = outVariableName; - - exp.Add(evalNodeTraining, evalOutputTraining); - subGraphNodes.AddRange(EntryPointNode.ValidateNodes(env, node.Context, exp.GetNodes())); + var evalTrainingArgs = MacroUtils.GetEvaluatorArgs(input.Kind, out var evalTrainingEntryPointName, settings); + inputBindingMap = new Dictionary>(); + inputMap = new Dictionary(); + var evalTrainingNodeInputData = new SimpleVariableBinding(evalTrainingDataVarName); + paramBinding = new SimpleParameterBinding(nameof(evalTrainingArgs.Data)); + inputBindingMap.Add(nameof(evalTrainingArgs.Data), new List() { paramBinding }); + inputMap.Add(paramBinding, evalTrainingNodeInputData); + + outputMap = new Dictionary(); + if (node.OutputMap.TryGetValue(nameof(Output.TrainingWarnings), out var outTrainingVariableName)) + outputMap.Add(nameof(CommonOutputs.ClassificationEvaluateOutput.Warnings), outTrainingVariableName); + if (node.OutputMap.TryGetValue(nameof(Output.TrainingOverallMetrics), out outTrainingVariableName)) + outputMap.Add(nameof(CommonOutputs.ClassificationEvaluateOutput.OverallMetrics), outTrainingVariableName); + if (node.OutputMap.TryGetValue(nameof(Output.TrainingPerInstanceMetrics), out outTrainingVariableName)) + outputMap.Add(nameof(CommonOutputs.ClassificationEvaluateOutput.PerInstanceMetrics), outTrainingVariableName); + if (node.OutputMap.TryGetValue(nameof(Output.TrainingConfusionMatrix), out outTrainingVariableName)) + outputMap.Add(nameof(CommonOutputs.ClassificationEvaluateOutput.ConfusionMatrix), outTrainingVariableName); + EntryPointNode evalTrainingNode = EntryPointNode.Create(env, evalTrainingEntryPointName, evalTrainingArgs, node.Context, inputBindingMap, inputMap, outputMap); + subGraphNodes.Add(evalTrainingNode); } - // Do not double-add previous nodes. - exp.Reset(); - // Add the evaluator node for testing. - var evalInputOutput = MacroUtils.GetEvaluatorInputOutput(input.Kind, settings); - var evalNode = evalInputOutput.Item1; - var evalOutput = evalInputOutput.Item2; - evalNode.Data.VarName = input.Outputs.PredictorModel == null ? datasetTransformNodeOutput.OutputData.VarName : scoreNodeOutput.ScoredData.VarName; - - if (node.OutputMap.TryGetValue(nameof(Output.Warnings), out outVariableName)) - evalOutput.Warnings.VarName = outVariableName; + var evalArgs = MacroUtils.GetEvaluatorArgs(input.Kind, out var evalEntryPointName, settings); + inputBindingMap = new Dictionary>(); + inputMap = new Dictionary(); + var evalNodeInputData = new SimpleVariableBinding(evalDataVarName); + paramBinding = new SimpleParameterBinding(nameof(evalArgs.Data)); + inputBindingMap.Add(nameof(evalArgs.Data), new List() { paramBinding }); + inputMap.Add(paramBinding, evalNodeInputData); + + outputMap = new Dictionary(); + if (node.OutputMap.TryGetValue(nameof(Output.Warnings), out var outVariableName)) + outputMap.Add(nameof(CommonOutputs.ClassificationEvaluateOutput.Warnings), outVariableName); if (node.OutputMap.TryGetValue(nameof(Output.OverallMetrics), out outVariableName)) - evalOutput.OverallMetrics.VarName = outVariableName; + outputMap.Add(nameof(CommonOutputs.ClassificationEvaluateOutput.OverallMetrics), outVariableName); if (node.OutputMap.TryGetValue(nameof(Output.PerInstanceMetrics), out outVariableName)) - evalOutput.PerInstanceMetrics.VarName = outVariableName; - if (node.OutputMap.TryGetValue(nameof(Output.ConfusionMatrix), out outVariableName) - && evalOutput is CommonOutputs.IClassificationEvaluatorOutput eo) - eo.ConfusionMatrix.VarName = outVariableName; - - exp.Add(evalNode, evalOutput); - subGraphNodes.AddRange(EntryPointNode.ValidateNodes(env, node.Context, exp.GetNodes())); + outputMap.Add(nameof(CommonOutputs.ClassificationEvaluateOutput.PerInstanceMetrics), outVariableName); + if (node.OutputMap.TryGetValue(nameof(Output.ConfusionMatrix), out outVariableName)) + outputMap.Add(nameof(CommonOutputs.ClassificationEvaluateOutput.ConfusionMatrix), outVariableName); + EntryPointNode evalNode = EntryPointNode.Create(env, evalEntryPointName, evalArgs, node.Context, inputBindingMap, inputMap, outputMap); + subGraphNodes.Add(evalNode); // Marks as an atomic unit that can be run in // a distributed fashion. diff --git a/src/Microsoft.ML.Legacy/Runtime/EntryPoints/TrainTestSplit.cs b/src/Microsoft.ML.EntryPoints/TrainTestSplit.cs similarity index 95% rename from src/Microsoft.ML.Legacy/Runtime/EntryPoints/TrainTestSplit.cs rename to src/Microsoft.ML.EntryPoints/TrainTestSplit.cs index 928e557e37..a4b8ec98a6 100644 --- a/src/Microsoft.ML.Legacy/Runtime/EntryPoints/TrainTestSplit.cs +++ b/src/Microsoft.ML.EntryPoints/TrainTestSplit.cs @@ -2,16 +2,16 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; using Microsoft.ML.Transforms; using Microsoft.ML.Transforms.Conversions; [assembly: LoadableClass(typeof(void), typeof(TrainTestSplit), null, typeof(SignatureEntryPointModule), "TrainTestSplit")] -namespace Microsoft.ML.Runtime.EntryPoints +namespace Microsoft.ML.EntryPoints { public static class TrainTestSplit { diff --git a/src/Microsoft.ML.FastTree/Application/DominationLossApplication.cs b/src/Microsoft.ML.FastTree/Application/DominationLossApplication.cs index e9330adfe4..d12a0471fe 100644 --- a/src/Microsoft.ML.FastTree/Application/DominationLossApplication.cs +++ b/src/Microsoft.ML.FastTree/Application/DominationLossApplication.cs @@ -2,9 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using Microsoft.ML.Runtime.CommandLine; - namespace Microsoft.ML.Trainers.FastTree.Internal { #if OLD_DATALOAD diff --git a/src/Microsoft.ML.FastTree/Application/LogLossApplication.cs b/src/Microsoft.ML.FastTree/Application/LogLossApplication.cs index 27fe732fc9..9dd896a401 100644 --- a/src/Microsoft.ML.FastTree/Application/LogLossApplication.cs +++ b/src/Microsoft.ML.FastTree/Application/LogLossApplication.cs @@ -2,12 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using Microsoft.ML.Runtime.CommandLine; - namespace Microsoft.ML.Trainers.FastTree.Internal { #if OLD_DATALOAD diff --git a/src/Microsoft.ML.FastTree/Application/SizeAdjustedLogLossApplication.cs b/src/Microsoft.ML.FastTree/Application/SizeAdjustedLogLossApplication.cs index 7ac4ce8dac..0cb8c9496f 100644 --- a/src/Microsoft.ML.FastTree/Application/SizeAdjustedLogLossApplication.cs +++ b/src/Microsoft.ML.FastTree/Application/SizeAdjustedLogLossApplication.cs @@ -2,13 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Collections.Generic; -using System.Globalization; -using System.IO; -using System.Linq; -using Microsoft.ML.Runtime.CommandLine; - namespace Microsoft.ML.Trainers.FastTree.Internal { #if OLD_DATALOAD diff --git a/src/Microsoft.ML.FastTree/Application/WinLossSurplusApplication.cs b/src/Microsoft.ML.FastTree/Application/WinLossSurplusApplication.cs index 93618e2fa1..1708033161 100644 --- a/src/Microsoft.ML.FastTree/Application/WinLossSurplusApplication.cs +++ b/src/Microsoft.ML.FastTree/Application/WinLossSurplusApplication.cs @@ -2,11 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Linq; -using System.Runtime.InteropServices; -using Microsoft.ML.Runtime.CommandLine; - namespace Microsoft.ML.Trainers.FastTree.Internal { #if OLD_DATALOAD diff --git a/src/Microsoft.ML.FastTree/BinFile/BinFinder.cs b/src/Microsoft.ML.FastTree/BinFile/BinFinder.cs index eaae7d5448..2978971d22 100644 --- a/src/Microsoft.ML.FastTree/BinFile/BinFinder.cs +++ b/src/Microsoft.ML.FastTree/BinFile/BinFinder.cs @@ -2,11 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; using System; using System.Threading; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Utilities; namespace Microsoft.ML.Trainers.FastTree.Internal { diff --git a/src/Microsoft.ML.FastTree/BinFile/IniFileParserInterface.cs b/src/Microsoft.ML.FastTree/BinFile/IniFileParserInterface.cs index 70d7899b90..9b28df134c 100644 --- a/src/Microsoft.ML.FastTree/BinFile/IniFileParserInterface.cs +++ b/src/Microsoft.ML.FastTree/BinFile/IniFileParserInterface.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; using System; using System.Linq; using System.Runtime.InteropServices; diff --git a/src/Microsoft.ML.FastTree/BoostingFastTree.cs b/src/Microsoft.ML.FastTree/BoostingFastTree.cs index 4ae0bf16b6..41995b1585 100644 --- a/src/Microsoft.ML.FastTree/BoostingFastTree.cs +++ b/src/Microsoft.ML.FastTree/BoostingFastTree.cs @@ -2,12 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Core.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Internal.Internallearn; -using Microsoft.ML.Trainers.FastTree.Internal; using System; using System.Linq; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Trainers.FastTree.Internal; using Float = System.Single; namespace Microsoft.ML.Trainers.FastTree diff --git a/src/Microsoft.ML.FastTree/Dataset/Dataset.cs b/src/Microsoft.ML.FastTree/Dataset/Dataset.cs index a1e70ab5b9..330ef028ca 100644 --- a/src/Microsoft.ML.FastTree/Dataset/Dataset.cs +++ b/src/Microsoft.ML.FastTree/Dataset/Dataset.cs @@ -2,12 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Internal.Utilities; using System; using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; +using Microsoft.ML.Internal.Utilities; namespace Microsoft.ML.Trainers.FastTree.Internal { diff --git a/src/Microsoft.ML.FastTree/Dataset/DenseIntArray.cs b/src/Microsoft.ML.FastTree/Dataset/DenseIntArray.cs index 1e437456e6..c6f10843d8 100644 --- a/src/Microsoft.ML.FastTree/Dataset/DenseIntArray.cs +++ b/src/Microsoft.ML.FastTree/Dataset/DenseIntArray.cs @@ -2,11 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; using System; using System.Collections.Generic; using System.Linq; using System.Runtime.InteropServices; +using System.Security; namespace Microsoft.ML.Trainers.FastTree.Internal { @@ -70,13 +70,14 @@ public override IntArray[] Split(int[][] assignment) } #if USE_FASTTREENATIVE - [DllImport("FastTreeNative", CallingConvention = CallingConvention.StdCall)] + internal const string NativePath = "FastTreeNative"; + [DllImport(NativePath), SuppressUnmanagedCodeSecurity] private static extern unsafe int C_Sumup_float( int numBits, byte* pData, int* pIndices, float* pSampleOutputs, double* pSampleOutputWeights, FloatType* pSumTargetsByBin, double* pSumTargets2ByBin, int* pCountByBin, int totalCount, double totalSampleOutputs, double totalSampleOutputWeights); - [DllImport("FastTreeNative", CallingConvention = CallingConvention.StdCall)] + [DllImport(NativePath), SuppressUnmanagedCodeSecurity] private static extern unsafe int C_Sumup_double( int numBits, byte* pData, int* pIndices, double* pSampleOutputs, double* pSampleOutputWeights, FloatType* pSumTargetsByBin, double* pSumTargets2ByBin, int* pCountByBin, @@ -154,7 +155,8 @@ public Dense0BitIntArray(byte[] buffer, ref int position) { } - public override MD5Hash MD5Hash { + public override MD5Hash MD5Hash + { get { return MD5Hasher.Hash(Length); } } @@ -178,13 +180,16 @@ public override void ToByteArray(byte[] buffer, ref int position) Length.ToByteArray(buffer, ref position); } - public override int this[int index] { - get { + public override int this[int index] + { + get + { Contracts.Assert(0 <= index && index < Length); return 0; } - set { + set + { Contracts.Assert(0 <= index && index < Length); Contracts.Assert(value == 0); } @@ -266,7 +271,8 @@ private void Set(long offset, uint mask, int value) _data[major + 1] = (_data[major + 1] & ~major1Mask) | (uint)(val >> 32); } - public override MD5Hash MD5Hash { + public override MD5Hash MD5Hash + { get { return MD5Hasher.Hash(_data); } } @@ -291,8 +297,10 @@ public override void ToByteArray(byte[] buffer, ref int position) _data.ToByteArray(buffer, ref position); } - public sealed override unsafe int this[int index] { - get { + public sealed override unsafe int this[int index] + { + get + { long offset = index; offset = (offset << 3) + (offset << 1); int minor = (int)(offset & 0x1f); @@ -301,7 +309,8 @@ public sealed override unsafe int this[int index] { return (int)(((*(ulong*)(pData + major)) >> minor) & _mask); } - set { + set + { Contracts.Assert(0 <= value && value < (1 << 10)); Set(((long)index) * 10, _mask, value); } @@ -436,10 +445,12 @@ public override unsafe void Callback(Action callback) } } - public override unsafe int this[int index] { + public override unsafe int this[int index] + { get { return _data[index]; } - set { + set + { Contracts.Assert(0 <= value && value <= byte.MaxValue); _data[index] = (byte)value; } @@ -471,7 +482,8 @@ internal sealed class Dense4BitIntArray : DenseIntArray public override IntArrayBits BitsPerItem { get { return IntArrayBits.Bits4; } } - public override MD5Hash MD5Hash { + public override MD5Hash MD5Hash + { get { return MD5Hasher.Hash(_data); } } @@ -532,8 +544,10 @@ public override void ToByteArray(byte[] buffer, ref int position) _data.ToByteArray(buffer, ref position); } - public override unsafe int this[int index] { - get { + public override unsafe int this[int index] + { + get + { int dataIndex = index / 2; bool highBits = (index % 2 == 0); @@ -546,7 +560,8 @@ public override unsafe int this[int index] { return v; } - set { + set + { Contracts.Assert(0 <= value && value < (1 << 4)); byte v; v = (byte)value; @@ -607,7 +622,8 @@ public Dense16BitIntArray(byte[] buffer, ref int position) _data = buffer.ToUShortArray(ref position); } - public override MD5Hash MD5Hash { + public override MD5Hash MD5Hash + { get { return MD5Hasher.Hash(_data); } } @@ -640,12 +656,15 @@ public override void ToByteArray(byte[] buffer, ref int position) _data.ToByteArray(buffer, ref position); } - public override unsafe int this[int index] { - get { + public override unsafe int this[int index] + { + get + { return _data[index]; } - set { + set + { Contracts.Assert(0 <= value && value <= ushort.MaxValue); _data[index] = (ushort)value; } @@ -700,7 +719,8 @@ public override unsafe void Callback(Action callback) } } - public override MD5Hash MD5Hash { + public override MD5Hash MD5Hash + { get { return MD5Hasher.Hash(_data); } } @@ -725,12 +745,15 @@ public override void ToByteArray(byte[] buffer, ref int position) _data.ToByteArray(buffer, ref position); } - public override int this[int index] { - get { + public override int this[int index] + { + get + { return _data[index]; } - set { + set + { Contracts.Assert(value >= 0); _data[index] = value; } diff --git a/src/Microsoft.ML.FastTree/Dataset/Feature.cs b/src/Microsoft.ML.FastTree/Dataset/Feature.cs index 0b02f8e4f2..c74ef5baf4 100644 --- a/src/Microsoft.ML.FastTree/Dataset/Feature.cs +++ b/src/Microsoft.ML.FastTree/Dataset/Feature.cs @@ -3,7 +3,6 @@ // See the LICENSE file in the project root for more information. using System.Linq; -using Microsoft.ML.Runtime; namespace Microsoft.ML.Trainers.FastTree.Internal { diff --git a/src/Microsoft.ML.FastTree/Dataset/FeatureFlock.cs b/src/Microsoft.ML.FastTree/Dataset/FeatureFlock.cs index cb3070d45b..67f7dec621 100644 --- a/src/Microsoft.ML.FastTree/Dataset/FeatureFlock.cs +++ b/src/Microsoft.ML.FastTree/Dataset/FeatureFlock.cs @@ -12,9 +12,8 @@ using System.Collections.Generic; using System.Linq; using System.Runtime.CompilerServices; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Internal.CpuMath; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Internal.CpuMath; +using Microsoft.ML.Internal.Utilities; namespace Microsoft.ML.Trainers.FastTree.Internal { diff --git a/src/Microsoft.ML.FastTree/Dataset/FeatureHistogram.cs b/src/Microsoft.ML.FastTree/Dataset/FeatureHistogram.cs index 5ba2508374..2294525a9f 100644 --- a/src/Microsoft.ML.FastTree/Dataset/FeatureHistogram.cs +++ b/src/Microsoft.ML.FastTree/Dataset/FeatureHistogram.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; using System; using System.Linq; diff --git a/src/Microsoft.ML.FastTree/Dataset/FileObjectStore.cs b/src/Microsoft.ML.FastTree/Dataset/FileObjectStore.cs index f1a8337c53..9f7623d0ad 100644 --- a/src/Microsoft.ML.FastTree/Dataset/FileObjectStore.cs +++ b/src/Microsoft.ML.FastTree/Dataset/FileObjectStore.cs @@ -1,16 +1,6 @@ -// ----------------------------------------------------------------------- -// -// Copyright (C) All Rights Reserved -// -// ----------------------------------------------------------------------- - -using System; -using System.Collections.Generic; -using System.IO; -using System.IO.MemoryMappedFiles; -using System.Linq; -using System.Runtime.Serialization; -using System.Threading; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. namespace Microsoft.ML.Trainers.FastTree.Internal { @@ -229,7 +219,7 @@ private void InitializeAsFileStream(string fileName) this.objectCacheFileStream = new FileStream(this.fileStreamName, FileMode.OpenOrCreate, FileAccess.ReadWrite, FileShare.ReadWrite); } - #region Process shutdown event handlers + #region Process shutdown event handlers /// /// Process exit and Domain Unload event handler @@ -262,7 +252,7 @@ private void CurrentDomain_UnhandledException(object sender, UnhandledExceptionE this.Close(); throw Contracts.Except(e.ExceptionObject as Exception, "Unhandled Exception detected"); } - #endregion + #endregion } /// diff --git a/src/Microsoft.ML.FastTree/Dataset/IntArray.cs b/src/Microsoft.ML.FastTree/Dataset/IntArray.cs index dd17900733..e2c8de872b 100644 --- a/src/Microsoft.ML.FastTree/Dataset/IntArray.cs +++ b/src/Microsoft.ML.FastTree/Dataset/IntArray.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; using System; using System.Collections.Generic; using System.Linq; @@ -12,7 +11,6 @@ namespace Microsoft.ML.Trainers.FastTree.Internal #if USE_SINGLE_PRECISION using FloatType = System.Single; #else - using FloatType = System.Double; #endif public enum IntArrayType { Dense, Sparse, Repeat, Segmented, Current }; diff --git a/src/Microsoft.ML.FastTree/Dataset/NHotFeatureFlock.cs b/src/Microsoft.ML.FastTree/Dataset/NHotFeatureFlock.cs index 272f9e6d92..fcfa6c938a 100644 --- a/src/Microsoft.ML.FastTree/Dataset/NHotFeatureFlock.cs +++ b/src/Microsoft.ML.FastTree/Dataset/NHotFeatureFlock.cs @@ -2,8 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; - namespace Microsoft.ML.Trainers.FastTree.Internal { /// diff --git a/src/Microsoft.ML.FastTree/Dataset/OneHotFeatureFlock.cs b/src/Microsoft.ML.FastTree/Dataset/OneHotFeatureFlock.cs index 01a26e85d5..da8592fc14 100644 --- a/src/Microsoft.ML.FastTree/Dataset/OneHotFeatureFlock.cs +++ b/src/Microsoft.ML.FastTree/Dataset/OneHotFeatureFlock.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; using System.Linq; namespace Microsoft.ML.Trainers.FastTree.Internal diff --git a/src/Microsoft.ML.FastTree/Dataset/RepeatIntArray.cs b/src/Microsoft.ML.FastTree/Dataset/RepeatIntArray.cs index d5b9e2964e..3b9502a8db 100644 --- a/src/Microsoft.ML.FastTree/Dataset/RepeatIntArray.cs +++ b/src/Microsoft.ML.FastTree/Dataset/RepeatIntArray.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; using System; using System.Collections.Generic; using System.Linq; diff --git a/src/Microsoft.ML.FastTree/Dataset/SegmentIntArray.cs b/src/Microsoft.ML.FastTree/Dataset/SegmentIntArray.cs index 5ddb8db873..156f5e227c 100644 --- a/src/Microsoft.ML.FastTree/Dataset/SegmentIntArray.cs +++ b/src/Microsoft.ML.FastTree/Dataset/SegmentIntArray.cs @@ -2,10 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; using System.Collections.Generic; using System.Linq; using System.Runtime.InteropServices; +using System.Security; namespace Microsoft.ML.Trainers.FastTree.Internal { @@ -489,31 +489,31 @@ public static unsafe void SegmentFindOptimalCost31(uint[] array, int len, out lo } bits = b; } - + internal const string NativePath = "FastTreeNative"; #pragma warning disable TLC_GeneralName // Externs follow their own rules. - [DllImport("FastTreeNative", CallingConvention = CallingConvention.StdCall, CharSet = CharSet.Ansi)] + [DllImport(NativePath, CharSet = CharSet.Ansi), SuppressUnmanagedCodeSecurity] private static extern unsafe void C_SegmentFindOptimalPath21(uint* valv, int valc, long* pBits, int* pTransitions); - [DllImport("FastTreeNative", CallingConvention = CallingConvention.StdCall, CharSet = CharSet.Ansi)] + [DllImport(NativePath, CharSet = CharSet.Ansi), SuppressUnmanagedCodeSecurity] private static extern unsafe void C_SegmentFindOptimalPath15(uint* valv, int valc, long* pBits, int* pTransitions); - [DllImport("FastTreeNative", CallingConvention = CallingConvention.StdCall, CharSet = CharSet.Ansi)] + [DllImport(NativePath, CharSet = CharSet.Ansi), SuppressUnmanagedCodeSecurity] private static extern unsafe void C_SegmentFindOptimalPath7(uint* valv, int valc, long* pBits, int* pTransitions); - [DllImport("FastTreeNative", CallingConvention = CallingConvention.StdCall, CharSet = CharSet.Ansi)] + [DllImport(NativePath, CharSet = CharSet.Ansi), SuppressUnmanagedCodeSecurity] private static extern unsafe void C_SegmentFindOptimalCost15(uint* valv, int valc, long* pBits); - [DllImport("FastTreeNative", CallingConvention = CallingConvention.StdCall, CharSet = CharSet.Ansi)] + [DllImport(NativePath, CharSet = CharSet.Ansi), SuppressUnmanagedCodeSecurity] private static extern unsafe void C_SegmentFindOptimalCost31(uint* valv, int valc, long* pBits); - [DllImport("FastTreeNative", CallingConvention = CallingConvention.StdCall)] + [DllImport(NativePath)] private static extern unsafe int C_SumupSegment_float( uint* pData, byte* pSegType, int* pSegLength, int* pIndices, float* pSampleOutputs, double* pSampleOutputWeights, float* pSumTargetsByBin, double* pSumWeightsByBin, int* pCountByBin, int totalCount, double totalSampleOutputs); - [DllImport("FastTreeNative", CallingConvention = CallingConvention.StdCall)] + [DllImport(NativePath)] private static extern unsafe int C_SumupSegment_double( uint* pData, byte* pSegType, int* pSegLength, int* pIndices, double* pSampleOutputs, double* pSampleOutputWeights, diff --git a/src/Microsoft.ML.FastTree/Dataset/SingletonFeatureFlock.cs b/src/Microsoft.ML.FastTree/Dataset/SingletonFeatureFlock.cs index 91f74faa9a..c84a38db28 100644 --- a/src/Microsoft.ML.FastTree/Dataset/SingletonFeatureFlock.cs +++ b/src/Microsoft.ML.FastTree/Dataset/SingletonFeatureFlock.cs @@ -2,9 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Internal.Utilities; using System.Linq; +using Microsoft.ML.Internal.Utilities; namespace Microsoft.ML.Trainers.FastTree.Internal { diff --git a/src/Microsoft.ML.FastTree/Dataset/SparseIntArray.cs b/src/Microsoft.ML.FastTree/Dataset/SparseIntArray.cs index 6d4098f3b7..b3dc61b07a 100644 --- a/src/Microsoft.ML.FastTree/Dataset/SparseIntArray.cs +++ b/src/Microsoft.ML.FastTree/Dataset/SparseIntArray.cs @@ -2,10 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; using System.Collections.Generic; using System.Linq; using System.Runtime.InteropServices; +using System.Security; namespace Microsoft.ML.Trainers.FastTree.Internal { @@ -490,12 +490,13 @@ public override void Sumup(SumupInputData input, FeatureHistogram histogram) } #if USE_FASTTREENATIVE - [DllImport("FastTreeNative", CallingConvention = CallingConvention.StdCall)] + internal const string NativePath = "FastTreeNative"; + [DllImport(NativePath), SuppressUnmanagedCodeSecurity] private static extern unsafe int C_SumupDeltaSparse_float(int numBits, byte* pValues, byte* pDeltas, int numDeltas, int* pIndices, float* pSampleOutputs, double* pSampleOutputWeights, float* pSumTargetsByBin, double* pSumTargets2ByBin, int* pCountByBin, int totalCount, double totalSampleOutputs, double totalSampleOutputWeights); - [DllImport("FastTreeNative", CallingConvention = CallingConvention.StdCall)] + [DllImport(NativePath), SuppressUnmanagedCodeSecurity] private static extern unsafe int C_SumupDeltaSparse_double(int numBits, byte* pValues, byte* pDeltas, int numDeltas, int* pIndices, double* pSampleOutputs, double* pSampleOutputWeights, double* pSumTargetsByBin, double* pSumTargets2ByBin, int* pCountByBin, int totalCount, double totalSampleOutputs, double totalSampleOutputWeights); diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index 1238b8392c..9c5e7ec4cd 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -2,24 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Core.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Data.Conversion; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Calibration; -using Microsoft.ML.Runtime.Internal.Internallearn; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Model.Onnx; -using Microsoft.ML.Runtime.Model.Pfa; -using Microsoft.ML.Runtime.Training; -using Microsoft.ML.Runtime.TreePredictor; -using Microsoft.ML.Trainers.FastTree.Internal; -using Microsoft.ML.Transforms; -using Microsoft.ML.Transforms.Conversions; -using Newtonsoft.Json.Linq; using System; using System.Collections; using System.Collections.Generic; @@ -28,6 +10,24 @@ using System.IO; using System.Linq; using System.Text; +using Microsoft.ML.Calibrator; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data; +using Microsoft.ML.Data.Conversion; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Calibration; +using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Model.Onnx; +using Microsoft.ML.Model.Pfa; +using Microsoft.ML.Trainers.FastTree.Internal; +using Microsoft.ML.Training; +using Microsoft.ML.Transforms; +using Microsoft.ML.Transforms.Conversions; +using Microsoft.ML.TreePredictor; +using Newtonsoft.Json.Linq; using Float = System.Single; // All of these reviews apply in general to fast tree and random forest implementations. @@ -57,12 +57,12 @@ public abstract class FastTreeTrainerBase : protected readonly bool AllowGC; protected TreeEnsemble TrainedEnsemble; protected int FeatureCount; - protected RoleMappedData ValidData; + private protected RoleMappedData ValidData; /// /// If not null, it's a test data set passed in from training context. It will be converted to one element in /// by calling in . /// - protected RoleMappedData TestData; + private protected RoleMappedData TestData; protected IParallelTraining ParallelTraining; protected OptimizationAlgorithm OptimizationAlgorithm; protected Dataset TrainSet; @@ -208,9 +208,9 @@ private void Initialize(IHostEnvironment env) InitializeThreads(numThreads); } - protected void ConvertData(RoleMappedData trainData) + private protected void ConvertData(RoleMappedData trainData) { - MetadataUtils.TryGetCategoricalFeatureIndices(trainData.Schema.Schema, trainData.Schema.Feature.Index, out CategoricalFeatures); + MetadataUtils.TryGetCategoricalFeatureIndices(trainData.Schema.Schema, trainData.Schema.Feature.Value.Index, out CategoricalFeatures); var useTranspose = UseTranspose(Args.DiskTranspose, trainData) && (ValidData == null || UseTranspose(Args.DiskTranspose, ValidData)); var instanceConverter = new ExamplesToFastTreeBins(Host, Args.MaxBins, useTranspose, !Args.FeatureFlocks, Args.MinDocumentsInLeafs, GetMaxLabel()); @@ -225,13 +225,13 @@ protected void ConvertData(RoleMappedData trainData) private bool UseTranspose(bool? useTranspose, RoleMappedData data) { Host.AssertValue(data); - Host.AssertValue(data.Schema.Feature); + Host.Assert(data.Schema.Feature.HasValue); if (useTranspose.HasValue) return useTranspose.Value; ITransposeDataView td = data.Data as ITransposeDataView; - return td != null && td.TransposeSchema.GetSlotType(data.Schema.Feature.Index) != null; + return td != null && td.TransposeSchema.GetSlotType(data.Schema.Feature.Value.Index) != null; } protected void TrainCore(IChannel ch) @@ -947,11 +947,11 @@ private DataConverter(RoleMappedData data, IHost host, Double[][] binUpperBounds Contracts.AssertValue(host, "host"); Host = host; Host.CheckValue(data, nameof(data)); - data.CheckFeatureFloatVector(); + data.CheckFeatureFloatVector(out int featLen); data.CheckOptFloatWeight(); data.CheckOptGroup(); - NumFeatures = data.Schema.Feature.Type.VectorSize; + NumFeatures = featLen; if (binUpperBounds != null) { Host.AssertValue(binUpperBounds); @@ -1318,14 +1318,15 @@ public override Dataset GetDataset() return _dataset; } - private static int AddColumnIfNeeded(ColumnInfo info, List toTranspose) + private static int AddColumnIfNeeded(Schema.Column? info, List toTranspose) { - if (info == null) + if (!info.HasValue) return -1; // It is entirely possible that a single column could have two roles, // and so be added twice, but this case is handled by the transposer. - toTranspose.Add(info.Index); - return info.Index; + var idx = info.Value.Index; + toTranspose.Add(idx); + return idx; } private ValueMapper, VBuffer> GetCopier(ColumnType itemType1, ColumnType itemType2) @@ -1358,10 +1359,7 @@ private ValueMapper, VBuffer> GetCopier(ColumnType itemT private Dataset Construct(RoleMappedData examples, ref int numExamples, int maxBins, IParallelTraining parallelTraining) { Host.AssertValue(examples); - Host.AssertValue(examples.Schema.Feature); - Host.AssertValueOrNull(examples.Schema.Label); - Host.AssertValueOrNull(examples.Schema.Group); - Host.AssertValueOrNull(examples.Schema.Weight); + Host.Assert(examples.Schema.Feature.HasValue); if (parallelTraining == null) Host.AssertValue(BinUpperBounds); @@ -1386,8 +1384,8 @@ private Dataset Construct(RoleMappedData examples, ref int numExamples, int maxB data = new LabelConvertTransform(Host, convArgs, data); } // Convert the group column, if one exists. - if (examples.Schema.Group != null) - data = new TypeConvertingTransformer(Host, new TypeConvertingTransformer.ColumnInfo(examples.Schema.Group.Name, examples.Schema.Group.Name, DataKind.U8)).Transform(data); + if (examples.Schema.Group?.Name is string groupName) + data = new TypeConvertingTransformer(Host, new TypeConvertingTransformer.ColumnInfo(groupName, groupName, DataKind.U8)).Transform(data); // Since we've passed it through a few transforms, reconstitute the mapping on the // newly transformed data. @@ -1440,7 +1438,7 @@ private Dataset Construct(RoleMappedData examples, ref int numExamples, int maxB pch.SetHeader(new ProgressHeader("features"), e => e.SetProgress(0, iFeature, features.Length)); while (cursor.MoveNext()) { - iFeature = checked((int)cursor.Position); + iFeature = cursor.SlotIndex; if (!localConstructBinFeatures[iFeature]) continue; @@ -1646,7 +1644,7 @@ private Dataset Construct(RoleMappedData examples, ref int numExamples, int maxB else { if (groupIdx >= 0) - ch.Warning("This is not ranking problem, Group Id '{0}' column will be ignored", examples.Schema.Group.Name); + ch.Warning("This is not ranking problem, Group Id '{0}' column will be ignored", examples.Schema.Group.Value.Name); const int queryChunkSize = 100; qids = new ulong[(numExamples - 1) / queryChunkSize + 1]; boundaries = new int[qids.Length + 1]; @@ -1669,19 +1667,19 @@ private Dataset Construct(RoleMappedData examples, ref int numExamples, int maxB return result; } - private void GetFeatureValues(ISlotCursor cursor, int iFeature, ValueGetter> getter, + private void GetFeatureValues(SlotCursor cursor, int iFeature, ValueGetter> getter, ref VBuffer temp, ref VBuffer doubleTemp, ValueMapper, VBuffer> copier) { while (cursor.MoveNext()) { - Contracts.Assert(iFeature >= checked((int)cursor.Position)); + Contracts.Assert(iFeature >= cursor.SlotIndex); - if (iFeature == checked((int)cursor.Position)) + if (iFeature == cursor.SlotIndex) break; } - Contracts.Assert(cursor.Position == iFeature); + Contracts.Assert(cursor.SlotIndex == iFeature); getter(ref temp); copier(in temp, ref doubleTemp); @@ -1699,13 +1697,13 @@ private static ValueGetter> SubsetGetter(ValueGetter> g /// Returns a slot dropper object that has ranges of slots to be dropped, /// based on an examination of the feature values. /// - private static SlotDropper ConstructDropSlotRanges(ISlotCursor cursor, + private static SlotDropper ConstructDropSlotRanges(SlotCursor cursor, ValueGetter> getter, ref VBuffer temp) { // The iteration here is slightly differently from a usual cursor iteration. Here, temp // already holds the value of the cursor's current position, and we don't really want // to re-fetch it, and the cursor is necessarily advanced. - Contracts.Assert(cursor.State == CursorState.Good); + Contracts.Assert(cursor.SlotIndex >= 0); BitArray rowHasMissing = new BitArray(temp.Length); for (; ; ) { @@ -1858,7 +1856,7 @@ private void MakeBoundariesAndCheckLabels(out long missingInstances, out long to else { if (_data.Schema.Group != null) - ch.Warning("This is not ranking problem, Group Id '{0}' column will be ignored", _data.Schema.Group.Name); + ch.Warning("This is not ranking problem, Group Id '{0}' column will be ignored", _data.Schema.Group.Value.Name); } using (var cursor = new FloatLabelCursor(_data, curOptions)) { @@ -2793,24 +2791,25 @@ public Dataset GetCompatibleDataset(RoleMappedData data, PredictionKind kind, in } } - public abstract class FastTreePredictionWrapper : - PredictorBase, + public abstract class TreeEnsembleModelParameters : + ModelParametersBase, IValueMapper, ICanSaveInTextFormat, ICanSaveInIniFormat, ICanSaveInSourceCode, - ICanSaveModel, ICanSaveSummary, ICanGetSummaryInKeyValuePairs, ITreeEnsemble, IPredictorWithFeatureWeights, IFeatureContributionMapper, + ICalculateFeatureContribution, ICanGetSummaryAsIRow, ISingleCanSavePfa, ISingleCanSaveOnnx { //The below two properties are necessary for tree Visualizer - public TreeEnsemble TrainedEnsemble { get; } + [BestFriend] + internal TreeEnsemble TrainedEnsemble { get; } int ITreeEnsemble.NumTrees => TrainedEnsemble.NumTrees; // Inner args is used only for documentation purposes when saving comments to INI files. @@ -2829,12 +2828,18 @@ public abstract class FastTreePredictionWrapper : protected abstract uint VerCategoricalSplitSerialized { get; } - public ColumnType InputType { get; } - public ColumnType OutputType => NumberType.Float; + protected internal readonly ColumnType InputType; + ColumnType IValueMapper.InputType => InputType; + + protected readonly ColumnType OutputType; + ColumnType IValueMapper.OutputType => OutputType; + bool ICanSavePfa.CanSavePfa => true; + bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true; + public FeatureContributionCalculator FeatureContributionClaculator => new FeatureContributionCalculator(this); - protected FastTreePredictionWrapper(IHostEnvironment env, string name, TreeEnsemble trainedEnsemble, int numFeatures, string innerArgs) + public TreeEnsembleModelParameters(IHostEnvironment env, string name, TreeEnsemble trainedEnsemble, int numFeatures, string innerArgs) : base(env, name) { Host.CheckValue(trainedEnsemble, nameof(trainedEnsemble)); @@ -2852,9 +2857,10 @@ protected FastTreePredictionWrapper(IHostEnvironment env, string name, TreeEnsem Contracts.Assert(NumFeatures > MaxSplitFeatIdx); InputType = new VectorType(NumberType.Float, NumFeatures); + OutputType = NumberType.Float; } - protected FastTreePredictionWrapper(IHostEnvironment env, string name, ModelLoadContext ctx, VersionInfo ver) + protected TreeEnsembleModelParameters(IHostEnvironment env, string name, ModelLoadContext ctx, VersionInfo ver) : base(env, name, ctx) { // *** Binary format *** @@ -2889,9 +2895,11 @@ protected FastTreePredictionWrapper(IHostEnvironment env, string name, ModelLoad // tricks. InputType = new VectorType(NumberType.Float, NumFeatures); + OutputType = NumberType.Float; } - protected override void SaveCore(ModelSaveContext ctx) + [BestFriend] + private protected override void SaveCore(ModelSaveContext ctx) { base.SaveCore(ctx); @@ -2906,7 +2914,7 @@ protected override void SaveCore(ModelSaveContext ctx) ctx.Writer.Write(NumFeatures); } - public ValueMapper GetMapper() + ValueMapper IValueMapper.GetMapper() { Host.Check(typeof(TIn) == typeof(VBuffer)); Host.Check(typeof(TOut) == typeof(Float)); @@ -2925,7 +2933,7 @@ protected virtual void Map(in VBuffer src, ref Float dst) dst = (Float)TrainedEnsemble.GetOutput(in src); } - public ValueMapper> GetFeatureContributionMapper(int top, int bottom, bool normalize) + ValueMapper> IFeatureContributionMapper.GetFeatureContributionMapper(int top, int bottom, bool normalize) { Host.Check(typeof(TSrc) == typeof(VBuffer)); Host.Check(typeof(TDst) == typeof(VBuffer)); @@ -2937,7 +2945,7 @@ public ValueMapper> GetFeatureContributionMapper src, ref VBuffer dst) => { FeatureContributionMap(in src, ref dst, ref builder); - Runtime.Numeric.VectorUtils.SparsifyNormalize(ref dst, top, bottom, normalize); + Numeric.VectorUtils.SparsifyNormalize(ref dst, top, bottom, normalize); }; return (ValueMapper>)(Delegate)del; } @@ -2955,7 +2963,7 @@ private void FeatureContributionMap(in VBuffer src, ref VBuffer ds /// /// write out a C# representation of the ensemble /// - public void SaveAsCode(TextWriter writer, RoleMappedSchema schema) + void ICanSaveInSourceCode.SaveAsCode(TextWriter writer, RoleMappedSchema schema) { Host.CheckValueOrNull(schema); SaveEnsembleAsCode(writer, schema); @@ -2964,54 +2972,24 @@ public void SaveAsCode(TextWriter writer, RoleMappedSchema schema) /// /// Output the INI model to a given writer /// - public void SaveAsText(TextWriter writer, RoleMappedSchema schema) + void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema) { Host.CheckValue(writer, nameof(writer)); Host.CheckValueOrNull(schema); - SaveAsIni(writer, schema); + ((ICanSaveInIniFormat)this).SaveAsIni(writer, schema); } /// /// Output the INI model to a given writer /// - public void SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICalibrator calibrator = null) + void ICanSaveInIniFormat.SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICalibrator calibrator) { Host.CheckValue(writer, nameof(writer)); - Host.CheckValue(schema, nameof(schema)); - Host.CheckValueOrNull(calibrator); - string ensembleIni = TrainedEnsemble.ToTreeEnsembleIni(new FeaturesToContentMap(schema), + var ensembleIni = FastTreeIniFileUtils.TreeEnsembleToIni(Host, TrainedEnsemble, schema, calibrator, InnerArgs, appendFeatureGain: true, includeZeroGainFeatures: false); - ensembleIni = AddCalibrationToIni(ensembleIni, calibrator); writer.WriteLine(ensembleIni); } - /// - /// Get the calibration summary in INI format - /// - private string AddCalibrationToIni(string ini, ICalibrator calibrator) - { - Host.AssertValue(ini); - Host.AssertValueOrNull(calibrator); - - if (calibrator == null) - return ini; - - if (calibrator is PlattCalibrator) - { - string calibratorEvaluatorIni = IniFileUtils.GetCalibratorEvaluatorIni(ini, calibrator as PlattCalibrator); - return IniFileUtils.AddEvaluator(ini, calibratorEvaluatorIni); - } - else - { - StringBuilder newSection = new StringBuilder(); - newSection.AppendLine(); - newSection.AppendLine(); - newSection.AppendLine("[TLCCalibration]"); - newSection.AppendLine("Type=" + calibrator.GetType().Name); - return ini + newSection; - } - } - JToken ISingleCanSavePfa.SaveAsPfa(BoundPfaContext ctx, JToken input) { Host.CheckValue(ctx, nameof(ctx)); @@ -3148,12 +3126,12 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string return true; } - public void SaveSummary(TextWriter writer, RoleMappedSchema schema) + void ICanSaveSummary.SaveSummary(TextWriter writer, RoleMappedSchema schema) { writer.WriteLine(); writer.WriteLine("Per-feature gain summary for the boosted tree ensemble:"); - foreach (var pair in GetSummaryInKeyValuePairs(schema)) + foreach (var pair in ((ICanGetSummaryInKeyValuePairs)this).GetSummaryInKeyValuePairs(schema)) { Host.Assert(pair.Value is Double); writer.WriteLine("\t{0}\t{1}", pair.Key, (Double)pair.Value); @@ -3179,7 +3157,7 @@ private IEnumerable> GetSortedFeatureGains(RoleMapp } /// - public IList> GetSummaryInKeyValuePairs(RoleMappedSchema schema) + IList> ICanGetSummaryInKeyValuePairs.GetSummaryInKeyValuePairs(RoleMappedSchema schema) { List> results = new List>(); @@ -3301,20 +3279,22 @@ public int GetLeaf(int treeId, in VBuffer features, ref List path) return TrainedEnsemble.GetTreeAt(treeId).GetLeaf(in features, ref path); } - public IRow GetSummaryIRowOrNull(RoleMappedSchema schema) + Row ICanGetSummaryAsIRow.GetSummaryIRowOrNull(RoleMappedSchema schema) { var names = default(VBuffer>); MetadataUtils.GetSlotNames(schema, RoleMappedSchema.ColumnRole.Feature, NumFeatures, ref names); - var slotNamesCol = RowColumnUtils.GetColumn(MetadataUtils.Kinds.SlotNames, - new VectorType(TextType.Instance, NumFeatures), ref names); - var slotNamesRow = RowColumnUtils.GetRow(null, slotNamesCol); + var metaBuilder = new MetadataBuilder(); + metaBuilder.AddSlotNames(NumFeatures, names.CopyTo); var weights = default(VBuffer); - GetFeatureWeights(ref weights); - return RowColumnUtils.GetRow(null, RowColumnUtils.GetColumn("Gains", new VectorType(NumberType.R4, NumFeatures), ref weights, slotNamesRow)); + ((IHaveFeatureWeights)this).GetFeatureWeights(ref weights); + var builder = new MetadataBuilder(); + builder.Add>("Gains", new VectorType(NumberType.R4, NumFeatures), weights.CopyTo, metaBuilder.GetMetadata()); + + return MetadataUtils.MetadataAsRow(builder.GetMetadata()); } - public IRow GetStatsIRowOrNull(RoleMappedSchema schema) + Row ICanGetSummaryAsIRow.GetStatsIRowOrNull(RoleMappedSchema schema) { return null; } diff --git a/src/Microsoft.ML.FastTree/FastTreeArguments.cs b/src/Microsoft.ML.FastTree/FastTreeArguments.cs index f68108d2e8..e14fd33db0 100644 --- a/src/Microsoft.ML.FastTree/FastTreeArguments.cs +++ b/src/Microsoft.ML.FastTree/FastTreeArguments.cs @@ -2,12 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Internallearn; -using Microsoft.ML.Trainers.FastTree; using System; +using Microsoft.ML.CommandLine; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Trainers.FastTree; [assembly: EntryPointModule(typeof(FastTreeBinaryClassificationTrainer.Arguments))] [assembly: EntryPointModule(typeof(FastTreeRegressionTrainer.Arguments))] diff --git a/src/Microsoft.ML.FastTree/FastTreeClassification.cs b/src/Microsoft.ML.FastTree/FastTreeClassification.cs index 7de0cd1ac5..be17d13661 100644 --- a/src/Microsoft.ML.FastTree/FastTreeClassification.cs +++ b/src/Microsoft.ML.FastTree/FastTreeClassification.cs @@ -2,20 +2,20 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML; +using Microsoft.ML.Calibrator; using Microsoft.ML.Core.Data; using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Calibration; -using Microsoft.ML.Runtime.Internal.Internallearn; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Training; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Calibration; +using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Model; using Microsoft.ML.Trainers.FastTree; using Microsoft.ML.Trainers.FastTree.Internal; -using System; -using System.Collections.Generic; -using System.Linq; +using Microsoft.ML.Training; [assembly: LoadableClass(FastTreeBinaryClassificationTrainer.Summary, typeof(FastTreeBinaryClassificationTrainer), typeof(FastTreeBinaryClassificationTrainer.Arguments), new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer), typeof(SignatureFeatureScorerTrainer) }, @@ -36,17 +36,17 @@ "fastrank", "fastrankwrapper")] -[assembly: LoadableClass(typeof(IPredictorProducing), typeof(FastTreeBinaryPredictor), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(IPredictorProducing), typeof(FastTreeBinaryModelParameters), null, typeof(SignatureLoadModel), "FastTree Binary Executor", - FastTreeBinaryPredictor.LoaderSignature)] + FastTreeBinaryModelParameters.LoaderSignature)] namespace Microsoft.ML.Trainers.FastTree { - public sealed class FastTreeBinaryPredictor : - FastTreePredictionWrapper + public sealed class FastTreeBinaryModelParameters : + TreeEnsembleModelParameters { - public const string LoaderSignature = "FastTreeBinaryExec"; - public const string RegistrationName = "FastTreeBinaryPredictor"; + internal const string LoaderSignature = "FastTreeBinaryExec"; + internal const string RegistrationName = "FastTreeBinaryPredictor"; private static VersionInfo GetVersionInfo() { @@ -60,7 +60,7 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010005, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(FastTreeBinaryPredictor).Assembly.FullName); + loaderAssemblyName: typeof(FastTreeBinaryModelParameters).Assembly.FullName); } protected override uint VerNumFeaturesSerialized => 0x00010002; @@ -69,28 +69,28 @@ private static VersionInfo GetVersionInfo() protected override uint VerCategoricalSplitSerialized => 0x00010005; - internal FastTreeBinaryPredictor(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) + public FastTreeBinaryModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) { } - private FastTreeBinaryPredictor(IHostEnvironment env, ModelLoadContext ctx) + private FastTreeBinaryModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, RegistrationName, ctx, GetVersionInfo()) { } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { base.SaveCore(ctx); ctx.SetVersionInfo(GetVersionInfo()); } - public static IPredictorProducing Create(IHostEnvironment env, ModelLoadContext ctx) + private static IPredictorProducing Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - var predictor = new FastTreeBinaryPredictor(env, ctx); + var predictor = new FastTreeBinaryModelParameters(env, ctx); ICalibrator calibrator; ctx.LoadModelOrNull(env, out calibrator, @"Calibrator"); if (calibrator == null) @@ -108,7 +108,7 @@ public sealed partial class FastTreeBinaryClassificationTrainer : /// /// The LoadName for the assembly containing the trainer. /// - public const string LoadNameValue = "FastTreeBinaryClassification"; + internal const string LoadNameValue = "FastTreeBinaryClassification"; internal const string UserNameValue = "FastTree (Boosted Trees) Classification"; internal const string Summary = "Uses a logit-boost boosted tree learner to perform binary classification."; internal const string ShortName = "ftc"; @@ -168,7 +168,7 @@ private protected override IPredictorWithFeatureWeights TrainModelCore(Tr trainData.CheckBinaryLabel(); trainData.CheckFeatureFloatVector(); trainData.CheckOptFloatWeight(); - FeatureCount = trainData.Schema.Feature.Type.ValueCount; + FeatureCount = trainData.Schema.Feature.Value.Type.ValueCount; ConvertData(trainData); TrainCore(ch); } @@ -177,7 +177,7 @@ private protected override IPredictorWithFeatureWeights TrainModelCore(Tr // output probabilities when transformed using a scaled logistic function, // so transform the scores using that. - var pred = new FastTreeBinaryPredictor(Host, TrainedEnsemble, FeatureCount, InnerArgs); + var pred = new FastTreeBinaryModelParameters(Host, TrainedEnsemble, FeatureCount, InnerArgs); // FastTree's binary classification boosting framework's natural probabilistic interpretation // is explained in "From RankNet to LambdaRank to LambdaMART: An Overview" by Chris Burges. // The correctness of this scaling depends upon the gradient calculation in diff --git a/src/Microsoft.ML.FastTree/FastTreeRanking.cs b/src/Microsoft.ML.FastTree/FastTreeRanking.cs index 9fe6023764..0c967ab1e5 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRanking.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRanking.cs @@ -2,22 +2,22 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Core.Data; -using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Internallearn; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Training; -using Microsoft.ML.Trainers.FastTree; -using Microsoft.ML.Trainers.FastTree.Internal; using System; using System.Collections.Generic; using System.Linq; using System.Runtime.InteropServices; +using System.Security; using System.Text; +using Microsoft.ML; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Trainers.FastTree; +using Microsoft.ML.Trainers.FastTree.Internal; +using Microsoft.ML.Training; // REVIEW: Do we really need all these names? [assembly: LoadableClass(FastTreeRankingTrainer.Summary, typeof(FastTreeRankingTrainer), typeof(FastTreeRankingTrainer.Arguments), @@ -33,9 +33,9 @@ "frrank", "btrank")] -[assembly: LoadableClass(typeof(FastTreeRankingPredictor), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(FastTreeRankingModelParameters), null, typeof(SignatureLoadModel), "FastTree Ranking Executor", - FastTreeRankingPredictor.LoaderSignature)] + FastTreeRankingModelParameters.LoaderSignature)] [assembly: LoadableClass(typeof(void), typeof(FastTree), null, typeof(SignatureEntryPointModule), "FastTree")] @@ -43,7 +43,7 @@ namespace Microsoft.ML.Trainers.FastTree { /// public sealed partial class FastTreeRankingTrainer - : BoostingFastTreeTrainerBase, FastTreeRankingPredictor> + : BoostingFastTreeTrainerBase, FastTreeRankingModelParameters> { internal const string LoadNameValue = "FastTreeRanking"; internal const string UserNameValue = "FastTree (Boosted Trees) Ranking"; @@ -82,7 +82,7 @@ public FastTreeRankingTrainer(IHostEnvironment env, int minDatapointsInLeaves = Defaults.MinDocumentsInLeaves, double learningRate = Defaults.LearningRates, Action advancedSettings = null) - : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, numLeaves, numTrees, minDatapointsInLeaves, learningRate, advancedSettings) + : base(env, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weightColumn, groupIdColumn, numLeaves, numTrees, minDatapointsInLeaves, learningRate, advancedSettings) { Host.CheckNonEmpty(groupIdColumn, nameof(groupIdColumn)); } @@ -91,13 +91,13 @@ public FastTreeRankingTrainer(IHostEnvironment env, /// Initializes a new instance of by using the legacy class. /// internal FastTreeRankingTrainer(IHostEnvironment env, Arguments args) - : base(env, args, TrainerUtils.MakeR4ScalarLabel(args.LabelColumn)) + : base(env, args, TrainerUtils.MakeR4ScalarColumn(args.LabelColumn)) { } protected override void CheckLabelCompatible(SchemaShape.Column labelCol) { - Contracts.AssertValue(labelCol); + Contracts.Assert(labelCol.IsValid); Action error = () => throw Host.ExceptSchemaMismatch(nameof(labelCol), RoleMappedSchema.ColumnRole.Label.Value, labelCol.Name, "R4 or a Key", labelCol.GetTypeString()); @@ -112,7 +112,7 @@ protected override float GetMaxLabel() return GetLabelGains().Length - 1; } - private protected override FastTreeRankingPredictor TrainModelCore(TrainContext context) + private protected override FastTreeRankingModelParameters TrainModelCore(TrainContext context) { Host.CheckValue(context, nameof(context)); var trainData = context.TrainingSet; @@ -124,9 +124,9 @@ private protected override FastTreeRankingPredictor TrainModelCore(TrainContext var maxLabel = GetLabelGains().Length - 1; ConvertData(trainData); TrainCore(ch); - FeatureCount = trainData.Schema.Feature.Type.ValueCount; + FeatureCount = trainData.Schema.Feature.Value.Type.ValueCount; } - return new FastTreeRankingPredictor(Host, TrainedEnsemble, FeatureCount, InnerArgs); + return new FastTreeRankingModelParameters(Host, TrainedEnsemble, FeatureCount, InnerArgs); } private Double[] GetLabelGains() @@ -454,10 +454,10 @@ protected override string GetTestGraphHeader() return headerBuilder.ToString(); } - protected override RankingPredictionTransformer MakeTransformer(FastTreeRankingPredictor model, Schema trainSchema) - => new RankingPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); + protected override RankingPredictionTransformer MakeTransformer(FastTreeRankingModelParameters model, Schema trainSchema) + => new RankingPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); - public RankingPredictionTransformer Train(IDataView trainData, IDataView validationData = null) + public RankingPredictionTransformer Train(IDataView trainData, IDataView validationData = null) => TrainTransformer(trainData, validationData); protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) @@ -1090,7 +1090,7 @@ private static void PermutationSort(int[] permutation, double[] scores, short[] })); } - [DllImport("FastTreeNative", EntryPoint = "C_GetDerivatives", CallingConvention = CallingConvention.StdCall, CharSet = CharSet.Ansi)] + [DllImport("FastTreeNative", EntryPoint = "C_GetDerivatives", CharSet = CharSet.Ansi), SuppressUnmanagedCodeSecurity] private static extern unsafe void GetDerivatives( int numDocuments, int begin, int* pPermutation, short* pLabels, double* pScores, double* pLambdas, double* pWeights, double* pDiscount, @@ -1104,10 +1104,10 @@ private static extern unsafe void GetDerivatives( } } - public sealed class FastTreeRankingPredictor : FastTreePredictionWrapper + public sealed class FastTreeRankingModelParameters : TreeEnsembleModelParameters { - public const string LoaderSignature = "FastTreeRankerExec"; - public const string RegistrationName = "FastTreeRankingPredictor"; + internal const string LoaderSignature = "FastTreeRankerExec"; + internal const string RegistrationName = "FastTreeRankingPredictor"; private static VersionInfo GetVersionInfo() { @@ -1121,7 +1121,7 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010004, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(FastTreeRankingPredictor).Assembly.FullName); + loaderAssemblyName: typeof(FastTreeRankingModelParameters).Assembly.FullName); } protected override uint VerNumFeaturesSerialized => 0x00010002; @@ -1130,25 +1130,25 @@ private static VersionInfo GetVersionInfo() protected override uint VerCategoricalSplitSerialized => 0x00010005; - internal FastTreeRankingPredictor(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) + public FastTreeRankingModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) { } - private FastTreeRankingPredictor(IHostEnvironment env, ModelLoadContext ctx) + private FastTreeRankingModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, RegistrationName, ctx, GetVersionInfo()) { } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { base.SaveCore(ctx); ctx.SetVersionInfo(GetVersionInfo()); } - public static FastTreeRankingPredictor Create(IHostEnvironment env, ModelLoadContext ctx) + private static FastTreeRankingModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) { - return new FastTreeRankingPredictor(env, ctx); + return new FastTreeRankingModelParameters(env, ctx); } public override PredictionKind PredictionKind => PredictionKind.Ranking; diff --git a/src/Microsoft.ML.FastTree/FastTreeRegression.cs b/src/Microsoft.ML.FastTree/FastTreeRegression.cs index 5b6aac44f7..2cb321fafe 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRegression.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRegression.cs @@ -2,19 +2,18 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; +using System.Linq; +using System.Text; +using Microsoft.ML; using Microsoft.ML.Core.Data; using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Internallearn; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Training; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Model; using Microsoft.ML.Trainers.FastTree; using Microsoft.ML.Trainers.FastTree.Internal; -using System; -using System.Linq; -using System.Text; +using Microsoft.ML.Training; [assembly: LoadableClass(FastTreeRegressionTrainer.Summary, typeof(FastTreeRegressionTrainer), typeof(FastTreeRegressionTrainer.Arguments), new[] { typeof(SignatureRegressorTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer), typeof(SignatureFeatureScorerTrainer) }, @@ -28,15 +27,15 @@ "frr", "btr")] -[assembly: LoadableClass(typeof(FastTreeRegressionPredictor), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(FastTreeRegressionModelParameters), null, typeof(SignatureLoadModel), "FastTree Regression Executor", - FastTreeRegressionPredictor.LoaderSignature)] + FastTreeRegressionModelParameters.LoaderSignature)] namespace Microsoft.ML.Trainers.FastTree { /// public sealed partial class FastTreeRegressionTrainer - : BoostingFastTreeTrainerBase, FastTreeRegressionPredictor> + : BoostingFastTreeTrainerBase, FastTreeRegressionModelParameters> { public const string LoadNameValue = "FastTreeRegression"; internal const string UserNameValue = "FastTree (Boosted Trees) Regression"; @@ -73,7 +72,7 @@ public FastTreeRegressionTrainer(IHostEnvironment env, int minDatapointsInLeaves = Defaults.MinDocumentsInLeaves, double learningRate = Defaults.LearningRates, Action advancedSettings = null) - : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, null, numLeaves, numTrees, minDatapointsInLeaves, learningRate, advancedSettings) + : base(env, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weightColumn, null, numLeaves, numTrees, minDatapointsInLeaves, learningRate, advancedSettings) { } @@ -81,11 +80,11 @@ public FastTreeRegressionTrainer(IHostEnvironment env, /// Initializes a new instance of by using the legacy class. /// internal FastTreeRegressionTrainer(IHostEnvironment env, Arguments args) - : base(env, args, TrainerUtils.MakeR4ScalarLabel(args.LabelColumn)) + : base(env, args, TrainerUtils.MakeR4ScalarColumn(args.LabelColumn)) { } - private protected override FastTreeRegressionPredictor TrainModelCore(TrainContext context) + private protected override FastTreeRegressionModelParameters TrainModelCore(TrainContext context) { Host.CheckValue(context, nameof(context)); var trainData = context.TrainingSet; @@ -97,11 +96,11 @@ private protected override FastTreeRegressionPredictor TrainModelCore(TrainConte trainData.CheckRegressionLabel(); trainData.CheckFeatureFloatVector(); trainData.CheckOptFloatWeight(); - FeatureCount = trainData.Schema.Feature.Type.ValueCount; + FeatureCount = trainData.Schema.Feature.Value.Type.ValueCount; ConvertData(trainData); TrainCore(ch); } - return new FastTreeRegressionPredictor(Host, TrainedEnsemble, FeatureCount, InnerArgs); + return new FastTreeRegressionModelParameters(Host, TrainedEnsemble, FeatureCount, InnerArgs); } protected override void CheckArgs(IChannel ch) @@ -164,10 +163,10 @@ protected override Test ConstructTestForTrainingData() return new RegressionTest(ConstructScoreTracker(TrainSet)); } - protected override RegressionPredictionTransformer MakeTransformer(FastTreeRegressionPredictor model, Schema trainSchema) - => new RegressionPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); + protected override RegressionPredictionTransformer MakeTransformer(FastTreeRegressionModelParameters model, Schema trainSchema) + => new RegressionPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); - public RegressionPredictionTransformer Train(IDataView trainData, IDataView validationData = null) + public RegressionPredictionTransformer Train(IDataView trainData, IDataView validationData = null) => TrainTransformer(trainData, validationData); protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) @@ -441,10 +440,10 @@ protected override void GetGradientInOneQuery(int query, int threadIndex) } } - public sealed class FastTreeRegressionPredictor : FastTreePredictionWrapper + public sealed class FastTreeRegressionModelParameters : TreeEnsembleModelParameters { - public const string LoaderSignature = "FastTreeRegressionExec"; - public const string RegistrationName = "FastTreeRegressionPredictor"; + internal const string LoaderSignature = "FastTreeRegressionExec"; + internal const string RegistrationName = "FastTreeRegressionPredictor"; private static VersionInfo GetVersionInfo() { @@ -458,7 +457,7 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010004, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(FastTreeRegressionPredictor).Assembly.FullName); + loaderAssemblyName: typeof(FastTreeRegressionModelParameters).Assembly.FullName); } protected override uint VerNumFeaturesSerialized => 0x00010002; @@ -467,28 +466,28 @@ private static VersionInfo GetVersionInfo() protected override uint VerCategoricalSplitSerialized => 0x00010005; - internal FastTreeRegressionPredictor(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) + public FastTreeRegressionModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) { } - private FastTreeRegressionPredictor(IHostEnvironment env, ModelLoadContext ctx) + private FastTreeRegressionModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, RegistrationName, ctx, GetVersionInfo()) { } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { base.SaveCore(ctx); ctx.SetVersionInfo(GetVersionInfo()); } - public static FastTreeRegressionPredictor Create(IHostEnvironment env, ModelLoadContext ctx) + private static FastTreeRegressionModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - return new FastTreeRegressionPredictor(env, ctx); + return new FastTreeRegressionModelParameters(env, ctx); } public override PredictionKind PredictionKind => PredictionKind.Regression; diff --git a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs index 1a40239082..8e9234c808 100644 --- a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs +++ b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs @@ -2,20 +2,19 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; +using System.Linq; +using System.Text; +using Microsoft.ML; using Microsoft.ML.Core.Data; using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Internallearn; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Training; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; using Microsoft.ML.Trainers.FastTree; using Microsoft.ML.Trainers.FastTree.Internal; -using System; -using System.Linq; -using System.Text; +using Microsoft.ML.Training; [assembly: LoadableClass(FastTreeTweedieTrainer.Summary, typeof(FastTreeTweedieTrainer), typeof(FastTreeTweedieTrainer.Arguments), new[] { typeof(SignatureRegressorTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer), typeof(SignatureFeatureScorerTrainer) }, @@ -23,9 +22,9 @@ FastTreeTweedieTrainer.LoadNameValue, FastTreeTweedieTrainer.ShortName)] -[assembly: LoadableClass(typeof(FastTreeTweediePredictor), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(FastTreeTweedieModelParameters), null, typeof(SignatureLoadModel), "FastTree Tweedie Regression Executor", - FastTreeTweediePredictor.LoaderSignature)] + FastTreeTweedieModelParameters.LoaderSignature)] namespace Microsoft.ML.Trainers.FastTree { @@ -34,7 +33,7 @@ namespace Microsoft.ML.Trainers.FastTree // https://arxiv.org/pdf/1508.06378.pdf /// public sealed partial class FastTreeTweedieTrainer - : BoostingFastTreeTrainerBase, FastTreeTweediePredictor> + : BoostingFastTreeTrainerBase, FastTreeTweedieModelParameters> { internal const string LoadNameValue = "FastTreeTweedieRegression"; internal const string UserNameValue = "FastTree (Boosted Trees) Tweedie Regression"; @@ -70,7 +69,7 @@ public FastTreeTweedieTrainer(IHostEnvironment env, int minDatapointsInLeaves = Defaults.MinDocumentsInLeaves, double learningRate = Defaults.LearningRates, Action advancedSettings = null) - : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, null, numLeaves, numTrees, minDatapointsInLeaves, learningRate, advancedSettings) + : base(env, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weightColumn, null, numLeaves, numTrees, minDatapointsInLeaves, learningRate, advancedSettings) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); @@ -82,12 +81,12 @@ public FastTreeTweedieTrainer(IHostEnvironment env, /// Initializes a new instance of by using the legacy class. /// internal FastTreeTweedieTrainer(IHostEnvironment env, Arguments args) - : base(env, args, TrainerUtils.MakeR4ScalarLabel(args.LabelColumn)) + : base(env, args, TrainerUtils.MakeR4ScalarColumn(args.LabelColumn)) { Initialize(); } - private protected override FastTreeTweediePredictor TrainModelCore(TrainContext context) + private protected override FastTreeTweedieModelParameters TrainModelCore(TrainContext context) { Host.CheckValue(context, nameof(context)); var trainData = context.TrainingSet; @@ -100,11 +99,11 @@ private protected override FastTreeTweediePredictor TrainModelCore(TrainContext trainData.CheckRegressionLabel(); trainData.CheckFeatureFloatVector(); trainData.CheckOptFloatWeight(); - FeatureCount = trainData.Schema.Feature.Type.ValueCount; + FeatureCount = trainData.Schema.Feature.Value.Type.ValueCount; ConvertData(trainData); TrainCore(ch); } - return new FastTreeTweediePredictor(Host, TrainedEnsemble, FeatureCount, InnerArgs); + return new FastTreeTweedieModelParameters(Host, TrainedEnsemble, FeatureCount, InnerArgs); } protected override void CheckArgs(IChannel ch) @@ -316,10 +315,10 @@ protected override void Train(IChannel ch) PrintTestGraph(ch); } - protected override RegressionPredictionTransformer MakeTransformer(FastTreeTweediePredictor model, Schema trainSchema) - => new RegressionPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); + protected override RegressionPredictionTransformer MakeTransformer(FastTreeTweedieModelParameters model, Schema trainSchema) + => new RegressionPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); - public RegressionPredictionTransformer Train(IDataView trainData, IDataView validationData = null) + public RegressionPredictionTransformer Train(IDataView trainData, IDataView validationData = null) => TrainTransformer(trainData, validationData); protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) @@ -446,10 +445,10 @@ protected override void GetGradientInOneQuery(int query, int threadIndex) } } - public sealed class FastTreeTweediePredictor : FastTreePredictionWrapper + public sealed class FastTreeTweedieModelParameters : TreeEnsembleModelParameters { - public const string LoaderSignature = "FastTreeTweedieExec"; - public const string RegistrationName = "FastTreeTweediePredictor"; + internal const string LoaderSignature = "FastTreeTweedieExec"; + internal const string RegistrationName = "FastTreeTweediePredictor"; private static VersionInfo GetVersionInfo() { @@ -461,7 +460,7 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010002, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(FastTreeTweediePredictor).Assembly.FullName); + loaderAssemblyName: typeof(FastTreeTweedieModelParameters).Assembly.FullName); } protected override uint VerNumFeaturesSerialized => 0x00010001; @@ -470,28 +469,28 @@ private static VersionInfo GetVersionInfo() protected override uint VerCategoricalSplitSerialized => 0x00010003; - internal FastTreeTweediePredictor(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) + public FastTreeTweedieModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) { } - private FastTreeTweediePredictor(IHostEnvironment env, ModelLoadContext ctx) + private FastTreeTweedieModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, RegistrationName, ctx, GetVersionInfo()) { } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { base.SaveCore(ctx); ctx.SetVersionInfo(GetVersionInfo()); } - public static FastTreeTweediePredictor Create(IHostEnvironment env, ModelLoadContext ctx) + private static FastTreeTweedieModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - return new FastTreeTweediePredictor(env, ctx); + return new FastTreeTweedieModelParameters(env, ctx); } protected override void Map(in VBuffer src, ref float dst) diff --git a/src/Microsoft.ML.FastTree/GamClassification.cs b/src/Microsoft.ML.FastTree/GamClassification.cs index ee094a384d..3f234f2ca8 100644 --- a/src/Microsoft.ML.FastTree/GamClassification.cs +++ b/src/Microsoft.ML.FastTree/GamClassification.cs @@ -2,19 +2,19 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; +using System.Threading.Tasks; +using Microsoft.ML; +using Microsoft.ML.Calibrator; +using Microsoft.ML.CommandLine; using Microsoft.ML.Core.Data; using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Calibration; -using Microsoft.ML.Runtime.Internal.Internallearn; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Training; +using Microsoft.ML.Internal.Calibration; +using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Model; using Microsoft.ML.Trainers.FastTree; using Microsoft.ML.Trainers.FastTree.Internal; -using System; -using System.Threading.Tasks; +using Microsoft.ML.Training; [assembly: LoadableClass(BinaryClassificationGamTrainer.Summary, typeof(BinaryClassificationGamTrainer), typeof(BinaryClassificationGamTrainer.Arguments), @@ -23,9 +23,9 @@ BinaryClassificationGamTrainer.LoadNameValue, BinaryClassificationGamTrainer.ShortName, DocName = "trainer/GAM.md")] -[assembly: LoadableClass(typeof(IPredictorProducing), typeof(BinaryClassificationGamPredictor), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(IPredictorProducing), typeof(BinaryClassificationGamModelParameters), null, typeof(SignatureLoadModel), "GAM Binary Class Predictor", - BinaryClassificationGamPredictor.LoaderSignature)] + BinaryClassificationGamModelParameters.LoaderSignature)] namespace Microsoft.ML.Trainers.FastTree { @@ -80,7 +80,7 @@ public BinaryClassificationGamTrainer(IHostEnvironment env, _sigmoidParameter = 1; } - internal override void CheckLabel(RoleMappedData data) + private protected override void CheckLabel(RoleMappedData data) { data.CheckBinaryLabel(); } @@ -108,7 +108,7 @@ private static bool[] ConvertTargetsToBool(double[] targets) private protected override IPredictorProducing TrainModelCore(TrainContext context) { TrainBase(context); - var predictor = new BinaryClassificationGamPredictor(Host, InputLength, TrainSet, + var predictor = new BinaryClassificationGamModelParameters(Host, InputLength, TrainSet, MeanEffect, BinEffects, FeatureMap); var calibrator = new PlattCalibrator(Host, -1.0 * _sigmoidParameter, 0); return new CalibratedPredictor(Host, predictor, calibrator); @@ -157,19 +157,19 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc } } - public class BinaryClassificationGamPredictor : GamPredictorBase, IPredictorProducing + public class BinaryClassificationGamModelParameters : GamModelParametersBase, IPredictorProducing { - public const string LoaderSignature = "BinaryClassGamPredictor"; + internal const string LoaderSignature = "BinaryClassGamPredictor"; public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; - public BinaryClassificationGamPredictor(IHostEnvironment env, int inputLength, Dataset trainset, + internal BinaryClassificationGamModelParameters(IHostEnvironment env, int inputLength, Dataset trainset, double meanEffect, double[][] binEffects, int[] featureMap) : base(env, LoaderSignature, inputLength, trainset, meanEffect, binEffects, featureMap) { } - private BinaryClassificationGamPredictor(IHostEnvironment env, ModelLoadContext ctx) + private BinaryClassificationGamModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, LoaderSignature, ctx) { } - public static VersionInfo GetVersionInfo() + private static VersionInfo GetVersionInfo() { return new VersionInfo( modelSignature: "GAM BINP", @@ -177,16 +177,16 @@ public static VersionInfo GetVersionInfo() verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(BinaryClassificationGamPredictor).Assembly.FullName); + loaderAssemblyName: typeof(BinaryClassificationGamModelParameters).Assembly.FullName); } - public static IPredictorProducing Create(IHostEnvironment env, ModelLoadContext ctx) + private static IPredictorProducing Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - var predictor = new BinaryClassificationGamPredictor(env, ctx); + var predictor = new BinaryClassificationGamModelParameters(env, ctx); ICalibrator calibrator; ctx.LoadModelOrNull(env, out calibrator, @"Calibrator"); if (calibrator == null) @@ -194,12 +194,12 @@ public static IPredictorProducing Create(IHostEnvironment env, ModelLoadC return new SchemaBindableCalibratedPredictor(env, predictor, calibrator); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); ctx.SetVersionInfo(GetVersionInfo()); - base.Save(ctx); + base.SaveCore(ctx); } } } diff --git a/src/Microsoft.ML.FastTree/GamRegression.cs b/src/Microsoft.ML.FastTree/GamRegression.cs index a0e4ead4be..315b94ec78 100644 --- a/src/Microsoft.ML.FastTree/GamRegression.cs +++ b/src/Microsoft.ML.FastTree/GamRegression.cs @@ -2,17 +2,16 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; +using Microsoft.ML; +using Microsoft.ML.CommandLine; using Microsoft.ML.Core.Data; using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Internallearn; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Training; +using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Model; using Microsoft.ML.Trainers.FastTree; using Microsoft.ML.Trainers.FastTree.Internal; -using System; +using Microsoft.ML.Training; [assembly: LoadableClass(RegressionGamTrainer.Summary, typeof(RegressionGamTrainer), typeof(RegressionGamTrainer.Arguments), @@ -21,13 +20,13 @@ RegressionGamTrainer.LoadNameValue, RegressionGamTrainer.ShortName, DocName = "trainer/GAM.md")] -[assembly: LoadableClass(typeof(RegressionGamPredictor), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(RegressionGamModelParameters), null, typeof(SignatureLoadModel), "GAM Regression Predictor", - RegressionGamPredictor.LoaderSignature)] + RegressionGamModelParameters.LoaderSignature)] namespace Microsoft.ML.Trainers.FastTree { - public sealed class RegressionGamTrainer : GamTrainerBase, RegressionGamPredictor> + public sealed class RegressionGamTrainer : GamTrainerBase, RegressionGamModelParameters> { public partial class Arguments : ArgumentsBase { @@ -43,7 +42,7 @@ public partial class Arguments : ArgumentsBase public override PredictionKind PredictionKind => PredictionKind.Regression; internal RegressionGamTrainer(IHostEnvironment env, Arguments args) - : base(env, args, LoadNameValue, TrainerUtils.MakeR4ScalarLabel(args.LabelColumn)) { } + : base(env, args, LoadNameValue, TrainerUtils.MakeR4ScalarColumn(args.LabelColumn)) { } /// /// Initializes a new instance of @@ -64,19 +63,19 @@ public RegressionGamTrainer(IHostEnvironment env, double learningRate = GamDefaults.LearningRates, int maxBins = GamDefaults.MaxBins, Action advancedSettings = null) - : base(env, LoadNameValue, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, numIterations, learningRate, maxBins, advancedSettings) + : base(env, LoadNameValue, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weightColumn, numIterations, learningRate, maxBins, advancedSettings) { } - internal override void CheckLabel(RoleMappedData data) + private protected override void CheckLabel(RoleMappedData data) { data.CheckRegressionLabel(); } - private protected override RegressionGamPredictor TrainModelCore(TrainContext context) + private protected override RegressionGamModelParameters TrainModelCore(TrainContext context) { TrainBase(context); - return new RegressionGamPredictor(Host, InputLength, TrainSet, MeanEffect, BinEffects, FeatureMap); + return new RegressionGamModelParameters(Host, InputLength, TrainSet, MeanEffect, BinEffects, FeatureMap); } protected override ObjectiveFunctionBase CreateObjectiveFunction() @@ -92,10 +91,10 @@ protected override void DefinePruningTest() PruningTest = new TestHistory(validTest, PruningLossIndex); } - protected override RegressionPredictionTransformer MakeTransformer(RegressionGamPredictor model, Schema trainSchema) - => new RegressionPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); + protected override RegressionPredictionTransformer MakeTransformer(RegressionGamModelParameters model, Schema trainSchema) + => new RegressionPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); - public RegressionPredictionTransformer Train(IDataView trainData, IDataView validationData = null) + public RegressionPredictionTransformer Train(IDataView trainData, IDataView validationData = null) => TrainTransformer(trainData, validationData); protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) @@ -107,19 +106,19 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc } } - public class RegressionGamPredictor : GamPredictorBase + public class RegressionGamModelParameters : GamModelParametersBase { - public const string LoaderSignature = "RegressionGamPredictor"; + internal const string LoaderSignature = "RegressionGamPredictor"; public override PredictionKind PredictionKind => PredictionKind.Regression; - public RegressionGamPredictor(IHostEnvironment env, int inputLength, Dataset trainset, + internal RegressionGamModelParameters(IHostEnvironment env, int inputLength, Dataset trainset, double meanEffect, double[][] binEffects, int[] featureMap) : base(env, LoaderSignature, inputLength, trainset, meanEffect, binEffects, featureMap) { } - private RegressionGamPredictor(IHostEnvironment env, ModelLoadContext ctx) + private RegressionGamModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, LoaderSignature, ctx) { } - public static VersionInfo GetVersionInfo() + private static VersionInfo GetVersionInfo() { return new VersionInfo( modelSignature: "GAM REGP", @@ -127,24 +126,24 @@ public static VersionInfo GetVersionInfo() verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(RegressionGamPredictor).Assembly.FullName); + loaderAssemblyName: typeof(RegressionGamModelParameters).Assembly.FullName); } - public static RegressionGamPredictor Create(IHostEnvironment env, ModelLoadContext ctx) + private static RegressionGamModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - return new RegressionGamPredictor(env, ctx); + return new RegressionGamModelParameters(env, ctx); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); ctx.SetVersionInfo(GetVersionInfo()); - base.Save(ctx); + base.SaveCore(ctx); } } } diff --git a/src/Microsoft.ML.FastTree/GamTrainer.cs b/src/Microsoft.ML.FastTree/GamTrainer.cs index 05fc8eab4f..3c6edf0acc 100644 --- a/src/Microsoft.ML.FastTree/GamTrainer.cs +++ b/src/Microsoft.ML.FastTree/GamTrainer.cs @@ -2,29 +2,30 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Core.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Command; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Trainers.FastTree; -using Microsoft.ML.Trainers.FastTree.Internal; -using Microsoft.ML.Runtime.Internal.Calibration; -using Microsoft.ML.Runtime.Internal.CpuMath; -using Microsoft.ML.Runtime.Internal.Internallearn; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Training; using System; using System.Collections.Generic; using System.IO; using System.Linq; using System.Threading; +using Microsoft.ML; +using Microsoft.ML.Calibrator; +using Microsoft.ML.Command; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Calibration; +using Microsoft.ML.Internal.CpuMath; +using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Trainers.FastTree; +using Microsoft.ML.Trainers.FastTree.Internal; +using Microsoft.ML.Training; using Timer = Microsoft.ML.Trainers.FastTree.Internal.Timer; -[assembly: LoadableClass(typeof(GamPredictorBase.VisualizationCommand), typeof(GamPredictorBase.VisualizationCommand.Arguments), typeof(SignatureCommand), - "GAM Vizualization Command", GamPredictorBase.VisualizationCommand.LoadName, "gamviz", DocName = "command/GamViz.md")] +[assembly: LoadableClass(typeof(GamModelParametersBase.VisualizationCommand), typeof(GamModelParametersBase.VisualizationCommand.Arguments), typeof(SignatureCommand), + "GAM Vizualization Command", GamModelParametersBase.VisualizationCommand.LoadName, "gamviz", DocName = "command/GamViz.md")] [assembly: LoadableClass(typeof(void), typeof(Gam), null, typeof(SignatureEntryPointModule), "GAM")] @@ -226,7 +227,7 @@ private protected void TrainBase(TrainContext context) DefineScoreTrackers(); if (HasValidSet) DefinePruningTest(); - InputLength = context.TrainingSet.Schema.Feature.Type.ValueCount; + InputLength = context.TrainingSet.Schema.Feature.Value.Type.ValueCount; TrainCore(ch); } @@ -241,7 +242,7 @@ private void DefineScoreTrackers() protected abstract void DefinePruningTest(); - internal abstract void CheckLabel(RoleMappedData data); + private protected abstract void CheckLabel(RoleMappedData data); private void ConvertData(RoleMappedData trainData, RoleMappedData validationData) { @@ -263,13 +264,11 @@ private void ConvertData(RoleMappedData trainData, RoleMappedData validationData private bool UseTranspose(bool? useTranspose, RoleMappedData data) { Host.AssertValue(data); - Host.AssertValue(data.Schema.Feature); + Host.Assert(data.Schema.Feature.HasValue); if (useTranspose.HasValue) return useTranspose.Value; - - ITransposeDataView td = data.Data as ITransposeDataView; - return td != null && td.TransposeSchema.GetSlotType(data.Schema.Feature.Index) != null; + return data.Data is ITransposeDataView td && td.TransposeSchema.GetSlotType(data.Schema.Feature.Value.Index) != null; } private void TrainCore(IChannel ch) @@ -648,14 +647,15 @@ public Stump(uint splitPoint, double lteValue, double gtValue) } } - public abstract class GamPredictorBase : PredictorBase, - IValueMapper, ICanSaveModel, ICanSaveInTextFormat, ICanSaveSummary + public abstract class GamModelParametersBase : ModelParametersBase, IValueMapper, ICalculateFeatureContribution, + IFeatureContributionMapper, ICanSaveInTextFormat, ICanSaveSummary, ICanSaveInIniFormat { private readonly double[][] _binUpperBounds; private readonly double[][] _binEffects; public readonly double Intercept; private readonly int _numFeatures; private readonly ColumnType _inputType; + private readonly ColumnType _outputType; // These would be the bins for a totally sparse input. private readonly int[] _binsAtAllZero; // The output value for all zeros @@ -665,11 +665,12 @@ public abstract class GamPredictorBase : PredictorBase, private readonly int _inputLength; private readonly Dictionary _inputFeatureToDatasetFeatureMap; - public ColumnType InputType => _inputType; + ColumnType IValueMapper.InputType => _inputType; + ColumnType IValueMapper.OutputType => _outputType; - public ColumnType OutputType => NumberType.Float; + public FeatureContributionCalculator FeatureContributionClaculator => new FeatureContributionCalculator(this); - private protected GamPredictorBase(IHostEnvironment env, string name, + private protected GamModelParametersBase(IHostEnvironment env, string name, int inputLength, Dataset trainSet, double meanEffect, double[][] binEffects, int[] featureMap) : base(env, name) { @@ -683,6 +684,7 @@ private protected GamPredictorBase(IHostEnvironment env, string name, _numFeatures = binEffects.Length; _inputType = new VectorType(NumberType.Float, _inputLength); + _outputType = NumberType.Float; _featureMap = featureMap; Intercept = meanEffect; @@ -739,7 +741,7 @@ private protected GamPredictorBase(IHostEnvironment env, string name, } } - protected GamPredictorBase(IHostEnvironment env, string name, ModelLoadContext ctx) + protected GamModelParametersBase(IHostEnvironment env, string name, ModelLoadContext ctx) : base(env, name) { Host.CheckValue(ctx, nameof(ctx)); @@ -787,9 +789,10 @@ protected GamPredictorBase(IHostEnvironment env, string name, ModelLoadContext c } _inputType = new VectorType(NumberType.Float, _inputLength); + _outputType = NumberType.Float; } - public override void Save(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); @@ -816,7 +819,7 @@ public override void Save(ModelSaveContext ctx) } } - public ValueMapper GetMapper() + ValueMapper IValueMapper.GetMapper() { Host.Check(typeof(TIn) == typeof(VBuffer)); Host.Check(typeof(TOut) == typeof(float)); @@ -831,6 +834,7 @@ private void Map(in VBuffer features, ref float response) double value = Intercept; var featuresValues = features.GetValues(); + if (features.IsDense) { for (int i = 0; i < featuresValues.Length; ++i) @@ -855,60 +859,6 @@ private void Map(in VBuffer features, ref float response) response = (float)value; } - /// - /// Returns a vector of feature contributions for a given example. - /// is used as a buffer to accumulate the contributions across trees. - /// If is null, it will be created, otherwise it will be reused. - /// - internal void GetFeatureContributions(in VBuffer features, ref VBuffer contribs, ref BufferBuilder builder) - { - if (builder == null) - builder = new BufferBuilder(R4Adder.Instance); - - // The model is Intercept + Features - builder.Reset(features.Length + 1, false); - builder.AddFeature(0, (float)Intercept); - - var featuresValues = features.GetValues(); - if (features.IsDense) - { - for (int i = 0; i < featuresValues.Length; ++i) - { - if (_inputFeatureToDatasetFeatureMap.TryGetValue(i, out int j)) - builder.AddFeature(i+1, (float) GetBinEffect(j, featuresValues[i])); - } - } - else - { - int k = -1; - var featuresIndices = features.GetIndices(); - int index = featuresIndices[++k]; - for (int i = 0; i < _numFeatures; ++i) - { - if (_inputFeatureToDatasetFeatureMap.TryGetValue(i, out int j)) - { - double value; - if (i == index) - { - // Get the computed value - value = GetBinEffect(j, featuresValues[index]); - // Increment index to the next feature - if (k < featuresIndices.Length - 1) - index = featuresIndices[++k]; - } - else - // For features not defined, the impact is the impact at 0 - value = GetBinEffect(i, 0); - builder.AddFeature(i + 1, (float)value); - } - } - } - - builder.GetResult(ref contribs); - - return; - } - internal double GetFeatureBinsAndScore(in VBuffer features, int[] bins) { Host.CheckParam(features.Length == _inputLength, nameof(features)); @@ -1000,7 +950,7 @@ public double[] GetFeatureWeights(int featureIndex) return featureWeights; } - public void SaveAsText(TextWriter writer, RoleMappedSchema schema) + void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema) { Host.CheckValue(writer, nameof(writer)); Host.CheckValueOrNull(schema); @@ -1041,14 +991,156 @@ public void SaveAsText(TextWriter writer, RoleMappedSchema schema) } } - public void SaveSummary(TextWriter writer, RoleMappedSchema schema) + void ICanSaveSummary.SaveSummary(TextWriter writer, RoleMappedSchema schema) + { + ((ICanSaveInTextFormat)this).SaveAsText(writer, schema); + } + + ValueMapper> IFeatureContributionMapper.GetFeatureContributionMapper + (int top, int bottom, bool normalize) + { + Contracts.Check(typeof(TSrc) == typeof(VBuffer)); + Contracts.Check(typeof(TDstContributions) == typeof(VBuffer)); + + ValueMapper, VBuffer> del = + (in VBuffer srcFeatures, ref VBuffer dstContributions) => + { + GetFeatureContributions(in srcFeatures, ref dstContributions, top, bottom, normalize); + }; + return (ValueMapper>)(Delegate)del; + } + + private void GetFeatureContributions(in VBuffer features, ref VBuffer contributions, + int top, int bottom, bool normalize) + { + var editor = VBufferEditor.Create(ref contributions, features.Length); + + // We need to use dense value of features, b/c the feature contributions could be significant + // even for features with value 0. + var featureIndex = 0; + foreach (var featureValue in features.DenseValues()) + { + float contribution = 0; + if (_inputFeatureToDatasetFeatureMap.TryGetValue(featureIndex, out int j)) + contribution = (float)GetBinEffect(j, featureValue); + editor.Values[featureIndex] = contribution; + featureIndex++; + } + contributions = editor.Commit(); + Numeric.VectorUtils.SparsifyNormalize(ref contributions, top, bottom, normalize); + } + + void ICanSaveInIniFormat.SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICalibrator calibrator) + { + Host.CheckValue(writer, nameof(writer)); + var ensemble = new TreeEnsemble(); + + for (int featureIndex = 0; featureIndex < _numFeatures; featureIndex++) + { + var effects = _binEffects[featureIndex]; + var binThresholds = _binUpperBounds[featureIndex]; + + Host.Assert(effects.Length == binThresholds.Length); + var numLeaves = effects.Length; + var numInternalNodes = numLeaves - 1; + + var splitFeatures = Enumerable.Repeat(featureIndex, numInternalNodes).ToArray(); + var (treeThresholds, lteChild, gtChild) = CreateBalancedTree(numInternalNodes, binThresholds); + var tree = CreateRegressionTree(numLeaves, splitFeatures, treeThresholds, lteChild, gtChild, effects); + ensemble.AddTree(tree); + } + + // Adding the intercept as a dummy tree with the output values being the model intercept, + // works for reaching parity. + var interceptTree = CreateRegressionTree( + numLeaves: 2, + splitFeatures: new[] { 0 }, + rawThresholds: new[] { 0f }, + lteChild: new[] { ~0 }, + gtChild: new[] { ~1 }, + leafValues: new[] { Intercept, Intercept }); + ensemble.AddTree(interceptTree); + + var ini = FastTreeIniFileUtils.TreeEnsembleToIni( + Host, ensemble, schema, calibrator, string.Empty, false, false); + + // Remove the SplitGain values which are all 0. + // It's eaiser to remove them here, than to modify the FastTree code. + var goodLines = ini.Split(new[] { '\n' }).Where(line => !line.StartsWith("SplitGain=")); + ini = string.Join("\n", goodLines); + writer.WriteLine(ini); + } + + // GAM bins should be converted to balanced trees / binary search trees + // so that scoring takes O(log(n)) instead of O(n). The following utility + // creates a balanced tree. + private (float[], int[], int[]) CreateBalancedTree(int numInternalNodes, double[] binThresholds) + { + var binIndices = Enumerable.Range(0, numInternalNodes).ToArray(); + var internalNodeIndices = new List(); + var lteChild = new List(); + var gtChild = new List(); + var internalNodeId = numInternalNodes; + + CreateBalancedTreeRecursive( + 0, binIndices.Length - 1, internalNodeIndices, lteChild, gtChild, ref internalNodeId); + // internalNodeId should have been counted all the way down to 0 (root node) + Host.Assert(internalNodeId == 0); + + var tree = ( + thresholds: internalNodeIndices.Select(x => (float)binThresholds[binIndices[x]]).ToArray(), + lteChild: lteChild.ToArray(), + gtChild: gtChild.ToArray()); + return tree; + } + + private int CreateBalancedTreeRecursive(int lower, int upper, + List internalNodeIndices, List lteChild, List gtChild, ref int internalNodeId) + { + if (lower > upper) + { + // Base case: we've reached a leaf node + Host.Assert(lower == upper + 1); + return ~lower; + } + else + { + // This is postorder traversal algorithm and populating the internalNodeIndices/lte/gt lists in reverse. + // Preorder is the only option, because we need the results of both left/right recursions for populating the lists. + // As a result, lists are populated in reverse, because the root node should be the first item on the lists. + // Binary search tree algorithm (recursive splitting to half) is used for creating balanced tree. + var mid = (lower + upper) / 2; + var left = CreateBalancedTreeRecursive( + lower, mid - 1, internalNodeIndices, lteChild, gtChild, ref internalNodeId); + var right = CreateBalancedTreeRecursive( + mid + 1, upper, internalNodeIndices, lteChild, gtChild, ref internalNodeId); + internalNodeIndices.Insert(0, mid); + lteChild.Insert(0, left); + gtChild.Insert(0, right); + return --internalNodeId; + } + } + private static RegressionTree CreateRegressionTree( + int numLeaves, int[] splitFeatures, float[] rawThresholds, int[] lteChild, int[] gtChild, double[] leafValues) { - SaveAsText(writer, schema); + var numInternalNodes = numLeaves - 1; + return RegressionTree.Create( + numLeaves: numLeaves, + splitFeatures: splitFeatures, + rawThresholds: rawThresholds, + lteChild: lteChild, + gtChild: gtChild.ToArray(), + leafValues: leafValues, + // Ignored arguments + splitGain: new double[numInternalNodes], + defaultValueForMissing: new float[numInternalNodes], + categoricalSplitFeatures: new int[numInternalNodes][], + categoricalSplit: new bool[numInternalNodes]); } /// /// The GAM model visualization command. Because the data access commands must access private members of - /// , it is convenient to have the command itself nested within the base + /// , it is convenient to have the command itself nested within the base /// predictor class. /// internal sealed class VisualizationCommand : DataCommand.ImplBase @@ -1094,7 +1186,7 @@ public override void Run() private sealed class Context { - private readonly GamPredictorBase _pred; + private readonly GamModelParametersBase _pred; private readonly RoleMappedData _data; private readonly VBuffer> _featNames; @@ -1122,9 +1214,9 @@ private sealed class Context /// These are the number of input features, as opposed to the number of features used within GAM /// which may be lower. /// - public int NumFeatures => _pred.InputType.VectorSize; + public int NumFeatures => _pred._inputType.VectorSize; - public Context(IChannel ch, GamPredictorBase pred, RoleMappedData data, IEvaluator eval) + public Context(IChannel ch, GamModelParametersBase pred, RoleMappedData data, IEvaluator eval) { Contracts.AssertValue(ch); ch.AssertValue(pred); @@ -1135,11 +1227,12 @@ public Context(IChannel ch, GamPredictorBase pred, RoleMappedData data, IEvaluat _pred = pred; _data = data; var schema = _data.Schema; - ch.Check(schema.Feature.Type.ValueCount == _pred._inputLength); + var featCol = schema.Feature.Value; + ch.Check(featCol.Type.ValueCount == _pred._inputLength); - int len = schema.Feature.Type.ValueCount; - if (schema.Schema.HasSlotNames(schema.Feature.Index, len)) - schema.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, schema.Feature.Index, ref _featNames); + int len = featCol.Type.ValueCount; + if (featCol.HasSlotNames(len)) + featCol.Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref _featNames); else _featNames = VBufferUtils.CreateEmpty>(len); @@ -1342,7 +1435,7 @@ private FeatureInfo(Context context, int index, int internalIndex, int[] catsMap public static FeatureInfo GetInfoForIndex(Context context, int index) { Contracts.AssertValue(context); - Contracts.Assert(0 <= index && index < context._pred.InputType.ValueCount); + Contracts.Assert(0 <= index && index < context._pred._inputType.ValueCount); lock (context._pred) { int internalIndex; @@ -1385,8 +1478,8 @@ private Context Init(IChannel ch) rawPred = calibrated.SubPredictor; calibrated = rawPred as CalibratedPredictorBase; } - var pred = rawPred as GamPredictorBase; - ch.CheckUserArg(pred != null, nameof(Args.InputModelFile), "Predictor was not a " + nameof(GamPredictorBase)); + var pred = rawPred as GamModelParametersBase; + ch.CheckUserArg(pred != null, nameof(Args.InputModelFile), "Predictor was not a " + nameof(GamModelParametersBase)); var data = new RoleMappedData(loader, schema.GetColumnRoleNames(), opt: true); if (hadCalibrator && !string.IsNullOrWhiteSpace(Args.OutputModelFile)) ch.Warning("If you save the GAM model, only the GAM model, not the wrapping calibrator, will be saved."); @@ -1394,7 +1487,7 @@ private Context Init(IChannel ch) return new Context(ch, pred, data, InitEvaluator(pred)); } - private IEvaluator InitEvaluator(GamPredictorBase pred) + private IEvaluator InitEvaluator(GamModelParametersBase pred) { switch (pred.PredictionKind) { diff --git a/src/Microsoft.ML.FastTree/Properties/AssemblyInfo.cs b/src/Microsoft.ML.FastTree/Properties/AssemblyInfo.cs index a03d7bdab6..cf6d8d8d42 100644 --- a/src/Microsoft.ML.FastTree/Properties/AssemblyInfo.cs +++ b/src/Microsoft.ML.FastTree/Properties/AssemblyInfo.cs @@ -7,4 +7,7 @@ [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Core.Tests" + PublicKey.TestValue)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.LightGBM" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Sweeper" + PublicKey.Value)] + [assembly: WantsToBeBestFriends] diff --git a/src/Microsoft.ML.FastTree/QuantileStatistics.cs b/src/Microsoft.ML.FastTree/QuantileStatistics.cs index b93c1919a9..62d05cdc4a 100644 --- a/src/Microsoft.ML.FastTree/QuantileStatistics.cs +++ b/src/Microsoft.ML.FastTree/QuantileStatistics.cs @@ -2,56 +2,54 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Float = System.Single; - using System; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Internal.Internallearn; +using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime +namespace Microsoft.ML { - public sealed class QuantileStatistics : IQuantileDistribution + public sealed class QuantileStatistics : IQuantileDistribution { - private readonly Float[] _data; - private readonly Float[] _weights; + private readonly float[] _data; + private readonly float[] _weights; //This holds the cumulative sum of _weights to search the rank easily by binary search. - private Float[] _weightedSums; + private float[] _weightedSums; private SummaryStatistics _summaryStatistics; - public Float Minimum + float IDistribution.Minimum { get { if (_data.Length == 0) - return Float.NaN; + return float.NaN; return _data[0]; } } - public Float Maximum + float IDistribution.Maximum { get { if (_data.Length == 0) - return Float.NaN; + return float.NaN; return _data[_data.Length - 1]; } } - public Float Median { get { return GetQuantile(0.5F); } } + float IQuantileDistribution.Median { get { return ((IQuantileDistribution)this).GetQuantile(0.5F); } } - public Float Mean { get { return (Float)SummaryStatistics.Mean; } } + float IDistribution.Mean { get { return (float)SummaryStatistics.Mean; } } - public Float StandardDeviation { get { return (Float)SummaryStatistics.SampleStdDev; } } + float IDistribution.StandardDeviation { get { return (float)SummaryStatistics.SampleStdDev; } } /// /// data array will be modified because of sorting if it is not already sorted yet and this class owns the data. /// Modifying the data outside will lead to erroneous output by this class /// - public QuantileStatistics(Float[] data, Float[] weights = null, bool isSorted = false) + public QuantileStatistics(float[] data, float[] weights = null, bool isSorted = false) { Contracts.CheckValue(data, nameof(data)); Contracts.Check(weights == null || weights.Length == data.Length, "weights"); @@ -69,19 +67,19 @@ public QuantileStatistics(Float[] data, Float[] weights = null, bool isSorted = /// There are many ways to estimate quantile. This implementations is based on R-8, SciPy-(1/3,1/3) /// https://en.wikipedia.org/wiki/Quantile#Estimating_the_quantiles_of_a_population /// - public Float GetQuantile(Float p) + float IQuantileDistribution.GetQuantile(float p) { Contracts.CheckParam(0 <= p && p <= 1, nameof(p), "Probablity argument for Quantile function should be between 0 to 1 inclusive"); if (_data.Length == 0) - return Float.NaN; + return float.NaN; if (p == 0 || _data.Length == 1) return _data[0]; if (p == 1) return _data[_data.Length - 1]; - Float h = GetRank(p); + float h = GetRank(p); if (h <= 1) return _data[0]; @@ -90,12 +88,12 @@ public Float GetQuantile(Float p) return _data[_data.Length - 1]; var hf = (int)h; - return (Float)(_data[hf - 1] + (h - hf) * (_data[hf] - _data[hf - 1])); + return (float)(_data[hf - 1] + (h - hf) * (_data[hf] - _data[hf - 1])); } - public Float[] GetSupportSample(out Float[] weights) + float[] ISampleableDistribution.GetSupportSample(out float[] weights) { - var result = new Float[_data.Length]; + var result = new float[_data.Length]; Array.Copy(_data, result, _data.Length); if (_weights == null) { @@ -103,25 +101,25 @@ public Float[] GetSupportSample(out Float[] weights) } else { - weights = new Float[_data.Length]; + weights = new float[_data.Length]; Array.Copy(_weights, weights, _weights.Length); } return result; } - private Float GetRank(Float p) + private float GetRank(float p) { - const Float oneThird = (Float)1 / 3; + const float oneThird = (float)1 / 3; // holds length of the _data array if the weights is null or holds the sum of weights - Float weightedLength = _data.Length; + float weightedLength = _data.Length; if (_weights != null) { if (_weightedSums == null) { - _weightedSums = new Float[_weights.Length]; + _weightedSums = new float[_weights.Length]; _weightedSums[0] = _weights[0]; for (int i = 1; i < _weights.Length; i++) _weightedSums[i] = _weights[i] + _weightedSums[i - 1]; diff --git a/src/Microsoft.ML.FastTree/RandomForest.cs b/src/Microsoft.ML.FastTree/RandomForest.cs index ab91a02e21..43aa5a0149 100644 --- a/src/Microsoft.ML.FastTree/RandomForest.cs +++ b/src/Microsoft.ML.FastTree/RandomForest.cs @@ -2,10 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; using Microsoft.ML.Core.Data; -using Microsoft.ML.Runtime; using Microsoft.ML.Trainers.FastTree.Internal; -using System; namespace Microsoft.ML.Trainers.FastTree { diff --git a/src/Microsoft.ML.FastTree/RandomForestClassification.cs b/src/Microsoft.ML.FastTree/RandomForestClassification.cs index 21e6f50722..7417a67d5f 100644 --- a/src/Microsoft.ML.FastTree/RandomForestClassification.cs +++ b/src/Microsoft.ML.FastTree/RandomForestClassification.cs @@ -2,20 +2,20 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; +using System.Linq; +using Microsoft.ML; +using Microsoft.ML.Calibrator; +using Microsoft.ML.CommandLine; using Microsoft.ML.Core.Data; using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Calibration; -using Microsoft.ML.Runtime.Internal.Internallearn; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Training; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Calibration; +using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Model; using Microsoft.ML.Trainers.FastTree; using Microsoft.ML.Trainers.FastTree.Internal; -using System; -using System.Linq; +using Microsoft.ML.Training; [assembly: LoadableClass(FastForestClassification.Summary, typeof(FastForestClassification), typeof(FastForestClassification.Arguments), new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer), typeof(SignatureFeatureScorerTrainer) }, @@ -25,9 +25,9 @@ FastForestClassification.ShortName, "ffc")] -[assembly: LoadableClass(typeof(IPredictorProducing), typeof(FastForestClassificationPredictor), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(IPredictorProducing), typeof(FastForestClassificationModelParameters), null, typeof(SignatureLoadModel), "FastForest Binary Executor", - FastForestClassificationPredictor.LoaderSignature)] + FastForestClassificationModelParameters.LoaderSignature)] [assembly: LoadableClass(typeof(void), typeof(FastForest), null, typeof(SignatureEntryPointModule), "FastForest")] @@ -46,11 +46,11 @@ public FastForestArgumentsBase() } } - public sealed class FastForestClassificationPredictor : - FastTreePredictionWrapper + public sealed class FastForestClassificationModelParameters : + TreeEnsembleModelParameters { - public const string LoaderSignature = "FastForestBinaryExec"; - public const string RegistrationName = "FastForestClassificationPredictor"; + internal const string LoaderSignature = "FastForestBinaryExec"; + internal const string RegistrationName = "FastForestClassificationPredictor"; private static VersionInfo GetVersionInfo() { @@ -65,7 +65,7 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010005, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(FastForestClassificationPredictor).Assembly.FullName); + loaderAssemblyName: typeof(FastForestClassificationModelParameters).Assembly.FullName); } protected override uint VerNumFeaturesSerialized => 0x00010003; @@ -79,27 +79,27 @@ private static VersionInfo GetVersionInfo() /// public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; - public FastForestClassificationPredictor(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) + public FastForestClassificationModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) { } - private FastForestClassificationPredictor(IHostEnvironment env, ModelLoadContext ctx) + private FastForestClassificationModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, RegistrationName, ctx, GetVersionInfo()) { } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { base.SaveCore(ctx); ctx.SetVersionInfo(GetVersionInfo()); } - public static IPredictorProducing Create(IHostEnvironment env, ModelLoadContext ctx) + private static IPredictorProducing Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - var predictor = new FastForestClassificationPredictor(env, ctx); + var predictor = new FastForestClassificationModelParameters(env, ctx); ICalibrator calibrator; ctx.LoadModelOrNull(env, out calibrator, @"Calibrator"); if (calibrator == null) @@ -182,7 +182,7 @@ private protected override IPredictorWithFeatureWeights TrainModelCore(Tr trainData.CheckBinaryLabel(); trainData.CheckFeatureFloatVector(); trainData.CheckOptFloatWeight(); - FeatureCount = trainData.Schema.Feature.Type.ValueCount; + FeatureCount = trainData.Schema.Feature.Value.Type.ValueCount; ConvertData(trainData); TrainCore(ch); } @@ -192,7 +192,7 @@ private protected override IPredictorWithFeatureWeights TrainModelCore(Tr // calibrator, transform the scores using that. // REVIEW: Need a way to signal the outside world that we prefer simple sigmoid? - return new FastForestClassificationPredictor(Host, TrainedEnsemble, FeatureCount, InnerArgs); + return new FastForestClassificationModelParameters(Host, TrainedEnsemble, FeatureCount, InnerArgs); } protected override ObjectiveFunctionBase ConstructObjFunc(IChannel ch) diff --git a/src/Microsoft.ML.FastTree/RandomForestRegression.cs b/src/Microsoft.ML.FastTree/RandomForestRegression.cs index 71daa1b169..f90a37bfa7 100644 --- a/src/Microsoft.ML.FastTree/RandomForestRegression.cs +++ b/src/Microsoft.ML.FastTree/RandomForestRegression.cs @@ -2,19 +2,18 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; +using Microsoft.ML; +using Microsoft.ML.CommandLine; using Microsoft.ML.Core.Data; using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Internallearn; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Training; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; using Microsoft.ML.Trainers.FastTree; using Microsoft.ML.Trainers.FastTree.Internal; -using System; +using Microsoft.ML.Training; [assembly: LoadableClass(FastForestRegression.Summary, typeof(FastForestRegression), typeof(FastForestRegression.Arguments), new[] { typeof(SignatureRegressorTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer), typeof(SignatureFeatureScorerTrainer) }, @@ -22,21 +21,21 @@ FastForestRegression.LoadNameValue, FastForestRegression.ShortName)] -[assembly: LoadableClass(typeof(FastForestRegressionPredictor), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(FastForestRegressionModelParameters), null, typeof(SignatureLoadModel), "FastForest Regression Executor", - FastForestRegressionPredictor.LoaderSignature)] + FastForestRegressionModelParameters.LoaderSignature)] namespace Microsoft.ML.Trainers.FastTree { - public sealed class FastForestRegressionPredictor : - FastTreePredictionWrapper, + public sealed class FastForestRegressionModelParameters : + TreeEnsembleModelParameters, IQuantileValueMapper, IQuantileRegressionPredictor { private readonly int _quantileSampleCount; - public const string LoaderSignature = "FastForestRegressionExec"; - public const string RegistrationName = "FastForestRegressionPredictor"; + internal const string LoaderSignature = "FastForestRegressionExec"; + internal const string RegistrationName = "FastForestRegressionPredictor"; private static VersionInfo GetVersionInfo() { @@ -51,7 +50,7 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010005, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(FastForestRegressionPredictor).Assembly.FullName); + loaderAssemblyName: typeof(FastForestRegressionModelParameters).Assembly.FullName); } protected override uint VerNumFeaturesSerialized => 0x00010003; @@ -60,13 +59,13 @@ private static VersionInfo GetVersionInfo() protected override uint VerCategoricalSplitSerialized => 0x00010006; - public FastForestRegressionPredictor(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs, int samplesCount) + public FastForestRegressionModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs, int samplesCount) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) { _quantileSampleCount = samplesCount; } - private FastForestRegressionPredictor(IHostEnvironment env, ModelLoadContext ctx) + private FastForestRegressionModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, RegistrationName, ctx, GetVersionInfo()) { // *** Binary format *** @@ -76,7 +75,7 @@ private FastForestRegressionPredictor(IHostEnvironment env, ModelLoadContext ctx _quantileSampleCount = ctx.Reader.ReadInt32(); } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { base.SaveCore(ctx); ctx.SetVersionInfo(GetVersionInfo()); @@ -91,12 +90,12 @@ protected override void SaveCore(ModelSaveContext ctx) ctx.Writer.Write(_quantileSampleCount); } - public static FastForestRegressionPredictor Create(IHostEnvironment env, ModelLoadContext ctx) + private static FastForestRegressionModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - return new FastForestRegressionPredictor(env, ctx); + return new FastForestRegressionModelParameters(env, ctx); } public override PredictionKind PredictionKind => PredictionKind.Regression; @@ -111,7 +110,7 @@ protected override void Map(in VBuffer src, ref float dst) dst = (float)TrainedEnsemble.GetOutput(in src) / TrainedEnsemble.NumTrees; } - public ValueMapper, VBuffer> GetMapper(float[] quantiles) + ValueMapper, VBuffer> IQuantileValueMapper.GetMapper(float[] quantiles) { return (in VBuffer src, ref VBuffer dst) => @@ -119,7 +118,7 @@ public ValueMapper, VBuffer> GetMapper(float[] quantiles) // REVIEW: Should make this more efficient - it repeatedly allocates too much stuff. float[] weights = null; var distribution = TrainedEnsemble.GetDistribution(in src, _quantileSampleCount, out weights); - var qdist = new QuantileStatistics(distribution, weights); + IQuantileDistribution qdist = new QuantileStatistics(distribution, weights); var editor = VBufferEditor.Create(ref dst, quantiles.Length); for (int i = 0; i < quantiles.Length; i++) @@ -128,7 +127,7 @@ public ValueMapper, VBuffer> GetMapper(float[] quantiles) }; } - public ISchemaBindableMapper CreateMapper(Double[] quantiles) + ISchemaBindableMapper IQuantileRegressionPredictor.CreateMapper(Double[] quantiles) { Host.CheckNonEmpty(quantiles, nameof(quantiles)); return new SchemaBindableQuantileRegressionPredictor(this, quantiles); @@ -137,7 +136,7 @@ public ISchemaBindableMapper CreateMapper(Double[] quantiles) /// public sealed partial class FastForestRegression - : RandomForestTrainerBase, FastForestRegressionPredictor> + : RandomForestTrainerBase, FastForestRegressionModelParameters> { public sealed class Arguments : FastForestArgumentsBase { @@ -174,7 +173,7 @@ public FastForestRegression(IHostEnvironment env, int minDatapointsInLeaves = Defaults.MinDocumentsInLeaves, double learningRate = Defaults.LearningRates, Action advancedSettings = null) - : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, null, numLeaves, numTrees, minDatapointsInLeaves, learningRate, advancedSettings) + : base(env, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weightColumn, null, numLeaves, numTrees, minDatapointsInLeaves, learningRate, advancedSettings) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); @@ -184,11 +183,11 @@ public FastForestRegression(IHostEnvironment env, /// Initializes a new instance of by using the legacy class. /// public FastForestRegression(IHostEnvironment env, Arguments args) - : base(env, args, TrainerUtils.MakeR4ScalarLabel(args.LabelColumn), true) + : base(env, args, TrainerUtils.MakeR4ScalarColumn(args.LabelColumn), true) { } - private protected override FastForestRegressionPredictor TrainModelCore(TrainContext context) + private protected override FastForestRegressionModelParameters TrainModelCore(TrainContext context) { Host.CheckValue(context, nameof(context)); var trainData = context.TrainingSet; @@ -201,11 +200,11 @@ private protected override FastForestRegressionPredictor TrainModelCore(TrainCon trainData.CheckRegressionLabel(); trainData.CheckFeatureFloatVector(); trainData.CheckOptFloatWeight(); - FeatureCount = trainData.Schema.Feature.Type.ValueCount; + FeatureCount = trainData.Schema.Feature.Value.Type.ValueCount; ConvertData(trainData); TrainCore(ch); } - return new FastForestRegressionPredictor(Host, TrainedEnsemble, FeatureCount, InnerArgs, Args.QuantileSampleCount); + return new FastForestRegressionModelParameters(Host, TrainedEnsemble, FeatureCount, InnerArgs, Args.QuantileSampleCount); } protected override void PrepareLabels(IChannel ch) @@ -222,10 +221,10 @@ protected override Test ConstructTestForTrainingData() return new RegressionTest(ConstructScoreTracker(TrainSet)); } - protected override RegressionPredictionTransformer MakeTransformer(FastForestRegressionPredictor model, Schema trainSchema) - => new RegressionPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); + protected override RegressionPredictionTransformer MakeTransformer(FastForestRegressionModelParameters model, Schema trainSchema) + => new RegressionPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); - public RegressionPredictionTransformer Train(IDataView trainData, IDataView validationData = null) + public RegressionPredictionTransformer Train(IDataView trainData, IDataView validationData = null) => TrainTransformer(trainData, validationData); protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) diff --git a/src/Microsoft.ML.FastTree/SumupPerformanceCommand.cs b/src/Microsoft.ML.FastTree/SumupPerformanceCommand.cs index 14fcad759e..7e1214f2d5 100644 --- a/src/Microsoft.ML.FastTree/SumupPerformanceCommand.cs +++ b/src/Microsoft.ML.FastTree/SumupPerformanceCommand.cs @@ -12,19 +12,19 @@ using System.Collections.Generic; using System.Linq; using System.Threading; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Command; -using Microsoft.ML.Runtime.CommandLine; +using Microsoft.ML; +using Microsoft.ML.Command; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Trainers.FastTree; using Microsoft.ML.Trainers.FastTree.Internal; -using Microsoft.ML.Runtime.Internal.Utilities; [assembly: LoadableClass(typeof(SumupPerformanceCommand), typeof(SumupPerformanceCommand.Arguments), typeof(SignatureCommand), "", "FastTreeSumupPerformance", "ftsumup")] namespace Microsoft.ML.Trainers.FastTree { - using Stopwatch = System.Diagnostics.Stopwatch; +using Stopwatch = System.Diagnostics.Stopwatch; /// /// This is an internal utility command to measure the performance of the IntArray sumup operation. @@ -183,7 +183,7 @@ private IntArray[] CreateRandomIntArrays(IChannel ch) } } - private IEnumerator Geometric(double p, IRandom rgen) + private IEnumerator Geometric(double p, Random rgen) { double denom = Math.Log(1 - p); @@ -209,7 +209,7 @@ private IEnumerator Geometric(double p, IRandom rgen) } } - private IEnumerable CreateDocIndicesCore(double sparsity, IRandom rgen) + private IEnumerable CreateDocIndicesCore(double sparsity, Random rgen) { _host.Assert(0 < sparsity && sparsity < 1); int remaining = _len; @@ -227,7 +227,7 @@ private IEnumerable CreateDocIndicesCore(double sparsity, IRandom rgen) } } - private IEnumerable CreateDocIndices(double sparsity, IRandom rgen) + private IEnumerable CreateDocIndices(double sparsity, Random rgen) { _host.Assert(0 <= sparsity && sparsity <= 1); if (sparsity == 1) @@ -237,7 +237,7 @@ private IEnumerable CreateDocIndices(double sparsity, IRandom rgen) return CreateDocIndicesCore(sparsity, rgen); } - private void InitSumupInputData(SumupInputData data, double sparsity, IRandom rgen) + private void InitSumupInputData(SumupInputData data, double sparsity, Random rgen) { int count = 0; foreach (int d in CreateDocIndices(sparsity, rgen)) diff --git a/src/Microsoft.ML.FastTree/Training/Applications/ObjectiveFunction.cs b/src/Microsoft.ML.FastTree/Training/Applications/ObjectiveFunction.cs index 775f15779b..2808e24cb7 100644 --- a/src/Microsoft.ML.FastTree/Training/Applications/ObjectiveFunction.cs +++ b/src/Microsoft.ML.FastTree/Training/Applications/ObjectiveFunction.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; using System; using System.Collections.Concurrent; using System.Linq; diff --git a/src/Microsoft.ML.FastTree/Training/DcgCalculator.cs b/src/Microsoft.ML.FastTree/Training/DcgCalculator.cs index d81a93d6bc..e9284a6a5e 100644 --- a/src/Microsoft.ML.FastTree/Training/DcgCalculator.cs +++ b/src/Microsoft.ML.FastTree/Training/DcgCalculator.cs @@ -2,11 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; using System; using System.Linq; using System.Threading.Tasks; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Internal.Utilities; namespace Microsoft.ML.Trainers.FastTree.Internal { diff --git a/src/Microsoft.ML.FastTree/Training/DcgPermutationComparer.cs b/src/Microsoft.ML.FastTree/Training/DcgPermutationComparer.cs index 6467c9d0eb..05e2955ec9 100644 --- a/src/Microsoft.ML.FastTree/Training/DcgPermutationComparer.cs +++ b/src/Microsoft.ML.FastTree/Training/DcgPermutationComparer.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; using System.Collections.Generic; namespace Microsoft.ML.Trainers.FastTree.Internal diff --git a/src/Microsoft.ML.FastTree/Training/DocumentPartitioning.cs b/src/Microsoft.ML.FastTree/Training/DocumentPartitioning.cs index 9b2947b556..aac8885a54 100644 --- a/src/Microsoft.ML.FastTree/Training/DocumentPartitioning.cs +++ b/src/Microsoft.ML.FastTree/Training/DocumentPartitioning.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; using System; using System.Collections.Generic; using System.Linq; diff --git a/src/Microsoft.ML.FastTree/Training/EnsembleCompression/IEnsembleCompressor.cs b/src/Microsoft.ML.FastTree/Training/EnsembleCompression/IEnsembleCompressor.cs index 47fca98092..1bdb32fb10 100644 --- a/src/Microsoft.ML.FastTree/Training/EnsembleCompression/IEnsembleCompressor.cs +++ b/src/Microsoft.ML.FastTree/Training/EnsembleCompression/IEnsembleCompressor.cs @@ -1,10 +1,6 @@ -// ----------------------------------------------------------------------- -// -// Copyright (C) All Rights Reserved -// -// ----------------------------------------------------------------------- - -using Microsoft.ML.Runtime; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. namespace Microsoft.ML.Trainers.FastTree.Internal { diff --git a/src/Microsoft.ML.FastTree/Training/EnsembleCompression/LassoBasedEnsembleCompressor.cs b/src/Microsoft.ML.FastTree/Training/EnsembleCompression/LassoBasedEnsembleCompressor.cs index d005573d25..b2f7750e8e 100644 --- a/src/Microsoft.ML.FastTree/Training/EnsembleCompression/LassoBasedEnsembleCompressor.cs +++ b/src/Microsoft.ML.FastTree/Training/EnsembleCompression/LassoBasedEnsembleCompressor.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; using System; using System.Collections.Generic; using System.Diagnostics; diff --git a/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/AcceleratedGradientDescent.cs b/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/AcceleratedGradientDescent.cs index 110c40155e..eb6e27137f 100644 --- a/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/AcceleratedGradientDescent.cs +++ b/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/AcceleratedGradientDescent.cs @@ -2,8 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; - namespace Microsoft.ML.Trainers.FastTree.Internal { //Accelerated gradient descent score tracker diff --git a/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/ConjugateGradientDescent.cs b/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/ConjugateGradientDescent.cs index 7eae8d90e9..7e10e4d479 100644 --- a/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/ConjugateGradientDescent.cs +++ b/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/ConjugateGradientDescent.cs @@ -2,8 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; - namespace Microsoft.ML.Trainers.FastTree.Internal { // Conjugate gradient descent diff --git a/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/GradientDescent.cs b/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/GradientDescent.cs index 159e833b2d..917eb541fd 100644 --- a/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/GradientDescent.cs +++ b/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/GradientDescent.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; using System; using System.Collections.Generic; using System.Linq; diff --git a/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/NoOptimizationAlgorithm.cs b/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/NoOptimizationAlgorithm.cs index 6822bdf4c2..fa5ff98fe8 100644 --- a/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/NoOptimizationAlgorithm.cs +++ b/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/NoOptimizationAlgorithm.cs @@ -2,8 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; - namespace Microsoft.ML.Trainers.FastTree.Internal { /// diff --git a/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/OptimizationAlgorithm.cs b/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/OptimizationAlgorithm.cs index d18107e25d..1366dfd372 100644 --- a/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/OptimizationAlgorithm.cs +++ b/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/OptimizationAlgorithm.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; using System; using System.Collections.Generic; diff --git a/src/Microsoft.ML.FastTree/Training/Parallel/IParallelTraining.cs b/src/Microsoft.ML.FastTree/Training/Parallel/IParallelTraining.cs index 0b28e2157c..03cc1900fb 100644 --- a/src/Microsoft.ML.FastTree/Training/Parallel/IParallelTraining.cs +++ b/src/Microsoft.ML.FastTree/Training/Parallel/IParallelTraining.cs @@ -2,15 +2,14 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Trainers.FastTree.Internal; using System; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Trainers.FastTree.Internal; namespace Microsoft.ML.Trainers.FastTree { - using SplitInfo = Internal.LeastSquaresRegressionTreeLearner.SplitInfo; using LeafSplitCandidates = Internal.LeastSquaresRegressionTreeLearner.LeafSplitCandidates; + using SplitInfo = Internal.LeastSquaresRegressionTreeLearner.SplitInfo; #if USE_SINGLE_PRECISION using FloatType = System.Single; diff --git a/src/Microsoft.ML.FastTree/Training/Parallel/SingleTrainer.cs b/src/Microsoft.ML.FastTree/Training/Parallel/SingleTrainer.cs index c5859cd559..7d31b86adc 100644 --- a/src/Microsoft.ML.FastTree/Training/Parallel/SingleTrainer.cs +++ b/src/Microsoft.ML.FastTree/Training/Parallel/SingleTrainer.cs @@ -3,10 +3,10 @@ // See the LICENSE file in the project root for more information. using System; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Trainers.FastTree; -using Microsoft.ML.Runtime.Internal.Utilities; [assembly: LoadableClass(typeof(Microsoft.ML.Trainers.FastTree.SingleTrainer), null, typeof(Microsoft.ML.Trainers.FastTree.SignatureParallelTrainer), "single")] @@ -16,8 +16,8 @@ namespace Microsoft.ML.Trainers.FastTree { using Microsoft.ML.Trainers.FastTree.Internal; - using SplitInfo = Internal.LeastSquaresRegressionTreeLearner.SplitInfo; using LeafSplitCandidates = Internal.LeastSquaresRegressionTreeLearner.LeafSplitCandidates; + using SplitInfo = Internal.LeastSquaresRegressionTreeLearner.SplitInfo; public sealed class SingleTrainer : IParallelTraining { diff --git a/src/Microsoft.ML.FastTree/Training/ScoreTracker.cs b/src/Microsoft.ML.FastTree/Training/ScoreTracker.cs index c383681958..9f931d5926 100644 --- a/src/Microsoft.ML.FastTree/Training/ScoreTracker.cs +++ b/src/Microsoft.ML.FastTree/Training/ScoreTracker.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; using System; using System.Collections.Generic; using System.Threading.Tasks; diff --git a/src/Microsoft.ML.FastTree/Training/StepSearch.cs b/src/Microsoft.ML.FastTree/Training/StepSearch.cs index 13139f7e68..66ffcb943d 100644 --- a/src/Microsoft.ML.FastTree/Training/StepSearch.cs +++ b/src/Microsoft.ML.FastTree/Training/StepSearch.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; using System; using System.Linq; diff --git a/src/Microsoft.ML.FastTree/Training/Test.cs b/src/Microsoft.ML.FastTree/Training/Test.cs index 6792c84b81..dc587d8e76 100644 --- a/src/Microsoft.ML.FastTree/Training/Test.cs +++ b/src/Microsoft.ML.FastTree/Training/Test.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; using System; using System.Collections.Generic; using System.Linq; diff --git a/src/Microsoft.ML.FastTree/Training/TreeLearners/FastForestLeastSquaresTreeLearner.cs b/src/Microsoft.ML.FastTree/Training/TreeLearners/FastForestLeastSquaresTreeLearner.cs index cd42672c74..d45e3e00ab 100644 --- a/src/Microsoft.ML.FastTree/Training/TreeLearners/FastForestLeastSquaresTreeLearner.cs +++ b/src/Microsoft.ML.FastTree/Training/TreeLearners/FastForestLeastSquaresTreeLearner.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; using System; namespace Microsoft.ML.Trainers.FastTree.Internal diff --git a/src/Microsoft.ML.FastTree/Training/TreeLearners/LeastSquaresRegressionTreeLearner.cs b/src/Microsoft.ML.FastTree/Training/TreeLearners/LeastSquaresRegressionTreeLearner.cs index 7b1a5f8b7d..a78e117c75 100644 --- a/src/Microsoft.ML.FastTree/Training/TreeLearners/LeastSquaresRegressionTreeLearner.cs +++ b/src/Microsoft.ML.FastTree/Training/TreeLearners/LeastSquaresRegressionTreeLearner.cs @@ -2,13 +2,12 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Internal.CpuMath; -using Microsoft.ML.Runtime.Internal.Utilities; using System; using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; +using Microsoft.ML.Internal.CpuMath; +using Microsoft.ML.Internal.Utilities; namespace Microsoft.ML.Trainers.FastTree.Internal { diff --git a/src/Microsoft.ML.FastTree/Training/TreeLearners/TreeLearner.cs b/src/Microsoft.ML.FastTree/Training/TreeLearners/TreeLearner.cs index 8d61f63e8b..19e3dea71b 100644 --- a/src/Microsoft.ML.FastTree/Training/TreeLearners/TreeLearner.cs +++ b/src/Microsoft.ML.FastTree/Training/TreeLearners/TreeLearner.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; using System; namespace Microsoft.ML.Trainers.FastTree.Internal diff --git a/src/Microsoft.ML.FastTree/Training/WinLossCalculator.cs b/src/Microsoft.ML.FastTree/Training/WinLossCalculator.cs index 0c3d89ed10..0320a23ded 100644 --- a/src/Microsoft.ML.FastTree/Training/WinLossCalculator.cs +++ b/src/Microsoft.ML.FastTree/Training/WinLossCalculator.cs @@ -6,7 +6,7 @@ using System.Collections.Concurrent; using System.Linq; using System.Threading.Tasks; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Internal.Utilities; namespace Microsoft.ML.Trainers.FastTree.Internal { diff --git a/src/Microsoft.ML.FastTree/TreeEnsemble/QuantileRegressionTree.cs b/src/Microsoft.ML.FastTree/TreeEnsemble/QuantileRegressionTree.cs index 6025160c20..b40011bc68 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsemble/QuantileRegressionTree.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsemble/QuantileRegressionTree.cs @@ -2,10 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; using Float = System.Single; namespace Microsoft.ML.Trainers.FastTree.Internal diff --git a/src/Microsoft.ML.FastTree/TreeEnsemble/RegressionTree.cs b/src/Microsoft.ML.FastTree/TreeEnsemble/RegressionTree.cs index 834a4188f1..9ed3b25e3c 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsemble/RegressionTree.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsemble/RegressionTree.cs @@ -2,19 +2,18 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Internallearn; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Model.Pfa; -using Newtonsoft.Json.Linq; using System; using System.Collections.Generic; using System.IO; using System.Linq; using System.Text; using System.Threading.Tasks; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Model.Pfa; +using Newtonsoft.Json.Linq; using Float = System.Single; namespace Microsoft.ML.Trainers.FastTree.Internal @@ -1119,7 +1118,7 @@ public void RemapFeatures(int[] oldToNewFeatures) /// A map of feature index (in the features array) /// to the ID as it will be written in the file. This instance should be /// used for all - public void ToTreeEnsembleFormat(StringBuilder sbEvaluator, StringBuilder sbInput, FeaturesToContentMap featureContents, + internal void ToTreeEnsembleFormat(StringBuilder sbEvaluator, StringBuilder sbInput, FeaturesToContentMap featureContents, ref int evaluatorCounter, Dictionary featureToId) { Contracts.AssertValue(sbEvaluator); @@ -1259,7 +1258,7 @@ private void ToTreeEnsembleFormatForCategoricalSplit(StringBuilder sbEvaluator, } // prints the tree out as a string (in old Bing format used by LambdaMART and AdIndex) - public string ToOldIni(FeatureNameCollection featureNames) + internal string ToOldIni(FeatureNameCollection featureNames) { // print the root node StringBuilder output = new StringBuilder(); diff --git a/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsemble.cs b/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsemble.cs index f4bcffd873..746882f51c 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsemble.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsemble.cs @@ -2,18 +2,17 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Model.Pfa; -using Newtonsoft.Json.Linq; using System; using System.Collections.Generic; using System.IO; using System.Linq; using System.Text; using System.Threading.Tasks; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Model.Pfa; +using Newtonsoft.Json.Linq; namespace Microsoft.ML.Trainers.FastTree.Internal { @@ -106,7 +105,7 @@ public void RemapFeatures(int[] oldToNewFeatures) /// /// returns the ensemble in the production TreeEnsemble format /// - public string ToTreeEnsembleIni(FeaturesToContentMap fmap, + internal string ToTreeEnsembleIni(FeaturesToContentMap fmap, string trainingParams, bool appendFeatureGain, bool includeZeroGainFeatures = true) { StringBuilder sbEvaluator = new StringBuilder(); @@ -305,7 +304,7 @@ public void GetOutputs(Dataset dataset, double[] outputs, int prefix) Parallel.Invoke(new ParallelOptions { MaxDegreeOfParallelism = BlockingThreadPool.NumThreads }, actions); } - public string ToGainSummary(FeaturesToContentMap fmap, Dictionary featureToID, int prefix, bool includeZeroGainFeatures, bool normalize, int startingCommentNumber) + internal string ToGainSummary(FeaturesToContentMap fmap, Dictionary featureToID, int prefix, bool includeZeroGainFeatures, bool normalize, int startingCommentNumber) { if (_trees.Count == 0) return string.Empty; @@ -396,7 +395,7 @@ public FeatureToGainMap(IList trees, bool normalize) /// A class that given either a /// provides a mechanism for getting the corresponding input INI content for the features. /// - public sealed class FeaturesToContentMap + internal sealed class FeaturesToContentMap { private readonly VBuffer> _content; private readonly VBuffer> _names; @@ -412,13 +411,13 @@ public sealed class FeaturesToContentMap public FeaturesToContentMap(RoleMappedSchema schema) { Contracts.AssertValue(schema); - var feat = schema.Feature; - Contracts.AssertValue(feat); + Contracts.Assert(schema.Feature.HasValue); + var feat = schema.Feature.Value; Contracts.Assert(feat.Type.ValueCount > 0); var sch = schema.Schema; - if (sch.HasSlotNames(feat.Index, feat.Type.ValueCount)) - sch.GetMetadata(MetadataUtils.Kinds.SlotNames, feat.Index, ref _names); + if (sch[feat.Index].HasSlotNames(feat.Type.ValueCount)) + sch[feat.Index].Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref _names); else _names = VBufferUtils.CreateEmpty>(feat.Type.ValueCount); #if !CORECLR diff --git a/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsembleCombiner.cs b/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsembleCombiner.cs index 84b04ed793..46ee38d566 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsembleCombiner.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsembleCombiner.cs @@ -2,11 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Ensemble; -using Microsoft.ML.Runtime.Internal.Calibration; -using Microsoft.ML.Trainers.FastTree.Internal; using System.Collections.Generic; +using Microsoft.ML; +using Microsoft.ML.Ensemble; +using Microsoft.ML.Internal.Calibration; +using Microsoft.ML.Trainers.FastTree.Internal; [assembly: LoadableClass(typeof(TreeEnsembleCombiner), null, typeof(SignatureModelCombiner), "Fast Tree Model Combiner", "FastTreeCombiner")] @@ -55,9 +55,9 @@ public IPredictor CombineModels(IEnumerable models) _host.Check(calibrated.Calibrator is PlattCalibrator, "Combining FastTree models can only be done when the models are calibrated with Platt calibrator"); predictor = calibrated.SubPredictor; - paramA = -(calibrated.Calibrator as PlattCalibrator).ParamA; + paramA = -(calibrated.Calibrator as PlattCalibrator).Slope; } - var tree = predictor as FastTreePredictionWrapper; + var tree = predictor as TreeEnsembleModelParameters; if (tree == null) throw _host.Except("Model is not a tree ensemble"); foreach (var t in tree.TrainedEnsemble.Trees) @@ -99,14 +99,14 @@ public IPredictor CombineModels(IEnumerable models) { case PredictionKind.BinaryClassification: if (!binaryClassifier) - return new FastTreeBinaryPredictor(_host, ensemble, featureCount, null); + return new FastTreeBinaryModelParameters(_host, ensemble, featureCount, null); var cali = new PlattCalibrator(_host, -1, 0); - return new FeatureWeightsCalibratedPredictor(_host, new FastTreeBinaryPredictor(_host, ensemble, featureCount, null), cali); + return new FeatureWeightsCalibratedPredictor(_host, new FastTreeBinaryModelParameters(_host, ensemble, featureCount, null), cali); case PredictionKind.Regression: - return new FastTreeRegressionPredictor(_host, ensemble, featureCount, null); + return new FastTreeRegressionModelParameters(_host, ensemble, featureCount, null); case PredictionKind.Ranking: - return new FastTreeRankingPredictor(_host, ensemble, featureCount, null); + return new FastTreeRankingModelParameters(_host, ensemble, featureCount, null); default: _host.Assert(false); throw _host.ExceptNotSupp(); diff --git a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs index e8b57dda6e..fd0fbd832c 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs @@ -2,21 +2,20 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Data.Conversion; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Calibration; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.TreePredictor; -using Microsoft.ML.Trainers.FastTree; -using Microsoft.ML.Transforms; using System; using System.Collections.Generic; using System.IO; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.Data.Conversion; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Calibration; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Trainers.FastTree; +using Microsoft.ML.Transforms; +using Microsoft.ML.TreePredictor; [assembly: LoadableClass(typeof(ISchemaBindableMapper), typeof(TreeEnsembleFeaturizerTransform), typeof(TreeEnsembleFeaturizerBindableMapper.Arguments), typeof(SignatureBindableMapper), "Tree Ensemble Featurizer Mapper", TreeEnsembleFeaturizerBindableMapper.LoadNameShort)] @@ -33,7 +32,7 @@ [assembly: LoadableClass(typeof(void), typeof(TreeFeaturize), null, typeof(SignatureEntryPointModule), "TreeFeaturize")] -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Data { /// /// A bindable mapper wrapper for tree ensembles, that creates a bound mapper with three outputs: @@ -174,8 +173,9 @@ private void IsNormalized(int iinfo, ref bool dst) public RoleMappedSchema InputRoleMappedSchema { get; } - public Schema Schema { get; } public Schema InputSchema => InputRoleMappedSchema.Schema; + public Schema OutputSchema { get; } + private Schema.Column FeatureColumn => InputRoleMappedSchema.Feature.Value; public ISchemaBindableMapper Bindable => _owner; @@ -185,7 +185,7 @@ public BoundMapper(IExceptionContext ectx, TreeEnsembleFeaturizerBindableMapper Contracts.AssertValue(ectx); ectx.AssertValue(owner); ectx.AssertValue(schema); - ectx.AssertValue(schema.Feature); + ectx.Assert(schema.Feature.HasValue); _ectx = ectx; @@ -205,18 +205,17 @@ public BoundMapper(IExceptionContext ectx, TreeEnsembleFeaturizerBindableMapper // which means that #internal = #leaf - 1. // Therefore, the number of internal nodes in the ensemble is #leaf - #trees. var pathIdType = new VectorType(NumberType.Float, _owner._totalLeafCount - _owner._ensemble.TrainedEnsemble.NumTrees); - Schema = Schema.Create(new SchemaImpl(ectx, owner, treeValueType, leafIdType, pathIdType)); + OutputSchema = Schema.Create(new SchemaImpl(ectx, owner, treeValueType, leafIdType, pathIdType)); } - public IRow GetRow(IRow input, Func predicate, out Action disposer) + public Row GetRow(Row input, Func predicate) { _ectx.CheckValue(input, nameof(input)); _ectx.CheckValue(predicate, nameof(predicate)); - disposer = null; - return new SimpleRow(Schema, input, CreateGetters(input, predicate)); + return new SimpleRow(OutputSchema, input, CreateGetters(input, predicate)); } - private Delegate[] CreateGetters(IRow input, Func predicate) + private Delegate[] CreateGetters(Row input, Func predicate) { _ectx.AssertValue(input); _ectx.AssertValue(predicate); @@ -230,7 +229,7 @@ private Delegate[] CreateGetters(IRow input, Func predicate) if (!treeValueActive && !leafIdActive && !pathIdActive) return delegates; - var state = new State(_ectx, input, _owner._ensemble, _owner._totalLeafCount, InputRoleMappedSchema.Feature.Index); + var state = new State(_ectx, input, _owner._ensemble, _owner._totalLeafCount, FeatureColumn.Index); // Get the tree value getter. if (treeValueActive) @@ -259,8 +258,8 @@ private Delegate[] CreateGetters(IRow input, Func predicate) private sealed class State { private readonly IExceptionContext _ectx; - private readonly IRow _input; - private readonly FastTreePredictionWrapper _ensemble; + private readonly Row _input; + private readonly TreeEnsembleModelParameters _ensemble; private readonly int _numTrees; private readonly int _numLeaves; @@ -276,7 +275,7 @@ private sealed class State private long _cachedLeafBuilderPosition; private long _cachedPathBuilderPosition; - public State(IExceptionContext ectx, IRow input, FastTreePredictionWrapper ensemble, int numLeaves, int featureIndex) + public State(IExceptionContext ectx, Row input, TreeEnsembleModelParameters ensemble, int numLeaves, int featureIndex) { Contracts.AssertValue(ectx); _ectx = ectx; @@ -392,15 +391,15 @@ private void EnsureCachedPosition() public IEnumerable> GetInputColumnRoles() { - yield return RoleMappedSchema.ColumnRole.Feature.Bind(InputRoleMappedSchema.Feature.Name); + yield return RoleMappedSchema.ColumnRole.Feature.Bind(FeatureColumn.Name); } public Func GetDependencies(Func predicate) { - for (int i = 0; i < Schema.ColumnCount; i++) + for (int i = 0; i < OutputSchema.Count; i++) { if (predicate(i)) - return col => col == InputRoleMappedSchema.Feature.Index; + return col => col == FeatureColumn.Index; } return col => false; } @@ -422,7 +421,7 @@ private static VersionInfo GetVersionInfo() } private readonly IHost _host; - private readonly FastTreePredictionWrapper _ensemble; + private readonly TreeEnsembleModelParameters _ensemble; private readonly int _totalLeafCount; public TreeEnsembleFeaturizerBindableMapper(IHostEnvironment env, Arguments args, IPredictor predictor) @@ -434,7 +433,7 @@ public TreeEnsembleFeaturizerBindableMapper(IHostEnvironment env, Arguments args if (predictor is CalibratedPredictorBase) predictor = ((CalibratedPredictorBase)predictor).SubPredictor; - _ensemble = predictor as FastTreePredictionWrapper; + _ensemble = predictor as TreeEnsembleModelParameters; _host.Check(_ensemble != null, "Predictor in model file does not have compatible type"); _totalLeafCount = CountLeaves(_ensemble); @@ -449,7 +448,7 @@ public TreeEnsembleFeaturizerBindableMapper(IHostEnvironment env, ModelLoadConte // *** Binary format *** // ensemble - ctx.LoadModel(env, out _ensemble, "Ensemble"); + ctx.LoadModel(env, out _ensemble, "Ensemble"); _totalLeafCount = CountLeaves(_ensemble); } @@ -466,7 +465,7 @@ public void Save(ModelSaveContext ctx) ctx.SaveModel(_ensemble, "Ensemble"); } - private static int CountLeaves(FastTreePredictionWrapper ensemble) + private static int CountLeaves(TreeEnsembleModelParameters ensemble) { Contracts.AssertValue(ensemble); @@ -526,7 +525,7 @@ private void GetPathSlotNames(int col, ref VBuffer> dst) dst = editor.Commit(); } - public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema) + ISchemaBoundMapper ISchemaBindableMapper.Bind(IHostEnvironment env, RoleMappedSchema schema) { Contracts.AssertValue(env); env.AssertValue(schema); @@ -576,7 +575,7 @@ public sealed class ArgumentsForEntryPoint : TransformInputBase public int LabelPermutationSeed; [Argument(ArgumentType.Required, HelpText = "Trainer to use", SortOrder = 10, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)] - public IPredictorModel PredictorModel; + public PredictorModel PredictorModel; } #pragma warning restore CS0649 @@ -640,22 +639,23 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV ch.Trace("Creating scorer"); var data = TrainAndScoreTransformer.CreateDataFromArgs(ch, input, args); + Contracts.Assert(data.Schema.Feature.HasValue); // Make sure that the given predictor has the correct number of input features. if (predictor is CalibratedPredictorBase) predictor = ((CalibratedPredictorBase)predictor).SubPredictor; - // Predictor should be a FastTreePredictionWrapper, which implements IValueMapper, so this should + // Predictor should be a TreeEnsembleModelParameters, which implements IValueMapper, so this should // be non-null. var vm = predictor as IValueMapper; ch.CheckUserArg(vm != null, nameof(args.TrainedModelFile), "Predictor in model file does not have compatible type"); - if (vm.InputType.VectorSize != data.Schema.Feature.Type.VectorSize) + if (vm.InputType.VectorSize != data.Schema.Feature.Value.Type.VectorSize) { throw ch.ExceptUserArg(nameof(args.TrainedModelFile), "Predictor in model file expects {0} features, but data has {1} features", - vm.InputType.VectorSize, data.Schema.Feature.Type.VectorSize); + vm.InputType.VectorSize, data.Schema.Feature.Value.Type.VectorSize); } - var bindable = new TreeEnsembleFeaturizerBindableMapper(env, scorerArgs, predictor); + ISchemaBindableMapper bindable = new TreeEnsembleFeaturizerBindableMapper(env, scorerArgs, predictor); var bound = bindable.Bind(env, data.Schema); xf = new GenericScorer(env, scorerArgs, input, bound, data.Schema); } @@ -703,23 +703,24 @@ public static IDataTransform CreateForEntryPoint(IHostEnvironment env, Arguments RoleMappedData data = null; args.PredictorModel.PrepareData(env, input, out data, out var predictor2); ch.AssertValue(data); + ch.Assert(data.Schema.Feature.HasValue); ch.Assert(predictor == predictor2); // Make sure that the given predictor has the correct number of input features. if (predictor is CalibratedPredictorBase) predictor = ((CalibratedPredictorBase)predictor).SubPredictor; - // Predictor should be a FastTreePredictionWrapper, which implements IValueMapper, so this should + // Predictor should be a TreeEnsembleModelParameters, which implements IValueMapper, so this should // be non-null. var vm = predictor as IValueMapper; ch.CheckUserArg(vm != null, nameof(args.PredictorModel), "Predictor does not have compatible type"); - if (data != null && vm.InputType.VectorSize != data.Schema.Feature.Type.VectorSize) + if (data != null && vm.InputType.VectorSize != data.Schema.Feature.Value.Type.VectorSize) { throw ch.ExceptUserArg(nameof(args.PredictorModel), "Predictor expects {0} features, but data has {1} features", - vm.InputType.VectorSize, data.Schema.Feature.Type.VectorSize); + vm.InputType.VectorSize, data.Schema.Feature.Value.Type.VectorSize); } - var bindable = new TreeEnsembleFeaturizerBindableMapper(env, scorerArgs, predictor); + ISchemaBindableMapper bindable = new TreeEnsembleFeaturizerBindableMapper(env, scorerArgs, predictor); var bound = bindable.Bind(env, data.Schema); return new GenericScorer(env, scorerArgs, data.Data, bound, data.Schema); } @@ -781,7 +782,7 @@ private static IDataView AppendLabelTransform(IHostEnvironment env, IChannel ch, int col; if (!input.Schema.TryGetColumnIndex(labelName, out col)) throw ch.Except("Label column '{0}' not found.", labelName); - ColumnType labelType = input.Schema.GetColumnType(col); + ColumnType labelType = input.Schema[col].Type; if (!labelType.IsKey) { if (labelPermutationSeed != 0) @@ -789,7 +790,7 @@ private static IDataView AppendLabelTransform(IHostEnvironment env, IChannel ch, "labelPermutationSeed != 0 only applies on a multi-class learning problem when the label type is a key."); return input; } - return Utils.MarshalInvoke(AppendFloatMapper, labelType.RawType, env, ch, input, labelName, labelType.AsKey, + return Utils.MarshalInvoke(AppendFloatMapper, labelType.RawType, env, ch, input, labelName, (KeyType)labelType, labelPermutationSeed); } } @@ -810,7 +811,7 @@ public static CommonOutputs.TransformOutput Featurizer(IHostEnvironment env, Tre EntryPointUtils.CheckInputArgs(host, input); var xf = TreeEnsembleFeaturizerTransform.CreateForEntryPoint(env, input, input.Data); - return new CommonOutputs.TransformOutput { Model = new TransformModel(env, xf, input.Data), OutputData = xf }; + return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, xf, input.Data), OutputData = xf }; } #pragma warning restore CS0649 } diff --git a/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs b/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs index 584484de10..afc670cdb4 100644 --- a/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs +++ b/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs @@ -2,10 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Trainers.FastTree; using System; +using Microsoft.ML.Data; +using Microsoft.ML.Trainers.FastTree; namespace Microsoft.ML { diff --git a/src/Microsoft.ML.FastTree/TreeTrainersStatic.cs b/src/Microsoft.ML.FastTree/TreeTrainersStatic.cs index 176f0cead8..1e9ff0f4de 100644 --- a/src/Microsoft.ML.FastTree/TreeTrainersStatic.cs +++ b/src/Microsoft.ML.FastTree/TreeTrainersStatic.cs @@ -2,12 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Trainers.FastTree; -using Microsoft.ML.Runtime.Internal.Internallearn; -using Microsoft.ML.StaticPipe.Runtime; using System; +using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.StaticPipe.Runtime; +using Microsoft.ML.Trainers.FastTree; namespace Microsoft.ML.StaticPipe { @@ -48,7 +46,7 @@ public static Scalar FastTree(this RegressionContext.RegressionTrainers c int minDatapointsInLeaves = Defaults.MinDocumentsInLeaves, double learningRate = Defaults.LearningRates, Action advancedSettings = null, - Action onFit = null) + Action onFit = null) { CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeaves, learningRate, advancedSettings, onFit); @@ -144,7 +142,7 @@ public static Scalar FastTree(this RankingContext.RankingTrainers c int minDatapointsInLeaves = Defaults.MinDocumentsInLeaves, double learningRate = Defaults.LearningRates, Action advancedSettings = null, - Action onFit = null) + Action onFit = null) { CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeaves, learningRate, advancedSettings, onFit); diff --git a/src/Microsoft.ML.FastTree/Utils/Algorithms.cs b/src/Microsoft.ML.FastTree/Utils/Algorithms.cs index 304e185c4f..9ad94b359a 100644 --- a/src/Microsoft.ML.FastTree/Utils/Algorithms.cs +++ b/src/Microsoft.ML.FastTree/Utils/Algorithms.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; using System; using System.Linq; diff --git a/src/Microsoft.ML.FastTree/Utils/BufferPoolManager.cs b/src/Microsoft.ML.FastTree/Utils/BufferPoolManager.cs index 5197e7ed4c..cd0d03cfc8 100644 --- a/src/Microsoft.ML.FastTree/Utils/BufferPoolManager.cs +++ b/src/Microsoft.ML.FastTree/Utils/BufferPoolManager.cs @@ -4,7 +4,6 @@ // // ----------------------------------------------------------------------- -using Microsoft.ML.Runtime; using System; using System.Collections.Concurrent; using System.Collections.Generic; diff --git a/src/Microsoft.ML.FastTree/Utils/FastTreeIniFileUtils.cs b/src/Microsoft.ML.FastTree/Utils/FastTreeIniFileUtils.cs new file mode 100644 index 0000000000..1a2a35136b --- /dev/null +++ b/src/Microsoft.ML.FastTree/Utils/FastTreeIniFileUtils.cs @@ -0,0 +1,55 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Text; +using Microsoft.ML.Calibrator; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Calibration; +using Microsoft.ML.Internal.Utilities; + +namespace Microsoft.ML.Trainers.FastTree.Internal +{ + internal static class FastTreeIniFileUtils + { + public static string TreeEnsembleToIni( + IHost host, TreeEnsemble ensemble, RoleMappedSchema schema, ICalibrator calibrator, + string trainingParams, bool appendFeatureGain, bool includeZeroGainFeatures) + { + host.CheckValue(ensemble, nameof(ensemble)); + host.CheckValue(schema, nameof(schema)); + + string ensembleIni = ensemble.ToTreeEnsembleIni(new FeaturesToContentMap(schema), + trainingParams, appendFeatureGain, includeZeroGainFeatures); + ensembleIni = AddCalibrationToIni(host, ensembleIni, calibrator); + return ensembleIni; + } + + /// + /// Get the calibration summary in INI format + /// + private static string AddCalibrationToIni(IHost host, string ini, ICalibrator calibrator) + { + host.AssertValue(ini); + host.AssertValueOrNull(calibrator); + + if (calibrator == null) + return ini; + + if (calibrator is PlattCalibrator) + { + string calibratorEvaluatorIni = IniFileUtils.GetCalibratorEvaluatorIni(ini, calibrator as PlattCalibrator); + return IniFileUtils.AddEvaluator(ini, calibratorEvaluatorIni); + } + else + { + StringBuilder newSection = new StringBuilder(); + newSection.AppendLine(); + newSection.AppendLine(); + newSection.AppendLine("[TLCCalibration]"); + newSection.AppendLine("Type=" + calibrator.GetType().Name); + return ini + newSection; + } + } + } +} diff --git a/src/Microsoft.ML.FastTree/Utils/LinqExtensions.cs b/src/Microsoft.ML.FastTree/Utils/LinqExtensions.cs index 21e19dfd60..88d8825205 100644 --- a/src/Microsoft.ML.FastTree/Utils/LinqExtensions.cs +++ b/src/Microsoft.ML.FastTree/Utils/LinqExtensions.cs @@ -23,14 +23,14 @@ public static int ArgMin(this T[] arr) where T : IComparable return argMin; } - public static int ArgMax(this T[] arr) where T : IComparable + public static int ArgMax(this ReadOnlySpan span) where T : IComparable { - if (arr.Length == 0) + if (span.Length == 0) return -1; int argMax = 0; - for (int i = 1; i < arr.Length; i++) + for (int i = 1; i < span.Length; i++) { - if (arr[i].CompareTo(arr[argMax]) > 0) + if (span[i].CompareTo(span[argMax]) > 0) argMax = i; } return argMax; diff --git a/src/Microsoft.ML.FastTree/Utils/MD5Hasher.cs b/src/Microsoft.ML.FastTree/Utils/MD5Hasher.cs index 4367d7940d..fc31ec097e 100644 --- a/src/Microsoft.ML.FastTree/Utils/MD5Hasher.cs +++ b/src/Microsoft.ML.FastTree/Utils/MD5Hasher.cs @@ -2,11 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Internal.Utilities; using System; using System.IO; using System.Security.Cryptography; +using Microsoft.ML.Internal.Utilities; namespace Microsoft.ML.Trainers.FastTree.Internal { diff --git a/src/Microsoft.ML.FastTree/Utils/ThreadTaskManager.cs b/src/Microsoft.ML.FastTree/Utils/ThreadTaskManager.cs index e53bc7f0c6..d479c2c53b 100644 --- a/src/Microsoft.ML.FastTree/Utils/ThreadTaskManager.cs +++ b/src/Microsoft.ML.FastTree/Utils/ThreadTaskManager.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; using System; using System.Collections.Generic; using System.Linq; diff --git a/src/Microsoft.ML.FastTree/Utils/ToByteArrayExtensions.cs b/src/Microsoft.ML.FastTree/Utils/ToByteArrayExtensions.cs index 86df84b0cb..53d475df1e 100644 --- a/src/Microsoft.ML.FastTree/Utils/ToByteArrayExtensions.cs +++ b/src/Microsoft.ML.FastTree/Utils/ToByteArrayExtensions.cs @@ -5,7 +5,7 @@ using System; using System.Linq; using System.Text; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Internal.Utilities; namespace Microsoft.ML.Trainers.FastTree.Internal { @@ -194,9 +194,9 @@ public static ulong ToULong(this byte[] buffer, ref int position) return a; } - // UInt128 + // RowId - public static MD5Hash ToUInt128(this byte[] buffer, ref int position) + public static MD5Hash ToRowId(this byte[] buffer, ref int position) { MD5Hash a = new MD5Hash { @@ -550,7 +550,7 @@ public static unsafe ulong[] ToULongArray(this byte[] buffer, ref int position) return a; } - // UInt128[] + // RowId[] public static int SizeInBytes(this MD5Hash[] array) { @@ -572,7 +572,7 @@ public static unsafe MD5Hash[] ToUInt128Array(this byte[] buffer, ref int positi MD5Hash[] a = new MD5Hash[length]; for (int i = 0; i < length; ++i) { - a[i] = buffer.ToUInt128(ref position); + a[i] = buffer.ToRowId(ref position); } return a; } diff --git a/src/Microsoft.ML.HalLearners.StaticPipe/Microsoft.ML.HalLearners.StaticPipe.csproj b/src/Microsoft.ML.HalLearners.StaticPipe/Microsoft.ML.HalLearners.StaticPipe.csproj new file mode 100644 index 0000000000..aabe2dec4c --- /dev/null +++ b/src/Microsoft.ML.HalLearners.StaticPipe/Microsoft.ML.HalLearners.StaticPipe.csproj @@ -0,0 +1,12 @@ + + + + netstandard2.0 + + + + + + + + diff --git a/src/Microsoft.ML.HalLearners/TransformsStatic.cs b/src/Microsoft.ML.HalLearners.StaticPipe/VectorWhiteningStaticExtensions.cs similarity index 89% rename from src/Microsoft.ML.HalLearners/TransformsStatic.cs rename to src/Microsoft.ML.HalLearners.StaticPipe/VectorWhiteningStaticExtensions.cs index 87ea5a2c0f..97e6005bc7 100644 --- a/src/Microsoft.ML.HalLearners/TransformsStatic.cs +++ b/src/Microsoft.ML.HalLearners.StaticPipe/VectorWhiteningStaticExtensions.cs @@ -1,19 +1,19 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Collections.Generic; using Microsoft.ML.Core.Data; -using Microsoft.ML.Runtime; +using Microsoft.ML.StaticPipe; using Microsoft.ML.StaticPipe.Runtime; using Microsoft.ML.Transforms.Projections; -using System.Collections.Generic; -namespace Microsoft.ML.StaticPipe +namespace Microsoft.ML.HalLearners.StaticPipe { /// /// Extensions for statically typed Whitening estimator. /// - public static class VectorWhiteningExtensions + public static class VectorWhiteningStaticExtensions { private sealed class OutPipelineColumn : Vector { @@ -57,7 +57,7 @@ public override IEstimator Reconcile(IHostEnvironment env, } } - /// + /// /// The column to which the transform will be applied. /// Whitening constant, prevents division by zero when scaling the data by inverse of eigenvalues. /// Maximum number of rows used to train the transform. @@ -68,7 +68,7 @@ public static Vector PcaWhitening(this Vector input, int pcaNum = VectorWhiteningTransformer.Defaults.PcaNum) => new OutPipelineColumn(input, WhiteningKind.Pca, eps, maxRows, pcaNum); - /// + /// /// The column to which the transform will be applied. /// Whitening constant, prevents division by zero. /// Maximum number of rows used to train the transform. diff --git a/src/Microsoft.ML.HalLearners/ComputeLRTrainingStdThroughHal.cs b/src/Microsoft.ML.HalLearners/ComputeLRTrainingStdThroughHal.cs index 0a387c1e03..bc30deabfc 100644 --- a/src/Microsoft.ML.HalLearners/ComputeLRTrainingStdThroughHal.cs +++ b/src/Microsoft.ML.HalLearners/ComputeLRTrainingStdThroughHal.cs @@ -2,12 +2,12 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Trainers.HalLearners; using System; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Trainers.HalLearners; -namespace Microsoft.ML.Runtime.Learners +namespace Microsoft.ML.Learners { using Mkl = OlsLinearRegressionTrainer.Mkl; diff --git a/src/Microsoft.ML.HalLearners/HalLearnersCatalog.cs b/src/Microsoft.ML.HalLearners/HalLearnersCatalog.cs index 72d163e65a..62420c9b0f 100644 --- a/src/Microsoft.ML.HalLearners/HalLearnersCatalog.cs +++ b/src/Microsoft.ML.HalLearners/HalLearnersCatalog.cs @@ -2,12 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; +using System; +using Microsoft.ML.Data; using Microsoft.ML.Trainers.HalLearners; using Microsoft.ML.Trainers.SymSgd; using Microsoft.ML.Transforms.Projections; -using System; namespace Microsoft.ML { diff --git a/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs index 5f69c3e83d..1dd6fa550e 100644 --- a/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs +++ b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs @@ -2,22 +2,22 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Core.Data; -using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Internallearn; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Learners; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Training; -using Microsoft.ML.Trainers.HalLearners; using System; using System.Collections.Generic; using System.IO; using System.Runtime.InteropServices; +using System.Security; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Learners; +using Microsoft.ML.Model; +using Microsoft.ML.Trainers.HalLearners; +using Microsoft.ML.Training; [assembly: LoadableClass(OlsLinearRegressionTrainer.Summary, typeof(OlsLinearRegressionTrainer), typeof(OlsLinearRegressionTrainer.Arguments), new[] { typeof(SignatureRegressorTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) }, @@ -25,16 +25,16 @@ OlsLinearRegressionTrainer.LoadNameValue, OlsLinearRegressionTrainer.ShortName)] -[assembly: LoadableClass(typeof(OlsLinearRegressionPredictor), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(OlsLinearRegressionModelParameters), null, typeof(SignatureLoadModel), "OLS Linear Regression Executor", - OlsLinearRegressionPredictor.LoaderSignature)] + OlsLinearRegressionModelParameters.LoaderSignature)] [assembly: LoadableClass(typeof(void), typeof(OlsLinearRegressionTrainer), null, typeof(SignatureEntryPointModule), OlsLinearRegressionTrainer.LoadNameValue)] namespace Microsoft.ML.Trainers.HalLearners { /// - public sealed class OlsLinearRegressionTrainer : TrainerEstimatorBase, OlsLinearRegressionPredictor> + public sealed class OlsLinearRegressionTrainer : TrainerEstimatorBase, OlsLinearRegressionModelParameters> { public sealed class Arguments : LearnerInputBaseWithWeight { @@ -87,7 +87,7 @@ public OlsLinearRegressionTrainer(IHostEnvironment env, /// internal OlsLinearRegressionTrainer(IHostEnvironment env, Arguments args) : base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), - TrainerUtils.MakeR4ScalarLabel(args.LabelColumn), TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit)) + TrainerUtils.MakeR4ScalarColumn(args.LabelColumn), TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit)) { Host.CheckValue(args, nameof(args)); Host.CheckUserArg(args.L2Weight >= 0, nameof(args.L2Weight), "L2 regularization term cannot be negative"); @@ -110,10 +110,10 @@ private static Arguments ArgsInit(string featureColumn, return args; } - protected override RegressionPredictionTransformer MakeTransformer(OlsLinearRegressionPredictor model, Schema trainSchema) - => new RegressionPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); + protected override RegressionPredictionTransformer MakeTransformer(OlsLinearRegressionModelParameters model, Schema trainSchema) + => new RegressionPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); - public RegressionPredictionTransformer Train(IDataView trainData, IPredictor initialPredictor = null) + public RegressionPredictionTransformer Train(IDataView trainData, IPredictor initialPredictor = null) => TrainTransformer(trainData, initPredictor: initialPredictor); protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) @@ -134,22 +134,22 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc private static Double ProbClamp(Double p) => Math.Max(0, Math.Min(p, 1)); - private protected override OlsLinearRegressionPredictor TrainModelCore(TrainContext context) + private protected override OlsLinearRegressionModelParameters TrainModelCore(TrainContext context) { using (var ch = Host.Start("Training")) { ch.CheckValue(context, nameof(context)); var examples = context.TrainingSet; - ch.CheckParam(examples.Schema.Feature != null, nameof(examples), "Need a feature column"); - ch.CheckParam(examples.Schema.Label != null, nameof(examples), "Need a labelColumn column"); + ch.CheckParam(examples.Schema.Feature.HasValue, nameof(examples), "Need a feature column"); + ch.CheckParam(examples.Schema.Label.HasValue, nameof(examples), "Need a labelColumn column"); // The labelColumn type must be either Float or a key type based on int (if allowKeyLabels is true). - var typeLab = examples.Schema.Label.Type; + var typeLab = examples.Schema.Label.Value.Type; if (typeLab != NumberType.Float) throw ch.Except("Incompatible labelColumn column type {0}, must be {1}", typeLab, NumberType.Float); // The feature type must be a vector of Float. - var typeFeat = examples.Schema.Feature.Type; + var typeFeat = examples.Schema.Feature.Value.Type; if (!typeFeat.IsKnownSizeVector) throw ch.Except("Incompatible feature column type {0}, must be known sized vector of {1}", typeFeat, NumberType.Float); if (typeFeat.ItemType != NumberType.Float) @@ -161,7 +161,7 @@ private protected override OlsLinearRegressionPredictor TrainModelCore(TrainCont } } - private OlsLinearRegressionPredictor TrainCore(IChannel ch, FloatLabelCursor.Factory cursorFactory, int featureCount) + private OlsLinearRegressionModelParameters TrainCore(IChannel ch, FloatLabelCursor.Factory cursorFactory, int featureCount) { Host.AssertValue(ch); ch.AssertValue(cursorFactory); @@ -292,14 +292,14 @@ private OlsLinearRegressionPredictor TrainCore(IChannel ch, FloatLabelCursor.Fac { // We would expect the solution to the problem to be exact in this case. ch.Info("Number of examples equals number of parameters, solution is exact but no statistics can be derived"); - return new OlsLinearRegressionPredictor(Host, in weights, bias, null, null, null, 1, float.NaN); + return new OlsLinearRegressionModelParameters(Host, in weights, bias); } Double rss = 0; // residual sum of squares Double tss = 0; // total sum of squares using (var cursor = cursorFactory.Create()) { - var lrPredictor = new LinearRegressionPredictor(Host, in weights, bias); + IValueMapper lrPredictor = new LinearRegressionModelParameters(Host, in weights, bias); var lrMap = lrPredictor.GetMapper, float>(); float yh = default; while (cursor.MoveNext()) @@ -328,7 +328,7 @@ private OlsLinearRegressionPredictor TrainCore(IChannel ch, FloatLabelCursor.Fac // Also we can't estimate it, unless we can estimate the variance, which requires more examples than // parameters. if (!_perParameterSignificance || m >= n) - return new OlsLinearRegressionPredictor(Host, in weights, bias, null, null, null, rSquared, rSquaredAdjusted); + return new OlsLinearRegressionModelParameters(Host, in weights, bias, rSquared: rSquared, rSquaredAdjusted: rSquaredAdjusted); ch.Assert(!Double.IsNaN(rSquaredAdjusted)); var standardErrors = new Double[m]; @@ -375,12 +375,12 @@ private OlsLinearRegressionPredictor TrainCore(IChannel ch, FloatLabelCursor.Fac ch.Check(0 <= pValues[i] && pValues[i] <= 1, "p-Value calculated outside expected [0,1] range"); } - return new OlsLinearRegressionPredictor(Host, in weights, bias, standardErrors, tValues, pValues, rSquared, rSquaredAdjusted); + return new OlsLinearRegressionModelParameters(Host, in weights, bias, standardErrors, tValues, pValues, rSquared, rSquaredAdjusted); } internal static class Mkl { - private const string DllName = "MklImports"; + private const string MklPath = "MklImports"; public enum Layout { @@ -394,7 +394,7 @@ public enum UpLo : byte Lo = (byte)'L' } - [DllImport(DllName, EntryPoint = "LAPACKE_dpptrf")] + [DllImport(MklPath, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LAPACKE_dpptrf"), SuppressUnmanagedCodeSecurity] private static extern int PptrfInternal(Layout layout, UpLo uplo, int n, Double[] ap); /// @@ -429,7 +429,7 @@ public static void Pptrf(Layout layout, UpLo uplo, int n, Double[] ap) } } - [DllImport(DllName, EntryPoint = "LAPACKE_dpptrs")] + [DllImport(MklPath, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LAPACKE_dpptrs"), SuppressUnmanagedCodeSecurity] private static extern int PptrsInternal(Layout layout, UpLo uplo, int n, int nrhs, Double[] ap, Double[] b, int ldb); /// @@ -476,7 +476,7 @@ public static void Pptrs(Layout layout, UpLo uplo, int n, int nrhs, Double[] ap, } - [DllImport(DllName, EntryPoint = "LAPACKE_dpptri")] + [DllImport(MklPath, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LAPACKE_dpptri"), SuppressUnmanagedCodeSecurity] private static extern int PptriInternal(Layout layout, UpLo uplo, int n, Double[] ap); /// @@ -535,10 +535,10 @@ public static CommonOutputs.RegressionOutput TrainRegression(IHostEnvironment en /// /// A linear predictor for which per parameter significance statistics are available. /// - public sealed class OlsLinearRegressionPredictor : RegressionPredictor + public sealed class OlsLinearRegressionModelParameters : RegressionModelParameters { - public const string LoaderSignature = "OlsLinearRegressionExec"; - public const string RegistrationName = "OlsLinearRegressionPredictor"; + internal const string LoaderSignature = "OlsLinearRegressionExec"; + internal const string RegistrationName = "OlsLinearRegressionPredictor"; /// /// Version information to be saved in binary format @@ -551,7 +551,7 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(OlsLinearRegressionPredictor).Assembly.FullName); + loaderAssemblyName: typeof(OlsLinearRegressionModelParameters).Assembly.FullName); } // The following will be null iff RSquaredAdjusted is NaN. @@ -564,16 +564,14 @@ private static VersionInfo GetVersionInfo() /// /// The coefficient of determination. /// - public Double RSquared - { get { return _rSquared; } } + public Double RSquared => _rSquared; /// /// The adjusted coefficient of determination. It is only possible to produce /// an adjusted R-squared if there are more examples than parameters in the model /// plus one. If this condition is not met, this value will be NaN. /// - public Double RSquaredAdjusted - { get { return _rSquaredAdjusted; } } + public Double RSquaredAdjusted => _rSquaredAdjusted; /// /// Whether the model has per parameter statistics. This is false iff @@ -584,33 +582,41 @@ public Double RSquaredAdjusted /// /// to false. /// - public bool HasStatistics - { get { return _standardErrors != null; } } + public bool HasStatistics => _standardErrors != null; /// /// The standard error per model parameter, where the first corresponds to the bias, /// and all subsequent correspond to each weight in turn. This is null if and /// only if is false. /// - public IReadOnlyCollection StandardErrors - { get { return _standardErrors.AsReadOnly(); } } + public IReadOnlyCollection StandardErrors => _standardErrors.AsReadOnly(); /// /// t-Statistic values corresponding to each of the model standard errors. This is /// null if and only if is false. /// - public IReadOnlyCollection TValues - { get { return _tValues.AsReadOnly(); } } + public IReadOnlyCollection TValues => _tValues.AsReadOnly(); /// /// p-values corresponding to each of the model standard errors. This is null /// if and only if is false. /// - public IReadOnlyCollection PValues - { get { return _pValues.AsReadOnly(); } } + public IReadOnlyCollection PValues => _pValues.AsReadOnly(); - internal OlsLinearRegressionPredictor(IHostEnvironment env, in VBuffer weights, float bias, - Double[] standardErrors, Double[] tValues, Double[] pValues, Double rSquared, Double rSquaredAdjusted) + /// + /// Constructs a new OLS regression model parameters from trained model. + /// + /// The Host environment. + /// The weights for the linear model. The i-th element of weights is the coefficient + /// of the i-th feature. Note that this will take ownership of the . + /// The bias added to every output score. + /// Optional: The statndard errors of the weights and bias. + /// Optional: The t-statistics for the estimates of the weights and bias. + /// Optional: The p-values of the weights and bias. + /// The coefficient of determination. + /// The adjusted coefficient of determination. + public OlsLinearRegressionModelParameters(IHostEnvironment env, in VBuffer weights, float bias, + Double[] standardErrors = null, Double[] tValues = null, Double[] pValues = null, Double rSquared = 1, Double rSquaredAdjusted = float.NaN) : base(env, RegistrationName, in weights, bias) { Contracts.AssertValueOrNull(standardErrors); @@ -646,7 +652,7 @@ internal OlsLinearRegressionPredictor(IHostEnvironment env, in VBuffer we _rSquaredAdjusted = rSquaredAdjusted; } - private OlsLinearRegressionPredictor(IHostEnvironment env, ModelLoadContext ctx) + private OlsLinearRegressionModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, RegistrationName, ctx) { // *** Binary format *** @@ -685,7 +691,7 @@ private OlsLinearRegressionPredictor(IHostEnvironment env, ModelLoadContext ctx) ProbCheckDecode(_pValues[i]); } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { base.SaveCore(ctx); ctx.SetVersionInfo(GetVersionInfo()); @@ -730,15 +736,15 @@ private static void ProbCheckDecode(Double p) Contracts.CheckDecode(0 <= p && p <= 1); } - public static OlsLinearRegressionPredictor Create(IHostEnvironment env, ModelLoadContext ctx) + private static OlsLinearRegressionModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - return new OlsLinearRegressionPredictor(env, ctx); + return new OlsLinearRegressionModelParameters(env, ctx); } - public override void SaveSummary(TextWriter writer, RoleMappedSchema schema) + private protected override void SaveSummary(TextWriter writer, RoleMappedSchema schema) { var names = default(VBuffer>); MetadataUtils.GetSlotNames(schema, RoleMappedSchema.ColumnRole.Feature, Weight.Length, ref names); diff --git a/src/Microsoft.ML.HalLearners/Properties/AssemblyInfo.cs b/src/Microsoft.ML.HalLearners/Properties/AssemblyInfo.cs new file mode 100644 index 0000000000..694c53c65f --- /dev/null +++ b/src/Microsoft.ML.HalLearners/Properties/AssemblyInfo.cs @@ -0,0 +1,10 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Runtime.CompilerServices; +using Microsoft.ML; + +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.HalLearners.StaticPipe" + PublicKey.Value)] + +[assembly: WantsToBeBestFriends] diff --git a/src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs b/src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs index 0cbe520303..cc7c07edc2 100644 --- a/src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs +++ b/src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs @@ -2,24 +2,23 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Core.Data; -using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Data.Conversion; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Calibration; -using Microsoft.ML.Runtime.Internal.Internallearn; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Learners; -using Microsoft.ML.Runtime.Training; -using Microsoft.ML.Trainers.SymSgd; -using Microsoft.ML.Transforms; using System; using System.Collections.Generic; using System.Runtime.InteropServices; using System.Security; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data; +using Microsoft.ML.Data.Conversion; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Calibration; +using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Learners; +using Microsoft.ML.Trainers.SymSgd; +using Microsoft.ML.Training; +using Microsoft.ML.Transforms; [assembly: LoadableClass(typeof(SymSgdClassificationTrainer), typeof(SymSgdClassificationTrainer.Arguments), new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) }, @@ -124,13 +123,12 @@ private RoleMappedData PrepareDataFromTrainingExamples(IChannel ch, RoleMappedDa var roles = examples.Schema.GetColumnRoleNames(); var examplesToFeedTrain = new RoleMappedData(idvToFeedTrain, roles); - ch.AssertValue(examplesToFeedTrain.Schema.Label); - ch.AssertValue(examplesToFeedTrain.Schema.Feature); - if (examples.Schema.Weight != null) - ch.AssertValue(examplesToFeedTrain.Schema.Weight); + ch.Assert(examplesToFeedTrain.Schema.Label.HasValue); + ch.Assert(examplesToFeedTrain.Schema.Feature.HasValue); + if (examples.Schema.Weight.HasValue) + ch.Assert(examplesToFeedTrain.Schema.Weight.HasValue); - int numFeatures = examplesToFeedTrain.Schema.Feature.Type.VectorSize; - ch.Check(numFeatures > 0, "Training set has no features, aborting training."); + ch.Check(examplesToFeedTrain.Schema.Feature.Value.Type is VectorType vecType && vecType.Size > 0, "Training set has no features, aborting training."); return examplesToFeedTrain; } @@ -141,8 +139,8 @@ private protected override TPredictor TrainModelCore(TrainContext context) { var preparedData = PrepareDataFromTrainingExamples(ch, context.TrainingSet, out int weightSetCount); var initPred = context.InitialPredictor; - var linInitPred = (initPred as CalibratedPredictorBase)?.SubPredictor as LinearPredictor; - linInitPred = linInitPred ?? initPred as LinearPredictor; + var linInitPred = (initPred as CalibratedPredictorBase)?.SubPredictor as LinearModelParameters; + linInitPred = linInitPred ?? initPred as LinearModelParameters; Host.CheckParam(context.InitialPredictor == null || linInitPred != null, nameof(context), "Initial predictor was not a linear predictor."); return TrainCore(ch, preparedData, linInitPred, weightSetCount); @@ -173,7 +171,7 @@ public SymSgdClassificationTrainer(IHostEnvironment env, _args.FeatureColumn = featureColumn; _args.LabelColumn = labelColumn; - Info = new TrainerInfo(supportIncrementalTrain:true); + Info = new TrainerInfo(supportIncrementalTrain: true); } /// @@ -195,7 +193,7 @@ private TPredictor CreatePredictor(VBuffer weights, float bias) VBuffer maybeSparseWeights = default; VBufferUtils.CreateMaybeSparseCopy(in weights, ref maybeSparseWeights, Conversions.Instance.GetIsDefaultPredicate(NumberType.R4)); - var predictor = new LinearBinaryPredictor(Host, in maybeSparseWeights, bias); + var predictor = new LinearBinaryModelParameters(Host, in maybeSparseWeights, bias); return new ParameterMixingCalibratedPredictor(Host, predictor, new PlattCalibrator(Host, -1, 0)); } @@ -203,7 +201,7 @@ protected override BinaryPredictionTransformer MakeTransformer(TPred => new BinaryPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); public BinaryPredictionTransformer Train(IDataView trainData, TPredictor initialPredictor = null) - => TrainTransformer(trainData, initPredictor: initialPredictor); + => TrainTransformer(trainData, initPredictor: initialPredictor); protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) { @@ -635,9 +633,9 @@ public void Dispose() } } - private TPredictor TrainCore(IChannel ch, RoleMappedData data, LinearPredictor predictor, int weightSetCount) + private TPredictor TrainCore(IChannel ch, RoleMappedData data, LinearModelParameters predictor, int weightSetCount) { - int numFeatures = data.Schema.Feature.Type.VectorSize; + int numFeatures = data.Schema.Feature.Value.Type.VectorSize; var cursorFactory = new FloatLabelCursor.Factory(data, CursOpt.Label | CursOpt.Features | CursOpt.Weight); int numThreads = 1; ch.CheckUserArg(numThreads > 0, nameof(_args.NumberOfThreads), @@ -692,7 +690,8 @@ private TPredictor TrainCore(IChannel ch, RoleMappedData data, LinearPredictor p entry => entry.SetProgress(0, state.PassIteration, _args.NumberOfIterations)); // If fully loaded, call the SymSGDNative and do not come back until learned for all iterations. Native.LearnAll(inputDataManager, tuneLR, ref lr, l2Const, piw, weightsEditor.Values, ref bias, numFeatures, - _args.NumberOfIterations, numThreads, tuneNumLocIter, ref numLocIter, _args.Tolerance, _args.Shuffle, shouldInitialize, stateGCHandle); + _args.NumberOfIterations, numThreads, tuneNumLocIter, ref numLocIter, _args.Tolerance, _args.Shuffle, shouldInitialize, + stateGCHandle, ch.Info); shouldInitialize = false; } else @@ -713,7 +712,8 @@ private TPredictor TrainCore(IChannel ch, RoleMappedData data, LinearPredictor p numPassesForThisBatch = Math.Max(1, numPassesForThisBatch); state.PassIteration = iter; Native.LearnAll(inputDataManager, tuneLR, ref lr, l2Const, piw, weightsEditor.Values, ref bias, numFeatures, - numPassesForThisBatch, numThreads, tuneNumLocIter, ref numLocIter, _args.Tolerance, _args.Shuffle, shouldInitialize, stateGCHandle); + numPassesForThisBatch, numThreads, tuneNumLocIter, ref numLocIter, _args.Tolerance, _args.Shuffle, shouldInitialize, + stateGCHandle, ch.Info); shouldInitialize = false; // Check if we are done with going through the data @@ -760,12 +760,16 @@ private static unsafe class Native //To triger the loading of MKL library since SymSGD native library depends on it. static Native() => ErrorMessage(0); - internal const string DllName = "SymSgdNative"; + internal const string NativePath = "SymSgdNative"; + internal const string MklPath = "MklImports"; + + public delegate void ChannelCallBack(string message); - [DllImport(DllName), SuppressUnmanagedCodeSecurity] + [DllImport(NativePath), SuppressUnmanagedCodeSecurity] private static extern void LearnAll(int totalNumInstances, int* instSizes, int** instIndices, float** instValues, float* labels, bool tuneLR, ref float lr, float l2Const, float piw, float* weightVector, ref float bias, - int numFeatres, int numPasses, int numThreads, bool tuneNumLocIter, ref int numLocIter, float tolerance, bool needShuffle, bool shouldInitialize, State* state); + int numFeatres, int numPasses, int numThreads, bool tuneNumLocIter, ref int numLocIter, float tolerance, bool needShuffle, bool shouldInitialize, + State* state, ChannelCallBack info); /// /// This method puts all of the buffered instances in array of pointers to pass it to SymSGDNative. @@ -786,9 +790,10 @@ private static extern void LearnAll(int totalNumInstances, int* instSizes, int** /// Specifies if data needs to be shuffled /// Specifies if this is the first time to run SymSGD /// + /// public static void LearnAll(InputDataManager inputDataManager, bool tuneLR, ref float lr, float l2Const, float piw, Span weightVector, ref float bias, int numFeatres, int numPasses, - int numThreads, bool tuneNumLocIter, ref int numLocIter, float tolerance, bool needShuffle, bool shouldInitialize, GCHandle stateGCHandle) + int numThreads, bool tuneNumLocIter, ref int numLocIter, float tolerance, bool needShuffle, bool shouldInitialize, GCHandle stateGCHandle, ChannelCallBack info) { inputDataManager.PrepareCursoring(); @@ -829,11 +834,12 @@ public static void LearnAll(InputDataManager inputDataManager, bool tuneLR, fixed (float* pInstLabels = &instLabels[0]) { LearnAll(totalNumInstances, pInstSizes, pIndicesPointer, pValuesPointer, pInstLabels, tuneLR, ref lr, l2Const, piw, - pweightVector, ref bias, numFeatres, numPasses, numThreads, tuneNumLocIter, ref numLocIter, tolerance, needShuffle, shouldInitialize, (State*)stateGCHandle.AddrOfPinnedObject()); + pweightVector, ref bias, numFeatres, numPasses, numThreads, tuneNumLocIter, ref numLocIter, tolerance, needShuffle, + shouldInitialize, (State*)stateGCHandle.AddrOfPinnedObject(), info); } } - [DllImport(DllName), SuppressUnmanagedCodeSecurity] + [DllImport(NativePath), SuppressUnmanagedCodeSecurity] private static extern void MapBackWeightVector(float* weightVector, State* state); /// @@ -847,7 +853,7 @@ public static void MapBackWeightVector(Span weightVector, GCHandle stateG MapBackWeightVector(pweightVector, (State*)stateGCHandle.AddrOfPinnedObject()); } - [DllImport(DllName), SuppressUnmanagedCodeSecurity] + [DllImport(NativePath), SuppressUnmanagedCodeSecurity] private static extern void DeallocateSequentially(State* state); public static void DeallocateSequentially(GCHandle stateGCHandle) @@ -856,8 +862,7 @@ public static void DeallocateSequentially(GCHandle stateGCHandle) } // See: https://software.intel.com/en-us/node/521990 - [System.Security.SuppressUnmanagedCodeSecurity] - [DllImport("MklImports", EntryPoint = "DftiErrorMessage", CallingConvention = CallingConvention.Cdecl, CharSet = CharSet.Auto)] + [DllImport(MklPath, EntryPoint = "DftiErrorMessage", CallingConvention = CallingConvention.Cdecl, CharSet = CharSet.Auto), SuppressUnmanagedCodeSecurity] private static extern IntPtr ErrorMessage(int status); } diff --git a/src/Microsoft.ML.HalLearners/VectorWhitening.cs b/src/Microsoft.ML.HalLearners/VectorWhitening.cs index 9c9cd820b4..9d7daf09ec 100644 --- a/src/Microsoft.ML.HalLearners/VectorWhitening.cs +++ b/src/Microsoft.ML.HalLearners/VectorWhitening.cs @@ -2,23 +2,21 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Core.Data; -using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.CpuMath; -using Microsoft.ML.Runtime.Internal.Internallearn; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.StaticPipe; -using Microsoft.ML.StaticPipe.Runtime; -using Microsoft.ML.Transforms.Projections; using System; using System.Collections.Generic; using System.Linq; using System.Runtime.InteropServices; +using System.Security; using System.Text; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.CpuMath; +using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Transforms.Projections; [assembly: LoadableClass(VectorWhiteningTransformer.Summary, typeof(IDataTransform), typeof(VectorWhiteningTransformer), typeof(VectorWhiteningTransformer.Arguments), typeof(SignatureDataTransform), VectorWhiteningTransformer.FriendlyName, VectorWhiteningTransformer.LoaderSignature, "Whitening")] @@ -46,6 +44,7 @@ public enum WhiteningKind /// public sealed class VectorWhiteningTransformer : OneToOneTransformerBase { + [BestFriend] internal static class Defaults { public const WhiteningKind Kind = WhiteningKind.Zca; @@ -130,7 +129,7 @@ public sealed class ColumnInfo /// Describes how the transformer handles one input-output column pair. /// /// Name of the input column. - /// Name of the column resulting from the transformation of . Null means is replaced. + /// Name of the column resulting from the transformation of . Null means is replaced. /// Whitening kind (PCA/ZCA). /// Whitening constant, prevents division by zero. /// Maximum number of rows used to train the transform. @@ -305,15 +304,15 @@ internal static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx => Create(env, ctx).MakeDataTransform(input); // Factory method for SignatureLoadRowMapper. - internal static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) - => Create(env, ctx).MakeRowMapper(Schema.Create(inputSchema)); + internal static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema) + => Create(env, ctx).MakeRowMapper(inputSchema); private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns) => columns.Select(c => (c.Input, c.Output ?? c.Input)).ToArray(); - protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol) + protected override void CheckInputColumn(Schema inputSchema, int col, int srcCol) { - var inType = inputSchema.GetColumnType(srcCol); + var inType = inputSchema[srcCol].Type; var reason = TestColumn(inType); if (reason != null) throw Host.ExceptParam(nameof(inputSchema), reason); @@ -322,7 +321,8 @@ protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCo // Check if the input column's type is supported. Note that only float vector with a known shape is allowed. internal static string TestColumn(ColumnType type) { - if ((type.IsVector && !type.IsKnownSizeVector && (type.AsVector.Dimensions.Length > 1)) || type.ItemType != NumberType.R4) + if ((type is VectorType vectorType && !vectorType.IsKnownSizeVector && vectorType.Dimensions.Length > 1) + || type.ItemType != NumberType.R4) return "Expected float or float vector of known size"; if ((long)type.ValueCount * type.ValueCount > Utils.ArrayMaxSize) @@ -383,7 +383,7 @@ private static void GetColTypesAndIndex(IHostEnvironment env, IDataView inputDat { if (!inputSchema.TryGetColumnIndex(columns[i].Input, out cols[i])) throw env.ExceptSchemaMismatch(nameof(inputSchema), "input", columns[i].Input); - srcTypes[i] = inputSchema.GetColumnType(cols[i]); + srcTypes[i] = inputSchema[cols[i]].Type; var reason = TestColumn(srcTypes[i]); if (reason != null) throw env.ExceptParam(nameof(inputData.Schema), reason); @@ -585,7 +585,7 @@ public override void Save(ModelSaveContext ctx) private static class Mkl { - private const string DllName = "MklImports"; + private const string MklPath = "MklImports"; // The allowed value of Layout is specified in Intel's MLK library. See Layout parameter in this // [doc](https://software.intel.com/en-us/mkl-developer-reference-c-cblas-gemm) for details. @@ -624,22 +624,22 @@ public static unsafe void Gemv(Layout layout, Transpose trans, int m, int n, flo } // See: https://software.intel.com/en-us/node/520750 - [DllImport(DllName, EntryPoint = "cblas_sgemv")] + [DllImport(MklPath, CallingConvention = CallingConvention.Cdecl, EntryPoint = "cblas_sgemv"), SuppressUnmanagedCodeSecurity] private static unsafe extern void Gemv(Layout layout, Transpose trans, int m, int n, float alpha, float* a, int lda, float* x, int incx, float beta, float* y, int incy); // See: https://software.intel.com/en-us/node/520775 - [DllImport(DllName, EntryPoint = "cblas_sgemm")] + [DllImport(MklPath, CallingConvention = CallingConvention.Cdecl, EntryPoint = "cblas_sgemm"), SuppressUnmanagedCodeSecurity] public static extern void Gemm(Layout layout, Transpose transA, Transpose transB, int m, int n, int k, float alpha, float[] a, int lda, float[] b, int ldb, float beta, float[] c, int ldc); // See: https://software.intel.com/en-us/node/521150 - [DllImport(DllName, EntryPoint = "LAPACKE_sgesvd")] + [DllImport(MklPath, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LAPACKE_sgesvd"), SuppressUnmanagedCodeSecurity] public static extern int Svd(Layout layout, SvdJob jobu, SvdJob jobvt, int m, int n, float[] a, int lda, float[] s, float[] u, int ldu, float[] vt, int ldvt, float[] superb); } - protected override IRowMapper MakeRowMapper(Schema schema) + private protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); private sealed class Mapper : OneToOneMapperBase @@ -659,7 +659,7 @@ public Mapper(VectorWhiteningTransformer parent, Schema inputSchema) { if (!InputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out _cols[i])) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].input); - _srcTypes[i] = inputSchema.GetColumnType(_cols[i]); + _srcTypes[i] = inputSchema[_cols[i]].Type; ValidateModel(Host, _parent._models[i], _srcTypes[i]); if (_parent._columns[i].SaveInv) ValidateModel(Host, _parent._invModels[i], _srcTypes[i]); @@ -689,7 +689,7 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() return result; } - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { Host.AssertValue(input); Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); @@ -714,7 +714,7 @@ protected override Delegate MakeGetter(IRow input, int iinfo, Func ac return del; } - private ValueGetter GetSrcGetter(IRow input, int iinfo) + private ValueGetter GetSrcGetter(Row input, int iinfo) { Host.AssertValue(input); Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); @@ -806,7 +806,7 @@ public VectorWhiteningTransformer Fit(IDataView input) public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); foreach (var colPair in _infos) { if (!inputSchema.TryFindColumn(colPair.Input, out var col)) diff --git a/src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs b/src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs index fe5a83b5f6..18e5cbc61f 100644 --- a/src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs +++ b/src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs @@ -1,16 +1,16 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.ImageAnalytics.EntryPoints; + +using Microsoft.ML; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.ImageAnalytics.EntryPoints; [assembly: LoadableClass(typeof(void), typeof(ImageAnalytics), null, typeof(SignatureEntryPointModule), "ImageAnalytics")] -namespace Microsoft.ML.Runtime.ImageAnalytics.EntryPoints +namespace Microsoft.ML.ImageAnalytics.EntryPoints { public static class ImageAnalytics { - [TlcModule.EntryPoint(Name = "Transforms.ImageLoader", Desc = ImageLoaderTransform.Summary, UserName = ImageLoaderTransform.UserName, ShortName = ImageLoaderTransform.LoaderSignature)] public static CommonOutputs.TransformOutput ImageLoader(IHostEnvironment env, ImageLoaderTransform.Arguments input) @@ -19,7 +19,7 @@ public static CommonOutputs.TransformOutput ImageLoader(IHostEnvironment env, Im var xf = ImageLoaderTransform.Create(h, input, input.Data); return new CommonOutputs.TransformOutput() { - Model = new TransformModel(h, xf, input.Data), + Model = new TransformModelImpl(h, xf, input.Data), OutputData = xf }; } @@ -32,7 +32,7 @@ public static CommonOutputs.TransformOutput ImageResizer(IHostEnvironment env, I var xf = ImageResizerTransform.Create(h, input, input.Data); return new CommonOutputs.TransformOutput() { - Model = new TransformModel(h, xf, input.Data), + Model = new TransformModelImpl(h, xf, input.Data), OutputData = xf }; } @@ -45,7 +45,7 @@ public static CommonOutputs.TransformOutput ImagePixelExtractor(IHostEnvironment var xf = ImagePixelExtractorTransform.Create(h, input, input.Data); return new CommonOutputs.TransformOutput() { - Model = new TransformModel(h, xf, input.Data), + Model = new TransformModelImpl(h, xf, input.Data), OutputData = xf }; } @@ -58,7 +58,7 @@ public static CommonOutputs.TransformOutput ImageGrayscale(IHostEnvironment env, var xf = ImageGrayscaleTransform.Create(h, input, input.Data); return new CommonOutputs.TransformOutput() { - Model = new TransformModel(h, xf, input.Data), + Model = new TransformModelImpl(h, xf, input.Data), OutputData = xf }; } @@ -71,7 +71,7 @@ public static CommonOutputs.TransformOutput VectorToImage(IHostEnvironment env, var xf = new VectorToImageTransform(h, input, input.Data); return new CommonOutputs.TransformOutput() { - Model = new TransformModel(h, xf, input.Data), + Model = new TransformModelImpl(h, xf, input.Data), OutputData = xf }; } diff --git a/src/Microsoft.ML.ImageAnalytics/ExtensionsCatalog.cs b/src/Microsoft.ML.ImageAnalytics/ExtensionsCatalog.cs index 0c76c08384..ebc9ee1ff3 100644 --- a/src/Microsoft.ML.ImageAnalytics/ExtensionsCatalog.cs +++ b/src/Microsoft.ML.ImageAnalytics/ExtensionsCatalog.cs @@ -2,9 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.ImageAnalytics; +using Microsoft.ML.Data; +using Microsoft.ML.ImageAnalytics; namespace Microsoft.ML { diff --git a/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs index 7b89bdb173..5cb7642893 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs @@ -2,23 +2,22 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Core.Data; -using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.ImageAnalytics; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.StaticPipe; -using Microsoft.ML.StaticPipe.Runtime; using System; using System.Collections.Generic; using System.Drawing; using System.Drawing.Imaging; using System.Linq; using System.Text; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.ImageAnalytics; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.StaticPipe; +using Microsoft.ML.StaticPipe.Runtime; [assembly: LoadableClass(ImageGrayscaleTransform.Summary, typeof(IDataTransform), typeof(ImageGrayscaleTransform), typeof(ImageGrayscaleTransform.Arguments), typeof(SignatureDataTransform), ImageGrayscaleTransform.UserName, "ImageGrayscaleTransform", "ImageGrayscale")] @@ -32,7 +31,7 @@ [assembly: LoadableClass(typeof(IRowMapper), typeof(ImageGrayscaleTransform), null, typeof(SignatureLoadRowMapper), ImageGrayscaleTransform.UserName, ImageGrayscaleTransform.LoaderSignature)] -namespace Microsoft.ML.Runtime.ImageAnalytics +namespace Microsoft.ML.ImageAnalytics { // REVIEW: Rewrite as LambdaTransform to simplify. // REVIEW: Should it be separate transform or part of ImageResizerTransform? @@ -127,8 +126,8 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, => Create(env, ctx).MakeDataTransform(input); // Factory method for SignatureLoadRowMapper. - private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) - => Create(env, ctx).MakeRowMapper(Schema.Create(inputSchema)); + private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema) + => Create(env, ctx).MakeRowMapper(inputSchema); public override void Save(ModelSaveContext ctx) { @@ -152,12 +151,12 @@ public override void Save(ModelSaveContext ctx) new float[] {0, 0, 0, 0, 1} }); - protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); + private protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); - protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol) + protected override void CheckInputColumn(Schema inputSchema, int col, int srcCol) { - if (!(inputSchema.GetColumnType(srcCol) is ImageType)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input, "image", inputSchema.GetColumnType(srcCol).ToString()); + if (!(inputSchema[srcCol].Type is ImageType)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input, "image", inputSchema[srcCol].Type.ToString()); } private sealed class Mapper : OneToOneMapperBase @@ -173,7 +172,7 @@ public Mapper(ImageGrayscaleTransform parent, Schema inputSchema) protected override Schema.DetachedColumn[] GetOutputColumnsCore() => _parent.ColumnPairs.Select((x, idx) => new Schema.DetachedColumn(x.output, InputSchema[ColMapNewToOld[idx]].Type, null)).ToArray(); - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { Contracts.AssertValue(input); Contracts.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); @@ -227,7 +226,7 @@ public ImageGrayscalingEstimator(IHostEnvironment env, params (string input, str public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in Transformer.Columns) { if (!inputSchema.TryFindColumn(colInfo.input, out var col)) diff --git a/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs index 81b3d339e5..a79085b606 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs @@ -2,23 +2,22 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Core.Data; -using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.ImageAnalytics; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.StaticPipe; -using Microsoft.ML.StaticPipe.Runtime; using System; using System.Collections.Generic; using System.Drawing; using System.IO; using System.Linq; using System.Text; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.ImageAnalytics; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.StaticPipe; +using Microsoft.ML.StaticPipe.Runtime; [assembly: LoadableClass(ImageLoaderTransform.Summary, typeof(IDataTransform), typeof(ImageLoaderTransform), typeof(ImageLoaderTransform.Arguments), typeof(SignatureDataTransform), ImageLoaderTransform.UserName, "ImageLoaderTransform", "ImageLoader")] @@ -30,7 +29,7 @@ [assembly: LoadableClass(typeof(IRowMapper), typeof(ImageLoaderTransform), null, typeof(SignatureLoadRowMapper), "", ImageLoaderTransform.LoaderSignature)] -namespace Microsoft.ML.Runtime.ImageAnalytics +namespace Microsoft.ML.ImageAnalytics { /// /// Transform which takes one or many columns of type ReadOnlyMemory and loads them as @@ -111,13 +110,13 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, => Create(env, ctx).MakeDataTransform(input); // Factory method for SignatureLoadRowMapper. - private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) - => Create(env, ctx).MakeRowMapper(Schema.Create(inputSchema)); + private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema) + => Create(env, ctx).MakeRowMapper(inputSchema); - protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol) + protected override void CheckInputColumn(Schema inputSchema, int col, int srcCol) { - if (!inputSchema.GetColumnType(srcCol).IsText) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input, TextType.Instance.ToString(), inputSchema.GetColumnType(srcCol).ToString()); + if (!inputSchema[srcCol].Type.IsText) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input, TextType.Instance.ToString(), inputSchema[srcCol].Type.ToString()); } public override void Save(ModelSaveContext ctx) @@ -147,7 +146,7 @@ private static VersionInfo GetVersionInfo() loaderAssemblyName: typeof(ImageLoaderTransform).Assembly.FullName); } - protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); + private protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); private sealed class Mapper : OneToOneMapperBase { @@ -161,7 +160,7 @@ public Mapper(ImageLoaderTransform parent, Schema inputSchema) _parent = parent; } - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { Contracts.AssertValue(input); Contracts.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); @@ -230,7 +229,7 @@ public ImageLoadingEstimator(IHostEnvironment env, ImageLoaderTransform transfor public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); foreach (var (input, output) in Transformer.Columns) { if (!inputSchema.TryFindColumn(input, out var col)) diff --git a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs index ceecacb0f9..dd76d71da7 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs @@ -8,15 +8,14 @@ using System.Linq; using System.Runtime.InteropServices; using System.Text; +using Microsoft.ML; +using Microsoft.ML.CommandLine; using Microsoft.ML.Core.Data; using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.ImageAnalytics; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.ImageAnalytics; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; using Microsoft.ML.StaticPipe; using Microsoft.ML.StaticPipe.Runtime; @@ -32,7 +31,7 @@ [assembly: LoadableClass(typeof(IRowMapper), typeof(ImagePixelExtractorTransform), null, typeof(SignatureLoadRowMapper), ImagePixelExtractorTransform.UserName, ImagePixelExtractorTransform.LoaderSignature)] -namespace Microsoft.ML.Runtime.ImageAnalytics +namespace Microsoft.ML.ImageAnalytics { /// /// Transform which takes one or many columns of and convert them into vector representation. @@ -378,8 +377,8 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, => Create(env, ctx).MakeDataTransform(input); // Factory method for SignatureLoadRowMapper. - private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) - => Create(env, ctx).MakeRowMapper(Schema.Create(inputSchema)); + private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema) + => Create(env, ctx).MakeRowMapper(inputSchema); public override void Save(ModelSaveContext ctx) { @@ -400,14 +399,14 @@ public override void Save(ModelSaveContext ctx) info.Save(ctx); } - protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); + private protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); - protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol) + protected override void CheckInputColumn(Schema inputSchema, int col, int srcCol) { var inputColName = _columns[col].Input; - var imageType = inputSchema.GetColumnType(srcCol) as ImageType; + var imageType = inputSchema[srcCol].Type as ImageType; if (imageType == null) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", inputColName, "image", inputSchema.GetColumnType(srcCol).ToString()); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", inputColName, "image", inputSchema[srcCol].Type.ToString()); if (imageType.Height <= 0 || imageType.Width <= 0) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", inputColName, "known-size image", "unknown-size image"); if ((long)imageType.Height * imageType.Width > int.MaxValue / 4) @@ -429,7 +428,7 @@ public Mapper(ImagePixelExtractorTransform parent, Schema inputSchema) protected override Schema.DetachedColumn[] GetOutputColumnsCore() => _parent._columns.Select((x, idx) => new Schema.DetachedColumn(x.Output, _types[idx], null)).ToArray(); - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { Contracts.AssertValue(input); Contracts.Assert(0 <= iinfo && iinfo < _parent._columns.Length); @@ -440,7 +439,7 @@ protected override Delegate MakeGetter(IRow input, int iinfo, Func ac } //REVIEW Rewrite it to where TValue : IConvertible - private ValueGetter> GetGetterCore(IRow input, int iinfo, out Action disposer) + private ValueGetter> GetGetterCore(Row input, int iinfo, out Action disposer) where TValue : struct { var type = _types[iinfo]; @@ -483,7 +482,9 @@ private ValueGetter> GetGetterCore(IRow input, int iinfo return; } - Host.Check(src.PixelFormat == System.Drawing.Imaging.PixelFormat.Format32bppArgb); + Host.Check(src.PixelFormat == System.Drawing.Imaging.PixelFormat.Format32bppArgb + || src.PixelFormat == System.Drawing.Imaging.PixelFormat.Format24bppRgb, + "Transform only supports pixel formats Format24bppRgb and Format32bppArgb"); Host.Check(src.Height == height && src.Width == width); var editor = VBufferEditor.Create(ref dst, size); @@ -542,28 +543,6 @@ private ValueGetter> GetGetterCore(IRow input, int iinfo else { int idstMin = 0; - if (ex.Alpha) - { - // The image only has rgb but we need to supply alpha as well, so fake it up, - // assuming that it is 0xFF. - if (!vf.IsEmpty) - { - Single v = (0xFF - offset) * scale; - for (int i = 0; i < cpix; i++) - vf[i] = v; - } - else - { - for (int i = 0; i < cpix; i++) - vb[i] = 0xFF; - } - idstMin = cpix; - - // We've preprocessed alpha, avoid it in the - // scan operation below. - a = false; - } - for (int y = 0; y < h; ++y) { int idstBase = idstMin + y * w; @@ -655,7 +634,7 @@ public ImagePixelExtractingEstimator(IHostEnvironment env, params ImagePixelExtr public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in Transformer.Columns) { if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) diff --git a/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs index e5054b515f..3d795846ba 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs @@ -7,16 +7,15 @@ using System.Drawing; using System.Linq; using System.Text; +using Microsoft.ML; +using Microsoft.ML.CommandLine; using Microsoft.ML.Core.Data; using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.ImageAnalytics; -using Microsoft.ML.Runtime.Internal.Internallearn; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.ImageAnalytics; +using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; using Microsoft.ML.StaticPipe; using Microsoft.ML.StaticPipe.Runtime; @@ -32,7 +31,7 @@ [assembly: LoadableClass(typeof(IRowMapper), typeof(ImageResizerTransform), null, typeof(SignatureLoadRowMapper), ImageResizerTransform.UserName, ImageResizerTransform.LoaderSignature)] -namespace Microsoft.ML.Runtime.ImageAnalytics +namespace Microsoft.ML.ImageAnalytics { // REVIEW: Rewrite as LambdaTransform to simplify. /// @@ -253,8 +252,8 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, => Create(env, ctx).MakeDataTransform(input); // Factory method for SignatureLoadRowMapper. - private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) - => Create(env, ctx).MakeRowMapper(Schema.Create(inputSchema)); + private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema) + => Create(env, ctx).MakeRowMapper(inputSchema); public override void Save(ModelSaveContext ctx) { @@ -285,12 +284,12 @@ public override void Save(ModelSaveContext ctx) } } - protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); + private protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); - protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol) + protected override void CheckInputColumn(Schema inputSchema, int col, int srcCol) { - if (!(inputSchema.GetColumnType(srcCol) is ImageType)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _columns[col].Input, "image", inputSchema.GetColumnType(srcCol).ToString()); + if (!(inputSchema[srcCol].Type is ImageType)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _columns[col].Input, "image", inputSchema[srcCol].Type.ToString()); } private sealed class Mapper : OneToOneMapperBase @@ -306,7 +305,7 @@ public Mapper(ImageResizerTransform parent, Schema inputSchema) protected override Schema.DetachedColumn[] GetOutputColumnsCore() => _parent._columns.Select(x => new Schema.DetachedColumn(x.Output, x.Type, null)).ToArray(); - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { Contracts.AssertValue(input); Contracts.Assert(0 <= iinfo && iinfo < _parent._columns.Length); @@ -411,7 +410,7 @@ protected override Delegate MakeGetter(IRow input, int iinfo, Func ac destWidth = (int)(sourceWidth * aspect); destHeight = (int)(sourceHeight * aspect); } - dst = new Bitmap(info.Width, info.Height); + dst = new Bitmap(info.Width, info.Height, src.PixelFormat); var srcRectangle = new Rectangle(sourceX, sourceY, sourceWidth, sourceHeight); var destRectangle = new Rectangle(destX, destY, destWidth, destHeight); using (var g = Graphics.FromImage(dst)) @@ -447,7 +446,7 @@ public ImageResizingEstimator(IHostEnvironment env, ImageResizerTransform transf public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in Transformer.Columns) { if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) diff --git a/src/Microsoft.ML.ImageAnalytics/ImageStaticPipe.cs b/src/Microsoft.ML.ImageAnalytics/ImageStaticPipe.cs index cf67103d86..aaf99431ee 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageStaticPipe.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageStaticPipe.cs @@ -6,7 +6,7 @@ using System.Drawing; using Microsoft.ML.StaticPipe; -namespace Microsoft.ML.Runtime.ImageAnalytics +namespace Microsoft.ML.ImageAnalytics { /// /// A type used in the generic argument to . We must simultaneously distinguish diff --git a/src/Microsoft.ML.ImageAnalytics/ImageType.cs b/src/Microsoft.ML.ImageAnalytics/ImageType.cs index fd31302808..270e86a411 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageType.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageType.cs @@ -2,12 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; using System.Drawing; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime.ImageAnalytics +namespace Microsoft.ML.ImageAnalytics { public sealed class ImageType : StructuredType { diff --git a/src/Microsoft.ML.ImageAnalytics/Microsoft.ML.ImageAnalytics.csproj b/src/Microsoft.ML.ImageAnalytics/Microsoft.ML.ImageAnalytics.csproj index 9829bcf26f..0134a050fb 100644 --- a/src/Microsoft.ML.ImageAnalytics/Microsoft.ML.ImageAnalytics.csproj +++ b/src/Microsoft.ML.ImageAnalytics/Microsoft.ML.ImageAnalytics.csproj @@ -2,8 +2,8 @@ netstandard2.0 - Microsoft.ML.Runtime.ImageAnalytics - Microsoft.ML.Runtime.ImageAnalytics + Microsoft.ML.ImageAnalytics + Microsoft.ML.ImageAnalytics Microsoft.ML.ImageAnalytics diff --git a/src/Microsoft.ML.ImageAnalytics/VectorToImageTransform.cs b/src/Microsoft.ML.ImageAnalytics/VectorToImageTransform.cs index 5468714696..0d2ccb53da 100644 --- a/src/Microsoft.ML.ImageAnalytics/VectorToImageTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/VectorToImageTransform.cs @@ -5,13 +5,13 @@ using System; using System.Drawing; using System.Text; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.ImageAnalytics; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.ImageAnalytics; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; [assembly: LoadableClass(VectorToImageTransform.Summary, typeof(VectorToImageTransform), typeof(VectorToImageTransform.Arguments), typeof(SignatureDataTransform), VectorToImageTransform.UserName, "VectorToImageTransform", "VectorToImage")] @@ -19,7 +19,7 @@ [assembly: LoadableClass(VectorToImageTransform.Summary, typeof(VectorToImageTransform), null, typeof(SignatureLoadDataTransform), VectorToImageTransform.UserName, VectorToImageTransform.LoaderSignature)] -namespace Microsoft.ML.Runtime.ImageAnalytics +namespace Microsoft.ML.ImageAnalytics { // REVIEW: Rewrite as LambdaTransform to simplify. @@ -330,7 +330,7 @@ protected override ColumnType GetColumnTypeCore(int iinfo) return _types[iinfo]; } - protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer) + protected override Delegate GetGetterCore(IChannel ch, Row input, int iinfo, out Action disposer) { Host.AssertValueOrNull(ch); Host.AssertValue(input); @@ -340,7 +340,7 @@ protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, ou var ex = _exes[iinfo]; bool needScale = ex.Offset != 0 || ex.Scale != 1; disposer = null; - var sourceType = Schema.GetColumnType(Infos[iinfo].Source); + var sourceType = InputSchema[Infos[iinfo].Source].Type; if (sourceType.ItemType == NumberType.R4 || sourceType.ItemType == NumberType.R8) return GetterFromType(input, iinfo, ex, needScale); else @@ -351,7 +351,7 @@ protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, ou } - private ValueGetter GetterFromType(IRow input, int iinfo, ColInfoEx ex, bool needScale) where TValue : IConvertible + private ValueGetter GetterFromType(Row input, int iinfo, ColInfoEx ex, bool needScale) where TValue : IConvertible { var getSrc = GetSrcGetter>(input, iinfo); var src = default(VBuffer); diff --git a/src/Microsoft.ML.KMeansClustering/KMeansCatalog.cs b/src/Microsoft.ML.KMeansClustering/KMeansCatalog.cs index c096750ae0..3fa9f16741 100644 --- a/src/Microsoft.ML.KMeansClustering/KMeansCatalog.cs +++ b/src/Microsoft.ML.KMeansClustering/KMeansCatalog.cs @@ -2,10 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Trainers.KMeans; using System; +using Microsoft.ML.Data; +using Microsoft.ML.Trainers.KMeans; namespace Microsoft.ML { diff --git a/src/Microsoft.ML.KMeansClustering/KMeansPredictor.cs b/src/Microsoft.ML.KMeansClustering/KMeansModelParameters.cs similarity index 87% rename from src/Microsoft.ML.KMeansClustering/KMeansPredictor.cs rename to src/Microsoft.ML.KMeansClustering/KMeansModelParameters.cs index 38b5116da4..9bbf803924 100644 --- a/src/Microsoft.ML.KMeansClustering/KMeansPredictor.cs +++ b/src/Microsoft.ML.KMeansClustering/KMeansModelParameters.cs @@ -2,33 +2,37 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Float = System.Single; - using System; +using System.Collections.Generic; using System.IO; -using Microsoft.ML.Runtime.Numeric; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; +using Microsoft.ML; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model; +using Microsoft.ML.Model.Onnx; +using Microsoft.ML.Numeric; using Microsoft.ML.Trainers.KMeans; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Model.Onnx; -using Microsoft.ML.Runtime.Internal.Internallearn; -using System.Collections.Generic; +using Float = System.Single; -[assembly: LoadableClass(typeof(KMeansPredictor), null, typeof(SignatureLoadModel), - "KMeans predictor", KMeansPredictor.LoaderSignature)] +[assembly: LoadableClass(typeof(KMeansModelParameters), null, typeof(SignatureLoadModel), + "KMeans predictor", KMeansModelParameters.LoaderSignature)] namespace Microsoft.ML.Trainers.KMeans { - public sealed class KMeansPredictor : - PredictorBase>, + /// + /// + /// + /// + public sealed class KMeansModelParameters : + ModelParametersBase>, IValueMapper, ICanSaveInTextFormat, - ICanSaveModel, ISingleCanSaveOnnx { - public const string LoaderSignature = "KMeansPredictor"; + internal const string LoaderSignature = "KMeansPredictor"; /// /// Version information to be saved in binary format @@ -42,12 +46,16 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010002, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(KMeansPredictor).Assembly.FullName); + loaderAssemblyName: typeof(KMeansModelParameters).Assembly.FullName); } + // REVIEW: Leaving this public for now until we figure out the correct way to remove it. public override PredictionKind PredictionKind => PredictionKind.Clustering; - public ColumnType InputType { get; } - public ColumnType OutputType { get; } + + private readonly ColumnType _inputType; + private readonly ColumnType _outputType; + ColumnType IValueMapper.InputType => _inputType; + ColumnType IValueMapper.OutputType => _outputType; bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true; @@ -66,7 +74,7 @@ private static VersionInfo GetVersionInfo() /// a deep copy, if false then this constructor will take ownership of the passed in centroid vectors. /// If false then the caller must take care to not use or modify the input vectors once this object /// is constructed, and should probably remove all references. - public KMeansPredictor(IHostEnvironment env, int k, VBuffer[] centroids, bool copyIn) + public KMeansModelParameters(IHostEnvironment env, int k, VBuffer[] centroids, bool copyIn) : base(env, LoaderSignature) { Host.CheckParam(k > 0, nameof(k), "Need at least one cluster"); @@ -92,8 +100,8 @@ public KMeansPredictor(IHostEnvironment env, int k, VBuffer[] centroids, InitPredictor(); - InputType = new VectorType(NumberType.Float, _dimensionality); - OutputType = new VectorType(NumberType.Float, _k); + _inputType = new VectorType(NumberType.Float, _dimensionality); + _outputType = new VectorType(NumberType.Float, _k); } /// @@ -101,7 +109,7 @@ public KMeansPredictor(IHostEnvironment env, int k, VBuffer[] centroids, /// /// The load context /// The host environment - private KMeansPredictor(IHostEnvironment env, ModelLoadContext ctx) + private KMeansModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, LoaderSignature, ctx) { // *** Binary format *** @@ -134,11 +142,11 @@ private KMeansPredictor(IHostEnvironment env, ModelLoadContext ctx) InitPredictor(); - InputType = new VectorType(NumberType.Float, _dimensionality); - OutputType = new VectorType(NumberType.Float, _k); + _inputType = new VectorType(NumberType.Float, _dimensionality); + _outputType = new VectorType(NumberType.Float, _k); } - public ValueMapper GetMapper() + ValueMapper IValueMapper.GetMapper() { Host.Check(typeof(TIn) == typeof(VBuffer)); Host.Check(typeof(TOut) == typeof(VBuffer)); @@ -169,7 +177,7 @@ private void Map(in VBuffer src, Span distances) } } - public void SaveAsText(TextWriter writer, RoleMappedSchema schema) + void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema) { writer.WriteLine("K: {0}", _k); writer.WriteLine("Dimensionality: {0}", _dimensionality); @@ -215,7 +223,7 @@ public void SaveAsText(TextWriter writer, RoleMappedSchema schema) /// Save the predictor in binary format. /// /// The context to save to - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { base.SaveCore(ctx); ctx.SetVersionInfo(GetVersionInfo()); @@ -247,12 +255,12 @@ protected override void SaveCore(ModelSaveContext ctx) /// /// This method is called by reflection to instantiate a predictor. /// - public static KMeansPredictor Create(IHostEnvironment env, ModelLoadContext ctx) + private static KMeansModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - return new KMeansPredictor(env, ctx); + return new KMeansModelParameters(env, ctx); } /// diff --git a/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs b/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs index 41084bd5ab..35f64309d6 100644 --- a/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs +++ b/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs @@ -2,21 +2,20 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Core.Data; -using Microsoft.ML.Data; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.CpuMath; -using Microsoft.ML.Runtime.Internal.Internallearn; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Numeric; -using Microsoft.ML.Runtime.Training; -using Microsoft.ML.Trainers.KMeans; using System; using System.Linq; using System.Threading.Tasks; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.CpuMath; +using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Numeric; +using Microsoft.ML.Trainers.KMeans; +using Microsoft.ML.Training; [assembly: LoadableClass(KMeansPlusPlusTrainer.Summary, typeof(KMeansPlusPlusTrainer), typeof(KMeansPlusPlusTrainer.Arguments), new[] { typeof(SignatureClusteringTrainer), typeof(SignatureTrainer) }, @@ -29,7 +28,7 @@ namespace Microsoft.ML.Trainers.KMeans { /// - public class KMeansPlusPlusTrainer : TrainerEstimatorBase, KMeansPredictor> + public class KMeansPlusPlusTrainer : TrainerEstimatorBase, KMeansModelParameters> { public const string LoadNameValue = "KMeansPlusPlus"; internal const string UserNameValue = "KMeans++ Clustering"; @@ -61,7 +60,7 @@ public class Arguments : UnsupervisedLearnerInputBaseWithWeight [Argument(ArgumentType.AtMostOnce, HelpText = "Cluster initialization algorithm", ShortName = "init")] public InitAlgorithm InitAlgorithm = InitAlgorithm.KMeansParallel; - [Argument(ArgumentType.AtMostOnce, HelpText = "Tolerance parameter for trainer convergence. Lower = slower, more accurate", + [Argument(ArgumentType.AtMostOnce, HelpText = "Tolerance parameter for trainer convergence. Low = slower, more accurate", ShortName = "ot")] [TGUI(Label = "Optimization Tolerance", Description = "Threshold for trainer convergence")] public float OptTol = (float)1e-7; @@ -121,7 +120,7 @@ internal KMeansPlusPlusTrainer(IHostEnvironment env, Arguments args) } private KMeansPlusPlusTrainer(IHostEnvironment env, Arguments args, Action advancedSettings = null) - : base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), null, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn)) + : base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), default, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn)) { Host.CheckValue(args, nameof(args)); @@ -151,7 +150,7 @@ private KMeansPlusPlusTrainer(IHostEnvironment env, Arguments args, Action MakeTransformer(KMeansPredictor model, Schema trainSchema) - => new ClusteringPredictionTransformer(Host, model, trainSchema, _featureColumn); + protected override ClusteringPredictionTransformer MakeTransformer(KMeansModelParameters model, Schema trainSchema) + => new ClusteringPredictionTransformer(Host, model, trainSchema, _featureColumn); } internal static class KMeansPlusPlusInit @@ -423,7 +422,7 @@ internal sealed class KMeansAcceleratedRowMap // each row. Instead the RowCursor provides a stable ID across multiple // cursorings. We map those IDs into an index to poke into the per instance // structures. - private readonly HashArray _parallelIndexLookup; + private readonly HashArray _parallelIndexLookup; public KMeansAcceleratedRowMap(FeatureFloatVectorCursor.Factory factory, IChannel ch, long baseMaxInstancesToAccelerate, long totalTrainingInstances, bool isParallel) @@ -469,11 +468,11 @@ public KMeansAcceleratedRowMap(FeatureFloatVectorCursor.Factory factory, IChanne /// preinitialize the HashArray so we can perform lock-free lookup operations during /// the primary KMeans pass. /// - private HashArray BuildParallelIndexLookup(FeatureFloatVectorCursor.Factory factory) + private HashArray BuildParallelIndexLookup(FeatureFloatVectorCursor.Factory factory) { Contracts.AssertValue(factory); - HashArray lookup = new HashArray(); + HashArray lookup = new HashArray(); int n = 0; using (var cursor = factory.Create()) { @@ -875,7 +874,7 @@ public static void Initialize(IHost host, int numThreads, IChannel ch, FeatureFl KMeansUtils.ParallelMapReduce( numThreads, host, cursorFactory, initializationState.RowIndexGetter, (ref float[] weights) => weights = new float[totalSamples], - (ref VBuffer point, int pointRowIndex, float[] weights, IRandom rand) => + (ref VBuffer point, int pointRowIndex, float[] weights, Random rand) => { int bestCluster; float discardBestWeight; @@ -887,7 +886,7 @@ public static void Initialize(IHost host, int numThreads, IChannel ch, FeatureFl #endif weights[bestCluster]++; }, - (float[][] workStateWeights, IRandom rand, ref float[] weights) => + (float[][] workStateWeights, Random rand, ref float[] weights) => { weights = new float[totalSamples]; for (int i = 0; i < workStateWeights.Length; i++) @@ -902,8 +901,8 @@ public static void Initialize(IHost host, int numThreads, IChannel ch, FeatureFl KMeansUtils.ParallelMapReduce( numThreads, host, cursorFactory, (FeatureFloatVectorCursor cur) => -1, (ref float[] weights) => weights = new float[totalSamples], - (ref VBuffer point, int discard, float[] weights, IRandom rand) => weights[KMeansUtils.FindBestCluster(in point, clusters, clustersL2s)]++, - (float[][] workStateWeights, IRandom rand, ref float[] weights) => + (ref VBuffer point, int discard, float[] weights, Random rand) => weights[KMeansUtils.FindBestCluster(in point, clusters, clustersL2s)]++, + (float[][] workStateWeights, Random rand, ref float[] weights) => { weights = new float[totalSamples]; for (int i = 0; i < workStateWeights.Length; i++) @@ -1572,7 +1571,7 @@ public static RowStats ParallelWeightedReservoirSample( else heap.Clear(); }, - (ref VBuffer point, int pointRowIndex, Heap heap, IRandom rand) => + (ref VBuffer point, int pointRowIndex, Heap heap, Random rand) => { // We use distance as a proxy for 'is the same point'. By excluding // all points that lie within a very small distance of our current set of @@ -1608,7 +1607,7 @@ public static RowStats ParallelWeightedReservoirSample( Utils.Swap(ref wRow.Point, ref point); heap.Add(wRow); }, - (Heap[] heaps, IRandom rand, ref Heap finalHeap) => + (Heap[] heaps, Random rand, ref Heap finalHeap) => { host.Assert(finalHeap == null); finalHeap = new Heap((x, y) => x.Weight > y.Weight, numSamples); @@ -1647,13 +1646,13 @@ public static RowStats ParallelWeightedReservoirSample( public delegate void InitAction(ref TPartitionState val); public delegate int RowIndexGetter(FeatureFloatVectorCursor cur); - public delegate void MapAction(ref VBuffer point, int rowIndex, TPartitionState state, IRandom rand); - public delegate void ReduceAction(TPartitionState[] intermediates, IRandom rand, ref TGlobalState result); + public delegate void MapAction(ref VBuffer point, int rowIndex, TPartitionState state, Random rand); + public delegate void ReduceAction(TPartitionState[] intermediates, Random rand, ref TGlobalState result); /// /// Takes a data cursor and perform an in-memory parallel aggregation operation on it. This /// helper wraps some of the behavior common to parallel operations over a IRowCursor set, - /// including building the set, creating separate IRandom instances, and IRowCursor disposal. + /// including building the set, creating separate Random instances, and IRowCursor disposal. /// /// The type that each parallel cursor will be expected to aggregate to. /// The type of the final output from combining each per-thread instance of TInterAgg. @@ -1688,7 +1687,7 @@ public static RowStats ParallelMapReduce( var cur = set[i]; initChunk(ref buffer[i]); var innerWorkState = buffer[i]; - IRandom rand = RandomUtils.Create(baseHost.Rand); + Random rand = RandomUtils.Create(baseHost.Rand); workArr[i] = () => { using (cur) diff --git a/src/Microsoft.ML.KMeansClustering/KMeansStatic.cs b/src/Microsoft.ML.KMeansClustering/KMeansStatic.cs index f5fe8a91d9..7a7aef03ba 100644 --- a/src/Microsoft.ML.KMeansClustering/KMeansStatic.cs +++ b/src/Microsoft.ML.KMeansClustering/KMeansStatic.cs @@ -2,11 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Trainers.KMeans; -using Microsoft.ML.StaticPipe.Runtime; using System; +using Microsoft.ML.StaticPipe.Runtime; +using Microsoft.ML.Trainers.KMeans; namespace Microsoft.ML.StaticPipe { @@ -33,7 +31,7 @@ public static (Vector score, Key predictedLabel) KMeans(this Cluste Vector features, Scalar weights = null, int clustersCount = KMeansPlusPlusTrainer.Defaults.K, Action advancedSettings = null, - Action onFit = null) + Action onFit = null) { Contracts.CheckValue(features, nameof(features)); Contracts.CheckValueOrNull(weights); diff --git a/src/Microsoft.ML.Legacy/AssemblyRegistration.cs b/src/Microsoft.ML.Legacy/AssemblyRegistration.cs index 022645b069..72e2debc77 100644 --- a/src/Microsoft.ML.Legacy/AssemblyRegistration.cs +++ b/src/Microsoft.ML.Legacy/AssemblyRegistration.cs @@ -2,18 +2,18 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime.Api; -using Microsoft.ML.Runtime.Data; +using System; +using System.Reflection; +using Microsoft.ML.Data; +using Microsoft.ML.Ensemble; +using Microsoft.ML.Sweeper; +using Microsoft.ML.Tools; using Microsoft.ML.Trainers.FastTree; using Microsoft.ML.Trainers.KMeans; using Microsoft.ML.Trainers.PCA; -using Microsoft.ML.Runtime.Sweeper; -using Microsoft.ML.Runtime.Tools; using Microsoft.ML.Transforms.Categorical; -using System; -using System.Reflection; -namespace Microsoft.ML.Runtime +namespace Microsoft.ML { internal static class AssemblyRegistration { @@ -38,24 +38,23 @@ public static void RegisterAssemblies(IHostEnvironment environment) /// private static bool LoadStandardAssemblies() { - Assembly apiAssembly = typeof(LambdaTransform).Assembly; // ML.Api - AssemblyName apiAssemblyName = apiAssembly.GetName(); + Assembly dataAssembly = typeof(TextLoader).Assembly; // ML.Data + AssemblyName dataAssemblyName = dataAssembly.GetName(); - _ = typeof(TextLoader).Assembly; // ML.Data - //_ = typeof(EnsemblePredictor).Assembly); // ML.Ensemble BUG https://github.com/dotnet/machinelearning/issues/1078 Ensemble isn't in a NuGet package - _ = typeof(FastTreeBinaryPredictor).Assembly; // ML.FastTree - _ = typeof(KMeansPredictor).Assembly; // ML.KMeansClustering + _ = typeof(EnsembleModelParameters).Assembly; // ML.Ensemble + _ = typeof(FastTreeBinaryModelParameters).Assembly; // ML.FastTree + _ = typeof(KMeansModelParameters).Assembly; // ML.KMeansClustering _ = typeof(Maml).Assembly; // ML.Maml - _ = typeof(PcaPredictor).Assembly; // ML.PCA + _ = typeof(PcaModelParameters).Assembly; // ML.PCA _ = typeof(SweepCommand).Assembly; // ML.Sweeper _ = typeof(OneHotEncodingTransformer).Assembly; // ML.Transforms // The following assemblies reference this assembly, so we can't directly reference them - //_ = typeof(Microsoft.ML.Runtime.Data.LinearPredictor).Assembly); // ML.StandardLearners + //_ = typeof(Microsoft.ML.Data.LinearPredictor).Assembly); // ML.StandardLearners _ = Assembly.Load(new AssemblyName() { Name = "Microsoft.ML.StandardLearners", - Version = apiAssemblyName.Version, //assume the same version as ML.Api + Version = dataAssemblyName.Version, //assume the same version as ML.Data }); return true; diff --git a/src/Microsoft.ML.Legacy/CSharpApi.cs b/src/Microsoft.ML.Legacy/CSharpApi.cs index c7e5c3ff89..b607908914 100644 --- a/src/Microsoft.ML.Legacy/CSharpApi.cs +++ b/src/Microsoft.ML.Legacy/CSharpApi.cs @@ -8,1967 +8,1936 @@ //------------------------------------------------------------------------------ #pragma warning disable using System.Collections.Generic; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; using Newtonsoft.Json; using System; using System.Linq; -using Microsoft.ML.Runtime.CommandLine; +using Microsoft.ML.CommandLine; namespace Microsoft.ML { - namespace Runtime + public sealed partial class Experiment { - public sealed partial class Experiment + [Obsolete] + public Microsoft.ML.Legacy.Data.CustomTextLoader.Output Add(Microsoft.ML.Legacy.Data.CustomTextLoader input) { - [Obsolete] - public Microsoft.ML.Legacy.Data.CustomTextLoader.Output Add(Microsoft.ML.Legacy.Data.CustomTextLoader input) - { - var output = new Microsoft.ML.Legacy.Data.CustomTextLoader.Output(); - Add(input, output); - return output; - } - - [Obsolete] - public void Add(Microsoft.ML.Legacy.Data.CustomTextLoader input, Microsoft.ML.Legacy.Data.CustomTextLoader.Output output) - { - _jsonNodes.Add(Serialize("Data.CustomTextLoader", input, output)); - } + var output = new Microsoft.ML.Legacy.Data.CustomTextLoader.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Data.DataViewReference.Output Add(Microsoft.ML.Legacy.Data.DataViewReference input) - { - var output = new Microsoft.ML.Legacy.Data.DataViewReference.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Data.CustomTextLoader input, Microsoft.ML.Legacy.Data.CustomTextLoader.Output output) + { + _jsonNodes.Add(Serialize("Data.CustomTextLoader", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Data.DataViewReference input, Microsoft.ML.Legacy.Data.DataViewReference.Output output) - { - _jsonNodes.Add(Serialize("Data.DataViewReference", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Data.DataViewReference.Output Add(Microsoft.ML.Legacy.Data.DataViewReference input) + { + var output = new Microsoft.ML.Legacy.Data.DataViewReference.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Data.IDataViewArrayConverter.Output Add(Microsoft.ML.Legacy.Data.IDataViewArrayConverter input) - { - var output = new Microsoft.ML.Legacy.Data.IDataViewArrayConverter.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Data.DataViewReference input, Microsoft.ML.Legacy.Data.DataViewReference.Output output) + { + _jsonNodes.Add(Serialize("Data.DataViewReference", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Data.IDataViewArrayConverter input, Microsoft.ML.Legacy.Data.IDataViewArrayConverter.Output output) - { - _jsonNodes.Add(Serialize("Data.IDataViewArrayConverter", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Data.IDataViewArrayConverter.Output Add(Microsoft.ML.Legacy.Data.IDataViewArrayConverter input) + { + var output = new Microsoft.ML.Legacy.Data.IDataViewArrayConverter.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Data.PredictorModelArrayConverter.Output Add(Microsoft.ML.Legacy.Data.PredictorModelArrayConverter input) - { - var output = new Microsoft.ML.Legacy.Data.PredictorModelArrayConverter.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Data.IDataViewArrayConverter input, Microsoft.ML.Legacy.Data.IDataViewArrayConverter.Output output) + { + _jsonNodes.Add(Serialize("Data.IDataViewArrayConverter", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Data.PredictorModelArrayConverter input, Microsoft.ML.Legacy.Data.PredictorModelArrayConverter.Output output) - { - _jsonNodes.Add(Serialize("Data.PredictorModelArrayConverter", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Data.PredictorModelArrayConverter.Output Add(Microsoft.ML.Legacy.Data.PredictorModelArrayConverter input) + { + var output = new Microsoft.ML.Legacy.Data.PredictorModelArrayConverter.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Data.TextLoader.Output Add(Microsoft.ML.Legacy.Data.TextLoader input) - { - var output = new Microsoft.ML.Legacy.Data.TextLoader.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Data.PredictorModelArrayConverter input, Microsoft.ML.Legacy.Data.PredictorModelArrayConverter.Output output) + { + _jsonNodes.Add(Serialize("Data.PredictorModelArrayConverter", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Data.TextLoader input, Microsoft.ML.Legacy.Data.TextLoader.Output output) - { - _jsonNodes.Add(Serialize("Data.TextLoader", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Data.TextLoader.Output Add(Microsoft.ML.Legacy.Data.TextLoader input) + { + var output = new Microsoft.ML.Legacy.Data.TextLoader.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Data.TransformModelArrayConverter.Output Add(Microsoft.ML.Legacy.Data.TransformModelArrayConverter input) - { - var output = new Microsoft.ML.Legacy.Data.TransformModelArrayConverter.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Data.TextLoader input, Microsoft.ML.Legacy.Data.TextLoader.Output output) + { + _jsonNodes.Add(Serialize("Data.TextLoader", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Data.TransformModelArrayConverter input, Microsoft.ML.Legacy.Data.TransformModelArrayConverter.Output output) - { - _jsonNodes.Add(Serialize("Data.TransformModelArrayConverter", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Models.AnomalyDetectionEvaluator.Output Add(Microsoft.ML.Legacy.Models.AnomalyDetectionEvaluator input) + { + var output = new Microsoft.ML.Legacy.Models.AnomalyDetectionEvaluator.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Models.AnomalyDetectionEvaluator.Output Add(Microsoft.ML.Legacy.Models.AnomalyDetectionEvaluator input) - { - var output = new Microsoft.ML.Legacy.Models.AnomalyDetectionEvaluator.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Models.AnomalyDetectionEvaluator input, Microsoft.ML.Legacy.Models.AnomalyDetectionEvaluator.Output output) + { + _jsonNodes.Add(Serialize("Models.AnomalyDetectionEvaluator", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Models.AnomalyDetectionEvaluator input, Microsoft.ML.Legacy.Models.AnomalyDetectionEvaluator.Output output) - { - _jsonNodes.Add(Serialize("Models.AnomalyDetectionEvaluator", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Models.AnomalyPipelineEnsemble.Output Add(Microsoft.ML.Legacy.Models.AnomalyPipelineEnsemble input) + { + var output = new Microsoft.ML.Legacy.Models.AnomalyPipelineEnsemble.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Models.AnomalyPipelineEnsemble.Output Add(Microsoft.ML.Legacy.Models.AnomalyPipelineEnsemble input) - { - var output = new Microsoft.ML.Legacy.Models.AnomalyPipelineEnsemble.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Models.AnomalyPipelineEnsemble input, Microsoft.ML.Legacy.Models.AnomalyPipelineEnsemble.Output output) + { + _jsonNodes.Add(Serialize("Models.AnomalyPipelineEnsemble", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Models.AnomalyPipelineEnsemble input, Microsoft.ML.Legacy.Models.AnomalyPipelineEnsemble.Output output) - { - _jsonNodes.Add(Serialize("Models.AnomalyPipelineEnsemble", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Models.BinaryClassificationEvaluator.Output Add(Microsoft.ML.Legacy.Models.BinaryClassificationEvaluator input) + { + var output = new Microsoft.ML.Legacy.Models.BinaryClassificationEvaluator.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Models.BinaryClassificationEvaluator.Output Add(Microsoft.ML.Legacy.Models.BinaryClassificationEvaluator input) - { - var output = new Microsoft.ML.Legacy.Models.BinaryClassificationEvaluator.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Models.BinaryClassificationEvaluator input, Microsoft.ML.Legacy.Models.BinaryClassificationEvaluator.Output output) + { + _jsonNodes.Add(Serialize("Models.BinaryClassificationEvaluator", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Models.BinaryClassificationEvaluator input, Microsoft.ML.Legacy.Models.BinaryClassificationEvaluator.Output output) - { - _jsonNodes.Add(Serialize("Models.BinaryClassificationEvaluator", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Models.BinaryEnsemble.Output Add(Microsoft.ML.Legacy.Models.BinaryEnsemble input) + { + var output = new Microsoft.ML.Legacy.Models.BinaryEnsemble.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Models.BinaryCrossValidator.Output Add(Microsoft.ML.Legacy.Models.BinaryCrossValidator input) - { - var output = new Microsoft.ML.Legacy.Models.BinaryCrossValidator.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Models.BinaryEnsemble input, Microsoft.ML.Legacy.Models.BinaryEnsemble.Output output) + { + _jsonNodes.Add(Serialize("Models.BinaryEnsemble", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Models.BinaryCrossValidator input, Microsoft.ML.Legacy.Models.BinaryCrossValidator.Output output) - { - _jsonNodes.Add(Serialize("Models.BinaryCrossValidator", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Models.BinaryPipelineEnsemble.Output Add(Microsoft.ML.Legacy.Models.BinaryPipelineEnsemble input) + { + var output = new Microsoft.ML.Legacy.Models.BinaryPipelineEnsemble.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Models.BinaryEnsemble.Output Add(Microsoft.ML.Legacy.Models.BinaryEnsemble input) - { - var output = new Microsoft.ML.Legacy.Models.BinaryEnsemble.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Models.BinaryPipelineEnsemble input, Microsoft.ML.Legacy.Models.BinaryPipelineEnsemble.Output output) + { + _jsonNodes.Add(Serialize("Models.BinaryPipelineEnsemble", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Models.BinaryEnsemble input, Microsoft.ML.Legacy.Models.BinaryEnsemble.Output output) - { - _jsonNodes.Add(Serialize("Models.BinaryEnsemble", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Models.ClassificationEvaluator.Output Add(Microsoft.ML.Legacy.Models.ClassificationEvaluator input) + { + var output = new Microsoft.ML.Legacy.Models.ClassificationEvaluator.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Models.BinaryPipelineEnsemble.Output Add(Microsoft.ML.Legacy.Models.BinaryPipelineEnsemble input) - { - var output = new Microsoft.ML.Legacy.Models.BinaryPipelineEnsemble.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Models.ClassificationEvaluator input, Microsoft.ML.Legacy.Models.ClassificationEvaluator.Output output) + { + _jsonNodes.Add(Serialize("Models.ClassificationEvaluator", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Models.BinaryPipelineEnsemble input, Microsoft.ML.Legacy.Models.BinaryPipelineEnsemble.Output output) - { - _jsonNodes.Add(Serialize("Models.BinaryPipelineEnsemble", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Models.ClusterEvaluator.Output Add(Microsoft.ML.Legacy.Models.ClusterEvaluator input) + { + var output = new Microsoft.ML.Legacy.Models.ClusterEvaluator.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Models.ClassificationEvaluator.Output Add(Microsoft.ML.Legacy.Models.ClassificationEvaluator input) - { - var output = new Microsoft.ML.Legacy.Models.ClassificationEvaluator.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Models.ClusterEvaluator input, Microsoft.ML.Legacy.Models.ClusterEvaluator.Output output) + { + _jsonNodes.Add(Serialize("Models.ClusterEvaluator", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Models.ClassificationEvaluator input, Microsoft.ML.Legacy.Models.ClassificationEvaluator.Output output) - { - _jsonNodes.Add(Serialize("Models.ClassificationEvaluator", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Models.CrossValidationResultsCombiner.Output Add(Microsoft.ML.Legacy.Models.CrossValidationResultsCombiner input) + { + var output = new Microsoft.ML.Legacy.Models.CrossValidationResultsCombiner.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Models.ClusterEvaluator.Output Add(Microsoft.ML.Legacy.Models.ClusterEvaluator input) - { - var output = new Microsoft.ML.Legacy.Models.ClusterEvaluator.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Models.CrossValidationResultsCombiner input, Microsoft.ML.Legacy.Models.CrossValidationResultsCombiner.Output output) + { + _jsonNodes.Add(Serialize("Models.CrossValidationResultsCombiner", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Models.ClusterEvaluator input, Microsoft.ML.Legacy.Models.ClusterEvaluator.Output output) - { - _jsonNodes.Add(Serialize("Models.ClusterEvaluator", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Models.CrossValidator.Output Add(Microsoft.ML.Legacy.Models.CrossValidator input) + { + var output = new Microsoft.ML.Legacy.Models.CrossValidator.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Models.CrossValidationResultsCombiner.Output Add(Microsoft.ML.Legacy.Models.CrossValidationResultsCombiner input) - { - var output = new Microsoft.ML.Legacy.Models.CrossValidationResultsCombiner.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Models.CrossValidator input, Microsoft.ML.Legacy.Models.CrossValidator.Output output) + { + _jsonNodes.Add(Serialize("Models.CrossValidator", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Models.CrossValidationResultsCombiner input, Microsoft.ML.Legacy.Models.CrossValidationResultsCombiner.Output output) - { - _jsonNodes.Add(Serialize("Models.CrossValidationResultsCombiner", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Models.CrossValidatorDatasetSplitter.Output Add(Microsoft.ML.Legacy.Models.CrossValidatorDatasetSplitter input) + { + var output = new Microsoft.ML.Legacy.Models.CrossValidatorDatasetSplitter.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Models.CrossValidator.Output Add(Microsoft.ML.Legacy.Models.CrossValidator input) - { - var output = new Microsoft.ML.Legacy.Models.CrossValidator.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Models.CrossValidatorDatasetSplitter input, Microsoft.ML.Legacy.Models.CrossValidatorDatasetSplitter.Output output) + { + _jsonNodes.Add(Serialize("Models.CrossValidatorDatasetSplitter", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Models.CrossValidator input, Microsoft.ML.Legacy.Models.CrossValidator.Output output) - { - _jsonNodes.Add(Serialize("Models.CrossValidator", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Models.DatasetTransformer.Output Add(Microsoft.ML.Legacy.Models.DatasetTransformer input) + { + var output = new Microsoft.ML.Legacy.Models.DatasetTransformer.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Models.CrossValidatorDatasetSplitter.Output Add(Microsoft.ML.Legacy.Models.CrossValidatorDatasetSplitter input) - { - var output = new Microsoft.ML.Legacy.Models.CrossValidatorDatasetSplitter.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Models.DatasetTransformer input, Microsoft.ML.Legacy.Models.DatasetTransformer.Output output) + { + _jsonNodes.Add(Serialize("Models.DatasetTransformer", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Models.CrossValidatorDatasetSplitter input, Microsoft.ML.Legacy.Models.CrossValidatorDatasetSplitter.Output output) - { - _jsonNodes.Add(Serialize("Models.CrossValidatorDatasetSplitter", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Models.EnsembleSummary.Output Add(Microsoft.ML.Legacy.Models.EnsembleSummary input) + { + var output = new Microsoft.ML.Legacy.Models.EnsembleSummary.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Models.DatasetTransformer.Output Add(Microsoft.ML.Legacy.Models.DatasetTransformer input) - { - var output = new Microsoft.ML.Legacy.Models.DatasetTransformer.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Models.EnsembleSummary input, Microsoft.ML.Legacy.Models.EnsembleSummary.Output output) + { + _jsonNodes.Add(Serialize("Models.EnsembleSummary", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Models.DatasetTransformer input, Microsoft.ML.Legacy.Models.DatasetTransformer.Output output) - { - _jsonNodes.Add(Serialize("Models.DatasetTransformer", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Models.FixedPlattCalibrator.Output Add(Microsoft.ML.Legacy.Models.FixedPlattCalibrator input) + { + var output = new Microsoft.ML.Legacy.Models.FixedPlattCalibrator.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Models.EnsembleSummary.Output Add(Microsoft.ML.Legacy.Models.EnsembleSummary input) - { - var output = new Microsoft.ML.Legacy.Models.EnsembleSummary.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Models.FixedPlattCalibrator input, Microsoft.ML.Legacy.Models.FixedPlattCalibrator.Output output) + { + _jsonNodes.Add(Serialize("Models.FixedPlattCalibrator", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Models.EnsembleSummary input, Microsoft.ML.Legacy.Models.EnsembleSummary.Output output) - { - _jsonNodes.Add(Serialize("Models.EnsembleSummary", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Models.MultiClassPipelineEnsemble.Output Add(Microsoft.ML.Legacy.Models.MultiClassPipelineEnsemble input) + { + var output = new Microsoft.ML.Legacy.Models.MultiClassPipelineEnsemble.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Models.FixedPlattCalibrator.Output Add(Microsoft.ML.Legacy.Models.FixedPlattCalibrator input) - { - var output = new Microsoft.ML.Legacy.Models.FixedPlattCalibrator.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Models.MultiClassPipelineEnsemble input, Microsoft.ML.Legacy.Models.MultiClassPipelineEnsemble.Output output) + { + _jsonNodes.Add(Serialize("Models.MultiClassPipelineEnsemble", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Models.FixedPlattCalibrator input, Microsoft.ML.Legacy.Models.FixedPlattCalibrator.Output output) - { - _jsonNodes.Add(Serialize("Models.FixedPlattCalibrator", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Models.MultiOutputRegressionEvaluator.Output Add(Microsoft.ML.Legacy.Models.MultiOutputRegressionEvaluator input) + { + var output = new Microsoft.ML.Legacy.Models.MultiOutputRegressionEvaluator.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Models.MultiClassPipelineEnsemble.Output Add(Microsoft.ML.Legacy.Models.MultiClassPipelineEnsemble input) - { - var output = new Microsoft.ML.Legacy.Models.MultiClassPipelineEnsemble.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Models.MultiOutputRegressionEvaluator input, Microsoft.ML.Legacy.Models.MultiOutputRegressionEvaluator.Output output) + { + _jsonNodes.Add(Serialize("Models.MultiOutputRegressionEvaluator", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Models.MultiClassPipelineEnsemble input, Microsoft.ML.Legacy.Models.MultiClassPipelineEnsemble.Output output) - { - _jsonNodes.Add(Serialize("Models.MultiClassPipelineEnsemble", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Models.NaiveCalibrator.Output Add(Microsoft.ML.Legacy.Models.NaiveCalibrator input) + { + var output = new Microsoft.ML.Legacy.Models.NaiveCalibrator.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Models.MultiOutputRegressionEvaluator.Output Add(Microsoft.ML.Legacy.Models.MultiOutputRegressionEvaluator input) - { - var output = new Microsoft.ML.Legacy.Models.MultiOutputRegressionEvaluator.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Models.NaiveCalibrator input, Microsoft.ML.Legacy.Models.NaiveCalibrator.Output output) + { + _jsonNodes.Add(Serialize("Models.NaiveCalibrator", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Models.MultiOutputRegressionEvaluator input, Microsoft.ML.Legacy.Models.MultiOutputRegressionEvaluator.Output output) - { - _jsonNodes.Add(Serialize("Models.MultiOutputRegressionEvaluator", input, output)); - } - - [Obsolete] - public Microsoft.ML.Legacy.Models.NaiveCalibrator.Output Add(Microsoft.ML.Legacy.Models.NaiveCalibrator input) - { - var output = new Microsoft.ML.Legacy.Models.NaiveCalibrator.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Models.OneVersusAll.Output Add(Microsoft.ML.Legacy.Models.OneVersusAll input) + { + var output = new Microsoft.ML.Legacy.Models.OneVersusAll.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Models.NaiveCalibrator input, Microsoft.ML.Legacy.Models.NaiveCalibrator.Output output) - { - _jsonNodes.Add(Serialize("Models.NaiveCalibrator", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Models.OneVersusAll input, Microsoft.ML.Legacy.Models.OneVersusAll.Output output) + { + _jsonNodes.Add(Serialize("Models.OneVersusAll", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Models.OneVersusAll.Output Add(Microsoft.ML.Legacy.Models.OneVersusAll input) - { - var output = new Microsoft.ML.Legacy.Models.OneVersusAll.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Models.OnnxConverter.Output Add(Microsoft.ML.Legacy.Models.OnnxConverter input) + { + var output = new Microsoft.ML.Legacy.Models.OnnxConverter.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Models.OneVersusAll input, Microsoft.ML.Legacy.Models.OneVersusAll.Output output) - { - _jsonNodes.Add(Serialize("Models.OneVersusAll", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Models.OnnxConverter input, Microsoft.ML.Legacy.Models.OnnxConverter.Output output) + { + _jsonNodes.Add(Serialize("Models.OnnxConverter", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Models.OnnxConverter.Output Add(Microsoft.ML.Legacy.Models.OnnxConverter input) - { - var output = new Microsoft.ML.Legacy.Models.OnnxConverter.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Models.OvaModelCombiner.Output Add(Microsoft.ML.Legacy.Models.OvaModelCombiner input) + { + var output = new Microsoft.ML.Legacy.Models.OvaModelCombiner.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Models.OnnxConverter input, Microsoft.ML.Legacy.Models.OnnxConverter.Output output) - { - _jsonNodes.Add(Serialize("Models.OnnxConverter", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Models.OvaModelCombiner input, Microsoft.ML.Legacy.Models.OvaModelCombiner.Output output) + { + _jsonNodes.Add(Serialize("Models.OvaModelCombiner", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Models.OvaModelCombiner.Output Add(Microsoft.ML.Legacy.Models.OvaModelCombiner input) - { - var output = new Microsoft.ML.Legacy.Models.OvaModelCombiner.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Models.PAVCalibrator.Output Add(Microsoft.ML.Legacy.Models.PAVCalibrator input) + { + var output = new Microsoft.ML.Legacy.Models.PAVCalibrator.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Models.OvaModelCombiner input, Microsoft.ML.Legacy.Models.OvaModelCombiner.Output output) - { - _jsonNodes.Add(Serialize("Models.OvaModelCombiner", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Models.PAVCalibrator input, Microsoft.ML.Legacy.Models.PAVCalibrator.Output output) + { + _jsonNodes.Add(Serialize("Models.PAVCalibrator", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Models.PAVCalibrator.Output Add(Microsoft.ML.Legacy.Models.PAVCalibrator input) - { - var output = new Microsoft.ML.Legacy.Models.PAVCalibrator.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Models.PlattCalibrator.Output Add(Microsoft.ML.Legacy.Models.PlattCalibrator input) + { + var output = new Microsoft.ML.Legacy.Models.PlattCalibrator.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Models.PAVCalibrator input, Microsoft.ML.Legacy.Models.PAVCalibrator.Output output) - { - _jsonNodes.Add(Serialize("Models.PAVCalibrator", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Models.PlattCalibrator input, Microsoft.ML.Legacy.Models.PlattCalibrator.Output output) + { + _jsonNodes.Add(Serialize("Models.PlattCalibrator", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Models.PlattCalibrator.Output Add(Microsoft.ML.Legacy.Models.PlattCalibrator input) - { - var output = new Microsoft.ML.Legacy.Models.PlattCalibrator.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Models.QuantileRegressionEvaluator.Output Add(Microsoft.ML.Legacy.Models.QuantileRegressionEvaluator input) + { + var output = new Microsoft.ML.Legacy.Models.QuantileRegressionEvaluator.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Models.PlattCalibrator input, Microsoft.ML.Legacy.Models.PlattCalibrator.Output output) - { - _jsonNodes.Add(Serialize("Models.PlattCalibrator", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Models.QuantileRegressionEvaluator input, Microsoft.ML.Legacy.Models.QuantileRegressionEvaluator.Output output) + { + _jsonNodes.Add(Serialize("Models.QuantileRegressionEvaluator", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Models.QuantileRegressionEvaluator.Output Add(Microsoft.ML.Legacy.Models.QuantileRegressionEvaluator input) - { - var output = new Microsoft.ML.Legacy.Models.QuantileRegressionEvaluator.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Models.RankerEvaluator.Output Add(Microsoft.ML.Legacy.Models.RankerEvaluator input) + { + var output = new Microsoft.ML.Legacy.Models.RankerEvaluator.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Models.QuantileRegressionEvaluator input, Microsoft.ML.Legacy.Models.QuantileRegressionEvaluator.Output output) - { - _jsonNodes.Add(Serialize("Models.QuantileRegressionEvaluator", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Models.RankerEvaluator input, Microsoft.ML.Legacy.Models.RankerEvaluator.Output output) + { + _jsonNodes.Add(Serialize("Models.RankerEvaluator", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Models.RankerEvaluator.Output Add(Microsoft.ML.Legacy.Models.RankerEvaluator input) - { - var output = new Microsoft.ML.Legacy.Models.RankerEvaluator.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Models.RegressionEnsemble.Output Add(Microsoft.ML.Legacy.Models.RegressionEnsemble input) + { + var output = new Microsoft.ML.Legacy.Models.RegressionEnsemble.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Models.RankerEvaluator input, Microsoft.ML.Legacy.Models.RankerEvaluator.Output output) - { - _jsonNodes.Add(Serialize("Models.RankerEvaluator", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Models.RegressionEnsemble input, Microsoft.ML.Legacy.Models.RegressionEnsemble.Output output) + { + _jsonNodes.Add(Serialize("Models.RegressionEnsemble", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Models.RegressionEnsemble.Output Add(Microsoft.ML.Legacy.Models.RegressionEnsemble input) - { - var output = new Microsoft.ML.Legacy.Models.RegressionEnsemble.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Models.RegressionEvaluator.Output Add(Microsoft.ML.Legacy.Models.RegressionEvaluator input) + { + var output = new Microsoft.ML.Legacy.Models.RegressionEvaluator.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Models.RegressionEnsemble input, Microsoft.ML.Legacy.Models.RegressionEnsemble.Output output) - { - _jsonNodes.Add(Serialize("Models.RegressionEnsemble", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Models.RegressionEvaluator input, Microsoft.ML.Legacy.Models.RegressionEvaluator.Output output) + { + _jsonNodes.Add(Serialize("Models.RegressionEvaluator", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Models.RegressionEvaluator.Output Add(Microsoft.ML.Legacy.Models.RegressionEvaluator input) - { - var output = new Microsoft.ML.Legacy.Models.RegressionEvaluator.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Models.RegressionPipelineEnsemble.Output Add(Microsoft.ML.Legacy.Models.RegressionPipelineEnsemble input) + { + var output = new Microsoft.ML.Legacy.Models.RegressionPipelineEnsemble.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Models.RegressionEvaluator input, Microsoft.ML.Legacy.Models.RegressionEvaluator.Output output) - { - _jsonNodes.Add(Serialize("Models.RegressionEvaluator", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Models.RegressionPipelineEnsemble input, Microsoft.ML.Legacy.Models.RegressionPipelineEnsemble.Output output) + { + _jsonNodes.Add(Serialize("Models.RegressionPipelineEnsemble", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Models.RegressionPipelineEnsemble.Output Add(Microsoft.ML.Legacy.Models.RegressionPipelineEnsemble input) - { - var output = new Microsoft.ML.Legacy.Models.RegressionPipelineEnsemble.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Models.Summarizer.Output Add(Microsoft.ML.Legacy.Models.Summarizer input) + { + var output = new Microsoft.ML.Legacy.Models.Summarizer.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Models.RegressionPipelineEnsemble input, Microsoft.ML.Legacy.Models.RegressionPipelineEnsemble.Output output) - { - _jsonNodes.Add(Serialize("Models.RegressionPipelineEnsemble", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Models.Summarizer input, Microsoft.ML.Legacy.Models.Summarizer.Output output) + { + _jsonNodes.Add(Serialize("Models.Summarizer", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Models.Summarizer.Output Add(Microsoft.ML.Legacy.Models.Summarizer input) - { - var output = new Microsoft.ML.Legacy.Models.Summarizer.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Models.TrainTestEvaluator.Output Add(Microsoft.ML.Legacy.Models.TrainTestEvaluator input) + { + var output = new Microsoft.ML.Legacy.Models.TrainTestEvaluator.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Models.Summarizer input, Microsoft.ML.Legacy.Models.Summarizer.Output output) - { - _jsonNodes.Add(Serialize("Models.Summarizer", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Models.TrainTestEvaluator input, Microsoft.ML.Legacy.Models.TrainTestEvaluator.Output output) + { + _jsonNodes.Add(Serialize("Models.TrainTestEvaluator", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Models.TrainTestBinaryEvaluator.Output Add(Microsoft.ML.Legacy.Models.TrainTestBinaryEvaluator input) - { - var output = new Microsoft.ML.Legacy.Models.TrainTestBinaryEvaluator.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.TimeSeriesProcessingEntryPoints.ExponentialAverage.Output Add(Microsoft.ML.Legacy.TimeSeriesProcessingEntryPoints.ExponentialAverage input) + { + var output = new Microsoft.ML.Legacy.TimeSeriesProcessingEntryPoints.ExponentialAverage.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Models.TrainTestBinaryEvaluator input, Microsoft.ML.Legacy.Models.TrainTestBinaryEvaluator.Output output) - { - _jsonNodes.Add(Serialize("Models.TrainTestBinaryEvaluator", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.TimeSeriesProcessingEntryPoints.ExponentialAverage input, Microsoft.ML.Legacy.TimeSeriesProcessingEntryPoints.ExponentialAverage.Output output) + { + _jsonNodes.Add(Serialize("TimeSeriesProcessingEntryPoints.ExponentialAverage", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Models.TrainTestEvaluator.Output Add(Microsoft.ML.Legacy.Models.TrainTestEvaluator input) - { - var output = new Microsoft.ML.Legacy.Models.TrainTestEvaluator.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.TimeSeriesProcessingEntryPoints.IidChangePointDetector.Output Add(Microsoft.ML.Legacy.TimeSeriesProcessingEntryPoints.IidChangePointDetector input) + { + var output = new Microsoft.ML.Legacy.TimeSeriesProcessingEntryPoints.IidChangePointDetector.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Models.TrainTestEvaluator input, Microsoft.ML.Legacy.Models.TrainTestEvaluator.Output output) - { - _jsonNodes.Add(Serialize("Models.TrainTestEvaluator", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.TimeSeriesProcessingEntryPoints.IidChangePointDetector input, Microsoft.ML.Legacy.TimeSeriesProcessingEntryPoints.IidChangePointDetector.Output output) + { + _jsonNodes.Add(Serialize("TimeSeriesProcessingEntryPoints.IidChangePointDetector", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.TimeSeriesProcessing.ExponentialAverage.Output Add(Microsoft.ML.Legacy.TimeSeriesProcessing.ExponentialAverage input) - { - var output = new Microsoft.ML.Legacy.TimeSeriesProcessing.ExponentialAverage.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.TimeSeriesProcessingEntryPoints.IidSpikeDetector.Output Add(Microsoft.ML.Legacy.TimeSeriesProcessingEntryPoints.IidSpikeDetector input) + { + var output = new Microsoft.ML.Legacy.TimeSeriesProcessingEntryPoints.IidSpikeDetector.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.TimeSeriesProcessing.ExponentialAverage input, Microsoft.ML.Legacy.TimeSeriesProcessing.ExponentialAverage.Output output) - { - _jsonNodes.Add(Serialize("TimeSeriesProcessing.ExponentialAverage", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.TimeSeriesProcessingEntryPoints.IidSpikeDetector input, Microsoft.ML.Legacy.TimeSeriesProcessingEntryPoints.IidSpikeDetector.Output output) + { + _jsonNodes.Add(Serialize("TimeSeriesProcessingEntryPoints.IidSpikeDetector", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.TimeSeriesProcessing.IidChangePointDetector.Output Add(Microsoft.ML.Legacy.TimeSeriesProcessing.IidChangePointDetector input) - { - var output = new Microsoft.ML.Legacy.TimeSeriesProcessing.IidChangePointDetector.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.TimeSeriesProcessingEntryPoints.PercentileThresholdTransform.Output Add(Microsoft.ML.Legacy.TimeSeriesProcessingEntryPoints.PercentileThresholdTransform input) + { + var output = new Microsoft.ML.Legacy.TimeSeriesProcessingEntryPoints.PercentileThresholdTransform.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.TimeSeriesProcessing.IidChangePointDetector input, Microsoft.ML.Legacy.TimeSeriesProcessing.IidChangePointDetector.Output output) - { - _jsonNodes.Add(Serialize("TimeSeriesProcessing.IidChangePointDetector", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.TimeSeriesProcessingEntryPoints.PercentileThresholdTransform input, Microsoft.ML.Legacy.TimeSeriesProcessingEntryPoints.PercentileThresholdTransform.Output output) + { + _jsonNodes.Add(Serialize("TimeSeriesProcessingEntryPoints.PercentileThresholdTransform", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.TimeSeriesProcessing.IidSpikeDetector.Output Add(Microsoft.ML.Legacy.TimeSeriesProcessing.IidSpikeDetector input) - { - var output = new Microsoft.ML.Legacy.TimeSeriesProcessing.IidSpikeDetector.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.TimeSeriesProcessingEntryPoints.PValueTransform.Output Add(Microsoft.ML.Legacy.TimeSeriesProcessingEntryPoints.PValueTransform input) + { + var output = new Microsoft.ML.Legacy.TimeSeriesProcessingEntryPoints.PValueTransform.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.TimeSeriesProcessing.IidSpikeDetector input, Microsoft.ML.Legacy.TimeSeriesProcessing.IidSpikeDetector.Output output) - { - _jsonNodes.Add(Serialize("TimeSeriesProcessing.IidSpikeDetector", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.TimeSeriesProcessingEntryPoints.PValueTransform input, Microsoft.ML.Legacy.TimeSeriesProcessingEntryPoints.PValueTransform.Output output) + { + _jsonNodes.Add(Serialize("TimeSeriesProcessingEntryPoints.PValueTransform", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.TimeSeriesProcessing.PercentileThresholdTransform.Output Add(Microsoft.ML.Legacy.TimeSeriesProcessing.PercentileThresholdTransform input) - { - var output = new Microsoft.ML.Legacy.TimeSeriesProcessing.PercentileThresholdTransform.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.TimeSeriesProcessingEntryPoints.SlidingWindowTransform.Output Add(Microsoft.ML.Legacy.TimeSeriesProcessingEntryPoints.SlidingWindowTransform input) + { + var output = new Microsoft.ML.Legacy.TimeSeriesProcessingEntryPoints.SlidingWindowTransform.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.TimeSeriesProcessing.PercentileThresholdTransform input, Microsoft.ML.Legacy.TimeSeriesProcessing.PercentileThresholdTransform.Output output) - { - _jsonNodes.Add(Serialize("TimeSeriesProcessing.PercentileThresholdTransform", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.TimeSeriesProcessingEntryPoints.SlidingWindowTransform input, Microsoft.ML.Legacy.TimeSeriesProcessingEntryPoints.SlidingWindowTransform.Output output) + { + _jsonNodes.Add(Serialize("TimeSeriesProcessingEntryPoints.SlidingWindowTransform", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.TimeSeriesProcessing.PValueTransform.Output Add(Microsoft.ML.Legacy.TimeSeriesProcessing.PValueTransform input) - { - var output = new Microsoft.ML.Legacy.TimeSeriesProcessing.PValueTransform.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.TimeSeriesProcessingEntryPoints.SsaChangePointDetector.Output Add(Microsoft.ML.Legacy.TimeSeriesProcessingEntryPoints.SsaChangePointDetector input) + { + var output = new Microsoft.ML.Legacy.TimeSeriesProcessingEntryPoints.SsaChangePointDetector.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.TimeSeriesProcessing.PValueTransform input, Microsoft.ML.Legacy.TimeSeriesProcessing.PValueTransform.Output output) - { - _jsonNodes.Add(Serialize("TimeSeriesProcessing.PValueTransform", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.TimeSeriesProcessingEntryPoints.SsaChangePointDetector input, Microsoft.ML.Legacy.TimeSeriesProcessingEntryPoints.SsaChangePointDetector.Output output) + { + _jsonNodes.Add(Serialize("TimeSeriesProcessingEntryPoints.SsaChangePointDetector", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.TimeSeriesProcessing.SlidingWindowTransform.Output Add(Microsoft.ML.Legacy.TimeSeriesProcessing.SlidingWindowTransform input) - { - var output = new Microsoft.ML.Legacy.TimeSeriesProcessing.SlidingWindowTransform.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.TimeSeriesProcessingEntryPoints.SsaSpikeDetector.Output Add(Microsoft.ML.Legacy.TimeSeriesProcessingEntryPoints.SsaSpikeDetector input) + { + var output = new Microsoft.ML.Legacy.TimeSeriesProcessingEntryPoints.SsaSpikeDetector.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.TimeSeriesProcessing.SlidingWindowTransform input, Microsoft.ML.Legacy.TimeSeriesProcessing.SlidingWindowTransform.Output output) - { - _jsonNodes.Add(Serialize("TimeSeriesProcessing.SlidingWindowTransform", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.TimeSeriesProcessingEntryPoints.SsaSpikeDetector input, Microsoft.ML.Legacy.TimeSeriesProcessingEntryPoints.SsaSpikeDetector.Output output) + { + _jsonNodes.Add(Serialize("TimeSeriesProcessingEntryPoints.SsaSpikeDetector", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.TimeSeriesProcessing.SsaChangePointDetector.Output Add(Microsoft.ML.Legacy.TimeSeriesProcessing.SsaChangePointDetector input) - { - var output = new Microsoft.ML.Legacy.TimeSeriesProcessing.SsaChangePointDetector.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Trainers.AveragedPerceptronBinaryClassifier.Output Add(Microsoft.ML.Legacy.Trainers.AveragedPerceptronBinaryClassifier input) + { + var output = new Microsoft.ML.Legacy.Trainers.AveragedPerceptronBinaryClassifier.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.TimeSeriesProcessing.SsaChangePointDetector input, Microsoft.ML.Legacy.TimeSeriesProcessing.SsaChangePointDetector.Output output) - { - _jsonNodes.Add(Serialize("TimeSeriesProcessing.SsaChangePointDetector", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Trainers.AveragedPerceptronBinaryClassifier input, Microsoft.ML.Legacy.Trainers.AveragedPerceptronBinaryClassifier.Output output) + { + _jsonNodes.Add(Serialize("Trainers.AveragedPerceptronBinaryClassifier", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.TimeSeriesProcessing.SsaSpikeDetector.Output Add(Microsoft.ML.Legacy.TimeSeriesProcessing.SsaSpikeDetector input) - { - var output = new Microsoft.ML.Legacy.TimeSeriesProcessing.SsaSpikeDetector.Output(); - Add(input, output); - return output; - } - - [Obsolete] - public void Add(Microsoft.ML.Legacy.TimeSeriesProcessing.SsaSpikeDetector input, Microsoft.ML.Legacy.TimeSeriesProcessing.SsaSpikeDetector.Output output) - { - _jsonNodes.Add(Serialize("TimeSeriesProcessing.SsaSpikeDetector", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Trainers.EnsembleBinaryClassifier.Output Add(Microsoft.ML.Legacy.Trainers.EnsembleBinaryClassifier input) + { + var output = new Microsoft.ML.Legacy.Trainers.EnsembleBinaryClassifier.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Trainers.AveragedPerceptronBinaryClassifier.Output Add(Microsoft.ML.Legacy.Trainers.AveragedPerceptronBinaryClassifier input) - { - var output = new Microsoft.ML.Legacy.Trainers.AveragedPerceptronBinaryClassifier.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Trainers.EnsembleBinaryClassifier input, Microsoft.ML.Legacy.Trainers.EnsembleBinaryClassifier.Output output) + { + _jsonNodes.Add(Serialize("Trainers.EnsembleBinaryClassifier", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Trainers.AveragedPerceptronBinaryClassifier input, Microsoft.ML.Legacy.Trainers.AveragedPerceptronBinaryClassifier.Output output) - { - _jsonNodes.Add(Serialize("Trainers.AveragedPerceptronBinaryClassifier", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Trainers.EnsembleClassification.Output Add(Microsoft.ML.Legacy.Trainers.EnsembleClassification input) + { + var output = new Microsoft.ML.Legacy.Trainers.EnsembleClassification.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Trainers.EnsembleBinaryClassifier.Output Add(Microsoft.ML.Legacy.Trainers.EnsembleBinaryClassifier input) - { - var output = new Microsoft.ML.Legacy.Trainers.EnsembleBinaryClassifier.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Trainers.EnsembleClassification input, Microsoft.ML.Legacy.Trainers.EnsembleClassification.Output output) + { + _jsonNodes.Add(Serialize("Trainers.EnsembleClassification", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Trainers.EnsembleBinaryClassifier input, Microsoft.ML.Legacy.Trainers.EnsembleBinaryClassifier.Output output) - { - _jsonNodes.Add(Serialize("Trainers.EnsembleBinaryClassifier", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Trainers.EnsembleRegression.Output Add(Microsoft.ML.Legacy.Trainers.EnsembleRegression input) + { + var output = new Microsoft.ML.Legacy.Trainers.EnsembleRegression.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Trainers.EnsembleClassification.Output Add(Microsoft.ML.Legacy.Trainers.EnsembleClassification input) - { - var output = new Microsoft.ML.Legacy.Trainers.EnsembleClassification.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Trainers.EnsembleRegression input, Microsoft.ML.Legacy.Trainers.EnsembleRegression.Output output) + { + _jsonNodes.Add(Serialize("Trainers.EnsembleRegression", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Trainers.EnsembleClassification input, Microsoft.ML.Legacy.Trainers.EnsembleClassification.Output output) - { - _jsonNodes.Add(Serialize("Trainers.EnsembleClassification", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Trainers.FastForestBinaryClassifier.Output Add(Microsoft.ML.Legacy.Trainers.FastForestBinaryClassifier input) + { + var output = new Microsoft.ML.Legacy.Trainers.FastForestBinaryClassifier.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Trainers.EnsembleRegression.Output Add(Microsoft.ML.Legacy.Trainers.EnsembleRegression input) - { - var output = new Microsoft.ML.Legacy.Trainers.EnsembleRegression.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Trainers.FastForestBinaryClassifier input, Microsoft.ML.Legacy.Trainers.FastForestBinaryClassifier.Output output) + { + _jsonNodes.Add(Serialize("Trainers.FastForestBinaryClassifier", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Trainers.EnsembleRegression input, Microsoft.ML.Legacy.Trainers.EnsembleRegression.Output output) - { - _jsonNodes.Add(Serialize("Trainers.EnsembleRegression", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Trainers.FastForestRegressor.Output Add(Microsoft.ML.Legacy.Trainers.FastForestRegressor input) + { + var output = new Microsoft.ML.Legacy.Trainers.FastForestRegressor.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Trainers.FastForestBinaryClassifier.Output Add(Microsoft.ML.Legacy.Trainers.FastForestBinaryClassifier input) - { - var output = new Microsoft.ML.Legacy.Trainers.FastForestBinaryClassifier.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Trainers.FastForestRegressor input, Microsoft.ML.Legacy.Trainers.FastForestRegressor.Output output) + { + _jsonNodes.Add(Serialize("Trainers.FastForestRegressor", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Trainers.FastForestBinaryClassifier input, Microsoft.ML.Legacy.Trainers.FastForestBinaryClassifier.Output output) - { - _jsonNodes.Add(Serialize("Trainers.FastForestBinaryClassifier", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Trainers.FastTreeBinaryClassifier.Output Add(Microsoft.ML.Legacy.Trainers.FastTreeBinaryClassifier input) + { + var output = new Microsoft.ML.Legacy.Trainers.FastTreeBinaryClassifier.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Trainers.FastForestRegressor.Output Add(Microsoft.ML.Legacy.Trainers.FastForestRegressor input) - { - var output = new Microsoft.ML.Legacy.Trainers.FastForestRegressor.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Trainers.FastTreeBinaryClassifier input, Microsoft.ML.Legacy.Trainers.FastTreeBinaryClassifier.Output output) + { + _jsonNodes.Add(Serialize("Trainers.FastTreeBinaryClassifier", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Trainers.FastForestRegressor input, Microsoft.ML.Legacy.Trainers.FastForestRegressor.Output output) - { - _jsonNodes.Add(Serialize("Trainers.FastForestRegressor", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Trainers.FastTreeRanker.Output Add(Microsoft.ML.Legacy.Trainers.FastTreeRanker input) + { + var output = new Microsoft.ML.Legacy.Trainers.FastTreeRanker.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Trainers.FastTreeBinaryClassifier.Output Add(Microsoft.ML.Legacy.Trainers.FastTreeBinaryClassifier input) - { - var output = new Microsoft.ML.Legacy.Trainers.FastTreeBinaryClassifier.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Trainers.FastTreeRanker input, Microsoft.ML.Legacy.Trainers.FastTreeRanker.Output output) + { + _jsonNodes.Add(Serialize("Trainers.FastTreeRanker", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Trainers.FastTreeBinaryClassifier input, Microsoft.ML.Legacy.Trainers.FastTreeBinaryClassifier.Output output) - { - _jsonNodes.Add(Serialize("Trainers.FastTreeBinaryClassifier", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Trainers.FastTreeRegressor.Output Add(Microsoft.ML.Legacy.Trainers.FastTreeRegressor input) + { + var output = new Microsoft.ML.Legacy.Trainers.FastTreeRegressor.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Trainers.FastTreeRanker.Output Add(Microsoft.ML.Legacy.Trainers.FastTreeRanker input) - { - var output = new Microsoft.ML.Legacy.Trainers.FastTreeRanker.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Trainers.FastTreeRegressor input, Microsoft.ML.Legacy.Trainers.FastTreeRegressor.Output output) + { + _jsonNodes.Add(Serialize("Trainers.FastTreeRegressor", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Trainers.FastTreeRanker input, Microsoft.ML.Legacy.Trainers.FastTreeRanker.Output output) - { - _jsonNodes.Add(Serialize("Trainers.FastTreeRanker", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Trainers.FastTreeTweedieRegressor.Output Add(Microsoft.ML.Legacy.Trainers.FastTreeTweedieRegressor input) + { + var output = new Microsoft.ML.Legacy.Trainers.FastTreeTweedieRegressor.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Trainers.FastTreeRegressor.Output Add(Microsoft.ML.Legacy.Trainers.FastTreeRegressor input) - { - var output = new Microsoft.ML.Legacy.Trainers.FastTreeRegressor.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Trainers.FastTreeTweedieRegressor input, Microsoft.ML.Legacy.Trainers.FastTreeTweedieRegressor.Output output) + { + _jsonNodes.Add(Serialize("Trainers.FastTreeTweedieRegressor", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Trainers.FastTreeRegressor input, Microsoft.ML.Legacy.Trainers.FastTreeRegressor.Output output) - { - _jsonNodes.Add(Serialize("Trainers.FastTreeRegressor", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Trainers.FieldAwareFactorizationMachineBinaryClassifier.Output Add(Microsoft.ML.Legacy.Trainers.FieldAwareFactorizationMachineBinaryClassifier input) + { + var output = new Microsoft.ML.Legacy.Trainers.FieldAwareFactorizationMachineBinaryClassifier.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Trainers.FastTreeTweedieRegressor.Output Add(Microsoft.ML.Legacy.Trainers.FastTreeTweedieRegressor input) - { - var output = new Microsoft.ML.Legacy.Trainers.FastTreeTweedieRegressor.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Trainers.FieldAwareFactorizationMachineBinaryClassifier input, Microsoft.ML.Legacy.Trainers.FieldAwareFactorizationMachineBinaryClassifier.Output output) + { + _jsonNodes.Add(Serialize("Trainers.FieldAwareFactorizationMachineBinaryClassifier", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Trainers.FastTreeTweedieRegressor input, Microsoft.ML.Legacy.Trainers.FastTreeTweedieRegressor.Output output) - { - _jsonNodes.Add(Serialize("Trainers.FastTreeTweedieRegressor", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Trainers.GeneralizedAdditiveModelBinaryClassifier.Output Add(Microsoft.ML.Legacy.Trainers.GeneralizedAdditiveModelBinaryClassifier input) + { + var output = new Microsoft.ML.Legacy.Trainers.GeneralizedAdditiveModelBinaryClassifier.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Trainers.FieldAwareFactorizationMachineBinaryClassifier.Output Add(Microsoft.ML.Legacy.Trainers.FieldAwareFactorizationMachineBinaryClassifier input) - { - var output = new Microsoft.ML.Legacy.Trainers.FieldAwareFactorizationMachineBinaryClassifier.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Trainers.GeneralizedAdditiveModelBinaryClassifier input, Microsoft.ML.Legacy.Trainers.GeneralizedAdditiveModelBinaryClassifier.Output output) + { + _jsonNodes.Add(Serialize("Trainers.GeneralizedAdditiveModelBinaryClassifier", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Trainers.FieldAwareFactorizationMachineBinaryClassifier input, Microsoft.ML.Legacy.Trainers.FieldAwareFactorizationMachineBinaryClassifier.Output output) - { - _jsonNodes.Add(Serialize("Trainers.FieldAwareFactorizationMachineBinaryClassifier", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Trainers.GeneralizedAdditiveModelRegressor.Output Add(Microsoft.ML.Legacy.Trainers.GeneralizedAdditiveModelRegressor input) + { + var output = new Microsoft.ML.Legacy.Trainers.GeneralizedAdditiveModelRegressor.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Trainers.GeneralizedAdditiveModelBinaryClassifier.Output Add(Microsoft.ML.Legacy.Trainers.GeneralizedAdditiveModelBinaryClassifier input) - { - var output = new Microsoft.ML.Legacy.Trainers.GeneralizedAdditiveModelBinaryClassifier.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Trainers.GeneralizedAdditiveModelRegressor input, Microsoft.ML.Legacy.Trainers.GeneralizedAdditiveModelRegressor.Output output) + { + _jsonNodes.Add(Serialize("Trainers.GeneralizedAdditiveModelRegressor", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Trainers.GeneralizedAdditiveModelBinaryClassifier input, Microsoft.ML.Legacy.Trainers.GeneralizedAdditiveModelBinaryClassifier.Output output) - { - _jsonNodes.Add(Serialize("Trainers.GeneralizedAdditiveModelBinaryClassifier", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Trainers.KMeansPlusPlusClusterer.Output Add(Microsoft.ML.Legacy.Trainers.KMeansPlusPlusClusterer input) + { + var output = new Microsoft.ML.Legacy.Trainers.KMeansPlusPlusClusterer.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Trainers.GeneralizedAdditiveModelRegressor.Output Add(Microsoft.ML.Legacy.Trainers.GeneralizedAdditiveModelRegressor input) - { - var output = new Microsoft.ML.Legacy.Trainers.GeneralizedAdditiveModelRegressor.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Trainers.KMeansPlusPlusClusterer input, Microsoft.ML.Legacy.Trainers.KMeansPlusPlusClusterer.Output output) + { + _jsonNodes.Add(Serialize("Trainers.KMeansPlusPlusClusterer", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Trainers.GeneralizedAdditiveModelRegressor input, Microsoft.ML.Legacy.Trainers.GeneralizedAdditiveModelRegressor.Output output) - { - _jsonNodes.Add(Serialize("Trainers.GeneralizedAdditiveModelRegressor", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Trainers.LightGbmBinaryClassifier.Output Add(Microsoft.ML.Legacy.Trainers.LightGbmBinaryClassifier input) + { + var output = new Microsoft.ML.Legacy.Trainers.LightGbmBinaryClassifier.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Trainers.KMeansPlusPlusClusterer.Output Add(Microsoft.ML.Legacy.Trainers.KMeansPlusPlusClusterer input) - { - var output = new Microsoft.ML.Legacy.Trainers.KMeansPlusPlusClusterer.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Trainers.LightGbmBinaryClassifier input, Microsoft.ML.Legacy.Trainers.LightGbmBinaryClassifier.Output output) + { + _jsonNodes.Add(Serialize("Trainers.LightGbmBinaryClassifier", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Trainers.KMeansPlusPlusClusterer input, Microsoft.ML.Legacy.Trainers.KMeansPlusPlusClusterer.Output output) - { - _jsonNodes.Add(Serialize("Trainers.KMeansPlusPlusClusterer", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Trainers.LightGbmClassifier.Output Add(Microsoft.ML.Legacy.Trainers.LightGbmClassifier input) + { + var output = new Microsoft.ML.Legacy.Trainers.LightGbmClassifier.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Trainers.LightGbmBinaryClassifier.Output Add(Microsoft.ML.Legacy.Trainers.LightGbmBinaryClassifier input) - { - var output = new Microsoft.ML.Legacy.Trainers.LightGbmBinaryClassifier.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Trainers.LightGbmClassifier input, Microsoft.ML.Legacy.Trainers.LightGbmClassifier.Output output) + { + _jsonNodes.Add(Serialize("Trainers.LightGbmClassifier", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Trainers.LightGbmBinaryClassifier input, Microsoft.ML.Legacy.Trainers.LightGbmBinaryClassifier.Output output) - { - _jsonNodes.Add(Serialize("Trainers.LightGbmBinaryClassifier", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Trainers.LightGbmRanker.Output Add(Microsoft.ML.Legacy.Trainers.LightGbmRanker input) + { + var output = new Microsoft.ML.Legacy.Trainers.LightGbmRanker.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Trainers.LightGbmClassifier.Output Add(Microsoft.ML.Legacy.Trainers.LightGbmClassifier input) - { - var output = new Microsoft.ML.Legacy.Trainers.LightGbmClassifier.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Trainers.LightGbmRanker input, Microsoft.ML.Legacy.Trainers.LightGbmRanker.Output output) + { + _jsonNodes.Add(Serialize("Trainers.LightGbmRanker", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Trainers.LightGbmClassifier input, Microsoft.ML.Legacy.Trainers.LightGbmClassifier.Output output) - { - _jsonNodes.Add(Serialize("Trainers.LightGbmClassifier", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Trainers.LightGbmRegressor.Output Add(Microsoft.ML.Legacy.Trainers.LightGbmRegressor input) + { + var output = new Microsoft.ML.Legacy.Trainers.LightGbmRegressor.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Trainers.LightGbmRanker.Output Add(Microsoft.ML.Legacy.Trainers.LightGbmRanker input) - { - var output = new Microsoft.ML.Legacy.Trainers.LightGbmRanker.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Trainers.LightGbmRegressor input, Microsoft.ML.Legacy.Trainers.LightGbmRegressor.Output output) + { + _jsonNodes.Add(Serialize("Trainers.LightGbmRegressor", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Trainers.LightGbmRanker input, Microsoft.ML.Legacy.Trainers.LightGbmRanker.Output output) - { - _jsonNodes.Add(Serialize("Trainers.LightGbmRanker", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Trainers.LinearSvmBinaryClassifier.Output Add(Microsoft.ML.Legacy.Trainers.LinearSvmBinaryClassifier input) + { + var output = new Microsoft.ML.Legacy.Trainers.LinearSvmBinaryClassifier.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Trainers.LightGbmRegressor.Output Add(Microsoft.ML.Legacy.Trainers.LightGbmRegressor input) - { - var output = new Microsoft.ML.Legacy.Trainers.LightGbmRegressor.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Trainers.LinearSvmBinaryClassifier input, Microsoft.ML.Legacy.Trainers.LinearSvmBinaryClassifier.Output output) + { + _jsonNodes.Add(Serialize("Trainers.LinearSvmBinaryClassifier", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Trainers.LightGbmRegressor input, Microsoft.ML.Legacy.Trainers.LightGbmRegressor.Output output) - { - _jsonNodes.Add(Serialize("Trainers.LightGbmRegressor", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Trainers.LogisticRegressionBinaryClassifier.Output Add(Microsoft.ML.Legacy.Trainers.LogisticRegressionBinaryClassifier input) + { + var output = new Microsoft.ML.Legacy.Trainers.LogisticRegressionBinaryClassifier.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Trainers.LinearSvmBinaryClassifier.Output Add(Microsoft.ML.Legacy.Trainers.LinearSvmBinaryClassifier input) - { - var output = new Microsoft.ML.Legacy.Trainers.LinearSvmBinaryClassifier.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Trainers.LogisticRegressionBinaryClassifier input, Microsoft.ML.Legacy.Trainers.LogisticRegressionBinaryClassifier.Output output) + { + _jsonNodes.Add(Serialize("Trainers.LogisticRegressionBinaryClassifier", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Trainers.LinearSvmBinaryClassifier input, Microsoft.ML.Legacy.Trainers.LinearSvmBinaryClassifier.Output output) - { - _jsonNodes.Add(Serialize("Trainers.LinearSvmBinaryClassifier", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Trainers.LogisticRegressionClassifier.Output Add(Microsoft.ML.Legacy.Trainers.LogisticRegressionClassifier input) + { + var output = new Microsoft.ML.Legacy.Trainers.LogisticRegressionClassifier.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Trainers.LogisticRegressionBinaryClassifier.Output Add(Microsoft.ML.Legacy.Trainers.LogisticRegressionBinaryClassifier input) - { - var output = new Microsoft.ML.Legacy.Trainers.LogisticRegressionBinaryClassifier.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Trainers.LogisticRegressionClassifier input, Microsoft.ML.Legacy.Trainers.LogisticRegressionClassifier.Output output) + { + _jsonNodes.Add(Serialize("Trainers.LogisticRegressionClassifier", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Trainers.LogisticRegressionBinaryClassifier input, Microsoft.ML.Legacy.Trainers.LogisticRegressionBinaryClassifier.Output output) - { - _jsonNodes.Add(Serialize("Trainers.LogisticRegressionBinaryClassifier", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Trainers.NaiveBayesClassifier.Output Add(Microsoft.ML.Legacy.Trainers.NaiveBayesClassifier input) + { + var output = new Microsoft.ML.Legacy.Trainers.NaiveBayesClassifier.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Trainers.LogisticRegressionClassifier.Output Add(Microsoft.ML.Legacy.Trainers.LogisticRegressionClassifier input) - { - var output = new Microsoft.ML.Legacy.Trainers.LogisticRegressionClassifier.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Trainers.NaiveBayesClassifier input, Microsoft.ML.Legacy.Trainers.NaiveBayesClassifier.Output output) + { + _jsonNodes.Add(Serialize("Trainers.NaiveBayesClassifier", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Trainers.LogisticRegressionClassifier input, Microsoft.ML.Legacy.Trainers.LogisticRegressionClassifier.Output output) - { - _jsonNodes.Add(Serialize("Trainers.LogisticRegressionClassifier", input, output)); - } - - [Obsolete] - public Microsoft.ML.Legacy.Trainers.NaiveBayesClassifier.Output Add(Microsoft.ML.Legacy.Trainers.NaiveBayesClassifier input) - { - var output = new Microsoft.ML.Legacy.Trainers.NaiveBayesClassifier.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Trainers.OnlineGradientDescentRegressor.Output Add(Microsoft.ML.Legacy.Trainers.OnlineGradientDescentRegressor input) + { + var output = new Microsoft.ML.Legacy.Trainers.OnlineGradientDescentRegressor.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Trainers.NaiveBayesClassifier input, Microsoft.ML.Legacy.Trainers.NaiveBayesClassifier.Output output) - { - _jsonNodes.Add(Serialize("Trainers.NaiveBayesClassifier", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Trainers.OnlineGradientDescentRegressor input, Microsoft.ML.Legacy.Trainers.OnlineGradientDescentRegressor.Output output) + { + _jsonNodes.Add(Serialize("Trainers.OnlineGradientDescentRegressor", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Trainers.OnlineGradientDescentRegressor.Output Add(Microsoft.ML.Legacy.Trainers.OnlineGradientDescentRegressor input) - { - var output = new Microsoft.ML.Legacy.Trainers.OnlineGradientDescentRegressor.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Trainers.OrdinaryLeastSquaresRegressor.Output Add(Microsoft.ML.Legacy.Trainers.OrdinaryLeastSquaresRegressor input) + { + var output = new Microsoft.ML.Legacy.Trainers.OrdinaryLeastSquaresRegressor.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Trainers.OnlineGradientDescentRegressor input, Microsoft.ML.Legacy.Trainers.OnlineGradientDescentRegressor.Output output) - { - _jsonNodes.Add(Serialize("Trainers.OnlineGradientDescentRegressor", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Trainers.OrdinaryLeastSquaresRegressor input, Microsoft.ML.Legacy.Trainers.OrdinaryLeastSquaresRegressor.Output output) + { + _jsonNodes.Add(Serialize("Trainers.OrdinaryLeastSquaresRegressor", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Trainers.OrdinaryLeastSquaresRegressor.Output Add(Microsoft.ML.Legacy.Trainers.OrdinaryLeastSquaresRegressor input) - { - var output = new Microsoft.ML.Legacy.Trainers.OrdinaryLeastSquaresRegressor.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Trainers.PcaAnomalyDetector.Output Add(Microsoft.ML.Legacy.Trainers.PcaAnomalyDetector input) + { + var output = new Microsoft.ML.Legacy.Trainers.PcaAnomalyDetector.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Trainers.OrdinaryLeastSquaresRegressor input, Microsoft.ML.Legacy.Trainers.OrdinaryLeastSquaresRegressor.Output output) - { - _jsonNodes.Add(Serialize("Trainers.OrdinaryLeastSquaresRegressor", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Trainers.PcaAnomalyDetector input, Microsoft.ML.Legacy.Trainers.PcaAnomalyDetector.Output output) + { + _jsonNodes.Add(Serialize("Trainers.PcaAnomalyDetector", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Trainers.PcaAnomalyDetector.Output Add(Microsoft.ML.Legacy.Trainers.PcaAnomalyDetector input) - { - var output = new Microsoft.ML.Legacy.Trainers.PcaAnomalyDetector.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Trainers.PoissonRegressor.Output Add(Microsoft.ML.Legacy.Trainers.PoissonRegressor input) + { + var output = new Microsoft.ML.Legacy.Trainers.PoissonRegressor.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Trainers.PcaAnomalyDetector input, Microsoft.ML.Legacy.Trainers.PcaAnomalyDetector.Output output) - { - _jsonNodes.Add(Serialize("Trainers.PcaAnomalyDetector", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Trainers.PoissonRegressor input, Microsoft.ML.Legacy.Trainers.PoissonRegressor.Output output) + { + _jsonNodes.Add(Serialize("Trainers.PoissonRegressor", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Trainers.PoissonRegressor.Output Add(Microsoft.ML.Legacy.Trainers.PoissonRegressor input) - { - var output = new Microsoft.ML.Legacy.Trainers.PoissonRegressor.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Trainers.StochasticDualCoordinateAscentBinaryClassifier.Output Add(Microsoft.ML.Legacy.Trainers.StochasticDualCoordinateAscentBinaryClassifier input) + { + var output = new Microsoft.ML.Legacy.Trainers.StochasticDualCoordinateAscentBinaryClassifier.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Trainers.PoissonRegressor input, Microsoft.ML.Legacy.Trainers.PoissonRegressor.Output output) - { - _jsonNodes.Add(Serialize("Trainers.PoissonRegressor", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Trainers.StochasticDualCoordinateAscentBinaryClassifier input, Microsoft.ML.Legacy.Trainers.StochasticDualCoordinateAscentBinaryClassifier.Output output) + { + _jsonNodes.Add(Serialize("Trainers.StochasticDualCoordinateAscentBinaryClassifier", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Trainers.StochasticDualCoordinateAscentBinaryClassifier.Output Add(Microsoft.ML.Legacy.Trainers.StochasticDualCoordinateAscentBinaryClassifier input) - { - var output = new Microsoft.ML.Legacy.Trainers.StochasticDualCoordinateAscentBinaryClassifier.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Trainers.StochasticDualCoordinateAscentClassifier.Output Add(Microsoft.ML.Legacy.Trainers.StochasticDualCoordinateAscentClassifier input) + { + var output = new Microsoft.ML.Legacy.Trainers.StochasticDualCoordinateAscentClassifier.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Trainers.StochasticDualCoordinateAscentBinaryClassifier input, Microsoft.ML.Legacy.Trainers.StochasticDualCoordinateAscentBinaryClassifier.Output output) - { - _jsonNodes.Add(Serialize("Trainers.StochasticDualCoordinateAscentBinaryClassifier", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Trainers.StochasticDualCoordinateAscentClassifier input, Microsoft.ML.Legacy.Trainers.StochasticDualCoordinateAscentClassifier.Output output) + { + _jsonNodes.Add(Serialize("Trainers.StochasticDualCoordinateAscentClassifier", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Trainers.StochasticDualCoordinateAscentClassifier.Output Add(Microsoft.ML.Legacy.Trainers.StochasticDualCoordinateAscentClassifier input) - { - var output = new Microsoft.ML.Legacy.Trainers.StochasticDualCoordinateAscentClassifier.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Trainers.StochasticDualCoordinateAscentRegressor.Output Add(Microsoft.ML.Legacy.Trainers.StochasticDualCoordinateAscentRegressor input) + { + var output = new Microsoft.ML.Legacy.Trainers.StochasticDualCoordinateAscentRegressor.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Trainers.StochasticDualCoordinateAscentClassifier input, Microsoft.ML.Legacy.Trainers.StochasticDualCoordinateAscentClassifier.Output output) - { - _jsonNodes.Add(Serialize("Trainers.StochasticDualCoordinateAscentClassifier", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Trainers.StochasticDualCoordinateAscentRegressor input, Microsoft.ML.Legacy.Trainers.StochasticDualCoordinateAscentRegressor.Output output) + { + _jsonNodes.Add(Serialize("Trainers.StochasticDualCoordinateAscentRegressor", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Trainers.StochasticDualCoordinateAscentRegressor.Output Add(Microsoft.ML.Legacy.Trainers.StochasticDualCoordinateAscentRegressor input) - { - var output = new Microsoft.ML.Legacy.Trainers.StochasticDualCoordinateAscentRegressor.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Trainers.StochasticGradientDescentBinaryClassifier.Output Add(Microsoft.ML.Legacy.Trainers.StochasticGradientDescentBinaryClassifier input) + { + var output = new Microsoft.ML.Legacy.Trainers.StochasticGradientDescentBinaryClassifier.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Trainers.StochasticDualCoordinateAscentRegressor input, Microsoft.ML.Legacy.Trainers.StochasticDualCoordinateAscentRegressor.Output output) - { - _jsonNodes.Add(Serialize("Trainers.StochasticDualCoordinateAscentRegressor", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Trainers.StochasticGradientDescentBinaryClassifier input, Microsoft.ML.Legacy.Trainers.StochasticGradientDescentBinaryClassifier.Output output) + { + _jsonNodes.Add(Serialize("Trainers.StochasticGradientDescentBinaryClassifier", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Trainers.StochasticGradientDescentBinaryClassifier.Output Add(Microsoft.ML.Legacy.Trainers.StochasticGradientDescentBinaryClassifier input) - { - var output = new Microsoft.ML.Legacy.Trainers.StochasticGradientDescentBinaryClassifier.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Trainers.SymSgdBinaryClassifier.Output Add(Microsoft.ML.Legacy.Trainers.SymSgdBinaryClassifier input) + { + var output = new Microsoft.ML.Legacy.Trainers.SymSgdBinaryClassifier.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Trainers.StochasticGradientDescentBinaryClassifier input, Microsoft.ML.Legacy.Trainers.StochasticGradientDescentBinaryClassifier.Output output) - { - _jsonNodes.Add(Serialize("Trainers.StochasticGradientDescentBinaryClassifier", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Trainers.SymSgdBinaryClassifier input, Microsoft.ML.Legacy.Trainers.SymSgdBinaryClassifier.Output output) + { + _jsonNodes.Add(Serialize("Trainers.SymSgdBinaryClassifier", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Trainers.SymSgdBinaryClassifier.Output Add(Microsoft.ML.Legacy.Trainers.SymSgdBinaryClassifier input) - { - var output = new Microsoft.ML.Legacy.Trainers.SymSgdBinaryClassifier.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.ApproximateBootstrapSampler.Output Add(Microsoft.ML.Legacy.Transforms.ApproximateBootstrapSampler input) + { + var output = new Microsoft.ML.Legacy.Transforms.ApproximateBootstrapSampler.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Trainers.SymSgdBinaryClassifier input, Microsoft.ML.Legacy.Trainers.SymSgdBinaryClassifier.Output output) - { - _jsonNodes.Add(Serialize("Trainers.SymSgdBinaryClassifier", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.ApproximateBootstrapSampler input, Microsoft.ML.Legacy.Transforms.ApproximateBootstrapSampler.Output output) + { + _jsonNodes.Add(Serialize("Transforms.ApproximateBootstrapSampler", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.ApproximateBootstrapSampler.Output Add(Microsoft.ML.Legacy.Transforms.ApproximateBootstrapSampler input) - { - var output = new Microsoft.ML.Legacy.Transforms.ApproximateBootstrapSampler.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.BinaryPredictionScoreColumnsRenamer.Output Add(Microsoft.ML.Legacy.Transforms.BinaryPredictionScoreColumnsRenamer input) + { + var output = new Microsoft.ML.Legacy.Transforms.BinaryPredictionScoreColumnsRenamer.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.ApproximateBootstrapSampler input, Microsoft.ML.Legacy.Transforms.ApproximateBootstrapSampler.Output output) - { - _jsonNodes.Add(Serialize("Transforms.ApproximateBootstrapSampler", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.BinaryPredictionScoreColumnsRenamer input, Microsoft.ML.Legacy.Transforms.BinaryPredictionScoreColumnsRenamer.Output output) + { + _jsonNodes.Add(Serialize("Transforms.BinaryPredictionScoreColumnsRenamer", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.BinaryPredictionScoreColumnsRenamer.Output Add(Microsoft.ML.Legacy.Transforms.BinaryPredictionScoreColumnsRenamer input) - { - var output = new Microsoft.ML.Legacy.Transforms.BinaryPredictionScoreColumnsRenamer.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.BinNormalizer.Output Add(Microsoft.ML.Legacy.Transforms.BinNormalizer input) + { + var output = new Microsoft.ML.Legacy.Transforms.BinNormalizer.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.BinaryPredictionScoreColumnsRenamer input, Microsoft.ML.Legacy.Transforms.BinaryPredictionScoreColumnsRenamer.Output output) - { - _jsonNodes.Add(Serialize("Transforms.BinaryPredictionScoreColumnsRenamer", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.BinNormalizer input, Microsoft.ML.Legacy.Transforms.BinNormalizer.Output output) + { + _jsonNodes.Add(Serialize("Transforms.BinNormalizer", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.BinNormalizer.Output Add(Microsoft.ML.Legacy.Transforms.BinNormalizer input) - { - var output = new Microsoft.ML.Legacy.Transforms.BinNormalizer.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.CategoricalHashOneHotVectorizer.Output Add(Microsoft.ML.Legacy.Transforms.CategoricalHashOneHotVectorizer input) + { + var output = new Microsoft.ML.Legacy.Transforms.CategoricalHashOneHotVectorizer.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.BinNormalizer input, Microsoft.ML.Legacy.Transforms.BinNormalizer.Output output) - { - _jsonNodes.Add(Serialize("Transforms.BinNormalizer", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.CategoricalHashOneHotVectorizer input, Microsoft.ML.Legacy.Transforms.CategoricalHashOneHotVectorizer.Output output) + { + _jsonNodes.Add(Serialize("Transforms.CategoricalHashOneHotVectorizer", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.CategoricalHashOneHotVectorizer.Output Add(Microsoft.ML.Legacy.Transforms.CategoricalHashOneHotVectorizer input) - { - var output = new Microsoft.ML.Legacy.Transforms.CategoricalHashOneHotVectorizer.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.CategoricalOneHotVectorizer.Output Add(Microsoft.ML.Legacy.Transforms.CategoricalOneHotVectorizer input) + { + var output = new Microsoft.ML.Legacy.Transforms.CategoricalOneHotVectorizer.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.CategoricalHashOneHotVectorizer input, Microsoft.ML.Legacy.Transforms.CategoricalHashOneHotVectorizer.Output output) - { - _jsonNodes.Add(Serialize("Transforms.CategoricalHashOneHotVectorizer", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.CategoricalOneHotVectorizer input, Microsoft.ML.Legacy.Transforms.CategoricalOneHotVectorizer.Output output) + { + _jsonNodes.Add(Serialize("Transforms.CategoricalOneHotVectorizer", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.CategoricalOneHotVectorizer.Output Add(Microsoft.ML.Legacy.Transforms.CategoricalOneHotVectorizer input) - { - var output = new Microsoft.ML.Legacy.Transforms.CategoricalOneHotVectorizer.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.CharacterTokenizer.Output Add(Microsoft.ML.Legacy.Transforms.CharacterTokenizer input) + { + var output = new Microsoft.ML.Legacy.Transforms.CharacterTokenizer.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.CategoricalOneHotVectorizer input, Microsoft.ML.Legacy.Transforms.CategoricalOneHotVectorizer.Output output) - { - _jsonNodes.Add(Serialize("Transforms.CategoricalOneHotVectorizer", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.CharacterTokenizer input, Microsoft.ML.Legacy.Transforms.CharacterTokenizer.Output output) + { + _jsonNodes.Add(Serialize("Transforms.CharacterTokenizer", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.CharacterTokenizer.Output Add(Microsoft.ML.Legacy.Transforms.CharacterTokenizer input) - { - var output = new Microsoft.ML.Legacy.Transforms.CharacterTokenizer.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.ColumnConcatenator.Output Add(Microsoft.ML.Legacy.Transforms.ColumnConcatenator input) + { + var output = new Microsoft.ML.Legacy.Transforms.ColumnConcatenator.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.CharacterTokenizer input, Microsoft.ML.Legacy.Transforms.CharacterTokenizer.Output output) - { - _jsonNodes.Add(Serialize("Transforms.CharacterTokenizer", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.ColumnConcatenator input, Microsoft.ML.Legacy.Transforms.ColumnConcatenator.Output output) + { + _jsonNodes.Add(Serialize("Transforms.ColumnConcatenator", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.ColumnConcatenator.Output Add(Microsoft.ML.Legacy.Transforms.ColumnConcatenator input) - { - var output = new Microsoft.ML.Legacy.Transforms.ColumnConcatenator.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.ColumnCopier.Output Add(Microsoft.ML.Legacy.Transforms.ColumnCopier input) + { + var output = new Microsoft.ML.Legacy.Transforms.ColumnCopier.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.ColumnConcatenator input, Microsoft.ML.Legacy.Transforms.ColumnConcatenator.Output output) - { - _jsonNodes.Add(Serialize("Transforms.ColumnConcatenator", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.ColumnCopier input, Microsoft.ML.Legacy.Transforms.ColumnCopier.Output output) + { + _jsonNodes.Add(Serialize("Transforms.ColumnCopier", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.ColumnCopier.Output Add(Microsoft.ML.Legacy.Transforms.ColumnCopier input) - { - var output = new Microsoft.ML.Legacy.Transforms.ColumnCopier.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.ColumnSelector.Output Add(Microsoft.ML.Legacy.Transforms.ColumnSelector input) + { + var output = new Microsoft.ML.Legacy.Transforms.ColumnSelector.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.ColumnCopier input, Microsoft.ML.Legacy.Transforms.ColumnCopier.Output output) - { - _jsonNodes.Add(Serialize("Transforms.ColumnCopier", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.ColumnSelector input, Microsoft.ML.Legacy.Transforms.ColumnSelector.Output output) + { + _jsonNodes.Add(Serialize("Transforms.ColumnSelector", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.ColumnSelector.Output Add(Microsoft.ML.Legacy.Transforms.ColumnSelector input) - { - var output = new Microsoft.ML.Legacy.Transforms.ColumnSelector.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.ColumnTypeConverter.Output Add(Microsoft.ML.Legacy.Transforms.ColumnTypeConverter input) + { + var output = new Microsoft.ML.Legacy.Transforms.ColumnTypeConverter.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.ColumnSelector input, Microsoft.ML.Legacy.Transforms.ColumnSelector.Output output) - { - _jsonNodes.Add(Serialize("Transforms.ColumnSelector", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.ColumnTypeConverter input, Microsoft.ML.Legacy.Transforms.ColumnTypeConverter.Output output) + { + _jsonNodes.Add(Serialize("Transforms.ColumnTypeConverter", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.ColumnTypeConverter.Output Add(Microsoft.ML.Legacy.Transforms.ColumnTypeConverter input) - { - var output = new Microsoft.ML.Legacy.Transforms.ColumnTypeConverter.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.CombinerByContiguousGroupId.Output Add(Microsoft.ML.Legacy.Transforms.CombinerByContiguousGroupId input) + { + var output = new Microsoft.ML.Legacy.Transforms.CombinerByContiguousGroupId.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.ColumnTypeConverter input, Microsoft.ML.Legacy.Transforms.ColumnTypeConverter.Output output) - { - _jsonNodes.Add(Serialize("Transforms.ColumnTypeConverter", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.CombinerByContiguousGroupId input, Microsoft.ML.Legacy.Transforms.CombinerByContiguousGroupId.Output output) + { + _jsonNodes.Add(Serialize("Transforms.CombinerByContiguousGroupId", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.CombinerByContiguousGroupId.Output Add(Microsoft.ML.Legacy.Transforms.CombinerByContiguousGroupId input) - { - var output = new Microsoft.ML.Legacy.Transforms.CombinerByContiguousGroupId.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.ConditionalNormalizer.Output Add(Microsoft.ML.Legacy.Transforms.ConditionalNormalizer input) + { + var output = new Microsoft.ML.Legacy.Transforms.ConditionalNormalizer.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.CombinerByContiguousGroupId input, Microsoft.ML.Legacy.Transforms.CombinerByContiguousGroupId.Output output) - { - _jsonNodes.Add(Serialize("Transforms.CombinerByContiguousGroupId", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.ConditionalNormalizer input, Microsoft.ML.Legacy.Transforms.ConditionalNormalizer.Output output) + { + _jsonNodes.Add(Serialize("Transforms.ConditionalNormalizer", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.ConditionalNormalizer.Output Add(Microsoft.ML.Legacy.Transforms.ConditionalNormalizer input) - { - var output = new Microsoft.ML.Legacy.Transforms.ConditionalNormalizer.Output(); - Add(input, output); - return output; - } - - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.ConditionalNormalizer input, Microsoft.ML.Legacy.Transforms.ConditionalNormalizer.Output output) - { - _jsonNodes.Add(Serialize("Transforms.ConditionalNormalizer", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.DataCache.Output Add(Microsoft.ML.Legacy.Transforms.DataCache input) + { + var output = new Microsoft.ML.Legacy.Transforms.DataCache.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.DataCache.Output Add(Microsoft.ML.Legacy.Transforms.DataCache input) - { - var output = new Microsoft.ML.Legacy.Transforms.DataCache.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.DataCache input, Microsoft.ML.Legacy.Transforms.DataCache.Output output) + { + _jsonNodes.Add(Serialize("Transforms.DataCache", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.DataCache input, Microsoft.ML.Legacy.Transforms.DataCache.Output output) - { - _jsonNodes.Add(Serialize("Transforms.DataCache", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.DatasetScorer.Output Add(Microsoft.ML.Legacy.Transforms.DatasetScorer input) + { + var output = new Microsoft.ML.Legacy.Transforms.DatasetScorer.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.DatasetScorer.Output Add(Microsoft.ML.Legacy.Transforms.DatasetScorer input) - { - var output = new Microsoft.ML.Legacy.Transforms.DatasetScorer.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.DatasetScorer input, Microsoft.ML.Legacy.Transforms.DatasetScorer.Output output) + { + _jsonNodes.Add(Serialize("Transforms.DatasetScorer", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.DatasetScorer input, Microsoft.ML.Legacy.Transforms.DatasetScorer.Output output) - { - _jsonNodes.Add(Serialize("Transforms.DatasetScorer", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.DatasetTransformScorer.Output Add(Microsoft.ML.Legacy.Transforms.DatasetTransformScorer input) + { + var output = new Microsoft.ML.Legacy.Transforms.DatasetTransformScorer.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.DatasetTransformScorer.Output Add(Microsoft.ML.Legacy.Transforms.DatasetTransformScorer input) - { - var output = new Microsoft.ML.Legacy.Transforms.DatasetTransformScorer.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.DatasetTransformScorer input, Microsoft.ML.Legacy.Transforms.DatasetTransformScorer.Output output) + { + _jsonNodes.Add(Serialize("Transforms.DatasetTransformScorer", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.DatasetTransformScorer input, Microsoft.ML.Legacy.Transforms.DatasetTransformScorer.Output output) - { - _jsonNodes.Add(Serialize("Transforms.DatasetTransformScorer", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.Dictionarizer.Output Add(Microsoft.ML.Legacy.Transforms.Dictionarizer input) + { + var output = new Microsoft.ML.Legacy.Transforms.Dictionarizer.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.Dictionarizer.Output Add(Microsoft.ML.Legacy.Transforms.Dictionarizer input) - { - var output = new Microsoft.ML.Legacy.Transforms.Dictionarizer.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.Dictionarizer input, Microsoft.ML.Legacy.Transforms.Dictionarizer.Output output) + { + _jsonNodes.Add(Serialize("Transforms.Dictionarizer", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.Dictionarizer input, Microsoft.ML.Legacy.Transforms.Dictionarizer.Output output) - { - _jsonNodes.Add(Serialize("Transforms.Dictionarizer", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.FeatureCombiner.Output Add(Microsoft.ML.Legacy.Transforms.FeatureCombiner input) + { + var output = new Microsoft.ML.Legacy.Transforms.FeatureCombiner.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.FeatureCombiner.Output Add(Microsoft.ML.Legacy.Transforms.FeatureCombiner input) - { - var output = new Microsoft.ML.Legacy.Transforms.FeatureCombiner.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.FeatureCombiner input, Microsoft.ML.Legacy.Transforms.FeatureCombiner.Output output) + { + _jsonNodes.Add(Serialize("Transforms.FeatureCombiner", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.FeatureCombiner input, Microsoft.ML.Legacy.Transforms.FeatureCombiner.Output output) - { - _jsonNodes.Add(Serialize("Transforms.FeatureCombiner", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.FeatureContributionCalculationTransformer.Output Add(Microsoft.ML.Legacy.Transforms.FeatureContributionCalculationTransformer input) + { + var output = new Microsoft.ML.Legacy.Transforms.FeatureContributionCalculationTransformer.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.FeatureSelectorByCount.Output Add(Microsoft.ML.Legacy.Transforms.FeatureSelectorByCount input) - { - var output = new Microsoft.ML.Legacy.Transforms.FeatureSelectorByCount.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.FeatureContributionCalculationTransformer input, Microsoft.ML.Legacy.Transforms.FeatureContributionCalculationTransformer.Output output) + { + _jsonNodes.Add(Serialize("Transforms.FeatureContributionCalculationTransformer", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.FeatureSelectorByCount input, Microsoft.ML.Legacy.Transforms.FeatureSelectorByCount.Output output) - { - _jsonNodes.Add(Serialize("Transforms.FeatureSelectorByCount", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.FeatureSelectorByCount.Output Add(Microsoft.ML.Legacy.Transforms.FeatureSelectorByCount input) + { + var output = new Microsoft.ML.Legacy.Transforms.FeatureSelectorByCount.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.FeatureSelectorByMutualInformation.Output Add(Microsoft.ML.Legacy.Transforms.FeatureSelectorByMutualInformation input) - { - var output = new Microsoft.ML.Legacy.Transforms.FeatureSelectorByMutualInformation.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.FeatureSelectorByCount input, Microsoft.ML.Legacy.Transforms.FeatureSelectorByCount.Output output) + { + _jsonNodes.Add(Serialize("Transforms.FeatureSelectorByCount", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.FeatureSelectorByMutualInformation input, Microsoft.ML.Legacy.Transforms.FeatureSelectorByMutualInformation.Output output) - { - _jsonNodes.Add(Serialize("Transforms.FeatureSelectorByMutualInformation", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.FeatureSelectorByMutualInformation.Output Add(Microsoft.ML.Legacy.Transforms.FeatureSelectorByMutualInformation input) + { + var output = new Microsoft.ML.Legacy.Transforms.FeatureSelectorByMutualInformation.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.GlobalContrastNormalizer.Output Add(Microsoft.ML.Legacy.Transforms.GlobalContrastNormalizer input) - { - var output = new Microsoft.ML.Legacy.Transforms.GlobalContrastNormalizer.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.FeatureSelectorByMutualInformation input, Microsoft.ML.Legacy.Transforms.FeatureSelectorByMutualInformation.Output output) + { + _jsonNodes.Add(Serialize("Transforms.FeatureSelectorByMutualInformation", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.GlobalContrastNormalizer input, Microsoft.ML.Legacy.Transforms.GlobalContrastNormalizer.Output output) - { - _jsonNodes.Add(Serialize("Transforms.GlobalContrastNormalizer", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.GlobalContrastNormalizer.Output Add(Microsoft.ML.Legacy.Transforms.GlobalContrastNormalizer input) + { + var output = new Microsoft.ML.Legacy.Transforms.GlobalContrastNormalizer.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.HashConverter.Output Add(Microsoft.ML.Legacy.Transforms.HashConverter input) - { - var output = new Microsoft.ML.Legacy.Transforms.HashConverter.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.GlobalContrastNormalizer input, Microsoft.ML.Legacy.Transforms.GlobalContrastNormalizer.Output output) + { + _jsonNodes.Add(Serialize("Transforms.GlobalContrastNormalizer", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.HashConverter input, Microsoft.ML.Legacy.Transforms.HashConverter.Output output) - { - _jsonNodes.Add(Serialize("Transforms.HashConverter", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.HashConverter.Output Add(Microsoft.ML.Legacy.Transforms.HashConverter input) + { + var output = new Microsoft.ML.Legacy.Transforms.HashConverter.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.ImageGrayscale.Output Add(Microsoft.ML.Legacy.Transforms.ImageGrayscale input) - { - var output = new Microsoft.ML.Legacy.Transforms.ImageGrayscale.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.HashConverter input, Microsoft.ML.Legacy.Transforms.HashConverter.Output output) + { + _jsonNodes.Add(Serialize("Transforms.HashConverter", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.ImageGrayscale input, Microsoft.ML.Legacy.Transforms.ImageGrayscale.Output output) - { - _jsonNodes.Add(Serialize("Transforms.ImageGrayscale", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.ImageGrayscale.Output Add(Microsoft.ML.Legacy.Transforms.ImageGrayscale input) + { + var output = new Microsoft.ML.Legacy.Transforms.ImageGrayscale.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.ImageLoader.Output Add(Microsoft.ML.Legacy.Transforms.ImageLoader input) - { - var output = new Microsoft.ML.Legacy.Transforms.ImageLoader.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.ImageGrayscale input, Microsoft.ML.Legacy.Transforms.ImageGrayscale.Output output) + { + _jsonNodes.Add(Serialize("Transforms.ImageGrayscale", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.ImageLoader input, Microsoft.ML.Legacy.Transforms.ImageLoader.Output output) - { - _jsonNodes.Add(Serialize("Transforms.ImageLoader", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.ImageLoader.Output Add(Microsoft.ML.Legacy.Transforms.ImageLoader input) + { + var output = new Microsoft.ML.Legacy.Transforms.ImageLoader.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.ImagePixelExtractor.Output Add(Microsoft.ML.Legacy.Transforms.ImagePixelExtractor input) - { - var output = new Microsoft.ML.Legacy.Transforms.ImagePixelExtractor.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.ImageLoader input, Microsoft.ML.Legacy.Transforms.ImageLoader.Output output) + { + _jsonNodes.Add(Serialize("Transforms.ImageLoader", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.ImagePixelExtractor input, Microsoft.ML.Legacy.Transforms.ImagePixelExtractor.Output output) - { - _jsonNodes.Add(Serialize("Transforms.ImagePixelExtractor", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.ImagePixelExtractor.Output Add(Microsoft.ML.Legacy.Transforms.ImagePixelExtractor input) + { + var output = new Microsoft.ML.Legacy.Transforms.ImagePixelExtractor.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.ImageResizer.Output Add(Microsoft.ML.Legacy.Transforms.ImageResizer input) - { - var output = new Microsoft.ML.Legacy.Transforms.ImageResizer.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.ImagePixelExtractor input, Microsoft.ML.Legacy.Transforms.ImagePixelExtractor.Output output) + { + _jsonNodes.Add(Serialize("Transforms.ImagePixelExtractor", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.ImageResizer input, Microsoft.ML.Legacy.Transforms.ImageResizer.Output output) - { - _jsonNodes.Add(Serialize("Transforms.ImageResizer", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.ImageResizer.Output Add(Microsoft.ML.Legacy.Transforms.ImageResizer input) + { + var output = new Microsoft.ML.Legacy.Transforms.ImageResizer.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.KeyToTextConverter.Output Add(Microsoft.ML.Legacy.Transforms.KeyToTextConverter input) - { - var output = new Microsoft.ML.Legacy.Transforms.KeyToTextConverter.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.ImageResizer input, Microsoft.ML.Legacy.Transforms.ImageResizer.Output output) + { + _jsonNodes.Add(Serialize("Transforms.ImageResizer", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.KeyToTextConverter input, Microsoft.ML.Legacy.Transforms.KeyToTextConverter.Output output) - { - _jsonNodes.Add(Serialize("Transforms.KeyToTextConverter", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.KeyToTextConverter.Output Add(Microsoft.ML.Legacy.Transforms.KeyToTextConverter input) + { + var output = new Microsoft.ML.Legacy.Transforms.KeyToTextConverter.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.LabelColumnKeyBooleanConverter.Output Add(Microsoft.ML.Legacy.Transforms.LabelColumnKeyBooleanConverter input) - { - var output = new Microsoft.ML.Legacy.Transforms.LabelColumnKeyBooleanConverter.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.KeyToTextConverter input, Microsoft.ML.Legacy.Transforms.KeyToTextConverter.Output output) + { + _jsonNodes.Add(Serialize("Transforms.KeyToTextConverter", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.LabelColumnKeyBooleanConverter input, Microsoft.ML.Legacy.Transforms.LabelColumnKeyBooleanConverter.Output output) - { - _jsonNodes.Add(Serialize("Transforms.LabelColumnKeyBooleanConverter", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.LabelColumnKeyBooleanConverter.Output Add(Microsoft.ML.Legacy.Transforms.LabelColumnKeyBooleanConverter input) + { + var output = new Microsoft.ML.Legacy.Transforms.LabelColumnKeyBooleanConverter.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.LabelIndicator.Output Add(Microsoft.ML.Legacy.Transforms.LabelIndicator input) - { - var output = new Microsoft.ML.Legacy.Transforms.LabelIndicator.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.LabelColumnKeyBooleanConverter input, Microsoft.ML.Legacy.Transforms.LabelColumnKeyBooleanConverter.Output output) + { + _jsonNodes.Add(Serialize("Transforms.LabelColumnKeyBooleanConverter", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.LabelIndicator input, Microsoft.ML.Legacy.Transforms.LabelIndicator.Output output) - { - _jsonNodes.Add(Serialize("Transforms.LabelIndicator", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.LabelIndicator.Output Add(Microsoft.ML.Legacy.Transforms.LabelIndicator input) + { + var output = new Microsoft.ML.Legacy.Transforms.LabelIndicator.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.LabelToFloatConverter.Output Add(Microsoft.ML.Legacy.Transforms.LabelToFloatConverter input) - { - var output = new Microsoft.ML.Legacy.Transforms.LabelToFloatConverter.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.LabelIndicator input, Microsoft.ML.Legacy.Transforms.LabelIndicator.Output output) + { + _jsonNodes.Add(Serialize("Transforms.LabelIndicator", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.LabelToFloatConverter input, Microsoft.ML.Legacy.Transforms.LabelToFloatConverter.Output output) - { - _jsonNodes.Add(Serialize("Transforms.LabelToFloatConverter", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.LabelToFloatConverter.Output Add(Microsoft.ML.Legacy.Transforms.LabelToFloatConverter input) + { + var output = new Microsoft.ML.Legacy.Transforms.LabelToFloatConverter.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.LightLda.Output Add(Microsoft.ML.Legacy.Transforms.LightLda input) - { - var output = new Microsoft.ML.Legacy.Transforms.LightLda.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.LabelToFloatConverter input, Microsoft.ML.Legacy.Transforms.LabelToFloatConverter.Output output) + { + _jsonNodes.Add(Serialize("Transforms.LabelToFloatConverter", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.LightLda input, Microsoft.ML.Legacy.Transforms.LightLda.Output output) - { - _jsonNodes.Add(Serialize("Transforms.LightLda", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.LightLda.Output Add(Microsoft.ML.Legacy.Transforms.LightLda input) + { + var output = new Microsoft.ML.Legacy.Transforms.LightLda.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.LogMeanVarianceNormalizer.Output Add(Microsoft.ML.Legacy.Transforms.LogMeanVarianceNormalizer input) - { - var output = new Microsoft.ML.Legacy.Transforms.LogMeanVarianceNormalizer.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.LightLda input, Microsoft.ML.Legacy.Transforms.LightLda.Output output) + { + _jsonNodes.Add(Serialize("Transforms.LightLda", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.LogMeanVarianceNormalizer input, Microsoft.ML.Legacy.Transforms.LogMeanVarianceNormalizer.Output output) - { - _jsonNodes.Add(Serialize("Transforms.LogMeanVarianceNormalizer", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.LogMeanVarianceNormalizer.Output Add(Microsoft.ML.Legacy.Transforms.LogMeanVarianceNormalizer input) + { + var output = new Microsoft.ML.Legacy.Transforms.LogMeanVarianceNormalizer.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.LpNormalizer.Output Add(Microsoft.ML.Legacy.Transforms.LpNormalizer input) - { - var output = new Microsoft.ML.Legacy.Transforms.LpNormalizer.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.LogMeanVarianceNormalizer input, Microsoft.ML.Legacy.Transforms.LogMeanVarianceNormalizer.Output output) + { + _jsonNodes.Add(Serialize("Transforms.LogMeanVarianceNormalizer", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.LpNormalizer input, Microsoft.ML.Legacy.Transforms.LpNormalizer.Output output) - { - _jsonNodes.Add(Serialize("Transforms.LpNormalizer", input, output)); - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.LpNormalizer.Output Add(Microsoft.ML.Legacy.Transforms.LpNormalizer input) + { + var output = new Microsoft.ML.Legacy.Transforms.LpNormalizer.Output(); + Add(input, output); + return output; + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.ManyHeterogeneousModelCombiner.Output Add(Microsoft.ML.Legacy.Transforms.ManyHeterogeneousModelCombiner input) - { - var output = new Microsoft.ML.Legacy.Transforms.ManyHeterogeneousModelCombiner.Output(); - Add(input, output); - return output; - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.LpNormalizer input, Microsoft.ML.Legacy.Transforms.LpNormalizer.Output output) + { + _jsonNodes.Add(Serialize("Transforms.LpNormalizer", input, output)); + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.ManyHeterogeneousModelCombiner input, Microsoft.ML.Legacy.Transforms.ManyHeterogeneousModelCombiner.Output output) - { - _jsonNodes.Add(Serialize("Transforms.ManyHeterogeneousModelCombiner", input, output)); - } - - [Obsolete] - public Microsoft.ML.Legacy.Transforms.MeanVarianceNormalizer.Output Add(Microsoft.ML.Legacy.Transforms.MeanVarianceNormalizer input) - { - var output = new Microsoft.ML.Legacy.Transforms.MeanVarianceNormalizer.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.ManyHeterogeneousModelCombiner.Output Add(Microsoft.ML.Legacy.Transforms.ManyHeterogeneousModelCombiner input) + { + var output = new Microsoft.ML.Legacy.Transforms.ManyHeterogeneousModelCombiner.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.MeanVarianceNormalizer input, Microsoft.ML.Legacy.Transforms.MeanVarianceNormalizer.Output output) - { - _jsonNodes.Add(Serialize("Transforms.MeanVarianceNormalizer", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.ManyHeterogeneousModelCombiner input, Microsoft.ML.Legacy.Transforms.ManyHeterogeneousModelCombiner.Output output) + { + _jsonNodes.Add(Serialize("Transforms.ManyHeterogeneousModelCombiner", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.MinMaxNormalizer.Output Add(Microsoft.ML.Legacy.Transforms.MinMaxNormalizer input) - { - var output = new Microsoft.ML.Legacy.Transforms.MinMaxNormalizer.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.MeanVarianceNormalizer.Output Add(Microsoft.ML.Legacy.Transforms.MeanVarianceNormalizer input) + { + var output = new Microsoft.ML.Legacy.Transforms.MeanVarianceNormalizer.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.MinMaxNormalizer input, Microsoft.ML.Legacy.Transforms.MinMaxNormalizer.Output output) - { - _jsonNodes.Add(Serialize("Transforms.MinMaxNormalizer", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.MeanVarianceNormalizer input, Microsoft.ML.Legacy.Transforms.MeanVarianceNormalizer.Output output) + { + _jsonNodes.Add(Serialize("Transforms.MeanVarianceNormalizer", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.MissingValueHandler.Output Add(Microsoft.ML.Legacy.Transforms.MissingValueHandler input) - { - var output = new Microsoft.ML.Legacy.Transforms.MissingValueHandler.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.MinMaxNormalizer.Output Add(Microsoft.ML.Legacy.Transforms.MinMaxNormalizer input) + { + var output = new Microsoft.ML.Legacy.Transforms.MinMaxNormalizer.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.MissingValueHandler input, Microsoft.ML.Legacy.Transforms.MissingValueHandler.Output output) - { - _jsonNodes.Add(Serialize("Transforms.MissingValueHandler", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.MinMaxNormalizer input, Microsoft.ML.Legacy.Transforms.MinMaxNormalizer.Output output) + { + _jsonNodes.Add(Serialize("Transforms.MinMaxNormalizer", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.MissingValueIndicator.Output Add(Microsoft.ML.Legacy.Transforms.MissingValueIndicator input) - { - var output = new Microsoft.ML.Legacy.Transforms.MissingValueIndicator.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.MissingValueHandler.Output Add(Microsoft.ML.Legacy.Transforms.MissingValueHandler input) + { + var output = new Microsoft.ML.Legacy.Transforms.MissingValueHandler.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.MissingValueIndicator input, Microsoft.ML.Legacy.Transforms.MissingValueIndicator.Output output) - { - _jsonNodes.Add(Serialize("Transforms.MissingValueIndicator", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.MissingValueHandler input, Microsoft.ML.Legacy.Transforms.MissingValueHandler.Output output) + { + _jsonNodes.Add(Serialize("Transforms.MissingValueHandler", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.MissingValuesDropper.Output Add(Microsoft.ML.Legacy.Transforms.MissingValuesDropper input) - { - var output = new Microsoft.ML.Legacy.Transforms.MissingValuesDropper.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.MissingValueIndicator.Output Add(Microsoft.ML.Legacy.Transforms.MissingValueIndicator input) + { + var output = new Microsoft.ML.Legacy.Transforms.MissingValueIndicator.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.MissingValuesDropper input, Microsoft.ML.Legacy.Transforms.MissingValuesDropper.Output output) - { - _jsonNodes.Add(Serialize("Transforms.MissingValuesDropper", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.MissingValueIndicator input, Microsoft.ML.Legacy.Transforms.MissingValueIndicator.Output output) + { + _jsonNodes.Add(Serialize("Transforms.MissingValueIndicator", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.MissingValuesRowDropper.Output Add(Microsoft.ML.Legacy.Transforms.MissingValuesRowDropper input) - { - var output = new Microsoft.ML.Legacy.Transforms.MissingValuesRowDropper.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.MissingValuesDropper.Output Add(Microsoft.ML.Legacy.Transforms.MissingValuesDropper input) + { + var output = new Microsoft.ML.Legacy.Transforms.MissingValuesDropper.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.MissingValuesRowDropper input, Microsoft.ML.Legacy.Transforms.MissingValuesRowDropper.Output output) - { - _jsonNodes.Add(Serialize("Transforms.MissingValuesRowDropper", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.MissingValuesDropper input, Microsoft.ML.Legacy.Transforms.MissingValuesDropper.Output output) + { + _jsonNodes.Add(Serialize("Transforms.MissingValuesDropper", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.MissingValueSubstitutor.Output Add(Microsoft.ML.Legacy.Transforms.MissingValueSubstitutor input) - { - var output = new Microsoft.ML.Legacy.Transforms.MissingValueSubstitutor.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.MissingValuesRowDropper.Output Add(Microsoft.ML.Legacy.Transforms.MissingValuesRowDropper input) + { + var output = new Microsoft.ML.Legacy.Transforms.MissingValuesRowDropper.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.MissingValueSubstitutor input, Microsoft.ML.Legacy.Transforms.MissingValueSubstitutor.Output output) - { - _jsonNodes.Add(Serialize("Transforms.MissingValueSubstitutor", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.MissingValuesRowDropper input, Microsoft.ML.Legacy.Transforms.MissingValuesRowDropper.Output output) + { + _jsonNodes.Add(Serialize("Transforms.MissingValuesRowDropper", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.ModelCombiner.Output Add(Microsoft.ML.Legacy.Transforms.ModelCombiner input) - { - var output = new Microsoft.ML.Legacy.Transforms.ModelCombiner.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.MissingValueSubstitutor.Output Add(Microsoft.ML.Legacy.Transforms.MissingValueSubstitutor input) + { + var output = new Microsoft.ML.Legacy.Transforms.MissingValueSubstitutor.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.ModelCombiner input, Microsoft.ML.Legacy.Transforms.ModelCombiner.Output output) - { - _jsonNodes.Add(Serialize("Transforms.ModelCombiner", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.MissingValueSubstitutor input, Microsoft.ML.Legacy.Transforms.MissingValueSubstitutor.Output output) + { + _jsonNodes.Add(Serialize("Transforms.MissingValueSubstitutor", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.NGramTranslator.Output Add(Microsoft.ML.Legacy.Transforms.NGramTranslator input) - { - var output = new Microsoft.ML.Legacy.Transforms.NGramTranslator.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.ModelCombiner.Output Add(Microsoft.ML.Legacy.Transforms.ModelCombiner input) + { + var output = new Microsoft.ML.Legacy.Transforms.ModelCombiner.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.NGramTranslator input, Microsoft.ML.Legacy.Transforms.NGramTranslator.Output output) - { - _jsonNodes.Add(Serialize("Transforms.NGramTranslator", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.ModelCombiner input, Microsoft.ML.Legacy.Transforms.ModelCombiner.Output output) + { + _jsonNodes.Add(Serialize("Transforms.ModelCombiner", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.NoOperation.Output Add(Microsoft.ML.Legacy.Transforms.NoOperation input) - { - var output = new Microsoft.ML.Legacy.Transforms.NoOperation.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.NGramTranslator.Output Add(Microsoft.ML.Legacy.Transforms.NGramTranslator input) + { + var output = new Microsoft.ML.Legacy.Transforms.NGramTranslator.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.NoOperation input, Microsoft.ML.Legacy.Transforms.NoOperation.Output output) - { - _jsonNodes.Add(Serialize("Transforms.NoOperation", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.NGramTranslator input, Microsoft.ML.Legacy.Transforms.NGramTranslator.Output output) + { + _jsonNodes.Add(Serialize("Transforms.NGramTranslator", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.OptionalColumnCreator.Output Add(Microsoft.ML.Legacy.Transforms.OptionalColumnCreator input) - { - var output = new Microsoft.ML.Legacy.Transforms.OptionalColumnCreator.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.NoOperation.Output Add(Microsoft.ML.Legacy.Transforms.NoOperation input) + { + var output = new Microsoft.ML.Legacy.Transforms.NoOperation.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.OptionalColumnCreator input, Microsoft.ML.Legacy.Transforms.OptionalColumnCreator.Output output) - { - _jsonNodes.Add(Serialize("Transforms.OptionalColumnCreator", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.NoOperation input, Microsoft.ML.Legacy.Transforms.NoOperation.Output output) + { + _jsonNodes.Add(Serialize("Transforms.NoOperation", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.PcaCalculator.Output Add(Microsoft.ML.Legacy.Transforms.PcaCalculator input) - { - var output = new Microsoft.ML.Legacy.Transforms.PcaCalculator.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.OptionalColumnCreator.Output Add(Microsoft.ML.Legacy.Transforms.OptionalColumnCreator input) + { + var output = new Microsoft.ML.Legacy.Transforms.OptionalColumnCreator.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.PcaCalculator input, Microsoft.ML.Legacy.Transforms.PcaCalculator.Output output) - { - _jsonNodes.Add(Serialize("Transforms.PcaCalculator", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.OptionalColumnCreator input, Microsoft.ML.Legacy.Transforms.OptionalColumnCreator.Output output) + { + _jsonNodes.Add(Serialize("Transforms.OptionalColumnCreator", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.PredictedLabelColumnOriginalValueConverter.Output Add(Microsoft.ML.Legacy.Transforms.PredictedLabelColumnOriginalValueConverter input) - { - var output = new Microsoft.ML.Legacy.Transforms.PredictedLabelColumnOriginalValueConverter.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.PcaCalculator.Output Add(Microsoft.ML.Legacy.Transforms.PcaCalculator input) + { + var output = new Microsoft.ML.Legacy.Transforms.PcaCalculator.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.PredictedLabelColumnOriginalValueConverter input, Microsoft.ML.Legacy.Transforms.PredictedLabelColumnOriginalValueConverter.Output output) - { - _jsonNodes.Add(Serialize("Transforms.PredictedLabelColumnOriginalValueConverter", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.PcaCalculator input, Microsoft.ML.Legacy.Transforms.PcaCalculator.Output output) + { + _jsonNodes.Add(Serialize("Transforms.PcaCalculator", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.RandomNumberGenerator.Output Add(Microsoft.ML.Legacy.Transforms.RandomNumberGenerator input) - { - var output = new Microsoft.ML.Legacy.Transforms.RandomNumberGenerator.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.PredictedLabelColumnOriginalValueConverter.Output Add(Microsoft.ML.Legacy.Transforms.PredictedLabelColumnOriginalValueConverter input) + { + var output = new Microsoft.ML.Legacy.Transforms.PredictedLabelColumnOriginalValueConverter.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.RandomNumberGenerator input, Microsoft.ML.Legacy.Transforms.RandomNumberGenerator.Output output) - { - _jsonNodes.Add(Serialize("Transforms.RandomNumberGenerator", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.PredictedLabelColumnOriginalValueConverter input, Microsoft.ML.Legacy.Transforms.PredictedLabelColumnOriginalValueConverter.Output output) + { + _jsonNodes.Add(Serialize("Transforms.PredictedLabelColumnOriginalValueConverter", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.RowRangeFilter.Output Add(Microsoft.ML.Legacy.Transforms.RowRangeFilter input) - { - var output = new Microsoft.ML.Legacy.Transforms.RowRangeFilter.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.RandomNumberGenerator.Output Add(Microsoft.ML.Legacy.Transforms.RandomNumberGenerator input) + { + var output = new Microsoft.ML.Legacy.Transforms.RandomNumberGenerator.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.RowRangeFilter input, Microsoft.ML.Legacy.Transforms.RowRangeFilter.Output output) - { - _jsonNodes.Add(Serialize("Transforms.RowRangeFilter", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.RandomNumberGenerator input, Microsoft.ML.Legacy.Transforms.RandomNumberGenerator.Output output) + { + _jsonNodes.Add(Serialize("Transforms.RandomNumberGenerator", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.RowSkipAndTakeFilter.Output Add(Microsoft.ML.Legacy.Transforms.RowSkipAndTakeFilter input) - { - var output = new Microsoft.ML.Legacy.Transforms.RowSkipAndTakeFilter.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.RowRangeFilter.Output Add(Microsoft.ML.Legacy.Transforms.RowRangeFilter input) + { + var output = new Microsoft.ML.Legacy.Transforms.RowRangeFilter.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.RowSkipAndTakeFilter input, Microsoft.ML.Legacy.Transforms.RowSkipAndTakeFilter.Output output) - { - _jsonNodes.Add(Serialize("Transforms.RowSkipAndTakeFilter", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.RowRangeFilter input, Microsoft.ML.Legacy.Transforms.RowRangeFilter.Output output) + { + _jsonNodes.Add(Serialize("Transforms.RowRangeFilter", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.RowSkipFilter.Output Add(Microsoft.ML.Legacy.Transforms.RowSkipFilter input) - { - var output = new Microsoft.ML.Legacy.Transforms.RowSkipFilter.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.RowSkipAndTakeFilter.Output Add(Microsoft.ML.Legacy.Transforms.RowSkipAndTakeFilter input) + { + var output = new Microsoft.ML.Legacy.Transforms.RowSkipAndTakeFilter.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.RowSkipFilter input, Microsoft.ML.Legacy.Transforms.RowSkipFilter.Output output) - { - _jsonNodes.Add(Serialize("Transforms.RowSkipFilter", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.RowSkipAndTakeFilter input, Microsoft.ML.Legacy.Transforms.RowSkipAndTakeFilter.Output output) + { + _jsonNodes.Add(Serialize("Transforms.RowSkipAndTakeFilter", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.RowTakeFilter.Output Add(Microsoft.ML.Legacy.Transforms.RowTakeFilter input) - { - var output = new Microsoft.ML.Legacy.Transforms.RowTakeFilter.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.RowSkipFilter.Output Add(Microsoft.ML.Legacy.Transforms.RowSkipFilter input) + { + var output = new Microsoft.ML.Legacy.Transforms.RowSkipFilter.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.RowTakeFilter input, Microsoft.ML.Legacy.Transforms.RowTakeFilter.Output output) - { - _jsonNodes.Add(Serialize("Transforms.RowTakeFilter", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.RowSkipFilter input, Microsoft.ML.Legacy.Transforms.RowSkipFilter.Output output) + { + _jsonNodes.Add(Serialize("Transforms.RowSkipFilter", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.ScoreColumnSelector.Output Add(Microsoft.ML.Legacy.Transforms.ScoreColumnSelector input) - { - var output = new Microsoft.ML.Legacy.Transforms.ScoreColumnSelector.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.RowTakeFilter.Output Add(Microsoft.ML.Legacy.Transforms.RowTakeFilter input) + { + var output = new Microsoft.ML.Legacy.Transforms.RowTakeFilter.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.ScoreColumnSelector input, Microsoft.ML.Legacy.Transforms.ScoreColumnSelector.Output output) - { - _jsonNodes.Add(Serialize("Transforms.ScoreColumnSelector", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.RowTakeFilter input, Microsoft.ML.Legacy.Transforms.RowTakeFilter.Output output) + { + _jsonNodes.Add(Serialize("Transforms.RowTakeFilter", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.Scorer.Output Add(Microsoft.ML.Legacy.Transforms.Scorer input) - { - var output = new Microsoft.ML.Legacy.Transforms.Scorer.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.ScoreColumnSelector.Output Add(Microsoft.ML.Legacy.Transforms.ScoreColumnSelector input) + { + var output = new Microsoft.ML.Legacy.Transforms.ScoreColumnSelector.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.Scorer input, Microsoft.ML.Legacy.Transforms.Scorer.Output output) - { - _jsonNodes.Add(Serialize("Transforms.Scorer", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.ScoreColumnSelector input, Microsoft.ML.Legacy.Transforms.ScoreColumnSelector.Output output) + { + _jsonNodes.Add(Serialize("Transforms.ScoreColumnSelector", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.Segregator.Output Add(Microsoft.ML.Legacy.Transforms.Segregator input) - { - var output = new Microsoft.ML.Legacy.Transforms.Segregator.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.Scorer.Output Add(Microsoft.ML.Legacy.Transforms.Scorer input) + { + var output = new Microsoft.ML.Legacy.Transforms.Scorer.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.Segregator input, Microsoft.ML.Legacy.Transforms.Segregator.Output output) - { - _jsonNodes.Add(Serialize("Transforms.Segregator", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.Scorer input, Microsoft.ML.Legacy.Transforms.Scorer.Output output) + { + _jsonNodes.Add(Serialize("Transforms.Scorer", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.SentimentAnalyzer.Output Add(Microsoft.ML.Legacy.Transforms.SentimentAnalyzer input) - { - var output = new Microsoft.ML.Legacy.Transforms.SentimentAnalyzer.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.Segregator.Output Add(Microsoft.ML.Legacy.Transforms.Segregator input) + { + var output = new Microsoft.ML.Legacy.Transforms.Segregator.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.SentimentAnalyzer input, Microsoft.ML.Legacy.Transforms.SentimentAnalyzer.Output output) - { - _jsonNodes.Add(Serialize("Transforms.SentimentAnalyzer", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.Segregator input, Microsoft.ML.Legacy.Transforms.Segregator.Output output) + { + _jsonNodes.Add(Serialize("Transforms.Segregator", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.TensorFlowScorer.Output Add(Microsoft.ML.Legacy.Transforms.TensorFlowScorer input) - { - var output = new Microsoft.ML.Legacy.Transforms.TensorFlowScorer.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.SentimentAnalyzer.Output Add(Microsoft.ML.Legacy.Transforms.SentimentAnalyzer input) + { + var output = new Microsoft.ML.Legacy.Transforms.SentimentAnalyzer.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.TensorFlowScorer input, Microsoft.ML.Legacy.Transforms.TensorFlowScorer.Output output) - { - _jsonNodes.Add(Serialize("Transforms.TensorFlowScorer", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.SentimentAnalyzer input, Microsoft.ML.Legacy.Transforms.SentimentAnalyzer.Output output) + { + _jsonNodes.Add(Serialize("Transforms.SentimentAnalyzer", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.TextFeaturizer.Output Add(Microsoft.ML.Legacy.Transforms.TextFeaturizer input) - { - var output = new Microsoft.ML.Legacy.Transforms.TextFeaturizer.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.TensorFlowScorer.Output Add(Microsoft.ML.Legacy.Transforms.TensorFlowScorer input) + { + var output = new Microsoft.ML.Legacy.Transforms.TensorFlowScorer.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.TextFeaturizer input, Microsoft.ML.Legacy.Transforms.TextFeaturizer.Output output) - { - _jsonNodes.Add(Serialize("Transforms.TextFeaturizer", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.TensorFlowScorer input, Microsoft.ML.Legacy.Transforms.TensorFlowScorer.Output output) + { + _jsonNodes.Add(Serialize("Transforms.TensorFlowScorer", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.TextToKeyConverter.Output Add(Microsoft.ML.Legacy.Transforms.TextToKeyConverter input) - { - var output = new Microsoft.ML.Legacy.Transforms.TextToKeyConverter.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.TextFeaturizer.Output Add(Microsoft.ML.Legacy.Transforms.TextFeaturizer input) + { + var output = new Microsoft.ML.Legacy.Transforms.TextFeaturizer.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.TextToKeyConverter input, Microsoft.ML.Legacy.Transforms.TextToKeyConverter.Output output) - { - _jsonNodes.Add(Serialize("Transforms.TextToKeyConverter", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.TextFeaturizer input, Microsoft.ML.Legacy.Transforms.TextFeaturizer.Output output) + { + _jsonNodes.Add(Serialize("Transforms.TextFeaturizer", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.TrainTestDatasetSplitter.Output Add(Microsoft.ML.Legacy.Transforms.TrainTestDatasetSplitter input) - { - var output = new Microsoft.ML.Legacy.Transforms.TrainTestDatasetSplitter.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.TextToKeyConverter.Output Add(Microsoft.ML.Legacy.Transforms.TextToKeyConverter input) + { + var output = new Microsoft.ML.Legacy.Transforms.TextToKeyConverter.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.TrainTestDatasetSplitter input, Microsoft.ML.Legacy.Transforms.TrainTestDatasetSplitter.Output output) - { - _jsonNodes.Add(Serialize("Transforms.TrainTestDatasetSplitter", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.TextToKeyConverter input, Microsoft.ML.Legacy.Transforms.TextToKeyConverter.Output output) + { + _jsonNodes.Add(Serialize("Transforms.TextToKeyConverter", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.TreeLeafFeaturizer.Output Add(Microsoft.ML.Legacy.Transforms.TreeLeafFeaturizer input) - { - var output = new Microsoft.ML.Legacy.Transforms.TreeLeafFeaturizer.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.TrainTestDatasetSplitter.Output Add(Microsoft.ML.Legacy.Transforms.TrainTestDatasetSplitter input) + { + var output = new Microsoft.ML.Legacy.Transforms.TrainTestDatasetSplitter.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.TreeLeafFeaturizer input, Microsoft.ML.Legacy.Transforms.TreeLeafFeaturizer.Output output) - { - _jsonNodes.Add(Serialize("Transforms.TreeLeafFeaturizer", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.TrainTestDatasetSplitter input, Microsoft.ML.Legacy.Transforms.TrainTestDatasetSplitter.Output output) + { + _jsonNodes.Add(Serialize("Transforms.TrainTestDatasetSplitter", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.TwoHeterogeneousModelCombiner.Output Add(Microsoft.ML.Legacy.Transforms.TwoHeterogeneousModelCombiner input) - { - var output = new Microsoft.ML.Legacy.Transforms.TwoHeterogeneousModelCombiner.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.TreeLeafFeaturizer.Output Add(Microsoft.ML.Legacy.Transforms.TreeLeafFeaturizer input) + { + var output = new Microsoft.ML.Legacy.Transforms.TreeLeafFeaturizer.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.TwoHeterogeneousModelCombiner input, Microsoft.ML.Legacy.Transforms.TwoHeterogeneousModelCombiner.Output output) - { - _jsonNodes.Add(Serialize("Transforms.TwoHeterogeneousModelCombiner", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.TreeLeafFeaturizer input, Microsoft.ML.Legacy.Transforms.TreeLeafFeaturizer.Output output) + { + _jsonNodes.Add(Serialize("Transforms.TreeLeafFeaturizer", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.VectorToImage.Output Add(Microsoft.ML.Legacy.Transforms.VectorToImage input) - { - var output = new Microsoft.ML.Legacy.Transforms.VectorToImage.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.TwoHeterogeneousModelCombiner.Output Add(Microsoft.ML.Legacy.Transforms.TwoHeterogeneousModelCombiner input) + { + var output = new Microsoft.ML.Legacy.Transforms.TwoHeterogeneousModelCombiner.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.VectorToImage input, Microsoft.ML.Legacy.Transforms.VectorToImage.Output output) - { - _jsonNodes.Add(Serialize("Transforms.VectorToImage", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.TwoHeterogeneousModelCombiner input, Microsoft.ML.Legacy.Transforms.TwoHeterogeneousModelCombiner.Output output) + { + _jsonNodes.Add(Serialize("Transforms.TwoHeterogeneousModelCombiner", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.WordEmbeddings.Output Add(Microsoft.ML.Legacy.Transforms.WordEmbeddings input) - { - var output = new Microsoft.ML.Legacy.Transforms.WordEmbeddings.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.VectorToImage.Output Add(Microsoft.ML.Legacy.Transforms.VectorToImage input) + { + var output = new Microsoft.ML.Legacy.Transforms.VectorToImage.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.WordEmbeddings input, Microsoft.ML.Legacy.Transforms.WordEmbeddings.Output output) - { - _jsonNodes.Add(Serialize("Transforms.WordEmbeddings", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.VectorToImage input, Microsoft.ML.Legacy.Transforms.VectorToImage.Output output) + { + _jsonNodes.Add(Serialize("Transforms.VectorToImage", input, output)); + } - [Obsolete] - public Microsoft.ML.Legacy.Transforms.WordTokenizer.Output Add(Microsoft.ML.Legacy.Transforms.WordTokenizer input) - { - var output = new Microsoft.ML.Legacy.Transforms.WordTokenizer.Output(); - Add(input, output); - return output; - } + [Obsolete] + public Microsoft.ML.Legacy.Transforms.WordEmbeddings.Output Add(Microsoft.ML.Legacy.Transforms.WordEmbeddings input) + { + var output = new Microsoft.ML.Legacy.Transforms.WordEmbeddings.Output(); + Add(input, output); + return output; + } - [Obsolete] - public void Add(Microsoft.ML.Legacy.Transforms.WordTokenizer input, Microsoft.ML.Legacy.Transforms.WordTokenizer.Output output) - { - _jsonNodes.Add(Serialize("Transforms.WordTokenizer", input, output)); - } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.WordEmbeddings input, Microsoft.ML.Legacy.Transforms.WordEmbeddings.Output output) + { + _jsonNodes.Add(Serialize("Transforms.WordEmbeddings", input, output)); + } + + [Obsolete] + public Microsoft.ML.Legacy.Transforms.WordTokenizer.Output Add(Microsoft.ML.Legacy.Transforms.WordTokenizer input) + { + var output = new Microsoft.ML.Legacy.Transforms.WordTokenizer.Output(); + Add(input, output); + return output; + } + [Obsolete] + public void Add(Microsoft.ML.Legacy.Transforms.WordTokenizer input, Microsoft.ML.Legacy.Transforms.WordTokenizer.Output output) + { + _jsonNodes.Add(Serialize("Transforms.WordTokenizer", input, output)); } + } namespace Legacy.Data { @@ -1976,7 +1945,7 @@ namespace Legacy.Data /// /// Import a dataset from a text file /// - [Obsolete("Use TextLoader instead.")] + [Obsolete] public sealed partial class CustomTextLoader { @@ -1985,7 +1954,7 @@ public sealed partial class CustomTextLoader /// Location of the input file /// [Obsolete] - public Var InputFile { get; set; } = new Var(); + public Var InputFile { get; set; } = new Var(); /// /// Custom schema to use for parsing @@ -2000,7 +1969,7 @@ public sealed class Output /// /// The resulting data view /// - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); } } @@ -2021,7 +1990,7 @@ public sealed partial class DataViewReference /// Pointer to IDataView in memory /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); [Obsolete] @@ -2030,7 +1999,7 @@ public sealed class Output /// /// The resulting data view /// - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); } } @@ -2051,7 +2020,7 @@ public sealed partial class IDataViewArrayConverter /// The data sets /// [Obsolete] - public ArrayVar Data { get; set; } = new ArrayVar(); + public ArrayVar Data { get; set; } = new ArrayVar(); [Obsolete] @@ -2060,7 +2029,7 @@ public sealed class Output /// /// The data set array /// - public ArrayVar OutputData { get; set; } = new ArrayVar(); + public ArrayVar OutputData { get; set; } = new ArrayVar(); } } @@ -2070,7 +2039,7 @@ namespace Legacy.Data { /// - /// Create an array variable of IPredictorModel + /// Create an array variable of PredictorModel /// [Obsolete] public sealed partial class PredictorModelArrayConverter @@ -2081,7 +2050,7 @@ public sealed partial class PredictorModelArrayConverter /// The models /// [Obsolete] - public ArrayVar Model { get; set; } = new ArrayVar(); + public ArrayVar Model { get; set; } = new ArrayVar(); [Obsolete] @@ -2090,7 +2059,7 @@ public sealed class Output /// /// The model array /// - public ArrayVar OutputModel { get; set; } = new ArrayVar(); + public ArrayVar OutputModel { get; set; } = new ArrayVar(); } } @@ -2329,14 +2298,14 @@ public TextLoaderPipelineStep (Output output) } public Var Data { get; } - public Var Model { get; } + public Var Model { get; } } /// /// Location of the input file /// [Obsolete] - public Var InputFile { get; set; } = new Var(); + public Var InputFile { get; set; } = new Var(); /// /// Arguments @@ -2351,37 +2320,7 @@ public sealed class Output /// /// The resulting data view /// - public Var Data { get; set; } = new Var(); - - } - } - } - - namespace Legacy.Data - { - - /// - /// Create an array variable of ITransformModel - /// - [Obsolete] - public sealed partial class TransformModelArrayConverter - { - - - /// - /// The models - /// - [Obsolete] - public ArrayVar TransformModel { get; set; } = new ArrayVar(); - - - [Obsolete] - public sealed class Output - { - /// - /// The model array - /// - public ArrayVar OutputModel { get; set; } = new ArrayVar(); + public Var Data { get; set; } = new Var(); } } @@ -2394,7 +2333,7 @@ namespace Legacy.Models /// Evaluates an anomaly detection scored dataset. /// [Obsolete] - public sealed partial class AnomalyDetectionEvaluator : Microsoft.ML.Runtime.EntryPoints.CommonInputs.IEvaluatorInput + public sealed partial class AnomalyDetectionEvaluator : Microsoft.ML.EntryPoints.CommonInputs.IEvaluatorInput { @@ -2456,7 +2395,7 @@ public sealed partial class AnomalyDetectionEvaluator : Microsoft.ML.Runtime.Ent /// The data to be used for evaluation. /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); /// /// Name column name. @@ -2466,22 +2405,22 @@ public sealed partial class AnomalyDetectionEvaluator : Microsoft.ML.Runtime.Ent [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IEvaluatorOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IEvaluatorOutput { /// /// Warning dataset /// - public Var Warnings { get; set; } = new Var(); + public Var Warnings { get; set; } = new Var(); /// /// Overall metrics dataset /// - public Var OverallMetrics { get; set; } = new Var(); + public Var OverallMetrics { get; set; } = new Var(); /// /// Per instance metrics dataset /// - public Var PerInstanceMetrics { get; set; } = new Var(); + public Var PerInstanceMetrics { get; set; } = new Var(); } } @@ -2515,16 +2454,16 @@ public sealed partial class AnomalyPipelineEnsemble /// The models to combine into an ensemble /// [Obsolete] - public ArrayVar Models { get; set; } = new ArrayVar(); + public ArrayVar Models { get; set; } = new ArrayVar(); [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IAnomalyDetectionOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IAnomalyDetectionOutput, Microsoft.ML.EntryPoints.CommonOutputs.ITrainerOutput { /// /// The trained model /// - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } } @@ -2537,7 +2476,7 @@ namespace Legacy.Models /// Evaluates a binary classification scored dataset. /// [Obsolete] - public sealed partial class BinaryClassificationEvaluator : Microsoft.ML.Runtime.EntryPoints.CommonInputs.IEvaluatorInput + public sealed partial class BinaryClassificationEvaluator : Microsoft.ML.EntryPoints.CommonInputs.IEvaluatorInput { @@ -2605,7 +2544,7 @@ public sealed partial class BinaryClassificationEvaluator : Microsoft.ML.Runtime /// The data to be used for evaluation. /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); /// /// Name column name. @@ -2615,129 +2554,27 @@ public sealed partial class BinaryClassificationEvaluator : Microsoft.ML.Runtime [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IClassificationEvaluatorOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IEvaluatorOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IClassificationEvaluatorOutput, Microsoft.ML.EntryPoints.CommonOutputs.IEvaluatorOutput { /// /// Confusion matrix dataset /// - public Var ConfusionMatrix { get; set; } = new Var(); - - /// - /// Warning dataset - /// - public Var Warnings { get; set; } = new Var(); - - /// - /// Overall metrics dataset - /// - public Var OverallMetrics { get; set; } = new Var(); - - /// - /// Per instance metrics dataset - /// - public Var PerInstanceMetrics { get; set; } = new Var(); - - } - } - } - - namespace Legacy.Models - { - - [Obsolete] - public sealed partial class CrossValidationBinaryMacroSubGraphInput - { - /// - /// The data to be used for training - /// - [Obsolete] - public Var Data { get; set; } = new Var(); - - } - - [Obsolete] - public sealed partial class CrossValidationBinaryMacroSubGraphOutput - { - /// - /// The model - /// - [Obsolete] - public Var Model { get; set; } = new Var(); - - } - - /// - /// Cross validation for binary classification - /// - [Obsolete] - public sealed partial class BinaryCrossValidator - { - - - /// - /// The data set - /// - [Obsolete] - public Var Data { get; set; } = new Var(); - - /// - /// The training subgraph - /// - [Obsolete] - public Experiment Nodes { get; set; } - - /// - /// The training subgraph inputs - /// - [Obsolete] - public CrossValidationBinaryMacroSubGraphInput Inputs { get; set; } = new CrossValidationBinaryMacroSubGraphInput(); - - /// - /// The training subgraph outputs - /// - [Obsolete] - public CrossValidationBinaryMacroSubGraphOutput Outputs { get; set; } = new CrossValidationBinaryMacroSubGraphOutput(); - - /// - /// Column to use for stratification - /// - [Obsolete] - public string StratificationColumn { get; set; } - - /// - /// Number of folds in k-fold cross-validation - /// - [Obsolete] - public int NumFolds { get; set; } = 2; - - - [Obsolete] - public sealed class Output - { - /// - /// The trained model - /// - public ArrayVar PredictorModel { get; set; } = new ArrayVar(); + public Var ConfusionMatrix { get; set; } = new Var(); /// /// Warning dataset /// - public ArrayVar Warnings { get; set; } = new ArrayVar(); + public Var Warnings { get; set; } = new Var(); /// /// Overall metrics dataset /// - public ArrayVar OverallMetrics { get; set; } = new ArrayVar(); + public Var OverallMetrics { get; set; } = new Var(); /// /// Per instance metrics dataset /// - public ArrayVar PerInstanceMetrics { get; set; } = new ArrayVar(); - - /// - /// Confusion matrix dataset - /// - public ArrayVar ConfusionMatrix { get; set; } = new ArrayVar(); + public Var PerInstanceMetrics { get; set; } = new Var(); } } @@ -2772,7 +2609,7 @@ public sealed partial class BinaryEnsemble /// The models to combine into an ensemble /// [Obsolete] - public ArrayVar Models { get; set; } = new ArrayVar(); + public ArrayVar Models { get; set; } = new ArrayVar(); /// /// Whether to validate that all the pipelines are identical @@ -2782,12 +2619,12 @@ public sealed partial class BinaryEnsemble [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IBinaryClassificationOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IBinaryClassificationOutput, Microsoft.ML.EntryPoints.CommonOutputs.ITrainerOutput { /// /// The trained model /// - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } } @@ -2814,16 +2651,16 @@ public sealed partial class BinaryPipelineEnsemble /// The models to combine into an ensemble /// [Obsolete] - public ArrayVar Models { get; set; } = new ArrayVar(); + public ArrayVar Models { get; set; } = new ArrayVar(); [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IBinaryClassificationOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IBinaryClassificationOutput, Microsoft.ML.EntryPoints.CommonOutputs.ITrainerOutput { /// /// The trained model /// - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } } @@ -2836,7 +2673,7 @@ namespace Legacy.Models /// Evaluates a multi class classification scored dataset. /// [Obsolete] - public sealed partial class ClassificationEvaluator : Microsoft.ML.Runtime.EntryPoints.CommonInputs.IEvaluatorInput + public sealed partial class ClassificationEvaluator : Microsoft.ML.EntryPoints.CommonInputs.IEvaluatorInput { @@ -2892,7 +2729,7 @@ public sealed partial class ClassificationEvaluator : Microsoft.ML.Runtime.Entry /// The data to be used for evaluation. /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); /// /// Name column name. @@ -2902,27 +2739,27 @@ public sealed partial class ClassificationEvaluator : Microsoft.ML.Runtime.Entry [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IClassificationEvaluatorOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IEvaluatorOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IClassificationEvaluatorOutput, Microsoft.ML.EntryPoints.CommonOutputs.IEvaluatorOutput { /// /// Confusion matrix dataset /// - public Var ConfusionMatrix { get; set; } = new Var(); + public Var ConfusionMatrix { get; set; } = new Var(); /// /// Warning dataset /// - public Var Warnings { get; set; } = new Var(); + public Var Warnings { get; set; } = new Var(); /// /// Overall metrics dataset /// - public Var OverallMetrics { get; set; } = new Var(); + public Var OverallMetrics { get; set; } = new Var(); /// /// Per instance metrics dataset /// - public Var PerInstanceMetrics { get; set; } = new Var(); + public Var PerInstanceMetrics { get; set; } = new Var(); } } @@ -2935,7 +2772,7 @@ namespace Legacy.Models /// Evaluates a clustering scored dataset. /// [Obsolete] - public sealed partial class ClusterEvaluator : Microsoft.ML.Runtime.EntryPoints.CommonInputs.IEvaluatorInput + public sealed partial class ClusterEvaluator : Microsoft.ML.EntryPoints.CommonInputs.IEvaluatorInput { @@ -2985,7 +2822,7 @@ public sealed partial class ClusterEvaluator : Microsoft.ML.Runtime.EntryPoints. /// The data to be used for evaluation. /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); /// /// Name column name. @@ -2995,22 +2832,22 @@ public sealed partial class ClusterEvaluator : Microsoft.ML.Runtime.EntryPoints. [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IEvaluatorOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IEvaluatorOutput { /// /// Warning dataset /// - public Var Warnings { get; set; } = new Var(); + public Var Warnings { get; set; } = new Var(); /// /// Overall metrics dataset /// - public Var OverallMetrics { get; set; } = new Var(); + public Var OverallMetrics { get; set; } = new Var(); /// /// Per instance metrics dataset /// - public Var PerInstanceMetrics { get; set; } = new Var(); + public Var PerInstanceMetrics { get; set; } = new Var(); } } @@ -3043,25 +2880,25 @@ public sealed partial class CrossValidationResultsCombiner /// Overall metrics datasets /// [Obsolete] - public ArrayVar OverallMetrics { get; set; } = new ArrayVar(); + public ArrayVar OverallMetrics { get; set; } = new ArrayVar(); /// /// Per instance metrics datasets /// [Obsolete] - public ArrayVar PerInstanceMetrics { get; set; } = new ArrayVar(); + public ArrayVar PerInstanceMetrics { get; set; } = new ArrayVar(); /// /// Confusion matrix datasets /// [Obsolete] - public ArrayVar ConfusionMatrix { get; set; } = new ArrayVar(); + public ArrayVar ConfusionMatrix { get; set; } = new ArrayVar(); /// /// Warning datasets /// [Obsolete] - public ArrayVar Warnings { get; set; } = new ArrayVar(); + public ArrayVar Warnings { get; set; } = new ArrayVar(); /// /// The label column name @@ -3073,19 +2910,19 @@ public sealed partial class CrossValidationResultsCombiner /// Column to use for example weight /// [Obsolete] - public Microsoft.ML.Runtime.EntryPoints.Optional WeightColumn { get; set; } + public Microsoft.ML.EntryPoints.Optional WeightColumn { get; set; } /// /// Column to use for grouping /// [Obsolete] - public Microsoft.ML.Runtime.EntryPoints.Optional GroupColumn { get; set; } + public Microsoft.ML.EntryPoints.Optional GroupColumn { get; set; } /// /// Name column name /// [Obsolete] - public Microsoft.ML.Runtime.EntryPoints.Optional NameColumn { get; set; } + public Microsoft.ML.EntryPoints.Optional NameColumn { get; set; } /// /// Specifies the trainer kind, which determines the evaluator to be used. @@ -3100,22 +2937,22 @@ public sealed class Output /// /// Warning dataset /// - public Var Warnings { get; set; } = new Var(); + public Var Warnings { get; set; } = new Var(); /// /// Overall metrics dataset /// - public Var OverallMetrics { get; set; } = new Var(); + public Var OverallMetrics { get; set; } = new Var(); /// /// Per instance metrics dataset /// - public Var PerInstanceMetrics { get; set; } = new Var(); + public Var PerInstanceMetrics { get; set; } = new Var(); /// /// Confusion matrix dataset /// - public Var ConfusionMatrix { get; set; } = new Var(); + public Var ConfusionMatrix { get; set; } = new Var(); } } @@ -3131,7 +2968,7 @@ public sealed partial class CrossValidationMacroSubGraphInput /// The data to be used for training /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); } @@ -3142,13 +2979,7 @@ public sealed partial class CrossValidationMacroSubGraphOutput /// The predictor model /// [Obsolete] - public Var PredictorModel { get; set; } = new Var(); - - /// - /// The transform model - /// - [Obsolete] - public Var TransformModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } @@ -3164,13 +2995,13 @@ public sealed partial class CrossValidator /// The data set /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); /// /// The transform model from the pipeline before this command. It gets included in the Output.PredictorModel. /// [Obsolete] - public Var TransformModel { get; set; } = new Var(); + public Var TransformModel { get; set; } = new Var(); /// /// The training subgraph @@ -3218,19 +3049,19 @@ public sealed partial class CrossValidator /// Column to use for example weight /// [Obsolete] - public Microsoft.ML.Runtime.EntryPoints.Optional WeightColumn { get; set; } + public Microsoft.ML.EntryPoints.Optional WeightColumn { get; set; } /// /// Column to use for grouping /// [Obsolete] - public Microsoft.ML.Runtime.EntryPoints.Optional GroupColumn { get; set; } + public Microsoft.ML.EntryPoints.Optional GroupColumn { get; set; } /// /// Name column name /// [Obsolete] - public Microsoft.ML.Runtime.EntryPoints.Optional NameColumn { get; set; } + public Microsoft.ML.EntryPoints.Optional NameColumn { get; set; } [Obsolete] @@ -3239,32 +3070,27 @@ public sealed class Output /// /// The final model including the trained predictor model and the model from the transforms, provided as the Input.TransformModel. /// - public ArrayVar PredictorModel { get; set; } = new ArrayVar(); - - /// - /// The final model including the trained predictor model and the model from the transforms, provided as the Input.TransformModel. - /// - public ArrayVar TransformModel { get; set; } = new ArrayVar(); + public ArrayVar PredictorModel { get; set; } = new ArrayVar(); /// /// Warning dataset /// - public Var Warnings { get; set; } = new Var(); + public Var Warnings { get; set; } = new Var(); /// /// Overall metrics dataset /// - public Var OverallMetrics { get; set; } = new Var(); + public Var OverallMetrics { get; set; } = new Var(); /// /// Per instance metrics dataset /// - public Var PerInstanceMetrics { get; set; } = new Var(); + public Var PerInstanceMetrics { get; set; } = new Var(); /// /// Confusion matrix dataset /// - public Var ConfusionMatrix { get; set; } = new Var(); + public Var ConfusionMatrix { get; set; } = new Var(); } } @@ -3285,7 +3111,7 @@ public sealed partial class CrossValidatorDatasetSplitter /// Input dataset /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); /// /// Number of folds to split into @@ -3306,12 +3132,12 @@ public sealed class Output /// /// Training data (one dataset per fold) /// - public ArrayVar TrainData { get; set; } = new ArrayVar(); + public ArrayVar TrainData { get; set; } = new ArrayVar(); /// /// Testing data (one dataset per fold) /// - public ArrayVar TestData { get; set; } = new ArrayVar(); + public ArrayVar TestData { get; set; } = new ArrayVar(); } } @@ -3324,7 +3150,7 @@ namespace Legacy.Models /// Applies a TransformModel to a dataset. /// [Obsolete] - public sealed partial class DatasetTransformer : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class DatasetTransformer : Microsoft.ML.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -3332,13 +3158,13 @@ public sealed partial class DatasetTransformer : Microsoft.ML.Runtime.EntryPoint /// Transform model /// [Obsolete] - public Var TransformModel { get; set; } = new Var(); + public Var TransformModel { get; set; } = new Var(); /// /// Input dataset /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); [Obsolete] @@ -3347,7 +3173,7 @@ public sealed class Output /// /// Transformed dataset /// - public Var OutputData { get; set; } = new Var(); + public Var OutputData { get; set; } = new Var(); } [Obsolete] @@ -3381,7 +3207,7 @@ public DatasetTransformerPipelineStep(Output output) [Obsolete] public Var Data { get; } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -3401,7 +3227,7 @@ public sealed partial class EnsembleSummary /// The predictor to summarize /// [Obsolete] - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); [Obsolete] @@ -3410,12 +3236,12 @@ public sealed class Output /// /// The summaries of the individual predictors /// - public ArrayVar Summaries { get; set; } = new ArrayVar(); + public ArrayVar Summaries { get; set; } = new ArrayVar(); /// /// The model statistics of the individual predictors /// - public ArrayVar Stats { get; set; } = new ArrayVar(); + public ArrayVar Stats { get; set; } = new ArrayVar(); } } @@ -3428,7 +3254,7 @@ namespace Legacy.Models /// Apply a Platt calibrator with a fixed slope and offset to an input model /// [Obsolete] - public sealed partial class FixedPlattCalibrator : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ICalibratorInput, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class FixedPlattCalibrator : Microsoft.ML.EntryPoints.CommonInputs.ICalibratorInput, Microsoft.ML.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -3448,7 +3274,7 @@ public sealed partial class FixedPlattCalibrator : Microsoft.ML.Runtime.EntryPoi /// The predictor to calibrate /// [Obsolete] - public Var UncalibratedPredictorModel { get; set; } = new Var(); + public Var UncalibratedPredictorModel { get; set; } = new Var(); /// /// The maximum number of examples to train the calibrator on @@ -3461,16 +3287,16 @@ public sealed partial class FixedPlattCalibrator : Microsoft.ML.Runtime.EntryPoi /// Input dataset /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ICalibratorOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.ICalibratorOutput, Microsoft.ML.EntryPoints.CommonOutputs.ITrainerOutput { /// /// The trained model /// - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } [Obsolete] @@ -3502,7 +3328,7 @@ public FixedPlattCalibratorPipelineStep(Output output) } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -3528,16 +3354,16 @@ public sealed partial class MultiClassPipelineEnsemble /// The models to combine into an ensemble /// [Obsolete] - public ArrayVar Models { get; set; } = new ArrayVar(); + public ArrayVar Models { get; set; } = new ArrayVar(); [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IMulticlassClassificationOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IMulticlassClassificationOutput, Microsoft.ML.EntryPoints.CommonOutputs.ITrainerOutput { /// /// The trained model /// - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } } @@ -3550,7 +3376,7 @@ namespace Legacy.Models /// Evaluates a multi output regression scored dataset. /// [Obsolete] - public sealed partial class MultiOutputRegressionEvaluator : Microsoft.ML.Runtime.EntryPoints.CommonInputs.IEvaluatorInput + public sealed partial class MultiOutputRegressionEvaluator : Microsoft.ML.EntryPoints.CommonInputs.IEvaluatorInput { @@ -3595,7 +3421,7 @@ public sealed partial class MultiOutputRegressionEvaluator : Microsoft.ML.Runtim /// The data to be used for evaluation. /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); /// /// Name column name. @@ -3605,22 +3431,22 @@ public sealed partial class MultiOutputRegressionEvaluator : Microsoft.ML.Runtim [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IEvaluatorOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IEvaluatorOutput { /// /// Warning dataset /// - public Var Warnings { get; set; } = new Var(); + public Var Warnings { get; set; } = new Var(); /// /// Overall metrics dataset /// - public Var OverallMetrics { get; set; } = new Var(); + public Var OverallMetrics { get; set; } = new Var(); /// /// Per instance metrics dataset /// - public Var PerInstanceMetrics { get; set; } = new Var(); + public Var PerInstanceMetrics { get; set; } = new Var(); } } @@ -3633,7 +3459,7 @@ namespace Legacy.Models /// Apply a Naive calibrator to an input model /// [Obsolete] - public sealed partial class NaiveCalibrator : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ICalibratorInput, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class NaiveCalibrator : Microsoft.ML.EntryPoints.CommonInputs.ICalibratorInput, Microsoft.ML.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -3641,7 +3467,7 @@ public sealed partial class NaiveCalibrator : Microsoft.ML.Runtime.EntryPoints.C /// The predictor to calibrate /// [Obsolete] - public Var UncalibratedPredictorModel { get; set; } = new Var(); + public Var UncalibratedPredictorModel { get; set; } = new Var(); /// /// The maximum number of examples to train the calibrator on @@ -3654,16 +3480,16 @@ public sealed partial class NaiveCalibrator : Microsoft.ML.Runtime.EntryPoints.C /// Input dataset /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ICalibratorOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.ICalibratorOutput, Microsoft.ML.EntryPoints.CommonOutputs.ITrainerOutput { /// /// The trained model /// - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } [Obsolete] @@ -3695,7 +3521,7 @@ public NaiveCalibratorPipelineStep(Output output) } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -3728,13 +3554,13 @@ public sealed partial class OneVersusAllMacroSubGraphOutput /// The predictor model for the subgraph exemplar. /// [Obsolete] - public Var Model { get; set; } = new Var(); + public Var Model { get; set; } = new Var(); } /// [Obsolete] - public sealed partial class OneVersusAll : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithWeight, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class OneVersusAll : Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithWeight, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -3760,7 +3586,7 @@ public sealed partial class OneVersusAll : Microsoft.ML.Runtime.EntryPoints.Comm /// Column to use for example weight /// [Obsolete] - public Microsoft.ML.Runtime.EntryPoints.Optional WeightColumn { get; set; } + public Microsoft.ML.EntryPoints.Optional WeightColumn { get; set; } /// /// Column to use for labels @@ -3772,7 +3598,7 @@ public sealed partial class OneVersusAll : Microsoft.ML.Runtime.EntryPoints.Comm /// The data to be used for training /// [Obsolete] - public Var TrainingData { get; set; } = new Var(); + public Var TrainingData { get; set; } = new Var(); /// /// Column to use for features @@ -3799,7 +3625,7 @@ public sealed class Output /// /// The trained multiclass model /// - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } [Obsolete] @@ -3831,7 +3657,7 @@ public OneVersusAllPipelineStep(Output output) } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -3894,7 +3720,7 @@ public sealed partial class OnnxConverter /// Model that needs to be converted to ONNX format. /// [Obsolete] - public Var Model { get; set; } = new Var(); + public Var Model { get; set; } = new Var(); /// /// The targeted ONNX version. It can be either "Stable" or "Experimental". If "Experimental" is used, produced model can contain components that is not officially supported in ONNX standard. @@ -3923,7 +3749,7 @@ namespace Legacy.Models /// Combines a sequence of PredictorModels into a single model /// [Obsolete] - public sealed partial class OvaModelCombiner : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithWeight, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class OvaModelCombiner : Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithWeight, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -3931,7 +3757,7 @@ public sealed partial class OvaModelCombiner : Microsoft.ML.Runtime.EntryPoints. /// Input models /// [Obsolete] - public ArrayVar ModelArray { get; set; } = new ArrayVar(); + public ArrayVar ModelArray { get; set; } = new ArrayVar(); /// /// Use probabilities from learners instead of raw values. @@ -3943,7 +3769,7 @@ public sealed partial class OvaModelCombiner : Microsoft.ML.Runtime.EntryPoints. /// Column to use for example weight /// [Obsolete] - public Microsoft.ML.Runtime.EntryPoints.Optional WeightColumn { get; set; } + public Microsoft.ML.EntryPoints.Optional WeightColumn { get; set; } /// /// Column to use for labels @@ -3955,7 +3781,7 @@ public sealed partial class OvaModelCombiner : Microsoft.ML.Runtime.EntryPoints. /// The data to be used for training /// [Obsolete] - public Var TrainingData { get; set; } = new Var(); + public Var TrainingData { get; set; } = new Var(); /// /// Column to use for features @@ -3982,7 +3808,7 @@ public sealed class Output /// /// Predictor model /// - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } [Obsolete] @@ -4014,7 +3840,7 @@ public OvaModelCombinerPipelineStep(Output output) } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -4026,7 +3852,7 @@ namespace Legacy.Models /// Apply a PAV calibrator to an input model /// [Obsolete] - public sealed partial class PAVCalibrator : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ICalibratorInput, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class PAVCalibrator : Microsoft.ML.EntryPoints.CommonInputs.ICalibratorInput, Microsoft.ML.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -4034,7 +3860,7 @@ public sealed partial class PAVCalibrator : Microsoft.ML.Runtime.EntryPoints.Com /// The predictor to calibrate /// [Obsolete] - public Var UncalibratedPredictorModel { get; set; } = new Var(); + public Var UncalibratedPredictorModel { get; set; } = new Var(); /// /// The maximum number of examples to train the calibrator on @@ -4047,16 +3873,16 @@ public sealed partial class PAVCalibrator : Microsoft.ML.Runtime.EntryPoints.Com /// Input dataset /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ICalibratorOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.ICalibratorOutput, Microsoft.ML.EntryPoints.CommonOutputs.ITrainerOutput { /// /// The trained model /// - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } [Obsolete] @@ -4088,7 +3914,7 @@ public PAVCalibratorPipelineStep(Output output) } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -4100,7 +3926,7 @@ namespace Legacy.Models /// Apply a Platt calibrator to an input model /// [Obsolete] - public sealed partial class PlattCalibrator : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ICalibratorInput, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class PlattCalibrator : Microsoft.ML.EntryPoints.CommonInputs.ICalibratorInput, Microsoft.ML.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -4108,7 +3934,7 @@ public sealed partial class PlattCalibrator : Microsoft.ML.Runtime.EntryPoints.C /// The predictor to calibrate /// [Obsolete] - public Var UncalibratedPredictorModel { get; set; } = new Var(); + public Var UncalibratedPredictorModel { get; set; } = new Var(); /// /// The maximum number of examples to train the calibrator on @@ -4121,16 +3947,16 @@ public sealed partial class PlattCalibrator : Microsoft.ML.Runtime.EntryPoints.C /// Input dataset /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ICalibratorOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.ICalibratorOutput, Microsoft.ML.EntryPoints.CommonOutputs.ITrainerOutput { /// /// The trained model /// - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } [Obsolete] @@ -4162,7 +3988,7 @@ public PlattCalibratorPipelineStep(Output output) } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -4174,7 +4000,7 @@ namespace Legacy.Models /// Evaluates a quantile regression scored dataset. /// [Obsolete] - public sealed partial class QuantileRegressionEvaluator : Microsoft.ML.Runtime.EntryPoints.CommonInputs.IEvaluatorInput + public sealed partial class QuantileRegressionEvaluator : Microsoft.ML.EntryPoints.CommonInputs.IEvaluatorInput { @@ -4219,7 +4045,7 @@ public sealed partial class QuantileRegressionEvaluator : Microsoft.ML.Runtime.E /// The data to be used for evaluation. /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); /// /// Name column name. @@ -4229,22 +4055,22 @@ public sealed partial class QuantileRegressionEvaluator : Microsoft.ML.Runtime.E [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IEvaluatorOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IEvaluatorOutput { /// /// Warning dataset /// - public Var Warnings { get; set; } = new Var(); + public Var Warnings { get; set; } = new Var(); /// /// Overall metrics dataset /// - public Var OverallMetrics { get; set; } = new Var(); + public Var OverallMetrics { get; set; } = new Var(); /// /// Per instance metrics dataset /// - public Var PerInstanceMetrics { get; set; } = new Var(); + public Var PerInstanceMetrics { get; set; } = new Var(); } } @@ -4257,7 +4083,7 @@ namespace Legacy.Models /// Evaluates a ranking scored dataset. /// [Obsolete] - public sealed partial class RankerEvaluator : Microsoft.ML.Runtime.EntryPoints.CommonInputs.IEvaluatorInput + public sealed partial class RankerEvaluator : Microsoft.ML.EntryPoints.CommonInputs.IEvaluatorInput { @@ -4307,7 +4133,7 @@ public sealed partial class RankerEvaluator : Microsoft.ML.Runtime.EntryPoints.C /// The data to be used for evaluation. /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); /// /// Name column name. @@ -4317,22 +4143,22 @@ public sealed partial class RankerEvaluator : Microsoft.ML.Runtime.EntryPoints.C [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IEvaluatorOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IEvaluatorOutput { /// /// Warning dataset /// - public Var Warnings { get; set; } = new Var(); + public Var Warnings { get; set; } = new Var(); /// /// Overall metrics dataset /// - public Var OverallMetrics { get; set; } = new Var(); + public Var OverallMetrics { get; set; } = new Var(); /// /// Per instance metrics dataset /// - public Var PerInstanceMetrics { get; set; } = new Var(); + public Var PerInstanceMetrics { get; set; } = new Var(); } } @@ -4359,7 +4185,7 @@ public sealed partial class RegressionEnsemble /// The models to combine into an ensemble /// [Obsolete] - public ArrayVar Models { get; set; } = new ArrayVar(); + public ArrayVar Models { get; set; } = new ArrayVar(); /// /// Whether to validate that all the pipelines are identical @@ -4369,12 +4195,12 @@ public sealed partial class RegressionEnsemble [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IRegressionOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IRegressionOutput, Microsoft.ML.EntryPoints.CommonOutputs.ITrainerOutput { /// /// The trained model /// - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } } @@ -4387,7 +4213,7 @@ namespace Legacy.Models /// Evaluates a regression scored dataset. /// [Obsolete] - public sealed partial class RegressionEvaluator : Microsoft.ML.Runtime.EntryPoints.CommonInputs.IEvaluatorInput + public sealed partial class RegressionEvaluator : Microsoft.ML.EntryPoints.CommonInputs.IEvaluatorInput { @@ -4426,7 +4252,7 @@ public sealed partial class RegressionEvaluator : Microsoft.ML.Runtime.EntryPoin /// The data to be used for evaluation. /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); /// /// Name column name. @@ -4436,22 +4262,22 @@ public sealed partial class RegressionEvaluator : Microsoft.ML.Runtime.EntryPoin [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IEvaluatorOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IEvaluatorOutput { /// /// Warning dataset /// - public Var Warnings { get; set; } = new Var(); + public Var Warnings { get; set; } = new Var(); /// /// Overall metrics dataset /// - public Var OverallMetrics { get; set; } = new Var(); + public Var OverallMetrics { get; set; } = new Var(); /// /// Per instance metrics dataset /// - public Var PerInstanceMetrics { get; set; } = new Var(); + public Var PerInstanceMetrics { get; set; } = new Var(); } } @@ -4478,16 +4304,16 @@ public sealed partial class RegressionPipelineEnsemble /// The models to combine into an ensemble /// [Obsolete] - public ArrayVar Models { get; set; } = new ArrayVar(); + public ArrayVar Models { get; set; } = new ArrayVar(); [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IRegressionOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IRegressionOutput, Microsoft.ML.EntryPoints.CommonOutputs.ITrainerOutput { /// /// The trained model /// - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } } @@ -4508,7 +4334,7 @@ public sealed partial class Summarizer /// The predictor to summarize /// [Obsolete] - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); [Obsolete] @@ -4517,12 +4343,12 @@ public sealed class Output /// /// The summary of a predictor /// - public Var Summary { get; set; } = new Var(); + public Var Summary { get; set; } = new Var(); /// /// The training set statistics. Note that this output can be null. /// - public Var Stats { get; set; } = new Var(); + public Var Stats { get; set; } = new Var(); } } @@ -4532,32 +4358,32 @@ namespace Legacy.Models { [Obsolete] - public sealed partial class TrainTestBinaryMacroSubGraphInput + public sealed partial class TrainTestMacroSubGraphInput { /// /// The data to be used for training /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); } [Obsolete] - public sealed partial class TrainTestBinaryMacroSubGraphOutput + public sealed partial class TrainTestMacroSubGraphOutput { /// - /// The model + /// The predictor model /// [Obsolete] - public Var Model { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } /// - /// Train test for binary classification + /// General train test for any supported evaluator /// [Obsolete] - public sealed partial class TrainTestBinaryEvaluator + public sealed partial class TrainTestEvaluator { @@ -4565,121 +4391,19 @@ public sealed partial class TrainTestBinaryEvaluator /// The data to be used for training /// [Obsolete] - public Var TrainingData { get; set; } = new Var(); + public Var TrainingData { get; set; } = new Var(); /// /// The data to be used for testing /// [Obsolete] - public Var TestingData { get; set; } = new Var(); + public Var TestingData { get; set; } = new Var(); /// - /// The training subgraph + /// The aggregated transform model from the pipeline before this command, to apply to the test data, and also include in the final model, together with the predictor model. /// [Obsolete] - public Experiment Nodes { get; set; } - - /// - /// The training subgraph inputs - /// - [Obsolete] - public TrainTestBinaryMacroSubGraphInput Inputs { get; set; } = new TrainTestBinaryMacroSubGraphInput(); - - /// - /// The training subgraph outputs - /// - [Obsolete] - public TrainTestBinaryMacroSubGraphOutput Outputs { get; set; } = new TrainTestBinaryMacroSubGraphOutput(); - - - [Obsolete] - public sealed class Output - { - /// - /// The trained model - /// - public Var PredictorModel { get; set; } = new Var(); - - /// - /// Warning dataset - /// - public Var Warnings { get; set; } = new Var(); - - /// - /// Overall metrics dataset - /// - public Var OverallMetrics { get; set; } = new Var(); - - /// - /// Per instance metrics dataset - /// - public Var PerInstanceMetrics { get; set; } = new Var(); - - /// - /// Confusion matrix dataset - /// - public Var ConfusionMatrix { get; set; } = new Var(); - - } - } - } - - namespace Legacy.Models - { - - [Obsolete] - public sealed partial class TrainTestMacroSubGraphInput - { - /// - /// The data to be used for training - /// - [Obsolete] - public Var Data { get; set; } = new Var(); - - } - - [Obsolete] - public sealed partial class TrainTestMacroSubGraphOutput - { - /// - /// The predictor model - /// - [Obsolete] - public Var PredictorModel { get; set; } = new Var(); - - /// - /// Transform model - /// - [Obsolete] - public Var TransformModel { get; set; } = new Var(); - - } - - /// - /// General train test for any supported evaluator - /// - [Obsolete] - public sealed partial class TrainTestEvaluator - { - - - /// - /// The data to be used for training - /// - [Obsolete] - public Var TrainingData { get; set; } = new Var(); - - /// - /// The data to be used for testing - /// - [Obsolete] - public Var TestingData { get; set; } = new Var(); - - /// - /// The aggregated transform model from the pipeline before this command, to apply to the test data, and also include in the final model, together with the predictor model. - /// - [Obsolete] - public Var TransformModel { get; set; } = new Var(); + public Var TransformModel { get; set; } = new Var(); /// /// The training subgraph @@ -4727,19 +4451,19 @@ public sealed partial class TrainTestEvaluator /// Column to use for example weight /// [Obsolete] - public Microsoft.ML.Runtime.EntryPoints.Optional WeightColumn { get; set; } + public Microsoft.ML.EntryPoints.Optional WeightColumn { get; set; } /// /// Column to use for grouping /// [Obsolete] - public Microsoft.ML.Runtime.EntryPoints.Optional GroupColumn { get; set; } + public Microsoft.ML.EntryPoints.Optional GroupColumn { get; set; } /// /// Name column name /// [Obsolete] - public Microsoft.ML.Runtime.EntryPoints.Optional NameColumn { get; set; } + public Microsoft.ML.EntryPoints.Optional NameColumn { get; set; } [Obsolete] @@ -4748,65 +4472,60 @@ public sealed class Output /// /// The final model including the trained predictor model and the model from the transforms, provided as the Input.TransformModel. /// - public Var PredictorModel { get; set; } = new Var(); - - /// - /// The final model including the trained predictor model and the model from the transforms, provided as the Input.TransformModel. - /// - public Var TransformModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); /// /// Warning dataset /// - public Var Warnings { get; set; } = new Var(); + public Var Warnings { get; set; } = new Var(); /// /// Overall metrics dataset /// - public Var OverallMetrics { get; set; } = new Var(); + public Var OverallMetrics { get; set; } = new Var(); /// /// Per instance metrics dataset /// - public Var PerInstanceMetrics { get; set; } = new Var(); + public Var PerInstanceMetrics { get; set; } = new Var(); /// /// Confusion matrix dataset /// - public Var ConfusionMatrix { get; set; } = new Var(); + public Var ConfusionMatrix { get; set; } = new Var(); /// /// Warning dataset for training /// - public Var TrainingWarnings { get; set; } = new Var(); + public Var TrainingWarnings { get; set; } = new Var(); /// /// Overall metrics dataset for training /// - public Var TrainingOverallMetrics { get; set; } = new Var(); + public Var TrainingOverallMetrics { get; set; } = new Var(); /// /// Per instance metrics dataset for training /// - public Var TrainingPerInstanceMetrics { get; set; } = new Var(); + public Var TrainingPerInstanceMetrics { get; set; } = new Var(); /// /// Confusion matrix dataset for training /// - public Var TrainingConfusionMatrix { get; set; } = new Var(); + public Var TrainingConfusionMatrix { get; set; } = new Var(); } } } - namespace Legacy.TimeSeriesProcessing + namespace Legacy.TimeSeriesProcessingEntryPoints { /// /// Applies a Exponential average on a time series. /// [Obsolete] - public sealed partial class ExponentialAverage : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class ExponentialAverage : Microsoft.ML.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -4832,21 +4551,21 @@ public sealed partial class ExponentialAverage : Microsoft.ML.Runtime.EntryPoint /// Input dataset /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITransformOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.ITransformOutput { /// /// Transformed dataset /// - public Var OutputData { get; set; } = new Var(); + public Var OutputData { get; set; } = new Var(); /// /// Transform model /// - public Var Model { get; set; } = new Var(); + public Var Model { get; set; } = new Var(); } [Obsolete] @@ -4881,12 +4600,12 @@ public ExponentialAveragePipelineStep(Output output) [Obsolete] public Var Data { get; } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } - namespace Legacy.TimeSeriesProcessing + namespace Legacy.TimeSeriesProcessingEntryPoints { [Obsolete] public enum SequentialAnomalyDetectionTransformBaseSingleIidAnomalyDetectionBaseStateMartingaleType : byte @@ -4901,7 +4620,7 @@ public enum SequentialAnomalyDetectionTransformBaseSingleIidAnomalyDetectionBase /// This transform detects the change-points in an i.i.d. sequence using adaptive kernel density estimation and martingales. /// [Obsolete] - public sealed partial class IidChangePointDetector : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class IidChangePointDetector : Microsoft.ML.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -4945,21 +4664,21 @@ public sealed partial class IidChangePointDetector : Microsoft.ML.Runtime.EntryP /// Input dataset /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITransformOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.ITransformOutput { /// /// Transformed dataset /// - public Var OutputData { get; set; } = new Var(); + public Var OutputData { get; set; } = new Var(); /// /// Transform model /// - public Var Model { get; set; } = new Var(); + public Var Model { get; set; } = new Var(); } [Obsolete] @@ -4994,12 +4713,12 @@ public IidChangePointDetectorPipelineStep(Output output) [Obsolete] public Var Data { get; } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } - namespace Legacy.TimeSeriesProcessing + namespace Legacy.TimeSeriesProcessingEntryPoints { [Obsolete] public enum SequentialAnomalyDetectionTransformBaseSingleIidAnomalyDetectionBaseStateAnomalySide : byte @@ -5014,7 +4733,7 @@ public enum SequentialAnomalyDetectionTransformBaseSingleIidAnomalyDetectionBase /// This transform detects the spikes in a i.i.d. sequence using adaptive kernel density estimation. /// [Obsolete] - public sealed partial class IidSpikeDetector : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class IidSpikeDetector : Microsoft.ML.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -5052,21 +4771,21 @@ public sealed partial class IidSpikeDetector : Microsoft.ML.Runtime.EntryPoints. /// Input dataset /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITransformOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.ITransformOutput { /// /// Transformed dataset /// - public Var OutputData { get; set; } = new Var(); + public Var OutputData { get; set; } = new Var(); /// /// Transform model /// - public Var Model { get; set; } = new Var(); + public Var Model { get; set; } = new Var(); } [Obsolete] @@ -5101,19 +4820,19 @@ public IidSpikeDetectorPipelineStep(Output output) [Obsolete] public Var Data { get; } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } - namespace Legacy.TimeSeriesProcessing + namespace Legacy.TimeSeriesProcessingEntryPoints { /// /// Detects the values of time-series that are in the top percentile of the sliding window. /// [Obsolete] - public sealed partial class PercentileThresholdTransform : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class PercentileThresholdTransform : Microsoft.ML.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -5145,21 +4864,21 @@ public sealed partial class PercentileThresholdTransform : Microsoft.ML.Runtime. /// Input dataset /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITransformOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.ITransformOutput { /// /// Transformed dataset /// - public Var OutputData { get; set; } = new Var(); + public Var OutputData { get; set; } = new Var(); /// /// Transform model /// - public Var Model { get; set; } = new Var(); + public Var Model { get; set; } = new Var(); } [Obsolete] @@ -5194,19 +4913,19 @@ public PercentileThresholdTransformPipelineStep(Output output) [Obsolete] public Var Data { get; } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } - namespace Legacy.TimeSeriesProcessing + namespace Legacy.TimeSeriesProcessingEntryPoints { /// /// This P-Value transform calculates the p-value of the current input in the sequence with regard to the values in the sliding window. /// [Obsolete] - public sealed partial class PValueTransform : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class PValueTransform : Microsoft.ML.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -5250,21 +4969,21 @@ public sealed partial class PValueTransform : Microsoft.ML.Runtime.EntryPoints.C /// Input dataset /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITransformOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.ITransformOutput { /// /// Transformed dataset /// - public Var OutputData { get; set; } = new Var(); + public Var OutputData { get; set; } = new Var(); /// /// Transform model /// - public Var Model { get; set; } = new Var(); + public Var Model { get; set; } = new Var(); } [Obsolete] @@ -5299,12 +5018,12 @@ public PValueTransformPipelineStep(Output output) [Obsolete] public Var Data { get; } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } - namespace Legacy.TimeSeriesProcessing + namespace Legacy.TimeSeriesProcessingEntryPoints { [Obsolete] public enum SlidingWindowTransformBaseSingleBeginOptions : byte @@ -5318,7 +5037,7 @@ public enum SlidingWindowTransformBaseSingleBeginOptions : byte /// Returns the last values for a time series [y(t-d-l+1), y(t-d-l+2), ..., y(t-l-1), y(t-l)] where d is the size of the window, l the lag and y is a Float. /// [Obsolete] - public sealed partial class SlidingWindowTransform : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class SlidingWindowTransform : Microsoft.ML.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -5356,21 +5075,21 @@ public sealed partial class SlidingWindowTransform : Microsoft.ML.Runtime.EntryP /// Input dataset /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITransformOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.ITransformOutput { /// /// Transformed dataset /// - public Var OutputData { get; set; } = new Var(); + public Var OutputData { get; set; } = new Var(); /// /// Transform model /// - public Var Model { get; set; } = new Var(); + public Var Model { get; set; } = new Var(); } [Obsolete] @@ -5405,12 +5124,12 @@ public SlidingWindowTransformPipelineStep(Output output) [Obsolete] public Var Data { get; } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } - namespace Legacy.TimeSeriesProcessing + namespace Legacy.TimeSeriesProcessingEntryPoints { [Obsolete] public enum ErrorFunctionUtilsErrorFunction : byte @@ -5435,7 +5154,7 @@ public enum SequentialAnomalyDetectionTransformBaseSingleSsaAnomalyDetectionBase /// This transform detects the change-points in a seasonal time-series using Singular Spectrum Analysis (SSA). /// [Obsolete] - public sealed partial class SsaChangePointDetector : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class SsaChangePointDetector : Microsoft.ML.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -5497,21 +5216,21 @@ public sealed partial class SsaChangePointDetector : Microsoft.ML.Runtime.EntryP /// Input dataset /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITransformOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.ITransformOutput { /// /// Transformed dataset /// - public Var OutputData { get; set; } = new Var(); + public Var OutputData { get; set; } = new Var(); /// /// Transform model /// - public Var Model { get; set; } = new Var(); + public Var Model { get; set; } = new Var(); } [Obsolete] @@ -5546,12 +5265,12 @@ public SsaChangePointDetectorPipelineStep(Output output) [Obsolete] public Var Data { get; } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } - namespace Legacy.TimeSeriesProcessing + namespace Legacy.TimeSeriesProcessingEntryPoints { [Obsolete] public enum SequentialAnomalyDetectionTransformBaseSingleSsaAnomalyDetectionBaseStateAnomalySide : byte @@ -5566,7 +5285,7 @@ public enum SequentialAnomalyDetectionTransformBaseSingleSsaAnomalyDetectionBase /// This transform detects the spikes in a seasonal time-series using Singular Spectrum Analysis (SSA). /// [Obsolete] - public sealed partial class SsaSpikeDetector : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class SsaSpikeDetector : Microsoft.ML.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -5622,21 +5341,21 @@ public sealed partial class SsaSpikeDetector : Microsoft.ML.Runtime.EntryPoints. /// Input dataset /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITransformOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.ITransformOutput { /// /// Transformed dataset /// - public Var OutputData { get; set; } = new Var(); + public Var OutputData { get; set; } = new Var(); /// /// Transform model /// - public Var Model { get; set; } = new Var(); + public Var Model { get; set; } = new Var(); } [Obsolete] @@ -5671,7 +5390,7 @@ public SsaSpikeDetectorPipelineStep(Output output) [Obsolete] public Var Data { get; } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -5682,7 +5401,7 @@ namespace Legacy.Trainers /// /// [Obsolete] - public sealed partial class AveragedPerceptronBinaryClassifier : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class AveragedPerceptronBinaryClassifier : Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -5806,7 +5525,7 @@ public sealed partial class AveragedPerceptronBinaryClassifier : Microsoft.ML.Ru /// The data to be used for training /// [Obsolete] - public Var TrainingData { get; set; } = new Var(); + public Var TrainingData { get; set; } = new Var(); /// /// Column to use for features @@ -5828,12 +5547,12 @@ public sealed partial class AveragedPerceptronBinaryClassifier : Microsoft.ML.Ru [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IBinaryClassificationOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IBinaryClassificationOutput, Microsoft.ML.EntryPoints.CommonOutputs.ITrainerOutput { /// /// The trained model /// - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } [Obsolete] @@ -5865,7 +5584,7 @@ public AveragedPerceptronBinaryClassifierPipelineStep(Output output) } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -5877,7 +5596,7 @@ namespace Legacy.Trainers /// Train binary ensemble. /// [Obsolete] - public sealed partial class EnsembleBinaryClassifier : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class EnsembleBinaryClassifier : Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -5936,7 +5655,7 @@ public sealed partial class EnsembleBinaryClassifier : Microsoft.ML.Runtime.Entr /// The data to be used for training /// [Obsolete] - public Var TrainingData { get; set; } = new Var(); + public Var TrainingData { get; set; } = new Var(); /// /// Column to use for features @@ -5958,12 +5677,12 @@ public sealed partial class EnsembleBinaryClassifier : Microsoft.ML.Runtime.Entr [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IBinaryClassificationOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IBinaryClassificationOutput, Microsoft.ML.EntryPoints.CommonOutputs.ITrainerOutput { /// /// The trained model /// - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } [Obsolete] @@ -5995,7 +5714,7 @@ public EnsembleBinaryClassifierPipelineStep(Output output) } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -6007,7 +5726,7 @@ namespace Legacy.Trainers /// Train multiclass ensemble. /// [Obsolete] - public sealed partial class EnsembleClassification : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class EnsembleClassification : Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -6066,7 +5785,7 @@ public sealed partial class EnsembleClassification : Microsoft.ML.Runtime.EntryP /// The data to be used for training /// [Obsolete] - public Var TrainingData { get; set; } = new Var(); + public Var TrainingData { get; set; } = new Var(); /// /// Column to use for features @@ -6088,12 +5807,12 @@ public sealed partial class EnsembleClassification : Microsoft.ML.Runtime.EntryP [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IMulticlassClassificationOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IMulticlassClassificationOutput, Microsoft.ML.EntryPoints.CommonOutputs.ITrainerOutput { /// /// The trained model /// - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } [Obsolete] @@ -6125,7 +5844,7 @@ public EnsembleClassificationPipelineStep(Output output) } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -6137,7 +5856,7 @@ namespace Legacy.Trainers /// Train regression ensemble. /// [Obsolete] - public sealed partial class EnsembleRegression : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class EnsembleRegression : Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -6196,7 +5915,7 @@ public sealed partial class EnsembleRegression : Microsoft.ML.Runtime.EntryPoint /// The data to be used for training /// [Obsolete] - public Var TrainingData { get; set; } = new Var(); + public Var TrainingData { get; set; } = new Var(); /// /// Column to use for features @@ -6218,12 +5937,12 @@ public sealed partial class EnsembleRegression : Microsoft.ML.Runtime.EntryPoint [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IRegressionOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IRegressionOutput, Microsoft.ML.EntryPoints.CommonOutputs.ITrainerOutput { /// /// The trained model /// - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } [Obsolete] @@ -6255,7 +5974,7 @@ public EnsembleRegressionPipelineStep(Output output) } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -6274,7 +5993,7 @@ public enum Bundle : byte /// /// [Obsolete] - public sealed partial class FastForestBinaryClassifier : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithGroupId, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithWeight, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class FastForestBinaryClassifier : Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithGroupId, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithWeight, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -6533,13 +6252,13 @@ public sealed partial class FastForestBinaryClassifier : Microsoft.ML.Runtime.En /// Column to use for example groupId /// [Obsolete] - public Microsoft.ML.Runtime.EntryPoints.Optional GroupIdColumn { get; set; } + public Microsoft.ML.EntryPoints.Optional GroupIdColumn { get; set; } /// /// Column to use for example weight /// [Obsolete] - public Microsoft.ML.Runtime.EntryPoints.Optional WeightColumn { get; set; } + public Microsoft.ML.EntryPoints.Optional WeightColumn { get; set; } /// /// Column to use for labels @@ -6551,7 +6270,7 @@ public sealed partial class FastForestBinaryClassifier : Microsoft.ML.Runtime.En /// The data to be used for training /// [Obsolete] - public Var TrainingData { get; set; } = new Var(); + public Var TrainingData { get; set; } = new Var(); /// /// Column to use for features @@ -6573,12 +6292,12 @@ public sealed partial class FastForestBinaryClassifier : Microsoft.ML.Runtime.En [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IBinaryClassificationOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IBinaryClassificationOutput, Microsoft.ML.EntryPoints.CommonOutputs.ITrainerOutput { /// /// The trained model /// - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } [Obsolete] @@ -6610,7 +6329,7 @@ public FastForestBinaryClassifierPipelineStep(Output output) } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -6621,7 +6340,7 @@ namespace Legacy.Trainers /// /// [Obsolete] - public sealed partial class FastForestRegressor : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithGroupId, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithWeight, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class FastForestRegressor : Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithGroupId, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithWeight, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -6867,13 +6586,13 @@ public sealed partial class FastForestRegressor : Microsoft.ML.Runtime.EntryPoin /// Column to use for example groupId /// [Obsolete] - public Microsoft.ML.Runtime.EntryPoints.Optional GroupIdColumn { get; set; } + public Microsoft.ML.EntryPoints.Optional GroupIdColumn { get; set; } /// /// Column to use for example weight /// [Obsolete] - public Microsoft.ML.Runtime.EntryPoints.Optional WeightColumn { get; set; } + public Microsoft.ML.EntryPoints.Optional WeightColumn { get; set; } /// /// Column to use for labels @@ -6885,7 +6604,7 @@ public sealed partial class FastForestRegressor : Microsoft.ML.Runtime.EntryPoin /// The data to be used for training /// [Obsolete] - public Var TrainingData { get; set; } = new Var(); + public Var TrainingData { get; set; } = new Var(); /// /// Column to use for features @@ -6907,12 +6626,12 @@ public sealed partial class FastForestRegressor : Microsoft.ML.Runtime.EntryPoin [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IRegressionOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IRegressionOutput, Microsoft.ML.EntryPoints.CommonOutputs.ITrainerOutput { /// /// The trained model /// - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } [Obsolete] @@ -6944,7 +6663,7 @@ public FastForestRegressorPipelineStep(Output output) } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -6963,7 +6682,7 @@ public enum BoostedTreeArgsOptimizationAlgorithmType /// /// [Obsolete] - public sealed partial class FastTreeBinaryClassifier : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithGroupId, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithWeight, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class FastTreeBinaryClassifier : Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithGroupId, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithWeight, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -7339,13 +7058,13 @@ public sealed partial class FastTreeBinaryClassifier : Microsoft.ML.Runtime.Entr /// Column to use for example groupId /// [Obsolete] - public Microsoft.ML.Runtime.EntryPoints.Optional GroupIdColumn { get; set; } + public Microsoft.ML.EntryPoints.Optional GroupIdColumn { get; set; } /// /// Column to use for example weight /// [Obsolete] - public Microsoft.ML.Runtime.EntryPoints.Optional WeightColumn { get; set; } + public Microsoft.ML.EntryPoints.Optional WeightColumn { get; set; } /// /// Column to use for labels @@ -7357,7 +7076,7 @@ public sealed partial class FastTreeBinaryClassifier : Microsoft.ML.Runtime.Entr /// The data to be used for training /// [Obsolete] - public Var TrainingData { get; set; } = new Var(); + public Var TrainingData { get; set; } = new Var(); /// /// Column to use for features @@ -7379,12 +7098,12 @@ public sealed partial class FastTreeBinaryClassifier : Microsoft.ML.Runtime.Entr [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IBinaryClassificationOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IBinaryClassificationOutput, Microsoft.ML.EntryPoints.CommonOutputs.ITrainerOutput { /// /// The trained model /// - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } [Obsolete] @@ -7416,7 +7135,7 @@ public FastTreeBinaryClassifierPipelineStep(Output output) } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -7427,7 +7146,7 @@ namespace Legacy.Trainers /// /// [Obsolete] - public sealed partial class FastTreeRanker : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithGroupId, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithWeight, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class FastTreeRanker : Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithGroupId, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithWeight, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -7845,13 +7564,13 @@ public sealed partial class FastTreeRanker : Microsoft.ML.Runtime.EntryPoints.Co /// Column to use for example groupId /// [Obsolete] - public Microsoft.ML.Runtime.EntryPoints.Optional GroupIdColumn { get; set; } + public Microsoft.ML.EntryPoints.Optional GroupIdColumn { get; set; } /// /// Column to use for example weight /// [Obsolete] - public Microsoft.ML.Runtime.EntryPoints.Optional WeightColumn { get; set; } + public Microsoft.ML.EntryPoints.Optional WeightColumn { get; set; } /// /// Column to use for labels @@ -7863,7 +7582,7 @@ public sealed partial class FastTreeRanker : Microsoft.ML.Runtime.EntryPoints.Co /// The data to be used for training /// [Obsolete] - public Var TrainingData { get; set; } = new Var(); + public Var TrainingData { get; set; } = new Var(); /// /// Column to use for features @@ -7885,12 +7604,12 @@ public sealed partial class FastTreeRanker : Microsoft.ML.Runtime.EntryPoints.Co [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IRankingOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IRankingOutput, Microsoft.ML.EntryPoints.CommonOutputs.ITrainerOutput { /// /// The trained model /// - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } [Obsolete] @@ -7922,7 +7641,7 @@ public FastTreeRankerPipelineStep(Output output) } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -7933,7 +7652,7 @@ namespace Legacy.Trainers /// /// [Obsolete] - public sealed partial class FastTreeRegressor : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithGroupId, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithWeight, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class FastTreeRegressor : Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithGroupId, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithWeight, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -8303,13 +8022,13 @@ public sealed partial class FastTreeRegressor : Microsoft.ML.Runtime.EntryPoints /// Column to use for example groupId /// [Obsolete] - public Microsoft.ML.Runtime.EntryPoints.Optional GroupIdColumn { get; set; } + public Microsoft.ML.EntryPoints.Optional GroupIdColumn { get; set; } /// /// Column to use for example weight /// [Obsolete] - public Microsoft.ML.Runtime.EntryPoints.Optional WeightColumn { get; set; } + public Microsoft.ML.EntryPoints.Optional WeightColumn { get; set; } /// /// Column to use for labels @@ -8321,7 +8040,7 @@ public sealed partial class FastTreeRegressor : Microsoft.ML.Runtime.EntryPoints /// The data to be used for training /// [Obsolete] - public Var TrainingData { get; set; } = new Var(); + public Var TrainingData { get; set; } = new Var(); /// /// Column to use for features @@ -8343,12 +8062,12 @@ public sealed partial class FastTreeRegressor : Microsoft.ML.Runtime.EntryPoints [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IRegressionOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IRegressionOutput, Microsoft.ML.EntryPoints.CommonOutputs.ITrainerOutput { /// /// The trained model /// - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } [Obsolete] @@ -8380,7 +8099,7 @@ public FastTreeRegressorPipelineStep(Output output) } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -8390,7 +8109,7 @@ namespace Legacy.Trainers /// [Obsolete] - public sealed partial class FastTreeTweedieRegressor : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithGroupId, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithWeight, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class FastTreeTweedieRegressor : Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithGroupId, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithWeight, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -8766,13 +8485,13 @@ public sealed partial class FastTreeTweedieRegressor : Microsoft.ML.Runtime.Entr /// Column to use for example groupId /// [Obsolete] - public Microsoft.ML.Runtime.EntryPoints.Optional GroupIdColumn { get; set; } + public Microsoft.ML.EntryPoints.Optional GroupIdColumn { get; set; } /// /// Column to use for example weight /// [Obsolete] - public Microsoft.ML.Runtime.EntryPoints.Optional WeightColumn { get; set; } + public Microsoft.ML.EntryPoints.Optional WeightColumn { get; set; } /// /// Column to use for labels @@ -8784,7 +8503,7 @@ public sealed partial class FastTreeTweedieRegressor : Microsoft.ML.Runtime.Entr /// The data to be used for training /// [Obsolete] - public Var TrainingData { get; set; } = new Var(); + public Var TrainingData { get; set; } = new Var(); /// /// Column to use for features @@ -8806,12 +8525,12 @@ public sealed partial class FastTreeTweedieRegressor : Microsoft.ML.Runtime.Entr [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IRegressionOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IRegressionOutput, Microsoft.ML.EntryPoints.CommonOutputs.ITrainerOutput { /// /// The trained model /// - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } [Obsolete] @@ -8843,7 +8562,7 @@ public FastTreeTweedieRegressorPipelineStep(Output output) } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -8854,7 +8573,7 @@ namespace Legacy.Trainers /// /// [Obsolete] - public sealed partial class FieldAwareFactorizationMachineBinaryClassifier : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class FieldAwareFactorizationMachineBinaryClassifier : Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -8928,7 +8647,7 @@ public sealed partial class FieldAwareFactorizationMachineBinaryClassifier : Mic /// The data to be used for training /// [Obsolete] - public Var TrainingData { get; set; } = new Var(); + public Var TrainingData { get; set; } = new Var(); /// /// Column to use for features @@ -8950,12 +8669,12 @@ public sealed partial class FieldAwareFactorizationMachineBinaryClassifier : Mic [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IBinaryClassificationOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IBinaryClassificationOutput, Microsoft.ML.EntryPoints.CommonOutputs.ITrainerOutput { /// /// The trained model /// - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } [Obsolete] @@ -8987,7 +8706,7 @@ public FieldAwareFactorizationMachineBinaryClassifierPipelineStep(Output output) } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -8999,7 +8718,7 @@ namespace Legacy.Trainers /// Trains a gradient boosted stump per feature, on all features simultaneously, to fit target values using least-squares. It mantains no interactions between features. /// [Obsolete] - public sealed partial class GeneralizedAdditiveModelBinaryClassifier : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithWeight, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class GeneralizedAdditiveModelBinaryClassifier : Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithWeight, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -9094,7 +8813,7 @@ public sealed partial class GeneralizedAdditiveModelBinaryClassifier : Microsoft /// Column to use for example weight /// [Obsolete] - public Microsoft.ML.Runtime.EntryPoints.Optional WeightColumn { get; set; } + public Microsoft.ML.EntryPoints.Optional WeightColumn { get; set; } /// /// Column to use for labels @@ -9106,7 +8825,7 @@ public sealed partial class GeneralizedAdditiveModelBinaryClassifier : Microsoft /// The data to be used for training /// [Obsolete] - public Var TrainingData { get; set; } = new Var(); + public Var TrainingData { get; set; } = new Var(); /// /// Column to use for features @@ -9128,12 +8847,12 @@ public sealed partial class GeneralizedAdditiveModelBinaryClassifier : Microsoft [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IBinaryClassificationOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IBinaryClassificationOutput, Microsoft.ML.EntryPoints.CommonOutputs.ITrainerOutput { /// /// The trained model /// - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } [Obsolete] @@ -9165,7 +8884,7 @@ public GeneralizedAdditiveModelBinaryClassifierPipelineStep(Output output) } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -9177,7 +8896,7 @@ namespace Legacy.Trainers /// Trains a gradient boosted stump per feature, on all features simultaneously, to fit target values using least-squares. It mantains no interactions between features. /// [Obsolete] - public sealed partial class GeneralizedAdditiveModelRegressor : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithWeight, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class GeneralizedAdditiveModelRegressor : Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithWeight, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -9272,7 +8991,7 @@ public sealed partial class GeneralizedAdditiveModelRegressor : Microsoft.ML.Run /// Column to use for example weight /// [Obsolete] - public Microsoft.ML.Runtime.EntryPoints.Optional WeightColumn { get; set; } + public Microsoft.ML.EntryPoints.Optional WeightColumn { get; set; } /// /// Column to use for labels @@ -9284,7 +9003,7 @@ public sealed partial class GeneralizedAdditiveModelRegressor : Microsoft.ML.Run /// The data to be used for training /// [Obsolete] - public Var TrainingData { get; set; } = new Var(); + public Var TrainingData { get; set; } = new Var(); /// /// Column to use for features @@ -9306,12 +9025,12 @@ public sealed partial class GeneralizedAdditiveModelRegressor : Microsoft.ML.Run [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IRegressionOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IRegressionOutput, Microsoft.ML.EntryPoints.CommonOutputs.ITrainerOutput { /// /// The trained model /// - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } [Obsolete] @@ -9343,7 +9062,7 @@ public GeneralizedAdditiveModelRegressorPipelineStep(Output output) } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -9362,7 +9081,7 @@ public enum KMeansPlusPlusTrainerInitAlgorithm /// /// [Obsolete] - public sealed partial class KMeansPlusPlusClusterer : Microsoft.ML.Runtime.EntryPoints.CommonInputs.IUnsupervisedTrainerWithWeight, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class KMeansPlusPlusClusterer : Microsoft.ML.EntryPoints.CommonInputs.IUnsupervisedTrainerWithWeight, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -9380,7 +9099,7 @@ public sealed partial class KMeansPlusPlusClusterer : Microsoft.ML.Runtime.Entry public KMeansPlusPlusTrainerInitAlgorithm InitAlgorithm { get; set; } = KMeansPlusPlusTrainerInitAlgorithm.KMeansParallel; /// - /// Tolerance parameter for trainer convergence. Lower = slower, more accurate + /// Tolerance parameter for trainer convergence. Low = slower, more accurate /// [Obsolete] public float OptTol { get; set; } = 1E-07f; @@ -9407,13 +9126,13 @@ public sealed partial class KMeansPlusPlusClusterer : Microsoft.ML.Runtime.Entry /// Column to use for example weight /// [Obsolete] - public Microsoft.ML.Runtime.EntryPoints.Optional WeightColumn { get; set; } + public Microsoft.ML.EntryPoints.Optional WeightColumn { get; set; } /// /// The data to be used for training /// [Obsolete] - public Var TrainingData { get; set; } = new Var(); + public Var TrainingData { get; set; } = new Var(); /// /// Column to use for features @@ -9435,12 +9154,12 @@ public sealed partial class KMeansPlusPlusClusterer : Microsoft.ML.Runtime.Entry [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IClusteringOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IClusteringOutput, Microsoft.ML.EntryPoints.CommonOutputs.ITrainerOutput { /// /// The trained model /// - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } [Obsolete] @@ -9472,7 +9191,7 @@ public KMeansPlusPlusClustererPipelineStep(Output output) } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -9498,7 +9217,7 @@ public enum LightGbmArgumentsEvalMetricType /// /// [Obsolete] - public sealed partial class LightGbmBinaryClassifier : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithGroupId, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithWeight, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class LightGbmBinaryClassifier : Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithGroupId, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithWeight, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -9655,13 +9374,13 @@ public sealed partial class LightGbmBinaryClassifier : Microsoft.ML.Runtime.Entr /// Column to use for example groupId /// [Obsolete] - public Microsoft.ML.Runtime.EntryPoints.Optional GroupIdColumn { get; set; } + public Microsoft.ML.EntryPoints.Optional GroupIdColumn { get; set; } /// /// Column to use for example weight /// [Obsolete] - public Microsoft.ML.Runtime.EntryPoints.Optional WeightColumn { get; set; } + public Microsoft.ML.EntryPoints.Optional WeightColumn { get; set; } /// /// Column to use for labels @@ -9673,7 +9392,7 @@ public sealed partial class LightGbmBinaryClassifier : Microsoft.ML.Runtime.Entr /// The data to be used for training /// [Obsolete] - public Var TrainingData { get; set; } = new Var(); + public Var TrainingData { get; set; } = new Var(); /// /// Column to use for features @@ -9695,12 +9414,12 @@ public sealed partial class LightGbmBinaryClassifier : Microsoft.ML.Runtime.Entr [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IBinaryClassificationOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IBinaryClassificationOutput, Microsoft.ML.EntryPoints.CommonOutputs.ITrainerOutput { /// /// The trained model /// - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } [Obsolete] @@ -9732,7 +9451,7 @@ public LightGbmBinaryClassifierPipelineStep(Output output) } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -9743,7 +9462,7 @@ namespace Legacy.Trainers /// /// [Obsolete] - public sealed partial class LightGbmClassifier : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithGroupId, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithWeight, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class LightGbmClassifier : Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithGroupId, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithWeight, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -9900,13 +9619,13 @@ public sealed partial class LightGbmClassifier : Microsoft.ML.Runtime.EntryPoint /// Column to use for example groupId /// [Obsolete] - public Microsoft.ML.Runtime.EntryPoints.Optional GroupIdColumn { get; set; } + public Microsoft.ML.EntryPoints.Optional GroupIdColumn { get; set; } /// /// Column to use for example weight /// [Obsolete] - public Microsoft.ML.Runtime.EntryPoints.Optional WeightColumn { get; set; } + public Microsoft.ML.EntryPoints.Optional WeightColumn { get; set; } /// /// Column to use for labels @@ -9918,7 +9637,7 @@ public sealed partial class LightGbmClassifier : Microsoft.ML.Runtime.EntryPoint /// The data to be used for training /// [Obsolete] - public Var TrainingData { get; set; } = new Var(); + public Var TrainingData { get; set; } = new Var(); /// /// Column to use for features @@ -9940,12 +9659,12 @@ public sealed partial class LightGbmClassifier : Microsoft.ML.Runtime.EntryPoint [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IMulticlassClassificationOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IMulticlassClassificationOutput, Microsoft.ML.EntryPoints.CommonOutputs.ITrainerOutput { /// /// The trained model /// - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } [Obsolete] @@ -9977,7 +9696,7 @@ public LightGbmClassifierPipelineStep(Output output) } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -9988,7 +9707,7 @@ namespace Legacy.Trainers /// /// [Obsolete] - public sealed partial class LightGbmRanker : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithGroupId, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithWeight, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class LightGbmRanker : Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithGroupId, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithWeight, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -10145,13 +9864,13 @@ public sealed partial class LightGbmRanker : Microsoft.ML.Runtime.EntryPoints.Co /// Column to use for example groupId /// [Obsolete] - public Microsoft.ML.Runtime.EntryPoints.Optional GroupIdColumn { get; set; } + public Microsoft.ML.EntryPoints.Optional GroupIdColumn { get; set; } /// /// Column to use for example weight /// [Obsolete] - public Microsoft.ML.Runtime.EntryPoints.Optional WeightColumn { get; set; } + public Microsoft.ML.EntryPoints.Optional WeightColumn { get; set; } /// /// Column to use for labels @@ -10163,7 +9882,7 @@ public sealed partial class LightGbmRanker : Microsoft.ML.Runtime.EntryPoints.Co /// The data to be used for training /// [Obsolete] - public Var TrainingData { get; set; } = new Var(); + public Var TrainingData { get; set; } = new Var(); /// /// Column to use for features @@ -10185,12 +9904,12 @@ public sealed partial class LightGbmRanker : Microsoft.ML.Runtime.EntryPoints.Co [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IRankingOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IRankingOutput, Microsoft.ML.EntryPoints.CommonOutputs.ITrainerOutput { /// /// The trained model /// - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } [Obsolete] @@ -10222,7 +9941,7 @@ public LightGbmRankerPipelineStep(Output output) } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -10233,7 +9952,7 @@ namespace Legacy.Trainers /// /// [Obsolete] - public sealed partial class LightGbmRegressor : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithGroupId, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithWeight, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class LightGbmRegressor : Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithGroupId, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithWeight, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -10390,13 +10109,13 @@ public sealed partial class LightGbmRegressor : Microsoft.ML.Runtime.EntryPoints /// Column to use for example groupId /// [Obsolete] - public Microsoft.ML.Runtime.EntryPoints.Optional GroupIdColumn { get; set; } + public Microsoft.ML.EntryPoints.Optional GroupIdColumn { get; set; } /// /// Column to use for example weight /// [Obsolete] - public Microsoft.ML.Runtime.EntryPoints.Optional WeightColumn { get; set; } + public Microsoft.ML.EntryPoints.Optional WeightColumn { get; set; } /// /// Column to use for labels @@ -10408,7 +10127,7 @@ public sealed partial class LightGbmRegressor : Microsoft.ML.Runtime.EntryPoints /// The data to be used for training /// [Obsolete] - public Var TrainingData { get; set; } = new Var(); + public Var TrainingData { get; set; } = new Var(); /// /// Column to use for features @@ -10430,12 +10149,12 @@ public sealed partial class LightGbmRegressor : Microsoft.ML.Runtime.EntryPoints [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IRegressionOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IRegressionOutput, Microsoft.ML.EntryPoints.CommonOutputs.ITrainerOutput { /// /// The trained model /// - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } [Obsolete] @@ -10467,7 +10186,7 @@ public LightGbmRegressorPipelineStep(Output output) } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -10479,7 +10198,7 @@ namespace Legacy.Trainers /// Train a linear SVM. /// [Obsolete] - public sealed partial class LinearSvmBinaryClassifier : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class LinearSvmBinaryClassifier : Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -10566,7 +10285,7 @@ public sealed partial class LinearSvmBinaryClassifier : Microsoft.ML.Runtime.Ent /// The data to be used for training /// [Obsolete] - public Var TrainingData { get; set; } = new Var(); + public Var TrainingData { get; set; } = new Var(); /// /// Column to use for features @@ -10588,12 +10307,12 @@ public sealed partial class LinearSvmBinaryClassifier : Microsoft.ML.Runtime.Ent [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IBinaryClassificationOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IBinaryClassificationOutput, Microsoft.ML.EntryPoints.CommonOutputs.ITrainerOutput { /// /// The trained model /// - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } [Obsolete] @@ -10625,7 +10344,7 @@ public LinearSvmBinaryClassifierPipelineStep(Output output) } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -10636,7 +10355,7 @@ namespace Legacy.Trainers /// /// [Obsolete] - public sealed partial class LogisticRegressionBinaryClassifier : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithWeight, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class LogisticRegressionBinaryClassifier : Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithWeight, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -10661,14 +10380,14 @@ public sealed partial class LogisticRegressionBinaryClassifier : Microsoft.ML.Ru public float L1Weight { get; set; } = 1f; /// - /// Tolerance parameter for optimization convergence. Lower = slower, more accurate + /// Tolerance parameter for optimization convergence. Low = slower, more accurate /// [TlcModule.SweepableDiscreteParamAttribute("OptTol", new object[]{0.0001f, 1E-07f})] [Obsolete] public float OptTol { get; set; } = 1E-07f; /// - /// Memory size for L-BFGS. Lower=faster, less accurate + /// Memory size for L-BFGS. Low=faster, less accurate /// [TlcModule.SweepableDiscreteParamAttribute("MemorySize", new object[]{5, 20, 50})] [Obsolete] @@ -10729,7 +10448,7 @@ public sealed partial class LogisticRegressionBinaryClassifier : Microsoft.ML.Ru /// Column to use for example weight /// [Obsolete] - public Microsoft.ML.Runtime.EntryPoints.Optional WeightColumn { get; set; } + public Microsoft.ML.EntryPoints.Optional WeightColumn { get; set; } /// /// Column to use for labels @@ -10741,7 +10460,7 @@ public sealed partial class LogisticRegressionBinaryClassifier : Microsoft.ML.Ru /// The data to be used for training /// [Obsolete] - public Var TrainingData { get; set; } = new Var(); + public Var TrainingData { get; set; } = new Var(); /// /// Column to use for features @@ -10763,12 +10482,12 @@ public sealed partial class LogisticRegressionBinaryClassifier : Microsoft.ML.Ru [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IBinaryClassificationOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IBinaryClassificationOutput, Microsoft.ML.EntryPoints.CommonOutputs.ITrainerOutput { /// /// The trained model /// - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } [Obsolete] @@ -10800,7 +10519,7 @@ public LogisticRegressionBinaryClassifierPipelineStep(Output output) } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -10811,7 +10530,7 @@ namespace Legacy.Trainers /// /// [Obsolete] - public sealed partial class LogisticRegressionClassifier : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithWeight, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class LogisticRegressionClassifier : Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithWeight, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -10836,14 +10555,14 @@ public sealed partial class LogisticRegressionClassifier : Microsoft.ML.Runtime. public float L1Weight { get; set; } = 1f; /// - /// Tolerance parameter for optimization convergence. Lower = slower, more accurate + /// Tolerance parameter for optimization convergence. Low = slower, more accurate /// [TlcModule.SweepableDiscreteParamAttribute("OptTol", new object[]{0.0001f, 1E-07f})] [Obsolete] public float OptTol { get; set; } = 1E-07f; /// - /// Memory size for L-BFGS. Lower=faster, less accurate + /// Memory size for L-BFGS. Low=faster, less accurate /// [TlcModule.SweepableDiscreteParamAttribute("MemorySize", new object[]{5, 20, 50})] [Obsolete] @@ -10904,7 +10623,7 @@ public sealed partial class LogisticRegressionClassifier : Microsoft.ML.Runtime. /// Column to use for example weight /// [Obsolete] - public Microsoft.ML.Runtime.EntryPoints.Optional WeightColumn { get; set; } + public Microsoft.ML.EntryPoints.Optional WeightColumn { get; set; } /// /// Column to use for labels @@ -10916,7 +10635,7 @@ public sealed partial class LogisticRegressionClassifier : Microsoft.ML.Runtime. /// The data to be used for training /// [Obsolete] - public Var TrainingData { get; set; } = new Var(); + public Var TrainingData { get; set; } = new Var(); /// /// Column to use for features @@ -10938,12 +10657,12 @@ public sealed partial class LogisticRegressionClassifier : Microsoft.ML.Runtime. [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IMulticlassClassificationOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IMulticlassClassificationOutput, Microsoft.ML.EntryPoints.CommonOutputs.ITrainerOutput { /// /// The trained model /// - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } [Obsolete] @@ -10975,7 +10694,7 @@ public LogisticRegressionClassifierPipelineStep(Output output) } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -10986,7 +10705,7 @@ namespace Legacy.Trainers /// /// [Obsolete] - public sealed partial class NaiveBayesClassifier : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class NaiveBayesClassifier : Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -11000,7 +10719,7 @@ public sealed partial class NaiveBayesClassifier : Microsoft.ML.Runtime.EntryPoi /// The data to be used for training /// [Obsolete] - public Var TrainingData { get; set; } = new Var(); + public Var TrainingData { get; set; } = new Var(); /// /// Column to use for features @@ -11022,12 +10741,12 @@ public sealed partial class NaiveBayesClassifier : Microsoft.ML.Runtime.EntryPoi [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IMulticlassClassificationOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IMulticlassClassificationOutput, Microsoft.ML.EntryPoints.CommonOutputs.ITrainerOutput { /// /// The trained model /// - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } [Obsolete] @@ -11059,7 +10778,7 @@ public NaiveBayesClassifierPipelineStep(Output output) } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -11070,7 +10789,7 @@ namespace Legacy.Trainers /// /// [Obsolete] - public sealed partial class OnlineGradientDescentRegressor : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class OnlineGradientDescentRegressor : Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -11181,7 +10900,7 @@ public sealed partial class OnlineGradientDescentRegressor : Microsoft.ML.Runtim /// The data to be used for training /// [Obsolete] - public Var TrainingData { get; set; } = new Var(); + public Var TrainingData { get; set; } = new Var(); /// /// Column to use for features @@ -11203,12 +10922,12 @@ public sealed partial class OnlineGradientDescentRegressor : Microsoft.ML.Runtim [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IRegressionOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IRegressionOutput, Microsoft.ML.EntryPoints.CommonOutputs.ITrainerOutput { /// /// The trained model /// - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } [Obsolete] @@ -11240,7 +10959,7 @@ public OnlineGradientDescentRegressorPipelineStep(Output output) } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -11250,7 +10969,7 @@ namespace Legacy.Trainers /// [Obsolete] - public sealed partial class OrdinaryLeastSquaresRegressor : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithWeight, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class OrdinaryLeastSquaresRegressor : Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithWeight, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -11271,7 +10990,7 @@ public sealed partial class OrdinaryLeastSquaresRegressor : Microsoft.ML.Runtime /// Column to use for example weight /// [Obsolete] - public Microsoft.ML.Runtime.EntryPoints.Optional WeightColumn { get; set; } + public Microsoft.ML.EntryPoints.Optional WeightColumn { get; set; } /// /// Column to use for labels @@ -11283,7 +11002,7 @@ public sealed partial class OrdinaryLeastSquaresRegressor : Microsoft.ML.Runtime /// The data to be used for training /// [Obsolete] - public Var TrainingData { get; set; } = new Var(); + public Var TrainingData { get; set; } = new Var(); /// /// Column to use for features @@ -11305,12 +11024,12 @@ public sealed partial class OrdinaryLeastSquaresRegressor : Microsoft.ML.Runtime [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IRegressionOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IRegressionOutput, Microsoft.ML.EntryPoints.CommonOutputs.ITrainerOutput { /// /// The trained model /// - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } [Obsolete] @@ -11342,7 +11061,7 @@ public OrdinaryLeastSquaresRegressorPipelineStep(Output output) } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -11353,7 +11072,7 @@ namespace Legacy.Trainers /// /// [Obsolete] - public sealed partial class PcaAnomalyDetector : Microsoft.ML.Runtime.EntryPoints.CommonInputs.IUnsupervisedTrainerWithWeight, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class PcaAnomalyDetector : Microsoft.ML.EntryPoints.CommonInputs.IUnsupervisedTrainerWithWeight, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -11388,13 +11107,13 @@ public sealed partial class PcaAnomalyDetector : Microsoft.ML.Runtime.EntryPoint /// Column to use for example weight /// [Obsolete] - public Microsoft.ML.Runtime.EntryPoints.Optional WeightColumn { get; set; } + public Microsoft.ML.EntryPoints.Optional WeightColumn { get; set; } /// /// The data to be used for training /// [Obsolete] - public Var TrainingData { get; set; } = new Var(); + public Var TrainingData { get; set; } = new Var(); /// /// Column to use for features @@ -11416,12 +11135,12 @@ public sealed partial class PcaAnomalyDetector : Microsoft.ML.Runtime.EntryPoint [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IAnomalyDetectionOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IAnomalyDetectionOutput, Microsoft.ML.EntryPoints.CommonOutputs.ITrainerOutput { /// /// The trained model /// - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } [Obsolete] @@ -11453,7 +11172,7 @@ public PcaAnomalyDetectorPipelineStep(Output output) } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -11464,7 +11183,7 @@ namespace Legacy.Trainers /// /// [Obsolete] - public sealed partial class PoissonRegressor : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithWeight, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class PoissonRegressor : Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithWeight, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -11483,14 +11202,14 @@ public sealed partial class PoissonRegressor : Microsoft.ML.Runtime.EntryPoints. public float L1Weight { get; set; } = 1f; /// - /// Tolerance parameter for optimization convergence. Lower = slower, more accurate + /// Tolerance parameter for optimization convergence. Low = slower, more accurate /// [TlcModule.SweepableDiscreteParamAttribute("OptTol", new object[]{0.0001f, 1E-07f})] [Obsolete] public float OptTol { get; set; } = 1E-07f; /// - /// Memory size for L-BFGS. Lower=faster, less accurate + /// Memory size for L-BFGS. Low=faster, less accurate /// [TlcModule.SweepableDiscreteParamAttribute("MemorySize", new object[]{5, 20, 50})] [Obsolete] @@ -11551,7 +11270,7 @@ public sealed partial class PoissonRegressor : Microsoft.ML.Runtime.EntryPoints. /// Column to use for example weight /// [Obsolete] - public Microsoft.ML.Runtime.EntryPoints.Optional WeightColumn { get; set; } + public Microsoft.ML.EntryPoints.Optional WeightColumn { get; set; } /// /// Column to use for labels @@ -11563,7 +11282,7 @@ public sealed partial class PoissonRegressor : Microsoft.ML.Runtime.EntryPoints. /// The data to be used for training /// [Obsolete] - public Var TrainingData { get; set; } = new Var(); + public Var TrainingData { get; set; } = new Var(); /// /// Column to use for features @@ -11585,12 +11304,12 @@ public sealed partial class PoissonRegressor : Microsoft.ML.Runtime.EntryPoints. [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IRegressionOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IRegressionOutput, Microsoft.ML.EntryPoints.CommonOutputs.ITrainerOutput { /// /// The trained model /// - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } [Obsolete] @@ -11622,7 +11341,7 @@ public PoissonRegressorPipelineStep(Output output) } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -11633,7 +11352,7 @@ namespace Legacy.Trainers /// /// [Obsolete] - public sealed partial class StochasticDualCoordinateAscentBinaryClassifier : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class StochasticDualCoordinateAscentBinaryClassifier : Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -11727,7 +11446,7 @@ public sealed partial class StochasticDualCoordinateAscentBinaryClassifier : Mic /// The data to be used for training /// [Obsolete] - public Var TrainingData { get; set; } = new Var(); + public Var TrainingData { get; set; } = new Var(); /// /// Column to use for features @@ -11749,12 +11468,12 @@ public sealed partial class StochasticDualCoordinateAscentBinaryClassifier : Mic [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IBinaryClassificationOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IBinaryClassificationOutput, Microsoft.ML.EntryPoints.CommonOutputs.ITrainerOutput { /// /// The trained model /// - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } [Obsolete] @@ -11786,7 +11505,7 @@ public StochasticDualCoordinateAscentBinaryClassifierPipelineStep(Output output) } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -11797,7 +11516,7 @@ namespace Legacy.Trainers /// /// [Obsolete] - public sealed partial class StochasticDualCoordinateAscentClassifier : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class StochasticDualCoordinateAscentClassifier : Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -11872,7 +11591,7 @@ public sealed partial class StochasticDualCoordinateAscentClassifier : Microsoft /// The data to be used for training /// [Obsolete] - public Var TrainingData { get; set; } = new Var(); + public Var TrainingData { get; set; } = new Var(); /// /// Column to use for features @@ -11894,12 +11613,12 @@ public sealed partial class StochasticDualCoordinateAscentClassifier : Microsoft [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IMulticlassClassificationOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IMulticlassClassificationOutput, Microsoft.ML.EntryPoints.CommonOutputs.ITrainerOutput { /// /// The trained model /// - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } [Obsolete] @@ -11931,7 +11650,7 @@ public StochasticDualCoordinateAscentClassifierPipelineStep(Output output) } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -11942,7 +11661,7 @@ namespace Legacy.Trainers /// /// [Obsolete] - public sealed partial class StochasticDualCoordinateAscentRegressor : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class StochasticDualCoordinateAscentRegressor : Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -12017,7 +11736,7 @@ public sealed partial class StochasticDualCoordinateAscentRegressor : Microsoft. /// The data to be used for training /// [Obsolete] - public Var TrainingData { get; set; } = new Var(); + public Var TrainingData { get; set; } = new Var(); /// /// Column to use for features @@ -12039,12 +11758,12 @@ public sealed partial class StochasticDualCoordinateAscentRegressor : Microsoft. [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IRegressionOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IRegressionOutput, Microsoft.ML.EntryPoints.CommonOutputs.ITrainerOutput { /// /// The trained model /// - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } [Obsolete] @@ -12076,7 +11795,7 @@ public StochasticDualCoordinateAscentRegressorPipelineStep(Output output) } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -12088,7 +11807,7 @@ namespace Legacy.Trainers /// Train an Hogwild SGD binary model. /// [Obsolete] - public sealed partial class StochasticGradientDescentBinaryClassifier : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithWeight, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class StochasticGradientDescentBinaryClassifier : Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithWeight, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -12168,7 +11887,7 @@ public sealed partial class StochasticGradientDescentBinaryClassifier : Microsof /// Column to use for example weight /// [Obsolete] - public Microsoft.ML.Runtime.EntryPoints.Optional WeightColumn { get; set; } + public Microsoft.ML.EntryPoints.Optional WeightColumn { get; set; } /// /// Column to use for labels @@ -12180,7 +11899,7 @@ public sealed partial class StochasticGradientDescentBinaryClassifier : Microsof /// The data to be used for training /// [Obsolete] - public Var TrainingData { get; set; } = new Var(); + public Var TrainingData { get; set; } = new Var(); /// /// Column to use for features @@ -12202,12 +11921,12 @@ public sealed partial class StochasticGradientDescentBinaryClassifier : Microsof [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IBinaryClassificationOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IBinaryClassificationOutput, Microsoft.ML.EntryPoints.CommonOutputs.ITrainerOutput { /// /// The trained model /// - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } [Obsolete] @@ -12239,7 +11958,7 @@ public StochasticGradientDescentBinaryClassifierPipelineStep(Output output) } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -12249,7 +11968,7 @@ namespace Legacy.Trainers /// [Obsolete] - public sealed partial class SymSgdBinaryClassifier : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class SymSgdBinaryClassifier : Microsoft.ML.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -12321,7 +12040,7 @@ public sealed partial class SymSgdBinaryClassifier : Microsoft.ML.Runtime.EntryP /// The data to be used for training /// [Obsolete] - public Var TrainingData { get; set; } = new Var(); + public Var TrainingData { get; set; } = new Var(); /// /// Column to use for features @@ -12343,12 +12062,12 @@ public sealed partial class SymSgdBinaryClassifier : Microsoft.ML.Runtime.EntryP [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IBinaryClassificationOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.IBinaryClassificationOutput, Microsoft.ML.EntryPoints.CommonOutputs.ITrainerOutput { /// /// The trained model /// - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); } [Obsolete] @@ -12380,7 +12099,7 @@ public SymSgdBinaryClassifierPipelineStep(Output output) } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -12392,7 +12111,7 @@ namespace Legacy.Transforms /// Approximate bootstrap sampling. /// [Obsolete] - public sealed partial class ApproximateBootstrapSampler : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class ApproximateBootstrapSampler : Microsoft.ML.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -12424,21 +12143,21 @@ public sealed partial class ApproximateBootstrapSampler : Microsoft.ML.Runtime.E /// Input dataset /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITransformOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.ITransformOutput { /// /// Transformed dataset /// - public Var OutputData { get; set; } = new Var(); + public Var OutputData { get; set; } = new Var(); /// /// Transform model /// - public Var Model { get; set; } = new Var(); + public Var Model { get; set; } = new Var(); } [Obsolete] @@ -12473,7 +12192,7 @@ public ApproximateBootstrapSamplerPipelineStep(Output output) [Obsolete] public Var Data { get; } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -12485,7 +12204,7 @@ namespace Legacy.Transforms /// For binary prediction, it renames the PredictedLabel and Score columns to include the name of the positive class. /// [Obsolete] - public sealed partial class BinaryPredictionScoreColumnsRenamer : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class BinaryPredictionScoreColumnsRenamer : Microsoft.ML.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -12493,27 +12212,27 @@ public sealed partial class BinaryPredictionScoreColumnsRenamer : Microsoft.ML.R /// The predictor model used in scoring /// [Obsolete] - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); /// /// Input dataset /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITransformOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.ITransformOutput { /// /// Transformed dataset /// - public Var OutputData { get; set; } = new Var(); + public Var OutputData { get; set; } = new Var(); /// /// Transform model /// - public Var Model { get; set; } = new Var(); + public Var Model { get; set; } = new Var(); } [Obsolete] @@ -12548,7 +12267,7 @@ public BinaryPredictionScoreColumnsRenamerPipelineStep(Output output) [Obsolete] public Var Data { get; } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -12595,7 +12314,7 @@ public sealed partial class NormalizeTransformBinColumn : OneToOneColumn [Obsolete] - public sealed partial class BinNormalizer : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class BinNormalizer : Microsoft.ML.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem { public BinNormalizer() @@ -12667,21 +12386,21 @@ public void AddColumn(string outputColumn, string inputColumn) /// Input dataset /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITransformOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.ITransformOutput { /// /// Transformed dataset /// - public Var OutputData { get; set; } = new Var(); + public Var OutputData { get; set; } = new Var(); /// /// Transform model /// - public Var Model { get; set; } = new Var(); + public Var Model { get; set; } = new Var(); } [Obsolete] @@ -12716,7 +12435,7 @@ public BinNormalizerPipelineStep(Output output) [Obsolete] public Var Data { get; } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -12734,7 +12453,7 @@ public enum OneHotEncodingTransformerOutputKind : byte [Obsolete] - public sealed partial class OneHotHashEncodingTransformerColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class OneHotHashEncodingColumn : OneToOneColumn, IOneToOneColumn { /// /// The number of bits to hash into. Must be between 1 and 30, inclusive. @@ -12783,7 +12502,7 @@ public sealed partial class OneHotHashEncodingTransformerColumn : OneToOneColumn /// /// [Obsolete] - public sealed partial class CategoricalHashOneHotVectorizer : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class CategoricalHashOneHotVectorizer : Microsoft.ML.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem { public CategoricalHashOneHotVectorizer() @@ -12814,15 +12533,15 @@ public CategoricalHashOneHotVectorizer(params (string inputColumn, string output public void AddColumn(string inputColumn) { - var list = Column == null ? new List() : new List(Column); - list.Add(OneToOneColumn.Create(inputColumn)); + var list = Column == null ? new List() : new List(Column); + list.Add(OneToOneColumn.Create(inputColumn)); Column = list.ToArray(); } public void AddColumn(string outputColumn, string inputColumn) { - var list = Column == null ? new List() : new List(Column); - list.Add(OneToOneColumn.Create(outputColumn, inputColumn)); + var list = Column == null ? new List() : new List(Column); + list.Add(OneToOneColumn.Create(outputColumn, inputColumn)); Column = list.ToArray(); } @@ -12831,7 +12550,7 @@ public void AddColumn(string outputColumn, string inputColumn) /// New column definition(s) (optional form: name:hashBits:src) /// [Obsolete] - public OneHotHashEncodingTransformerColumn[] Column { get; set; } + public OneHotHashEncodingColumn[] Column { get; set; } /// /// Number of bits to hash into. Must be between 1 and 30, inclusive. @@ -12867,21 +12586,21 @@ public void AddColumn(string outputColumn, string inputColumn) /// Input dataset /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITransformOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.ITransformOutput { /// /// Transformed dataset /// - public Var OutputData { get; set; } = new Var(); + public Var OutputData { get; set; } = new Var(); /// /// Transform model /// - public Var Model { get; set; } = new Var(); + public Var Model { get; set; } = new Var(); } [Obsolete] @@ -12916,7 +12635,7 @@ public CategoricalHashOneHotVectorizerPipelineStep(Output output) [Obsolete] public Var Data { get; } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -12981,7 +12700,7 @@ public sealed partial class OneHotEncodingTransformerColumn : OneToOneColumn /// [Obsolete] - public sealed partial class CategoricalOneHotVectorizer : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class CategoricalOneHotVectorizer : Microsoft.ML.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem { public CategoricalOneHotVectorizer() @@ -13065,21 +12784,21 @@ public void AddColumn(string outputColumn, string inputColumn) /// Input dataset /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITransformOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.ITransformOutput { /// /// Transformed dataset /// - public Var OutputData { get; set; } = new Var(); + public Var OutputData { get; set; } = new Var(); /// /// Transform model /// - public Var Model { get; set; } = new Var(); + public Var Model { get; set; } = new Var(); } [Obsolete] @@ -13114,7 +12833,7 @@ public CategoricalOneHotVectorizerPipelineStep(Output output) [Obsolete] public Var Data { get; } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -13141,7 +12860,7 @@ public sealed partial class TokenizingByCharactersTransformerColumn : OneToOneCo /// [Obsolete] - public sealed partial class CharacterTokenizer : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class CharacterTokenizer : Microsoft.ML.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem { public CharacterTokenizer() @@ -13201,21 +12920,21 @@ public void AddColumn(string outputColumn, string inputColumn) /// Input dataset /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITransformOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.ITransformOutput { /// /// Transformed dataset /// - public Var OutputData { get; set; } = new Var(); + public Var OutputData { get; set; } = new Var(); /// /// Transform model /// - public Var Model { get; set; } = new Var(); + public Var Model { get; set; } = new Var(); } [Obsolete] @@ -13250,7 +12969,7 @@ public CharacterTokenizerPipelineStep(Output output) [Obsolete] public Var Data { get; } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -13279,7 +12998,7 @@ public sealed partial class ColumnConcatenatingTransformerColumn : ManyToOneColu /// Concatenates one or more columns of the same item type. /// [Obsolete] - public sealed partial class ColumnConcatenator : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class ColumnConcatenator : Microsoft.ML.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem { public ColumnConcatenator() @@ -13309,21 +13028,21 @@ public void AddColumn(string name, params string[] source) /// Input dataset /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITransformOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.ITransformOutput { /// /// Transformed dataset /// - public Var OutputData { get; set; } = new Var(); + public Var OutputData { get; set; } = new Var(); /// /// Transform model /// - public Var Model { get; set; } = new Var(); + public Var Model { get; set; } = new Var(); } [Obsolete] @@ -13358,7 +13077,7 @@ public ColumnConcatenatorPipelineStep(Output output) [Obsolete] public Var Data { get; } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -13387,7 +13106,7 @@ public sealed partial class ColumnCopyingTransformerColumn : OneToOneColumn [Obsolete] - public sealed partial class ColumnCopier : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class ColumnCopier : Microsoft.ML.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem { public ColumnCopier() @@ -13441,21 +13160,21 @@ public void AddColumn(string outputColumn, string inputColumn) /// Input dataset /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITransformOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.ITransformOutput { /// /// Transformed dataset /// - public Var OutputData { get; set; } = new Var(); + public Var OutputData { get; set; } = new Var(); /// /// Transform model /// - public Var Model { get; set; } = new Var(); + public Var Model { get; set; } = new Var(); } [Obsolete] @@ -13490,7 +13209,7 @@ public ColumnCopierPipelineStep(Output output) [Obsolete] public Var Data { get; } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -13502,7 +13221,7 @@ namespace Legacy.Transforms /// Selects a set of columns, dropping all others /// [Obsolete] - public sealed partial class ColumnSelector : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class ColumnSelector : Microsoft.ML.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -13534,21 +13253,21 @@ public sealed partial class ColumnSelector : Microsoft.ML.Runtime.EntryPoints.Co /// Input dataset /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITransformOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.ITransformOutput { /// /// Transformed dataset /// - public Var OutputData { get; set; } = new Var(); + public Var OutputData { get; set; } = new Var(); /// /// Transform model /// - public Var Model { get; set; } = new Var(); + public Var Model { get; set; } = new Var(); } [Obsolete] @@ -13583,7 +13302,7 @@ public ColumnSelectorPipelineStep(Output output) [Obsolete] public Var Data { get; } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -13624,7 +13343,7 @@ public sealed partial class TypeConvertingTransformerColumn : OneToOneColumn [Obsolete] - public sealed partial class ColumnTypeConverter : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class ColumnTypeConverter : Microsoft.ML.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem { public ColumnTypeConverter() @@ -13690,21 +13409,21 @@ public void AddColumn(string outputColumn, string inputColumn) /// Input dataset /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITransformOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.ITransformOutput { /// /// Transformed dataset /// - public Var OutputData { get; set; } = new Var(); + public Var OutputData { get; set; } = new Var(); /// /// Transform model /// - public Var Model { get; set; } = new Var(); + public Var Model { get; set; } = new Var(); } [Obsolete] @@ -13739,7 +13458,7 @@ public ColumnTypeConverterPipelineStep(Output output) [Obsolete] public Var Data { get; } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -13749,7 +13468,7 @@ namespace Legacy.Transforms /// [Obsolete] - public sealed partial class CombinerByContiguousGroupId : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class CombinerByContiguousGroupId : Microsoft.ML.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -13769,21 +13488,21 @@ public sealed partial class CombinerByContiguousGroupId : Microsoft.ML.Runtime.E /// Input dataset /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITransformOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.ITransformOutput { /// /// Transformed dataset /// - public Var OutputData { get; set; } = new Var(); + public Var OutputData { get; set; } = new Var(); /// /// Transform model /// - public Var Model { get; set; } = new Var(); + public Var Model { get; set; } = new Var(); } [Obsolete] @@ -13818,7 +13537,7 @@ public CombinerByContiguousGroupIdPipelineStep(Output output) [Obsolete] public Var Data { get; } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -13859,7 +13578,7 @@ public sealed partial class NormalizeTransformAffineColumn : OneToOneColumn [Obsolete] - public sealed partial class ConditionalNormalizer : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class ConditionalNormalizer : Microsoft.ML.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem { public ConditionalNormalizer() @@ -13925,7 +13644,7 @@ public void AddColumn(string outputColumn, string inputColumn) /// Input dataset /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); [Obsolete] @@ -13934,12 +13653,12 @@ public sealed class Output /// /// Transformed dataset /// - public Var OutputData { get; set; } = new Var(); + public Var OutputData { get; set; } = new Var(); /// /// Transform model /// - public Var Model { get; set; } = new Var(); + public Var Model { get; set; } = new Var(); } [Obsolete] @@ -13974,7 +13693,7 @@ public ConditionalNormalizerPipelineStep(Output output) [Obsolete] public Var Data { get; } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -13993,7 +13712,7 @@ public enum CacheCachingType /// Caches using the specified cache option. /// [Obsolete] - public sealed partial class DataCache : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class DataCache : Microsoft.ML.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -14007,7 +13726,7 @@ public sealed partial class DataCache : Microsoft.ML.Runtime.EntryPoints.CommonI /// Input dataset /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); [Obsolete] @@ -14016,7 +13735,7 @@ public sealed class Output /// /// Dataset /// - public Var OutputData { get; set; } = new Var(); + public Var OutputData { get; set; } = new Var(); } [Obsolete] @@ -14050,7 +13769,7 @@ public DataCachePipelineStep(Output output) [Obsolete] public Var Data { get; } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -14070,13 +13789,13 @@ public sealed partial class DatasetScorer /// The dataset to be scored /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); /// /// The predictor model to apply to data /// [Obsolete] - public Var PredictorModel { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); /// /// Suffix to append to the score columns @@ -14091,12 +13810,12 @@ public sealed class Output /// /// The scored dataset /// - public Var ScoredData { get; set; } = new Var(); + public Var ScoredData { get; set; } = new Var(); /// /// The scoring transform /// - public Var ScoringTransform { get; set; } = new Var(); + public Var ScoringTransform { get; set; } = new Var(); } } @@ -14117,13 +13836,13 @@ public sealed partial class DatasetTransformScorer /// The dataset to be scored /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); /// /// The transform model to apply to data /// [Obsolete] - public Var TransformModel { get; set; } = new Var(); + public Var TransformModel { get; set; } = new Var(); [Obsolete] @@ -14132,12 +13851,12 @@ public sealed class Output /// /// The scored dataset /// - public Var ScoredData { get; set; } = new Var(); + public Var ScoredData { get; set; } = new Var(); /// /// The scoring transform /// - public Var ScoringTransform { get; set; } = new Var(); + public Var ScoringTransform { get; set; } = new Var(); } } @@ -14191,7 +13910,7 @@ public sealed partial class ValueToKeyMappingTransformerColumn : OneToOneColumn< /// Converts input values (words, numbers, etc.) to index in a dictionary. /// [Obsolete] - public sealed partial class Dictionarizer : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class Dictionarizer : Microsoft.ML.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem { public Dictionarizer() @@ -14269,21 +13988,21 @@ public void AddColumn(string outputColumn, string inputColumn) /// Input dataset /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITransformOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.ITransformOutput { /// /// Transformed dataset /// - public Var OutputData { get; set; } = new Var(); + public Var OutputData { get; set; } = new Var(); /// /// Transform model /// - public Var Model { get; set; } = new Var(); + public Var Model { get; set; } = new Var(); } [Obsolete] @@ -14318,7 +14037,7 @@ public DictionarizerPipelineStep(Output output) [Obsolete] public Var Data { get; } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -14330,7 +14049,7 @@ namespace Legacy.Transforms /// Combines all the features into one feature column. /// [Obsolete] - public sealed partial class FeatureCombiner : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class FeatureCombiner : Microsoft.ML.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -14344,21 +14063,21 @@ public sealed partial class FeatureCombiner : Microsoft.ML.Runtime.EntryPoints.C /// Input dataset /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITransformOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.ITransformOutput { /// /// Transformed dataset /// - public Var OutputData { get; set; } = new Var(); + public Var OutputData { get; set; } = new Var(); /// /// Transform model /// - public Var Model { get; set; } = new Var(); + public Var Model { get; set; } = new Var(); } [Obsolete] @@ -14393,7 +14112,106 @@ public FeatureCombinerPipelineStep(Output output) [Obsolete] public Var Data { get; } [Obsolete] - public Var Model { get; } + public Var Model { get; } + } + } + } + + namespace Legacy.Transforms + { + + /// + /// For each data point, calculates the contribution of individual features to the model prediction. + /// + [Obsolete] + public sealed partial class FeatureContributionCalculationTransformer : Microsoft.ML.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem + { + + + /// + /// The predictor model to apply to data + /// + [Obsolete] + public Var PredictorModel { get; set; } = new Var(); + + /// + /// Name of feature column + /// + [Obsolete] + public string FeatureColumn { get; set; } = "Features"; + + /// + /// Number of top contributions + /// + [Obsolete] + public int Top { get; set; } = 10; + + /// + /// Number of bottom contributions + /// + [Obsolete] + public int Bottom { get; set; } = 10; + + /// + /// Whether or not output of Features contribution should be normalized + /// + [Obsolete] + public bool Normalize { get; set; } = true; + + /// + /// Input dataset + /// + [Obsolete] + public Var Data { get; set; } = new Var(); + + + [Obsolete] + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.ITransformOutput + { + /// + /// Transformed dataset + /// + public Var OutputData { get; set; } = new Var(); + + /// + /// Transform model + /// + public Var Model { get; set; } = new Var(); + + } + [Obsolete] + public Var GetInputData() => Data; + + [Obsolete] + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) + { + if (previousStep != null) + { + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(FeatureContributionCalculationTransformer)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } + + Data = dataStep.Data; + } + Output output = experiment.Add(this); + return new FeatureContributionCalculationTransformerPipelineStep(output); + } + + [Obsolete] + private class FeatureContributionCalculationTransformerPipelineStep : ILearningPipelineDataStep + { + [Obsolete] + public FeatureContributionCalculationTransformerPipelineStep(Output output) + { + Data = output.OutputData; + Model = output.Model; + } + + [Obsolete] + public Var Data { get; } + [Obsolete] + public Var Model { get; } } } } @@ -14404,7 +14222,7 @@ namespace Legacy.Transforms /// /// [Obsolete] - public sealed partial class FeatureSelectorByCount : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class FeatureSelectorByCount : Microsoft.ML.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -14424,21 +14242,21 @@ public sealed partial class FeatureSelectorByCount : Microsoft.ML.Runtime.EntryP /// Input dataset /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITransformOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.ITransformOutput { /// /// Transformed dataset /// - public Var OutputData { get; set; } = new Var(); + public Var OutputData { get; set; } = new Var(); /// /// Transform model /// - public Var Model { get; set; } = new Var(); + public Var Model { get; set; } = new Var(); } [Obsolete] @@ -14473,7 +14291,7 @@ public FeatureSelectorByCountPipelineStep(Output output) [Obsolete] public Var Data { get; } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -14484,7 +14302,7 @@ namespace Legacy.Transforms /// /// [Obsolete] - public sealed partial class FeatureSelectorByMutualInformation : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class FeatureSelectorByMutualInformation : Microsoft.ML.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -14516,21 +14334,21 @@ public sealed partial class FeatureSelectorByMutualInformation : Microsoft.ML.Ru /// Input dataset /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITransformOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.ITransformOutput { /// /// Transformed dataset /// - public Var OutputData { get; set; } = new Var(); + public Var OutputData { get; set; } = new Var(); /// /// Transform model /// - public Var Model { get; set; } = new Var(); + public Var Model { get; set; } = new Var(); } [Obsolete] @@ -14565,7 +14383,7 @@ public FeatureSelectorByMutualInformationPipelineStep(Output output) [Obsolete] public Var Data { get; } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -14610,7 +14428,7 @@ public sealed partial class LpNormalizingTransformerGcnColumn : OneToOneColumn [Obsolete] - public sealed partial class GlobalContrastNormalizer : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class GlobalContrastNormalizer : Microsoft.ML.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem { public GlobalContrastNormalizer() @@ -14682,21 +14500,21 @@ public void AddColumn(string outputColumn, string inputColumn) /// Input dataset /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITransformOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.ITransformOutput { /// /// Transformed dataset /// - public Var OutputData { get; set; } = new Var(); + public Var OutputData { get; set; } = new Var(); /// /// Transform model /// - public Var Model { get; set; } = new Var(); + public Var Model { get; set; } = new Var(); } [Obsolete] @@ -14731,7 +14549,7 @@ public GlobalContrastNormalizerPipelineStep(Output output) [Obsolete] public Var Data { get; } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -14789,7 +14607,7 @@ public sealed partial class HashJoiningTransformColumn : OneToOneColumn /// [Obsolete] - public sealed partial class HashConverter : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class HashConverter : Microsoft.ML.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem { public HashConverter() @@ -14867,21 +14685,21 @@ public void AddColumn(string outputColumn, string inputColumn) /// Input dataset /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITransformOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.ITransformOutput { /// /// Transformed dataset /// - public Var OutputData { get; set; } = new Var(); + public Var OutputData { get; set; } = new Var(); /// /// Transform model /// - public Var Model { get; set; } = new Var(); + public Var Model { get; set; } = new Var(); } [Obsolete] @@ -14916,7 +14734,7 @@ public HashConverterPipelineStep(Output output) [Obsolete] public Var Data { get; } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -14945,7 +14763,7 @@ public sealed partial class ImageGrayscaleTransformColumn : OneToOneColumn [Obsolete] - public sealed partial class ImageGrayscale : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class ImageGrayscale : Microsoft.ML.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem { public ImageGrayscale() @@ -14999,21 +14817,21 @@ public void AddColumn(string outputColumn, string inputColumn) /// Input dataset /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITransformOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.ITransformOutput { /// /// Transformed dataset /// - public Var OutputData { get; set; } = new Var(); + public Var OutputData { get; set; } = new Var(); /// /// Transform model /// - public Var Model { get; set; } = new Var(); + public Var Model { get; set; } = new Var(); } [Obsolete] @@ -15048,7 +14866,7 @@ public ImageGrayscalePipelineStep(Output output) [Obsolete] public Var Data { get; } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -15077,7 +14895,7 @@ public sealed partial class ImageLoaderTransformColumn : OneToOneColumn [Obsolete] - public sealed partial class ImageLoader : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class ImageLoader : Microsoft.ML.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem { public ImageLoader() @@ -15137,21 +14955,21 @@ public void AddColumn(string outputColumn, string inputColumn) /// Input dataset /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITransformOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.ITransformOutput { /// /// Transformed dataset /// - public Var OutputData { get; set; } = new Var(); + public Var OutputData { get; set; } = new Var(); /// /// Transform model /// - public Var Model { get; set; } = new Var(); + public Var Model { get; set; } = new Var(); } [Obsolete] @@ -15186,7 +15004,7 @@ public ImageLoaderPipelineStep(Output output) [Obsolete] public Var Data { get; } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -15263,7 +15081,7 @@ public sealed partial class ImagePixelExtractorTransformColumn : OneToOneColumn< /// Extract color plane(s) from an image. Options include scaling, offset and conversion to floating point. /// [Obsolete] - public sealed partial class ImagePixelExtractor : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class ImagePixelExtractor : Microsoft.ML.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem { public ImagePixelExtractor() @@ -15365,21 +15183,21 @@ public void AddColumn(string outputColumn, string inputColumn) /// Input dataset /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITransformOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.ITransformOutput { /// /// Transformed dataset /// - public Var OutputData { get; set; } = new Var(); + public Var OutputData { get; set; } = new Var(); /// /// Transform model /// - public Var Model { get; set; } = new Var(); + public Var Model { get; set; } = new Var(); } [Obsolete] @@ -15414,7 +15232,7 @@ public ImagePixelExtractorPipelineStep(Output output) [Obsolete] public Var Data { get; } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -15484,7 +15302,7 @@ public sealed partial class ImageResizerTransformColumn : OneToOneColumn [Obsolete] - public sealed partial class ImageResizer : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class ImageResizer : Microsoft.ML.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem { public ImageResizer() @@ -15562,21 +15380,21 @@ public void AddColumn(string outputColumn, string inputColumn) /// Input dataset /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITransformOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.ITransformOutput { /// /// Transformed dataset /// - public Var OutputData { get; set; } = new Var(); + public Var OutputData { get; set; } = new Var(); /// /// Transform model /// - public Var Model { get; set; } = new Var(); + public Var Model { get; set; } = new Var(); } [Obsolete] @@ -15611,7 +15429,7 @@ public ImageResizerPipelineStep(Output output) [Obsolete] public Var Data { get; } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -15638,7 +15456,7 @@ public sealed partial class KeyToValueMappingTransformerColumn : OneToOneColumn< /// [Obsolete] - public sealed partial class KeyToTextConverter : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class KeyToTextConverter : Microsoft.ML.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem { public KeyToTextConverter() @@ -15692,21 +15510,21 @@ public void AddColumn(string outputColumn, string inputColumn) /// Input dataset /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITransformOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.ITransformOutput { /// /// Transformed dataset /// - public Var OutputData { get; set; } = new Var(); + public Var OutputData { get; set; } = new Var(); /// /// Transform model /// - public Var Model { get; set; } = new Var(); + public Var Model { get; set; } = new Var(); } [Obsolete] @@ -15741,7 +15559,7 @@ public KeyToTextConverterPipelineStep(Output output) [Obsolete] public Var Data { get; } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -15753,7 +15571,7 @@ namespace Legacy.Transforms /// Transforms the label to either key or bool (if needed) to make it suitable for classification. /// [Obsolete] - public sealed partial class LabelColumnKeyBooleanConverter : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem + public sealed partial class LabelColumnKeyBooleanConverter : Microsoft.ML.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.Legacy.ILearningPipelineItem { @@ -15773,21 +15591,21 @@ public sealed partial class LabelColumnKeyBooleanConverter : Microsoft.ML.Runtim /// Input dataset /// [Obsolete] - public Var Data { get; set; } = new Var(); + public Var Data { get; set; } = new Var(); [Obsolete] - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITransformOutput + public sealed class Output : Microsoft.ML.EntryPoints.CommonOutputs.ITransformOutput { /// /// Transformed dataset /// - public Var OutputData { get; set; } = new Var(); + public Var OutputData { get; set; } = new Var(); /// /// Transform model /// - public Var Model { get; set; } = new Var(); + public Var Model { get; set; } = new Var(); } [Obsolete] @@ -15822,7 +15640,7 @@ public LabelColumnKeyBooleanConverterPipelineStep(Output output) [Obsolete] public Var Data { get; } [Obsolete] - public Var Model { get; } + public Var Model { get; } } } } @@ -15857,7 +15675,7 @@ public sealed partial class LabelIndicatorTransformColumn : OneToOneColumn