Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scarliles/splitter injection #61

Open
wants to merge 20 commits into
base: submodulev3
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions sklearn/tree/_splitter.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,60 @@
# Jacob Schreiber <[email protected]>
# Adam Li <[email protected]>
# Jong Shin <[email protected]>
# Samuel Carliles <[email protected]>
#
# License: BSD 3 clause

# See _splitter.pyx for details.
cimport numpy as cnp

from libcpp.vector cimport vector
from libc.stdlib cimport malloc

from ..utils._typedefs cimport float32_t, float64_t, intp_t, int32_t
from ._utils cimport UINT32_t
from ._criterion cimport BaseCriterion, Criterion


# NICE IDEAS THAT DON'T APPEAR POSSIBLE
# - accessing elements of a memory view of cython extension types in a nogil block/function
# - storing cython extension types in cpp vectors
#
# despite the fact that we can access scalar extension type properties in such a context,
# as for instance node_split_best does with Criterion and Partition,
# and we can access the elements of a memory view of primitive types in such a context
#
# SO WHERE DOES THAT LEAVE US
# - we can transform these into cpp vectors of structs
# and with some minor casting irritations everything else works ok
ctypedef void* SplitConditionParameters
ctypedef bint (*SplitConditionFunction)(
Splitter splitter,
SplitRecord* current_split,
intp_t n_missing,
bint missing_go_to_left,
float64_t lower_bound,
float64_t upper_bound,
SplitConditionParameters split_condition_parameters
) noexcept nogil

cdef struct SplitConditionTuple:
SplitConditionFunction f
SplitConditionParameters p

cdef class SplitCondition:
cdef SplitConditionTuple t

cdef class MinSamplesLeafCondition(SplitCondition):
pass

cdef class MinWeightLeafCondition(SplitCondition):
pass

cdef class MonotonicConstraintCondition(SplitCondition):
pass


cdef struct SplitRecord:
# Data to track sample split
intp_t feature # Which feature to split on.
Expand Down Expand Up @@ -112,6 +153,12 @@ cdef class Splitter(BaseSplitter):
cdef const cnp.int8_t[:] monotonic_cst
cdef bint with_monotonic_cst

cdef list _presplit_conditions
cdef list _postsplit_conditions

cdef vector[SplitConditionTuple*] presplit_conditions
cdef vector[SplitConditionTuple*] postsplit_conditions

cdef int init(
self,
object X,
Expand Down
Loading
Loading