Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add as_sklearn and from_sklearn APIs to serialize to CPU sklearn-estimators for supported models #6102

Open
wants to merge 12 commits into
base: branch-25.02
Choose a base branch
from

Conversation

dantegd
Copy link
Member

@dantegd dantegd commented Oct 8, 2024

No description provided.

@github-actions github-actions bot added the Cython / Python Cython or Python issue label Oct 8, 2024
@betatim
Copy link
Member

betatim commented Nov 19, 2024

Why have the methods do both the conversion cuml<>sklearn and the serialisation? Having a way to convert to and from scikit-learn seems like a useful thing by itself. Maybe because you have your own way of serialising the model, or because you need a particular type of model or who-knows-what.

So to serialise it you'd do something like pickle.dumps(cuml_est.to_sklearn()) (or dill, joblib, ...)

How hard would it be to have cuml.from_sklearn(estimator)? As in one top level function that takes a scikit-learn estimator and converts it to the cuml equivalent? It seems like it should be easy to figure out the estimator's class name: "just look at .__class__.__name__" but I wonder if there is a trap here?


Name bike shedding: if we don't save things to a file, how about as_sklearn? A bit like other functions that do type conversion like astype.

If we do save to a file, then save_sklearn and load_sklearn? Basically getting words like "save" and "load" in there to make it clear that this is about storing things (to a file).

Copy link

copy-pr-bot bot commented Dec 19, 2024

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@dantegd dantegd changed the title Add to_sklearn and from_sklearn APIs to serialize to CPU sklearn-estimators for supported models Add as_sklearn and from_sklearn APIs to serialize to CPU sklearn-estimators for supported models Dec 19, 2024
@dantegd dantegd changed the base branch from branch-24.12 to branch-25.02 December 19, 2024 23:57
@dantegd
Copy link
Member Author

dantegd commented Dec 20, 2024

@betatim just implemented your suggestions and changed the functionality to not save a file but return an estimator.

I think the idea of cuml.from_sklearn is fantastic, but requires some additional logic and testing since it requires to also validate the estimator requested has the functionality needed, so perhaps it can be a follow up to keep this PR small and succint?

Copy link
Contributor

@viclafargue viclafargue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just have two comments

return self._cpu_model

@classmethod
def from_sklearn(cls, model):
Copy link
Contributor

@viclafargue viclafargue Dec 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be great to have a global conversion table, so that we don't need to provide the class as a parameter.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a class method, so we get the class from that, it's not something the user passes (like self in non class methods)

A global conversion table will be useful for a follow up to add cuml.from_sklearn library type of functionality though

"""
estimator = cls()
estimator.import_cpu_model()
estimator._cpu_model = model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could be interesting to add an optional parameter to this function to allow a deepcopy of the sklearn model.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol I asked the same thing before reading this suggestion :)

self.import_cpu_model()
self.build_cpu_model()
self.gpu_to_cpu()
return self._cpu_model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we not return a deep copy here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my education, why would we want to deepcopy? Mostly asking because in my experience 99% of cases where someone uses deepcopy there is something else that we can do instead or just not do it. Mostly Python "just works" without deepcopy'ing, hence my interest

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we simply return a reference to the internal model, any modification to one (additional training or something else) would affect the other. This might create a situation in which the CPU and GPU attributes are out of sync in the cuML estimator. Or inversely, the sklearn estimator returned by the function might silently be updated by the cuML estimator.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps it should be a parameter, by default most users probably won't care about needing a deep copy, so I wouldn't do it by default, but if a user needs it then they can request it, what do you guys think?

Comment on lines 70 to 71
else:
raise ValueError(f"Serializer {format} not supported.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does click not take care of invalid values being passed :(

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does, I added it by accident based on habits :P

python/cuml/cuml/experimental/accel/__main__.py Outdated Show resolved Hide resolved
python/cuml/cuml/experimental/accel/__main__.py Outdated Show resolved Hide resolved
Comment on lines 74 to 81
# Convert to sklearn estimator
sklearn_estimator = accelerated_estimator.as_sklearn()

# Save using chosen format
with open(output, "wb") as f:
serializer.dump(sklearn_estimator, f)

# Exit after conversion
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the comments? They seem like they repeat what the code says. I like comments that explain why the code is the way it is, but I don't think we need that here as it is pretty straightforward

python/cuml/cuml/internals/base.pyx Show resolved Hide resolved
python/cuml/cuml/internals/base.pyx Show resolved Hide resolved
python/cuml/cuml/internals/base.pyx Show resolved Hide resolved
python/cuml/cuml/internals/base.pyx Show resolved Hide resolved
python/cuml/cuml/internals/base.pyx Show resolved Hide resolved
Comment on lines +54 to +56
@pytest.fixture
def random_state():
return 42
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need this instead of using 42 in the tests directly?

We could have a global version of this that allows us to run the tests with several seeds, but maybe something to tackle in the future/new PR

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't really need it at all

@dantegd dantegd added feature request New feature or request non-breaking Non-breaking change labels Dec 20, 2024
@dantegd dantegd marked this pull request as ready for review December 20, 2024 19:35
@dantegd dantegd requested a review from a team as a code owner December 20, 2024 19:35
@dantegd dantegd requested review from csadorf and wphicks December 20, 2024 19:35
Copy link
Contributor

@wphicks wphicks left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pre-approving this given vacation schedules. Looks great to me once current discussion is resolved.

Copy link
Contributor

@viclafargue viclafargue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, addding the deepcopy optional parameter to both functions is probably optimal, but both solutions work for me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Cython / Python Cython or Python issue feature request New feature or request non-breaking Non-breaking change
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants