Skip to content

Commit

Permalink
First tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AKuederle committed Jul 18, 2023
1 parent 6b71f7d commit 740d2ab
Show file tree
Hide file tree
Showing 9 changed files with 715 additions and 937 deletions.
17 changes: 5 additions & 12 deletions examples/stride_segmentation/segmentation_hmm_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,24 +87,20 @@

stride_model = SimpleHmm(
n_states=20,
n_gmm_components=6,
algo_train="baum-welch",
n_gmm_components=5,
stop_threshold=1e-9,
max_iterations=5,
architecture="left-right-strict",
verbose=True,
name="stride_model",
)

transition_model = SimpleHmm(
n_states=5,
n_gmm_components=3,
algo_train="baum-welch",
n_states=2,
n_gmm_components=2,
stop_threshold=1e-9,
max_iterations=5,
architecture="left-right-loose",
verbose=True,
name="transition_model",
)

# %%
Expand All @@ -121,13 +117,10 @@
stride_model=stride_model,
transition_model=transition_model,
feature_transform=feature_transform,
algo_predict="viterbi",
algo_train="baum-welch",
stop_threshold=1e-9,
max_iterations=1,
initialization="labels",
verbose=True,
name="segmentation_model",
)

# %%
Expand Down Expand Up @@ -170,9 +163,9 @@

np.set_printoptions(precision=3, linewidth=180, suppress=True)

print(segmentation_model.model.dense_transition_matrix()[0:-2, 0:-2])
print(np.e ** segmentation_model.model.edges)

print(segmentation_model.model.states[10])
print(segmentation_model.model.distributions[segmentation_model.n_states - 1])

# %%
# Applying the Model to a Sequence
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""HMM based stride segmentation by Roth et al. 2021."""
from contextlib import suppress
from importlib.resources import open_text
from pathlib import Path
from typing import Dict, Generic, Optional, TypeVar, Union

import numpy as np
Expand Down Expand Up @@ -39,13 +37,13 @@ class PreTrainedRothSegmentationModel(RothSegmentationHmm):
"""

def __new__(cls):
# try to load models
with open_text(
"gaitmap_mad.stride_segmentation.hmm._pre_trained_models", "fallriskpd_at_lab_model.json"
) as test_data, Path(test_data.name).open(encoding="utf8") as f:
model_json = f.read()
return RothSegmentationHmm.from_json(model_json)
# def __new__(cls):
# # try to load models
# with open_text(
# "gaitmap_mad.stride_segmentation.hmm._pre_trained_models", "fallriskpd_at_lab_model.json"
# ) as test_data, Path(test_data.name).open(encoding="utf8") as f:
# model_json = f.read()
# return RothSegmentationHmm.from_json(model_json)


BaseSegmentationHmmT = TypeVar("BaseSegmentationHmmT", bound=BaseSegmentationHmm)
Expand Down
129 changes: 44 additions & 85 deletions gaitmap_mad/gaitmap_mad/stride_segmentation/hmm/_segmentation_model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
"""Segmentation _model base classes and helper."""
import copy
from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple
from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Set, cast

import numpy as np
import pandas as pd
import pomegranate as pg
from pomegranate import HiddenMarkovModel as pgHMM
from pomegranate.hmm import History
from pomegranate.hmm import DenseHMM as pgHMM
from tpcp import OptiPara, cf, make_optimize_safe
from typing_extensions import Self

Expand All @@ -19,26 +17,19 @@
from gaitmap_mad.stride_segmentation.hmm._simple_model import SimpleHmm
from gaitmap_mad.stride_segmentation.hmm._utils import (
ShortenedHMMPrint,
_clone_model,
_DataToShortError,
_HackyClonableHMMFix,
add_transition,
check_history_for_training_failure,
convert_stride_list_to_transition_list,
create_transition_matrix_fully_connected,
extract_transitions_starts_stops_from_hidden_state_sequence,
fix_model_names,
get_model_distributions,
freeze_nested_distribution,
get_train_data_sequences_strides,
get_train_data_sequences_transitions,
labels_to_strings,
labels_to_prior,
predict,
)


