diff --git a/timely_beliefs/beliefs/classes.py b/timely_beliefs/beliefs/classes.py index 34d9a6c5..f8031ca6 100644 --- a/timely_beliefs/beliefs/classes.py +++ b/timely_beliefs/beliefs/classes.py @@ -35,11 +35,13 @@ Interval, and_, func, + select, ) from sqlalchemy.ext.declarative import declared_attr, has_inherited_table from sqlalchemy.ext.hybrid import hybrid_method, hybrid_property from sqlalchemy.orm import Session, backref, relationship from sqlalchemy.orm.util import AliasedClass +from sqlalchemy.engine import Engine from sqlalchemy.schema import UniqueConstraint from sqlalchemy.sql.elements import BinaryExpression from sqlalchemy.sql.expression import Selectable @@ -326,9 +328,11 @@ def query(cls, *args, **kwargs): @classmethod def search_session( # noqa: C901 # todo: remove after removing deprecated arguments cls, + engine: Engine, session: Session, sensor: Union[SensorDBMixin, int], sensor_class: Optional[Type[SensorDBMixin]] = DBSensor, + source_class = DBBeliefSource, event_starts_after: Optional[datetime] = None, event_ends_after: Optional[datetime] = None, event_starts_before: Optional[datetime] = None, @@ -517,7 +521,8 @@ def apply_belief_timing_filters(q): return q # Main query - q = session.query(cls).filter(cls.sensor_id == sensor.id) + # q = session.query(cls).filter(cls.sensor_id == sensor.id) + q = select(cls).filter(cls.sensor_id == sensor.id) # Apply event time filter if not pd.isnull(event_starts_after): @@ -557,7 +562,7 @@ def apply_belief_timing_filters(q): most_recent_beliefs_only and not most_recent_beliefs_only_incompatible_criteria ): - subq = session.query( + subq = select( cls.event_start, cls.source_id, func.min(cls.belief_horizon).label("most_recent_belief_horizon"), @@ -581,7 +586,7 @@ def apply_belief_timing_filters(q): # Apply most recent events filter if most_recent_events_only: subq_most_recent_events = ( - session.query( + select( cls.source_id, func.max(cls.event_start).label("most_recent_event_start"), ) @@ -597,9 +602,27 @@ def apply_belief_timing_filters(q): == subq_most_recent_events.c.most_recent_event_start, ), ) - # Build our DataFrame of beliefs - df = BeliefsDataFrame(sensor=sensor, beliefs=q.all()) + beliefs = engine.execute(q).all() + + print(beliefs) + unique_sensor_ids = set([b[4] for b in beliefs]) + unique_source_ids = set([b[5] for b in beliefs]) + + unique_sources = session.query(source_class).filter(source_class.id.in_(unique_source_ids)).all() + unique_sensors = session.query(sensor_class).filter(sensor_class.id.in_(unique_sensor_ids)).all() + sensor_mapping = {sid: s for sid, s in zip(unique_sensor_ids, unique_sensors)} + + source_mapping = {sid: s for sid, s in zip(unique_source_ids, unique_sources)} + beliefs = [{"sensor": sensor_mapping[belief[4]], + "source": source_mapping[belief[5]], + "event_value": belief[3], + "cumulative_probability": belief[2], + "event_start": belief[0], + "belief_horizon": belief[1]} for belief in beliefs] + beliefs = [TimedBelief(**b) for b in beliefs] + + df = BeliefsDataFrame(sensor=sensor, beliefs=beliefs) # Actually filter by belief time if beliefs_after is not None: @@ -878,21 +901,23 @@ def __init__( # noqa: C901 todo: refactor, e.g. by detecting initialization met kwargs["columns"] = columns # Check for different sensors - unique_sensors = set(belief.sensor for belief in beliefs) - if len(unique_sensors) != 1: - raise ValueError("BeliefsDataFrame cannot describe multiple sensors.") - sensor = list(unique_sensors)[0] - - # Check for different sources with the same name - unique_sources = set(str(belief.source) for belief in beliefs) - unique_source_string_representations = set( - str(source) for source in unique_sources - ) - if len(unique_source_string_representations) != len(unique_sources): - raise ValueError( - "String representations of sources must be unique. Cannot initialise BeliefsDataFrame given the following unique sources:\n%s" - % unique_sources - ) + # unique_sensors = set(belief.sensor for belief in beliefs) + # unique_sensors = set(beliefs) + # # unique_sensors = {sensor} + # if len(unique_sensors) != 1: + # raise ValueError("BeliefsDataFrame cannot describe multiple sensors.") + # sensor = list(unique_sensors)[0] + # + # # Check for different sources with the same name + # unique_sources = set(str(belief.source) for belief in beliefs) + # unique_source_string_representations = set( + # str(source) for source in unique_sources + # ) + # if len(unique_source_string_representations) != len(unique_sources): + # raise ValueError( + # "String representations of sources must be unique. Cannot initialise BeliefsDataFrame given the following unique sources:\n%s" + # % unique_sources + # ) # Construct data and index from beliefs before calling super class beliefs = sorted( @@ -2256,4 +2281,4 @@ def downsample_beliefs_data_frame( for col, att in col_att_dict.items() ], axis=1, - ).set_index([belief_timing_col, "source", "cumulative_probability"], append=True) + ).set_index([belief_timing_col, "source", "cumulative_probability"], append=True) \ No newline at end of file