{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Keras Exercise\n",
"\n",
"## Predict political party based on votes\n",
"\n",
"As a fun little example, we'll use a public data set of how US congressmen voted on 17 different issues in the year 1984. Let's see if we can figure out their political party based on their votes alone, using a deep neural network!\n",
"\n",
"For those outside the United States, our two main political parties are \"Democrat\" and \"Republican.\" In modern times they represent progressive and conservative ideologies, respectively.\n",
"\n",
"Politics in 1984 weren't quite as polarized as they are today, but you should still be able to get over 90% accuracy without much trouble.\n",
"\n",
"Since the point of this exercise is implementing neural networks in Keras, I'll help you to load and prepare the data.\n",
"\n",
"Let's start by importing the raw CSV file using Pandas, and make a DataFrame out of it with nice column labels:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>party</th>\n",
" <th>handicapped-infants</th>\n",
" <th>water-project-cost-sharing</th>\n",
" <th>adoption-of-the-budget-resolution</th>\n",
" <th>physician-fee-freeze</th>\n",
" <th>el-salvador-aid</th>\n",
" <th>religious-groups-in-schools</th>\n",
" <th>anti-satellite-test-ban</th>\n",
" <th>aid-to-nicaraguan-contras</th>\n",
" <th>mx-missle</th>\n",
" <th>immigration</th>\n",
" <th>synfuels-corporation-cutback</th>\n",
" <th>education-spending</th>\n",
" <th>superfund-right-to-sue</th>\n",
" <th>crime</th>\n",
" <th>duty-free-exports</th>\n",
" <th>export-administration-act-south-africa</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>republican</td>\n",
" <td>n</td>\n",
" <td>y</td>\n",
" <td>n</td>\n",
" <td>y</td>\n",
" <td>y</td>\n",
" <td>y</td>\n",
" <td>n</td>\n",
" <td>n</td>\n",
" <td>n</td>\n",
" <td>y</td>\n",
" <td>NaN</td>\n",
" <td>y</td>\n",
" <td>y</td>\n",
" <td>y</td>\n",
" <td>n</td>\n",
" <td>y</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>republican</td>\n",
" <td>n</td>\n",
" <td>y</td>\n",
" <td>n</td>\n",
" <td>y</td>\n",
" <td>y</td>\n",
" <td>y</td>\n",
" <td>n</td>\n",
" <td>n</td>\n",
" <td>n</td>\n",
" <td>n</td>\n",
" <td>n</td>\n",
" <td>y</td>\n",
" <td>y</td>\n",
" <td>y</td>\n",
" <td>n</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>democrat</td>\n",
" <td>NaN</td>\n",
" <td>y</td>\n",
" <td>y</td>\n",
" <td>NaN</td>\n",
" <td>y</td>\n",
" <td>y</td>\n",
" <td>n</td>\n",
" <td>n</td>\n",
" <td>n</td>\n",
" <td>n</td>\n",
" <td>y</td>\n",
" <td>n</td>\n",
" <td>y</td>\n",
" <td>y</td>\n",
" <td>n</td>\n",
" <td>n</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>democrat</td>\n",
" <td>n</td>\n",
" <td>y</td>\n",
" <td>y</td>\n",
" <td>n</td>\n",
" <td>NaN</td>\n",
" <td>y</td>\n",
" <td>n</td>\n",
" <td>n</td>\n",
" <td>n</td>\n",
" <td>n</td>\n",
" <td>y</td>\n",
" <td>n</td>\n",
" <td>y</td>\n",
" <td>n</td>\n",
" <td>n</td>\n",
" <td>y</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>democrat</td>\n",
" <td>y</td>\n",
" <td>y</td>\n",
" <td>y</td>\n",
" <td>n</td>\n",
" <td>y</td>\n",
" <td>y</td>\n",
" <td>n</td>\n",
" <td>n</td>\n",
" <td>n</td>\n",
" <td>n</td>\n",
" <td>y</td>\n",
" <td>NaN</td>\n",
" <td>y</td>\n",
" <td>y</td>\n",
" <td>y</td>\n",
" <td>y</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" party handicapped-infants water-project-cost-sharing \\\n",
"0 republican n y \n",
"1 republican n y \n",
"2 democrat NaN y \n",
"3 democrat n y \n",
"4 democrat y y \n",
"\n",
" adoption-of-the-budget-resolution physician-fee-freeze el-salvador-aid \\\n",
"0 n y y \n",
"1 n y y \n",
"2 y NaN y \n",
"3 y n NaN \n",
"4 y n y \n",
"\n",
" religious-groups-in-schools anti-satellite-test-ban \\\n",
"0 y n \n",
"1 y n \n",
"2 y n \n",
"3 y n \n",
"4 y n \n",
"\n",
" aid-to-nicaraguan-contras mx-missle immigration \\\n",
"0 n n y \n",
"1 n n n \n",
"2 n n n \n",
"3 n n n \n",
"4 n n n \n",
"\n",
" synfuels-corporation-cutback education-spending superfund-right-to-sue \\\n",
"0 NaN y y \n",
"1 n y y \n",
"2 y n y \n",
"3 y n y \n",
"4 y NaN y \n",
"\n",
" crime duty-free-exports export-administration-act-south-africa \n",
"0 y n y \n",
"1 y n NaN \n",
"2 y n n \n",
"3 n n y \n",
"4 y y y "
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"\n",
"feature_names = ['party','handicapped-infants', 'water-project-cost-sharing', \n",
" 'adoption-of-the-budget-resolution', 'physician-fee-freeze',\n",
" 'el-salvador-aid', 'religious-groups-in-schools',\n",
" 'anti-satellite-test-ban', 'aid-to-nicaraguan-contras',\n",
" 'mx-missle', 'immigration', 'synfuels-corporation-cutback',\n",
" 'education-spending', 'superfund-right-to-sue', 'crime',\n",
" 'duty-free-exports', 'export-administration-act-south-africa']\n",
"\n",
"voting_data = pd.read_csv('house-votes-84.data.txt', na_values=['?'], \n",
" names = feature_names)\n",
"voting_data.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can use describe() to get a feel of how the data looks in aggregate:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>party</th>\n",
" <th>handicapped-infants</th>\n",
" <th>water-project-cost-sharing</th>\n",
" <th>adoption-of-the-budget-resolution</th>\n",
" <th>physician-fee-freeze</th>\n",
" <th>el-salvador-aid</th>\n",
" <th>religious-groups-in-schools</th>\n",
" <th>anti-satellite-test-ban</th>\n",
" <th>aid-to-nicaraguan-contras</th>\n",
" <th>mx-missle</th>\n",
" <th>immigration</th>\n",
" <th>synfuels-corporation-cutback</th>\n",
" <th>education-spending</th>\n",
" <th>superfund-right-to-sue</th>\n",
" <th>crime</th>\n",
" <th>duty-free-exports</th>\n",
" <th>export-administration-act-south-africa</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>count</th>\n",
" <td>435</td>\n",
" <td>423</td>\n",
" <td>387</td>\n",
" <td>424</td>\n",
" <td>424</td>\n",
" <td>420</td>\n",
" <td>424</td>\n",
" <td>421</td>\n",
" <td>420</td>\n",
" <td>413</td>\n",
" <td>428</td>\n",
" <td>414</td>\n",
" <td>404</td>\n",
" <td>410</td>\n",
" <td>418</td>\n",
" <td>407</td>\n",
" <td>331</td>\n",
" </tr>\n",
" <tr>\n",
" <th>unique</th>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>top</th>\n",
" <td>democrat</td>\n",
" <td>n</td>\n",
" <td>y</td>\n",
" <td>y</td>\n",
" <td>n</td>\n",
" <td>y</td>\n",
" <td>y</td>\n",
" <td>y</td>\n",
" <td>y</td>\n",
" <td>y</td>\n",
" <td>y</td>\n",
" <td>n</td>\n",
" <td>n</td>\n",
" <td>y</td>\n",
" <td>y</td>\n",
" <td>n</td>\n",
" <td>y</td>\n",
" </tr>\n",
" <tr>\n",
" <th>freq</th>\n",
" <td>267</td>\n",
" <td>236</td>\n",
" <td>195</td>\n",
" <td>253</td>\n",
" <td>247</td>\n",
" <td>212</td>\n",
" <td>272</td>\n",
" <td>239</td>\n",
" <td>242</td>\n",
" <td>207</td>\n",
" <td>216</td>\n",
" <td>264</td>\n",
" <td>233</td>\n",
" <td>209</td>\n",
" <td>248</td>\n",
" <td>233</td>\n",
" <td>269</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" party handicapped-infants water-project-cost-sharing \\\n",
"count 435 423 387 \n",
"unique 2 2 2 \n",
"top democrat n y \n",
"freq 267 236 195 \n",
"\n",
" adoption-of-the-budget-resolution physician-fee-freeze el-salvador-aid \\\n",
"count 424 424 420 \n",
"unique 2 2 2 \n",
"top y n y \n",
"freq 253 247 212 \n",
"\n",
" religious-groups-in-schools anti-satellite-test-ban \\\n",
"count 424 421 \n",
"unique 2 2 \n",
"top y y \n",
"freq 272 239 \n",
"\n",
" aid-to-nicaraguan-contras mx-missle immigration \\\n",
"count 420 413 428 \n",
"unique 2 2 2 \n",
"top y y y \n",
"freq 242 207 216 \n",
"\n",
" synfuels-corporation-cutback education-spending superfund-right-to-sue \\\n",
"count 414 404 410 \n",
"unique 2 2 2 \n",
"top n n y \n",
"freq 264 233 209 \n",
"\n",
" crime duty-free-exports export-administration-act-south-africa \n",
"count 418 407 331 \n",
"unique 2 2 2 \n",
"top y n y \n",
"freq 248 233 269 "
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"voting_data.describe()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see there's some missing data to deal with here; some politicians abstained on some votes, or just weren't present when the vote was taken. We will just drop the rows with missing data to keep it simple, but in practice you'd want to first make sure that doing so didn't introduce any sort of bias into your analysis (if one party abstains more than another, that could be problematic for example.)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>party</th>\n",
" <th>handicapped-infants</th>\n",
" <th>water-project-cost-sharing</th>\n",
" <th>adoption-of-the-budget-resolution</th>\n",
" <th>physician-fee-freeze</th>\n",
" <th>el-salvador-aid</th>\n",
" <th>religious-groups-in-schools</th>\n",
" <th>anti-satellite-test-ban</th>\n",
" <th>aid-to-nicaraguan-contras</th>\n",
" <th>mx-missle</th>\n",
" <th>immigration</th>\n",
" <th>synfuels-corporation-cutback</th>\n",
" <th>education-spending</th>\n",
" <th>superfund-right-to-sue</th>\n",
" <th>crime</th>\n",
" <th>duty-free-exports</th>\n",
" <th>export-administration-act-south-africa</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>count</th>\n",
" <td>232</td>\n",
" <td>232</td>\n",
" <td>232</td>\n",
" <td>232</td>\n",
" <td>232</td>\n",
" <td>232</td>\n",
" <td>232</td>\n",
" <td>232</td>\n",
" <td>232</td>\n",
" <td>232</td>\n",
" <td>232</td>\n",
" <td>232</td>\n",
" <td>232</td>\n",
" <td>232</td>\n",
" <td>232</td>\n",
" <td>232</td>\n",
" <td>232</td>\n",
" </tr>\n",
" <tr>\n",
" <th>unique</th>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>top</th>\n",
" <td>democrat</td>\n",
" <td>n</td>\n",
" <td>n</td>\n",
" <td>y</td>\n",
" <td>n</td>\n",
" <td>y</td>\n",
" <td>y</td>\n",
" <td>y</td>\n",
" <td>y</td>\n",
" <td>n</td>\n",
" <td>y</td>\n",
" <td>n</td>\n",
" <td>n</td>\n",
" <td>y</td>\n",
" <td>y</td>\n",
" <td>n</td>\n",
" <td>y</td>\n",
" </tr>\n",
" <tr>\n",
" <th>freq</th>\n",
" <td>124</td>\n",
" <td>136</td>\n",
" <td>125</td>\n",
" <td>123</td>\n",
" <td>119</td>\n",
" <td>128</td>\n",
" <td>149</td>\n",
" <td>124</td>\n",
" <td>119</td>\n",
" <td>119</td>\n",
" <td>128</td>\n",
" <td>152</td>\n",
" <td>124</td>\n",
" <td>127</td>\n",
" <td>149</td>\n",
" <td>146</td>\n",
" <td>189</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" party handicapped-infants water-project-cost-sharing \\\n",
"count 232 232 232 \n",
"unique 2 2 2 \n",
"top democrat n n \n",
"freq 124 136 125 \n",
"\n",
" adoption-of-the-budget-resolution physician-fee-freeze el-salvador-aid \\\n",
"count 232 232 232 \n",
"unique 2 2 2 \n",
"top y n y \n",
"freq 123 119 128 \n",
"\n",
" religious-groups-in-schools anti-satellite-test-ban \\\n",
"count 232 232 \n",
"unique 2 2 \n",
"top y y \n",
"freq 149 124 \n",
"\n",
" aid-to-nicaraguan-contras mx-missle immigration \\\n",
"count 232 232 232 \n",
"unique 2 2 2 \n",
"top y n y \n",
"freq 119 119 128 \n",
"\n",
" synfuels-corporation-cutback education-spending superfund-right-to-sue \\\n",
"count 232 232 232 \n",
"unique 2 2 2 \n",
"top n n y \n",
"freq 152 124 127 \n",
"\n",
" crime duty-free-exports export-administration-act-south-africa \n",
"count 232 232 232 \n",
"unique 2 2 2 \n",
"top y n y \n",
"freq 149 146 189 "
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"voting_data.dropna(inplace=True)\n",
"voting_data.describe()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Our neural network needs normalized numbers, not strings, to work. So let's replace all the y's and n's with 1's and 0's, and represent the parties as 1's and 0's as well."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\Frank\\AppData\\Local\\Temp\\ipykernel_9892\\4251409448.py:1: FutureWarning: Downcasting behavior in `replace` is deprecated and will be removed in a future version. To retain the old behavior, explicitly call `result.infer_objects(copy=False)`. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)`\n",
" voting_data.replace(('y', 'n'), (1, 0), inplace=True)\n",
"C:\\Users\\Frank\\AppData\\Local\\Temp\\ipykernel_9892\\4251409448.py:2: FutureWarning: Downcasting behavior in `replace` is deprecated and will be removed in a future version. To retain the old behavior, explicitly call `result.infer_objects(copy=False)`. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)`\n",
" voting_data.replace(('democrat', 'republican'), (1, 0), inplace=True)\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>party</th>\n",
" <th>handicapped-infants</th>\n",
" <th>water-project-cost-sharing</th>\n",
" <th>adoption-of-the-budget-resolution</th>\n",
" <th>physician-fee-freeze</th>\n",
" <th>el-salvador-aid</th>\n",
" <th>religious-groups-in-schools</th>\n",
" <th>anti-satellite-test-ban</th>\n",
" <th>aid-to-nicaraguan-contras</th>\n",
" <th>mx-missle</th>\n",
" <th>immigration</th>\n",
" <th>synfuels-corporation-cutback</th>\n",
" <th>education-spending</th>\n",
" <th>superfund-right-to-sue</th>\n",
" <th>crime</th>\n",
" <th>duty-free-exports</th>\n",
" <th>export-administration-act-south-africa</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19</th>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>23</th>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25</th>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>423</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>426</th>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>427</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>430</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>431</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>232 rows × 17 columns</p>\n",
"</div>"
],
"text/plain": [
" party handicapped-infants water-project-cost-sharing \\\n",
"5 1 0 1 \n",
"8 0 0 1 \n",
"19 1 1 1 \n",
"23 1 1 1 \n",
"25 1 1 0 \n",
".. ... ... ... \n",
"423 1 0 1 \n",
"426 1 1 0 \n",
"427 0 0 0 \n",
"430 0 0 0 \n",
"431 1 0 0 \n",
"\n",
" adoption-of-the-budget-resolution physician-fee-freeze el-salvador-aid \\\n",
"5 1 0 1 \n",
"8 0 1 1 \n",
"19 1 0 0 \n",
"23 1 0 0 \n",
"25 1 0 0 \n",
".. ... ... ... \n",
"423 1 0 0 \n",
"426 1 0 0 \n",
"427 0 1 1 \n",
"430 1 1 1 \n",
"431 1 0 0 \n",
"\n",
" religious-groups-in-schools anti-satellite-test-ban \\\n",
"5 1 0 \n",
"8 1 0 \n",
"19 0 1 \n",
"23 0 1 \n",
"25 0 1 \n",
".. ... ... \n",
"423 1 1 \n",
"426 0 1 \n",
"427 1 1 \n",
"430 1 0 \n",
"431 0 1 \n",
"\n",
" aid-to-nicaraguan-contras mx-missle immigration \\\n",
"5 0 0 0 \n",
"8 0 0 0 \n",
"19 1 1 0 \n",
"23 1 1 0 \n",
"25 1 1 1 \n",
".. ... ... ... \n",
"423 1 1 0 \n",
"426 1 1 1 \n",
"427 1 0 1 \n",
"430 0 1 1 \n",
"431 1 1 1 \n",
"\n",
" synfuels-corporation-cutback education-spending superfund-right-to-sue \\\n",
"5 0 0 1 \n",
"8 0 1 1 \n",
"19 1 0 0 \n",
"23 0 0 0 \n",
"25 0 0 0 \n",
".. ... ... ... \n",
"423 1 0 0 \n",
"426 0 0 0 \n",
"427 0 1 1 \n",
"430 0 1 1 \n",
"431 0 0 0 \n",
"\n",
" crime duty-free-exports export-administration-act-south-africa \n",
"5 1 1 1 \n",
"8 1 0 1 \n",
"19 0 1 1 \n",
"23 0 1 1 \n",
"25 0 1 1 \n",
".. ... ... ... \n",
"423 1 1 1 \n",
"426 0 1 1 \n",
"427 1 0 1 \n",
"430 1 0 1 \n",
"431 0 0 1 \n",
"\n",
"[232 rows x 17 columns]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"voting_data.replace(('y', 'n'), (1, 0), inplace=True)\n",
"voting_data.replace(('democrat', 'republican'), (1, 0), inplace=True)\n",
"voting_data.astype(int)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>party</th>\n",
" <th>handicapped-infants</th>\n",
" <th>water-project-cost-sharing</th>\n",
" <th>adoption-of-the-budget-resolution</th>\n",
" <th>physician-fee-freeze</th>\n",
" <th>el-salvador-aid</th>\n",
" <th>religious-groups-in-schools</th>\n",
" <th>anti-satellite-test-ban</th>\n",
" <th>aid-to-nicaraguan-contras</th>\n",
" <th>mx-missle</th>\n",
" <th>immigration</th>\n",
" <th>synfuels-corporation-cutback</th>\n",
" <th>education-spending</th>\n",
" <th>superfund-right-to-sue</th>\n",
" <th>crime</th>\n",
" <th>duty-free-exports</th>\n",
" <th>export-administration-act-south-africa</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19</th>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>23</th>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25</th>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" party handicapped-infants water-project-cost-sharing \\\n",
"5 1 0 1 \n",
"8 0 0 1 \n",
"19 1 1 1 \n",
"23 1 1 1 \n",
"25 1 1 0 \n",
"\n",
" adoption-of-the-budget-resolution physician-fee-freeze el-salvador-aid \\\n",
"5 1 0 1 \n",
"8 0 1 1 \n",
"19 1 0 0 \n",
"23 1 0 0 \n",
"25 1 0 0 \n",
"\n",
" religious-groups-in-schools anti-satellite-test-ban \\\n",
"5 1 0 \n",
"8 1 0 \n",
"19 0 1 \n",
"23 0 1 \n",
"25 0 1 \n",
"\n",
" aid-to-nicaraguan-contras mx-missle immigration \\\n",
"5 0 0 0 \n",
"8 0 0 0 \n",
"19 1 1 0 \n",
"23 1 1 0 \n",
"25 1 1 1 \n",
"\n",
" synfuels-corporation-cutback education-spending superfund-right-to-sue \\\n",
"5 0 0 1 \n",
"8 0 1 1 \n",
"19 1 0 0 \n",
"23 0 0 0 \n",
"25 0 0 0 \n",
"\n",
" crime duty-free-exports export-administration-act-south-africa \n",
"5 1 1 1 \n",
"8 1 0 1 \n",
"19 0 1 1 \n",
"23 0 1 1 \n",
"25 0 1 1 "
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"voting_data.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally let's extract the features and labels in the form that Keras will expect:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"all_features = voting_data[feature_names].drop('party', axis=1).values\n",
"all_classes = voting_data['party'].values"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"OK, so have a go at it! You'll want to refer back to the slide on using Keras with binary classification - there are only two parties, so this is a binary problem. This also saves us the hassle of representing classes with \"one-hot\" format like we had to do with MNIST; our output is just a single 0 or 1 value.\n",
"\n",
"Also refer to the scikit_learn integration slide, and use cross_val_score to evaluate your resulting model with 10-fold cross-validation.\n",
"\n",
"**If you're using tensorflow-gpu on a Windows machine** by the way, you probably *do* want to peek a little bit at my solution - if you run into memory allocation errors, there's a workaround there you can use.\n",
"\n",
"Try out your code here; be sure to have scikeras installed if you don't already (you may need to launch this notebook with admin privleges, or just install it from your Anaconda prompt):"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: scikeras in c:\\users\\frank\\anaconda3\\lib\\site-packages (0.13.0)\n",
"Requirement already satisfied: keras>=3.2.0 in c:\\users\\frank\\anaconda3\\lib\\site-packages (from scikeras) (3.8.0)\n",
"Requirement already satisfied: scikit-learn>=1.4.2 in c:\\users\\frank\\anaconda3\\lib\\site-packages (from scikeras) (1.5.1)\n",
"Requirement already satisfied: absl-py in c:\\users\\frank\\anaconda3\\lib\\site-packages (from keras>=3.2.0->scikeras) (2.1.0)\n",
"Requirement already satisfied: numpy in c:\\users\\frank\\anaconda3\\lib\\site-packages (from keras>=3.2.0->scikeras) (1.26.4)\n",
"Requirement already satisfied: rich in c:\\users\\frank\\anaconda3\\lib\\site-packages (from keras>=3.2.0->scikeras) (13.7.1)\n",
"Requirement already satisfied: namex in c:\\users\\frank\\anaconda3\\lib\\site-packages (from keras>=3.2.0->scikeras) (0.0.8)\n",
"Requirement already satisfied: h5py in c:\\users\\frank\\anaconda3\\lib\\site-packages (from keras>=3.2.0->scikeras) (3.11.0)\n",
"Requirement already satisfied: optree in c:\\users\\frank\\anaconda3\\lib\\site-packages (from keras>=3.2.0->scikeras) (0.14.0)\n",
"Requirement already satisfied: ml-dtypes in c:\\users\\frank\\anaconda3\\lib\\site-packages (from keras>=3.2.0->scikeras) (0.4.1)\n",
"Requirement already satisfied: packaging in c:\\users\\frank\\anaconda3\\lib\\site-packages (from keras>=3.2.0->scikeras) (24.1)\n",
"Requirement already satisfied: scipy>=1.6.0 in c:\\users\\frank\\anaconda3\\lib\\site-packages (from scikit-learn>=1.4.2->scikeras) (1.13.1)\n",
"Requirement already satisfied: joblib>=1.2.0 in c:\\users\\frank\\anaconda3\\lib\\site-packages (from scikit-learn>=1.4.2->scikeras) (1.4.2)\n",
"Requirement already satisfied: threadpoolctl>=3.1.0 in c:\\users\\frank\\anaconda3\\lib\\site-packages (from scikit-learn>=1.4.2->scikeras) (3.5.0)\n",
"Requirement already satisfied: typing-extensions>=4.5.0 in c:\\users\\frank\\anaconda3\\lib\\site-packages (from optree->keras>=3.2.0->scikeras) (4.11.0)\n",
"Requirement already satisfied: markdown-it-py>=2.2.0 in c:\\users\\frank\\anaconda3\\lib\\site-packages (from rich->keras>=3.2.0->scikeras) (2.2.0)\n",
"Requirement already satisfied: pygments<3.0.0,>=2.13.0 in c:\\users\\frank\\anaconda3\\lib\\site-packages (from rich->keras>=3.2.0->scikeras) (2.15.1)\n",
"Requirement already satisfied: mdurl~=0.1 in c:\\users\\frank\\anaconda3\\lib\\site-packages (from markdown-it-py>=2.2.0->rich->keras>=3.2.0->scikeras) (0.1.0)\n"
]
}
],
"source": [
"!pip install scikeras"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## My implementation is below\n",
"\n",
"# No peeking!\n",
"\n",
""
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:5 out of the last 5 calls to <function TensorFlowTrainer.make_predict_function.<locals>.one_step_on_data_distributed at 0x000001F8CC995620> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n",
"WARNING:tensorflow:6 out of the last 6 calls to <function TensorFlowTrainer.make_predict_function.<locals>.one_step_on_data_distributed at 0x000001F8CDB891C0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n",
"0.9438405797101449\n"
]
}
],
"source": [
"from tensorflow.keras.layers import Dense\n",
"from tensorflow.keras.models import Sequential\n",
"from sklearn.model_selection import cross_val_score\n",
"from scikeras.wrappers import KerasClassifier\n",
"from tensorflow.keras.layers import Input\n",
"\n",
"def create_model():\n",
" model = Sequential([\n",
" Input(shape=(16,)),\n",
" Dense(32, kernel_initializer='normal', activation='relu'),\n",
" Dense(16, kernel_initializer='normal', activation='relu'),\n",
" Dense(1, kernel_initializer='normal', activation='sigmoid')\n",
" ])\n",
" model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])\n",
" return model\n",
"\n",
"# Wrap our Keras model with SciKeras KerasClassifier\n",
"estimator = KerasClassifier(model=create_model, epochs=100, verbose=0)\n",
"\n",
"# Assuming all_features and all_classes are defined and properly preprocessed\n",
"cv_scores = cross_val_score(estimator, all_features, all_classes, cv=10)\n",
"print(cv_scores.mean())\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"94% without even trying too hard! Did you do better? Maybe more neurons, more layers, or Dropout layers would help even more."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:base] *",
"language": "python",
"name": "conda-base-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}