def create_fully_labeled_gait_sequences(
data_train_sequence, stride_list_sequence, transition_model, stride_model, algo_predict
):
def create_fully_labeled_gait_sequences(data_train_sequence, stride_list_sequence, transition_model, stride_model):
"""Create fully labeled gait sequence.
To find the "actual" hidden-state labels for "labeled-training" with the given training data set, we will again
Expand All @@ -60,9 +51,7 @@ def create_fully_labeled_gait_sequences(
for start, end in transition_start_end_list[["start", "end"]].to_numpy():
transition_data_train = data[start:end]
try:
labels_train[start:end] = transition_model.predict_hidden_state_sequence(
transition_data_train, algorithm=algo_predict
)
labels_train[start:end] = transition_model.predict_hidden_state_sequence(transition_data_train)
except _DataToShortError:
# This happens if a transition is too short to be predicted by the transition model
continue
Expand All @@ -72,8 +61,7 @@ def create_fully_labeled_gait_sequences(
stride_data_train = data[start:end]
try:
labels_train[start:end] = (
stride_model.predict_hidden_state_sequence(stride_data_train, algorithm=algo_predict)
+ transition_model.n_states
stride_model.predict_hidden_state_sequence(stride_data_train) + transition_model.n_states
)
except _DataToShortError:
# This happens if a stride is too short to be predicted by the stride model
Expand Down Expand Up @@ -180,7 +168,7 @@ def self_optimize_with_info(
raise NotImplementedError


class RothSegmentationHmm(BaseSegmentationHmm, _HackyClonableHMMFix, ShortenedHMMPrint):
class RothSegmentationHmm(BaseSegmentationHmm, ShortenedHMMPrint):
"""A hierarchical HMM model for stride segmentation proposed by Roth et al. [1]_.
This model differentiates between strides and transitions.
Expand Down Expand Up @@ -287,14 +275,11 @@ class RothSegmentationHmm(BaseSegmentationHmm, _HackyClonableHMMFix, ShortenedHM
transition_model__model: OptiPara
transition_model__data_columns: OptiPara
feature_transform: BaseHmmFeatureTransformer
algo_predict: Literal["viterbi", "baum-welch"]
algo_train: Literal["viterbi", "baum-welch"]
stop_threshold: float
max_iterations: int
initialization: Literal["labels", "fully-connected"]
verbose: bool
n_jobs: int
name: Optional[str]
model: OptiPara[Optional[pgHMM]]
data_columns: OptiPara[Optional[Tuple[str, ...]]]

Expand All @@ -307,48 +292,38 @@ def __init__(
SimpleHmm(
n_states=20,
n_gmm_components=6,
algo_train="baum-welch",
stop_threshold=1e-9,
max_iterations=10,
architecture="left-right-strict",
name="stride_model",
)
),
transition_model: SimpleHmm = cf(
SimpleHmm(
n_states=5,
n_gmm_components=3,
algo_train="baum-welch",
stop_threshold=1e-9,
max_iterations=10,
architecture="left-right-loose",
name="transition_model",
)
),
feature_transform: RothHmmFeatureTransformer = cf(RothHmmFeatureTransformer()),
*,
algo_predict: Literal["viterbi", "map"] = "viterbi",
algo_train: Literal["viterbi", "baum-welch"] = "baum-welch",
stop_threshold: float = 1e-9,
max_iterations: int = 1,
initialization: Literal["labels", "fully-connected"] = "labels",
verbose: bool = True,
n_jobs: int = 1,
name: str = "segmentation_model",
model: Optional[pgHMM] = None,
data_columns: Optional[Tuple[str, ...]] = None,
):
self.stride_model = stride_model
self.transition_model = transition_model
self.feature_transform = feature_transform
self.algo_predict = algo_predict
self.algo_train = algo_train
self.stop_threshold = stop_threshold
self.max_iterations = max_iterations
self.initialization = initialization
self.verbose = verbose
self.n_jobs = n_jobs
self.name = name
self.model = model
self.data_columns = data_columns

Expand Down Expand Up @@ -406,7 +381,7 @@ def predict(self, data: SingleSensorData, sampling_rate_hz: float) -> Self:

