Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[0.2.dev7] MergedChoiceTable check for duplicate column names #54

Merged
merged 2 commits into from
Dec 12, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion choicemodels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

from .mnl import MultinomialLogit, MultinomialLogitResults

version = __version__ = '0.2.dev6'
version = __version__ = '0.2.dev7'
23 changes: 19 additions & 4 deletions choicemodels/tools/mergedchoicetable.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,14 @@ class MergedChoiceTable(object):
observations : pandas.DataFrame
Table with one row for each chooser or choice scenario, with unique ID's in the
index field. Additional columns can contain fixed attributes of the choosers.
Index name is set to 'obs_id' if none provided.
Index name is set to 'obs_id' if none provided. All observation/alternative
column names must be unique except for the join key.

alternatives : pandas.DataFrame
Table with one row for each alternative, with unique ID's in the index field.
Additional columns can contain fixed attributes of the alternatives. Index name
is set to 'alt_id' if none provided.
is set to 'alt_id' if none provided. All observation/alternative column names
must be unique except for the join key.

chosen_alternatives : str or pandas.Series, optional
List of the alternative ID selected in each choice scenario. (This is required for
Expand Down Expand Up @@ -121,12 +123,25 @@ def __init__(self, observations, alternatives, chosen_alternatives=None,
# TO DO - check that dfs have unique indexes
# TO DO - check that chosen_alternatives correspond correctly to other dfs
# TO DO - same with weights (could join onto other tables and then split off)
# TO DO - check for overlapping column names

# Normalize chosen_alternatives to a pd.Series
if (chosen_alternatives is not None) & isinstance(chosen_alternatives, str):
chosen_alternatives = observations[chosen_alternatives]
chosen_alternatives = observations[chosen_alternatives].copy()
observations = observations.drop(chosen_alternatives.name, axis='columns')
chosen_alternatives.name = '_' + alternatives.index.name # avoids conflicts

# Check for duplicate column names
obs_cols = list(observations.columns) + list(observations.index.names)
alt_cols = list(alternatives.columns) + list(alternatives.index.names)
dupes = [c for c in obs_cols if c in alt_cols]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this using lists (O(a*o)) instead of sets (O(min(a,o)))? like:

        obs_cols = set(observations.columns) + set(observations.index.names)
        alt_cols = set(alternatives.columns) + set(alternatives.index.names)
        dupes = obs_cols & in alt_cols

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Eh2406 Thanks, that's definitely better! Unfortunately i just merged this PR, but i'll update this in the next one

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


if len(dupes) == 1:
raise ValueError("Column '{}' appears in both input tables. Please ensure "
"column names are unique before merging".format(dupes[0]))
elif len(dupes) > 1:
raise ValueError("Columns '{}' appear in both input tables. Please ensure "
"column names are unique before merging"\
.format("', '".join(dupes)))

# Normalize weights to a pd.Series
if (weights is not None) & isinstance(weights, str):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

setup(
name='choicemodels',
version='0.2.dev6',
version='0.2.dev7',
description='Tools for discrete choice estimation',
long_description=long_description,
author='UDST',
Expand Down
69 changes: 58 additions & 11 deletions tests/test_mct.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,25 @@
import choicemodels
from choicemodels.tools import MergedChoiceTable

d1 = {'oid': [0,1],
'obsval': [6,8],
'choice': [1,2]}

d2 = {'aid': [0,1,2],
'altval': [10,20,30],
'w': [1,1,100]}
@pytest.fixture
def obs():
d1 = {'oid': [0,1],
'obsval': [6,8],
'choice': [1,2]}

return pd.DataFrame(d1).set_index('oid')

@pytest.fixture
def alts():
d2 = {'aid': [0,1,2],
'altval': [10,20,30],
'w': [1,1,100]}

obs = pd.DataFrame(d1).set_index('oid')
alts = pd.DataFrame(d2).set_index('aid')
return pd.DataFrame(d2).set_index('aid')


def test_mergedchoicetable():
def test_mergedchoicetable(obs, alts):
# NO SAMPLING, TABLE FOR SIMULATION

mct = choicemodels.tools.MergedChoiceTable(obs, alts).to_frame()
Expand Down Expand Up @@ -163,7 +169,7 @@ def test_mergedchoicetable():
chosen_alternatives = 'choice').to_frame()


def test_no_alternatives():
def test_no_alternatives(obs, alts):
"""
Empty alternatives should produce empty choice table.

Expand All @@ -172,12 +178,53 @@ def test_no_alternatives():
assert len(mct) == 0


def test_no_choosers():
def test_no_choosers(obs, alts):
"""
Empty observations should produce empty choice table.

"""
mct = MergedChoiceTable(pd.DataFrame(), alts).to_frame()
assert len(mct) == 0


def test_dupe_column(obs, alts):
"""
Duplicate column names should raise an error.

"""
obs['save_the_whales'] = None
alts['save_the_whales'] = None

try:
MergedChoiceTable(obs, alts)
except ValueError as e:
print(e)


def test_multiple_dupe_columns(obs, alts):
"""
Duplicate column names should raise an error. This covers the case of multiple
columns, and the case of an index conflicting with a non-index.

"""
obs['save_the_whales'] = None
alts['save_the_whales'] = None
alts[obs.index.name] = None

try:
MergedChoiceTable(obs, alts)
except ValueError as e:
print(e)


def test_join_key_name_conflict(obs, alts):
"""
Duplicate column names are not allowed, except for the join key -- it's fine for the
chosen_alternatives column in the observations to have the same name as the index of
the alternatives. This test should run without raising an error.

"""
obs[alts.index.name] = obs.choice
MergedChoiceTable(obs, alts, chosen_alternatives=alts.index.name)