diff --git a/evofr/data/hier_frequencies.py b/evofr/data/hier_frequencies.py index 5af17c0..8a470b7 100644 --- a/evofr/data/hier_frequencies.py +++ b/evofr/data/hier_frequencies.py @@ -57,6 +57,14 @@ def __init__( self.var_names = format_var_names(raw_var_names, pivot=pivot) self.pivot = self.var_names[-1] + # Loop each group + grouped = raw_seq.groupby(group) + self.names = [name for name, _ in grouped] + self.groups = [ + VariantFrequencies(group, self.date_to_index, self.var_names) + for _, group in grouped + ] + # Aggregate counts into larger windows self.aggregation_frequency = aggregation_frequency if self.aggregation_frequency is not None: @@ -68,14 +76,6 @@ def __init__( self.groups, self.dates, self.aggregation_frequency ) - # Loop each group - grouped = raw_seq.groupby(group) - self.names = [name for name, _ in grouped] - self.groups = [ - VariantFrequencies(group, self.date_to_index, self.var_names) - for _, group in grouped - ] - self.seq_counts = np.stack([g.seq_counts for g in self.groups], axis=-1) def make_data_dict(self, data: Optional[dict] = None) -> dict: