Skip to content

Commit

Permalink
test plotting and catch quilt error (#79)
Browse files Browse the repository at this point in the history
* add plotting tests

* catch error when quilt data have been updated on the server

* typo in requirements_tests

* include geosnap image
  • Loading branch information
knaaptime authored Apr 29, 2019
1 parent 1cc562a commit dd7c9f5
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 36 deletions.
Binary file added doc/geosnap.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
82 changes: 46 additions & 36 deletions geosnap/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
import quilt

import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
sys.path.insert(0,
os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from util import adjust_inflation, convert_gdf


try:
from quilt.data.spatialucr import census
except ImportError:
Expand Down Expand Up @@ -46,13 +46,20 @@ def __dir__(self):

_package_directory = os.path.dirname(os.path.abspath(__file__))
_cbsa = pd.read_parquet(os.path.join(_package_directory, 'cbsas.parquet'))

dictionary = pd.read_csv(os.path.join(_package_directory, "variables.csv"))

states = census.states()
counties = census.counties()
tracts = census.tracts_2010
metros = convert_gdf(census.msas())
try: # if any of these aren't found, the user needs to refresh the quilt data package
states = census.states()
counties = census.counties()
tracts = census.tracts_2010
metros = convert_gdf(census.msas())
except AttributeError:
warn(
'Quilt data is outdated... rebuilding\n'
' You will need to restart your Python kernel once downloading has completed'
)
quilt.install("spatialucr/census", force=True)
quilt.install("spatialucr/census_cartographic", force=True)


def _db_checker(database):
Expand All @@ -69,12 +76,11 @@ def _db_checker(database):


#: A dict containing tabular data available to geosnap
db = Bunch(census_90=census.variables_1990(),
census_00=census.variables_2000(),
ltdb=_db_checker('ltdb'),
ncdb=_db_checker('ncdb')
)

db = Bunch(
census_90=census.variables_1990(),
census_00=census.variables_2000(),
ltdb=_db_checker('ltdb'),
ncdb=_db_checker('ncdb'))

# LTDB importer

Expand Down Expand Up @@ -135,14 +141,16 @@ def _ltdb_reader(path, file, year, dropcols=None):

df["year"] = year

inflate_cols = ["mhmval", "mrent", "incpc",
"hinc", "hincw", "hincb", "hinch", "hinca"]
inflate_cols = [
"mhmval", "mrent", "incpc", "hinc", "hincw", "hincb", "hinch",
"hinca"
]

inflate_available = list(set(df.columns).intersection(set(
inflate_cols)))
inflate_available = list(
set(df.columns).intersection(set(inflate_cols)))

if len(inflate_available):
# try:
# try:
df = adjust_inflation(df, inflate_available, year)
# except KeyError: # half the dfs don't have these variables
# pass
Expand Down Expand Up @@ -206,8 +214,8 @@ def _ltdb_reader(path, file, year, dropcols=None):
fullcount00.iloc[:, 7:], how="left")
ltdb_2010 = sample10

df = pd.concat(
[ltdb_1970, ltdb_1980, ltdb_1990, ltdb_2000, ltdb_2010], sort=True)
df = pd.concat([ltdb_1970, ltdb_1980, ltdb_1990, ltdb_2000, ltdb_2010],
sort=True)

renamer = dict(
zip(dictionary['ltdb'].tolist(), dictionary['variable'].tolist()))
Expand All @@ -218,7 +226,8 @@ def _ltdb_reader(path, file, year, dropcols=None):
for row in dictionary['formula'].dropna().tolist():
df.eval(row, inplace=True)

keeps = df.columns[df.columns.isin(dictionary['variable'].tolist() + ['year'])]
keeps = df.columns[df.columns.isin(dictionary['variable'].tolist() +
['year'])]
df = df[keeps]

data_store._set(['ltdb'], df)
Expand Down Expand Up @@ -325,7 +334,8 @@ def read_ncdb(filepath):

df = df.round(0)

keeps = df.columns[df.columns.isin(dictionary['variable'].tolist() + ['year'])]
keeps = df.columns[df.columns.isin(dictionary['variable'].tolist() +
['year'])]

df = df[keeps]

Expand Down Expand Up @@ -416,12 +426,13 @@ def __init__(self,
self.boundary = boundary
if boundary.crs != self.tracts.crs:
if not boundary.crs:
raise('Boundary must have a CRS to ensure valid spatial \
raise ('Boundary must have a CRS to ensure valid spatial \
selection')
self.tracts = self.tracts.to_crs(boundary.crs)

self.tracts = self.tracts[self.tracts.representative_point()
.within(self.boundary.unary_union)]
self.tracts = self.tracts[
self.tracts.representative_point().within(
self.boundary.unary_union)]
self.counties = convert_gdf(self.counties[counties.geoid.isin(
self.tracts.geoid.str[0:5])])
self.states = convert_gdf(self.states[states.geoid.isin(
Expand Down Expand Up @@ -471,7 +482,9 @@ def __init__(self,
if source in ['ltdb', 'ncdb']:
_df = _db_checker(source)
if len(_df) == 0:
raise ValueError("Unable to locate {source} data. Please import the database with the `read_{source}` function".format(source=source))
raise ValueError(
"Unable to locate {source} data. Please import the database with the `read_{source}` function"
.format(source=source))
elif source == "external":
_df = data
else:
Expand All @@ -481,8 +494,8 @@ def __init__(self,
if cbsafips:
if not add_indices:
add_indices = []
add_indices += _cbsa[_cbsa['CBSA Code'] == cbsafips][
'stcofips'].tolist()
add_indices += _cbsa[_cbsa['CBSA Code'] ==
cbsafips]['stcofips'].tolist()
if add_indices:
for index in add_indices:

Expand All @@ -492,8 +505,8 @@ def __init__(self,
convert_gdf(counties[counties.geoid.str.startswith(
index[0:5])]))
self.tracts = self.tracts[~self.tracts.geoid.duplicated(keep='first')]
self.counties = self.counties[
~self.counties.geoid.duplicated(keep='first')]
self.counties = self.counties[~self.counties.geoid.duplicated(
keep='first')]
self.census = _df[_df.index.isin(self.tracts.geoid)]

def plot(self,
Expand Down Expand Up @@ -545,8 +558,8 @@ def plot(self,
else:
if self.name:
plt.title(
self.name + " " + str(year) + '\n' +
colname, fontsize=20)
self.name + " " + str(year) + '\n' + colname,
fontsize=20)
else:
plt.title(colname + " " + str(year), fontsize=20)
plt.axis("off")
Expand All @@ -561,10 +574,7 @@ def plot(self,

if plot_counties is True:
self.counties.plot(
edgecolor="#5c5353",
linewidth=0.8,
facecolor="none",
ax=ax)
edgecolor="#5c5353", linewidth=0.8, facecolor="none", ax=ax)

return ax

Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
15 changes: 15 additions & 0 deletions geosnap/tests/test_plots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from context import analyze, data
import matplotlib
import pytest
#matplotlib.use('agg')

from matplotlib.testing.decorators import image_comparison
import matplotlib.pyplot as plt

@pytest.mark.mpl_image_compare
def test_plot():
fig, ax = plt.subplots()
sd = data.Community(name='sd', source='ltdb', cbsafips='41740')
sd_clusters = analyze.cluster(sd, columns=['median_household_income', 'p_poverty_rate', 'p_edu_college_greater', 'p_unemployment_rate'], method='kmeans')
sd_clusters.plot(column='kmeans', ax=ax)
return fig
1 change: 1 addition & 0 deletions requirements_tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ nose-exclude
coverage
coveralls
matplotlib
pytest-mpl

0 comments on commit dd7c9f5

Please sign in to comment.