|
@@ -0,0 +1,1262 @@
|
|
|
+{
|
|
|
+ "cells": [
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "bf68dfaf-37f5-463d-b72e-2120d27f0b0d",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "Imports for this Notebook:"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 234,
|
|
|
+ "id": "63528a79-842d-44ec-9109-29df1621cef8",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "# the basic imports\n",
|
|
|
+ "import matplotlib.pyplot as plt\n",
|
|
|
+ "import numpy as np\n",
|
|
|
+ "import pandas as pd\n",
|
|
|
+ "import torch\n",
|
|
|
+ "\n",
|
|
|
+ "# this loads the data, apparently\n",
|
|
|
+ "from torch.utils.data import DataLoader"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "ee773350-56e9-4601-a33e-b97c46a2cafc",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "Load and split the wine quality dataset with Pandas."
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 243,
|
|
|
+ "id": "d106a41b-e75c-4ca0-96e1-0e5b117992bf",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "# load data (only get rid of the Id column)\n",
|
|
|
+ "wine_data = pd.read_csv(\"./WineQT.csv\", delimiter=\",\").drop(\"Id\", axis=1)\n",
|
|
|
+ "\n",
|
|
|
+ "# split the dataset into model and test subsets\n",
|
|
|
+ "wine_model = wine_data.sample(frac=0.7, random_state=123)\n",
|
|
|
+ "wine_test = wine_data.drop(wine_train.index)\n",
|
|
|
+ "\n",
|
|
|
+ "# further split the training set into train and validation\n",
|
|
|
+ "wine_train = wine_model.sample(frac=0.7, random_state=123)\n",
|
|
|
+ "wine_validate = wine_model.drop(wine_train.index)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 244,
|
|
|
+ "id": "29301f87-1919-48ae-8c9e-1a55f4fb1495",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "source: 1143\n",
|
|
|
+ "train: 560 validate: 240 test: 343 == 1143\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "# verify the produced data size matches the source\n",
|
|
|
+ "print(\"source:\", len(wine_data))\n",
|
|
|
+ "print(\"train:\", len(wine_train), \"validate:\", len(wine_validate), \"test:\", len(wine_test), \"==\", len(wine_train + wine_validate + wine_test))"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 245,
|
|
|
+ "id": "3087bcdc-5c50-4b79-a2d8-f442a8d81825",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "data": {
|
|
|
+ "text/html": [
|
|
|
+ "<div>\n",
|
|
|
+ "<style scoped>\n",
|
|
|
+ " .dataframe tbody tr th:only-of-type {\n",
|
|
|
+ " vertical-align: middle;\n",
|
|
|
+ " }\n",
|
|
|
+ "\n",
|
|
|
+ " .dataframe tbody tr th {\n",
|
|
|
+ " vertical-align: top;\n",
|
|
|
+ " }\n",
|
|
|
+ "\n",
|
|
|
+ " .dataframe thead th {\n",
|
|
|
+ " text-align: right;\n",
|
|
|
+ " }\n",
|
|
|
+ "</style>\n",
|
|
|
+ "<table border=\"1\" class=\"dataframe\">\n",
|
|
|
+ " <thead>\n",
|
|
|
+ " <tr style=\"text-align: right;\">\n",
|
|
|
+ " <th></th>\n",
|
|
|
+ " <th>fixed acidity</th>\n",
|
|
|
+ " <th>volatile acidity</th>\n",
|
|
|
+ " <th>citric acid</th>\n",
|
|
|
+ " <th>residual sugar</th>\n",
|
|
|
+ " <th>chlorides</th>\n",
|
|
|
+ " <th>free sulfur dioxide</th>\n",
|
|
|
+ " <th>total sulfur dioxide</th>\n",
|
|
|
+ " <th>density</th>\n",
|
|
|
+ " <th>pH</th>\n",
|
|
|
+ " <th>sulphates</th>\n",
|
|
|
+ " <th>alcohol</th>\n",
|
|
|
+ " <th>quality</th>\n",
|
|
|
+ " </tr>\n",
|
|
|
+ " </thead>\n",
|
|
|
+ " <tbody>\n",
|
|
|
+ " <tr>\n",
|
|
|
+ " <th>711</th>\n",
|
|
|
+ " <td>7.5</td>\n",
|
|
|
+ " <td>0.71</td>\n",
|
|
|
+ " <td>0.00</td>\n",
|
|
|
+ " <td>1.6</td>\n",
|
|
|
+ " <td>0.092</td>\n",
|
|
|
+ " <td>22.0</td>\n",
|
|
|
+ " <td>31.0</td>\n",
|
|
|
+ " <td>0.99635</td>\n",
|
|
|
+ " <td>3.38</td>\n",
|
|
|
+ " <td>0.58</td>\n",
|
|
|
+ " <td>10.0</td>\n",
|
|
|
+ " <td>6</td>\n",
|
|
|
+ " </tr>\n",
|
|
|
+ " <tr>\n",
|
|
|
+ " <th>936</th>\n",
|
|
|
+ " <td>9.1</td>\n",
|
|
|
+ " <td>0.34</td>\n",
|
|
|
+ " <td>0.42</td>\n",
|
|
|
+ " <td>1.8</td>\n",
|
|
|
+ " <td>0.058</td>\n",
|
|
|
+ " <td>9.0</td>\n",
|
|
|
+ " <td>18.0</td>\n",
|
|
|
+ " <td>0.99392</td>\n",
|
|
|
+ " <td>3.18</td>\n",
|
|
|
+ " <td>0.55</td>\n",
|
|
|
+ " <td>11.4</td>\n",
|
|
|
+ " <td>5</td>\n",
|
|
|
+ " </tr>\n",
|
|
|
+ " <tr>\n",
|
|
|
+ " <th>683</th>\n",
|
|
|
+ " <td>10.4</td>\n",
|
|
|
+ " <td>0.26</td>\n",
|
|
|
+ " <td>0.48</td>\n",
|
|
|
+ " <td>1.9</td>\n",
|
|
|
+ " <td>0.066</td>\n",
|
|
|
+ " <td>6.0</td>\n",
|
|
|
+ " <td>10.0</td>\n",
|
|
|
+ " <td>0.99724</td>\n",
|
|
|
+ " <td>3.33</td>\n",
|
|
|
+ " <td>0.87</td>\n",
|
|
|
+ " <td>10.9</td>\n",
|
|
|
+ " <td>6</td>\n",
|
|
|
+ " </tr>\n",
|
|
|
+ " <tr>\n",
|
|
|
+ " <th>904</th>\n",
|
|
|
+ " <td>8.5</td>\n",
|
|
|
+ " <td>0.40</td>\n",
|
|
|
+ " <td>0.40</td>\n",
|
|
|
+ " <td>6.3</td>\n",
|
|
|
+ " <td>0.050</td>\n",
|
|
|
+ " <td>3.0</td>\n",
|
|
|
+ " <td>10.0</td>\n",
|
|
|
+ " <td>0.99566</td>\n",
|
|
|
+ " <td>3.28</td>\n",
|
|
|
+ " <td>0.56</td>\n",
|
|
|
+ " <td>12.0</td>\n",
|
|
|
+ " <td>4</td>\n",
|
|
|
+ " </tr>\n",
|
|
|
+ " <tr>\n",
|
|
|
+ " <th>198</th>\n",
|
|
|
+ " <td>8.9</td>\n",
|
|
|
+ " <td>0.40</td>\n",
|
|
|
+ " <td>0.32</td>\n",
|
|
|
+ " <td>5.6</td>\n",
|
|
|
+ " <td>0.087</td>\n",
|
|
|
+ " <td>10.0</td>\n",
|
|
|
+ " <td>47.0</td>\n",
|
|
|
+ " <td>0.99910</td>\n",
|
|
|
+ " <td>3.38</td>\n",
|
|
|
+ " <td>0.77</td>\n",
|
|
|
+ " <td>10.5</td>\n",
|
|
|
+ " <td>7</td>\n",
|
|
|
+ " </tr>\n",
|
|
|
+ " <tr>\n",
|
|
|
+ " <th>...</th>\n",
|
|
|
+ " <td>...</td>\n",
|
|
|
+ " <td>...</td>\n",
|
|
|
+ " <td>...</td>\n",
|
|
|
+ " <td>...</td>\n",
|
|
|
+ " <td>...</td>\n",
|
|
|
+ " <td>...</td>\n",
|
|
|
+ " <td>...</td>\n",
|
|
|
+ " <td>...</td>\n",
|
|
|
+ " <td>...</td>\n",
|
|
|
+ " <td>...</td>\n",
|
|
|
+ " <td>...</td>\n",
|
|
|
+ " <td>...</td>\n",
|
|
|
+ " </tr>\n",
|
|
|
+ " <tr>\n",
|
|
|
+ " <th>1076</th>\n",
|
|
|
+ " <td>7.5</td>\n",
|
|
|
+ " <td>0.38</td>\n",
|
|
|
+ " <td>0.57</td>\n",
|
|
|
+ " <td>2.3</td>\n",
|
|
|
+ " <td>0.106</td>\n",
|
|
|
+ " <td>5.0</td>\n",
|
|
|
+ " <td>12.0</td>\n",
|
|
|
+ " <td>0.99605</td>\n",
|
|
|
+ " <td>3.36</td>\n",
|
|
|
+ " <td>0.55</td>\n",
|
|
|
+ " <td>11.4</td>\n",
|
|
|
+ " <td>6</td>\n",
|
|
|
+ " </tr>\n",
|
|
|
+ " <tr>\n",
|
|
|
+ " <th>709</th>\n",
|
|
|
+ " <td>8.9</td>\n",
|
|
|
+ " <td>0.32</td>\n",
|
|
|
+ " <td>0.31</td>\n",
|
|
|
+ " <td>2.0</td>\n",
|
|
|
+ " <td>0.088</td>\n",
|
|
|
+ " <td>12.0</td>\n",
|
|
|
+ " <td>19.0</td>\n",
|
|
|
+ " <td>0.99570</td>\n",
|
|
|
+ " <td>3.17</td>\n",
|
|
|
+ " <td>0.55</td>\n",
|
|
|
+ " <td>10.4</td>\n",
|
|
|
+ " <td>6</td>\n",
|
|
|
+ " </tr>\n",
|
|
|
+ " <tr>\n",
|
|
|
+ " <th>486</th>\n",
|
|
|
+ " <td>8.1</td>\n",
|
|
|
+ " <td>0.78</td>\n",
|
|
|
+ " <td>0.23</td>\n",
|
|
|
+ " <td>2.6</td>\n",
|
|
|
+ " <td>0.059</td>\n",
|
|
|
+ " <td>5.0</td>\n",
|
|
|
+ " <td>15.0</td>\n",
|
|
|
+ " <td>0.99700</td>\n",
|
|
|
+ " <td>3.37</td>\n",
|
|
|
+ " <td>0.56</td>\n",
|
|
|
+ " <td>11.3</td>\n",
|
|
|
+ " <td>5</td>\n",
|
|
|
+ " </tr>\n",
|
|
|
+ " <tr>\n",
|
|
|
+ " <th>182</th>\n",
|
|
|
+ " <td>7.7</td>\n",
|
|
|
+ " <td>0.41</td>\n",
|
|
|
+ " <td>0.76</td>\n",
|
|
|
+ " <td>1.8</td>\n",
|
|
|
+ " <td>0.611</td>\n",
|
|
|
+ " <td>8.0</td>\n",
|
|
|
+ " <td>45.0</td>\n",
|
|
|
+ " <td>0.99680</td>\n",
|
|
|
+ " <td>3.06</td>\n",
|
|
|
+ " <td>1.26</td>\n",
|
|
|
+ " <td>9.4</td>\n",
|
|
|
+ " <td>5</td>\n",
|
|
|
+ " </tr>\n",
|
|
|
+ " <tr>\n",
|
|
|
+ " <th>206</th>\n",
|
|
|
+ " <td>8.7</td>\n",
|
|
|
+ " <td>0.52</td>\n",
|
|
|
+ " <td>0.09</td>\n",
|
|
|
+ " <td>2.5</td>\n",
|
|
|
+ " <td>0.091</td>\n",
|
|
|
+ " <td>20.0</td>\n",
|
|
|
+ " <td>49.0</td>\n",
|
|
|
+ " <td>0.99760</td>\n",
|
|
|
+ " <td>3.34</td>\n",
|
|
|
+ " <td>0.86</td>\n",
|
|
|
+ " <td>10.6</td>\n",
|
|
|
+ " <td>7</td>\n",
|
|
|
+ " </tr>\n",
|
|
|
+ " </tbody>\n",
|
|
|
+ "</table>\n",
|
|
|
+ "<p>560 rows × 12 columns</p>\n",
|
|
|
+ "</div>"
|
|
|
+ ],
|
|
|
+ "text/plain": [
|
|
|
+ " fixed acidity volatile acidity citric acid residual sugar chlorides \\\n",
|
|
|
+ "711 7.5 0.71 0.00 1.6 0.092 \n",
|
|
|
+ "936 9.1 0.34 0.42 1.8 0.058 \n",
|
|
|
+ "683 10.4 0.26 0.48 1.9 0.066 \n",
|
|
|
+ "904 8.5 0.40 0.40 6.3 0.050 \n",
|
|
|
+ "198 8.9 0.40 0.32 5.6 0.087 \n",
|
|
|
+ "... ... ... ... ... ... \n",
|
|
|
+ "1076 7.5 0.38 0.57 2.3 0.106 \n",
|
|
|
+ "709 8.9 0.32 0.31 2.0 0.088 \n",
|
|
|
+ "486 8.1 0.78 0.23 2.6 0.059 \n",
|
|
|
+ "182 7.7 0.41 0.76 1.8 0.611 \n",
|
|
|
+ "206 8.7 0.52 0.09 2.5 0.091 \n",
|
|
|
+ "\n",
|
|
|
+ " free sulfur dioxide total sulfur dioxide density pH sulphates \\\n",
|
|
|
+ "711 22.0 31.0 0.99635 3.38 0.58 \n",
|
|
|
+ "936 9.0 18.0 0.99392 3.18 0.55 \n",
|
|
|
+ "683 6.0 10.0 0.99724 3.33 0.87 \n",
|
|
|
+ "904 3.0 10.0 0.99566 3.28 0.56 \n",
|
|
|
+ "198 10.0 47.0 0.99910 3.38 0.77 \n",
|
|
|
+ "... ... ... ... ... ... \n",
|
|
|
+ "1076 5.0 12.0 0.99605 3.36 0.55 \n",
|
|
|
+ "709 12.0 19.0 0.99570 3.17 0.55 \n",
|
|
|
+ "486 5.0 15.0 0.99700 3.37 0.56 \n",
|
|
|
+ "182 8.0 45.0 0.99680 3.06 1.26 \n",
|
|
|
+ "206 20.0 49.0 0.99760 3.34 0.86 \n",
|
|
|
+ "\n",
|
|
|
+ " alcohol quality \n",
|
|
|
+ "711 10.0 6 \n",
|
|
|
+ "936 11.4 5 \n",
|
|
|
+ "683 10.9 6 \n",
|
|
|
+ "904 12.0 4 \n",
|
|
|
+ "198 10.5 7 \n",
|
|
|
+ "... ... ... \n",
|
|
|
+ "1076 11.4 6 \n",
|
|
|
+ "709 10.4 6 \n",
|
|
|
+ "486 11.3 5 \n",
|
|
|
+ "182 9.4 5 \n",
|
|
|
+ "206 10.6 7 \n",
|
|
|
+ "\n",
|
|
|
+ "[560 rows x 12 columns]"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ "execution_count": 245,
|
|
|
+ "metadata": {},
|
|
|
+ "output_type": "execute_result"
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "wine_train"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 246,
|
|
|
+ "id": "241f2cdb-a7fb-4165-be11-95ca6510082b",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "data": {
|
|
|
+ "text/html": [
|
|
|
+ "<div>\n",
|
|
|
+ "<style scoped>\n",
|
|
|
+ " .dataframe tbody tr th:only-of-type {\n",
|
|
|
+ " vertical-align: middle;\n",
|
|
|
+ " }\n",
|
|
|
+ "\n",
|
|
|
+ " .dataframe tbody tr th {\n",
|
|
|
+ " vertical-align: top;\n",
|
|
|
+ " }\n",
|
|
|
+ "\n",
|
|
|
+ " .dataframe thead th {\n",
|
|
|
+ " text-align: right;\n",
|
|
|
+ " }\n",
|
|
|
+ "</style>\n",
|
|
|
+ "<table border=\"1\" class=\"dataframe\">\n",
|
|
|
+ " <thead>\n",
|
|
|
+ " <tr style=\"text-align: right;\">\n",
|
|
|
+ " <th></th>\n",
|
|
|
+ " <th>fixed acidity</th>\n",
|
|
|
+ " <th>volatile acidity</th>\n",
|
|
|
+ " <th>citric acid</th>\n",
|
|
|
+ " <th>residual sugar</th>\n",
|
|
|
+ " <th>chlorides</th>\n",
|
|
|
+ " <th>free sulfur dioxide</th>\n",
|
|
|
+ " <th>total sulfur dioxide</th>\n",
|
|
|
+ " <th>density</th>\n",
|
|
|
+ " <th>pH</th>\n",
|
|
|
+ " <th>sulphates</th>\n",
|
|
|
+ " <th>alcohol</th>\n",
|
|
|
+ " <th>quality</th>\n",
|
|
|
+ " </tr>\n",
|
|
|
+ " </thead>\n",
|
|
|
+ " <tbody>\n",
|
|
|
+ " <tr>\n",
|
|
|
+ " <th>309</th>\n",
|
|
|
+ " <td>7.0</td>\n",
|
|
|
+ " <td>0.620</td>\n",
|
|
|
+ " <td>0.18</td>\n",
|
|
|
+ " <td>1.5</td>\n",
|
|
|
+ " <td>0.062</td>\n",
|
|
|
+ " <td>7.0</td>\n",
|
|
|
+ " <td>50.0</td>\n",
|
|
|
+ " <td>0.99510</td>\n",
|
|
|
+ " <td>3.08</td>\n",
|
|
|
+ " <td>0.60</td>\n",
|
|
|
+ " <td>9.3</td>\n",
|
|
|
+ " <td>5</td>\n",
|
|
|
+ " </tr>\n",
|
|
|
+ " <tr>\n",
|
|
|
+ " <th>528</th>\n",
|
|
|
+ " <td>8.3</td>\n",
|
|
|
+ " <td>0.760</td>\n",
|
|
|
+ " <td>0.29</td>\n",
|
|
|
+ " <td>4.2</td>\n",
|
|
|
+ " <td>0.075</td>\n",
|
|
|
+ " <td>12.0</td>\n",
|
|
|
+ " <td>16.0</td>\n",
|
|
|
+ " <td>0.99650</td>\n",
|
|
|
+ " <td>3.45</td>\n",
|
|
|
+ " <td>0.68</td>\n",
|
|
|
+ " <td>11.5</td>\n",
|
|
|
+ " <td>6</td>\n",
|
|
|
+ " </tr>\n",
|
|
|
+ " <tr>\n",
|
|
|
+ " <th>150</th>\n",
|
|
|
+ " <td>8.2</td>\n",
|
|
|
+ " <td>0.570</td>\n",
|
|
|
+ " <td>0.26</td>\n",
|
|
|
+ " <td>2.2</td>\n",
|
|
|
+ " <td>0.060</td>\n",
|
|
|
+ " <td>28.0</td>\n",
|
|
|
+ " <td>65.0</td>\n",
|
|
|
+ " <td>0.99590</td>\n",
|
|
|
+ " <td>3.30</td>\n",
|
|
|
+ " <td>0.43</td>\n",
|
|
|
+ " <td>10.1</td>\n",
|
|
|
+ " <td>5</td>\n",
|
|
|
+ " </tr>\n",
|
|
|
+ " <tr>\n",
|
|
|
+ " <th>278</th>\n",
|
|
|
+ " <td>6.6</td>\n",
|
|
|
+ " <td>0.735</td>\n",
|
|
|
+ " <td>0.02</td>\n",
|
|
|
+ " <td>7.9</td>\n",
|
|
|
+ " <td>0.122</td>\n",
|
|
|
+ " <td>68.0</td>\n",
|
|
|
+ " <td>124.0</td>\n",
|
|
|
+ " <td>0.99940</td>\n",
|
|
|
+ " <td>3.47</td>\n",
|
|
|
+ " <td>0.53</td>\n",
|
|
|
+ " <td>9.9</td>\n",
|
|
|
+ " <td>5</td>\n",
|
|
|
+ " </tr>\n",
|
|
|
+ " <tr>\n",
|
|
|
+ " <th>490</th>\n",
|
|
|
+ " <td>8.6</td>\n",
|
|
|
+ " <td>0.490</td>\n",
|
|
|
+ " <td>0.51</td>\n",
|
|
|
+ " <td>2.0</td>\n",
|
|
|
+ " <td>0.422</td>\n",
|
|
|
+ " <td>16.0</td>\n",
|
|
|
+ " <td>62.0</td>\n",
|
|
|
+ " <td>0.99790</td>\n",
|
|
|
+ " <td>3.03</td>\n",
|
|
|
+ " <td>1.17</td>\n",
|
|
|
+ " <td>9.0</td>\n",
|
|
|
+ " <td>5</td>\n",
|
|
|
+ " </tr>\n",
|
|
|
+ " <tr>\n",
|
|
|
+ " <th>...</th>\n",
|
|
|
+ " <td>...</td>\n",
|
|
|
+ " <td>...</td>\n",
|
|
|
+ " <td>...</td>\n",
|
|
|
+ " <td>...</td>\n",
|
|
|
+ " <td>...</td>\n",
|
|
|
+ " <td>...</td>\n",
|
|
|
+ " <td>...</td>\n",
|
|
|
+ " <td>...</td>\n",
|
|
|
+ " <td>...</td>\n",
|
|
|
+ " <td>...</td>\n",
|
|
|
+ " <td>...</td>\n",
|
|
|
+ " <td>...</td>\n",
|
|
|
+ " </tr>\n",
|
|
|
+ " <tr>\n",
|
|
|
+ " <th>896</th>\n",
|
|
|
+ " <td>10.4</td>\n",
|
|
|
+ " <td>0.430</td>\n",
|
|
|
+ " <td>0.50</td>\n",
|
|
|
+ " <td>2.3</td>\n",
|
|
|
+ " <td>0.068</td>\n",
|
|
|
+ " <td>13.0</td>\n",
|
|
|
+ " <td>19.0</td>\n",
|
|
|
+ " <td>0.99600</td>\n",
|
|
|
+ " <td>3.10</td>\n",
|
|
|
+ " <td>0.87</td>\n",
|
|
|
+ " <td>11.4</td>\n",
|
|
|
+ " <td>6</td>\n",
|
|
|
+ " </tr>\n",
|
|
|
+ " <tr>\n",
|
|
|
+ " <th>922</th>\n",
|
|
|
+ " <td>7.6</td>\n",
|
|
|
+ " <td>1.580</td>\n",
|
|
|
+ " <td>0.00</td>\n",
|
|
|
+ " <td>2.1</td>\n",
|
|
|
+ " <td>0.137</td>\n",
|
|
|
+ " <td>5.0</td>\n",
|
|
|
+ " <td>9.0</td>\n",
|
|
|
+ " <td>0.99476</td>\n",
|
|
|
+ " <td>3.50</td>\n",
|
|
|
+ " <td>0.40</td>\n",
|
|
|
+ " <td>10.9</td>\n",
|
|
|
+ " <td>3</td>\n",
|
|
|
+ " </tr>\n",
|
|
|
+ " <tr>\n",
|
|
|
+ " <th>219</th>\n",
|
|
|
+ " <td>8.4</td>\n",
|
|
|
+ " <td>0.650</td>\n",
|
|
|
+ " <td>0.60</td>\n",
|
|
|
+ " <td>2.1</td>\n",
|
|
|
+ " <td>0.112</td>\n",
|
|
|
+ " <td>12.0</td>\n",
|
|
|
+ " <td>90.0</td>\n",
|
|
|
+ " <td>0.99730</td>\n",
|
|
|
+ " <td>3.20</td>\n",
|
|
|
+ " <td>0.52</td>\n",
|
|
|
+ " <td>9.2</td>\n",
|
|
|
+ " <td>5</td>\n",
|
|
|
+ " </tr>\n",
|
|
|
+ " <tr>\n",
|
|
|
+ " <th>970</th>\n",
|
|
|
+ " <td>7.3</td>\n",
|
|
|
+ " <td>0.740</td>\n",
|
|
|
+ " <td>0.08</td>\n",
|
|
|
+ " <td>1.7</td>\n",
|
|
|
+ " <td>0.094</td>\n",
|
|
|
+ " <td>10.0</td>\n",
|
|
|
+ " <td>45.0</td>\n",
|
|
|
+ " <td>0.99576</td>\n",
|
|
|
+ " <td>3.24</td>\n",
|
|
|
+ " <td>0.50</td>\n",
|
|
|
+ " <td>9.8</td>\n",
|
|
|
+ " <td>5</td>\n",
|
|
|
+ " </tr>\n",
|
|
|
+ " <tr>\n",
|
|
|
+ " <th>288</th>\n",
|
|
|
+ " <td>8.8</td>\n",
|
|
|
+ " <td>0.520</td>\n",
|
|
|
+ " <td>0.34</td>\n",
|
|
|
+ " <td>2.7</td>\n",
|
|
|
+ " <td>0.087</td>\n",
|
|
|
+ " <td>24.0</td>\n",
|
|
|
+ " <td>122.0</td>\n",
|
|
|
+ " <td>0.99820</td>\n",
|
|
|
+ " <td>3.26</td>\n",
|
|
|
+ " <td>0.61</td>\n",
|
|
|
+ " <td>9.5</td>\n",
|
|
|
+ " <td>5</td>\n",
|
|
|
+ " </tr>\n",
|
|
|
+ " </tbody>\n",
|
|
|
+ "</table>\n",
|
|
|
+ "<p>240 rows × 12 columns</p>\n",
|
|
|
+ "</div>"
|
|
|
+ ],
|
|
|
+ "text/plain": [
|
|
|
+ " fixed acidity volatile acidity citric acid residual sugar chlorides \\\n",
|
|
|
+ "309 7.0 0.620 0.18 1.5 0.062 \n",
|
|
|
+ "528 8.3 0.760 0.29 4.2 0.075 \n",
|
|
|
+ "150 8.2 0.570 0.26 2.2 0.060 \n",
|
|
|
+ "278 6.6 0.735 0.02 7.9 0.122 \n",
|
|
|
+ "490 8.6 0.490 0.51 2.0 0.422 \n",
|
|
|
+ ".. ... ... ... ... ... \n",
|
|
|
+ "896 10.4 0.430 0.50 2.3 0.068 \n",
|
|
|
+ "922 7.6 1.580 0.00 2.1 0.137 \n",
|
|
|
+ "219 8.4 0.650 0.60 2.1 0.112 \n",
|
|
|
+ "970 7.3 0.740 0.08 1.7 0.094 \n",
|
|
|
+ "288 8.8 0.520 0.34 2.7 0.087 \n",
|
|
|
+ "\n",
|
|
|
+ " free sulfur dioxide total sulfur dioxide density pH sulphates \\\n",
|
|
|
+ "309 7.0 50.0 0.99510 3.08 0.60 \n",
|
|
|
+ "528 12.0 16.0 0.99650 3.45 0.68 \n",
|
|
|
+ "150 28.0 65.0 0.99590 3.30 0.43 \n",
|
|
|
+ "278 68.0 124.0 0.99940 3.47 0.53 \n",
|
|
|
+ "490 16.0 62.0 0.99790 3.03 1.17 \n",
|
|
|
+ ".. ... ... ... ... ... \n",
|
|
|
+ "896 13.0 19.0 0.99600 3.10 0.87 \n",
|
|
|
+ "922 5.0 9.0 0.99476 3.50 0.40 \n",
|
|
|
+ "219 12.0 90.0 0.99730 3.20 0.52 \n",
|
|
|
+ "970 10.0 45.0 0.99576 3.24 0.50 \n",
|
|
|
+ "288 24.0 122.0 0.99820 3.26 0.61 \n",
|
|
|
+ "\n",
|
|
|
+ " alcohol quality \n",
|
|
|
+ "309 9.3 5 \n",
|
|
|
+ "528 11.5 6 \n",
|
|
|
+ "150 10.1 5 \n",
|
|
|
+ "278 9.9 5 \n",
|
|
|
+ "490 9.0 5 \n",
|
|
|
+ ".. ... ... \n",
|
|
|
+ "896 11.4 6 \n",
|
|
|
+ "922 10.9 3 \n",
|
|
|
+ "219 9.2 5 \n",
|
|
|
+ "970 9.8 5 \n",
|
|
|
+ "288 9.5 5 \n",
|
|
|
+ "\n",
|
|
|
+ "[240 rows x 12 columns]"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ "execution_count": 246,
|
|
|
+ "metadata": {},
|
|
|
+ "output_type": "execute_result"
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "wine_validate"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 247,
|
|
|
+ "id": "b1e11d58-2415-4a17-9883-410e1bd15c21",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "data": {
|
|
|
+ "text/html": [
|
|
|
+ "<div>\n",
|
|
|
+ "<style scoped>\n",
|
|
|
+ " .dataframe tbody tr th:only-of-type {\n",
|
|
|
+ " vertical-align: middle;\n",
|
|
|
+ " }\n",
|
|
|
+ "\n",
|
|
|
+ " .dataframe tbody tr th {\n",
|
|
|
+ " vertical-align: top;\n",
|
|
|
+ " }\n",
|
|
|
+ "\n",
|
|
|
+ " .dataframe thead th {\n",
|
|
|
+ " text-align: right;\n",
|
|
|
+ " }\n",
|
|
|
+ "</style>\n",
|
|
|
+ "<table border=\"1\" class=\"dataframe\">\n",
|
|
|
+ " <thead>\n",
|
|
|
+ " <tr style=\"text-align: right;\">\n",
|
|
|
+ " <th></th>\n",
|
|
|
+ " <th>fixed acidity</th>\n",
|
|
|
+ " <th>volatile acidity</th>\n",
|
|
|
+ " <th>citric acid</th>\n",
|
|
|
+ " <th>residual sugar</th>\n",
|
|
|
+ " <th>chlorides</th>\n",
|
|
|
+ " <th>free sulfur dioxide</th>\n",
|
|
|
+ " <th>total sulfur dioxide</th>\n",
|
|
|
+ " <th>density</th>\n",
|
|
|
+ " <th>pH</th>\n",
|
|
|
+ " <th>sulphates</th>\n",
|
|
|
+ " <th>alcohol</th>\n",
|
|
|
+ " <th>quality</th>\n",
|
|
|
+ " </tr>\n",
|
|
|
+ " </thead>\n",
|
|
|
+ " <tbody>\n",
|
|
|
+ " <tr>\n",
|
|
|
+ " <th>1</th>\n",
|
|
|
+ " <td>7.8</td>\n",
|
|
|
+ " <td>0.88</td>\n",
|
|
|
+ " <td>0.00</td>\n",
|
|
|
+ " <td>2.6</td>\n",
|
|
|
+ " <td>0.098</td>\n",
|
|
|
+ " <td>25.0</td>\n",
|
|
|
+ " <td>67.0</td>\n",
|
|
|
+ " <td>0.99680</td>\n",
|
|
|
+ " <td>3.20</td>\n",
|
|
|
+ " <td>0.68</td>\n",
|
|
|
+ " <td>9.8</td>\n",
|
|
|
+ " <td>5</td>\n",
|
|
|
+ " </tr>\n",
|
|
|
+ " <tr>\n",
|
|
|
+ " <th>2</th>\n",
|
|
|
+ " <td>7.8</td>\n",
|
|
|
+ " <td>0.76</td>\n",
|
|
|
+ " <td>0.04</td>\n",
|
|
|
+ " <td>2.3</td>\n",
|
|
|
+ " <td>0.092</td>\n",
|
|
|
+ " <td>15.0</td>\n",
|
|
|
+ " <td>54.0</td>\n",
|
|
|
+ " <td>0.99700</td>\n",
|
|
|
+ " <td>3.26</td>\n",
|
|
|
+ " <td>0.65</td>\n",
|
|
|
+ " <td>9.8</td>\n",
|
|
|
+ " <td>5</td>\n",
|
|
|
+ " </tr>\n",
|
|
|
+ " <tr>\n",
|
|
|
+ " <th>3</th>\n",
|
|
|
+ " <td>11.2</td>\n",
|
|
|
+ " <td>0.28</td>\n",
|
|
|
+ " <td>0.56</td>\n",
|
|
|
+ " <td>1.9</td>\n",
|
|
|
+ " <td>0.075</td>\n",
|
|
|
+ " <td>17.0</td>\n",
|
|
|
+ " <td>60.0</td>\n",
|
|
|
+ " <td>0.99800</td>\n",
|
|
|
+ " <td>3.16</td>\n",
|
|
|
+ " <td>0.58</td>\n",
|
|
|
+ " <td>9.8</td>\n",
|
|
|
+ " <td>6</td>\n",
|
|
|
+ " </tr>\n",
|
|
|
+ " <tr>\n",
|
|
|
+ " <th>6</th>\n",
|
|
|
+ " <td>7.9</td>\n",
|
|
|
+ " <td>0.60</td>\n",
|
|
|
+ " <td>0.06</td>\n",
|
|
|
+ " <td>1.6</td>\n",
|
|
|
+ " <td>0.069</td>\n",
|
|
|
+ " <td>15.0</td>\n",
|
|
|
+ " <td>59.0</td>\n",
|
|
|
+ " <td>0.99640</td>\n",
|
|
|
+ " <td>3.30</td>\n",
|
|
|
+ " <td>0.46</td>\n",
|
|
|
+ " <td>9.4</td>\n",
|
|
|
+ " <td>5</td>\n",
|
|
|
+ " </tr>\n",
|
|
|
+ " <tr>\n",
|
|
|
+ " <th>8</th>\n",
|
|
|
+ " <td>7.8</td>\n",
|
|
|
+ " <td>0.58</td>\n",
|
|
|
+ " <td>0.02</td>\n",
|
|
|
+ " <td>2.0</td>\n",
|
|
|
+ " <td>0.073</td>\n",
|
|
|
+ " <td>9.0</td>\n",
|
|
|
+ " <td>18.0</td>\n",
|
|
|
+ " <td>0.99680</td>\n",
|
|
|
+ " <td>3.36</td>\n",
|
|
|
+ " <td>0.57</td>\n",
|
|
|
+ " <td>9.5</td>\n",
|
|
|
+ " <td>7</td>\n",
|
|
|
+ " </tr>\n",
|
|
|
+ " <tr>\n",
|
|
|
+ " <th>...</th>\n",
|
|
|
+ " <td>...</td>\n",
|
|
|
+ " <td>...</td>\n",
|
|
|
+ " <td>...</td>\n",
|
|
|
+ " <td>...</td>\n",
|
|
|
+ " <td>...</td>\n",
|
|
|
+ " <td>...</td>\n",
|
|
|
+ " <td>...</td>\n",
|
|
|
+ " <td>...</td>\n",
|
|
|
+ " <td>...</td>\n",
|
|
|
+ " <td>...</td>\n",
|
|
|
+ " <td>...</td>\n",
|
|
|
+ " <td>...</td>\n",
|
|
|
+ " </tr>\n",
|
|
|
+ " <tr>\n",
|
|
|
+ " <th>1128</th>\n",
|
|
|
+ " <td>6.2</td>\n",
|
|
|
+ " <td>0.70</td>\n",
|
|
|
+ " <td>0.15</td>\n",
|
|
|
+ " <td>5.1</td>\n",
|
|
|
+ " <td>0.076</td>\n",
|
|
|
+ " <td>13.0</td>\n",
|
|
|
+ " <td>27.0</td>\n",
|
|
|
+ " <td>0.99622</td>\n",
|
|
|
+ " <td>3.54</td>\n",
|
|
|
+ " <td>0.60</td>\n",
|
|
|
+ " <td>11.9</td>\n",
|
|
|
+ " <td>6</td>\n",
|
|
|
+ " </tr>\n",
|
|
|
+ " <tr>\n",
|
|
|
+ " <th>1133</th>\n",
|
|
|
+ " <td>6.7</td>\n",
|
|
|
+ " <td>0.32</td>\n",
|
|
|
+ " <td>0.44</td>\n",
|
|
|
+ " <td>2.4</td>\n",
|
|
|
+ " <td>0.061</td>\n",
|
|
|
+ " <td>24.0</td>\n",
|
|
|
+ " <td>34.0</td>\n",
|
|
|
+ " <td>0.99484</td>\n",
|
|
|
+ " <td>3.29</td>\n",
|
|
|
+ " <td>0.80</td>\n",
|
|
|
+ " <td>11.6</td>\n",
|
|
|
+ " <td>7</td>\n",
|
|
|
+ " </tr>\n",
|
|
|
+ " <tr>\n",
|
|
|
+ " <th>1135</th>\n",
|
|
|
+ " <td>5.8</td>\n",
|
|
|
+ " <td>0.61</td>\n",
|
|
|
+ " <td>0.11</td>\n",
|
|
|
+ " <td>1.8</td>\n",
|
|
|
+ " <td>0.066</td>\n",
|
|
|
+ " <td>18.0</td>\n",
|
|
|
+ " <td>28.0</td>\n",
|
|
|
+ " <td>0.99483</td>\n",
|
|
|
+ " <td>3.55</td>\n",
|
|
|
+ " <td>0.66</td>\n",
|
|
|
+ " <td>10.9</td>\n",
|
|
|
+ " <td>6</td>\n",
|
|
|
+ " </tr>\n",
|
|
|
+ " <tr>\n",
|
|
|
+ " <th>1137</th>\n",
|
|
|
+ " <td>5.4</td>\n",
|
|
|
+ " <td>0.74</td>\n",
|
|
|
+ " <td>0.09</td>\n",
|
|
|
+ " <td>1.7</td>\n",
|
|
|
+ " <td>0.089</td>\n",
|
|
|
+ " <td>16.0</td>\n",
|
|
|
+ " <td>26.0</td>\n",
|
|
|
+ " <td>0.99402</td>\n",
|
|
|
+ " <td>3.67</td>\n",
|
|
|
+ " <td>0.56</td>\n",
|
|
|
+ " <td>11.6</td>\n",
|
|
|
+ " <td>6</td>\n",
|
|
|
+ " </tr>\n",
|
|
|
+ " <tr>\n",
|
|
|
+ " <th>1140</th>\n",
|
|
|
+ " <td>6.2</td>\n",
|
|
|
+ " <td>0.60</td>\n",
|
|
|
+ " <td>0.08</td>\n",
|
|
|
+ " <td>2.0</td>\n",
|
|
|
+ " <td>0.090</td>\n",
|
|
|
+ " <td>32.0</td>\n",
|
|
|
+ " <td>44.0</td>\n",
|
|
|
+ " <td>0.99490</td>\n",
|
|
|
+ " <td>3.45</td>\n",
|
|
|
+ " <td>0.58</td>\n",
|
|
|
+ " <td>10.5</td>\n",
|
|
|
+ " <td>5</td>\n",
|
|
|
+ " </tr>\n",
|
|
|
+ " </tbody>\n",
|
|
|
+ "</table>\n",
|
|
|
+ "<p>343 rows × 12 columns</p>\n",
|
|
|
+ "</div>"
|
|
|
+ ],
|
|
|
+ "text/plain": [
|
|
|
+ " fixed acidity volatile acidity citric acid residual sugar chlorides \\\n",
|
|
|
+ "1 7.8 0.88 0.00 2.6 0.098 \n",
|
|
|
+ "2 7.8 0.76 0.04 2.3 0.092 \n",
|
|
|
+ "3 11.2 0.28 0.56 1.9 0.075 \n",
|
|
|
+ "6 7.9 0.60 0.06 1.6 0.069 \n",
|
|
|
+ "8 7.8 0.58 0.02 2.0 0.073 \n",
|
|
|
+ "... ... ... ... ... ... \n",
|
|
|
+ "1128 6.2 0.70 0.15 5.1 0.076 \n",
|
|
|
+ "1133 6.7 0.32 0.44 2.4 0.061 \n",
|
|
|
+ "1135 5.8 0.61 0.11 1.8 0.066 \n",
|
|
|
+ "1137 5.4 0.74 0.09 1.7 0.089 \n",
|
|
|
+ "1140 6.2 0.60 0.08 2.0 0.090 \n",
|
|
|
+ "\n",
|
|
|
+ " free sulfur dioxide total sulfur dioxide density pH sulphates \\\n",
|
|
|
+ "1 25.0 67.0 0.99680 3.20 0.68 \n",
|
|
|
+ "2 15.0 54.0 0.99700 3.26 0.65 \n",
|
|
|
+ "3 17.0 60.0 0.99800 3.16 0.58 \n",
|
|
|
+ "6 15.0 59.0 0.99640 3.30 0.46 \n",
|
|
|
+ "8 9.0 18.0 0.99680 3.36 0.57 \n",
|
|
|
+ "... ... ... ... ... ... \n",
|
|
|
+ "1128 13.0 27.0 0.99622 3.54 0.60 \n",
|
|
|
+ "1133 24.0 34.0 0.99484 3.29 0.80 \n",
|
|
|
+ "1135 18.0 28.0 0.99483 3.55 0.66 \n",
|
|
|
+ "1137 16.0 26.0 0.99402 3.67 0.56 \n",
|
|
|
+ "1140 32.0 44.0 0.99490 3.45 0.58 \n",
|
|
|
+ "\n",
|
|
|
+ " alcohol quality \n",
|
|
|
+ "1 9.8 5 \n",
|
|
|
+ "2 9.8 5 \n",
|
|
|
+ "3 9.8 6 \n",
|
|
|
+ "6 9.4 5 \n",
|
|
|
+ "8 9.5 7 \n",
|
|
|
+ "... ... ... \n",
|
|
|
+ "1128 11.9 6 \n",
|
|
|
+ "1133 11.6 7 \n",
|
|
|
+ "1135 10.9 6 \n",
|
|
|
+ "1137 11.6 6 \n",
|
|
|
+ "1140 10.5 5 \n",
|
|
|
+ "\n",
|
|
|
+ "[343 rows x 12 columns]"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ "execution_count": 247,
|
|
|
+ "metadata": {},
|
|
|
+ "output_type": "execute_result"
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "wine_test"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 248,
|
|
|
+ "id": "88bc3d8e-10c7-4887-83c8-21e0274349a2",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "# isolate results from features\n",
|
|
|
+ "# now this is the hassle most of the time - convert whatever the input is to whatever it's expected to be\n",
|
|
|
+ "# in our case, pandas data frames to pytorch tensors\n",
|
|
|
+ "train_features = torch.tensor(wine_train.drop('quality', axis=1).values.astype(np.float32))\n",
|
|
|
+ "train_target = torch.tensor(wine_train['quality'].values.astype(np.int64))\n",
|
|
|
+ "validate_features = torch.tensor(wine_validate.drop('quality', axis=1).values.astype(np.float32))\n",
|
|
|
+ "validate_target = torch.tensor(wine_validate['quality'].values.astype(np.int64))\n",
|
|
|
+ "test_features = torch.tensor(wine_test.drop('quality', axis=1).values.astype(np.float32))\n",
|
|
|
+ "test_target = torch.tensor(wine_test['quality'].values.astype(np.int64))\n",
|
|
|
+ "\n",
|
|
|
+ "train_data = torch.utils.data.TensorDataset(train_features, train_target)\n",
|
|
|
+ "validate_data = torch.utils.data.TensorDataset(validate_features, validate_target)\n",
|
|
|
+ "test_data = torch.utils.data.TensorDataset(test_features, test_target)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 249,
|
|
|
+ "id": "21a01ec6-44d9-4f50-b093-e80816345dd0",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "data": {
|
|
|
+ "text/plain": [
|
|
|
+ "(tensor([ 7.5000, 0.7100, 0.0000, 1.6000, 0.0920, 22.0000, 31.0000, 0.9963,\n",
|
|
|
+ " 3.3800, 0.5800, 10.0000]),\n",
|
|
|
+ " tensor(6))"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ "execution_count": 249,
|
|
|
+ "metadata": {},
|
|
|
+ "output_type": "execute_result"
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "train_data[0]"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 250,
|
|
|
+ "id": "84198c1f-5a45-45ad-94f8-9e3de5a4efd5",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "data": {
|
|
|
+ "text/plain": [
|
|
|
+ "(tensor([ 7.0000, 0.6200, 0.1800, 1.5000, 0.0620, 7.0000, 50.0000, 0.9951,\n",
|
|
|
+ " 3.0800, 0.6000, 9.3000]),\n",
|
|
|
+ " tensor(5))"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ "execution_count": 250,
|
|
|
+ "metadata": {},
|
|
|
+ "output_type": "execute_result"
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "validate_data[0]"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 251,
|
|
|
+ "id": "c0685118-8baf-48d1-a554-a8b94a85d39a",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "data": {
|
|
|
+ "text/plain": [
|
|
|
+ "(tensor([ 7.8000, 0.8800, 0.0000, 2.6000, 0.0980, 25.0000, 67.0000, 0.9968,\n",
|
|
|
+ " 3.2000, 0.6800, 9.8000]),\n",
|
|
|
+ " tensor(5))"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ "execution_count": 251,
|
|
|
+ "metadata": {},
|
|
|
+ "output_type": "execute_result"
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "test_data[0]"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 252,
|
|
|
+ "id": "c59d990a-3f85-4981-a288-8877ef1fa34c",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "# create data loaders\n",
|
|
|
+ "train_loader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)\n",
|
|
|
+ "validate_loader = torch.utils.data.DataLoader(validate_data, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)\n",
|
|
|
+ "test_loader = torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 224,
|
|
|
+ "id": "e3aaa169-e086-4a19-865d-557c0c2aa210",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "accelerated by mps\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "# check what device to use\n",
|
|
|
+ "from torch.accelerator import is_available, current_accelerator\n",
|
|
|
+ "if (is_available()):\n",
|
|
|
+ " device = current_accelerator().type\n",
|
|
|
+ " print(f\"accelerated by {device}\")\n",
|
|
|
+ "else:\n",
|
|
|
+ " device = \"cpu\"\n",
|
|
|
+ " print(\"no accelerator\")"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 229,
|
|
|
+ "id": "a1ee2d09-caed-4d39-ab51-cc27504753ef",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "**** DEFINING Linear/ReLU layer stack ****\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "# define the model\n",
|
|
|
+ "print(\"**** DEFINING Linear/ReLU layer stack ****\")\n",
|
|
|
+ "\n",
|
|
|
+ "from torch import nn\n",
|
|
|
+ "\n",
|
|
|
+ "class SeqNN(nn.Module):\n",
|
|
|
+ " def __init__(self):\n",
|
|
|
+ " super().__init__()\n",
|
|
|
+ " self.linear_relu_stack = nn.Sequential(\n",
|
|
|
+ " nn.Linear(11, 512),\n",
|
|
|
+ " nn.ReLU(),\n",
|
|
|
+ " nn.Linear(512, 10),\n",
|
|
|
+ " nn.Softmax(),\n",
|
|
|
+ " )\n",
|
|
|
+ "\n",
|
|
|
+ " def forward(self, x):\n",
|
|
|
+ " logits = self.linear_relu_stack(x)\n",
|
|
|
+ " return logits\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 230,
|
|
|
+ "id": "9e95ca59-e228-4b84-b3ec-f12c6e6dcb02",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "SeqNN(\n",
|
|
|
+ " (linear_relu_stack): Sequential(\n",
|
|
|
+ " (0): Linear(in_features=11, out_features=512, bias=True)\n",
|
|
|
+ " (1): ReLU()\n",
|
|
|
+ " (2): Linear(in_features=512, out_features=10, bias=True)\n",
|
|
|
+ " (3): Softmax(dim=None)\n",
|
|
|
+ " )\n",
|
|
|
+ ")\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "# show what the model looks like\n",
|
|
|
+ "model = SeqNN() #.to(device)\n",
|
|
|
+ "print(model)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 227,
|
|
|
+ "id": "6df0edaa-bf1a-4ff5-ae41-2f1c03811c7c",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "# define the loss function and optimizer\n",
|
|
|
+ "loss_fn = torch.nn.CrossEntropyLoss()\n",
|
|
|
+ "optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.90)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 257,
|
|
|
+ "id": "102c4fc5-0798-40c5-bb32-9c87fc7707d5",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "def train_epoch(ep_num):\n",
|
|
|
+ " running_loss = 0.\n",
|
|
|
+ " last_loss = 0.\n",
|
|
|
+ "\n",
|
|
|
+ " report_every = 5\n",
|
|
|
+ "\n",
|
|
|
+ " for idx, batch in enumerate(train_loader):\n",
|
|
|
+ " # unpack the next batch\n",
|
|
|
+ " features, labels = batch\n",
|
|
|
+ " \n",
|
|
|
+ " # ensure features are on the accelerator device (if any)\n",
|
|
|
+ " features.to(device)\n",
|
|
|
+ " \n",
|
|
|
+ " # zero the optimizer gradients\n",
|
|
|
+ " optimizer.zero_grad()\n",
|
|
|
+ " \n",
|
|
|
+ " # feed it to model\n",
|
|
|
+ " outputs = model(features)\n",
|
|
|
+ " \n",
|
|
|
+ " # calculate the loss against the expected label\n",
|
|
|
+ " loss = loss_fn(outputs, labels)\n",
|
|
|
+ " loss.backward()\n",
|
|
|
+ " \n",
|
|
|
+ " # feed the new loss info to optimizer\n",
|
|
|
+ " optimizer.step()\n",
|
|
|
+ " \n",
|
|
|
+ " # calculate whether this is an improvement or a degradation\n",
|
|
|
+ " running_loss += loss.item()\n",
|
|
|
+ "\n",
|
|
|
+ " # report loss uppon every 20 items\n",
|
|
|
+ " if idx % report_every == (report_every - 1):\n",
|
|
|
+ " last_loss = running_loss / report_every\n",
|
|
|
+ " print(\" batch {} loss: {}\".format(idx + 1, last_loss))\n",
|
|
|
+ " tb_x = ep_num * len(train_loader) + idx + 1\n",
|
|
|
+ " print(\" loss/train: {}/{}\".format(last_loss, tb_x))\n",
|
|
|
+ " running_loss = 0.\n",
|
|
|
+ "\n",
|
|
|
+ " return last_loss"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 258,
|
|
|
+ "id": "cdf2c8ae-d024-4687-b1df-23804f3d9816",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "**** TRAINING NN on supplied data ****\n",
|
|
|
+ "EPOCH 1\n",
|
|
|
+ " batch 5 loss: 2.4407766819000245\n",
|
|
|
+ " loss/train: 2.4407766819000245/5\n",
|
|
|
+ " batch 10 loss: 2.451565170288086\n",
|
|
|
+ " loss/train: 2.451565170288086/10\n",
|
|
|
+ " batch 15 loss: 2.43599009513855\n",
|
|
|
+ " loss/train: 2.43599009513855/15\n",
|
|
|
+ "LOSS train 2.43599009513855 / valid 2.4329328536987305\n",
|
|
|
+ "EPOCH 2\n",
|
|
|
+ " batch 5 loss: 2.444344425201416\n",
|
|
|
+ " loss/train: 2.444344425201416/23\n",
|
|
|
+ " batch 10 loss: 2.4506516456604004\n",
|
|
|
+ " loss/train: 2.4506516456604004/28\n",
|
|
|
+ " batch 15 loss: 2.4280431270599365\n",
|
|
|
+ " loss/train: 2.4280431270599365/33\n",
|
|
|
+ "LOSS train 2.4280431270599365 / valid 2.4371285438537598\n",
|
|
|
+ "EPOCH 3\n",
|
|
|
+ " batch 5 loss: 2.444630241394043\n",
|
|
|
+ " loss/train: 2.444630241394043/41\n",
|
|
|
+ " batch 10 loss: 2.4294994354248045\n",
|
|
|
+ " loss/train: 2.4294994354248045/46\n",
|
|
|
+ " batch 15 loss: 2.4486732959747313\n",
|
|
|
+ " loss/train: 2.4486732959747313/51\n",
|
|
|
+ "LOSS train 2.4486732959747313 / valid 2.437175989151001\n",
|
|
|
+ "EPOCH 4\n",
|
|
|
+ " batch 5 loss: 2.449464464187622\n",
|
|
|
+ " loss/train: 2.449464464187622/59\n",
|
|
|
+ " batch 10 loss: 2.435447835922241\n",
|
|
|
+ " loss/train: 2.435447835922241/64\n",
|
|
|
+ " batch 15 loss: 2.4380492210388183\n",
|
|
|
+ " loss/train: 2.4380492210388183/69\n",
|
|
|
+ "LOSS train 2.4380492210388183 / valid 2.437028408050537\n",
|
|
|
+ "EPOCH 5\n",
|
|
|
+ " batch 5 loss: 2.4409915447235107\n",
|
|
|
+ " loss/train: 2.4409915447235107/77\n",
|
|
|
+ " batch 10 loss: 2.443865346908569\n",
|
|
|
+ " loss/train: 2.443865346908569/82\n",
|
|
|
+ " batch 15 loss: 2.4377002716064453\n",
|
|
|
+ " loss/train: 2.4377002716064453/87\n",
|
|
|
+ "LOSS train 2.4377002716064453 / valid 2.4371585845947266\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "# train the model\n",
|
|
|
+ "print(\"**** TRAINING NN on supplied data ****\")\n",
|
|
|
+ "\n",
|
|
|
+ "epoch_no = 0\n",
|
|
|
+ "num_epochs = 5\n",
|
|
|
+ "best_vloss = 1000000\n",
|
|
|
+ "\n",
|
|
|
+ "for epoch in range(num_epochs):\n",
|
|
|
+ " print(\"EPOCH\", epoch + 1)\n",
|
|
|
+ "\n",
|
|
|
+ " model.train(True)\n",
|
|
|
+ " avg_loss = train_epoch(epoch_no)\n",
|
|
|
+ "\n",
|
|
|
+ " running_vloss = 0.0\n",
|
|
|
+ " model.eval()\n",
|
|
|
+ " with torch.no_grad():\n",
|
|
|
+ " for idx, vdata in enumerate(validate_loader):\n",
|
|
|
+ " vfeatures, vlabels = vdata\n",
|
|
|
+ " voutputs = model(vfeatures)\n",
|
|
|
+ " vloss = loss_fn(voutputs, vlabels)\n",
|
|
|
+ " running_vloss += vloss\n",
|
|
|
+ "\n",
|
|
|
+ " avg_vloss = running_vloss / (idx + 1)\n",
|
|
|
+ " print(\"LOSS train {} / valid {}\".format(avg_loss, avg_vloss))\n",
|
|
|
+ "\n",
|
|
|
+ " if avg_vloss < best_vloss:\n",
|
|
|
+ " best_vloss = avg_vloss\n",
|
|
|
+ "\n",
|
|
|
+ " epoch_no += 1"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 259,
|
|
|
+ "id": "4ef4f484-5664-45dd-a5aa-2909276cedc8",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "data": {
|
|
|
+ "text/plain": [
|
|
|
+ "SeqNN(\n",
|
|
|
+ " (linear_relu_stack): Sequential(\n",
|
|
|
+ " (0): Linear(in_features=11, out_features=512, bias=True)\n",
|
|
|
+ " (1): ReLU()\n",
|
|
|
+ " (2): Linear(in_features=512, out_features=10, bias=True)\n",
|
|
|
+ " (3): Softmax(dim=None)\n",
|
|
|
+ " )\n",
|
|
|
+ ")"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ "execution_count": 259,
|
|
|
+ "metadata": {},
|
|
|
+ "output_type": "execute_result"
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "model.eval()"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "1ff5d2c3-e979-434f-a6ac-355c68832952",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": []
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "metadata": {
|
|
|
+ "kernelspec": {
|
|
|
+ "display_name": "pytorch-26",
|
|
|
+ "language": "python",
|
|
|
+ "name": "pytorch-26"
|
|
|
+ },
|
|
|
+ "language_info": {
|
|
|
+ "codemirror_mode": {
|
|
|
+ "name": "ipython",
|
|
|
+ "version": 3
|
|
|
+ },
|
|
|
+ "file_extension": ".py",
|
|
|
+ "mimetype": "text/x-python",
|
|
|
+ "name": "python",
|
|
|
+ "nbconvert_exporter": "python",
|
|
|
+ "pygments_lexer": "ipython3",
|
|
|
+ "version": "3.12.9"
|
|
|
+ }
|
|
|
+ },
|
|
|
+ "nbformat": 4,
|
|
|
+ "nbformat_minor": 5
|
|
|
+}
|