-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathapp.py
253 lines (212 loc) · 18.5 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
import streamlit as st
from pathlib import Path
import weaviate as wv
import utils
import math
import pandas as pd
# ===== START STREAMLIT APP =====
st.set_page_config(layout="wide")
st.header("Data importer wizard 🧙♀️🪄 POC")
with st.expander("Turn users into Wizards!"):
st.markdown("![Image]()")
# ===== CONNECT TO A DATABASE =====
st.write("#### Connect to Weaviate")
if 'wv_url' not in st.session_state:
st.session_state['wv_url'] = "http://"
wv_url = st.text_input("Server URL:", value=st.session_state['wv_url'])
st.session_state['wv_url'] = wv_url
client = False
if len(st.session_state['wv_url']) > 7:
try:
client = wv.Client(st.session_state['wv_url'])
client.batch.configure(
batch_size=100,
dynamic=False,
timeout_retries=3,
)
except Exception as e:
st.write("No server found.")
st.write("**Connection error:**")
st.warning(e)
# ===== QUERY THE DATABASE =====
if client:
st.write(f"**CONNECTED** TO {st.session_state['wv_url']}")
# TODO - add text input for class name
# TODO - check class name validation
if 'schema' not in st.session_state:
st.session_state['schema'] = utils.get_schema(client)
if 'obj_count' not in st.session_state:
st.session_state['obj_count'] = utils.get_tot_object_count(client)
if 'custom_import_progress' not in st.session_state:
st.session_state['custom_import_progress'] = 0
if 'import_progress' not in st.session_state:
st.session_state['import_progress'] = 0
# ===== FETCH CURRENT STATUS =====
summ_c1, summ_c2 = st.columns(2)
class_objs = st.session_state['schema']['classes']
n_classes = len(class_objs)
if n_classes > 0:
n_objs = st.session_state['obj_count']
class_names = [i['class'] for i in class_objs]
else:
n_objs = 0
class_names = []
with summ_c1:
def update_dbstats(client):
st.session_state['schema'] = utils.get_schema(client)
st.session_state['obj_count'] = utils.get_tot_object_count(client)
st.write(f"{n_classes} class(es) and {n_objs} object(s) found in the database!")
st.button("Update database stats", on_click=update_dbstats, args=[client])
# with summ_c2:
# print(st.session_state['obj_count'])
# st.write(f"{st.session_state['obj_count']} object(s) in database")
with st.expander("Preview schema / data here"):
state_c1, state_c2 = st.columns(2)
if len(st.session_state['schema']['classes']) > 0:
with state_c1:
if type(st.session_state['schema']) == dict:
wv_classes = [i['class'] for i in st.session_state['schema']['classes']]
sch_select = st.selectbox("Pick a class", wv_classes)
st.write("**Schema**")
st.write(st.session_state['schema']['classes'][wv_classes.index(sch_select)])
with state_c2:
st.write("**Example objects**")
st.write(client.data_object.get())
else:
st.write("Nothing to see here! Try adding a schema and some data!")
st.markdown("-----")
# ===== SET UP FILE PARSING =====
st.write("#### Add Data")
st.write("You can import data in just a few clicks. Import a demo dataset just one click, or choose your own data.")
# t1, t2 = st.tabs(['Demo Datasets', 'Your Own Data'])
# with t1:
# # st.write("**Create your own dataset**")
# # manual_data = st.text_area("Specify your own data as a JSON", value="example JSON goes here")
with st.expander("Demo Datasets"):
# ===== Dataset =====
st.write("**Import a demo dataset**")
st.write("Import any of the below datasets in just one click.")
demo_max_objs = st.number_input("Max objects to import", value=500)
st.markdown("-----")
def import_csv_data(client, fpath, class_name, max_objs=1000, skip_schema=False):
class_obj = utils.get_csv_to_class(fpath, class_name)
if not skip_schema:
utils.build_schema(client, class_obj)
obj_count = utils.get_object_count(client, class_obj["class"])
else:
obj_count = 0
df = pd.read_csv(fpath, index_col=0)
df = df.fillna(df.mode().iloc[0])
with client.batch as batch:
for i, rowdata in df[obj_count:].iterrows():
batch.add_data_object(
data_object=rowdata.to_dict(),
class_name=class_obj["class"],
)
import_progbar.progress(i / max_objs)
if i+1 >= max_objs:
break
st.session_state['import_progress'] = 100
return True
demo_datasets = [
{"name": "Wine reviews", "description": "Wine reviews from [Kaggle](https://www.kaggle.com/datasets/zynicide/wine-reviews) containing 150k reviews.", "callback": import_csv_data, "args": [client, "demodata/winemag-data-130k-v2.csv", "WineReview", demo_max_objs], "disabled": False},
{"name": "Yelp reviews", "description": "Placeholder: 7M Yelp Reviews.", "callback": import_csv_data, "args": [client, "data/winemag-data-130k-v2.csv", "YelpReview", demo_max_objs], "disabled": True},
{"name": "Tiny ImageNet", "description": "Placholder: 100000 images of 200 classes.", "callback": import_csv_data, "args": [client, "data/winemag-data-130k-v2.csv", "ImageNetImage", demo_max_objs], "disabled": True},
]
bcs = st.columns(len(demo_datasets))
for i, bc in enumerate(bcs):
with bcs[i]:
st.write(f"**{demo_datasets[i]['name']}**")
st.write(f"{demo_datasets[i]['description']}")
if demo_datasets[i]['args'][2] not in class_names:
st.button(f"Import", key=f"default_import_button_{i}", on_click=demo_datasets[i]['callback'], args=demo_datasets[i]['args'], disabled=demo_datasets[i]['disabled'])
else:
st.button("Import more data", on_click=demo_datasets[i]['callback'], args=demo_datasets[i]['args'] + [True])
import_progbar = st.progress(st.session_state['import_progress'])
st.markdown("-----")
# - If the class does not exist, enable the import button
with st.expander("Your own data"):
st.write("**Import data from file**")
datadir = "./data"
path = Path(datadir)
data_files = ["."] + [f for f in path.glob('*.json')] + [f for f in path.glob('*.csv')]
if 'datafile_ind' not in st.session_state:
st.session_state['datafile_ind'] = 0
fpath = st.selectbox('Select file to import', data_files, st.session_state['datafile_ind'], key=3)
st.session_state['datafile_ind'] = data_files.index(fpath)
if fpath != ".":
col1, col2 = st.columns(2)
# ===== Build DataFrame for schema building
st.write("**Data Preview**")
st.markdown("Adjust the detected datatypes as appropriate for the schema.")
df = utils.get_preview_df(fpath)
st.dataframe(df)
# ===== Select data types for schema
st.write("**Schema builder**")
n_cols = 3
cols = st.columns(n_cols)
n_elms = math.ceil(len(df.dtypes) / n_cols)
dtype_series_parts = [df.dtypes[n_elms * i:n_elms * (i+1)] for i in range(n_cols)]
def build_schema_boxes(dtype_sers):
return [
st.selectbox(
f"{colname}",
utils.dtype_maps.values(),
utils.get_dtype_index(datatype),
key=colname,
) for colname, datatype in dtype_sers.iteritems()
]
wv_dtypes_parts = [[]] * n_cols
for i in range(n_cols):
with cols[i]:
wv_dtypes_parts[i] = build_schema_boxes(dtype_series_parts[i])
wv_dtypes = wv_dtypes_parts[0] + wv_dtypes_parts[1] + wv_dtypes_parts[2]
# ===== Add Schema
custom_classname = st.text_input("Class name", value="UserData")
props = [
{
"dataType": [wv_dtypes[i]],
"name": df.columns[i],
"description": f"Contains_{df.columns[i]}"
} for i in range(len(wv_dtypes))
]
class_obj = {
"class": custom_classname,
"description": f"Contains {custom_classname} data",
"properties": props,
}
if custom_classname not in class_names:
st.button("Add this ☝️ class", on_click=utils.build_schema, args=[client, class_obj])
else:
st.button("Class found", disabled=True, key="found")
# ===== Add Data
st.write("**Import data**")
def add_data(fpath, progbar, max_objs=200):
if fpath.suffix == '.json':
progbar.progress(0)
with client.batch as batch:
for i, l in enumerate(utils.parse_json(fpath)): # TODO - add CSV parsing option
progbar.progress(i / max_objs)
batch.add_data_object(
data_object=l,
class_name=custom_classname,
)
if i+1 >= max_objs:
break
else:
progbar.progress(0)
df = pd.read_csv(fpath)
with client.batch as batch:
for i, rowdata in df.iterrows():
batch.add_data_object(
data_object=rowdata.to_dict(),
class_name=class_obj["class"],
)
import_progbar.progress(i / max_objs)
if i+1 >= max_objs:
break
st.session_state['custom_import_progress'] = 100
return True
cust_max_objs = st.number_input("Max objects", value=500)
upload_progbar = st.progress(st.session_state['custom_import_progress'])
st.button("Add data", on_click=add_data, args=[fpath, upload_progbar, cust_max_objs], key="custom_data")