-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcollisions.R
167 lines (137 loc) · 4.79 KB
/
collisions.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
#-------------------------------------------------
# Set working directory
#-------------------------------------------------
setwd("/Users/witold/GitHub/SWITRS")
#-------------------------------------------------
# Load libraries
#-------------------------------------------------
# Libraries
library(tidyverse)
library(lubridate)
library(openxlsx)
library(RSQLite)
library(rpart.plot)
#-------------------------------------------------
# 1. Load data
#-------------------------------------------------
# Create connection
con <- dbConnect(SQLite(), "data/switrs.sqlite")
# View available tables
as.data.frame(dbListTables(con))
# Get tables
#case_ids <- dbReadTable(con, 'case_ids') %>% as_tibble()
#collisions <- dbReadTable(con, 'collisions') %>% as_tibble()
parties <- dbReadTable(con, 'parties') %>% as_tibble()
#victims <- dbReadTable(con, 'victims') %>% as_tibble()
# data is fetched so disconnect it.
dbDisconnect(con)
# Join
#df <- left_join(collisions,parties)
df <- parties
#-------------------------------------------------
# 2. Clean data set
#-------------------------------------------------
# Shuffle data
shuffle_index <- sample(1:nrow(df))
df <- df[shuffle_index, ]
# Drop NAs and Convert to factor level
clean_df <- df %>%
# Select drivers only to test for fault
#filter(
# party_type == "driver"
#) %>%
# Manually select reasonable independent vars
select(
at_fault
,party_type
,party_sex
,party_age
,party_sobriety
,party_drug_physical
,party_safety_equipment_1
,party_safety_equipment_2
,financial_responsibility
,hazardous_materials
,cellphone_use
,other_associate_factor_1
,other_associate_factor_2
,movement_preceding_collision
,vehicle_year
,vehicle_make
,statewide_vehicle_type
,party_race
) %>%
# Remove troublesome values
filter(
!movement_preceding_collision %in% c("S","0","4")
) %>%
# Convert to appropriate format / factor levels
mutate(
at_fault = at_fault
,party_type = addNA(party_type)
,party_sex = addNA(party_type)
,party_age = party_age
,party_sobriety = addNA(party_sobriety)
,party_drug_physical = addNA(party_drug_physical)
,party_safety_equipment_1 = addNA(party_safety_equipment_1)
,party_safety_equipment_2 = addNA(party_safety_equipment_2)
,financial_responsibility = addNA(financial_responsibility)
,hazardous_materials = addNA(hazardous_materials)
,cellphone_use = addNA(cellphone_use)
,other_associate_factor_1 = addNA(other_associate_factor_1)
,other_associate_factor_2 = addNA(other_associate_factor_2)
,movement_preceding_collision = addNA(movement_preceding_collision)
,vehicle_year = vehicle_year
,vehicle_make = addNA(vehicle_make)
,statewide_vehicle_type = addNA(statewide_vehicle_type)
,party_race = addNA(party_race)
)
#-------------------------------------------------
# 3. Create train / test subsets
#-------------------------------------------------
# Train / Test function
create_train_test <- function(data, train_proportion, train) {
n_row = nrow(data)
total_row = floor(train_proportion * n_row)
train_sample <- 1:total_row
if (train == TRUE) {
return (data[train_sample, ])
} else {
return (data[-train_sample, ])
}
}
# Define training & test subset
data_train <- create_train_test(clean_df, 0.8, TRUE)
data_test <- create_train_test(clean_df, 0.8, FALSE)
# Check that proportions are correct (proportion of survivors in both data sets should be same)
prop.table(table(data_train$at_fault))
prop.table(table(data_test$at_fault))
#-------------------------------------------------
# 4. Build model
#-------------------------------------------------
# Define model
fit <- rpart(at_fault~., data = data_train, method = 'class')
# Set = 106 (i.e. binary model, more in vignette:
# https://cran.r-project.org/web/packages/rpart.plot/rpart.plot.pdf)
rpart.plot(fit,
extra = 106, # show fitted class, probs, percentages
box.palette = "GnBu", # color scheme
branch.lty = 3, # dotted branch lines
shadow.col = "gray", # shadows under the node boxes
nn = TRUE) # display the node numbers
#-------------------------------------------------
# 5. Make prediction
#-------------------------------------------------
# Define
predict_unseen <-predict(fit, data_test, type = 'class')
# Check
table_mat <- table(data_test$at_fault, predict_unseen)
# Read output (Yes-Yes correct, No-No correct, variations Yes-No means model misclassified)
table_mat
#-------------------------------------------------
# 6. Calculate model accuracy / measure performance
#-------------------------------------------------
# Calculate accuracy
accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)
# Print model accuracy
print(paste('Accuracy for test', accuracy_Test))