From d689e540dd8ed232ccea5e3f23a6089e5a1834b7 Mon Sep 17 00:00:00 2001 From: georgedouzas Date: Tue, 24 Dec 2024 08:21:17 +0200 Subject: [PATCH] refactor: Split dataloader funcionality to multiple pages --- gui/.gitignore | 1 + gui/gui/components/common.py | 85 ++++-- gui/gui/components/dataloader/__init__.py | 0 gui/gui/components/dataloader/creation.py | 303 +++++++++++++++++++ gui/gui/components/dataloader/loading.py | 122 ++++++++ gui/gui/components/parameters.py | 72 ----- gui/gui/components/training_parameters.py | 34 --- gui/gui/gui.py | 13 + gui/gui/pages/create/creation.py | 306 +++++++++++++++++++ gui/gui/pages/dataloader/__init__.py | 0 gui/gui/pages/dataloader/creation.py | 302 +++++++++++++++++++ gui/gui/pages/dataloader/loading.py | 103 +++++++ gui/gui/pages/index.py | 350 ++++++++-------------- gui/rxconfig.py | 2 + 14 files changed, 1336 insertions(+), 357 deletions(-) create mode 100644 gui/gui/components/dataloader/__init__.py create mode 100644 gui/gui/components/dataloader/creation.py create mode 100644 gui/gui/components/dataloader/loading.py delete mode 100644 gui/gui/components/parameters.py delete mode 100644 gui/gui/components/training_parameters.py create mode 100644 gui/gui/pages/create/creation.py create mode 100644 gui/gui/pages/dataloader/__init__.py create mode 100644 gui/gui/pages/dataloader/creation.py create mode 100644 gui/gui/pages/dataloader/loading.py diff --git a/gui/.gitignore b/gui/.gitignore index 139ed98..810ed72 100644 --- a/gui/.gitignore +++ b/gui/.gitignore @@ -3,3 +3,4 @@ .web __pycache__/ assets/external/ +dataloader.pkl diff --git a/gui/gui/components/common.py b/gui/gui/components/common.py index dda8e64..a06b170 100644 --- a/gui/gui/components/common.py +++ b/gui/gui/components/common.py @@ -1,23 +1,23 @@ """Page common components.""" -from collections.abc import Callable - import reflex as rx - -def home(reset_state: Callable) -> rx.Component: - """Home title.""" - return rx.link(rx.hstack(rx.icon('home'), rx.text('Home', size='4', weight='bold')), on_click=reset_state) +SIDEBAR_OPTIONS = { + 'spacing': '1', + 'position': 'fixed', + 'left': '50px', + 'top': '50px', + 'padding_x': '1em', + 'padding_y': "1.5em", + 'bg': rx.color('blue', 3), + 'height': '620px', + 'width': '20em', +} -def header() -> rx.Component: - """Header of page.""" - return rx.vstack( - rx.heading("Sports Betting", size='9', align='center'), - rx.heading("Application", size='4', color_scheme='blue'), - spacing='1', - align='center', - ) +def home() -> rx.Component: + """Home title.""" + return rx.text('Sports Betting', size='4', weight='bold') def title(name: str, icon_name: str) -> rx.Component: @@ -31,12 +31,53 @@ def title(name: str, icon_name: str) -> rx.Component: ) -def selection(items: list[str], value: str, disabled: bool, on_change: Callable) -> rx.Component: - """The selection component.""" - return rx.select( - items=items, - value=value, - disabled=disabled, - on_change=on_change, - width='50%', +def select_mode(state: rx.State, content: str) -> rx.Component: + """Selection of mode component.""" + return rx.vstack( + rx.text(content, size='1'), + rx.hstack( + rx.select( + items=['Data', 'Modelling'], + value=state.mode_category, + disabled=state.visibility_level > 1, + width='120px', + on_change=state.set_mode_category, + ), + rx.select( + ['Create', 'Load'], + value=state.mode_type, + disabled=state.visibility_level > 1, + width='120px', + on_change=state.set_mode_type, + ), + ), + ) + + +def control_buttons(state: rx.State, disabled: bool) -> rx.Component: + """Control buttons of UI.""" + return rx.vstack( + rx.divider(top='600px', position='fixed', width='18em'), + rx.hstack( + rx.button( + 'Submit', + on_click=state.submit_state, + disabled=disabled, + loading=state.loading, + position='fixed', + top='620px', + width='70px', + ), + rx.link( + rx.button( + 'Reset', + on_click=state.reset_state, + position='fixed', + top='620px', + left='150px', + width='70px', + ), + href='/', + ), + ), ) diff --git a/gui/gui/components/dataloader/__init__.py b/gui/gui/components/dataloader/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gui/gui/components/dataloader/creation.py b/gui/gui/components/dataloader/creation.py new file mode 100644 index 0000000..126cbb4 --- /dev/null +++ b/gui/gui/components/dataloader/creation.py @@ -0,0 +1,303 @@ +"""Create dataloader components.""" + +from collections.abc import Callable + +import reflex as rx +from reflex_ag_grid import ag_grid + +from ..common import SIDEBAR_OPTIONS, control_buttons, home, select_mode, title + + +def checkboxes(row: list[str], state: rx.State) -> rx.Component: + """Checkbox of parameter value.""" + + def _in_leagues(name: str) -> rx.Var: + return state.default_param_checked['leagues'].contains(name.to_string()) + + def _in_years(name: str) -> rx.Var: + return state.default_param_checked['years'].contains(name.to_string()) + + def _in_divisions(name: str) -> rx.Var: + return state.default_param_checked['divisions'].contains(name.to_string()) + + return rx.vstack( + rx.foreach( + row, + lambda name: rx.checkbox( + name, + default_checked=rx.cond( + _in_leagues(name), True, rx.cond(_in_years(name), True, rx.cond(_in_divisions(name), True, False)) + ), + checked=state.param_checked[name.to_string()], + name=name.to_string(), + on_change=lambda checked: state.update_param_checked(name, checked), + ), + ), + ) + + +def dialog(name: str, icon_name: str, state: rx.State) -> Callable: + """Dialog component.""" + + def _dialog(rows: list[list[str]], on_submit: Callable) -> rx.Component: + """The dialog component.""" + return rx.dialog.root( + rx.dialog.trigger( + rx.button( + rx.tooltip(rx.icon(icon_name), content=name), + size='4', + variant='outline', + disabled=state.visibility_level > 3, + ), + ), + rx.dialog.content( + rx.form.root( + rx.dialog.title(name), + rx.dialog.description( + f'Select the {name.lower()} to include in the training data.', + size="2", + margin_bottom="16px", + ), + rx.hstack(rx.foreach(rows, lambda row: checkboxes(row, state))), + rx.flex( + rx.dialog.close(rx.button('Submit', type='submit')), + justify='end', + spacing="3", + margin_top="50px", + ), + on_submit=on_submit, + reset_on_submit=False, + width="100%", + ), + ), + ) + + return _dialog + + +def training_parameters_selection(state: rx.State) -> rx.Component: + """The training parameters selection component.""" + return rx.vstack( + rx.vstack( + rx.text('Odds type', size='1'), + rx.select( + state.odds_types, + default_value=state.odds_types[0], + on_change=state.handle_odds_type, + disabled=state.visibility_level > 4, + width='100%', + ), + ), + rx.vstack( + rx.text('Drop NA threshold of columns', size='1'), + rx.slider( + min=0.0, + max=1.0, + step=0.01, + default_value=0.0, + on_change=state.handle_drop_na_thres, + disabled=state.visibility_level > 4, + ), + style={ + 'margin-top': '15px', + 'width': '100%', + }, + ), + ) + + +def parameters_selection(state: rx.State) -> rx.Component: + """The parameters title.""" + return rx.hstack( + dialog('Leagues', 'earth', state)(state.all_leagues, state.handle_submit_leagues), + dialog('Years', 'calendar', state)(state.all_years, state.handle_submit_years), + dialog('Divisions', 'gauge', state)(state.all_divisions, state.handle_submit_divisions), + ) + + +def main(state: rx.State) -> rx.Component: + """Main container of UI.""" + return rx.container( + rx.vstack( + home(), + rx.divider(), + # Mode selection + title('Mode', 'blend'), + select_mode(state, 'Create a dataloader'), + # Sport selection + rx.cond( + state.visibility_level > 1, + title('Sport', 'medal'), + ), + rx.cond( + state.visibility_level > 1, + rx.text('Select a sport', size='1'), + ), + rx.cond( + state.visibility_level > 1, + rx.select( + items=['Soccer'], + value='Soccer', + disabled=state.visibility_level > 2, + on_change=state.set_sport_selection, + width='120px', + ), + ), + # Parameters selection + rx.cond( + state.visibility_level > 2, + title('Parameters', 'proportions'), + ), + rx.cond( + state.visibility_level > 2, + rx.text('Select parameters', size='1'), + ), + rx.cond( + state.visibility_level > 2, + parameters_selection(state), + ), + # Training parameters selection + rx.cond( + state.visibility_level > 3, + training_parameters_selection(state), + ), + rx.cond( + state.visibility_level > 4, + rx.button( + 'Save', + position='fixed', + top='620px', + left='275px', + width='70px', + on_click=state.download_dataloader, + ), + ), + # Control + control_buttons(state, state.visibility_level == 5), + **SIDEBAR_OPTIONS, + ), + rx.vstack( + rx.cond( + state.visibility_level == 5, + rx.hstack( + rx.heading( + 'Training data', size='7', position='fixed', left='450px', top='50px', color_scheme='blue' + ) + ), + ), + rx.hstack( + rx.vstack( + rx.cond(state.visibility_level == 5, rx.heading('Input')), + rx.cond( + state.visibility_level == 5, + ag_grid( + id='X_train', + row_data=state.X_train, + column_defs=state.X_train_cols, + height='200px', + width='250px', + theme='balham', + ), + ), + ), + rx.vstack( + rx.cond(state.visibility_level == 5, rx.heading('Output')), + rx.cond( + state.visibility_level == 5, + ag_grid( + id='Y_train', + row_data=state.Y_train, + column_defs=state.Y_train_cols, + height='200px', + width='250px', + theme='balham', + ), + ), + ), + rx.vstack( + rx.cond(state.visibility_level == 5, rx.heading('Odds')), + rx.cond( + state.visibility_level == 5, + ag_grid( + id='O_train', + row_data=state.O_train, + column_defs=state.O_train_cols, + height='200px', + width='250px', + theme='balham', + ), + ), + ), + position='fixed', + left='450px', + top='100px', + ), + ), + rx.vstack( + rx.cond( + state.visibility_level == 5, + rx.hstack( + rx.heading( + 'Fixtures data', size='7', position='fixed', left='450px', top='370px', color_scheme='blue' + ) + ), + ), + rx.cond( + state.visibility_level == 5, + rx.cond( + state.X_fix, + rx.hstack( + rx.vstack( + rx.cond(state.visibility_level == 5, rx.heading('Input')), + rx.cond( + state.visibility_level == 5, + ag_grid( + id='X_fix', + row_data=state.X_fix, + column_defs=state.X_fix_cols, + height='200px', + width='250px', + theme='balham', + ), + ), + ), + rx.vstack( + rx.cond(state.visibility_level == 5, rx.heading('Output')), + rx.cond( + state.visibility_level == 5, + ag_grid( + id='Y_fix', + row_data=[], + column_defs=[], + height='200px', + width='250px', + theme='balham', + ), + ), + ), + rx.vstack( + rx.cond(state.visibility_level == 5, rx.heading('Odds')), + rx.cond( + state.visibility_level == 5, + ag_grid( + id='O_fix', + row_data=state.O_fix, + column_defs=state.O_fix_cols, + height='200px', + width='250px', + theme='balham', + ), + ), + ), + position='fixed', + left='450px', + top='420px', + ), + rx.tooltip( + rx.icon('ban', position='fixed', left='450px', top='420px', size=60), + content='No fixtures were found. Try again later.', + ), + ), + ), + ), + ) diff --git a/gui/gui/components/dataloader/loading.py b/gui/gui/components/dataloader/loading.py new file mode 100644 index 0000000..7d6a674 --- /dev/null +++ b/gui/gui/components/dataloader/loading.py @@ -0,0 +1,122 @@ +"""Load dataloader components.""" + +from collections.abc import Callable + +import reflex as rx + +from ..common import SIDEBAR_OPTIONS, control_buttons, home, select_mode, title + + +def checkboxes(row: list[str], state: rx.State) -> rx.Component: + """Checkbox of parameter value.""" + + return rx.vstack( + rx.foreach( + row, + lambda name: rx.checkbox( + name, + disabled=True, + default_checked=state.param_checked[name.to_string()], + name=name.to_string(), + ), + ), + ) + + +def dialog(name: str, icon_name: str, state: rx.State) -> Callable: + """Dialog component.""" + + def _dialog(rows: list[list[str]]) -> rx.Component: + """The dialog component.""" + return rx.dialog.root( + rx.dialog.trigger( + rx.button( + rx.tooltip(rx.icon(icon_name), content=name), + size='4', + variant='outline', + disabled=state.visibility_level > 3, + ), + ), + rx.dialog.content( + rx.form.root( + rx.dialog.title(name), + rx.dialog.description( + f'{name} included in the training data.', + size="2", + margin_bottom="16px", + ), + rx.hstack(rx.foreach(rows, lambda row: checkboxes(row, state))), + width="100%", + ), + ), + ) + + return _dialog + + +def main(state: rx.State) -> rx.Component: + """Main container of UI.""" + return rx.container( + rx.vstack( + home(), + rx.divider(), + # Mode selection + title('Mode', 'blend'), + select_mode(state, 'Load a dataloader'), + # Dataloader selection + rx.cond( + state.visibility_level > 1, + title('Dataloader', 'database'), + ), + rx.cond( + state.visibility_level > 1, + rx.upload( + rx.vstack( + rx.button( + 'Select File', + bg='white', + color='rgb(107,99,246)', + border=f'1px solid rgb(107,99,246)', + disabled=state.dataloader_serialized.bool(), + ), + rx.text('Drag and drop', size='2'), + ), + id='dataloader', + multiple=False, + no_keyboard=True, + no_drag=state.dataloader_serialized.bool(), + on_drop=state.handle_upload(rx.upload_files(upload_id='dataloader')), + border='1px dotted blue', + padding='35px', + ), + ), + rx.cond( + state.dataloader_serialized, + rx.text(f'Dataloader: {state.dataloader_filename}', size='1'), + ), + # Parameters presentation + rx.cond( + state.visibility_level > 2, + title('Parameters', 'proportions'), + ), + rx.cond( + state.visibility_level > 2, + rx.hstack( + dialog('Leagues', 'earth', state)(state.all_leagues), + dialog('Years', 'calendar', state)(state.all_years), + dialog('Divisions', 'gauge', state)(state.all_divisions), + ), + ), + rx.cond( + state.visibility_level > 2, + rx.text(f'Odds type: {state.odds_type}', size='1'), + ), + rx.cond( + state.visibility_level > 2, + rx.text(f'Drop NA threshold of columns: {state.drop_na_thres}', size='1'), + ), + # Control + control_buttons(state, (~state.dataloader_serialized.bool()) | (state.visibility_level > 2)), + **SIDEBAR_OPTIONS, + ), + ) diff --git a/gui/gui/components/parameters.py b/gui/gui/components/parameters.py deleted file mode 100644 index 1c1491a..0000000 --- a/gui/gui/components/parameters.py +++ /dev/null @@ -1,72 +0,0 @@ -"""Page parameters components.""" - -from collections.abc import Callable - -import reflex as rx - - -def checkboxes(row: list[str], state: rx.State) -> rx.Component: - """Checkbox of parameter value.""" - return rx.vstack( - rx.foreach( - row, - lambda name: rx.checkbox( - name, - default_checked=state.default_param_checked[name.to_string()], - checked=state.param_checked[name.to_string()], - name=name.to_string(), - on_change=lambda checked: state.update_param_checked(name, checked), - ), - ), - ) - - -def dialog(name: str, icon_name: str, state: rx.State) -> Callable: - """Dialog component.""" - - def _dialog(rows: list[list[str]], on_submit: Callable) -> rx.Component: - """The dialog component.""" - return rx.dialog.root( - rx.dialog.trigger( - rx.button( - rx.tooltip(rx.icon(icon_name), content=name), - size='4', - variant='outline', - disabled=state.parameters_disabled, - ) - ), - rx.dialog.content( - rx.form.root( - rx.dialog.title(name), - rx.dialog.description( - f'Select the {name.lower()} to include in the training data.', - size="2", - margin_bottom="16px", - ), - rx.hstack(rx.foreach(rows, lambda row: checkboxes(row, state))), - rx.flex( - rx.dialog.close(rx.button('Submit', type='submit')), - justify='end', - spacing="3", - margin_top="50px", - ), - on_submit=on_submit, - reset_on_submit=False, - width="100%", - ), - ), - ) - - return _dialog - - -def parameters_selection(state: rx.State) -> rx.Component: - """The parameters title.""" - return rx.vstack( - rx.text('Leagues, years and divisions selection', size='1', hidden=state.parameters_disabled), - rx.hstack( - dialog('Leagues', 'earth', state)(state.all_leagues, state.handle_submit_leagues), - dialog('Years', 'calendar', state)(state.all_years, state.handle_submit_years), - dialog('Divisions', 'gauge', state)(state.all_divisions, state.handle_submit_divisions), - ), - ) diff --git a/gui/gui/components/training_parameters.py b/gui/gui/components/training_parameters.py deleted file mode 100644 index 3044cb7..0000000 --- a/gui/gui/components/training_parameters.py +++ /dev/null @@ -1,34 +0,0 @@ -"""Page training components.""" - -import reflex as rx - - -def training_parameters_selection(state: rx.State) -> rx.Component: - """The trianing parameters selection component.""" - return rx.vstack( - rx.vstack( - rx.text('Selection of odds type', size='1', hidden=state.training_disabled), - rx.select( - state.odds_types, - default_value=state.odds_types[0], - on_change=state.handle_odds_type, - disabled=state.training_disabled, - width='100%', - ), - ), - rx.vstack( - rx.text('Selection of NA columns threshold', size='1', hidden=state.training_disabled), - rx.slider( - min=0.0, - max=1.0, - step=0.01, - default_value=0.0, - on_change=state.handle_drop_na_thres, - disabled=state.training_disabled, - ), - style={ - 'margin-top': '15px', - 'width': '100%', - }, - ), - ) diff --git a/gui/gui/gui.py b/gui/gui/gui.py index 43b249a..004affb 100644 --- a/gui/gui/gui.py +++ b/gui/gui/gui.py @@ -1,5 +1,18 @@ """GUI of sports betting.""" import reflex as rx +from fastapi.responses import FileResponse + +from .pages.dataloader.creation import dataloader_creation +from .pages.dataloader.loading import dataloader_loading + + +async def dataloader() -> FileResponse: + """Dataloader endpoint.""" + return FileResponse(filename='dataloader.pkl', path='dataloader.pkl', media_type='application/octet-stream') + app = rx.App() +app.api.add_api_route("/dataloader", dataloader) +app.api.add_api_route("/dataloader/creation", dataloader_creation) +app.api.add_api_route("/dataloader/loading", dataloader_loading) diff --git a/gui/gui/pages/create/creation.py b/gui/gui/pages/create/creation.py new file mode 100644 index 0000000..ee3c52f --- /dev/null +++ b/gui/gui/pages/create/creation.py @@ -0,0 +1,306 @@ +"""Index page.""" + +from itertools import batched +from typing import Any, Self + +import cloudpickle +import nest_asyncio +import reflex as rx +from reflex.event import EventSpec +from reflex_ag_grid import ag_grid + +from sportsbet.datasets import SoccerDataLoader + +from ...components.dataloader.creation import main + +DATALOADERS = { + 'Soccer': SoccerDataLoader, +} +DEFAULT_PARAM_CHECKED = { + 'leagues': [ + '"England"', + '"Scotland"', + '"Germany"', + '"Italy"', + '"Spain"', + '"France"', + '"Netherlands"', + '"Belgium"', + '"Portugal"', + '"Turkey"', + '"Greece"', + ], + 'years': [ + '2020', + '2021', + '2022', + '2023', + '2024', + '2025', + ], + 'divisions': ['1', '2'], +} +DEFAULT_STATE_VALS = { + 'mode': { + 'category': 'Data', + 'type': 'Create', + }, + 'sport': { + 'selection': 'Soccer', + 'all_params': [], + 'all_leagues': [], + 'all_years': [], + 'all_divisions': [], + 'leagues': [], + 'years': [], + 'divisions': [], + 'params': [], + }, + 'parameters': { + 'checked': {}, + 'default_checked': DEFAULT_PARAM_CHECKED, + 'odds_types': [], + 'param_grid': [], + }, + 'training_parameters': { + 'odds_type': 'market_average', + 'drop_na_thres': [0.0], + }, + 'data': { + 'X_train': None, + 'Y_train': None, + 'O_train': None, + 'X_train_cols': None, + 'Y_train_cols': None, + 'O_train_cols': None, + 'X_fix': None, + 'O_fix': None, + 'X_fix_cols': None, + 'O_fix_cols': None, + }, +} + +nest_asyncio.apply() + + +class DataloaderCreationState(rx.State): + """The toolbox state.""" + + # Elements + visibility_level: int = 1 + loading: bool = False + + # Mode + mode_category: str = DEFAULT_STATE_VALS['mode']['category'] + mode_type: str = DEFAULT_STATE_VALS['mode']['type'] + + # Sport + sport_selection: str = DEFAULT_STATE_VALS['sport']['selection'] + all_params: list[dict[str, Any]] = DEFAULT_STATE_VALS['sport']['all_params'] + all_leagues: list[list[str]] = DEFAULT_STATE_VALS['sport']['all_leagues'] + all_years: list[list[str]] = DEFAULT_STATE_VALS['sport']['all_years'] + all_divisions: list[list[str]] = DEFAULT_STATE_VALS['sport']['all_divisions'] + leagues: list[str] = DEFAULT_STATE_VALS['sport']['leagues'] + years: list[str] = DEFAULT_STATE_VALS['sport']['years'] + divisions: list[str] = DEFAULT_STATE_VALS['sport']['divisions'] + params: list[dict[str, Any]] = DEFAULT_STATE_VALS['sport']['params'] + + # Parameters + param_checked: dict[str, bool] = DEFAULT_STATE_VALS['parameters']['checked'] + default_param_checked: dict[str, list[str]] = DEFAULT_STATE_VALS['parameters']['default_checked'] + odds_types: list[str] = DEFAULT_STATE_VALS['parameters']['odds_types'] + param_grid: list[dict] = DEFAULT_STATE_VALS['parameters']['param_grid'] + + # Training parameters + odds_type: str = DEFAULT_STATE_VALS['training_parameters']['odds_type'] + drop_na_thres: list = DEFAULT_STATE_VALS['training_parameters']['drop_na_thres'] + + # Data + dataloader_serialized: str | None = None + X_train: list | None = DEFAULT_STATE_VALS['data']['X_train'] + Y_train: list | None = DEFAULT_STATE_VALS['data']['Y_train'] + O_train: list | None = DEFAULT_STATE_VALS['data']['O_train'] + X_train_cols: list | None = DEFAULT_STATE_VALS['data']['X_train_cols'] + Y_train_cols: list | None = DEFAULT_STATE_VALS['data']['Y_train_cols'] + O_train_cols: list | None = DEFAULT_STATE_VALS['data']['O_train_cols'] + X_fix: list | None = DEFAULT_STATE_VALS['data']['X_fix'] + O_fix: list | None = DEFAULT_STATE_VALS['data']['O_fix'] + X_fix_cols: list | None = DEFAULT_STATE_VALS['data']['X_fix_cols'] + O_fix_cols: list | None = DEFAULT_STATE_VALS['data']['O_fix_cols'] + + def set_mode_category(self: Self, mode_category: str) -> None: + """Set the mode category.""" + self.mode_category = mode_category + + def set_mode_type(self: Self, mode_type: str) -> None: + """Set the mode category.""" + self.mode_type = mode_type + + def set_sport_selection(self: Self, sport_selection: str) -> None: + """Set the sport.""" + self.sport_selection = sport_selection + + @rx.event + def download_dataloader(self: Self) -> EventSpec: + """Download the dataloader.""" + dataloader = bytes(self.dataloader_serialized, 'iso8859_16') + return rx.download(data=dataloader, filename='dataloader.pkl') + + @staticmethod + def process_cols(col: str) -> str: + """Proces a column.""" + return " ".join([" ".join(token.split('_')).title() for token in col.split('__')]) + + @staticmethod + def process_form_data(form_data: dict[str, str]) -> list[str]: + """Process the form data.""" + return [key.replace('"', '') for key in form_data] + + def update_param_checked(self: Self, name: str | int, checked: bool) -> None: + """Update the parameters.""" + if isinstance(name, str): + name = f'"{name}"' + self.param_checked[name] = checked + + def update_params(self: Self) -> None: + """Update the parameters grid.""" + self.params = [ + params + for params in self.all_params + if params['league'] in self.leagues + and params['year'] in self.years + and params['division'] in self.divisions + ] + + def handle_submit_leagues(self: Self, leagues_form_data: dict) -> None: + """Handle the form submit.""" + self.leagues = self.process_form_data(leagues_form_data) + self.update_params() + + def handle_submit_years(self: Self, years_form_data: dict) -> None: + """Handle the form submit.""" + self.years = [int(year) for year in self.process_form_data(years_form_data)] + self.update_params() + + def handle_submit_divisions(self: Self, divisions_form_data: dict) -> None: + """Handle the form submit.""" + self.divisions = [int(division) for division in self.process_form_data(divisions_form_data)] + self.update_params() + + def handle_odds_type(self, odds_type: str) -> None: + """Handle the odds type selection.""" + self.odds_type = odds_type + + def handle_drop_na_thres(self, drop_na_thres: list) -> None: + """Handle the drop NA threshold selection.""" + self.drop_na_thres = drop_na_thres + + async def submit_state(self: Self) -> None: + """Submit handler.""" + self.loading = True + yield + if self.visibility_level == 1: + self.loading = False + yield + elif self.visibility_level == 2: + self.all_params = DATALOADERS[self.sport_selection].get_all_params() + self.all_leagues = list(batched(sorted({params['league'] for params in self.all_params}), 6)) + self.all_years = list(batched(sorted({params['year'] for params in self.all_params}), 5)) + self.all_divisions = list(batched(sorted({params['division'] for params in self.all_params}), 1)) + self.leagues = [league.replace('"', '') for league in DEFAULT_PARAM_CHECKED['leagues']] + self.years = [int(year) for year in DEFAULT_PARAM_CHECKED['years']] + self.divisions = [int(division) for division in DEFAULT_PARAM_CHECKED['divisions']] + self.loading = False + yield + elif self.visibility_level == 3: + self.update_params() + self.param_grid = [{k: [v] for k, v in param.items()} for param in self.params] + self.odds_types = DATALOADERS[self.sport_selection](self.param_grid).get_odds_types() + self.loading = False + yield + elif self.visibility_level == 4: + dataloader = DATALOADERS[self.sport_selection](self.param_grid) + X_train, Y_train, O_train = dataloader.extract_train_data( + odds_type=self.odds_type, + drop_na_thres=self.drop_na_thres[0], + ) + X_fix, _, O_fix = dataloader.extract_fixtures_data() + self.X_train = X_train.reset_index().to_dict('records') + self.X_train_cols = [ag_grid.column_def(field='date', header_name='Date')] + [ + ag_grid.column_def(field=col, header_name=self.process_cols(col)) for col in X_train.columns + ] + self.Y_train = Y_train.to_dict('records') + self.Y_train_cols = [ + ag_grid.column_def(field=col, header_name=self.process_cols(col)) for col in Y_train.columns + ] + self.O_train = O_train.to_dict('records') if O_train is not None else None + self.O_train_cols = ( + [ag_grid.column_def(field=col, header_name=self.process_cols(col)) for col in O_train.columns] + if O_train is not None + else None + ) + self.X_fix = X_fix.reset_index().to_dict('records') + self.X_fix_cols = [ag_grid.column_def(field='date', header_name='Date')] + [ + ag_grid.column_def(field=col, header_name=self.process_cols(col)) for col in X_fix.columns + ] + self.O_fix = O_fix.to_dict('records') if O_fix is not None else None + self.O_fix_cols = ( + [ag_grid.column_def(field=col, header_name=self.process_cols(col)) for col in O_fix.columns] + if O_fix is not None + else None + ) + self.dataloader_serialized = str(cloudpickle.dumps(dataloader), 'iso8859_16') + self.loading = False + yield + self.visibility_level += 1 + + def reset_state(self: Self) -> None: + """Reset handler.""" + + # Elements visibility + self.visibility_level = 1 + self.loading: bool = False + + # Mode + self.mode_category = DEFAULT_STATE_VALS['mode']['category'] + self.mode_type = DEFAULT_STATE_VALS['mode']['type'] + + # Sport + self.sport_selection = DEFAULT_STATE_VALS['sport']['selection'] + self.all_params = DEFAULT_STATE_VALS['sport']['all_params'] + self.all_leagues = DEFAULT_STATE_VALS['sport']['all_leagues'] + self.all_years = DEFAULT_STATE_VALS['sport']['all_years'] + self.all_divisions = DEFAULT_STATE_VALS['sport']['all_divisions'] + self.leagues = DEFAULT_STATE_VALS['sport']['leagues'] + self.years = DEFAULT_STATE_VALS['sport']['years'] + self.divisions = DEFAULT_STATE_VALS['sport']['divisions'] + self.params = DEFAULT_STATE_VALS['sport']['params'] + + # Parameters + self.param_checked = DEFAULT_STATE_VALS['parameters']['checked'] + self.default_param_checked = DEFAULT_STATE_VALS['parameters']['default_checked'] + self.odds_types = DEFAULT_STATE_VALS['parameters']['odds_types'] + self.param_grid = DEFAULT_STATE_VALS['parameters']['param_grid'] + + # Training + self.odds_type = DEFAULT_STATE_VALS['training_parameters']['odds_type'] + self.drop_na_thres = DEFAULT_STATE_VALS['training_parameters']['drop_na_thres'] + + # Data + self.dataloader_serialized = None + self.X_train = DEFAULT_STATE_VALS['data']['X_train'] + self.Y_train = DEFAULT_STATE_VALS['data']['Y_train'] + self.O_train = DEFAULT_STATE_VALS['data']['O_train'] + self.X_train_cols = DEFAULT_STATE_VALS['data']['X_train_cols'] + self.Y_train_cols = DEFAULT_STATE_VALS['data']['Y_train_cols'] + self.O_train_cols = DEFAULT_STATE_VALS['data']['O_train_cols'] + self.X_fix = DEFAULT_STATE_VALS['data']['X_fix'] + self.O_fix = DEFAULT_STATE_VALS['data']['O_fix'] + self.X_fix_cols = DEFAULT_STATE_VALS['data']['X_fix_cols'] + self.O_fix_cols = DEFAULT_STATE_VALS['data']['O_fix_cols'] + + +@rx.page(route="/dataloader/creation") +def dataloader_creation() -> rx.Component: + """Main page.""" + return main(DataloaderCreationState) diff --git a/gui/gui/pages/dataloader/__init__.py b/gui/gui/pages/dataloader/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gui/gui/pages/dataloader/creation.py b/gui/gui/pages/dataloader/creation.py new file mode 100644 index 0000000..e06fa92 --- /dev/null +++ b/gui/gui/pages/dataloader/creation.py @@ -0,0 +1,302 @@ +"""Index page.""" + +from itertools import batched +from typing import Any, Self + +import cloudpickle +import nest_asyncio +import reflex as rx +from reflex.event import EventSpec +from reflex_ag_grid import ag_grid + +from sportsbet.datasets import SoccerDataLoader + +from ...components.dataloader.creation import main +from ..index import State + +DATALOADERS = { + 'Soccer': SoccerDataLoader, +} +DEFAULT_PARAM_CHECKED = { + 'leagues': [ + '"England"', + '"Scotland"', + '"Germany"', + '"Italy"', + '"Spain"', + '"France"', + '"Netherlands"', + '"Belgium"', + '"Portugal"', + '"Turkey"', + '"Greece"', + ], + 'years': [ + '2020', + '2021', + '2022', + '2023', + '2024', + '2025', + ], + 'divisions': ['1', '2'], +} +DEFAULT_STATE_VALS = { + 'mode': { + 'category': 'Data', + 'type': 'Create', + }, + 'sport': { + 'selection': 'Soccer', + 'all_params': [], + 'all_leagues': [], + 'all_years': [], + 'all_divisions': [], + 'leagues': [], + 'years': [], + 'divisions': [], + 'params': [], + }, + 'parameters': { + 'checked': {}, + 'default_checked': DEFAULT_PARAM_CHECKED, + 'odds_types': [], + 'param_grid': [], + }, + 'training_parameters': { + 'odds_type': 'market_average', + 'drop_na_thres': [0.0], + }, + 'data': { + 'X_train': None, + 'Y_train': None, + 'O_train': None, + 'X_train_cols': None, + 'Y_train_cols': None, + 'O_train_cols': None, + 'X_fix': None, + 'O_fix': None, + 'X_fix_cols': None, + 'O_fix_cols': None, + }, +} + +nest_asyncio.apply() + + +class DataloaderCreationState(State): + """The toolbox state.""" + + # Sport + sport_selection: str = DEFAULT_STATE_VALS['sport']['selection'] + all_params: list[dict[str, Any]] = DEFAULT_STATE_VALS['sport']['all_params'] + all_leagues: list[list[str]] = DEFAULT_STATE_VALS['sport']['all_leagues'] + all_years: list[list[str]] = DEFAULT_STATE_VALS['sport']['all_years'] + all_divisions: list[list[str]] = DEFAULT_STATE_VALS['sport']['all_divisions'] + leagues: list[str] = DEFAULT_STATE_VALS['sport']['leagues'] + years: list[str] = DEFAULT_STATE_VALS['sport']['years'] + divisions: list[str] = DEFAULT_STATE_VALS['sport']['divisions'] + params: list[dict[str, Any]] = DEFAULT_STATE_VALS['sport']['params'] + + # Parameters + param_checked: dict[str, bool] = DEFAULT_STATE_VALS['parameters']['checked'] + default_param_checked: dict[str, list[str]] = DEFAULT_STATE_VALS['parameters']['default_checked'] + odds_types: list[str] = DEFAULT_STATE_VALS['parameters']['odds_types'] + param_grid: list[dict] = DEFAULT_STATE_VALS['parameters']['param_grid'] + + # Training parameters + odds_type: str = DEFAULT_STATE_VALS['training_parameters']['odds_type'] + drop_na_thres: list = DEFAULT_STATE_VALS['training_parameters']['drop_na_thres'] + + # Data + dataloader_serialized: str | None = None + X_train: list | None = DEFAULT_STATE_VALS['data']['X_train'] + Y_train: list | None = DEFAULT_STATE_VALS['data']['Y_train'] + O_train: list | None = DEFAULT_STATE_VALS['data']['O_train'] + X_train_cols: list | None = DEFAULT_STATE_VALS['data']['X_train_cols'] + Y_train_cols: list | None = DEFAULT_STATE_VALS['data']['Y_train_cols'] + O_train_cols: list | None = DEFAULT_STATE_VALS['data']['O_train_cols'] + X_fix: list | None = DEFAULT_STATE_VALS['data']['X_fix'] + O_fix: list | None = DEFAULT_STATE_VALS['data']['O_fix'] + X_fix_cols: list | None = DEFAULT_STATE_VALS['data']['X_fix_cols'] + O_fix_cols: list | None = DEFAULT_STATE_VALS['data']['O_fix_cols'] + + def set_mode_category(self: Self, mode_category: str) -> None: + """Set the mode category.""" + self.mode_category = mode_category + + def set_mode_type(self: Self, mode_type: str) -> None: + """Set the mode category.""" + self.mode_type = mode_type + + def set_sport_selection(self: Self, sport_selection: str) -> None: + """Set the sport.""" + self.sport_selection = sport_selection + + @rx.event + def download_dataloader(self: Self) -> EventSpec: + """Download the dataloader.""" + dataloader = bytes(self.dataloader_serialized, 'iso8859_16') + return rx.download(data=dataloader, filename='dataloader.pkl') + + @staticmethod + def process_cols(col: str) -> str: + """Proces a column.""" + return " ".join([" ".join(token.split('_')).title() for token in col.split('__')]) + + @staticmethod + def process_form_data(form_data: dict[str, str]) -> list[str]: + """Process the form data.""" + return [key.replace('"', '') for key in form_data] + + def update_param_checked(self: Self, name: str | int, checked: bool) -> None: + """Update the parameters.""" + if isinstance(name, str): + name = f'"{name}"' + self.param_checked[name] = checked + + def update_params(self: Self) -> None: + """Update the parameters grid.""" + self.params = [ + params + for params in self.all_params + if params['league'] in self.leagues + and params['year'] in self.years + and params['division'] in self.divisions + ] + + def handle_submit_leagues(self: Self, leagues_form_data: dict) -> None: + """Handle the form submit.""" + self.leagues = self.process_form_data(leagues_form_data) + self.update_params() + + def handle_submit_years(self: Self, years_form_data: dict) -> None: + """Handle the form submit.""" + self.years = [int(year) for year in self.process_form_data(years_form_data)] + self.update_params() + + def handle_submit_divisions(self: Self, divisions_form_data: dict) -> None: + """Handle the form submit.""" + self.divisions = [int(division) for division in self.process_form_data(divisions_form_data)] + self.update_params() + + def handle_odds_type(self, odds_type: str) -> None: + """Handle the odds type selection.""" + self.odds_type = odds_type + + def handle_drop_na_thres(self, drop_na_thres: list) -> None: + """Handle the drop NA threshold selection.""" + self.drop_na_thres = drop_na_thres + + def submit_state(self: Self) -> None: + """Submit handler.""" + + self.loading = True + yield + if self.visibility_level == 1: + self.loading = False + yield + elif self.visibility_level == 2: + self.all_params = DATALOADERS[self.sport_selection].get_all_params() + self.all_leagues = list(batched(sorted({params['league'] for params in self.all_params}), 6)) + self.all_years = list(batched(sorted({params['year'] for params in self.all_params}), 5)) + self.all_divisions = list(batched(sorted({params['division'] for params in self.all_params}), 1)) + self.leagues = [league.replace('"', '') for league in DEFAULT_PARAM_CHECKED['leagues']] + self.years = [int(year) for year in DEFAULT_PARAM_CHECKED['years']] + self.divisions = [int(division) for division in DEFAULT_PARAM_CHECKED['divisions']] + self.loading = False + yield + elif self.visibility_level == 3: + self.update_params() + self.param_grid = [{k: [v] for k, v in param.items()} for param in self.params] + self.odds_types = DATALOADERS[self.sport_selection](self.param_grid).get_odds_types() + self.loading = False + yield + elif self.visibility_level == 4: + dataloader = DATALOADERS[self.sport_selection](self.param_grid) + X_train, Y_train, O_train = dataloader.extract_train_data( + odds_type=self.odds_type, + drop_na_thres=self.drop_na_thres[0], + ) + X_fix, _, O_fix = dataloader.extract_fixtures_data() + self.X_train = X_train.reset_index().to_dict('records') + self.X_train_cols = [ag_grid.column_def(field='date', header_name='Date')] + [ + ag_grid.column_def(field=col, header_name=self.process_cols(col)) for col in X_train.columns + ] + self.Y_train = Y_train.to_dict('records') + self.Y_train_cols = [ + ag_grid.column_def(field=col, header_name=self.process_cols(col)) for col in Y_train.columns + ] + self.O_train = O_train.to_dict('records') if O_train is not None else None + self.O_train_cols = ( + [ag_grid.column_def(field=col, header_name=self.process_cols(col)) for col in O_train.columns] + if O_train is not None + else None + ) + self.X_fix = X_fix.reset_index().to_dict('records') + self.X_fix_cols = [ag_grid.column_def(field='date', header_name='Date')] + [ + ag_grid.column_def(field=col, header_name=self.process_cols(col)) for col in X_fix.columns + ] + self.O_fix = O_fix.to_dict('records') if O_fix is not None else None + self.O_fix_cols = ( + [ag_grid.column_def(field=col, header_name=self.process_cols(col)) for col in O_fix.columns] + if O_fix is not None + else None + ) + self.dataloader_serialized = str(cloudpickle.dumps(dataloader), 'iso8859_16') + self.loading = False + yield + self.visibility_level += 1 + + def reset_state(self: Self) -> None: + """Reset handler.""" + + # Elements visibility + self.visibility_level = 1 + self.loading: bool = False + + # Mode + self.mode_category = 'Data' + self.mode_type = 'Create' + + # Data + self.dataloader_serialized = None + + # Sport + self.sport_selection = DEFAULT_STATE_VALS['sport']['selection'] + self.all_params = DEFAULT_STATE_VALS['sport']['all_params'] + self.all_leagues = DEFAULT_STATE_VALS['sport']['all_leagues'] + self.all_years = DEFAULT_STATE_VALS['sport']['all_years'] + self.all_divisions = DEFAULT_STATE_VALS['sport']['all_divisions'] + self.leagues = DEFAULT_STATE_VALS['sport']['leagues'] + self.years = DEFAULT_STATE_VALS['sport']['years'] + self.divisions = DEFAULT_STATE_VALS['sport']['divisions'] + self.params = DEFAULT_STATE_VALS['sport']['params'] + + # Parameters + self.param_checked = DEFAULT_STATE_VALS['parameters']['checked'] + self.default_param_checked = DEFAULT_STATE_VALS['parameters']['default_checked'] + self.odds_types = DEFAULT_STATE_VALS['parameters']['odds_types'] + self.param_grid = DEFAULT_STATE_VALS['parameters']['param_grid'] + + # Training + self.odds_type = DEFAULT_STATE_VALS['training_parameters']['odds_type'] + self.drop_na_thres = DEFAULT_STATE_VALS['training_parameters']['drop_na_thres'] + + # Data + self.X_train = DEFAULT_STATE_VALS['data']['X_train'] + self.Y_train = DEFAULT_STATE_VALS['data']['Y_train'] + self.O_train = DEFAULT_STATE_VALS['data']['O_train'] + self.X_train_cols = DEFAULT_STATE_VALS['data']['X_train_cols'] + self.Y_train_cols = DEFAULT_STATE_VALS['data']['Y_train_cols'] + self.O_train_cols = DEFAULT_STATE_VALS['data']['O_train_cols'] + self.X_fix = DEFAULT_STATE_VALS['data']['X_fix'] + self.O_fix = DEFAULT_STATE_VALS['data']['O_fix'] + self.X_fix_cols = DEFAULT_STATE_VALS['data']['X_fix_cols'] + self.O_fix_cols = DEFAULT_STATE_VALS['data']['O_fix_cols'] + + +@rx.page(route="/dataloader/creation") +def dataloader_creation() -> rx.Component: + """Main page.""" + return main(DataloaderCreationState) diff --git a/gui/gui/pages/dataloader/loading.py b/gui/gui/pages/dataloader/loading.py new file mode 100644 index 0000000..fdf163e --- /dev/null +++ b/gui/gui/pages/dataloader/loading.py @@ -0,0 +1,103 @@ +"""Index page.""" + +from itertools import batched +from pathlib import Path +from typing import Self + +import cloudpickle +import nest_asyncio +import reflex as rx + +from sportsbet.datasets import SoccerDataLoader + +from ...components.dataloader.loading import dialog, main +from ..index import State + +DATALOADERS = { + 'Soccer': SoccerDataLoader, +} +DEFAULT_STATE_VALS = { + 'mode': { + 'category': 'Data', + 'type': 'Create', + }, +} + +nest_asyncio.apply() + + +class DataloaderLoadingState(State): + """The toolbox state.""" + + # Data + dataloader_serialized: str | None = None + dataloader_filename: str | None = None + all_leagues: list[list[str]] = [] + all_years: list[list[str]] = [] + all_divisions: list[list[str]] = [] + param_checked: dict[str, bool] = {} + odds_type: str | None = None + drop_na_thres: float | None = None + + @rx.event + async def handle_upload(self: Self, files: list[rx.UploadFile]) -> None: + """Handle the upload of files.""" + self.loading = True + yield + for file in files: + dataloader = await file.read() + self.dataloader_serialized = str(dataloader, 'iso8859_16') + self.dataloader_filename = Path(file.filename).name + self.loading = False + yield + + def submit_state(self: Self) -> None: + """Submit handler.""" + self.loading = True + yield + if self.visibility_level == 1: + self.loading = False + yield + elif self.visibility_level == 2: + dataloader = cloudpickle.loads(bytes(self.dataloader_serialized, 'iso8859_16')) + all_params = dataloader.get_all_params() + self.all_leagues = list(batched(sorted({params['league'] for params in all_params}), 6)) + self.all_years = list(batched(sorted({params['year'] for params in all_params}), 5)) + self.all_divisions = list(batched(sorted({params['division'] for params in all_params}), 1)) + self.loading = False + self.param_checked = { + **{f'"{key}"': True for key in {params['league'] for params in dataloader.param_grid_}}, + **{key: True for key in {params['year'] for params in dataloader.param_grid_}}, + **{key: True for key in {params['division'] for params in dataloader.param_grid_}}, + } + self.odds_type = dataloader.odds_type_ + self.drop_na_thres = dataloader.drop_na_thres_ + yield + self.visibility_level += 1 + + def reset_state(self: Self) -> None: + """Reset handler.""" + + # Elements visibility + self.visibility_level = 1 + self.loading: bool = False + + # Mode + self.mode_category = 'Data' + self.mode_type = 'Create' + + # Data + self.dataloader_serialized = None + self.dataloader_filename = None + self.all_leagues = [] + self.all_years = [] + self.all_divisions = [] + self.param_checked = {} + self.odds_type = None + self.drop_na_thres = None + + +@rx.page(route="/dataloader/loading") +def dataloader_loading() -> rx.Component: + """Main page.""" + return main(DataloaderLoadingState) diff --git a/gui/gui/pages/index.py b/gui/gui/pages/index.py index 9a4b905..23fd55b 100644 --- a/gui/gui/pages/index.py +++ b/gui/gui/pages/index.py @@ -1,268 +1,160 @@ """Index page.""" -from itertools import batched -from typing import Any, Self +from typing import Self -import nest_asyncio import reflex as rx -from sportsbet.datasets import SoccerDataLoader - -from ..components.common import header, home, selection, title -from ..components.parameters import parameters_selection -from ..components.training_parameters import training_parameters_selection - -SIDEBAR_OPTIONS = { - 'spacing': '5', - 'position': 'fixed', - 'left': '50px', - 'top': '200px', - 'padding_x': '1em', - 'padding_y': "1.5em", - 'bg': rx.color('blue', 3), - 'height': '750px', - 'width': '35em', -} -DATALOADERS = { - 'Soccer': SoccerDataLoader, -} -DEFAULT_PARAM_CHECKED = { - '"England"': True, - '"Scotland"': True, - '"Germany"': True, - '"Italy"': True, - '"Spain"': True, - '"France"': True, - '"Netherlands"': True, - '"Belgium"': True, - '"Portugal"': True, - '"Turkey"': True, - '"Greece"': True, - 2018: True, - 2019: True, - 2020: True, - 2021: True, - 2022: True, - 2023: True, - 2024: True, - 2025: True, - 1: True, - 2: True, -} - -nest_asyncio.apply() +from ..components.common import SIDEBAR_OPTIONS, home, title class State(rx.State): """The toolbox state.""" - # Task - task: str | None = None - task_disabled: bool = False - - # Sport - sport: str | None = None - sport_disabled: bool = False - all_params: list[dict[str, Any]] = [] - all_leagues: list[list[str]] = [] - all_years: list[list[str]] = [] - all_divisions: list[list[str]] = [] - leagues: list[str] = [] - years: list[str] = [] - divisions: list[str] = [] - params: list[dict[str, Any]] = [] - - # Parameters - parameters_disabled: bool = False - parameters_loading: bool = False - param_checked: dict[str, bool] = {} - default_param_checked: dict[str, bool] = DEFAULT_PARAM_CHECKED - odds_types: list[str] = [] - param_grid: list[dict] = [] - - # Training - training_disabled = False - training_loading: bool = False - odds_type: str | None = None - drop_na_thres: list | None = [0.0] - - @staticmethod - def process_form_data(form_data: dict[str, str]) -> list[str]: - """Process the form data.""" - return [key.replace('"', '') for key in form_data] - - def update_param_checked(self: Self, name: str | int, checked: bool) -> None: - """Update the parameters.""" - if isinstance(name, str): - name = f'"{name}"' - self.param_checked[name] = checked - - def update_params(self: Self) -> None: - """Update the parameters grid.""" - self.params = [ - params - for params in self.all_params - if params['league'] in self.leagues - and params['year'] in self.years - and params['division'] in self.divisions - ] - - def handle_submit_leagues(self: Self, leagues_form_data: dict) -> None: - """Handle the form submit.""" - self.leagues = self.process_form_data(leagues_form_data) - self.update_params() - - def handle_submit_years(self: Self, years_form_data: dict) -> None: - """Handle the form submit.""" - self.years = [int(year) for year in self.process_form_data(years_form_data)] - self.update_params() + # Elements + visibility_level: int = 1 + loading: bool = False - def handle_submit_divisions(self: Self, divisions_form_data: dict) -> None: - """Handle the form submit.""" - self.divisions = [int(division) for division in self.process_form_data(divisions_form_data)] - self.update_params() + # Mode + mode_category: str = 'Data' + mode_type: str = 'Create' - def handle_odds_type(self, odds_type: str) -> None: - """Handle the odds type selection.""" - self.odds_type = odds_type - - def handle_drop_na_thres(self, drop_na_thres: list) -> None: - """Handle the drop NA threshold selection.""" - self.drop_na_thres = drop_na_thres + def submit_state(self: Self) -> None: + """Submit handler.""" + self.loading = True + yield + if self.visibility_level == 1: + self.loading = False + yield + self.visibility_level += 1 def reset_state(self: Self) -> None: - """Reset the dataloader state.""" - - # Task - self.task = None - self.task_disabled = False - - # Sport - self.sport = None - self.sport_disabled = False - self.all_params = [] - self.all_leagues = [] - self.all_years = [] - self.all_divisions = [] - - # Parameters - self.parameters_disabled = False - self.parameters_loading = False - self.leagues = [] - self.years = [] - self.divisions = [] - self.params = [] - self.param_checked = {} - self.odds_types = [] - self.param_grid = [] - - # Training - self.training_disabled = False - self.training_loading = False - self.odds_type = None - self.drop_na_thres = None + """Reset handler.""" - def set_task(self: Self, task: str) -> None: - """Set the task.""" - self.task = task - self.task_disabled = True + # Elements visibility + self.visibility_level = 1 + self.loading: bool = False - def set_sport(self: Self, sport: str) -> None: - """Set the sport.""" - self.sport = sport - self.sport_disabled = True - yield - self.all_params = DATALOADERS[self.sport].get_all_params() - self.all_leagues = list(batched(sorted({params['league'] for params in self.all_params}), 6)) - self.all_years = list(batched(sorted({params['year'] for params in self.all_params}), 5)) - self.all_divisions = list(batched(sorted({params['division'] for params in self.all_params}), 1)) - self.leagues = [league for row in self.all_leagues for league in row] - self.years = [year for row in self.all_years for year in row] - self.divisions = [division for row in self.all_divisions for division in row] - - def set_parameters(self: Self) -> None: - """Set the parameters.""" - self.parameters_disabled = True - self.parameters_loading = True - yield - self.update_params() - self.param_grid = [{k: [v] for k, v in param.items()} for param in self.params] - self.odds_types = DATALOADERS[self.sport](self.param_grid).get_odds_types() - - def set_training_parameters(self: Self) -> None: - """Set the training parameters.""" - self.training_disabled = True - self.training_loading = True - yield - X_train, Y_train, O_train = DATALOADERS[self.sport](self.param_grid).extract_train_data( - odds_type=self.odds_type, drop_na_thres=self.drop_na_thres[0] - ) - self.training_loading = False + # Mode + self.mode_category = 'Data' + self.mode_type = 'Create' @rx.page(route="/") def index() -> rx.Component: """Main page.""" - return rx.box( - header(), - rx.box( + return rx.container( + rx.vstack( + home(), + rx.divider(), + # Mode selection + title('Mode', 'blend'), rx.vstack( - home(State.reset_state), - rx.divider(), - # Task selection - title('Task', 'arrow-up-down'), - rx.text('Data or modelling task selection', size='1', hidden=State.task_disabled), - selection(['Data', 'Modelling'], State.task, State.task_disabled, State.set_task), - # Sport selection rx.cond( - State.task == 'Data', - title('Sport', 'medal'), + (State.mode_category == 'Data') & (State.mode_type == 'Create'), + rx.text('Create a dataloader', size='1'), ), rx.cond( - State.task == 'Data', - rx.text('Sport selection', size='1', hidden=State.sport_disabled), + (State.mode_category == 'Data') & (State.mode_type == 'Load'), + rx.text('Load a dataloader', size='1'), ), rx.cond( - State.task == 'Data', - selection(['Soccer'], State.sport, State.sport_disabled, State.set_sport), + (State.mode_category == 'Modelling') & (State.mode_type == 'Create'), + rx.text('Create a model', size='1'), ), - # Parameters selection rx.cond( - State.sport == 'Soccer', - title('Parameters', 'proportions'), + (State.mode_category == 'Modelling') & (State.mode_type == 'Load'), + rx.text('Load a model', size='1'), ), - rx.cond( - State.sport == 'Soccer', - parameters_selection(State), + rx.hstack( + rx.select( + items=['Data', 'Modelling'], + value=State.mode_category, + disabled=State.visibility_level > 1, + width='120px', + on_change=State.set_mode_category, + ), + rx.select( + ['Create', 'Load'], + value=State.mode_type, + disabled=State.visibility_level > 1, + width='120px', + on_change=State.set_mode_type, + ), ), - rx.cond( - State.sport == 'Soccer', + ), + # Control + rx.vstack( + rx.divider(top='600px', position='fixed', width='18em'), + rx.hstack( rx.cond( - State.odds_types.to_string() == '[]', - rx.button( - 'Submit', - on_click=State.set_parameters, - loading=State.parameters_loading, - disabled=State.parameters_disabled, + (State.mode_category == 'Data') & (State.mode_type == 'Create'), + rx.link( + rx.button( + 'Submit', + on_click=State.submit_state, + loading=State.loading, + position='fixed', + top='620px', + width='70px', + ), + href='/dataloader/creation', ), ), - ), - # Training parameters selection - rx.cond( - State.odds_types, - training_parameters_selection(State), - ), - rx.cond( - State.odds_types, - rx.button( - 'Submit', - on_click=State.set_training_parameters, - loading=State.training_loading, - disabled=State.training_disabled, + rx.cond( + (State.mode_category == 'Data') & (State.mode_type == 'Load'), + rx.link( + rx.button( + 'Submit', + on_click=State.submit_state, + loading=State.loading, + position='fixed', + top='620px', + width='70px', + ), + href='/dataloader/loading', + ), + ), + rx.cond( + (State.mode_category == 'Modelling') & (State.mode_type == 'Create'), + rx.link( + rx.button( + 'Submit', + on_click=State.submit_state, + loading=State.loading, + position='fixed', + top='620px', + width='70px', + ), + href='/model/creation', + ), + ), + rx.cond( + (State.mode_category == 'Modelling') & (State.mode_type == 'Load'), + rx.link( + rx.button( + 'Submit', + on_click=State.submit_state, + loading=State.loading, + position='fixed', + top='620px', + width='70px', + ), + href='/model/loading', + ), + ), + rx.link( + rx.button( + 'Reset', + on_click=State.reset_state, + position='fixed', + top='620px', + left='150px', + width='70px', + ), + href='/', ), ), - # Options - **SIDEBAR_OPTIONS, ), + **SIDEBAR_OPTIONS, ), ) diff --git a/gui/rxconfig.py b/gui/rxconfig.py index 800d562..7718623 100644 --- a/gui/rxconfig.py +++ b/gui/rxconfig.py @@ -1,3 +1,5 @@ +"""Configuration of the application.""" + import reflex as rx config = rx.Config(