Skip to content

Commit

Permalink
feat: Implement basic cross_validate functionality (#443)
Browse files Browse the repository at this point in the history
  • Loading branch information
augustebaum authored Oct 21, 2024
1 parent ef1bfec commit 904b534
Show file tree
Hide file tree
Showing 16 changed files with 882 additions and 152 deletions.
26 changes: 26 additions & 0 deletions examples/01_basic_usage.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,32 @@
"project.put(\"my_fitted_pipeline\", my_pipeline)"
]
},
{
"cell_type": "markdown",
"id": "59aaa",
"metadata": {},
"source": [
"---\n",
"## Cross-validation with skore"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "58aaaa",
"metadata": {},
"outputs": [],
"source": [
"from sklearn import datasets, linear_model\n",
"from skore.cross_validate import cross_validate\n",
"diabetes = datasets.load_diabetes()\n",
"X = diabetes.data[:150]\n",
"y = diabetes.target[:150]\n",
"lasso = linear_model.Lasso()\n",
"\n",
"cv_results = cross_validate(lasso, X, y, cv=3, project=project)"
]
},
{
"cell_type": "markdown",
"id": "59",
Expand Down
14 changes: 14 additions & 0 deletions examples/01_basic_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,20 @@ def my_func(x):

project.put("my_fitted_pipeline", my_pipeline)

# %% [markdown]
# ---
# ## Cross-validation with skore

# %%
from sklearn import datasets, linear_model
from skore.cross_validate import cross_validate
diabetes = datasets.load_diabetes()
X = diabetes.data[:150]
y = diabetes.target[:150]
lasso = linear_model.Lasso()

cv_results = cross_validate(lasso, X, y, cv=3, project=project)

# %% [markdown]
# _Stay tuned for some new features!_

Expand Down
44 changes: 21 additions & 23 deletions skore-ui/src/components/VegaWidget.vue
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
<script setup lang="ts">
import { isUserInDarkMode } from "@/services/utils";
import { isDeepEqual, isUserInDarkMode } from "@/services/utils";
import { View as VegaView } from "vega";
import embed, { type Config, type VisualizationSpec } from "vega-embed";
import { onBeforeUnmount, onMounted, ref } from "vue";
import { onBeforeUnmount, onMounted, ref, watch } from "vue";
const props = defineProps<{ spec: VisualizationSpec }>();
Expand All @@ -17,38 +17,36 @@ const vegaConfig: Config = {
background: "transparent",
};
let vegaView: VegaView | null = null;
const resizeObserver = new ResizeObserver(async () => {
const w = container.value?.clientWidth || 0;
await vegaView?.width(w).runAsync();
});
async function makePlot(spec: VisualizationSpec) {
const mySpec = { ...spec, width: "container" } as VisualizationSpec;
const r = await embed(container.value!, mySpec, {
theme: isUserInDarkMode() ? "dark" : undefined,
config: vegaConfig,
actions: false,
});
vegaView = r.view;
}
onMounted(async () => {
if (container.value) {
const r = await embed(
container.value,
{
width: container.value?.clientWidth || 0,
...props.spec,
},
{
theme: isUserInDarkMode() ? "dark" : undefined,
config: vegaConfig,
actions: false,
}
);
vegaView = r.view;
resizeObserver.observe(container.value);
makePlot(props.spec);
}
});
onBeforeUnmount(() => {
if (container.value) {
resizeObserver.unobserve(container.value);
}
if (vegaView) {
vegaView.finalize();
}
});
watch(
() => props.spec,
async (newSpec, oldSpec) => {
if (!isDeepEqual(newSpec, oldSpec)) {
makePlot(newSpec);
}
}
);
</script>

<template>
Expand Down
3 changes: 2 additions & 1 deletion skore/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ maintainers = [{name = "skore developers", email="[email protected]"}]
dependencies = [
"diskcache",
"fastapi",
"plotly>=5,<6",
"rich",
"skops",
"uvicorn",
Expand Down Expand Up @@ -61,7 +62,7 @@ artifacts = ["src/skore/ui/static/"]

[project.optional-dependencies]
test = [
"altair",
"altair>=5,<6",
"httpx",
"matplotlib",
"pandas",
Expand Down
2 changes: 2 additions & 0 deletions skore/src/skore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@

import rich.logging

from skore.cross_validate import cross_validate
from skore.project import Project, load

from .utils._show_versions import show_versions

__all__ = [
"cross_validate",
"load",
"show_versions",
"Project",
Expand Down
Loading

0 comments on commit 904b534

Please sign in to comment.