# pomegranate always adds a label for the start- and end-state, which can be ignored here!
self.hidden_state_sequence_feature_space_ = predict(
self.model, feature_data, expected_columns=self.data_columns, algorithm=self.algo_predict
self.model, feature_data, expected_columns=self.data_columns
)
self.hidden_state_sequence_ = self.feature_transform.inverse_transform_state_sequence(
self.hidden_state_sequence_feature_space_, sampling_rate_hz=sampling_rate_hz
Expand Down Expand Up @@ -476,7 +451,7 @@ def self_optimize_with_info(
data_sequence: Sequence[SingleSensorData],
stride_list_sequence: Sequence[SingleSensorStrideList],
sampling_rate_hz: float,
) -> Tuple[Self, Dict[Literal["self", "transition_model", "stride_model"], History]]:
) -> Tuple[Self, None]:
"""Create and train the HMM model based on the given data and labels.
This is identical to `self_optimize`, but returns additional information about the training process.
Expand Down Expand Up @@ -536,41 +511,43 @@ def self_optimize_with_info(
n_states_stride = stride_model_trained.n_states

# extract fitted distributions from both separate trained models
distributions = get_model_distributions(transition_model_trained.model) + get_model_distributions(
stride_model_trained.model
distributions = copy.deepcopy(
[*transition_model_trained.model.distributions, *stride_model_trained.model.distributions]
)
# Freeze all distributions
for dist in distributions:
freeze_nested_distribution(dist)

# predict hidden state labels for complete walking bouts
labels_train_sequence = create_fully_labeled_gait_sequences(
data_sequence_feature_space,
stride_list_feature_space,
transition_model_trained,
stride_model_trained,
self.algo_predict,
)

# Now that we have a fully labeled dataset, we use our already fitted distributions as input for the new model
if self.initialization == "fully-connected":
trans_mat, start_probs, end_probs = create_transition_matrix_fully_connected(self.n_states)

new_model = pg.HiddenMarkovModel.from_matrix(
transition_probabilities=copy.deepcopy(trans_mat),
distributions=copy.deepcopy(distributions),
# TODO: Update to new API
new_model = pgHMM.from_matrix(
copy.deepcopy(distributions),
edges=trans_mat,
starts=start_probs,
ends=None,
state_names=None,
verbose=self.verbose,
)

elif self.initialization == "labels":
# combine already trained transition matrices -> zero pad "stride" transition matrix to the left
trans_mat_stride = stride_model_trained.model.dense_transition_matrix()[:-2, :-2]
trans_mat_stride = (np.e**stride_model_trained.model.edges).numpy()
transmat_stride = np.pad(
trans_mat_stride, [(n_states_transition, 0), (n_states_transition, 0)], mode="constant"
)

# zero-pad "transition" transition matrix to the right
trans_mat_transition = transition_model_trained.model.dense_transition_matrix()[:-2, :-2]
trans_mat_transition = (np.e**transition_model_trained.model.edges).numpy()
transmat_trans = np.pad(trans_mat_transition, [(0, n_states_stride), (0, n_states_stride)], mode="constant")

# after correct zero padding we can combine both transition matrices just by "adding" them together!
Expand All @@ -581,71 +558,53 @@ def self_optimize_with_info(
labels_train_sequence
)

existing_transitions = cast(Set[Tuple[int, int]], set(zip(*np.argwhere(trans_mat > 0).T)))
missing_transitions = transitions - existing_transitions
# Add missing transitions which will "connect" transition-hmm and stride-hmm
# We initialize with a very small probability, so that the model can learn the correct values in the next
# step.
# Note: We sort the transitions to enforce consistent order and reproducibility.
for i, j in sorted(missing_transitions):
trans_mat[i, j] = 0.1

start_probs = np.zeros(self.n_states)
start_probs[starts] = 1.0
end_probs = np.zeros(self.n_states)
end_probs[ends] = 1.0

new_model = pg.HiddenMarkovModel.from_matrix(
transition_probabilities=copy.deepcopy(trans_mat),
# We create the new model here.
# Note that we don't need to care about an initialization, as we froze all the distributions above.
new_model = pgHMM(
distributions=copy.deepcopy(distributions),
edges=trans_mat,
starts=start_probs,
ends=None,
state_names=None,
ends=end_probs,
verbose=self.verbose,
max_iter=self.max_iterations,
tol=self.stop_threshold,
)

existing_transitions = {(start.name, end.name) for start, end in new_model.graph.edges()}
missing_transitions = transitions - existing_transitions
# Add missing transitions which will "connect" transition-hmm and stride-hmm
# We initialize with a very small probability, so that the model can learn the correct values in the next
# step.
# Note: We sort the transitions to enforce consistent order and reproducibility.
for trans in sorted(missing_transitions):
add_transition(new_model, trans, 0.1)
else:
# Can not be reached, as we perform the check beforehand, but just to be sure and make the linter happy
raise RuntimeError()
# pomegranate seems to have a strange sorting bug where state names >= 10 (e.g. s10 get sorted in a bad order
# like s0, s1, s10, s2 usw..)
new_model = fix_model_names(new_model)
new_model.bake()

# make sure we do not change our distributions anymore!
new_model.freeze_distributions()

# We clone the model here, as this changes the order of edges to be sorted somehow...
new_model = _clone_model(new_model, assert_correct=False)

# convert labels to state-names
labels_train_sequence_str = labels_to_strings(labels_train_sequence)

self.data_columns = tuple(data_sequence_feature_space[0].columns)

# make sure data is in an pomegranate compatible format!
data_train_sequence = [
np.ascontiguousarray(feature_data[list(self.data_columns)].to_numpy().copy())
np.ascontiguousarray(feature_data[list(self.data_columns)].to_numpy())
for feature_data in data_sequence_feature_space
]

_, history = new_model.fit(
sequences=np.array(data_train_sequence, dtype=object),
labels=np.array(labels_train_sequence_str, dtype=object).copy(),
algorithm=self.algo_train,
stop_threshold=self.stop_threshold,
max_iterations=self.max_iterations,
return_history=True,
verbose=self.verbose,
n_jobs=self.n_jobs,
multiple_check_input=False,
# We froze all distributions above.
# This means this fit step should only update the transition probabilities.
new_model.fit(
data_train_sequence,
priors=labels_to_prior(labels_train_sequence, len(new_model.distributions)),
)
check_history_for_training_failure(history)

new_model.name = self.name

self.model = new_model

return (
self,
{"self": history, "transition_model": transition_model_history, "stride_model": stride_model_history},
None,
)
Loading

0 comments on commit 740d2ab

Please sign in to comment.